diff scripts/outputs.py @ 16:f9eb041c518c draft

planemo upload for repository https://github.com/usegalaxy-au/tools-au commit ee77734f1800350fa2a6ef28b2b8eade304a456f-dirty
author galaxy-australia
date Mon, 03 Apr 2023 01:00:42 +0000
parents
children 6ab1a261520a
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/scripts/outputs.py	Mon Apr 03 01:00:42 2023 +0000
@@ -0,0 +1,324 @@
+"""Generate additional output files not produced by AlphaFold.
+
+Currently this is includes:
+- model confidence scores
+- per-residue confidence scores (pLDDTs - optional output)
+- model_*.pkl files renamed with rank order
+
+N.B. There have been issues with this script breaking between AlphaFold
+versions due to minor changes in the output directory structure across minor
+versions. It will likely need updating with future releases of AlphaFold.
+
+This code is more complex than you might expect due to the output files
+'moving around' considerably, depending on run parameters. You will see that
+several output paths are determined dynamically.
+"""
+
+import argparse
+import json
+import os
+import pickle as pk
+import shutil
+from pathlib import Path
+from typing import List
+
+from matplotlib import pyplot as plt
+
+# Output file paths
+OUTPUT_DIR = 'extra'
+OUTPUTS = {
+    'model_pkl': OUTPUT_DIR + '/ranked_{rank}.pkl',
+    'model_pae': OUTPUT_DIR + '/pae_ranked_{rank}.csv',
+    'model_plot': OUTPUT_DIR + '/ranked_{rank}.png',
+    'model_confidence_scores': OUTPUT_DIR + '/model_confidence_scores.tsv',
+    'plddts': OUTPUT_DIR + '/plddts.tsv',
+    'relax': OUTPUT_DIR + '/relax_metrics_ranked.json',
+}
+
+# Keys for accessing confidence data from JSON/pkl files
+# They change depending on whether the run was monomer or multimer
+PLDDT_KEY = {
+    'monomer': 'plddts',
+    'multimer': 'iptm+ptm',
+}
+
+
+class Settings:
+    """Parse and store settings/config."""
+    def __init__(self):
+        self.workdir = None
+        self.output_confidence_scores = True
+        self.output_residue_scores = False
+        self.is_multimer = False
+        self.parse()
+
+    def parse(self) -> None:
+        parser = argparse.ArgumentParser()
+        parser.add_argument(
+            "workdir",
+            help="alphafold output directory",
+            type=str
+        )
+        parser.add_argument(
+            "-p",
+            "--plddts",
+            help="output per-residue confidence scores (pLDDTs)",
+            action="store_true"
+        )
+        parser.add_argument(
+            "-m",
+            "--multimer",
+            help="parse output from AlphaFold multimer",
+            action="store_true"
+        )
+        parser.add_argument(
+            "--pkl",
+            help="rename model pkl outputs with rank order",
+            action="store_true"
+        )
+        parser.add_argument(
+            "--pae",
+            help="extract PAE from pkl files to CSV format",
+            action="store_true"
+        )
+        parser.add_argument(
+            "--plot",
+            help="Plot pLDDT and PAE for each model",
+            action="store_true"
+        )
+        args = parser.parse_args()
+        self.workdir = Path(args.workdir.rstrip('/'))
+        self.output_residue_scores = args.plddts
+        self.output_model_pkls = args.pkl
+        self.output_model_plots = args.plot
+        self.output_pae = args.pae
+        self.is_multimer = args.multimer
+        self.output_dir = self.workdir / OUTPUT_DIR
+        os.makedirs(self.output_dir, exist_ok=True)
+
+
+class ExecutionContext:
+    """Collect file paths etc."""
+    def __init__(self, settings: Settings):
+        self.settings = settings
+        if settings.is_multimer:
+            self.plddt_key = PLDDT_KEY['multimer']
+        else:
+            self.plddt_key = PLDDT_KEY['monomer']
+
+    def get_model_key(self, ix: int) -> str:
+        """Return json key for model index.
+
+        The key format changed between minor AlphaFold versions so this
+        function determines the correct key.
+        """
+        with open(self.ranking_debug) as f:
+            data = json.load(f)
+        model_keys = list(data[self.plddt_key].keys())
+        for k in model_keys:
+            if k.startswith(f"model_{ix}_"):
+                return k
+        return KeyError(
+            f'Could not find key for index={ix} in'
+            ' ranking_debug.json')
+
+    @property
+    def ranking_debug(self) -> str:
+        return self.settings.workdir / 'ranking_debug.json'
+
+    @property
+    def relax_metrics(self) -> str:
+        return self.settings.workdir / 'relax_metrics.json'
+
+    @property
+    def relax_metrics_ranked(self) -> str:
+        return self.settings.workdir / 'relax_metrics_ranked.json'
+
+    @property
+    def model_pkl_paths(self) -> List[str]:
+        return sorted([
+            self.settings.workdir / f
+            for f in os.listdir(self.settings.workdir)
+            if f.startswith('result_model_') and f.endswith('.pkl')
+        ])
+
+
+class ResultModelPrediction:
+    """Load and manipulate data from result_model_*.pkl files."""
+    def __init__(self, path: str, context: ExecutionContext):
+        self.context = context
+        self.path = path
+        self.name = os.path.basename(path).replace('result_', '').split('.')[0]
+        with open(path, 'rb') as path:
+            self.data = pk.load(path)
+
+    @property
+    def plddts(self) -> List[float]:
+        """Return pLDDT scores for each residue."""
+        return list(self.data['plddt'])
+
+
+class ResultRanking:
+    """Load and manipulate data from ranking_debug.json file."""
+
+    def __init__(self, context: ExecutionContext):
+        self.path = context.ranking_debug
+        self.context = context
+        with open(self.path, 'r') as f:
+            self.data = json.load(f)
+
+    @property
+    def order(self) -> List[str]:
+        """Return ordered list of model indexes."""
+        return self.data['order']
+
+    def get_plddt_for_rank(self, rank: int) -> List[float]:
+        """Get pLDDT score for model instance."""
+        return self.data[self.context.plddt_key][self.data['order'][rank - 1]]
+
+    def get_rank_for_model(self, model_name: str) -> int:
+        """Return 0-indexed rank for given model name.
+
+        Model names are expressed in result_model_*.pkl file names.
+        """
+        return self.data['order'].index(model_name)
+
+
+def write_confidence_scores(ranking: ResultRanking, context: ExecutionContext):
+    """Write per-model confidence scores."""
+    path = context.settings.workdir / OUTPUTS['model_confidence_scores']
+    with open(path, 'w') as f:
+        for rank in range(1, 6):
+            score = ranking.get_plddt_for_rank(rank)
+            f.write(f'ranked_{rank - 1}\t{score:.2f}\n')
+
+
+def write_per_residue_scores(
+    ranking: ResultRanking,
+    context: ExecutionContext,
+):
+    """Write per-residue plddts for each model.
+
+    A row of plddt values is written for each model in tabular format.
+    """
+    model_plddts = {}
+    for i, path in enumerate(context.model_pkl_paths):
+        model = ResultModelPrediction(path, context)
+        rank = ranking.get_rank_for_model(model.name)
+        model_plddts[rank] = model.plddts
+
+    path = context.settings.workdir / OUTPUTS['plddts']
+    with open(path, 'w') as f:
+        for i in sorted(list(model_plddts.keys())):
+            row = [f'ranked_{i}'] + [
+                str(x) for x in model_plddts[i]
+            ]
+            f.write('\t'.join(row) + '\n')
+
+
+def rename_model_pkls(ranking: ResultRanking, context: ExecutionContext):
+    """Rename model.pkl files so the rank order is implicit."""
+    for path in context.model_pkl_paths:
+        model = ResultModelPrediction(path, context)
+        rank = ranking.get_rank_for_model(model.name)
+        new_path = (
+            context.settings.workdir
+            / OUTPUTS['model_pkl'].format(rank=rank)
+        )
+        shutil.copyfile(path, new_path)
+
+
+def extract_pae_to_csv(ranking: ResultRanking, context: ExecutionContext):
+    """Extract predicted alignment error matrix from pickle files.
+
+    Creates a CSV file for each of five ranked models.
+    """
+    for path in context.model_pkl_paths:
+        model = ResultModelPrediction(path, context)
+        rank = ranking.get_rank_for_model(model.name)
+        with open(path, 'rb') as f:
+            data = pk.load(f)
+        if 'predicted_aligned_error' not in data:
+            print("Skipping PAE output"
+                  f" - not found in {path}."
+                  " Running with model_preset=monomer?")
+            return
+        pae = data['predicted_aligned_error']
+        out_path = (
+            context.settings.workdir
+            / OUTPUTS['model_pae'].format(rank=rank)
+        )
+        with open(out_path, 'w') as f:
+            for row in pae:
+                f.write(','.join([str(x) for x in row]) + '\n')
+
+
+def rekey_relax_metrics(ranking: ResultRanking, context: ExecutionContext):
+    """Replace keys in relax_metrics.json with 0-indexed rank."""
+    with open(context.relax_metrics) as f:
+        data = json.load(f)
+        for k in list(data.keys()):
+            rank = ranking.get_rank_for_model(k)
+            data[f'ranked_{rank}'] = data.pop(k)
+    new_path = context.settings.workdir / OUTPUTS['relax']
+    with open(new_path, 'w') as f:
+        json.dump(data, f)
+
+
+def plddt_pae_plots(ranking: ResultRanking, context: ExecutionContext):
+    """Generate a pLDDT + PAE plot for each model."""
+    for path in context.model_pkl_paths:
+        num_plots = 2
+        model = ResultModelPrediction(path, context)
+        rank = ranking.get_rank_for_model(model.name)
+        png_path = (
+            context.settings.workdir
+            / OUTPUTS['model_plot'].format(rank=rank)
+        )
+        plddts = model.data['plddt']
+        if 'predicted_aligned_error' in model.data:
+            pae = model.data['predicted_aligned_error']
+            max_pae = model.data['max_predicted_aligned_error']
+        else:
+            num_plots = 1
+
+        plt.figure(figsize=[8 * num_plots, 6])
+        plt.subplot(1, num_plots, 1)
+        plt.plot(plddts)
+        plt.title('Predicted LDDT')
+        plt.xlabel('Residue')
+        plt.ylabel('pLDDT')
+
+        if num_plots == 2:
+            plt.subplot(1, 2, 2)
+            plt.imshow(pae, vmin=0., vmax=max_pae, cmap='Greens_r')
+            plt.colorbar(fraction=0.046, pad=0.04)
+            plt.title('Predicted Aligned Error')
+            plt.xlabel('Scored residue')
+            plt.ylabel('Aligned residue')
+
+        plt.savefig(png_path)
+
+
+def main():
+    """Parse output files and generate additional output files."""
+    settings = Settings()
+    context = ExecutionContext(settings)
+    ranking = ResultRanking(context)
+    write_confidence_scores(ranking, context)
+    rekey_relax_metrics(ranking, context)
+
+    # Optional outputs
+    if settings.output_model_pkls:
+        rename_model_pkls(ranking, context)
+    if settings.output_model_plots:
+        plddt_pae_plots(ranking, context)
+    if settings.output_pae:
+        # Only created by monomer_ptm and multimer models
+        extract_pae_to_csv(ranking, context)
+    if settings.output_residue_scores:
+        write_per_residue_scores(ranking, context)
+
+
+if __name__ == '__main__':
+    main()