diff gen_extra_outputs.py @ 13:c0e71cb2bd1b draft

planemo upload for repository https://github.com/usegalaxy-au/tools-au commit 830b5bbf9c5375e714e1b7b9a3e8eec1e584e6b2
author galaxy-australia
date Wed, 12 Oct 2022 22:25:20 +0000
parents 3bd420ec162d
children
line wrap: on
line diff
--- a/gen_extra_outputs.py	Fri Sep 16 06:14:06 2022 +0000
+++ b/gen_extra_outputs.py	Wed Oct 12 22:25:20 2022 +0000
@@ -35,6 +35,7 @@
             action="store_true"
         )
         parser.add_argument(
+            "-m",
             "--multimer",
             help="parse output from AlphaFold multimer",
             action="store_true"
@@ -51,6 +52,12 @@
     def __init__(self, settings: Settings):
         self.settings = settings
 
+    def get_model_key(self, ix):
+        """Return json key for model index."""
+        if self.settings.is_multimer:
+            return f'model_{ix}_multimer'
+        return f'model_{ix}'
+
     @property
     def ranking_debug(self) -> str:
         return f'{self.settings.workdir}/ranking_debug.json'
@@ -110,7 +117,7 @@
             pklfile = model_pkls[i]
             with open(pklfile, 'rb') as fp:
                 data = pickle.load(fp)
-            plddts[f'model_{i+1}'] = [
+            plddts[self.context.get_model_key(i+1)] = [
                 float(f'{x:.2f}')
                 for x in data['plddt']
             ]
@@ -121,20 +128,25 @@
     """generates the output data we are interested in creating"""
     def __init__(self, loader: FileLoader):
         self.loader = loader
+        self.context = loader.context
 
     def gen_conf_scores(self):
         mapping = self.loader.get_model_mapping()
         scores = self.loader.get_conf_scores()
         ranked = list(scores.items())
         ranked.sort(key=lambda x: x[1], reverse=True)
-        return {f'model_{mapping[name]}': score
-                for name, score in ranked}
+        return {
+            self.context.get_model_key(mapping[name]): score
+            for name, score in ranked
+        }
 
     def gen_residue_scores(self) -> Dict[str, List[float]]:
         mapping = self.loader.get_model_mapping()
         model_plddts = self.loader.get_model_plddts()
-        return {f'model_{mapping[name]}': plddts
-                for name, plddts in model_plddts.items()}
+        return {
+            self.context.get_model_key(mapping[name]): plddts
+            for name, plddts in model_plddts.items()
+        }
 
 
 class OutputWriter: