Mercurial > repos > galaxy-australia > alphafold2
comparison gen_extra_outputs.py @ 9:3bd420ec162d draft
planemo upload for repository https://github.com/usegalaxy-au/tools-au commit 7726c3cba165bdc8fc6366ec0ce6596e55657468
author | galaxy-australia |
---|---|
date | Tue, 13 Sep 2022 22:04:12 +0000 |
parents | 7ae9d78b06f5 |
children | c0e71cb2bd1b |
comparison
equal
deleted
inserted
replaced
8:ca90d17ff51b | 9:3bd420ec162d |
---|---|
1 | 1 """Generate additional output files not produced by AlphaFold.""" |
2 | 2 |
3 import json | 3 import json |
4 import pickle | 4 import pickle |
5 import argparse | 5 import argparse |
6 from typing import Any, Dict, List | 6 from typing import Any, Dict, List |
7 | |
8 # Keys for accessing confidence data from JSON/pkl files | |
9 # They change depending on whether the run was monomer or multimer | |
10 CONTEXT_KEY = { | |
11 'monomer': 'plddts', | |
12 'multimer': 'iptm+ptm', | |
13 } | |
7 | 14 |
8 | 15 |
9 class Settings: | 16 class Settings: |
10 """parses then keeps track of program settings""" | 17 """parses then keeps track of program settings""" |
11 def __init__(self): | 18 def __init__(self): |
12 self.workdir = None | 19 self.workdir = None |
13 self.output_confidence_scores = True | 20 self.output_confidence_scores = True |
14 self.output_residue_scores = False | 21 self.output_residue_scores = False |
22 self.is_multimer = False | |
15 | 23 |
16 def parse_settings(self) -> None: | 24 def parse_settings(self) -> None: |
17 parser = argparse.ArgumentParser() | 25 parser = argparse.ArgumentParser() |
18 parser.add_argument( | 26 parser.add_argument( |
19 "workdir", | 27 "workdir", |
20 help="alphafold output directory", | 28 help="alphafold output directory", |
21 type=str | 29 type=str |
22 ) | 30 ) |
23 parser.add_argument( | 31 parser.add_argument( |
24 "-p", | 32 "-p", |
25 "--plddts", | 33 "--plddts", |
26 help="output per-residue confidence scores (pLDDTs)", | 34 help="output per-residue confidence scores (pLDDTs)", |
35 action="store_true" | |
36 ) | |
37 parser.add_argument( | |
38 "--multimer", | |
39 help="parse output from AlphaFold multimer", | |
27 action="store_true" | 40 action="store_true" |
28 ) | 41 ) |
29 args = parser.parse_args() | 42 args = parser.parse_args() |
30 self.workdir = args.workdir.rstrip('/') | 43 self.workdir = args.workdir.rstrip('/') |
31 self.output_residue_scores = args.plddts | 44 self.output_residue_scores = args.plddts |
45 self.is_multimer = False | |
46 self.is_multimer = args.multimer | |
32 | 47 |
33 | 48 |
34 class ExecutionContext: | 49 class ExecutionContext: |
35 """uses program settings to get paths to files etc""" | 50 """uses program settings to get paths to files etc""" |
36 def __init__(self, settings: Settings): | 51 def __init__(self, settings: Settings): |
40 def ranking_debug(self) -> str: | 55 def ranking_debug(self) -> str: |
41 return f'{self.settings.workdir}/ranking_debug.json' | 56 return f'{self.settings.workdir}/ranking_debug.json' |
42 | 57 |
43 @property | 58 @property |
44 def model_pkls(self) -> List[str]: | 59 def model_pkls(self) -> List[str]: |
45 return [f'{self.settings.workdir}/result_model_{i}.pkl' | 60 ext = '.pkl' |
46 for i in range(1, 6)] | 61 if self.settings.is_multimer: |
62 ext = '_multimer.pkl' | |
63 return [ | |
64 f'{self.settings.workdir}/result_model_{i}{ext}' | |
65 for i in range(1, 6) | |
66 ] | |
47 | 67 |
48 @property | 68 @property |
49 def model_conf_score_output(self) -> str: | 69 def model_conf_score_output(self) -> str: |
50 return f'{self.settings.workdir}/model_confidence_scores.tsv' | 70 return f'{self.settings.workdir}/model_confidence_scores.tsv' |
51 | 71 |
54 return f'{self.settings.workdir}/plddts.tsv' | 74 return f'{self.settings.workdir}/plddts.tsv' |
55 | 75 |
56 | 76 |
57 class FileLoader: | 77 class FileLoader: |
58 """loads file data for use by other classes""" | 78 """loads file data for use by other classes""" |
79 | |
59 def __init__(self, context: ExecutionContext): | 80 def __init__(self, context: ExecutionContext): |
60 self.context = context | 81 self.context = context |
61 | 82 |
83 @property | |
84 def confidence_key(self) -> str: | |
85 """Return the correct key for confidence data.""" | |
86 if self.context.settings.is_multimer: | |
87 return CONTEXT_KEY['multimer'] | |
88 return CONTEXT_KEY['monomer'] | |
89 | |
62 def get_model_mapping(self) -> Dict[str, int]: | 90 def get_model_mapping(self) -> Dict[str, int]: |
63 data = self.load_ranking_debug() | 91 data = self.load_ranking_debug() |
64 return {name: int(rank) + 1 | 92 return {name: int(rank) + 1 |
65 for (rank, name) in enumerate(data['order'])} | 93 for (rank, name) in enumerate(data['order'])} |
66 | 94 |
67 def get_conf_scores(self) -> Dict[str, float]: | 95 def get_conf_scores(self) -> Dict[str, float]: |
68 data = self.load_ranking_debug() | 96 data = self.load_ranking_debug() |
69 return {name: float(f'{score:.2f}') | 97 return { |
70 for name, score in data['plddts'].items()} | 98 name: float(f'{score:.2f}') |
99 for name, score in data[self.confidence_key].items() | |
100 } | |
71 | 101 |
72 def load_ranking_debug(self) -> Dict[str, Any]: | 102 def load_ranking_debug(self) -> Dict[str, Any]: |
73 with open(self.context.ranking_debug, 'r') as fp: | 103 with open(self.context.ranking_debug, 'r') as fp: |
74 return json.load(fp) | 104 return json.load(fp) |
75 | 105 |
76 def get_model_plddts(self) -> Dict[str, List[float]]: | 106 def get_model_plddts(self) -> Dict[str, List[float]]: |
77 plddts: Dict[str, List[float]] = {} | 107 plddts: Dict[str, List[float]] = {} |
78 model_pkls = self.context.model_pkls | 108 model_pkls = self.context.model_pkls |
79 for i in range(5): | 109 for i in range(len(model_pkls)): |
80 pklfile = model_pkls[i] | 110 pklfile = model_pkls[i] |
81 with open(pklfile, 'rb') as fp: | 111 with open(pklfile, 'rb') as fp: |
82 data = pickle.load(fp) | 112 data = pickle.load(fp) |
83 plddts[f'model_{i+1}'] = [float(f'{x:.2f}') for x in data['plddt']] | 113 plddts[f'model_{i+1}'] = [ |
114 float(f'{x:.2f}') | |
115 for x in data['plddt'] | |
116 ] | |
84 return plddts | 117 return plddts |
85 | 118 |
86 | 119 |
87 class OutputGenerator: | 120 class OutputGenerator: |
88 """generates the output data we are interested in creating""" | 121 """generates the output data we are interested in creating""" |
92 def gen_conf_scores(self): | 125 def gen_conf_scores(self): |
93 mapping = self.loader.get_model_mapping() | 126 mapping = self.loader.get_model_mapping() |
94 scores = self.loader.get_conf_scores() | 127 scores = self.loader.get_conf_scores() |
95 ranked = list(scores.items()) | 128 ranked = list(scores.items()) |
96 ranked.sort(key=lambda x: x[1], reverse=True) | 129 ranked.sort(key=lambda x: x[1], reverse=True) |
97 return {f'model_{mapping[name]}': score | 130 return {f'model_{mapping[name]}': score |
98 for name, score in ranked} | 131 for name, score in ranked} |
99 | 132 |
100 def gen_residue_scores(self) -> Dict[str, List[float]]: | 133 def gen_residue_scores(self) -> Dict[str, List[float]]: |
101 mapping = self.loader.get_model_mapping() | 134 mapping = self.loader.get_model_mapping() |
102 model_plddts = self.loader.get_model_plddts() | 135 model_plddts = self.loader.get_model_plddts() |
103 return {f'model_{mapping[name]}': plddts | 136 return {f'model_{mapping[name]}': plddts |
104 for name, plddts in model_plddts.items()} | 137 for name, plddts in model_plddts.items()} |
105 | 138 |
106 | 139 |
107 class OutputWriter: | 140 class OutputWriter: |
108 """writes generated data to files""" | 141 """writes generated data to files""" |
112 def write_conf_scores(self, data: Dict[str, float]) -> None: | 145 def write_conf_scores(self, data: Dict[str, float]) -> None: |
113 outfile = self.context.model_conf_score_output | 146 outfile = self.context.model_conf_score_output |
114 with open(outfile, 'w') as fp: | 147 with open(outfile, 'w') as fp: |
115 for model, score in data.items(): | 148 for model, score in data.items(): |
116 fp.write(f'{model}\t{score}\n') | 149 fp.write(f'{model}\t{score}\n') |
117 | 150 |
118 def write_residue_scores(self, data: Dict[str, List[float]]) -> None: | 151 def write_residue_scores(self, data: Dict[str, List[float]]) -> None: |
119 outfile = self.context.plddt_output | 152 outfile = self.context.plddt_output |
120 model_plddts = list(data.items()) | 153 model_plddts = list(data.items()) |
121 model_plddts.sort() | 154 model_plddts.sort() |
122 | 155 |
131 # setup | 164 # setup |
132 settings = Settings() | 165 settings = Settings() |
133 settings.parse_settings() | 166 settings.parse_settings() |
134 context = ExecutionContext(settings) | 167 context = ExecutionContext(settings) |
135 loader = FileLoader(context) | 168 loader = FileLoader(context) |
136 | 169 |
137 # generate & write outputs | 170 # generate & write outputs |
138 generator = OutputGenerator(loader) | 171 generator = OutputGenerator(loader) |
139 writer = OutputWriter(context) | 172 writer = OutputWriter(context) |
140 | 173 |
141 # confidence scores | 174 # confidence scores |
142 conf_scores = generator.gen_conf_scores() | 175 conf_scores = generator.gen_conf_scores() |
143 writer.write_conf_scores(conf_scores) | 176 writer.write_conf_scores(conf_scores) |
144 | 177 |
145 # per-residue plddts | 178 # per-residue plddts |
146 if settings.output_residue_scores: | 179 if settings.output_residue_scores: |
147 residue_scores = generator.gen_residue_scores() | 180 residue_scores = generator.gen_residue_scores() |
148 writer.write_residue_scores(residue_scores) | 181 writer.write_residue_scores(residue_scores) |
149 | 182 |
150 | 183 |
151 if __name__ == '__main__': | 184 if __name__ == '__main__': |
152 main() | 185 main() |
153 | |
154 | |
155 |