Mercurial > repos > galaxy-australia > alphafold2
comparison scripts/outputs.py @ 16:f9eb041c518c draft
planemo upload for repository https://github.com/usegalaxy-au/tools-au commit ee77734f1800350fa2a6ef28b2b8eade304a456f-dirty
author | galaxy-australia |
---|---|
date | Mon, 03 Apr 2023 01:00:42 +0000 |
parents | |
children | 6ab1a261520a |
comparison
equal
deleted
inserted
replaced
15:a58f7eb0df2c | 16:f9eb041c518c |
---|---|
1 """Generate additional output files not produced by AlphaFold. | |
2 | |
3 Currently this is includes: | |
4 - model confidence scores | |
5 - per-residue confidence scores (pLDDTs - optional output) | |
6 - model_*.pkl files renamed with rank order | |
7 | |
8 N.B. There have been issues with this script breaking between AlphaFold | |
9 versions due to minor changes in the output directory structure across minor | |
10 versions. It will likely need updating with future releases of AlphaFold. | |
11 | |
12 This code is more complex than you might expect due to the output files | |
13 'moving around' considerably, depending on run parameters. You will see that | |
14 several output paths are determined dynamically. | |
15 """ | |
16 | |
17 import argparse | |
18 import json | |
19 import os | |
20 import pickle as pk | |
21 import shutil | |
22 from pathlib import Path | |
23 from typing import List | |
24 | |
25 from matplotlib import pyplot as plt | |
26 | |
27 # Output file paths | |
28 OUTPUT_DIR = 'extra' | |
29 OUTPUTS = { | |
30 'model_pkl': OUTPUT_DIR + '/ranked_{rank}.pkl', | |
31 'model_pae': OUTPUT_DIR + '/pae_ranked_{rank}.csv', | |
32 'model_plot': OUTPUT_DIR + '/ranked_{rank}.png', | |
33 'model_confidence_scores': OUTPUT_DIR + '/model_confidence_scores.tsv', | |
34 'plddts': OUTPUT_DIR + '/plddts.tsv', | |
35 'relax': OUTPUT_DIR + '/relax_metrics_ranked.json', | |
36 } | |
37 | |
38 # Keys for accessing confidence data from JSON/pkl files | |
39 # They change depending on whether the run was monomer or multimer | |
40 PLDDT_KEY = { | |
41 'monomer': 'plddts', | |
42 'multimer': 'iptm+ptm', | |
43 } | |
44 | |
45 | |
46 class Settings: | |
47 """Parse and store settings/config.""" | |
48 def __init__(self): | |
49 self.workdir = None | |
50 self.output_confidence_scores = True | |
51 self.output_residue_scores = False | |
52 self.is_multimer = False | |
53 self.parse() | |
54 | |
55 def parse(self) -> None: | |
56 parser = argparse.ArgumentParser() | |
57 parser.add_argument( | |
58 "workdir", | |
59 help="alphafold output directory", | |
60 type=str | |
61 ) | |
62 parser.add_argument( | |
63 "-p", | |
64 "--plddts", | |
65 help="output per-residue confidence scores (pLDDTs)", | |
66 action="store_true" | |
67 ) | |
68 parser.add_argument( | |
69 "-m", | |
70 "--multimer", | |
71 help="parse output from AlphaFold multimer", | |
72 action="store_true" | |
73 ) | |
74 parser.add_argument( | |
75 "--pkl", | |
76 help="rename model pkl outputs with rank order", | |
77 action="store_true" | |
78 ) | |
79 parser.add_argument( | |
80 "--pae", | |
81 help="extract PAE from pkl files to CSV format", | |
82 action="store_true" | |
83 ) | |
84 parser.add_argument( | |
85 "--plot", | |
86 help="Plot pLDDT and PAE for each model", | |
87 action="store_true" | |
88 ) | |
89 args = parser.parse_args() | |
90 self.workdir = Path(args.workdir.rstrip('/')) | |
91 self.output_residue_scores = args.plddts | |
92 self.output_model_pkls = args.pkl | |
93 self.output_model_plots = args.plot | |
94 self.output_pae = args.pae | |
95 self.is_multimer = args.multimer | |
96 self.output_dir = self.workdir / OUTPUT_DIR | |
97 os.makedirs(self.output_dir, exist_ok=True) | |
98 | |
99 | |
100 class ExecutionContext: | |
101 """Collect file paths etc.""" | |
102 def __init__(self, settings: Settings): | |
103 self.settings = settings | |
104 if settings.is_multimer: | |
105 self.plddt_key = PLDDT_KEY['multimer'] | |
106 else: | |
107 self.plddt_key = PLDDT_KEY['monomer'] | |
108 | |
109 def get_model_key(self, ix: int) -> str: | |
110 """Return json key for model index. | |
111 | |
112 The key format changed between minor AlphaFold versions so this | |
113 function determines the correct key. | |
114 """ | |
115 with open(self.ranking_debug) as f: | |
116 data = json.load(f) | |
117 model_keys = list(data[self.plddt_key].keys()) | |
118 for k in model_keys: | |
119 if k.startswith(f"model_{ix}_"): | |
120 return k | |
121 return KeyError( | |
122 f'Could not find key for index={ix} in' | |
123 ' ranking_debug.json') | |
124 | |
125 @property | |
126 def ranking_debug(self) -> str: | |
127 return self.settings.workdir / 'ranking_debug.json' | |
128 | |
129 @property | |
130 def relax_metrics(self) -> str: | |
131 return self.settings.workdir / 'relax_metrics.json' | |
132 | |
133 @property | |
134 def relax_metrics_ranked(self) -> str: | |
135 return self.settings.workdir / 'relax_metrics_ranked.json' | |
136 | |
137 @property | |
138 def model_pkl_paths(self) -> List[str]: | |
139 return sorted([ | |
140 self.settings.workdir / f | |
141 for f in os.listdir(self.settings.workdir) | |
142 if f.startswith('result_model_') and f.endswith('.pkl') | |
143 ]) | |
144 | |
145 | |
146 class ResultModelPrediction: | |
147 """Load and manipulate data from result_model_*.pkl files.""" | |
148 def __init__(self, path: str, context: ExecutionContext): | |
149 self.context = context | |
150 self.path = path | |
151 self.name = os.path.basename(path).replace('result_', '').split('.')[0] | |
152 with open(path, 'rb') as path: | |
153 self.data = pk.load(path) | |
154 | |
155 @property | |
156 def plddts(self) -> List[float]: | |
157 """Return pLDDT scores for each residue.""" | |
158 return list(self.data['plddt']) | |
159 | |
160 | |
161 class ResultRanking: | |
162 """Load and manipulate data from ranking_debug.json file.""" | |
163 | |
164 def __init__(self, context: ExecutionContext): | |
165 self.path = context.ranking_debug | |
166 self.context = context | |
167 with open(self.path, 'r') as f: | |
168 self.data = json.load(f) | |
169 | |
170 @property | |
171 def order(self) -> List[str]: | |
172 """Return ordered list of model indexes.""" | |
173 return self.data['order'] | |
174 | |
175 def get_plddt_for_rank(self, rank: int) -> List[float]: | |
176 """Get pLDDT score for model instance.""" | |
177 return self.data[self.context.plddt_key][self.data['order'][rank - 1]] | |
178 | |
179 def get_rank_for_model(self, model_name: str) -> int: | |
180 """Return 0-indexed rank for given model name. | |
181 | |
182 Model names are expressed in result_model_*.pkl file names. | |
183 """ | |
184 return self.data['order'].index(model_name) | |
185 | |
186 | |
187 def write_confidence_scores(ranking: ResultRanking, context: ExecutionContext): | |
188 """Write per-model confidence scores.""" | |
189 path = context.settings.workdir / OUTPUTS['model_confidence_scores'] | |
190 with open(path, 'w') as f: | |
191 for rank in range(1, 6): | |
192 score = ranking.get_plddt_for_rank(rank) | |
193 f.write(f'ranked_{rank - 1}\t{score:.2f}\n') | |
194 | |
195 | |
196 def write_per_residue_scores( | |
197 ranking: ResultRanking, | |
198 context: ExecutionContext, | |
199 ): | |
200 """Write per-residue plddts for each model. | |
201 | |
202 A row of plddt values is written for each model in tabular format. | |
203 """ | |
204 model_plddts = {} | |
205 for i, path in enumerate(context.model_pkl_paths): | |
206 model = ResultModelPrediction(path, context) | |
207 rank = ranking.get_rank_for_model(model.name) | |
208 model_plddts[rank] = model.plddts | |
209 | |
210 path = context.settings.workdir / OUTPUTS['plddts'] | |
211 with open(path, 'w') as f: | |
212 for i in sorted(list(model_plddts.keys())): | |
213 row = [f'ranked_{i}'] + [ | |
214 str(x) for x in model_plddts[i] | |
215 ] | |
216 f.write('\t'.join(row) + '\n') | |
217 | |
218 | |
219 def rename_model_pkls(ranking: ResultRanking, context: ExecutionContext): | |
220 """Rename model.pkl files so the rank order is implicit.""" | |
221 for path in context.model_pkl_paths: | |
222 model = ResultModelPrediction(path, context) | |
223 rank = ranking.get_rank_for_model(model.name) | |
224 new_path = ( | |
225 context.settings.workdir | |
226 / OUTPUTS['model_pkl'].format(rank=rank) | |
227 ) | |
228 shutil.copyfile(path, new_path) | |
229 | |
230 | |
231 def extract_pae_to_csv(ranking: ResultRanking, context: ExecutionContext): | |
232 """Extract predicted alignment error matrix from pickle files. | |
233 | |
234 Creates a CSV file for each of five ranked models. | |
235 """ | |
236 for path in context.model_pkl_paths: | |
237 model = ResultModelPrediction(path, context) | |
238 rank = ranking.get_rank_for_model(model.name) | |
239 with open(path, 'rb') as f: | |
240 data = pk.load(f) | |
241 if 'predicted_aligned_error' not in data: | |
242 print("Skipping PAE output" | |
243 f" - not found in {path}." | |
244 " Running with model_preset=monomer?") | |
245 return | |
246 pae = data['predicted_aligned_error'] | |
247 out_path = ( | |
248 context.settings.workdir | |
249 / OUTPUTS['model_pae'].format(rank=rank) | |
250 ) | |
251 with open(out_path, 'w') as f: | |
252 for row in pae: | |
253 f.write(','.join([str(x) for x in row]) + '\n') | |
254 | |
255 | |
256 def rekey_relax_metrics(ranking: ResultRanking, context: ExecutionContext): | |
257 """Replace keys in relax_metrics.json with 0-indexed rank.""" | |
258 with open(context.relax_metrics) as f: | |
259 data = json.load(f) | |
260 for k in list(data.keys()): | |
261 rank = ranking.get_rank_for_model(k) | |
262 data[f'ranked_{rank}'] = data.pop(k) | |
263 new_path = context.settings.workdir / OUTPUTS['relax'] | |
264 with open(new_path, 'w') as f: | |
265 json.dump(data, f) | |
266 | |
267 | |
268 def plddt_pae_plots(ranking: ResultRanking, context: ExecutionContext): | |
269 """Generate a pLDDT + PAE plot for each model.""" | |
270 for path in context.model_pkl_paths: | |
271 num_plots = 2 | |
272 model = ResultModelPrediction(path, context) | |
273 rank = ranking.get_rank_for_model(model.name) | |
274 png_path = ( | |
275 context.settings.workdir | |
276 / OUTPUTS['model_plot'].format(rank=rank) | |
277 ) | |
278 plddts = model.data['plddt'] | |
279 if 'predicted_aligned_error' in model.data: | |
280 pae = model.data['predicted_aligned_error'] | |
281 max_pae = model.data['max_predicted_aligned_error'] | |
282 else: | |
283 num_plots = 1 | |
284 | |
285 plt.figure(figsize=[8 * num_plots, 6]) | |
286 plt.subplot(1, num_plots, 1) | |
287 plt.plot(plddts) | |
288 plt.title('Predicted LDDT') | |
289 plt.xlabel('Residue') | |
290 plt.ylabel('pLDDT') | |
291 | |
292 if num_plots == 2: | |
293 plt.subplot(1, 2, 2) | |
294 plt.imshow(pae, vmin=0., vmax=max_pae, cmap='Greens_r') | |
295 plt.colorbar(fraction=0.046, pad=0.04) | |
296 plt.title('Predicted Aligned Error') | |
297 plt.xlabel('Scored residue') | |
298 plt.ylabel('Aligned residue') | |
299 | |
300 plt.savefig(png_path) | |
301 | |
302 | |
303 def main(): | |
304 """Parse output files and generate additional output files.""" | |
305 settings = Settings() | |
306 context = ExecutionContext(settings) | |
307 ranking = ResultRanking(context) | |
308 write_confidence_scores(ranking, context) | |
309 rekey_relax_metrics(ranking, context) | |
310 | |
311 # Optional outputs | |
312 if settings.output_model_pkls: | |
313 rename_model_pkls(ranking, context) | |
314 if settings.output_model_plots: | |
315 plddt_pae_plots(ranking, context) | |
316 if settings.output_pae: | |
317 # Only created by monomer_ptm and multimer models | |
318 extract_pae_to_csv(ranking, context) | |
319 if settings.output_residue_scores: | |
320 write_per_residue_scores(ranking, context) | |
321 | |
322 | |
323 if __name__ == '__main__': | |
324 main() |