changeset 21:e7f1b552a695 draft

planemo upload for repository https://github.com/usegalaxy-au/tools-au commit 628c9fdcb77489063145a2307b6bb6a450416dd6-dirty
author galaxy-australia
date Tue, 29 Oct 2024 02:15:36 +0000
parents 6ab1a261520a
children 3f188450ca4f
files alphafold.xml macro_output.xml scripts/outputs.py
diffstat 3 files changed, 141 insertions(+), 38 deletions(-) [+]
line wrap: on
line diff
--- a/alphafold.xml	Sun Jul 28 20:09:55 2024 +0000
+++ b/alphafold.xml	Tue Oct 29 02:15:36 2024 +0000
@@ -3,7 +3,7 @@
     <macros>
       <token name="@TOOL_VERSION@">2.3.2</token>
       <token name="@TOOL_MINOR_VERSION@">2.3</token>
-      <token name="@VERSION_SUFFIX@">0</token>
+      <token name="@VERSION_SUFFIX@">1</token>
       <import>macro_output.xml</import>
       <import>macro_test_output.xml</import>
     </macros>
@@ -118,9 +118,7 @@
 $outputs.model_pkls
 $outputs.pae_csv
 $outputs.plots
-#if $model_preset.selection == 'multimer':
---multimer
-#end if
+$outputs.plot_msa
 
 ## HTML output
 && mkdir -p '${ html.files_path }'
@@ -135,7 +133,7 @@
     ]]></command>
     <inputs>
         <conditional name="fasta_or_text">
-            <param name="input_mode" type="select" label="Fasta Input" help="Protein sequence(s) to fold. Input can be fasta file from history, or text. Sequence must be valid IUPAC amino acid characters. If multiple-sequence FASTA file provided, multimer mode must be selected.">
+            <param name="input_mode" type="select" label="Fasta Input" help="Protein sequence(s) to fold. Input can be fasta file from history, or text. Sequence must be valid IUPAC amino acid characters. We recommend submitting sequences with a maximum length of 3000AA, because run time scales exponentially with sequence length. If multiple-sequence FASTA file provided, multimer mode must be selected.">
                 <option value="history">Use fasta from history</option>
                 <option value="textbox">Paste sequence into textbox</option>
             </param>
@@ -236,18 +234,27 @@
                 help="A two-panel plot showing pLDDT against residue position (left) and PAE (paired-alignment error) as a heatmap image with residue numbers running along vertical and horizontal axes and color at each pixel indicating PAE value for the corresponding pair of residues. (right). PAE heatmap is only produced with monomer_ptm and multimer model presets."
             />
             <param
+                name="plot_msa"
+                type="boolean"
+                checked="false"
+                truevalue="--plot-msa"
+                falsevalue=""
+                label="MSA sequence coverage plot"
+                help="A heatmap showing sequence coverage across the multiple sequence alignment (MSA). This plot can help you understand if regions of low confidence are due to poor sequence coverage."
+            />
+            <param
                 name="confidence_scores"
                 type="boolean"
                 checked="false"
                 label="Per-model confidence scores"
-                help="A tabular file showing average confidence score for each model (predicted template modelling (PTM) score; interface PTM is incorporated into this score for multimer predictions)."
+                help="A tabular file showing average confidence score for each model. The monomer preset is scored in plddt, the monomer_ptm preset is scored in predicted template modelling (PTM) and the multimer preset is scored in PTM+IPTM (interface PTM)."
             />
             <param
                 name="plddts"
                 type="boolean"
                 checked="false"
                 label="Per-residue confidence scores"
-                truevalue="--plddts"
+                truevalue="--confidence-scores"
                 falsevalue=""
                 help="Alphafold produces a pLDDT score between 0-100 for each residue in the folded models. High scores represent high confidence in placement for the residue, while low scoring residues have lower confidence. This output is a tabular file with five rows (one for each output PDB model), with each column providing a pLDDT score for a single residue."
             />
@@ -291,6 +298,7 @@
         <data name="html" format="html" label="${tool.name} on ${on_string}: Visualization" />
         <!-- Optional outputs -->
         <expand macro="output_plddts" />
+        <expand macro="output_msa_plot" />
         <expand macro="output_confidence_scores" />
         <expand macro="output_pickles" />
         <expand macro="output_pae_csv" />
@@ -464,6 +472,17 @@
     |
     |
 
