comparison predict.py @ 1:9b12bc1b1e2c draft

planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit cea5a324cffd684e4cb08438fe5a305cfaf0f73a
author iuc
date Wed, 30 Nov 2022 17:31:52 +0000
parents 457fd8fd681a
children ea2cccb9f73e
comparison
equal deleted inserted replaced
0:457fd8fd681a 1:9b12bc1b1e2c
7 7
8 import numpy as np 8 import numpy as np
9 import pandas as pd 9 import pandas as pd
10 from Bio import SeqIO 10 from Bio import SeqIO
11 from joblib import load 11 from joblib import load
12 from models import model_5, model_7 12 from models import model_10, model_5, model_7
13 from utils import preprocess as pp 13 from utils import preprocess as pp
14 14
15 os.environ["CUDA_VISIBLE_DEVICES"] = "" 15 os.environ["CUDA_VISIBLE_DEVICES"] = ""
16 os.environ["TF_XLA_FLAGS"] = "--tf_xla_cpu_global_jit" 16 os.environ["TF_XLA_FLAGS"] = "--tf_xla_cpu_global_jit"
17 # loglevel : 0 all printed, 1 I not printed, 2 I and W not printed, 3 nothing printed 17 # loglevel : 0 all printed, 1 I not printed, 2 I and W not printed, 3 nothing printed
18 os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 18 os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
19 19
20 20
21 def predict_nn(ds_path, nn_weights_path, length, batch_size=256): 21 def predict_nn(ds_path, nn_weights_path, length, use_10, batch_size=256):
22 """ 22 """
23 Breaks down contigs into fragments 23 Breaks down contigs into fragments
24 and uses pretrained neural networks to give predictions for fragments 24 and uses pretrained neural networks to give predictions for fragments
25 """ 25 """
26 try: 26 try:
35 "pred_vir_5": [], 35 "pred_vir_5": [],
36 "pred_bact_5": [], 36 "pred_bact_5": [],
37 "pred_plant_7": [], 37 "pred_plant_7": [],
38 "pred_vir_7": [], 38 "pred_vir_7": [],
39 "pred_bact_7": [], 39 "pred_bact_7": [],
40 # "pred_plant_10": [],
41 # "pred_vir_10": [],
42 # "pred_bact_10": [],
43 } 40 }
41 if use_10:
42 out_table_ = {
43 "pred_plant_10": [],
44 "pred_vir_10": [],
45 "pred_bact_10": [],
46 }
47 out_table.update(out_table_)
44 if not seqs_: 48 if not seqs_:
45 raise ValueError("All sequences were smaller than length of the model") 49 raise ValueError("All sequences were smaller than length of the model")
46 test_fragments = [] 50 test_fragments = []
47 test_fragments_rc = [] 51 test_fragments_rc = []
48 for seq in seqs_: 52 for seq in seqs_:
54 out_table["id"].append(seq.id) 58 out_table["id"].append(seq.id)
55 out_table["length"].append(len(seq.seq)) 59 out_table["length"].append(len(seq.seq))
56 out_table["fragment"].append(j) 60 out_table["fragment"].append(j)
57 test_encoded = pp.one_hot_encode(test_fragments) 61 test_encoded = pp.one_hot_encode(test_fragments)
58 test_encoded_rc = pp.one_hot_encode(test_fragments_rc) 62 test_encoded_rc = pp.one_hot_encode(test_fragments_rc)
59 # for model, s in zip([model_5.model(length), model_7.model(length), model_10.model(length)], [5, 7, 10]): 63 if use_10:
60 for model, s in zip([model_5.model(length), model_7.model(length)], [5, 7]): 64 zipped_models = zip([model_5.model(length), model_7.model(length), model_10.model(length)], [5, 7, 10])
65 else:
66 zipped_models = zip([model_5.model(length), model_7.model(length)], [5, 7])
67 for model, s in zipped_models:
61 model.load_weights(Path(nn_weights_path, f"model_{s}_{length}.h5")) 68 model.load_weights(Path(nn_weights_path, f"model_{s}_{length}.h5"))
62 prediction = model.predict([test_encoded, test_encoded_rc], batch_size) 69 prediction = model.predict([test_encoded, test_encoded_rc], batch_size)
63 out_table[f"pred_plant_{s}"].extend(list(prediction[..., 0])) 70 out_table[f"pred_plant_{s}"].extend(list(prediction[..., 0]))
64 out_table[f"pred_vir_{s}"].extend(list(prediction[..., 1])) 71 out_table[f"pred_vir_{s}"].extend(list(prediction[..., 1]))
65 out_table[f"pred_bact_{s}"].extend(list(prediction[..., 2])) 72 out_table[f"pred_bact_{s}"].extend(list(prediction[..., 2]))
73
66 return pd.DataFrame(out_table) 74 return pd.DataFrame(out_table)
67 75
68 76
69 def predict_rf(df, rf_weights_path, length): 77 def predict_rf(df, rf_weights_path, length, use_10):
70 """ 78 """
71 Using predictions by predict_nn and weights of a trained RF classifier gives a single prediction for a fragment 79 Using predictions by predict_nn and weights of a trained RF classifier gives a single prediction for a fragment
72 """ 80 """
73 81
74 clf = load(Path(rf_weights_path, f"RF_{length}.joblib")) 82 clf = load(Path(rf_weights_path, f"RF_{length}.joblib"))
75 X = df[["pred_plant_5", "pred_vir_5", "pred_plant_7", "pred_vir_7"]] 83 if use_10:
76 # X = ["pred_plant_5", "pred_vir_5", "pred_plant_7", "pred_vir_7", "pred_plant_10", "pred_vir_10", ]] 84 X = df[
85 ["pred_plant_5", "pred_vir_5", "pred_plant_7", "pred_vir_7", "pred_plant_10", "pred_vir_10", ]]
86 else:
87 X = df[
88 ["pred_plant_5", "pred_vir_5", "pred_plant_7", "pred_vir_7", ]]
77 y_pred = clf.predict(X) 89 y_pred = clf.predict(X)
78 mapping = {0: "plant", 1: "virus", 2: "bacteria"} 90 mapping = {0: "plant", 1: "virus", 2: "bacteria"}
79 df["RF_decision"] = np.vectorize(mapping.get)(y_pred) 91 df["RF_decision"] = np.vectorize(mapping.get)(y_pred)
80 prob_classes = clf.predict_proba(X) 92 prob_classes = clf.predict_proba(X)
81 df["RF_pred_plant"] = prob_classes[..., 0] 93 df["RF_pred_plant"] = prob_classes[..., 0]
87 def predict_contigs(df): 99 def predict_contigs(df):
88 """ 100 """
89 Based on predictions of predict_rf for fragments gives a final prediction for the whole contig 101 Based on predictions of predict_rf for fragments gives a final prediction for the whole contig
90 """ 102 """
91 df = ( 103 df = (
92 df.groupby(["id", "length", 'RF_decision'], sort=False) 104 df.groupby(["id", "length", 'RF_decision'], sort=False).size().unstack(fill_value=0)
93 .size()
94 .unstack(fill_value=0)
95 ) 105 )
96 df = df.reset_index() 106 df = df.reset_index()
97 df = df.reindex(['length', 'id', 'virus', 'plant', 'bacteria'], axis=1) 107 df = df.reindex(['length', 'id', 'virus', 'plant', 'bacteria'], axis=1)
98 conditions = [ 108 conditions = [
99 (df['virus'] > df['plant']) & (df['virus'] > df['bacteria']), 109 (df['virus'] > df['plant']) & (df['virus'] > df['bacteria']),
129 139
130 assert Path(test_ds[0]).exists(), f'{test_ds[0]} does not exist' 140 assert Path(test_ds[0]).exists(), f'{test_ds[0]} does not exist'
131 assert Path(weights).exists(), f'{weights} does not exist' 141 assert Path(weights).exists(), f'{weights} does not exist'
132 assert isinstance(limit, int), 'limit should be an integer' 142 assert isinstance(limit, int), 'limit should be an integer'
133 Path(out_path).mkdir(parents=True, exist_ok=True) 143 Path(out_path).mkdir(parents=True, exist_ok=True)
134 144 use_10 = Path(weights, 'model_10_500.h5').exists()
135 for ts in test_ds: 145 for ts in test_ds:
136 dfs_fr = [] 146 dfs_fr = []
137 dfs_cont = [] 147 dfs_cont = []
138 for l_ in 500, 1000: 148 for l_ in 500, 1000:
139 # print(f'starting prediction for {Path(ts).name} for fragment length {l_}') 149 # print(f'starting prediction for {Path(ts).name} for fragment length {l_}')
140 df = predict_nn( 150 df = predict_nn(
141 ds_path=ts, 151 ds_path=ts,
142 nn_weights_path=weights, 152 nn_weights_path=weights,
143 length=l_, 153 length=l_,
154 use_10=use_10
144 ) 155 )
145 print(df) 156 print(df)
146 df = predict_rf( 157 df = predict_rf(
147 df=df, 158 df=df,
148 rf_weights_path=weights, 159 rf_weights_path=weights,
149 length=l_, 160 length=l_,
161 use_10=use_10
150 ) 162 )
151 df = df.round(3) 163 df = df.round(3)
152 dfs_fr.append(df) 164 dfs_fr.append(df)
153 df = predict_contigs(df) 165 df = predict_contigs(df)
154 dfs_cont.append(df) 166 dfs_cont.append(df)
176 parser = argparse.ArgumentParser() 188 parser = argparse.ArgumentParser()
177 parser.add_argument("--test_ds", help="path to the input file with contigs in fasta format (str or list of str)") 189 parser.add_argument("--test_ds", help="path to the input file with contigs in fasta format (str or list of str)")
178 parser.add_argument("--weights", help="path to the folder containing weights for NN and RF modules trained on 500 and 1000 fragment lengths (str)") 190 parser.add_argument("--weights", help="path to the folder containing weights for NN and RF modules trained on 500 and 1000 fragment lengths (str)")
179 parser.add_argument("--out_path", help="path to the folder to store predictions (str)") 191 parser.add_argument("--out_path", help="path to the folder to store predictions (str)")
180 parser.add_argument("--return_viral", help="whether to return contigs annotated as viral in separate fasta file (True/False)") 192 parser.add_argument("--return_viral", help="whether to return contigs annotated as viral in separate fasta file (True/False)")
181 parser.add_argument("--limit", help="Do predictions only for contigs > l. We suggest l=750. (int)", type=int) 193 parser.add_argument("--limit", help="Do predictions only for contigs > l. We suggest l=750. (int)", type=int, default=750)
182 194
183 args = parser.parse_args() 195 args = parser.parse_args()
184 if args.test_ds: 196 if args.test_ds:
185 test_ds = args.test_ds 197 test_ds = args.test_ds
186 if args.weights: 198 if args.weights: