Mercurial > repos > iuc > virhunter
view predict.py @ 1:9b12bc1b1e2c draft
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit cea5a324cffd684e4cb08438fe5a305cfaf0f73a
author | iuc |
---|---|
date | Wed, 30 Nov 2022 17:31:52 +0000 |
parents | 457fd8fd681a |
children | ea2cccb9f73e |
line wrap: on
line source
#!/usr/bin/env python # -*- coding: utf-8 -*- # Credits: Grigorii Sukhorukov, Macha Nikolski import argparse import os from pathlib import Path import numpy as np import pandas as pd from Bio import SeqIO from joblib import load from models import model_10, model_5, model_7 from utils import preprocess as pp os.environ["CUDA_VISIBLE_DEVICES"] = "" os.environ["TF_XLA_FLAGS"] = "--tf_xla_cpu_global_jit" # loglevel : 0 all printed, 1 I not printed, 2 I and W not printed, 3 nothing printed os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' def predict_nn(ds_path, nn_weights_path, length, use_10, batch_size=256): """ Breaks down contigs into fragments and uses pretrained neural networks to give predictions for fragments """ try: seqs_ = list(SeqIO.parse(ds_path, "fasta")) except FileNotFoundError: raise Exception("test dataset was not found. Change ds variable") out_table = { "id": [], "length": [], "fragment": [], "pred_plant_5": [], "pred_vir_5": [], "pred_bact_5": [], "pred_plant_7": [], "pred_vir_7": [], "pred_bact_7": [], } if use_10: out_table_ = { "pred_plant_10": [], "pred_vir_10": [], "pred_bact_10": [], } out_table.update(out_table_) if not seqs_: raise ValueError("All sequences were smaller than length of the model") test_fragments = [] test_fragments_rc = [] for seq in seqs_: fragments_, fragments_rc, _ = pp.fragmenting([seq], length, max_gap=0.8, sl_wind_step=int(length / 2)) test_fragments.extend(fragments_) test_fragments_rc.extend(fragments_rc) for j in range(len(fragments_)): out_table["id"].append(seq.id) out_table["length"].append(len(seq.seq)) out_table["fragment"].append(j) test_encoded = pp.one_hot_encode(test_fragments) test_encoded_rc = pp.one_hot_encode(test_fragments_rc) if use_10: zipped_models = zip([model_5.model(length), model_7.model(length), model_10.model(length)], [5, 7, 10]) else: zipped_models = zip([model_5.model(length), model_7.model(length)], [5, 7]) for model, s in zipped_models: model.load_weights(Path(nn_weights_path, f"model_{s}_{length}.h5")) prediction = model.predict([test_encoded, test_encoded_rc], batch_size) out_table[f"pred_plant_{s}"].extend(list(prediction[..., 0])) out_table[f"pred_vir_{s}"].extend(list(prediction[..., 1])) out_table[f"pred_bact_{s}"].extend(list(prediction[..., 2])) return pd.DataFrame(out_table) def predict_rf(df, rf_weights_path, length, use_10): """ Using predictions by predict_nn and weights of a trained RF classifier gives a single prediction for a fragment """ clf = load(Path(rf_weights_path, f"RF_{length}.joblib")) if use_10: X = df[ ["pred_plant_5", "pred_vir_5", "pred_plant_7", "pred_vir_7", "pred_plant_10", "pred_vir_10", ]] else: X = df[ ["pred_plant_5", "pred_vir_5", "pred_plant_7", "pred_vir_7", ]] y_pred = clf.predict(X) mapping = {0: "plant", 1: "virus", 2: "bacteria"} df["RF_decision"] = np.vectorize(mapping.get)(y_pred) prob_classes = clf.predict_proba(X) df["RF_pred_plant"] = prob_classes[..., 0] df["RF_pred_vir"] = prob_classes[..., 1] df["RF_pred_bact"] = prob_classes[..., 2] return df def predict_contigs(df): """ Based on predictions of predict_rf for fragments gives a final prediction for the whole contig """ df = ( df.groupby(["id", "length", 'RF_decision'], sort=False).size().unstack(fill_value=0) ) df = df.reset_index() df = df.reindex(['length', 'id', 'virus', 'plant', 'bacteria'], axis=1) conditions = [ (df['virus'] > df['plant']) & (df['virus'] > df['bacteria']), (df['plant'] > df['virus']) & (df['plant'] > df['bacteria']), (df['bacteria'] >= df['plant']) & (df['bacteria'] >= df['virus']), ] choices = ['virus', 'plant', 'bacteria'] df['decision'] = np.select(conditions, choices, default='bacteria') df = df.loc[:, ['length', 'id', 'virus', 'plant', 'bacteria', 'decision']] df = df.rename(columns={'virus': '# viral fragments', 'bacteria': '# bacterial fragments', 'plant': '# plant fragments'}) df['# viral / # total'] = (df['# viral fragments'] / (df['# viral fragments'] + df['# bacterial fragments'] + df['# plant fragments'])).round(3) df['# viral / # total * length'] = df['# viral / # total'] * df['length'] df = df.sort_values(by='# viral / # total * length', ascending=False) return df def predict(test_ds, weights, out_path, return_viral, limit): """Predicts viral contigs from the fasta file test_ds: path to the input file with contigs in fasta format (str or list of str) weights: path to the folder containing weights for NN and RF modules trained on 500 and 1000 fragment lengths (str) out_path: path to the folder to store predictions (str) return_viral: whether to return contigs annotated as viral in separate fasta file (True/False) limit: Do predictions only for contigs > l. We suggest l=750. (int) """ test_ds = test_ds if isinstance(test_ds, list): pass elif isinstance(test_ds, str): test_ds = [test_ds] else: raise ValueError('test_ds was incorrectly assigned in the config file') assert Path(test_ds[0]).exists(), f'{test_ds[0]} does not exist' assert Path(weights).exists(), f'{weights} does not exist' assert isinstance(limit, int), 'limit should be an integer' Path(out_path).mkdir(parents=True, exist_ok=True) use_10 = Path(weights, 'model_10_500.h5').exists() for ts in test_ds: dfs_fr = [] dfs_cont = [] for l_ in 500, 1000: # print(f'starting prediction for {Path(ts).name} for fragment length {l_}') df = predict_nn( ds_path=ts, nn_weights_path=weights, length=l_, use_10=use_10 ) print(df) df = predict_rf( df=df, rf_weights_path=weights, length=l_, use_10=use_10 ) df = df.round(3) dfs_fr.append(df) df = predict_contigs(df) dfs_cont.append(df) # print('prediction finished') df_500 = dfs_fr[0][(dfs_fr[0]['length'] >= limit) & (dfs_fr[0]['length'] < 1500)] df_1000 = dfs_fr[1][(dfs_fr[1]['length'] >= 1500)] df = pd.concat([df_1000, df_500], ignore_index=True) pred_fr = Path(out_path, 'predicted_fragments.csv') df.to_csv(pred_fr) df_500 = dfs_cont[0][(dfs_cont[0]['length'] >= limit) & (dfs_cont[0]['length'] < 1500)] df_1000 = dfs_cont[1][(dfs_cont[1]['length'] >= 1500)] df = pd.concat([df_1000, df_500], ignore_index=True) pred_contigs = Path(out_path, 'predicted.csv') df.to_csv(pred_contigs) if return_viral: viral_ids = list(df[df["decision"] == "virus"]["id"]) seqs_ = list(SeqIO.parse(ts, "fasta")) viral_seqs = [s_ for s_ in seqs_ if s_.id in viral_ids] SeqIO.write(viral_seqs, Path(out_path, 'viral.fasta'), 'fasta') if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument("--test_ds", help="path to the input file with contigs in fasta format (str or list of str)") parser.add_argument("--weights", help="path to the folder containing weights for NN and RF modules trained on 500 and 1000 fragment lengths (str)") parser.add_argument("--out_path", help="path to the folder to store predictions (str)") parser.add_argument("--return_viral", help="whether to return contigs annotated as viral in separate fasta file (True/False)") parser.add_argument("--limit", help="Do predictions only for contigs > l. We suggest l=750. (int)", type=int, default=750) args = parser.parse_args() if args.test_ds: test_ds = args.test_ds if args.weights: weights = args.weights if args.out_path: out_path = args.out_path if args.return_viral: return_viral = args.return_viral if args.limit: limit = args.limit predict(test_ds, weights, out_path, return_viral, limit)