Mercurial > repos > galaxy-australia > alphafold2
diff gen_extra_outputs.py @ 9:3bd420ec162d draft
planemo upload for repository https://github.com/usegalaxy-au/tools-au commit 7726c3cba165bdc8fc6366ec0ce6596e55657468
author | galaxy-australia |
---|---|
date | Tue, 13 Sep 2022 22:04:12 +0000 |
parents | 7ae9d78b06f5 |
children | c0e71cb2bd1b |
line wrap: on
line diff
--- a/gen_extra_outputs.py Fri Aug 19 00:29:16 2022 +0000 +++ b/gen_extra_outputs.py Tue Sep 13 22:04:12 2022 +0000 @@ -1,10 +1,17 @@ - +"""Generate additional output files not produced by AlphaFold.""" import json import pickle import argparse from typing import Any, Dict, List +# Keys for accessing confidence data from JSON/pkl files +# They change depending on whether the run was monomer or multimer +CONTEXT_KEY = { + 'monomer': 'plddts', + 'multimer': 'iptm+ptm', +} + class Settings: """parses then keeps track of program settings""" @@ -12,23 +19,31 @@ self.workdir = None self.output_confidence_scores = True self.output_residue_scores = False + self.is_multimer = False def parse_settings(self) -> None: parser = argparse.ArgumentParser() parser.add_argument( - "workdir", - help="alphafold output directory", + "workdir", + help="alphafold output directory", type=str - ) + ) parser.add_argument( "-p", "--plddts", - help="output per-residue confidence scores (pLDDTs)", + help="output per-residue confidence scores (pLDDTs)", + action="store_true" + ) + parser.add_argument( + "--multimer", + help="parse output from AlphaFold multimer", action="store_true" ) args = parser.parse_args() self.workdir = args.workdir.rstrip('/') self.output_residue_scores = args.plddts + self.is_multimer = False + self.is_multimer = args.multimer class ExecutionContext: @@ -42,8 +57,13 @@ @property def model_pkls(self) -> List[str]: - return [f'{self.settings.workdir}/result_model_{i}.pkl' - for i in range(1, 6)] + ext = '.pkl' + if self.settings.is_multimer: + ext = '_multimer.pkl' + return [ + f'{self.settings.workdir}/result_model_{i}{ext}' + for i in range(1, 6) + ] @property def model_conf_score_output(self) -> str: @@ -56,18 +76,28 @@ class FileLoader: """loads file data for use by other classes""" + def __init__(self, context: ExecutionContext): self.context = context + @property + def confidence_key(self) -> str: + """Return the correct key for confidence data.""" + if self.context.settings.is_multimer: + return CONTEXT_KEY['multimer'] + return CONTEXT_KEY['monomer'] + def get_model_mapping(self) -> Dict[str, int]: data = self.load_ranking_debug() - return {name: int(rank) + 1 + return {name: int(rank) + 1 for (rank, name) in enumerate(data['order'])} def get_conf_scores(self) -> Dict[str, float]: data = self.load_ranking_debug() - return {name: float(f'{score:.2f}') - for name, score in data['plddts'].items()} + return { + name: float(f'{score:.2f}') + for name, score in data[self.confidence_key].items() + } def load_ranking_debug(self) -> Dict[str, Any]: with open(self.context.ranking_debug, 'r') as fp: @@ -76,11 +106,14 @@ def get_model_plddts(self) -> Dict[str, List[float]]: plddts: Dict[str, List[float]] = {} model_pkls = self.context.model_pkls - for i in range(5): + for i in range(len(model_pkls)): pklfile = model_pkls[i] with open(pklfile, 'rb') as fp: data = pickle.load(fp) - plddts[f'model_{i+1}'] = [float(f'{x:.2f}') for x in data['plddt']] + plddts[f'model_{i+1}'] = [ + float(f'{x:.2f}') + for x in data['plddt'] + ] return plddts @@ -94,13 +127,13 @@ scores = self.loader.get_conf_scores() ranked = list(scores.items()) ranked.sort(key=lambda x: x[1], reverse=True) - return {f'model_{mapping[name]}': score + return {f'model_{mapping[name]}': score for name, score in ranked} def gen_residue_scores(self) -> Dict[str, List[float]]: mapping = self.loader.get_model_mapping() model_plddts = self.loader.get_model_plddts() - return {f'model_{mapping[name]}': plddts + return {f'model_{mapping[name]}': plddts for name, plddts in model_plddts.items()} @@ -114,7 +147,7 @@ with open(outfile, 'w') as fp: for model, score in data.items(): fp.write(f'{model}\t{score}\n') - + def write_residue_scores(self, data: Dict[str, List[float]]) -> None: outfile = self.context.plddt_output model_plddts = list(data.items()) @@ -133,23 +166,20 @@ settings.parse_settings() context = ExecutionContext(settings) loader = FileLoader(context) - + # generate & write outputs generator = OutputGenerator(loader) writer = OutputWriter(context) - + # confidence scores conf_scores = generator.gen_conf_scores() writer.write_conf_scores(conf_scores) - + # per-residue plddts if settings.output_residue_scores: residue_scores = generator.gen_residue_scores() writer.write_residue_scores(residue_scores) - + if __name__ == '__main__': main() - - -