view predict.py @ 2:ea2cccb9f73e draft

planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit c3685ed6a70b47012b62b95a2a3db062bd3b7475
author iuc
date Thu, 05 Jan 2023 14:27:54 +0000
parents 9b12bc1b1e2c
children 302332b914ef
line wrap: on
line source

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Credits: Grigorii Sukhorukov, Macha Nikolski
import argparse
import os
from pathlib import Path

import numpy as np
import pandas as pd
from Bio import SeqIO
from joblib import load
from models import model_10, model_5, model_7
from utils import preprocess as pp

os.environ["CUDA_VISIBLE_DEVICES"] = ""
os.environ["TF_XLA_FLAGS"] = "--tf_xla_cpu_global_jit"
# loglevel : 0 all printed, 1 I not printed, 2 I and W not printed, 3 nothing printed
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'


def predict_nn(ds_path, nn_weights_path, length, use_10, batch_size=256):
    """
    Breaks down contigs into fragments
    and uses pretrained neural networks to give predictions for fragments
    """
    try:
        seqs_ = list(SeqIO.parse(ds_path, "fasta"))
    except FileNotFoundError:
        raise Exception("test dataset was not found. Change ds variable")
    out_table = {
        "id": [],
        "length": [],
        "fragment": [],
        "pred_plant_5": [],
        "pred_vir_5": [],
        "pred_bact_5": [],
        "pred_plant_7": [],
        "pred_vir_7": [],
        "pred_bact_7": [],
    }
    if use_10:
        out_table_ = {
            "pred_plant_10": [],
            "pred_vir_10": [],
            "pred_bact_10": [],
        }
        out_table.update(out_table_)
    if not seqs_:
        raise ValueError("All sequences were smaller than length of the model")
    test_fragments = []
    test_fragments_rc = []
    for seq in seqs_:
        fragments_, fragments_rc, _ = pp.fragmenting([seq], length, max_gap=0.8,
                                                     sl_wind_step=int(length / 2))
        test_fragments.extend(fragments_)
        test_fragments_rc.extend(fragments_rc)
        for j in range(len(fragments_)):
            out_table["id"].append(seq.id)
            out_table["length"].append(len(seq.seq))
            out_table["fragment"].append(j)
    test_encoded = pp.one_hot_encode(test_fragments)
    test_encoded_rc = pp.one_hot_encode(test_fragments_rc)
    if use_10:
        zipped_models = zip([model_5.model(length), model_7.model(length), model_10.model(length)], [5, 7, 10])
    else:
        zipped_models = zip([model_5.model(length), model_7.model(length)], [5, 7])
    for model, s in zipped_models:
        model.load_weights(Path(nn_weights_path, f"model_{s}_{length}.h5"))
        prediction = model.predict([test_encoded, test_encoded_rc], batch_size)
        out_table[f"pred_plant_{s}"].extend(list(prediction[..., 0]))
        out_table[f"pred_vir_{s}"].extend(list(prediction[..., 1]))
        out_table[f"pred_bact_{s}"].extend(list(prediction[..., 2]))

    return pd.DataFrame(out_table)


def predict_rf(df, rf_weights_path, length, use_10):
    """
    Using predictions by predict_nn and weights of a trained RF classifier gives a single prediction for a fragment
    """

    clf = load(Path(rf_weights_path, f"RF_{length}.joblib"))
    if use_10:
        X = df[
            ["pred_plant_5", "pred_vir_5", "pred_plant_7", "pred_vir_7", "pred_plant_10", "pred_vir_10", ]]
    else:
        X = df[
            ["pred_plant_5", "pred_vir_5", "pred_plant_7", "pred_vir_7", ]]
    y_pred = clf.predict(X)
    mapping = {0: "plant", 1: "virus", 2: "bacteria"}
    df["RF_decision"] = np.vectorize(mapping.get)(y_pred)
    prob_classes = clf.predict_proba(X)
    df["RF_pred_plant"] = prob_classes[..., 0]
    df["RF_pred_vir"] = prob_classes[..., 1]
    df["RF_pred_bact"] = prob_classes[..., 2]
    return df


