comparison scripts/outputs.py @ 16:f9eb041c518c draft

planemo upload for repository https://github.com/usegalaxy-au/tools-au commit ee77734f1800350fa2a6ef28b2b8eade304a456f-dirty
author galaxy-australia
date Mon, 03 Apr 2023 01:00:42 +0000
parents
children 6ab1a261520a
comparison
equal deleted inserted replaced
15:a58f7eb0df2c 16:f9eb041c518c
1 """Generate additional output files not produced by AlphaFold.
2
3 Currently this is includes:
4 - model confidence scores
5 - per-residue confidence scores (pLDDTs - optional output)
6 - model_*.pkl files renamed with rank order
7
8 N.B. There have been issues with this script breaking between AlphaFold
9 versions due to minor changes in the output directory structure across minor
10 versions. It will likely need updating with future releases of AlphaFold.
11
12 This code is more complex than you might expect due to the output files
13 'moving around' considerably, depending on run parameters. You will see that
14 several output paths are determined dynamically.
15 """
16
17 import argparse
18 import json
19 import os
20 import pickle as pk
21 import shutil
22 from pathlib import Path
23 from typing import List
24
25 from matplotlib import pyplot as plt
26
27 # Output file paths
28 OUTPUT_DIR = 'extra'
29 OUTPUTS = {
30 'model_pkl': OUTPUT_DIR + '/ranked_{rank}.pkl',
31 'model_pae': OUTPUT_DIR + '/pae_ranked_{rank}.csv',
32 'model_plot': OUTPUT_DIR + '/ranked_{rank}.png',
33 'model_confidence_scores': OUTPUT_DIR + '/model_confidence_scores.tsv',
34 'plddts': OUTPUT_DIR + '/plddts.tsv',
35 'relax': OUTPUT_DIR + '/relax_metrics_ranked.json',
36 }
37
38 # Keys for accessing confidence data from JSON/pkl files
39 # They change depending on whether the run was monomer or multimer
40 PLDDT_KEY = {
41 'monomer': 'plddts',
42 'multimer': 'iptm+ptm',
43 }
44
45
46 class Settings:
47 """Parse and store settings/config."""
48 def __init__(self):
49 self.workdir = None
50 self.output_confidence_scores = True
51 self.output_residue_scores = False
52 self.is_multimer = False
53 self.parse()
54
55 def parse(self) -> None:
56 parser = argparse.ArgumentParser()
57 parser.add_argument(
58 "workdir",
59 help="alphafold output directory",
60 type=str
61 )
62 parser.add_argument(
63 "-p",
64 "--plddts",
65 help="output per-residue confidence scores (pLDDTs)",
66 action="store_true"
67 )
68 parser.add_argument(
69 "-m",
70 "--multimer",
71 help="parse output from AlphaFold multimer",
72 action="store_true"
73 )
74 parser.add_argument(
75 "--pkl",
76 help="rename model pkl outputs with rank order",
77 action="store_true"
78 )
79 parser.add_argument(
80 "--pae",
81 help="extract PAE from pkl files to CSV format",
82 action="store_true"
83 )
84 parser.add_argument(
85 "--plot",
86 help="Plot pLDDT and PAE for each model",
87 action="store_true"
88 )
89 args = parser.parse_args()
90 self.workdir = Path(args.workdir.rstrip('/'))
91 self.output_residue_scores = args.plddts
92 self.output_model_pkls = args.pkl
93 self.output_model_plots = args.plot
94 self.output_pae = args.pae
95 self.is_multimer = args.multimer
96 self.output_dir = self.workdir / OUTPUT_DIR
97 os.makedirs(self.output_dir, exist_ok=True)
98
99
100 class ExecutionContext:
101 """Collect file paths etc."""
102 def __init__(self, settings: Settings):
103 self.settings = settings
104 if settings.is_multimer:
105 self.plddt_key = PLDDT_KEY['multimer']
106 else:
107 self.plddt_key = PLDDT_KEY['monomer']
108
109 def get_model_key(self, ix: int) -> str:
110 """Return json key for model index.
111
112 The key format changed between minor AlphaFold versions so this
113 function determines the correct key.
114 """
115 with open(self.ranking_debug) as f:
116 data = json.load(f)
117 model_keys = list(data[self.plddt_key].keys())
118 for k in model_keys:
119 if k.startswith(f"model_{ix}_"):
120 return k
121 return KeyError(
122 f'Could not find key for index={ix} in'
123 ' ranking_debug.json')
124
125 @property
126 def ranking_debug(self) -> str:
127 return self.settings.workdir / 'ranking_debug.json'
128
129 @property
130 def relax_metrics(self) -> str:
131 return self.settings.workdir / 'relax_metrics.json'
132
133 @property
134 def relax_metrics_ranked(self) -> str:
135 return self.settings.workdir / 'relax_metrics_ranked.json'
136
137 @property
138 def model_pkl_paths(self) -> List[str]:
139 return sorted([
140 self.settings.workdir / f
141 for f in os.listdir(self.settings.workdir)
142 if f.startswith('result_model_') and f.endswith('.pkl')
143 ])
144
145
146 class ResultModelPrediction:
147 """Load and manipulate data from result_model_*.pkl files."""
148 def __init__(self, path: str, context: ExecutionContext):
149 self.context = context
150 self.path = path
151 self.name = os.path.basename(path).replace('result_', '').split('.')[0]
152 with open(path, 'rb') as path:
153 self.data = pk.load(path)
154
155 @property
156 def plddts(self) -> List[float]:
157 """Return pLDDT scores for each residue."""
158 return list(self.data['plddt'])
159
160
161 class ResultRanking:
162 """Load and manipulate data from ranking_debug.json file."""
163
164 def __init__(self, context: ExecutionContext):
165 self.path = context.ranking_debug
166 self.context = context
167 with open(self.path, 'r') as f:
168 self.data = json.load(f)
169
170 @property
171 def order(self) -> List[str]:
172 """Return ordered list of model indexes."""
173 return self.data['order']
174
175 def get_plddt_for_rank(self, rank: int) -> List[float]:
176 """Get pLDDT score for model instance."""
177 return self.data[self.context.plddt_key][self.data['order'][rank - 1]]
178
179 def get_rank_for_model(self, model_name: str) -> int:
180 """Return 0-indexed rank for given model name.
181
182 Model names are expressed in result_model_*.pkl file names.
183 """
184 return self.data['order'].index(model_name)
185
186
187 def write_confidence_scores(ranking: ResultRanking, context: ExecutionContext):
188 """Write per-model confidence scores."""
189 path = context.settings.workdir / OUTPUTS['model_confidence_scores']
190 with open(path, 'w') as f:
191 for rank in range(1, 6):
192 score = ranking.get_plddt_for_rank(rank)
193 f.write(f'ranked_{rank - 1}\t{score:.2f}\n')
194
195
196 def write_per_residue_scores(
197 ranking: ResultRanking,
198 context: ExecutionContext,
199 ):
200 """Write per-residue plddts for each model.
201
202 A row of plddt values is written for each model in tabular format.
203 """
204 model_plddts = {}
205 for i, path in enumerate(context.model_pkl_paths):
206 model = ResultModelPrediction(path, context)
207 rank = ranking.get_rank_for_model(model.name)
208 model_plddts[rank] = model.plddts
209
210 path = context.settings.workdir / OUTPUTS['plddts']
211 with open(path, 'w') as f:
212 for i in sorted(list(model_plddts.keys())):
213 row = [f'ranked_{i}'] + [
214 str(x) for x in model_plddts[i]
215 ]
216 f.write('\t'.join(row) + '\n')
217
218
219 def rename_model_pkls(ranking: ResultRanking, context: ExecutionContext):
220 """Rename model.pkl files so the rank order is implicit."""
221 for path in context.model_pkl_paths:
222 model = ResultModelPrediction(path, context)
223 rank = ranking.get_rank_for_model(model.name)
224 new_path = (
225 context.settings.workdir
226 / OUTPUTS['model_pkl'].format(rank=rank)
227 )
228 shutil.copyfile(path, new_path)
229
230
231 def extract_pae_to_csv(ranking: ResultRanking, context: ExecutionContext):
232 """Extract predicted alignment error matrix from pickle files.
233
234 Creates a CSV file for each of five ranked models.
235 """
236 for path in context.model_pkl_paths:
237 model = ResultModelPrediction(path, context)
238 rank = ranking.get_rank_for_model(model.name)
239 with open(path, 'rb') as f:
240 data = pk.load(f)
241 if 'predicted_aligned_error' not in data:
242 print("Skipping PAE output"
243 f" - not found in {path}."
244 " Running with model_preset=monomer?")
245 return
246 pae = data['predicted_aligned_error']
247 out_path = (
248 context.settings.workdir
249 / OUTPUTS['model_pae'].format(rank=rank)
250 )
251 with open(out_path, 'w') as f:
252 for row in pae:
253 f.write(','.join([str(x) for x in row]) + '\n')
254
255
256 def rekey_relax_metrics(ranking: ResultRanking, context: ExecutionContext):
257 """Replace keys in relax_metrics.json with 0-indexed rank."""
258 with open(context.relax_metrics) as f:
259 data = json.load(f)
260 for k in list(data.keys()):
261 rank = ranking.get_rank_for_model(k)
262 data[f'ranked_{rank}'] = data.pop(k)
263 new_path = context.settings.workdir / OUTPUTS['relax']
264 with open(new_path, 'w') as f:
265 json.dump(data, f)
266
267
268 def plddt_pae_plots(ranking: ResultRanking, context: ExecutionContext):
269 """Generate a pLDDT + PAE plot for each model."""
270 for path in context.model_pkl_paths:
271 num_plots = 2
272 model = ResultModelPrediction(path, context)
273 rank = ranking.get_rank_for_model(model.name)
274 png_path = (
275 context.settings.workdir
276 / OUTPUTS['model_plot'].format(rank=rank)
277 )
278 plddts = model.data['plddt']
279 if 'predicted_aligned_error' in model.data:
280 pae = model.data['predicted_aligned_error']
281 max_pae = model.data['max_predicted_aligned_error']
282 else:
283 num_plots = 1
284
285 plt.figure(figsize=[8 * num_plots, 6])
286 plt.subplot(1, num_plots, 1)
287 plt.plot(plddts)
288 plt.title('Predicted LDDT')
289 plt.xlabel('Residue')
290 plt.ylabel('pLDDT')
291
292 if num_plots == 2:
293 plt.subplot(1, 2, 2)
294 plt.imshow(pae, vmin=0., vmax=max_pae, cmap='Greens_r')
295 plt.colorbar(fraction=0.046, pad=0.04)
296 plt.title('Predicted Aligned Error')
297 plt.xlabel('Scored residue')
298 plt.ylabel('Aligned residue')
299
300 plt.savefig(png_path)
301
302
303 def main():
304 """Parse output files and generate additional output files."""
305 settings = Settings()
306 context = ExecutionContext(settings)
307 ranking = ResultRanking(context)
308 write_confidence_scores(ranking, context)
309 rekey_relax_metrics(ranking, context)
310
311 # Optional outputs
312 if settings.output_model_pkls:
313 rename_model_pkls(ranking, context)
314 if settings.output_model_plots:
315 plddt_pae_plots(ranking, context)
316 if settings.output_pae:
317 # Only created by monomer_ptm and multimer models
318 extract_pae_to_csv(ranking, context)
319 if settings.output_residue_scores:
320 write_per_residue_scores(ranking, context)
321
322
323 if __name__ == '__main__':
324 main()