Mercurial > repos > bgruening > sklearn_train_test_eval
comparison model_prediction.py @ 15:2eb5c017958d draft
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 9981e25b00de29ed881b2229a173a8c812ded9bb
author | bgruening |
---|---|
date | Wed, 09 Aug 2023 13:15:27 +0000 |
parents | caf7d2b71a48 |
children |
comparison
equal
deleted
inserted
replaced
14:4d1637cac794 | 15:2eb5c017958d |
---|---|
2 import json | 2 import json |
3 import warnings | 3 import warnings |
4 | 4 |
5 import numpy as np | 5 import numpy as np |
6 import pandas as pd | 6 import pandas as pd |
7 from galaxy_ml.utils import get_module, load_model, read_columns, try_get_attr | 7 from galaxy_ml.model_persist import load_model_from_h5 |
8 from galaxy_ml.utils import (clean_params, get_module, read_columns, | |
9 try_get_attr) | |
8 from scipy.io import mmread | 10 from scipy.io import mmread |
9 from sklearn.pipeline import Pipeline | |
10 | 11 |
11 N_JOBS = int(__import__("os").environ.get("GALAXY_SLOTS", 1)) | 12 N_JOBS = int(__import__("os").environ.get("GALAXY_SLOTS", 1)) |
12 | 13 |
13 | 14 |
14 def main( | 15 def main( |
15 inputs, | 16 inputs, |
16 infile_estimator, | 17 infile_estimator, |
17 outfile_predict, | 18 outfile_predict, |
18 infile_weights=None, | |
19 infile1=None, | 19 infile1=None, |
20 fasta_path=None, | 20 fasta_path=None, |
21 ref_seq=None, | 21 ref_seq=None, |
22 vcf_path=None, | 22 vcf_path=None, |
23 ): | 23 ): |
25 Parameter | 25 Parameter |
26 --------- | 26 --------- |
27 inputs : str | 27 inputs : str |
28 File path to galaxy tool parameter | 28 File path to galaxy tool parameter |
29 | 29 |
30 infile_estimator : strgit | 30 infile_estimator : str |
31 File path to trained estimator input | 31 File path to trained estimator input |
32 | 32 |
33 outfile_predict : str | 33 outfile_predict : str |
34 File path to save the prediction results, tabular | 34 File path to save the prediction results, tabular |
35 | |
36 infile_weights : str | |
37 File path to weights input | |
38 | 35 |
39 infile1 : str | 36 infile1 : str |
40 File path to dataset containing features | 37 File path to dataset containing features |
41 | 38 |
42 fasta_path : str | 39 fasta_path : str |
52 | 49 |
53 with open(inputs, "r") as param_handler: | 50 with open(inputs, "r") as param_handler: |
54 params = json.load(param_handler) | 51 params = json.load(param_handler) |
55 | 52 |
56 # load model | 53 # load model |
57 with open(infile_estimator, "rb") as est_handler: | 54 estimator = load_model_from_h5(infile_estimator) |
58 estimator = load_model(est_handler) | 55 estimator = clean_params(estimator) |
59 | |
60 main_est = estimator | |
61 if isinstance(estimator, Pipeline): | |
62 main_est = estimator.steps[-1][-1] | |
63 if hasattr(main_est, "config") and hasattr(main_est, "load_weights"): | |
64 if not infile_weights or infile_weights == "None": | |
65 raise ValueError( | |
66 "The selected model skeleton asks for weights, " | |
67 "but dataset for weights wan not selected!" | |
68 ) | |
69 main_est.load_weights(infile_weights) | |
70 | 56 |
71 # handle data input | 57 # handle data input |
72 input_type = params["input_options"]["selected_input"] | 58 input_type = params["input_options"]["selected_input"] |
73 # tabular input | 59 # tabular input |
74 if input_type == "tabular": | 60 if input_type == "tabular": |
219 | 205 |
220 if __name__ == "__main__": | 206 if __name__ == "__main__": |
221 aparser = argparse.ArgumentParser() | 207 aparser = argparse.ArgumentParser() |
222 aparser.add_argument("-i", "--inputs", dest="inputs", required=True) | 208 aparser.add_argument("-i", "--inputs", dest="inputs", required=True) |
223 aparser.add_argument("-e", "--infile_estimator", dest="infile_estimator") | 209 aparser.add_argument("-e", "--infile_estimator", dest="infile_estimator") |
224 aparser.add_argument("-w", "--infile_weights", dest="infile_weights") | |
225 aparser.add_argument("-X", "--infile1", dest="infile1") | 210 aparser.add_argument("-X", "--infile1", dest="infile1") |
226 aparser.add_argument("-O", "--outfile_predict", dest="outfile_predict") | 211 aparser.add_argument("-O", "--outfile_predict", dest="outfile_predict") |
227 aparser.add_argument("-f", "--fasta_path", dest="fasta_path") | 212 aparser.add_argument("-f", "--fasta_path", dest="fasta_path") |
228 aparser.add_argument("-r", "--ref_seq", dest="ref_seq") | 213 aparser.add_argument("-r", "--ref_seq", dest="ref_seq") |
229 aparser.add_argument("-v", "--vcf_path", dest="vcf_path") | 214 aparser.add_argument("-v", "--vcf_path", dest="vcf_path") |
231 | 216 |
232 main( | 217 main( |
233 args.inputs, | 218 args.inputs, |
234 args.infile_estimator, | 219 args.infile_estimator, |
235 args.outfile_predict, | 220 args.outfile_predict, |
236 infile_weights=args.infile_weights, | |
237 infile1=args.infile1, | 221 infile1=args.infile1, |
238 fasta_path=args.fasta_path, | 222 fasta_path=args.fasta_path, |
239 ref_seq=args.ref_seq, | 223 ref_seq=args.ref_seq, |
240 vcf_path=args.vcf_path, | 224 vcf_path=args.vcf_path, |
241 ) | 225 ) |