Mercurial > repos > iuc > virhunter
comparison predict.py @ 1:9b12bc1b1e2c draft
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit cea5a324cffd684e4cb08438fe5a305cfaf0f73a
| author | iuc |
|---|---|
| date | Wed, 30 Nov 2022 17:31:52 +0000 |
| parents | 457fd8fd681a |
| children | ea2cccb9f73e |
comparison
equal
deleted
inserted
replaced
| 0:457fd8fd681a | 1:9b12bc1b1e2c |
|---|---|
| 7 | 7 |
| 8 import numpy as np | 8 import numpy as np |
| 9 import pandas as pd | 9 import pandas as pd |
| 10 from Bio import SeqIO | 10 from Bio import SeqIO |
| 11 from joblib import load | 11 from joblib import load |
| 12 from models import model_5, model_7 | 12 from models import model_10, model_5, model_7 |
| 13 from utils import preprocess as pp | 13 from utils import preprocess as pp |
| 14 | 14 |
| 15 os.environ["CUDA_VISIBLE_DEVICES"] = "" | 15 os.environ["CUDA_VISIBLE_DEVICES"] = "" |
| 16 os.environ["TF_XLA_FLAGS"] = "--tf_xla_cpu_global_jit" | 16 os.environ["TF_XLA_FLAGS"] = "--tf_xla_cpu_global_jit" |
| 17 # loglevel : 0 all printed, 1 I not printed, 2 I and W not printed, 3 nothing printed | 17 # loglevel : 0 all printed, 1 I not printed, 2 I and W not printed, 3 nothing printed |
| 18 os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' | 18 os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' |
| 19 | 19 |
| 20 | 20 |
| 21 def predict_nn(ds_path, nn_weights_path, length, batch_size=256): | 21 def predict_nn(ds_path, nn_weights_path, length, use_10, batch_size=256): |
| 22 """ | 22 """ |
| 23 Breaks down contigs into fragments | 23 Breaks down contigs into fragments |
| 24 and uses pretrained neural networks to give predictions for fragments | 24 and uses pretrained neural networks to give predictions for fragments |
| 25 """ | 25 """ |
| 26 try: | 26 try: |
| 35 "pred_vir_5": [], | 35 "pred_vir_5": [], |
| 36 "pred_bact_5": [], | 36 "pred_bact_5": [], |
| 37 "pred_plant_7": [], | 37 "pred_plant_7": [], |
| 38 "pred_vir_7": [], | 38 "pred_vir_7": [], |
| 39 "pred_bact_7": [], | 39 "pred_bact_7": [], |
| 40 # "pred_plant_10": [], | |
| 41 # "pred_vir_10": [], | |
| 42 # "pred_bact_10": [], | |
| 43 } | 40 } |
| 41 if use_10: | |
| 42 out_table_ = { | |
| 43 "pred_plant_10": [], | |
| 44 "pred_vir_10": [], | |
| 45 "pred_bact_10": [], | |
| 46 } | |
| 47 out_table.update(out_table_) | |
| 44 if not seqs_: | 48 if not seqs_: |
| 45 raise ValueError("All sequences were smaller than length of the model") | 49 raise ValueError("All sequences were smaller than length of the model") |
| 46 test_fragments = [] | 50 test_fragments = [] |
| 47 test_fragments_rc = [] | 51 test_fragments_rc = [] |
| 48 for seq in seqs_: | 52 for seq in seqs_: |
| 54 out_table["id"].append(seq.id) | 58 out_table["id"].append(seq.id) |
| 55 out_table["length"].append(len(seq.seq)) | 59 out_table["length"].append(len(seq.seq)) |
| 56 out_table["fragment"].append(j) | 60 out_table["fragment"].append(j) |
| 57 test_encoded = pp.one_hot_encode(test_fragments) | 61 test_encoded = pp.one_hot_encode(test_fragments) |
| 58 test_encoded_rc = pp.one_hot_encode(test_fragments_rc) | 62 test_encoded_rc = pp.one_hot_encode(test_fragments_rc) |
| 59 # for model, s in zip([model_5.model(length), model_7.model(length), model_10.model(length)], [5, 7, 10]): | 63 if use_10: |
| 60 for model, s in zip([model_5.model(length), model_7.model(length)], [5, 7]): | 64 zipped_models = zip([model_5.model(length), model_7.model(length), model_10.model(length)], [5, 7, 10]) |
| 65 else: | |
| 66 zipped_models = zip([model_5.model(length), model_7.model(length)], [5, 7]) | |
| 67 for model, s in zipped_models: | |
| 61 model.load_weights(Path(nn_weights_path, f"model_{s}_{length}.h5")) | 68 model.load_weights(Path(nn_weights_path, f"model_{s}_{length}.h5")) |
| 62 prediction = model.predict([test_encoded, test_encoded_rc], batch_size) | 69 prediction = model.predict([test_encoded, test_encoded_rc], batch_size) |
| 63 out_table[f"pred_plant_{s}"].extend(list(prediction[..., 0])) | 70 out_table[f"pred_plant_{s}"].extend(list(prediction[..., 0])) |
| 64 out_table[f"pred_vir_{s}"].extend(list(prediction[..., 1])) | 71 out_table[f"pred_vir_{s}"].extend(list(prediction[..., 1])) |
| 65 out_table[f"pred_bact_{s}"].extend(list(prediction[..., 2])) | 72 out_table[f"pred_bact_{s}"].extend(list(prediction[..., 2])) |
| 73 | |
| 66 return pd.DataFrame(out_table) | 74 return pd.DataFrame(out_table) |
| 67 | 75 |
| 68 | 76 |
| 69 def predict_rf(df, rf_weights_path, length): | 77 def predict_rf(df, rf_weights_path, length, use_10): |
| 70 """ | 78 """ |
| 71 Using predictions by predict_nn and weights of a trained RF classifier gives a single prediction for a fragment | 79 Using predictions by predict_nn and weights of a trained RF classifier gives a single prediction for a fragment |
| 72 """ | 80 """ |
| 73 | 81 |
| 74 clf = load(Path(rf_weights_path, f"RF_{length}.joblib")) | 82 clf = load(Path(rf_weights_path, f"RF_{length}.joblib")) |
| 75 X = df[["pred_plant_5", "pred_vir_5", "pred_plant_7", "pred_vir_7"]] | 83 if use_10: |
| 76 # X = ["pred_plant_5", "pred_vir_5", "pred_plant_7", "pred_vir_7", "pred_plant_10", "pred_vir_10", ]] | 84 X = df[ |
| 85 ["pred_plant_5", "pred_vir_5", "pred_plant_7", "pred_vir_7", "pred_plant_10", "pred_vir_10", ]] | |
| 86 else: | |
| 87 X = df[ | |
| 88 ["pred_plant_5", "pred_vir_5", "pred_plant_7", "pred_vir_7", ]] | |
| 77 y_pred = clf.predict(X) | 89 y_pred = clf.predict(X) |
| 78 mapping = {0: "plant", 1: "virus", 2: "bacteria"} | 90 mapping = {0: "plant", 1: "virus", 2: "bacteria"} |
| 79 df["RF_decision"] = np.vectorize(mapping.get)(y_pred) | 91 df["RF_decision"] = np.vectorize(mapping.get)(y_pred) |
| 80 prob_classes = clf.predict_proba(X) | 92 prob_classes = clf.predict_proba(X) |
| 81 df["RF_pred_plant"] = prob_classes[..., 0] | 93 df["RF_pred_plant"] = prob_classes[..., 0] |
| 87 def predict_contigs(df): | 99 def predict_contigs(df): |
| 88 """ | 100 """ |
| 89 Based on predictions of predict_rf for fragments gives a final prediction for the whole contig | 101 Based on predictions of predict_rf for fragments gives a final prediction for the whole contig |
| 90 """ | 102 """ |
| 91 df = ( | 103 df = ( |
| 92 df.groupby(["id", "length", 'RF_decision'], sort=False) | 104 df.groupby(["id", "length", 'RF_decision'], sort=False).size().unstack(fill_value=0) |
| 93 .size() | |
| 94 .unstack(fill_value=0) | |
| 95 ) | 105 ) |
| 96 df = df.reset_index() | 106 df = df.reset_index() |
| 97 df = df.reindex(['length', 'id', 'virus', 'plant', 'bacteria'], axis=1) | 107 df = df.reindex(['length', 'id', 'virus', 'plant', 'bacteria'], axis=1) |
| 98 conditions = [ | 108 conditions = [ |
| 99 (df['virus'] > df['plant']) & (df['virus'] > df['bacteria']), | 109 (df['virus'] > df['plant']) & (df['virus'] > df['bacteria']), |
| 129 | 139 |
| 130 assert Path(test_ds[0]).exists(), f'{test_ds[0]} does not exist' | 140 assert Path(test_ds[0]).exists(), f'{test_ds[0]} does not exist' |
| 131 assert Path(weights).exists(), f'{weights} does not exist' | 141 assert Path(weights).exists(), f'{weights} does not exist' |
| 132 assert isinstance(limit, int), 'limit should be an integer' | 142 assert isinstance(limit, int), 'limit should be an integer' |
| 133 Path(out_path).mkdir(parents=True, exist_ok=True) | 143 Path(out_path).mkdir(parents=True, exist_ok=True) |
| 134 | 144 use_10 = Path(weights, 'model_10_500.h5').exists() |
| 135 for ts in test_ds: | 145 for ts in test_ds: |
| 136 dfs_fr = [] | 146 dfs_fr = [] |
| 137 dfs_cont = [] | 147 dfs_cont = [] |
| 138 for l_ in 500, 1000: | 148 for l_ in 500, 1000: |
| 139 # print(f'starting prediction for {Path(ts).name} for fragment length {l_}') | 149 # print(f'starting prediction for {Path(ts).name} for fragment length {l_}') |
| 140 df = predict_nn( | 150 df = predict_nn( |
| 141 ds_path=ts, | 151 ds_path=ts, |
| 142 nn_weights_path=weights, | 152 nn_weights_path=weights, |
| 143 length=l_, | 153 length=l_, |
| 154 use_10=use_10 | |
| 144 ) | 155 ) |
| 145 print(df) | 156 print(df) |
| 146 df = predict_rf( | 157 df = predict_rf( |
| 147 df=df, | 158 df=df, |
| 148 rf_weights_path=weights, | 159 rf_weights_path=weights, |
| 149 length=l_, | 160 length=l_, |
| 161 use_10=use_10 | |
| 150 ) | 162 ) |
| 151 df = df.round(3) | 163 df = df.round(3) |
| 152 dfs_fr.append(df) | 164 dfs_fr.append(df) |
| 153 df = predict_contigs(df) | 165 df = predict_contigs(df) |
| 154 dfs_cont.append(df) | 166 dfs_cont.append(df) |
| 176 parser = argparse.ArgumentParser() | 188 parser = argparse.ArgumentParser() |
| 177 parser.add_argument("--test_ds", help="path to the input file with contigs in fasta format (str or list of str)") | 189 parser.add_argument("--test_ds", help="path to the input file with contigs in fasta format (str or list of str)") |
| 178 parser.add_argument("--weights", help="path to the folder containing weights for NN and RF modules trained on 500 and 1000 fragment lengths (str)") | 190 parser.add_argument("--weights", help="path to the folder containing weights for NN and RF modules trained on 500 and 1000 fragment lengths (str)") |
| 179 parser.add_argument("--out_path", help="path to the folder to store predictions (str)") | 191 parser.add_argument("--out_path", help="path to the folder to store predictions (str)") |
| 180 parser.add_argument("--return_viral", help="whether to return contigs annotated as viral in separate fasta file (True/False)") | 192 parser.add_argument("--return_viral", help="whether to return contigs annotated as viral in separate fasta file (True/False)") |
| 181 parser.add_argument("--limit", help="Do predictions only for contigs > l. We suggest l=750. (int)", type=int) | 193 parser.add_argument("--limit", help="Do predictions only for contigs > l. We suggest l=750. (int)", type=int, default=750) |
| 182 | 194 |
| 183 args = parser.parse_args() | 195 args = parser.parse_args() |
| 184 if args.test_ds: | 196 if args.test_ds: |
| 185 test_ds = args.test_ds | 197 test_ds = args.test_ds |
| 186 if args.weights: | 198 if args.weights: |
