Mercurial > repos > iuc > virhunter
comparison predict.py @ 1:9b12bc1b1e2c draft
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit cea5a324cffd684e4cb08438fe5a305cfaf0f73a
author | iuc |
---|---|
date | Wed, 30 Nov 2022 17:31:52 +0000 |
parents | 457fd8fd681a |
children | ea2cccb9f73e |
comparison
equal
deleted
inserted
replaced
0:457fd8fd681a | 1:9b12bc1b1e2c |
---|---|
7 | 7 |
8 import numpy as np | 8 import numpy as np |
9 import pandas as pd | 9 import pandas as pd |
10 from Bio import SeqIO | 10 from Bio import SeqIO |
11 from joblib import load | 11 from joblib import load |
12 from models import model_5, model_7 | 12 from models import model_10, model_5, model_7 |
13 from utils import preprocess as pp | 13 from utils import preprocess as pp |
14 | 14 |
15 os.environ["CUDA_VISIBLE_DEVICES"] = "" | 15 os.environ["CUDA_VISIBLE_DEVICES"] = "" |
16 os.environ["TF_XLA_FLAGS"] = "--tf_xla_cpu_global_jit" | 16 os.environ["TF_XLA_FLAGS"] = "--tf_xla_cpu_global_jit" |
17 # loglevel : 0 all printed, 1 I not printed, 2 I and W not printed, 3 nothing printed | 17 # loglevel : 0 all printed, 1 I not printed, 2 I and W not printed, 3 nothing printed |
18 os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' | 18 os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' |
19 | 19 |
20 | 20 |
21 def predict_nn(ds_path, nn_weights_path, length, batch_size=256): | 21 def predict_nn(ds_path, nn_weights_path, length, use_10, batch_size=256): |
22 """ | 22 """ |
23 Breaks down contigs into fragments | 23 Breaks down contigs into fragments |
24 and uses pretrained neural networks to give predictions for fragments | 24 and uses pretrained neural networks to give predictions for fragments |
25 """ | 25 """ |
26 try: | 26 try: |
35 "pred_vir_5": [], | 35 "pred_vir_5": [], |
36 "pred_bact_5": [], | 36 "pred_bact_5": [], |
37 "pred_plant_7": [], | 37 "pred_plant_7": [], |
38 "pred_vir_7": [], | 38 "pred_vir_7": [], |
39 "pred_bact_7": [], | 39 "pred_bact_7": [], |
40 # "pred_plant_10": [], | |
41 # "pred_vir_10": [], | |
42 # "pred_bact_10": [], | |
43 } | 40 } |
41 if use_10: | |
42 out_table_ = { | |
43 "pred_plant_10": [], | |
44 "pred_vir_10": [], | |
45 "pred_bact_10": [], | |
46 } | |
47 out_table.update(out_table_) | |
44 if not seqs_: | 48 if not seqs_: |
45 raise ValueError("All sequences were smaller than length of the model") | 49 raise ValueError("All sequences were smaller than length of the model") |
46 test_fragments = [] | 50 test_fragments = [] |
47 test_fragments_rc = [] | 51 test_fragments_rc = [] |
48 for seq in seqs_: | 52 for seq in seqs_: |
54 out_table["id"].append(seq.id) | 58 out_table["id"].append(seq.id) |
55 out_table["length"].append(len(seq.seq)) | 59 out_table["length"].append(len(seq.seq)) |
56 out_table["fragment"].append(j) | 60 out_table["fragment"].append(j) |
57 test_encoded = pp.one_hot_encode(test_fragments) | 61 test_encoded = pp.one_hot_encode(test_fragments) |
58 test_encoded_rc = pp.one_hot_encode(test_fragments_rc) | 62 test_encoded_rc = pp.one_hot_encode(test_fragments_rc) |
59 # for model, s in zip([model_5.model(length), model_7.model(length), model_10.model(length)], [5, 7, 10]): | 63 if use_10: |
60 for model, s in zip([model_5.model(length), model_7.model(length)], [5, 7]): | 64 zipped_models = zip([model_5.model(length), model_7.model(length), model_10.model(length)], [5, 7, 10]) |
65 else: | |
66 zipped_models = zip([model_5.model(length), model_7.model(length)], [5, 7]) | |
67 for model, s in zipped_models: | |
61 model.load_weights(Path(nn_weights_path, f"model_{s}_{length}.h5")) | 68 model.load_weights(Path(nn_weights_path, f"model_{s}_{length}.h5")) |
62 prediction = model.predict([test_encoded, test_encoded_rc], batch_size) | 69 prediction = model.predict([test_encoded, test_encoded_rc], batch_size) |
63 out_table[f"pred_plant_{s}"].extend(list(prediction[..., 0])) | 70 out_table[f"pred_plant_{s}"].extend(list(prediction[..., 0])) |
64 out_table[f"pred_vir_{s}"].extend(list(prediction[..., 1])) | 71 out_table[f"pred_vir_{s}"].extend(list(prediction[..., 1])) |
65 out_table[f"pred_bact_{s}"].extend(list(prediction[..., 2])) | 72 out_table[f"pred_bact_{s}"].extend(list(prediction[..., 2])) |
73 | |
66 return pd.DataFrame(out_table) | 74 return pd.DataFrame(out_table) |
67 | 75 |
68 | 76 |
69 def predict_rf(df, rf_weights_path, length): | 77 def predict_rf(df, rf_weights_path, length, use_10): |
70 """ | 78 """ |
71 Using predictions by predict_nn and weights of a trained RF classifier gives a single prediction for a fragment | 79 Using predictions by predict_nn and weights of a trained RF classifier gives a single prediction for a fragment |
72 """ | 80 """ |
73 | 81 |
74 clf = load(Path(rf_weights_path, f"RF_{length}.joblib")) | 82 clf = load(Path(rf_weights_path, f"RF_{length}.joblib")) |
75 X = df[["pred_plant_5", "pred_vir_5", "pred_plant_7", "pred_vir_7"]] | 83 if use_10: |
76 # X = ["pred_plant_5", "pred_vir_5", "pred_plant_7", "pred_vir_7", "pred_plant_10", "pred_vir_10", ]] | 84 X = df[ |
85 ["pred_plant_5", "pred_vir_5", "pred_plant_7", "pred_vir_7", "pred_plant_10", "pred_vir_10", ]] | |
86 else: | |
87 X = df[ | |
88 ["pred_plant_5", "pred_vir_5", "pred_plant_7", "pred_vir_7", ]] | |
77 y_pred = clf.predict(X) | 89 y_pred = clf.predict(X) |
78 mapping = {0: "plant", 1: "virus", 2: "bacteria"} | 90 mapping = {0: "plant", 1: "virus", 2: "bacteria"} |
79 df["RF_decision"] = np.vectorize(mapping.get)(y_pred) | 91 df["RF_decision"] = np.vectorize(mapping.get)(y_pred) |
80 prob_classes = clf.predict_proba(X) | 92 prob_classes = clf.predict_proba(X) |
81 df["RF_pred_plant"] = prob_classes[..., 0] | 93 df["RF_pred_plant"] = prob_classes[..., 0] |
87 def predict_contigs(df): | 99 def predict_contigs(df): |
88 """ | 100 """ |
89 Based on predictions of predict_rf for fragments gives a final prediction for the whole contig | 101 Based on predictions of predict_rf for fragments gives a final prediction for the whole contig |
90 """ | 102 """ |
91 df = ( | 103 df = ( |
92 df.groupby(["id", "length", 'RF_decision'], sort=False) | 104 df.groupby(["id", "length", 'RF_decision'], sort=False).size().unstack(fill_value=0) |
93 .size() | |
94 .unstack(fill_value=0) | |
95 ) | 105 ) |
96 df = df.reset_index() | 106 df = df.reset_index() |
97 df = df.reindex(['length', 'id', 'virus', 'plant', 'bacteria'], axis=1) | 107 df = df.reindex(['length', 'id', 'virus', 'plant', 'bacteria'], axis=1) |
98 conditions = [ | 108 conditions = [ |
99 (df['virus'] > df['plant']) & (df['virus'] > df['bacteria']), | 109 (df['virus'] > df['plant']) & (df['virus'] > df['bacteria']), |
129 | 139 |
130 assert Path(test_ds[0]).exists(), f'{test_ds[0]} does not exist' | 140 assert Path(test_ds[0]).exists(), f'{test_ds[0]} does not exist' |
131 assert Path(weights).exists(), f'{weights} does not exist' | 141 assert Path(weights).exists(), f'{weights} does not exist' |
132 assert isinstance(limit, int), 'limit should be an integer' | 142 assert isinstance(limit, int), 'limit should be an integer' |
133 Path(out_path).mkdir(parents=True, exist_ok=True) | 143 Path(out_path).mkdir(parents=True, exist_ok=True) |
134 | 144 use_10 = Path(weights, 'model_10_500.h5').exists() |
135 for ts in test_ds: | 145 for ts in test_ds: |
136 dfs_fr = [] | 146 dfs_fr = [] |
137 dfs_cont = [] | 147 dfs_cont = [] |
138 for l_ in 500, 1000: | 148 for l_ in 500, 1000: |
139 # print(f'starting prediction for {Path(ts).name} for fragment length {l_}') | 149 # print(f'starting prediction for {Path(ts).name} for fragment length {l_}') |
140 df = predict_nn( | 150 df = predict_nn( |
141 ds_path=ts, | 151 ds_path=ts, |
142 nn_weights_path=weights, | 152 nn_weights_path=weights, |
143 length=l_, | 153 length=l_, |
154 use_10=use_10 | |
144 ) | 155 ) |
145 print(df) | 156 print(df) |
146 df = predict_rf( | 157 df = predict_rf( |
147 df=df, | 158 df=df, |
148 rf_weights_path=weights, | 159 rf_weights_path=weights, |
149 length=l_, | 160 length=l_, |
161 use_10=use_10 | |
150 ) | 162 ) |
151 df = df.round(3) | 163 df = df.round(3) |
152 dfs_fr.append(df) | 164 dfs_fr.append(df) |
153 df = predict_contigs(df) | 165 df = predict_contigs(df) |
154 dfs_cont.append(df) | 166 dfs_cont.append(df) |
176 parser = argparse.ArgumentParser() | 188 parser = argparse.ArgumentParser() |
177 parser.add_argument("--test_ds", help="path to the input file with contigs in fasta format (str or list of str)") | 189 parser.add_argument("--test_ds", help="path to the input file with contigs in fasta format (str or list of str)") |
178 parser.add_argument("--weights", help="path to the folder containing weights for NN and RF modules trained on 500 and 1000 fragment lengths (str)") | 190 parser.add_argument("--weights", help="path to the folder containing weights for NN and RF modules trained on 500 and 1000 fragment lengths (str)") |
179 parser.add_argument("--out_path", help="path to the folder to store predictions (str)") | 191 parser.add_argument("--out_path", help="path to the folder to store predictions (str)") |
180 parser.add_argument("--return_viral", help="whether to return contigs annotated as viral in separate fasta file (True/False)") | 192 parser.add_argument("--return_viral", help="whether to return contigs annotated as viral in separate fasta file (True/False)") |
181 parser.add_argument("--limit", help="Do predictions only for contigs > l. We suggest l=750. (int)", type=int) | 193 parser.add_argument("--limit", help="Do predictions only for contigs > l. We suggest l=750. (int)", type=int, default=750) |
182 | 194 |
183 args = parser.parse_args() | 195 args = parser.parse_args() |
184 if args.test_ds: | 196 if args.test_ds: |
185 test_ds = args.test_ds | 197 test_ds = args.test_ds |
186 if args.weights: | 198 if args.weights: |