Mercurial > repos > galaxy-australia > alphafold2
comparison docker/alphafold/run_alphafold.py @ 1:6c92e000d684 draft
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
author | galaxy-australia |
---|---|
date | Tue, 01 Mar 2022 02:53:05 +0000 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
0:7ae9d78b06f5 | 1:6c92e000d684 |
---|---|
1 # Copyright 2021 DeepMind Technologies Limited | |
2 # | |
3 # Licensed under the Apache License, Version 2.0 (the "License"); | |
4 # you may not use this file except in compliance with the License. | |
5 # You may obtain a copy of the License at | |
6 # | |
7 # http://www.apache.org/licenses/LICENSE-2.0 | |
8 # | |
9 # Unless required by applicable law or agreed to in writing, software | |
10 # distributed under the License is distributed on an "AS IS" BASIS, | |
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
12 # See the License for the specific language governing permissions and | |
13 # limitations under the License. | |
14 | |
15 """Full AlphaFold protein structure prediction script.""" | |
16 import json | |
17 import os | |
18 import pathlib | |
19 import pickle | |
20 import random | |
21 import shutil | |
22 import sys | |
23 import time | |
24 from typing import Dict, Union, Optional | |
25 | |
26 from absl import app | |
27 from absl import flags | |
28 from absl import logging | |
29 from alphafold.common import protein | |
30 from alphafold.common import residue_constants | |
31 from alphafold.data import pipeline | |
32 from alphafold.data import pipeline_multimer | |
33 from alphafold.data import templates | |
34 from alphafold.data.tools import hhsearch | |
35 from alphafold.data.tools import hmmsearch | |
36 from alphafold.model import config | |
37 from alphafold.model import model | |
38 from alphafold.relax import relax | |
39 import numpy as np | |
40 | |
41 from alphafold.model import data | |
42 # Internal import (7716). | |
43 | |
44 logging.set_verbosity(logging.INFO) | |
45 | |
46 flags.DEFINE_list( | |
47 'fasta_paths', None, 'Paths to FASTA files, each containing a prediction ' | |
48 'target that will be folded one after another. If a FASTA file contains ' | |
49 'multiple sequences, then it will be folded as a multimer. Paths should be ' | |
50 'separated by commas. All FASTA paths must have a unique basename as the ' | |
51 'basename is used to name the output directories for each prediction.') | |
52 flags.DEFINE_list( | |
53 'is_prokaryote_list', None, 'Optional for multimer system, not used by the ' | |
54 'single chain system. This list should contain a boolean for each fasta ' | |
55 'specifying true where the target complex is from a prokaryote, and false ' | |
56 'where it is not, or where the origin is unknown. These values determine ' | |
57 'the pairing method for the MSA.') | |
58 | |
59 flags.DEFINE_string('data_dir', None, 'Path to directory of supporting data.') | |
60 flags.DEFINE_string('output_dir', None, 'Path to a directory that will ' | |
61 'store the results.') | |
62 flags.DEFINE_string('jackhmmer_binary_path', shutil.which('jackhmmer'), | |
63 'Path to the JackHMMER executable.') | |
64 flags.DEFINE_string('hhblits_binary_path', shutil.which('hhblits'), | |
65 'Path to the HHblits executable.') | |
66 flags.DEFINE_string('hhsearch_binary_path', shutil.which('hhsearch'), | |
67 'Path to the HHsearch executable.') | |
68 flags.DEFINE_string('hmmsearch_binary_path', shutil.which('hmmsearch'), | |
69 'Path to the hmmsearch executable.') | |
70 flags.DEFINE_string('hmmbuild_binary_path', shutil.which('hmmbuild'), | |
71 'Path to the hmmbuild executable.') | |
72 flags.DEFINE_string('kalign_binary_path', shutil.which('kalign'), | |
73 'Path to the Kalign executable.') | |
74 flags.DEFINE_string('uniref90_database_path', None, 'Path to the Uniref90 ' | |
75 'database for use by JackHMMER.') | |
76 flags.DEFINE_string('mgnify_database_path', None, 'Path to the MGnify ' | |
77 'database for use by JackHMMER.') | |
78 flags.DEFINE_string('bfd_database_path', None, 'Path to the BFD ' | |
79 'database for use by HHblits.') | |
80 flags.DEFINE_string('small_bfd_database_path', None, 'Path to the small ' | |
81 'version of BFD used with the "reduced_dbs" preset.') | |
82 flags.DEFINE_string('uniclust30_database_path', None, 'Path to the Uniclust30 ' | |
83 'database for use by HHblits.') | |
84 flags.DEFINE_string('uniprot_database_path', None, 'Path to the Uniprot ' | |
85 'database for use by JackHMMer.') | |
86 flags.DEFINE_string('pdb70_database_path', None, 'Path to the PDB70 ' | |
87 'database for use by HHsearch.') | |
88 flags.DEFINE_string('pdb_seqres_database_path', None, 'Path to the PDB ' | |
89 'seqres database for use by hmmsearch.') | |
90 flags.DEFINE_string('template_mmcif_dir', None, 'Path to a directory with ' | |
91 'template mmCIF structures, each named <pdb_id>.cif') | |
92 flags.DEFINE_string('max_template_date', None, 'Maximum template release date ' | |
93 'to consider. Important if folding historical test sets.') | |
94 flags.DEFINE_string('obsolete_pdbs_path', None, 'Path to file containing a ' | |
95 'mapping from obsolete PDB IDs to the PDB IDs of their ' | |
96 'replacements.') | |
97 flags.DEFINE_enum('db_preset', 'full_dbs', | |
98 ['full_dbs', 'reduced_dbs'], | |
99 'Choose preset MSA database configuration - ' | |
100 'smaller genetic database config (reduced_dbs) or ' | |
101 'full genetic database config (full_dbs)') | |
102 flags.DEFINE_enum('model_preset', 'monomer', | |
103 ['monomer', 'monomer_casp14', 'monomer_ptm', 'multimer'], | |
104 'Choose preset model configuration - the monomer model, ' | |
105 'the monomer model with extra ensembling, monomer model with ' | |
106 'pTM head, or multimer model') | |
107 flags.DEFINE_boolean('benchmark', False, 'Run multiple JAX model evaluations ' | |
108 'to obtain a timing that excludes the compilation time, ' | |
109 'which should be more indicative of the time required for ' | |
110 'inferencing many proteins.') | |
111 flags.DEFINE_integer('random_seed', None, 'The random seed for the data ' | |
112 'pipeline. By default, this is randomly generated. Note ' | |
113 'that even if this is set, Alphafold may still not be ' | |
114 'deterministic, because processes like GPU inference are ' | |
115 'nondeterministic.') | |
116 flags.DEFINE_boolean('use_precomputed_msas', False, 'Whether to read MSAs that ' | |
117 'have been written to disk. WARNING: This will not check ' | |
118 'if the sequence, database or configuration have changed.') | |
119 | |
120 FLAGS = flags.FLAGS | |
121 | |
122 MAX_TEMPLATE_HITS = 20 | |
123 RELAX_MAX_ITERATIONS = 0 | |
124 RELAX_ENERGY_TOLERANCE = 2.39 | |
125 RELAX_STIFFNESS = 10.0 | |
126 RELAX_EXCLUDE_RESIDUES = [] | |
127 RELAX_MAX_OUTER_ITERATIONS = 3 | |
128 | |
129 | |
130 def _check_flag(flag_name: str, | |
131 other_flag_name: str, | |
132 should_be_set: bool): | |
133 if should_be_set != bool(FLAGS[flag_name].value): | |
134 verb = 'be' if should_be_set else 'not be' | |
135 raise ValueError(f'{flag_name} must {verb} set when running with ' | |
136 f'"--{other_flag_name}={FLAGS[other_flag_name].value}".') | |
137 | |
138 | |
139 def predict_structure( | |
140 fasta_path: str, | |
141 fasta_name: str, | |
142 output_dir_base: str, | |
143 data_pipeline: Union[pipeline.DataPipeline, pipeline_multimer.DataPipeline], | |
144 model_runners: Dict[str, model.RunModel], | |
145 amber_relaxer: relax.AmberRelaxation, | |
146 benchmark: bool, | |
147 random_seed: int, | |
148 is_prokaryote: Optional[bool] = None): | |
149 """Predicts structure using AlphaFold for the given sequence.""" | |
150 logging.info('Predicting %s', fasta_name) | |
151 timings = {} | |
152 output_dir = os.path.join(output_dir_base, fasta_name) | |
153 if not os.path.exists(output_dir): | |
154 os.makedirs(output_dir) | |
155 msa_output_dir = os.path.join(output_dir, 'msas') | |
156 if not os.path.exists(msa_output_dir): | |
157 os.makedirs(msa_output_dir) | |
158 | |
159 # Get features. | |
160 t_0 = time.time() | |
161 if is_prokaryote is None: | |
162 feature_dict = data_pipeline.process( | |
163 input_fasta_path=fasta_path, | |
164 msa_output_dir=msa_output_dir) | |
165 else: | |
166 feature_dict = data_pipeline.process( | |
167 input_fasta_path=fasta_path, | |
168 msa_output_dir=msa_output_dir, | |
169 is_prokaryote=is_prokaryote) | |
170 timings['features'] = time.time() - t_0 | |
171 | |
172 # Write out features as a pickled dictionary. | |
173 features_output_path = os.path.join(output_dir, 'features.pkl') | |
174 with open(features_output_path, 'wb') as f: | |
175 pickle.dump(feature_dict, f, protocol=4) | |
176 | |
177 unrelaxed_pdbs = {} | |
178 relaxed_pdbs = {} | |
179 ranking_confidences = {} | |
180 | |
181 # Run the models. | |
182 num_models = len(model_runners) | |
183 for model_index, (model_name, model_runner) in enumerate( | |
184 model_runners.items()): | |
185 logging.info('Running model %s on %s', model_name, fasta_name) | |
186 t_0 = time.time() | |
187 model_random_seed = model_index + random_seed * num_models | |
188 processed_feature_dict = model_runner.process_features( | |
189 feature_dict, random_seed=model_random_seed) | |
190 timings[f'process_features_{model_name}'] = time.time() - t_0 | |
191 | |
192 t_0 = time.time() | |
193 prediction_result = model_runner.predict(processed_feature_dict, | |
194 random_seed=model_random_seed) | |
195 t_diff = time.time() - t_0 | |
196 timings[f'predict_and_compile_{model_name}'] = t_diff | |
197 logging.info( | |
198 'Total JAX model %s on %s predict time (includes compilation time, see --benchmark): %.1fs', | |
199 model_name, fasta_name, t_diff) | |
200 | |
201 if benchmark: | |
202 t_0 = time.time() | |
203 model_runner.predict(processed_feature_dict, | |
204 random_seed=model_random_seed) | |
205 t_diff = time.time() - t_0 | |
206 timings[f'predict_benchmark_{model_name}'] = t_diff | |
207 logging.info( | |
208 'Total JAX model %s on %s predict time (excludes compilation time): %.1fs', | |
209 model_name, fasta_name, t_diff) | |
210 | |
211 plddt = prediction_result['plddt'] | |
212 ranking_confidences[model_name] = prediction_result['ranking_confidence'] | |
213 | |
214 # Save the model outputs. | |
215 result_output_path = os.path.join(output_dir, f'result_{model_name}.pkl') | |
216 with open(result_output_path, 'wb') as f: | |
217 pickle.dump(prediction_result, f, protocol=4) | |
218 | |
219 # Add the predicted LDDT in the b-factor column. | |
220 # Note that higher predicted LDDT value means higher model confidence. | |
221 plddt_b_factors = np.repeat( | |
222 plddt[:, None], residue_constants.atom_type_num, axis=-1) | |
223 unrelaxed_protein = protein.from_prediction( | |
224 features=processed_feature_dict, | |
225 result=prediction_result, | |
226 b_factors=plddt_b_factors, | |
227 remove_leading_feature_dimension=not model_runner.multimer_mode) | |
228 | |
229 unrelaxed_pdbs[model_name] = protein.to_pdb(unrelaxed_protein) | |
230 unrelaxed_pdb_path = os.path.join(output_dir, f'unrelaxed_{model_name}.pdb') | |
231 with open(unrelaxed_pdb_path, 'w') as f: | |
232 f.write(unrelaxed_pdbs[model_name]) | |
233 | |
234 if amber_relaxer: | |
235 # Relax the prediction. | |
236 t_0 = time.time() | |
237 relaxed_pdb_str, _, _ = amber_relaxer.process(prot=unrelaxed_protein) | |
238 timings[f'relax_{model_name}'] = time.time() - t_0 | |
239 | |
240 relaxed_pdbs[model_name] = relaxed_pdb_str | |
241 | |
242 # Save the relaxed PDB. | |
243 relaxed_output_path = os.path.join( | |
244 output_dir, f'relaxed_{model_name}.pdb') | |
245 with open(relaxed_output_path, 'w') as f: | |
246 f.write(relaxed_pdb_str) | |
247 | |
248 # Rank by model confidence and write out relaxed PDBs in rank order. | |
249 ranked_order = [] | |
250 for idx, (model_name, _) in enumerate( | |
251 sorted(ranking_confidences.items(), key=lambda x: x[1], reverse=True)): | |
252 ranked_order.append(model_name) | |
253 ranked_output_path = os.path.join(output_dir, f'ranked_{idx}.pdb') | |
254 with open(ranked_output_path, 'w') as f: | |
255 if amber_relaxer: | |
256 f.write(relaxed_pdbs[model_name]) | |
257 else: | |
258 f.write(unrelaxed_pdbs[model_name]) | |
259 | |
260 ranking_output_path = os.path.join(output_dir, 'ranking_debug.json') | |
261 with open(ranking_output_path, 'w') as f: | |
262 label = 'iptm+ptm' if 'iptm' in prediction_result else 'plddts' | |
263 f.write(json.dumps( | |
264 {label: ranking_confidences, 'order': ranked_order}, indent=4)) | |
265 | |
266 logging.info('Final timings for %s: %s', fasta_name, timings) | |
267 | |
268 timings_output_path = os.path.join(output_dir, 'timings.json') | |
269 with open(timings_output_path, 'w') as f: | |
270 f.write(json.dumps(timings, indent=4)) | |
271 | |
272 | |
273 def main(argv): | |
274 if len(argv) > 1: | |
275 raise app.UsageError('Too many command-line arguments.') | |
276 | |
277 for tool_name in ( | |
278 'jackhmmer', 'hhblits', 'hhsearch', 'hmmsearch', 'hmmbuild', 'kalign'): | |
279 if not FLAGS[f'{tool_name}_binary_path'].value: | |
280 raise ValueError(f'Could not find path to the "{tool_name}" binary. Make ' | |
281 'sure it is installed on your system.') | |
282 | |
283 use_small_bfd = FLAGS.db_preset == 'reduced_dbs' | |
284 _check_flag('small_bfd_database_path', 'db_preset', | |
285 should_be_set=use_small_bfd) | |
286 _check_flag('bfd_database_path', 'db_preset', | |
287 should_be_set=not use_small_bfd) | |
288 _check_flag('uniclust30_database_path', 'db_preset', | |
289 should_be_set=not use_small_bfd) | |
290 | |
291 run_multimer_system = 'multimer' in FLAGS.model_preset | |
292 _check_flag('pdb70_database_path', 'model_preset', | |
293 should_be_set=not run_multimer_system) | |
294 _check_flag('pdb_seqres_database_path', 'model_preset', | |
295 should_be_set=run_multimer_system) | |
296 _check_flag('uniprot_database_path', 'model_preset', | |
297 should_be_set=run_multimer_system) | |
298 | |
299 if FLAGS.model_preset == 'monomer_casp14': | |
300 num_ensemble = 8 | |
301 else: | |
302 num_ensemble = 1 | |
303 | |
304 # Check for duplicate FASTA file names. | |
305 fasta_names = [pathlib.Path(p).stem for p in FLAGS.fasta_paths] | |
306 if len(fasta_names) != len(set(fasta_names)): | |
307 raise ValueError('All FASTA paths must have a unique basename.') | |
308 | |
309 # Check that is_prokaryote_list has same number of elements as fasta_paths, | |
310 # and convert to bool. | |
311 if FLAGS.is_prokaryote_list: | |
312 if len(FLAGS.is_prokaryote_list) != len(FLAGS.fasta_paths): | |
313 raise ValueError('--is_prokaryote_list must either be omitted or match ' | |
314 'length of --fasta_paths.') | |
315 is_prokaryote_list = [] | |
316 for s in FLAGS.is_prokaryote_list: | |
317 if s in ('true', 'false'): | |
318 is_prokaryote_list.append(s == 'true') | |
319 else: | |
320 raise ValueError('--is_prokaryote_list must contain comma separated ' | |
321 'true or false values.') | |
322 else: # Default is_prokaryote to False. | |
323 is_prokaryote_list = [False] * len(fasta_names) | |
324 | |
325 if run_multimer_system: | |
326 template_searcher = hmmsearch.Hmmsearch( | |
327 binary_path=FLAGS.hmmsearch_binary_path, | |
328 hmmbuild_binary_path=FLAGS.hmmbuild_binary_path, | |
329 database_path=FLAGS.pdb_seqres_database_path) | |
330 template_featurizer = templates.HmmsearchHitFeaturizer( | |
331 mmcif_dir=FLAGS.template_mmcif_dir, | |
332 max_template_date=FLAGS.max_template_date, | |
333 max_hits=MAX_TEMPLATE_HITS, | |
334 kalign_binary_path=FLAGS.kalign_binary_path, | |
335 release_dates_path=None, | |
336 obsolete_pdbs_path=FLAGS.obsolete_pdbs_path) | |
337 else: | |
338 template_searcher = hhsearch.HHSearch( | |
339 binary_path=FLAGS.hhsearch_binary_path, | |
340 databases=[FLAGS.pdb70_database_path]) | |
341 template_featurizer = templates.HhsearchHitFeaturizer( | |
342 mmcif_dir=FLAGS.template_mmcif_dir, | |
343 max_template_date=FLAGS.max_template_date, | |
344 max_hits=MAX_TEMPLATE_HITS, | |
345 kalign_binary_path=FLAGS.kalign_binary_path, | |
346 release_dates_path=None, | |
347 obsolete_pdbs_path=FLAGS.obsolete_pdbs_path) | |
348 | |
349 monomer_data_pipeline = pipeline.DataPipeline( | |
350 jackhmmer_binary_path=FLAGS.jackhmmer_binary_path, | |
351 hhblits_binary_path=FLAGS.hhblits_binary_path, | |
352 uniref90_database_path=FLAGS.uniref90_database_path, | |
353 mgnify_database_path=FLAGS.mgnify_database_path, | |
354 bfd_database_path=FLAGS.bfd_database_path, | |
355 uniclust30_database_path=FLAGS.uniclust30_database_path, | |
356 small_bfd_database_path=FLAGS.small_bfd_database_path, | |
357 template_searcher=template_searcher, | |
358 template_featurizer=template_featurizer, | |
359 use_small_bfd=use_small_bfd, | |
360 use_precomputed_msas=FLAGS.use_precomputed_msas) | |
361 | |
362 if run_multimer_system: | |
363 data_pipeline = pipeline_multimer.DataPipeline( | |
364 monomer_data_pipeline=monomer_data_pipeline, | |
365 jackhmmer_binary_path=FLAGS.jackhmmer_binary_path, | |
366 uniprot_database_path=FLAGS.uniprot_database_path, | |
367 use_precomputed_msas=FLAGS.use_precomputed_msas) | |
368 else: | |
369 data_pipeline = monomer_data_pipeline | |
370 | |
371 model_runners = {} | |
372 model_names = config.MODEL_PRESETS[FLAGS.model_preset] | |
373 for model_name in model_names: | |
374 model_config = config.model_config(model_name) | |
375 if run_multimer_system: | |
376 model_config.model.num_ensemble_eval = num_ensemble | |
377 else: | |
378 model_config.data.eval.num_ensemble = num_ensemble | |
379 model_params = data.get_model_haiku_params( | |
380 model_name=model_name, data_dir=FLAGS.data_dir) | |
381 model_runner = model.RunModel(model_config, model_params) | |
382 model_runners[model_name] = model_runner | |
383 | |
384 logging.info('Have %d models: %s', len(model_runners), | |
385 list(model_runners.keys())) | |
386 | |
387 amber_relaxer = relax.AmberRelaxation( | |
388 max_iterations=RELAX_MAX_ITERATIONS, | |
389 tolerance=RELAX_ENERGY_TOLERANCE, | |
390 stiffness=RELAX_STIFFNESS, | |
391 exclude_residues=RELAX_EXCLUDE_RESIDUES, | |
392 max_outer_iterations=RELAX_MAX_OUTER_ITERATIONS) | |
393 | |
394 random_seed = FLAGS.random_seed | |
395 if random_seed is None: | |
396 random_seed = random.randrange(sys.maxsize // len(model_names)) | |
397 logging.info('Using random seed %d for the data pipeline', random_seed) | |
398 | |
399 # Predict structure for each of the sequences. | |
400 for i, fasta_path in enumerate(FLAGS.fasta_paths): | |
401 is_prokaryote = is_prokaryote_list[i] if run_multimer_system else None | |
402 fasta_name = fasta_names[i] | |
403 predict_structure( | |
404 fasta_path=fasta_path, | |
405 fasta_name=fasta_name, | |
406 output_dir_base=FLAGS.output_dir, | |
407 data_pipeline=data_pipeline, | |
408 model_runners=model_runners, | |
409 amber_relaxer=amber_relaxer, | |
410 benchmark=FLAGS.benchmark, | |
411 random_seed=random_seed, | |
412 is_prokaryote=is_prokaryote) | |
413 | |
414 | |
415 if __name__ == '__main__': | |
416 flags.mark_flags_as_required([ | |
417 'fasta_paths', | |
418 'output_dir', | |
419 'data_dir', | |
420 'uniref90_database_path', | |
421 'mgnify_database_path', | |
422 'template_mmcif_dir', | |
423 'max_template_date', | |
424 'obsolete_pdbs_path', | |
425 ]) | |
426 | |
427 app.run(main) |