Mercurial > repos > galaxy-australia > alphafold2
comparison scripts/outputs.py @ 24:31f648b7555a draft
planemo upload for repository https://github.com/usegalaxy-au/tools-au commit 44db277529c0e189149235cf60a627193a792fba
author | galaxy-australia |
---|---|
date | Sat, 05 Jul 2025 03:56:38 +0000 |
parents | 2891385d6ace |
children |
comparison
equal
deleted
inserted
replaced
23:2891385d6ace | 24:31f648b7555a |
---|---|
18 import json | 18 import json |
19 import numpy as np | 19 import numpy as np |
20 import os | 20 import os |
21 import pickle as pk | 21 import pickle as pk |
22 import shutil | 22 import shutil |
23 import zipfile | |
24 from matplotlib import pyplot as plt | |
25 from pathlib import Path | 23 from pathlib import Path |
26 from typing import Dict, List | 24 from typing import Dict, List |
27 | 25 |
26 from matplotlib import pyplot as plt | |
27 | |
28 # Output file paths | |
28 OUTPUT_DIR = 'extra' | 29 OUTPUT_DIR = 'extra' |
29 OUTPUTS = { | 30 OUTPUTS = { |
30 'model_pkl': OUTPUT_DIR + '/ranked_{rank}.pkl', | 31 'model_pkl': OUTPUT_DIR + '/ranked_{rank}.pkl', |
31 'model_pae': OUTPUT_DIR + '/pae_ranked_{rank}.csv', | 32 'model_pae': OUTPUT_DIR + '/pae_ranked_{rank}.csv', |
32 'model_plot': OUTPUT_DIR + '/ranked_{rank}.png', | 33 'model_plot': OUTPUT_DIR + '/ranked_{rank}.png', |
61 """Parse and store settings/config.""" | 62 """Parse and store settings/config.""" |
62 def __init__(self): | 63 def __init__(self): |
63 self.workdir = None | 64 self.workdir = None |
64 self.output_confidence_scores = True | 65 self.output_confidence_scores = True |
65 self.output_residue_scores = False | 66 self.output_residue_scores = False |
67 self.is_multimer = False | |
66 self.parse() | 68 self.parse() |
67 | 69 |
68 def parse(self) -> None: | 70 def parse(self) -> None: |
69 parser = argparse.ArgumentParser() | 71 parser = argparse.ArgumentParser() |
70 parser.add_argument( | 72 parser.add_argument( |
94 action="store_true", | 96 action="store_true", |
95 ) | 97 ) |
96 parser.add_argument( | 98 parser.add_argument( |
97 "--plot-msa", | 99 "--plot-msa", |
98 help="Plot multiple-sequence alignment coverage as a heatmap", | 100 help="Plot multiple-sequence alignment coverage as a heatmap", |
99 action="store_true", | |
100 ) | |
101 parser.add_argument( | |
102 "--msa", | |
103 help="Collect multiple-sequence alignments as ZIP archives", | |
104 action="store_true", | |
105 ) | |
106 parser.add_argument( | |
107 "--msa_only", | |
108 help="Alphafold generated MSA files only - skip all other outputs", | |
109 action="store_true", | 101 action="store_true", |
110 ) | 102 ) |
111 args = parser.parse_args() | 103 args = parser.parse_args() |
112 self.workdir = Path(args.workdir.rstrip('/')) | 104 self.workdir = Path(args.workdir.rstrip('/')) |
113 self.output_residue_scores = args.confidence_scores | 105 self.output_residue_scores = args.confidence_scores |
114 self.output_model_pkls = args.pkl | 106 self.output_model_pkls = args.pkl |
115 self.output_model_plots = args.plot | 107 self.output_model_plots = args.plot |
116 self.output_pae = args.pae | 108 self.output_pae = args.pae |
117 self.plot_msa = args.plot_msa | 109 self.plot_msa = args.plot_msa |
118 self.collect_msas = args.msa | |
119 self.model_preset = self._sniff_model_preset() | 110 self.model_preset = self._sniff_model_preset() |
120 self.is_multimer = self.model_preset == PRESETS.multimer | |
121 self.output_dir = self.workdir / OUTPUT_DIR | 111 self.output_dir = self.workdir / OUTPUT_DIR |
122 self.msa_only = args.msa_only | |
123 os.makedirs(self.output_dir, exist_ok=True) | 112 os.makedirs(self.output_dir, exist_ok=True) |
124 | 113 |
125 def _sniff_model_preset(self) -> bool: | 114 def _sniff_model_preset(self) -> bool: |
126 """Check if the run was multimer or monomer.""" | 115 """Check if the run was multimer or monomer.""" |
127 for path in self.workdir.glob('*.pkl'): | 116 for path in self.workdir.glob('*.pkl'): |
129 if '_multimer_' in path.name: | 118 if '_multimer_' in path.name: |
130 return PRESETS.multimer | 119 return PRESETS.multimer |
131 if '_ptm_' in path.name: | 120 if '_ptm_' in path.name: |
132 return PRESETS.monomer_ptm | 121 return PRESETS.monomer_ptm |
133 return PRESETS.monomer | 122 return PRESETS.monomer |
134 return PRESETS.monomer | |
135 | 123 |
136 | 124 |
137 class ExecutionContext: | 125 class ExecutionContext: |
138 """Collect file paths etc.""" | 126 """Collect file paths etc.""" |
139 def __init__(self, settings: Settings): | 127 def __init__(self, settings: Settings): |
140 self.settings = settings | 128 self.settings = settings |
141 if settings.is_multimer: | 129 if settings.model_preset == PRESETS.multimer: |
142 self.plddt_key = PLDDT_KEY.multimer | 130 self.plddt_key = PLDDT_KEY.multimer |
143 else: | 131 else: |
144 self.plddt_key = PLDDT_KEY.monomer | 132 self.plddt_key = PLDDT_KEY.monomer |
145 | 133 |
146 def get_model_key(self, ix: int) -> str: | 134 def get_model_key(self, ix: int) -> str: |
209 """Return ordered list of model indexes.""" | 197 """Return ordered list of model indexes.""" |
210 return self.data['order'] | 198 return self.data['order'] |
211 | 199 |
212 def get_plddt_for_rank(self, rank: int) -> List[float]: | 200 def get_plddt_for_rank(self, rank: int) -> List[float]: |
213 """Get pLDDT score for model instance.""" | 201 """Get pLDDT score for model instance.""" |
214 return self.data[self.context.plddt_key][self.data['order'][rank - 1]] | 202 return self.data[self.context.plddt_key][self.data['order'][rank]] |
215 | 203 |
216 def get_rank_for_model(self, model_name: str) -> int: | 204 def get_rank_for_model(self, model_name: str) -> int: |
217 """Return 0-indexed rank for given model name. | 205 """Return 0-indexed rank for given model name. |
218 | 206 |
219 Model names are expressed in result_model_*.pkl file names. | 207 Model names are expressed in result_model_*.pkl file names. |
226 outfile = context.settings.workdir / OUTPUTS['model_confidence_scores'] | 214 outfile = context.settings.workdir / OUTPUTS['model_confidence_scores'] |
227 scores: Dict[str, list] = {} | 215 scores: Dict[str, list] = {} |
228 header = ['model', context.plddt_key] | 216 header = ['model', context.plddt_key] |
229 | 217 |
230 for i, path in enumerate(context.model_pkl_paths): | 218 for i, path in enumerate(context.model_pkl_paths): |
231 rank = int(path.name.split('model_')[-1][0]) | 219 model_name = 'model_' + path.stem.split('model_')[1] |
220 rank = ranking.get_rank_for_model(model_name) | |
232 scores_ls = [ranking.get_plddt_for_rank(rank)] | 221 scores_ls = [ranking.get_plddt_for_rank(rank)] |
233 with open(path, 'rb') as f: | 222 with open(path, 'rb') as f: |
234 data = pk.load(f) | 223 data = pk.load(f) |
235 if 'ptm' in data: | 224 if 'ptm' in data: |
236 scores_ls.append(data['ptm']) | 225 scores_ls.append(data['ptm']) |
242 header += ['iptm'] | 231 header += ['iptm'] |
243 scores[rank] = scores_ls | 232 scores[rank] = scores_ls |
244 | 233 |
245 with open(outfile, 'w') as f: | 234 with open(outfile, 'w') as f: |
246 f.write('\t'.join(header) + '\n') | 235 f.write('\t'.join(header) + '\n') |
247 for rank, score_ls in scores.items(): | 236 for rank in sorted(scores): |
248 row = [f"ranked_{rank - 1}"] + [str(x) for x in score_ls] | 237 score_ls = scores[rank] |
238 row = [f"ranked_{rank}"] + [str(x) for x in score_ls] | |
249 f.write('\t'.join(row) + '\n') | 239 f.write('\t'.join(row) + '\n') |
250 | 240 |
251 | 241 |
252 def write_per_residue_scores( | 242 def write_per_residue_scores( |
253 ranking: ResultRanking, | 243 ranking: ResultRanking, |
388 plt.tight_layout() | 378 plt.tight_layout() |
389 plt.savefig(wdir / OUTPUTS['msa'], dpi=dpi) | 379 plt.savefig(wdir / OUTPUTS['msa'], dpi=dpi) |
390 plt.close() | 380 plt.close() |
391 | 381 |
392 | 382 |
393 def collect_msas(settings: Settings): | |
394 """Collect MSA files into ZIP archive(s).""" | |
395 | |
396 def zip_dir(directory: Path, is_multimer: bool, name: str): | |
397 chain_id = directory.with_suffix('.zip').stem | |
398 msa_dir = settings.output_dir / 'msas' | |
399 msa_dir.mkdir(exist_ok=True) | |
400 zip_name = ( | |
401 f"MSA-{chain_id}-{name}.zip" | |
402 if is_multimer | |
403 else f"MSA-{name}.zip") | |
404 zip_path = msa_dir / zip_name | |
405 with zipfile.ZipFile(zip_path, 'w') as z: | |
406 for path in directory.glob('*'): | |
407 z.write(path, path.name) | |
408 | |
409 print("Collecting MSA archives...") | |
410 chain_names = get_input_sequence_ids( | |
411 settings.workdir.parent.parent / 'alphafold.fasta') | |
412 msa_dir = settings.workdir / 'msas' | |
413 is_multimer = (msa_dir / 'A').exists() | |
414 if is_multimer: | |
415 msa_dirs = sorted([ | |
416 path for path in msa_dir.glob('*') | |
417 if path.is_dir() | |
418 ]) | |
419 for i, path in enumerate(msa_dirs): | |
420 zip_dir(path, is_multimer, chain_names[i]) | |
421 else: | |
422 zip_dir(msa_dir, is_multimer, chain_names[0]) | |
423 | |
424 | |
425 def get_input_sequence_ids(fasta_file: Path) -> List[str]: | |
426 """Read headers from the input FASTA file. | |
427 Split them to get a sequence ID and truncate to 20 chars max. | |
428 """ | |
429 headers = [] | |
430 for line in fasta_file.read_text().split('\n'): | |
431 if line.startswith('>'): | |
432 seq_id = line[1:].split(' ')[0] | |
433 seq_id_trunc = seq_id[:20].strip() | |
434 if len(seq_id) > 20: | |
435 seq_id_trunc += '...' | |
436 headers.append(seq_id_trunc) | |
437 return headers | |
438 | |
439 | |
440 def template_html(context: ExecutionContext): | 383 def template_html(context: ExecutionContext): |
441 """Template HTML file. | 384 """Template HTML file. |
442 | 385 |
443 Remove buttons that are redundant with limited model outputs. | 386 Remove buttons that are redundant with limited model outputs. |
444 """ | 387 """ |
454 | 397 |
455 | 398 |
456 def main(): | 399 def main(): |
457 """Parse output files and generate additional output files.""" | 400 """Parse output files and generate additional output files.""" |
458 settings = Settings() | 401 settings = Settings() |
459 if not settings.msa_only: | 402 context = ExecutionContext(settings) |
460 context = ExecutionContext(settings) | 403 ranking = ResultRanking(context) |
461 ranking = ResultRanking(context) | 404 write_confidence_scores(ranking, context) |
462 write_confidence_scores(ranking, context) | 405 rekey_relax_metrics(ranking, context) |
463 rekey_relax_metrics(ranking, context) | 406 template_html(context) |
464 template_html(context) | 407 |
465 | 408 # Optional outputs |
466 # Optional outputs | 409 if settings.output_model_pkls: |
467 if settings.output_model_pkls: | 410 rename_model_pkls(ranking, context) |
468 rename_model_pkls(ranking, context) | 411 if settings.output_model_plots: |
469 if settings.output_model_plots: | 412 plddt_pae_plots(ranking, context) |
470 plddt_pae_plots(ranking, context) | 413 if settings.output_pae: |
471 if settings.output_pae: | 414 # Only created by monomer_ptm and multimer models |
472 # Only created by monomer_ptm and multimer models | 415 extract_pae_to_csv(ranking, context) |
473 extract_pae_to_csv(ranking, context) | 416 if settings.output_residue_scores: |
474 if settings.output_residue_scores: | 417 write_per_residue_scores(ranking, context) |
475 write_per_residue_scores(ranking, context) | 418 if settings.plot_msa: |
476 if settings.plot_msa: | 419 plot_msa(context.settings.workdir) |
477 plot_msa(settings.workdir) | |
478 if settings.collect_msas or settings.msa_only: | |
479 collect_msas(settings) | |
480 | 420 |
481 | 421 |
482 if __name__ == '__main__': | 422 if __name__ == '__main__': |
483 main() | 423 main() |