def predict_contigs(df):
    """
    Based on predictions of predict_rf for fragments gives a final prediction for the whole contig
    """
    df = (
        df.groupby(["id", "length", 'RF_decision'], sort=False).size().unstack(fill_value=0)
    )
    df = df.reset_index()
    df = df.reindex(['length', 'id', 'virus', 'plant', 'bacteria'], axis=1).fillna(value=0)
    conditions = [
        (df['virus'] > df['plant']) & (df['virus'] > df['bacteria']),
        (df['plant'] > df['virus']) & (df['plant'] > df['bacteria']),
        (df['bacteria'] >= df['plant']) & (df['bacteria'] >= df['virus']),
    ]
    choices = ['virus', 'plant', 'bacteria']
    df['decision'] = np.select(conditions, choices, default='bacteria')
    df = df.loc[:, ['length', 'id', 'virus', 'plant', 'bacteria', 'decision']]
    df = df.rename(columns={'virus': '# viral fragments', 'bacteria': '# bacterial fragments', 'plant': '# plant fragments'})
    df['# viral / # total'] = (df['# viral fragments'] / (df['# viral fragments'] + df['# bacterial fragments'] + df['# plant fragments'])).round(3)
    df['# viral / # total * length'] = df['# viral / # total'] * df['length']
    df = df.sort_values(by='# viral / # total * length', ascending=False)
    return df


def predict(test_ds, weights, out_path, return_viral, limit):
    """Predicts viral contigs from the fasta file

    test_ds: path to the input file with contigs in fasta format (str or list of str)
    weights: path to the folder containing weights for NN and RF modules trained on 500 and 1000 fragment lengths (str)
    out_path: path to the folder to store predictions (str)
    return_viral: whether to return contigs annotated as viral in separate fasta file (True/False)
    limit: Do predictions only for contigs > l. We suggest l=750. (int)
    """
    test_ds = test_ds
    if isinstance(test_ds, list):
        pass
    elif isinstance(test_ds, str):
        test_ds = [test_ds]
    else:
        raise ValueError('test_ds was incorrectly assigned in the config file')

    assert Path(test_ds[0]).exists(), f'{test_ds[0]} does not exist'
    assert Path(weights).exists(), f'{weights} does not exist'
    assert isinstance(limit, int), 'limit should be an integer'
    Path(out_path).mkdir(parents=True, exist_ok=True)
    use_10 = Path(weights, 'model_10_500.h5').exists()
    for ts in test_ds:
        dfs_fr = []
        dfs_cont = []
        for l_ in 500, 1000:
            # print(f'starting prediction for {Path(ts).name} for fragment length {l_}')
            df = predict_nn(
                ds_path=ts,
                nn_weights_path=weights,
                length=l_,
                use_10=use_10
            )
            print(df)
            df = predict_rf(
                df=df,
                rf_weights_path=weights,
                length=l_,
                use_10=use_10
            )
            df = df.round(3)
            dfs_fr.append(df)
            df = predict_contigs(df)
            dfs_cont.append(df)
            # print('prediction finished')
        df_500 = dfs_fr[0][(dfs_fr[0]['length'] >= limit) & (dfs_fr[0]['length'] < 1500)]
        df_1000 = dfs_fr[1][(dfs_fr[1]['length'] >= 1500)]
        df = pd.concat([df_1000, df_500], ignore_index=True)
        pred_fr = Path(out_path, 'predicted_fragments.csv')
        df.to_csv(pred_fr)

        df_500 = dfs_cont[0][(dfs_cont[0]['length'] >= limit) & (dfs_cont[0]['length'] < 1500)]
        df_1000 = dfs_cont[1][(dfs_cont[1]['length'] >= 1500)]
        df = pd.concat([df_1000, df_500], ignore_index=True)
        pred_contigs = Path(out_path, 'predicted.csv')
        df.to_csv(pred_contigs)

        if return_viral:
            viral_ids = list(df[df["decision"] == "virus"]["id"])
            seqs_ = list(SeqIO.parse(ts, "fasta"))
            viral_seqs = [s_ for s_ in seqs_ if s_.id in viral_ids]
            SeqIO.write(viral_seqs, Path(out_path, 'viral.fasta'), 'fasta')


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--test_ds", help="path to the input file with contigs in fasta format (str or list of str)")
    parser.add_argument("--weights", help="path to the folder containing weights for NN and RF modules trained on 500 and 1000 fragment lengths (str)")
    parser.add_argument("--out_path", help="path to the folder to store predictions (str)")
    parser.add_argument("--return_viral", help="whether to return contigs annotated as viral in separate fasta file (True/False)")
    parser.add_argument("--limit", help="Do predictions only for contigs > l. We suggest l=750. (int)", type=int, default=750)

    args = parser.parse_args()
    if args.test_ds:
        test_ds = args.test_ds
    if args.weights:
        weights = args.weights
    if args.out_path:
        out_path = args.out_path
    if args.return_viral:
        return_viral = args.return_viral
    if args.limit:
        limit = args.limit
    predict(test_ds, weights, out_path, return_viral, limit)