comparison gen_extra_outputs.py @ 9:3bd420ec162d draft

planemo upload for repository https://github.com/usegalaxy-au/tools-au commit 7726c3cba165bdc8fc6366ec0ce6596e55657468
author galaxy-australia
date Tue, 13 Sep 2022 22:04:12 +0000
parents 7ae9d78b06f5
children c0e71cb2bd1b
comparison
equal deleted inserted replaced
8:ca90d17ff51b 9:3bd420ec162d
1 1 """Generate additional output files not produced by AlphaFold."""
2 2
3 import json 3 import json
4 import pickle 4 import pickle
5 import argparse 5 import argparse
6 from typing import Any, Dict, List 6 from typing import Any, Dict, List
7
8 # Keys for accessing confidence data from JSON/pkl files
9 # They change depending on whether the run was monomer or multimer
10 CONTEXT_KEY = {
11 'monomer': 'plddts',
12 'multimer': 'iptm+ptm',
13 }
7 14
8 15
9 class Settings: 16 class Settings:
10 """parses then keeps track of program settings""" 17 """parses then keeps track of program settings"""
11 def __init__(self): 18 def __init__(self):
12 self.workdir = None 19 self.workdir = None
13 self.output_confidence_scores = True 20 self.output_confidence_scores = True
14 self.output_residue_scores = False 21 self.output_residue_scores = False
22 self.is_multimer = False
15 23
16 def parse_settings(self) -> None: 24 def parse_settings(self) -> None:
17 parser = argparse.ArgumentParser() 25 parser = argparse.ArgumentParser()
18 parser.add_argument( 26 parser.add_argument(
19 "workdir", 27 "workdir",
20 help="alphafold output directory", 28 help="alphafold output directory",
21 type=str 29 type=str
22 ) 30 )
23 parser.add_argument( 31 parser.add_argument(
24 "-p", 32 "-p",
25 "--plddts", 33 "--plddts",
26 help="output per-residue confidence scores (pLDDTs)", 34 help="output per-residue confidence scores (pLDDTs)",
35 action="store_true"
36 )
37 parser.add_argument(
38 "--multimer",
39 help="parse output from AlphaFold multimer",
27 action="store_true" 40 action="store_true"
28 ) 41 )
29 args = parser.parse_args() 42 args = parser.parse_args()
30 self.workdir = args.workdir.rstrip('/') 43 self.workdir = args.workdir.rstrip('/')
31 self.output_residue_scores = args.plddts 44 self.output_residue_scores = args.plddts
45 self.is_multimer = False
46 self.is_multimer = args.multimer
32 47
33 48
34 class ExecutionContext: 49 class ExecutionContext:
35 """uses program settings to get paths to files etc""" 50 """uses program settings to get paths to files etc"""
36 def __init__(self, settings: Settings): 51 def __init__(self, settings: Settings):
40 def ranking_debug(self) -> str: 55 def ranking_debug(self) -> str:
41 return f'{self.settings.workdir}/ranking_debug.json' 56 return f'{self.settings.workdir}/ranking_debug.json'
42 57
43 @property 58 @property
44 def model_pkls(self) -> List[str]: 59 def model_pkls(self) -> List[str]:
45 return [f'{self.settings.workdir}/result_model_{i}.pkl' 60 ext = '.pkl'
46 for i in range(1, 6)] 61 if self.settings.is_multimer:
62 ext = '_multimer.pkl'
63 return [
64 f'{self.settings.workdir}/result_model_{i}{ext}'
65 for i in range(1, 6)
66 ]
47 67
48 @property 68 @property
49 def model_conf_score_output(self) -> str: 69 def model_conf_score_output(self) -> str:
50 return f'{self.settings.workdir}/model_confidence_scores.tsv' 70 return f'{self.settings.workdir}/model_confidence_scores.tsv'
51 71
54 return f'{self.settings.workdir}/plddts.tsv' 74 return f'{self.settings.workdir}/plddts.tsv'
55 75
56 76
57 class FileLoader: 77 class FileLoader:
58 """loads file data for use by other classes""" 78 """loads file data for use by other classes"""
79
59 def __init__(self, context: ExecutionContext): 80 def __init__(self, context: ExecutionContext):
60 self.context = context 81 self.context = context
61 82
83 @property
84 def confidence_key(self) -> str:
85 """Return the correct key for confidence data."""
86 if self.context.settings.is_multimer:
87 return CONTEXT_KEY['multimer']
88 return CONTEXT_KEY['monomer']
89
62 def get_model_mapping(self) -> Dict[str, int]: 90 def get_model_mapping(self) -> Dict[str, int]:
63 data = self.load_ranking_debug() 91 data = self.load_ranking_debug()
64 return {name: int(rank) + 1 92 return {name: int(rank) + 1
65 for (rank, name) in enumerate(data['order'])} 93 for (rank, name) in enumerate(data['order'])}
66 94
67 def get_conf_scores(self) -> Dict[str, float]: 95 def get_conf_scores(self) -> Dict[str, float]:
68 data = self.load_ranking_debug() 96 data = self.load_ranking_debug()
69 return {name: float(f'{score:.2f}') 97 return {
70 for name, score in data['plddts'].items()} 98 name: float(f'{score:.2f}')
99 for name, score in data[self.confidence_key].items()
100 }
71 101
72 def load_ranking_debug(self) -> Dict[str, Any]: 102 def load_ranking_debug(self) -> Dict[str, Any]:
73 with open(self.context.ranking_debug, 'r') as fp: 103 with open(self.context.ranking_debug, 'r') as fp:
74 return json.load(fp) 104 return json.load(fp)
75 105
76 def get_model_plddts(self) -> Dict[str, List[float]]: 106 def get_model_plddts(self) -> Dict[str, List[float]]:
77 plddts: Dict[str, List[float]] = {} 107 plddts: Dict[str, List[float]] = {}
78 model_pkls = self.context.model_pkls 108 model_pkls = self.context.model_pkls
79 for i in range(5): 109 for i in range(len(model_pkls)):
80 pklfile = model_pkls[i] 110 pklfile = model_pkls[i]
81 with open(pklfile, 'rb') as fp: 111 with open(pklfile, 'rb') as fp:
82 data = pickle.load(fp) 112 data = pickle.load(fp)
83 plddts[f'model_{i+1}'] = [float(f'{x:.2f}') for x in data['plddt']] 113 plddts[f'model_{i+1}'] = [
114 float(f'{x:.2f}')
115 for x in data['plddt']
116 ]
84 return plddts 117 return plddts
85 118
86 119
87 class OutputGenerator: 120 class OutputGenerator:
88 """generates the output data we are interested in creating""" 121 """generates the output data we are interested in creating"""
92 def gen_conf_scores(self): 125 def gen_conf_scores(self):
93 mapping = self.loader.get_model_mapping() 126 mapping = self.loader.get_model_mapping()
94 scores = self.loader.get_conf_scores() 127 scores = self.loader.get_conf_scores()
95 ranked = list(scores.items()) 128 ranked = list(scores.items())
96 ranked.sort(key=lambda x: x[1], reverse=True) 129 ranked.sort(key=lambda x: x[1], reverse=True)
97 return {f'model_{mapping[name]}': score 130 return {f'model_{mapping[name]}': score
98 for name, score in ranked} 131 for name, score in ranked}
99 132
100 def gen_residue_scores(self) -> Dict[str, List[float]]: 133 def gen_residue_scores(self) -> Dict[str, List[float]]:
101 mapping = self.loader.get_model_mapping() 134 mapping = self.loader.get_model_mapping()
102 model_plddts = self.loader.get_model_plddts() 135 model_plddts = self.loader.get_model_plddts()
103 return {f'model_{mapping[name]}': plddts 136 return {f'model_{mapping[name]}': plddts
104 for name, plddts in model_plddts.items()} 137 for name, plddts in model_plddts.items()}
105 138
106 139
107 class OutputWriter: 140 class OutputWriter:
108 """writes generated data to files""" 141 """writes generated data to files"""
112 def write_conf_scores(self, data: Dict[str, float]) -> None: 145 def write_conf_scores(self, data: Dict[str, float]) -> None:
113 outfile = self.context.model_conf_score_output 146 outfile = self.context.model_conf_score_output
114 with open(outfile, 'w') as fp: 147 with open(outfile, 'w') as fp:
115 for model, score in data.items(): 148 for model, score in data.items():
116 fp.write(f'{model}\t{score}\n') 149 fp.write(f'{model}\t{score}\n')
117 150
118 def write_residue_scores(self, data: Dict[str, List[float]]) -> None: 151 def write_residue_scores(self, data: Dict[str, List[float]]) -> None:
119 outfile = self.context.plddt_output 152 outfile = self.context.plddt_output
120 model_plddts = list(data.items()) 153 model_plddts = list(data.items())
121 model_plddts.sort() 154 model_plddts.sort()
122 155
131 # setup 164 # setup
132 settings = Settings() 165 settings = Settings()
133 settings.parse_settings() 166 settings.parse_settings()
134 context = ExecutionContext(settings) 167 context = ExecutionContext(settings)
135 loader = FileLoader(context) 168 loader = FileLoader(context)
136 169
137 # generate & write outputs 170 # generate & write outputs
138 generator = OutputGenerator(loader) 171 generator = OutputGenerator(loader)
139 writer = OutputWriter(context) 172 writer = OutputWriter(context)
140 173
141 # confidence scores 174 # confidence scores
142 conf_scores = generator.gen_conf_scores() 175 conf_scores = generator.gen_conf_scores()
143 writer.write_conf_scores(conf_scores) 176 writer.write_conf_scores(conf_scores)
144 177
145 # per-residue plddts 178 # per-residue plddts
146 if settings.output_residue_scores: 179 if settings.output_residue_scores:
147 residue_scores = generator.gen_residue_scores() 180 residue_scores = generator.gen_residue_scores()
148 writer.write_residue_scores(residue_scores) 181 writer.write_residue_scores(residue_scores)
149 182
150 183
151 if __name__ == '__main__': 184 if __name__ == '__main__':
152 main() 185 main()
153
154
155