+    *MSA sequence coverage plot (optional)*
+
+    | A heatmap in PNG format showing:
+    | a) Per-position sequence identity to query as a heatmap
+    | b) Per-position sequence coverage as a line plot
+    |
+    | This plot can help you understand if regions of low confidence are due to poor sequence coverage, rather than
+    | limitations of the model or intrinsically unstable regions.
+    |
+    |
+
     *Model predicted-alignment error matrix (pae_ranked_n.csv)*
 
     | Per-model predicted-alignment error (PAE) matrix - only available with the ``monomer_ptm`` and ``multimer`` model presets.
--- a/macro_output.xml	Sun Jul 28 20:09:55 2024 +0000
+++ b/macro_output.xml	Tue Oct 29 02:15:36 2024 +0000
@@ -172,6 +172,17 @@
         </data>
     </xml>
 
+    <xml name="output_msa_plot">
+        <data
+            name="output_msa_plot"
+            format="png"
+            from_work_dir="output/alphafold/extra/msa_coverage.png"
+            label="${tool.name} on ${on_string}: MSA plot"
+        >
+            <filter>outputs['plot_msa']</filter>
+        </data>
+    </xml>
+
     <xml name="output_plddts">
         <data
             name="output_plddts"
--- a/scripts/outputs.py	Sun Jul 28 20:09:55 2024 +0000
+++ b/scripts/outputs.py	Tue Oct 29 02:15:36 2024 +0000
@@ -16,11 +16,12 @@
 
 import argparse
 import json
+import numpy as np
 import os
 import pickle as pk
 import shutil
 from pathlib import Path
-from typing import List
+from typing import Dict, List
 
 from matplotlib import pyplot as plt
 
@@ -33,13 +34,7 @@
     'model_confidence_scores': OUTPUT_DIR + '/model_confidence_scores.tsv',
     'plddts': OUTPUT_DIR + '/plddts.tsv',
     'relax': OUTPUT_DIR + '/relax_metrics_ranked.json',
-}
-
-# Keys for accessing confidence data from JSON/pkl files
-# They change depending on whether the run was monomer or multimer
-PLDDT_KEY = {
-    'monomer': 'plddts',
-    'multimer': 'iptm+ptm',
+    'msa': OUTPUT_DIR + '/msa_coverage.png',
 }
 
 HTML_PATH = Path(__file__).parent / "alphafold.html"
@@ -49,6 +44,20 @@
     'class="btn disabled" id="btn-ranked_{rank}" disabled')
 
 
+class PLDDT_KEY:
+    """Dict keys for accessing confidence data from JSON/pkl files."
+    Changes depending on which model PRESET was used.
+    """
+    monomer = 'plddts'
+    multimer = 'iptm+ptm'
+
+
+class PRESETS:
+    monomer = 'monomer'
+    monomer_ptm = 'monomer_ptm'
+    multimer = 'multimer'
+
+
 class Settings:
     """Parse and store settings/config."""
     def __init__(self):
@@ -63,54 +72,63 @@
         parser.add_argument(
             "workdir",
             help="alphafold output directory",
-            type=str
+            type=str,
         )
         parser.add_argument(
-            "-p",
-            "--plddts",
+            "-s",
+            "--confidence-scores",
             help="output per-residue confidence scores (pLDDTs)",
-            action="store_true"
-        )
-        parser.add_argument(
-            "-m",
-            "--multimer",
-            help="parse output from AlphaFold multimer",
-            action="store_true"
+            action="store_true",
         )
         parser.add_argument(
             "--pkl",
             help="rename model pkl outputs with rank order",
-            action="store_true"
+            action="store_true",
         )
         parser.add_argument(
             "--pae",
             help="extract PAE from pkl files to CSV format",
-            action="store_true"
+            action="store_true",
         )
         parser.add_argument(
             "--plot",
             help="Plot pLDDT and PAE for each model",
-            action="store_true"
+            action="store_true",
+        )
+        parser.add_argument(
+            "--plot-msa",
+            help="Plot multiple-sequence alignment coverage as a heatmap",
+            action="store_true",
         )
         args = parser.parse_args()
         self.workdir = Path(args.workdir.rstrip('/'))
-        self.output_residue_scores = args.plddts
+        self.output_residue_scores = args.confidence_scores
         self.output_model_pkls = args.pkl
         self.output_model_plots = args.plot
         self.output_pae = args.pae
-        self.is_multimer = args.multimer
+        self.plot_msa = args.plot_msa
+        self.model_preset = self._sniff_model_preset()
         self.output_dir = self.workdir / OUTPUT_DIR
         os.makedirs(self.output_dir, exist_ok=True)
 
