Mercurial > repos > iuc > virhunter
comparison predict.py @ 0:457fd8fd681a draft
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
| author | iuc |
|---|---|
| date | Wed, 09 Nov 2022 12:19:26 +0000 |
| parents | |
| children | 9b12bc1b1e2c |
comparison
equal
deleted
inserted
replaced
| -1:000000000000 | 0:457fd8fd681a |
|---|---|
| 1 #!/usr/bin/env python | |
| 2 # -*- coding: utf-8 -*- | |
| 3 # Credits: Grigorii Sukhorukov, Macha Nikolski | |
| 4 import argparse | |
| 5 import os | |
| 6 from pathlib import Path | |
| 7 | |
| 8 import numpy as np | |
| 9 import pandas as pd | |
| 10 from Bio import SeqIO | |
| 11 from joblib import load | |
| 12 from models import model_5, model_7 | |
| 13 from utils import preprocess as pp | |
| 14 | |
| 15 os.environ["CUDA_VISIBLE_DEVICES"] = "" | |
| 16 os.environ["TF_XLA_FLAGS"] = "--tf_xla_cpu_global_jit" | |
| 17 # loglevel : 0 all printed, 1 I not printed, 2 I and W not printed, 3 nothing printed | |
| 18 os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' | |
| 19 | |
| 20 | |
| 21 def predict_nn(ds_path, nn_weights_path, length, batch_size=256): | |
| 22 """ | |
| 23 Breaks down contigs into fragments | |
| 24 and uses pretrained neural networks to give predictions for fragments | |
| 25 """ | |
| 26 try: | |
| 27 seqs_ = list(SeqIO.parse(ds_path, "fasta")) | |
| 28 except FileNotFoundError: | |
| 29 raise Exception("test dataset was not found. Change ds variable") | |
| 30 out_table = { | |
| 31 "id": [], | |
| 32 "length": [], | |
| 33 "fragment": [], | |
| 34 "pred_plant_5": [], | |
| 35 "pred_vir_5": [], | |
| 36 "pred_bact_5": [], | |
| 37 "pred_plant_7": [], | |
| 38 "pred_vir_7": [], | |
| 39 "pred_bact_7": [], | |
| 40 # "pred_plant_10": [], | |
| 41 # "pred_vir_10": [], | |
| 42 # "pred_bact_10": [], | |
| 43 } | |
| 44 if not seqs_: | |
| 45 raise ValueError("All sequences were smaller than length of the model") | |
| 46 test_fragments = [] | |
| 47 test_fragments_rc = [] | |
| 48 for seq in seqs_: | |
| 49 fragments_, fragments_rc, _ = pp.fragmenting([seq], length, max_gap=0.8, | |
| 50 sl_wind_step=int(length / 2)) | |
| 51 test_fragments.extend(fragments_) | |
| 52 test_fragments_rc.extend(fragments_rc) | |
| 53 for j in range(len(fragments_)): | |
| 54 out_table["id"].append(seq.id) | |
| 55 out_table["length"].append(len(seq.seq)) | |
| 56 out_table["fragment"].append(j) | |
| 57 test_encoded = pp.one_hot_encode(test_fragments) | |
| 58 test_encoded_rc = pp.one_hot_encode(test_fragments_rc) | |
| 59 # for model, s in zip([model_5.model(length), model_7.model(length), model_10.model(length)], [5, 7, 10]): | |
| 60 for model, s in zip([model_5.model(length), model_7.model(length)], [5, 7]): | |
| 61 model.load_weights(Path(nn_weights_path, f"model_{s}_{length}.h5")) | |
| 62 prediction = model.predict([test_encoded, test_encoded_rc], batch_size) | |
| 63 out_table[f"pred_plant_{s}"].extend(list(prediction[..., 0])) | |
| 64 out_table[f"pred_vir_{s}"].extend(list(prediction[..., 1])) | |
| 65 out_table[f"pred_bact_{s}"].extend(list(prediction[..., 2])) | |
| 66 return pd.DataFrame(out_table) | |
| 67 | |
| 68 | |
| 69 def predict_rf(df, rf_weights_path, length): | |
| 70 """ | |
| 71 Using predictions by predict_nn and weights of a trained RF classifier gives a single prediction for a fragment | |
| 72 """ | |
| 73 | |
| 74 clf = load(Path(rf_weights_path, f"RF_{length}.joblib")) | |
| 75 X = df[["pred_plant_5", "pred_vir_5", "pred_plant_7", "pred_vir_7"]] | |
| 76 # X = ["pred_plant_5", "pred_vir_5", "pred_plant_7", "pred_vir_7", "pred_plant_10", "pred_vir_10", ]] | |
| 77 y_pred = clf.predict(X) | |
| 78 mapping = {0: "plant", 1: "virus", 2: "bacteria"} | |
| 79 df["RF_decision"] = np.vectorize(mapping.get)(y_pred) | |
| 80 prob_classes = clf.predict_proba(X) | |
| 81 df["RF_pred_plant"] = prob_classes[..., 0] | |
| 82 df["RF_pred_vir"] = prob_classes[..., 1] | |
| 83 df["RF_pred_bact"] = prob_classes[..., 2] | |
| 84 return df | |
| 85 | |
| 86 | |
| 87 def predict_contigs(df): | |
| 88 """ | |
| 89 Based on predictions of predict_rf for fragments gives a final prediction for the whole contig | |
| 90 """ | |
| 91 df = ( | |
| 92 df.groupby(["id", "length", 'RF_decision'], sort=False) | |
| 93 .size() | |
| 94 .unstack(fill_value=0) | |
| 95 ) | |
| 96 df = df.reset_index() | |
| 97 df = df.reindex(['length', 'id', 'virus', 'plant', 'bacteria'], axis=1) | |
| 98 conditions = [ | |
| 99 (df['virus'] > df['plant']) & (df['virus'] > df['bacteria']), | |
| 100 (df['plant'] > df['virus']) & (df['plant'] > df['bacteria']), | |
| 101 (df['bacteria'] >= df['plant']) & (df['bacteria'] >= df['virus']), | |
| 102 ] | |
| 103 choices = ['virus', 'plant', 'bacteria'] | |
| 104 df['decision'] = np.select(conditions, choices, default='bacteria') | |
| 105 df = df.loc[:, ['length', 'id', 'virus', 'plant', 'bacteria', 'decision']] | |
| 106 df = df.rename(columns={'virus': '# viral fragments', 'bacteria': '# bacterial fragments', 'plant': '# plant fragments'}) | |
| 107 df['# viral / # total'] = (df['# viral fragments'] / (df['# viral fragments'] + df['# bacterial fragments'] + df['# plant fragments'])).round(3) | |
| 108 df['# viral / # total * length'] = df['# viral / # total'] * df['length'] | |
| 109 df = df.sort_values(by='# viral / # total * length', ascending=False) | |
| 110 return df | |
| 111 | |
| 112 | |
| 113 def predict(test_ds, weights, out_path, return_viral, limit): | |
| 114 """Predicts viral contigs from the fasta file | |
| 115 | |
| 116 test_ds: path to the input file with contigs in fasta format (str or list of str) | |
| 117 weights: path to the folder containing weights for NN and RF modules trained on 500 and 1000 fragment lengths (str) | |
| 118 out_path: path to the folder to store predictions (str) | |
| 119 return_viral: whether to return contigs annotated as viral in separate fasta file (True/False) | |
| 120 limit: Do predictions only for contigs > l. We suggest l=750. (int) | |
| 121 """ | |
| 122 test_ds = test_ds | |
| 123 if isinstance(test_ds, list): | |
| 124 pass | |
| 125 elif isinstance(test_ds, str): | |
| 126 test_ds = [test_ds] | |
| 127 else: | |
| 128 raise ValueError('test_ds was incorrectly assigned in the config file') | |
| 129 | |
| 130 assert Path(test_ds[0]).exists(), f'{test_ds[0]} does not exist' | |
| 131 assert Path(weights).exists(), f'{weights} does not exist' | |
| 132 assert isinstance(limit, int), 'limit should be an integer' | |
| 133 Path(out_path).mkdir(parents=True, exist_ok=True) | |
| 134 | |
| 135 for ts in test_ds: | |
| 136 dfs_fr = [] | |
| 137 dfs_cont = [] | |
| 138 for l_ in 500, 1000: | |
| 139 # print(f'starting prediction for {Path(ts).name} for fragment length {l_}') | |
| 140 df = predict_nn( | |
| 141 ds_path=ts, | |
| 142 nn_weights_path=weights, | |
| 143 length=l_, | |
| 144 ) | |
| 145 print(df) | |
| 146 df = predict_rf( | |
| 147 df=df, | |
| 148 rf_weights_path=weights, | |
| 149 length=l_, | |
| 150 ) | |
| 151 df = df.round(3) | |
| 152 dfs_fr.append(df) | |
| 153 df = predict_contigs(df) | |
| 154 dfs_cont.append(df) | |
| 155 # print('prediction finished') | |
| 156 df_500 = dfs_fr[0][(dfs_fr[0]['length'] >= limit) & (dfs_fr[0]['length'] < 1500)] | |
| 157 df_1000 = dfs_fr[1][(dfs_fr[1]['length'] >= 1500)] | |
| 158 df = pd.concat([df_1000, df_500], ignore_index=True) | |
| 159 pred_fr = Path(out_path, 'predicted_fragments.csv') | |
| 160 df.to_csv(pred_fr) | |
| 161 | |
| 162 df_500 = dfs_cont[0][(dfs_cont[0]['length'] >= limit) & (dfs_cont[0]['length'] < 1500)] | |
| 163 df_1000 = dfs_cont[1][(dfs_cont[1]['length'] >= 1500)] | |
| 164 df = pd.concat([df_1000, df_500], ignore_index=True) | |
| 165 pred_contigs = Path(out_path, 'predicted.csv') | |
| 166 df.to_csv(pred_contigs) | |
| 167 | |
| 168 if return_viral: | |
| 169 viral_ids = list(df[df["decision"] == "virus"]["id"]) | |
| 170 seqs_ = list(SeqIO.parse(ts, "fasta")) | |
| 171 viral_seqs = [s_ for s_ in seqs_ if s_.id in viral_ids] | |
| 172 SeqIO.write(viral_seqs, Path(out_path, 'viral.fasta'), 'fasta') | |
| 173 | |
| 174 | |
| 175 if __name__ == '__main__': | |
| 176 parser = argparse.ArgumentParser() | |
| 177 parser.add_argument("--test_ds", help="path to the input file with contigs in fasta format (str or list of str)") | |
| 178 parser.add_argument("--weights", help="path to the folder containing weights for NN and RF modules trained on 500 and 1000 fragment lengths (str)") | |
| 179 parser.add_argument("--out_path", help="path to the folder to store predictions (str)") | |
| 180 parser.add_argument("--return_viral", help="whether to return contigs annotated as viral in separate fasta file (True/False)") | |
| 181 parser.add_argument("--limit", help="Do predictions only for contigs > l. We suggest l=750. (int)", type=int) | |
| 182 | |
| 183 args = parser.parse_args() | |
| 184 if args.test_ds: | |
| 185 test_ds = args.test_ds | |
| 186 if args.weights: | |
| 187 weights = args.weights | |
| 188 if args.out_path: | |
| 189 out_path = args.out_path | |
| 190 if args.return_viral: | |
| 191 return_viral = args.return_viral | |
| 192 if args.limit: | |
| 193 limit = args.limit | |
| 194 predict(test_ds, weights, out_path, return_viral, limit) |
