diff predict.py @ 0:457fd8fd681a draft

planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
author iuc
date Wed, 09 Nov 2022 12:19:26 +0000
parents
children 9b12bc1b1e2c
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/predict.py	Wed Nov 09 12:19:26 2022 +0000
@@ -0,0 +1,194 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# Credits: Grigorii Sukhorukov, Macha Nikolski
+import argparse
+import os
+from pathlib import Path
+
+import numpy as np
+import pandas as pd
+from Bio import SeqIO
+from joblib import load
+from models import model_5, model_7
+from utils import preprocess as pp
+
+os.environ["CUDA_VISIBLE_DEVICES"] = ""
+os.environ["TF_XLA_FLAGS"] = "--tf_xla_cpu_global_jit"
+# loglevel : 0 all printed, 1 I not printed, 2 I and W not printed, 3 nothing printed
+os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
+
+
+def predict_nn(ds_path, nn_weights_path, length, batch_size=256):
+    """
+    Breaks down contigs into fragments
+    and uses pretrained neural networks to give predictions for fragments
+    """
+    try:
+        seqs_ = list(SeqIO.parse(ds_path, "fasta"))
+    except FileNotFoundError:
+        raise Exception("test dataset was not found. Change ds variable")
+    out_table = {
+        "id": [],
+        "length": [],
+        "fragment": [],
+        "pred_plant_5": [],
+        "pred_vir_5": [],
+        "pred_bact_5": [],
+        "pred_plant_7": [],
+        "pred_vir_7": [],
+        "pred_bact_7": [],
+        # "pred_plant_10": [],
+        # "pred_vir_10": [],
+        # "pred_bact_10": [],
+    }
+    if not seqs_:
+        raise ValueError("All sequences were smaller than length of the model")
+    test_fragments = []
+    test_fragments_rc = []
+    for seq in seqs_:
+        fragments_, fragments_rc, _ = pp.fragmenting([seq], length, max_gap=0.8,
+                                                     sl_wind_step=int(length / 2))
+        test_fragments.extend(fragments_)
+        test_fragments_rc.extend(fragments_rc)
+        for j in range(len(fragments_)):
+            out_table["id"].append(seq.id)
+            out_table["length"].append(len(seq.seq))
+            out_table["fragment"].append(j)
+    test_encoded = pp.one_hot_encode(test_fragments)
+    test_encoded_rc = pp.one_hot_encode(test_fragments_rc)
+    # for model, s in zip([model_5.model(length), model_7.model(length), model_10.model(length)], [5, 7, 10]):
+    for model, s in zip([model_5.model(length), model_7.model(length)], [5, 7]):
+        model.load_weights(Path(nn_weights_path, f"model_{s}_{length}.h5"))
+        prediction = model.predict([test_encoded, test_encoded_rc], batch_size)
+        out_table[f"pred_plant_{s}"].extend(list(prediction[..., 0]))
+        out_table[f"pred_vir_{s}"].extend(list(prediction[..., 1]))
+        out_table[f"pred_bact_{s}"].extend(list(prediction[..., 2]))
+    return pd.DataFrame(out_table)
+
+
+def predict_rf(df, rf_weights_path, length):
+    """
+    Using predictions by predict_nn and weights of a trained RF classifier gives a single prediction for a fragment
+    """
+
+    clf = load(Path(rf_weights_path, f"RF_{length}.joblib"))
+    X = df[["pred_plant_5", "pred_vir_5", "pred_plant_7", "pred_vir_7"]]
+    # X = ["pred_plant_5", "pred_vir_5", "pred_plant_7", "pred_vir_7", "pred_plant_10", "pred_vir_10", ]]
+    y_pred = clf.predict(X)
+    mapping = {0: "plant", 1: "virus", 2: "bacteria"}
+    df["RF_decision"] = np.vectorize(mapping.get)(y_pred)
+    prob_classes = clf.predict_proba(X)
+    df["RF_pred_plant"] = prob_classes[..., 0]
+    df["RF_pred_vir"] = prob_classes[..., 1]
+    df["RF_pred_bact"] = prob_classes[..., 2]
+    return df
+
+
+def predict_contigs(df):
+    """
+    Based on predictions of predict_rf for fragments gives a final prediction for the whole contig
+    """
+    df = (
+        df.groupby(["id", "length", 'RF_decision'], sort=False)
+        .size()
+        .unstack(fill_value=0)
+    )
+    df = df.reset_index()
+    df = df.reindex(['length', 'id', 'virus', 'plant', 'bacteria'], axis=1)
+    conditions = [
+        (df['virus'] > df['plant']) & (df['virus'] > df['bacteria']),
+        (df['plant'] > df['virus']) & (df['plant'] > df['bacteria']),
+        (df['bacteria'] >= df['plant']) & (df['bacteria'] >= df['virus']),
+    ]
+    choices = ['virus', 'plant', 'bacteria']
+    df['decision'] = np.select(conditions, choices, default='bacteria')
+    df = df.loc[:, ['length', 'id', 'virus', 'plant', 'bacteria', 'decision']]
+    df = df.rename(columns={'virus': '# viral fragments', 'bacteria': '# bacterial fragments', 'plant': '# plant fragments'})
+    df['# viral / # total'] = (df['# viral fragments'] / (df['# viral fragments'] + df['# bacterial fragments'] + df['# plant fragments'])).round(3)
+    df['# viral / # total * length'] = df['# viral / # total'] * df['length']
+    df = df.sort_values(by='# viral / # total * length', ascending=False)
+    return df
+
+
+def predict(test_ds, weights, out_path, return_viral, limit):
+    """Predicts viral contigs from the fasta file
+
+    test_ds: path to the input file with contigs in fasta format (str or list of str)
+    weights: path to the folder containing weights for NN and RF modules trained on 500 and 1000 fragment lengths (str)
+    out_path: path to the folder to store predictions (str)
+    return_viral: whether to return contigs annotated as viral in separate fasta file (True/False)
+    limit: Do predictions only for contigs > l. We suggest l=750. (int)
+    """
+    test_ds = test_ds
+    if isinstance(test_ds, list):
+        pass
+    elif isinstance(test_ds, str):
+        test_ds = [test_ds]
+    else:
+        raise ValueError('test_ds was incorrectly assigned in the config file')
+
+    assert Path(test_ds[0]).exists(), f'{test_ds[0]} does not exist'
+    assert Path(weights).exists(), f'{weights} does not exist'
+    assert isinstance(limit, int), 'limit should be an integer'
+    Path(out_path).mkdir(parents=True, exist_ok=True)
+
+    for ts in test_ds:
+        dfs_fr = []
+        dfs_cont = []
+        for l_ in 500, 1000:
+            # print(f'starting prediction for {Path(ts).name} for fragment length {l_}')
+            df = predict_nn(
+                ds_path=ts,
+                nn_weights_path=weights,
+                length=l_,
+            )
+            print(df)
+            df = predict_rf(
+                df=df,
+                rf_weights_path=weights,
+                length=l_,
+            )
+            df = df.round(3)
+            dfs_fr.append(df)
+            df = predict_contigs(df)
+            dfs_cont.append(df)
+            # print('prediction finished')
+        df_500 = dfs_fr[0][(dfs_fr[0]['length'] >= limit) & (dfs_fr[0]['length'] < 1500)]
+        df_1000 = dfs_fr[1][(dfs_fr[1]['length'] >= 1500)]
+        df = pd.concat([df_1000, df_500], ignore_index=True)
+        pred_fr = Path(out_path, 'predicted_fragments.csv')
+        df.to_csv(pred_fr)
+
+        df_500 = dfs_cont[0][(dfs_cont[0]['length'] >= limit) & (dfs_cont[0]['length'] < 1500)]
+        df_1000 = dfs_cont[1][(dfs_cont[1]['length'] >= 1500)]
+        df = pd.concat([df_1000, df_500], ignore_index=True)
+        pred_contigs = Path(out_path, 'predicted.csv')
+        df.to_csv(pred_contigs)
+
+        if return_viral:
+            viral_ids = list(df[df["decision"] == "virus"]["id"])
+            seqs_ = list(SeqIO.parse(ts, "fasta"))
+            viral_seqs = [s_ for s_ in seqs_ if s_.id in viral_ids]
+            SeqIO.write(viral_seqs, Path(out_path, 'viral.fasta'), 'fasta')
+
+
+if __name__ == '__main__':
+    parser = argparse.ArgumentParser()
+    parser.add_argument("--test_ds", help="path to the input file with contigs in fasta format (str or list of str)")
+    parser.add_argument("--weights", help="path to the folder containing weights for NN and RF modules trained on 500 and 1000 fragment lengths (str)")
+    parser.add_argument("--out_path", help="path to the folder to store predictions (str)")
+    parser.add_argument("--return_viral", help="whether to return contigs annotated as viral in separate fasta file (True/False)")
+    parser.add_argument("--limit", help="Do predictions only for contigs > l. We suggest l=750. (int)", type=int)
+
+    args = parser.parse_args()
+    if args.test_ds:
+        test_ds = args.test_ds
+    if args.weights:
+        weights = args.weights
+    if args.out_path:
+        out_path = args.out_path
+    if args.return_viral:
+        return_viral = args.return_viral
+    if args.limit:
+        limit = args.limit
+    predict(test_ds, weights, out_path, return_viral, limit)