+    def _sniff_model_preset(self) -> bool:
+        """Check if the run was multimer or monomer."""
+        with open(self.workdir / 'relax_metrics.json') as f:
+            if '_multimer_' in f.read():
+                return PRESETS.multimer
+            if '_ptm_' in f.read():
+                return PRESETS.monomer_ptm
+        return PRESETS.monomer
+
 
 class ExecutionContext:
     """Collect file paths etc."""
     def __init__(self, settings: Settings):
         self.settings = settings
-        if settings.is_multimer:
-            self.plddt_key = PLDDT_KEY['multimer']
+        if settings.model_preset == PRESETS.multimer:
+            self.plddt_key = PLDDT_KEY.multimer
         else:
-            self.plddt_key = PLDDT_KEY['monomer']
+            self.plddt_key = PLDDT_KEY.monomer
 
     def get_model_key(self, ix: int) -> str:
         """Return json key for model index.
@@ -192,11 +210,30 @@
 
 def write_confidence_scores(ranking: ResultRanking, context: ExecutionContext):
     """Write per-model confidence scores."""
-    path = context.settings.workdir / OUTPUTS['model_confidence_scores']
-    with open(path, 'w') as f:
-        for rank in range(1, len(context.model_pkl_paths) + 1):
-            score = ranking.get_plddt_for_rank(rank)
-            f.write(f'ranked_{rank - 1}\t{score:.2f}\n')
+    outfile = context.settings.workdir / OUTPUTS['model_confidence_scores']
+    scores: Dict[str, list] = {}
+    header = ['model', context.plddt_key]
+
+    for i, path in enumerate(context.model_pkl_paths):
+        rank = int(path.name.split('model_')[-1][0])
+        scores_ls = [ranking.get_plddt_for_rank(rank)]
+        with open(path, 'rb') as f:
+            data = pk.load(f)
+        if 'ptm' in data:
+            scores_ls.append(data['ptm'])
+            if i == 0:
+                header += ['ptm']
+        if 'iptm' in data:
+            scores_ls.append(data['iptm'])
+            if i == 0:
+                header += ['iptm']
+        scores[rank] = scores_ls
+
+    with open(outfile, 'w') as f:
+        f.write('\t'.join(header) + '\n')
+        for rank, score_ls in scores.items():
+            row = [f"ranked_{rank - 1}"] + [str(x) for x in score_ls]
+            f.write('\t'.join(row) + '\n')
 
 
 def write_per_residue_scores(
@@ -304,6 +341,40 @@
             plt.ylabel('Aligned residue')
 
         plt.savefig(png_path)
+        plt.close()
+
+
+def plot_msa(wdir: Path, dpi: int = 150):
+    """Plot MSA as a heatmap."""
+    with open(wdir / 'features.pkl', 'rb') as f:
+        features = pk.load(f)
+
+    msa = features.get('msa')
+    if msa is None:
+        print("Could not plot MSA coverage - 'msa' key not found in"
+              " features.pkl")
+        return
+    seqid = (np.array(msa[0] == msa).mean(-1))
+    seqid_sort = seqid.argsort()
+    non_gaps = (msa != 21).astype(float)
+    non_gaps[non_gaps == 0] = np.nan
+    final = non_gaps[seqid_sort] * seqid[seqid_sort, None]
+
+    plt.figure(figsize=(6, 4))
+    # plt.subplot(111)
+    plt.title("Sequence coverage")
+    plt.imshow(final,
+               interpolation='nearest', aspect='auto',
+               cmap="rainbow_r", vmin=0, vmax=1, origin='lower')
+    plt.plot((msa != 21).sum(0), color='black')
+    plt.xlim(-0.5, msa.shape[1] - 0.5)
+    plt.ylim(-0.5, msa.shape[0] - 0.5)
+    plt.colorbar(label="Sequence identity to query", )
+    plt.xlabel("Positions")
+    plt.ylabel("Sequences")
+    plt.tight_layout()
+    plt.savefig(wdir / OUTPUTS['msa'], dpi=dpi)
+    plt.close()
 
 
 def template_html(context: ExecutionContext):
@@ -341,6 +412,8 @@
         extract_pae_to_csv(ranking, context)
     if settings.output_residue_scores:
         write_per_residue_scores(ranking, context)
+    if settings.plot_msa:
+        plot_msa(context.settings.workdir)
 
 
 if __name__ == '__main__':