Mercurial > repos > bgruening > sklearn_train_test_split
comparison model_prediction.py @ 6:13b9ac5d277c draft
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 208a8d348e7c7a182cfbe1b6f17868146428a7e2"
author | bgruening |
---|---|
date | Tue, 13 Apr 2021 22:24:07 +0000 |
parents | 5a092779412e |
children | 3312fb686ffb |
comparison
equal
deleted
inserted
replaced
5:ce2fd1edbc6e | 6:13b9ac5d277c |
---|---|
1 import argparse | 1 import argparse |
2 import json | 2 import json |
3 import warnings | |
4 | |
3 import numpy as np | 5 import numpy as np |
4 import pandas as pd | 6 import pandas as pd |
5 import warnings | 7 from galaxy_ml.utils import get_module, load_model, read_columns, try_get_attr |
6 | |
7 from scipy.io import mmread | 8 from scipy.io import mmread |
8 from sklearn.pipeline import Pipeline | 9 from sklearn.pipeline import Pipeline |
9 | 10 |
10 from galaxy_ml.utils import (load_model, read_columns, | 11 N_JOBS = int(__import__("os").environ.get("GALAXY_SLOTS", 1)) |
11 get_module, try_get_attr) | 12 |
12 | 13 |
13 | 14 def main( |
14 N_JOBS = int(__import__('os').environ.get('GALAXY_SLOTS', 1)) | 15 inputs, |
15 | 16 infile_estimator, |
16 | 17 outfile_predict, |
17 def main(inputs, infile_estimator, outfile_predict, | 18 infile_weights=None, |
18 infile_weights=None, infile1=None, | 19 infile1=None, |
19 fasta_path=None, ref_seq=None, | 20 fasta_path=None, |
20 vcf_path=None): | 21 ref_seq=None, |
22 vcf_path=None, | |
23 ): | |
21 """ | 24 """ |
22 Parameter | 25 Parameter |
23 --------- | 26 --------- |
24 inputs : str | 27 inputs : str |
25 File path to galaxy tool parameter | 28 File path to galaxy tool parameter |
43 File path to dataset containing the reference genome sequence. | 46 File path to dataset containing the reference genome sequence. |
44 | 47 |
45 vcf_path : str | 48 vcf_path : str |
46 File path to dataset containing variants info. | 49 File path to dataset containing variants info. |
47 """ | 50 """ |
48 warnings.filterwarnings('ignore') | 51 warnings.filterwarnings("ignore") |
49 | 52 |
50 with open(inputs, 'r') as param_handler: | 53 with open(inputs, "r") as param_handler: |
51 params = json.load(param_handler) | 54 params = json.load(param_handler) |
52 | 55 |
53 # load model | 56 # load model |
54 with open(infile_estimator, 'rb') as est_handler: | 57 with open(infile_estimator, "rb") as est_handler: |
55 estimator = load_model(est_handler) | 58 estimator = load_model(est_handler) |
56 | 59 |
57 main_est = estimator | 60 main_est = estimator |
58 if isinstance(estimator, Pipeline): | 61 if isinstance(estimator, Pipeline): |
59 main_est = estimator.steps[-1][-1] | 62 main_est = estimator.steps[-1][-1] |
60 if hasattr(main_est, 'config') and hasattr(main_est, 'load_weights'): | 63 if hasattr(main_est, "config") and hasattr(main_est, "load_weights"): |
61 if not infile_weights or infile_weights == 'None': | 64 if not infile_weights or infile_weights == "None": |
62 raise ValueError("The selected model skeleton asks for weights, " | 65 raise ValueError( |
63 "but dataset for weights wan not selected!") | 66 "The selected model skeleton asks for weights, " "but dataset for weights wan not selected!" |
67 ) | |
64 main_est.load_weights(infile_weights) | 68 main_est.load_weights(infile_weights) |
65 | 69 |
66 # handle data input | 70 # handle data input |
67 input_type = params['input_options']['selected_input'] | 71 input_type = params["input_options"]["selected_input"] |
68 # tabular input | 72 # tabular input |
69 if input_type == 'tabular': | 73 if input_type == "tabular": |
70 header = 'infer' if params['input_options']['header1'] else None | 74 header = "infer" if params["input_options"]["header1"] else None |
71 column_option = (params['input_options'] | 75 column_option = params["input_options"]["column_selector_options_1"]["selected_column_selector_option"] |
72 ['column_selector_options_1'] | 76 if column_option in [ |
73 ['selected_column_selector_option']) | 77 "by_index_number", |
74 if column_option in ['by_index_number', 'all_but_by_index_number', | 78 "all_but_by_index_number", |
75 'by_header_name', 'all_but_by_header_name']: | 79 "by_header_name", |
76 c = params['input_options']['column_selector_options_1']['col1'] | 80 "all_but_by_header_name", |
81 ]: | |
82 c = params["input_options"]["column_selector_options_1"]["col1"] | |
77 else: | 83 else: |
78 c = None | 84 c = None |
79 | 85 |
80 df = pd.read_csv(infile1, sep='\t', header=header, parse_dates=True) | 86 df = pd.read_csv(infile1, sep="\t", header=header, parse_dates=True) |
81 | 87 |
82 X = read_columns(df, c=c, c_option=column_option).astype(float) | 88 X = read_columns(df, c=c, c_option=column_option).astype(float) |
83 | 89 |
84 if params['method'] == 'predict': | 90 if params["method"] == "predict": |
85 preds = estimator.predict(X) | 91 preds = estimator.predict(X) |
86 else: | 92 else: |
87 preds = estimator.predict_proba(X) | 93 preds = estimator.predict_proba(X) |
88 | 94 |
89 # sparse input | 95 # sparse input |
90 elif input_type == 'sparse': | 96 elif input_type == "sparse": |
91 X = mmread(open(infile1, 'r')) | 97 X = mmread(open(infile1, "r")) |
92 if params['method'] == 'predict': | 98 if params["method"] == "predict": |
93 preds = estimator.predict(X) | 99 preds = estimator.predict(X) |
94 else: | 100 else: |
95 preds = estimator.predict_proba(X) | 101 preds = estimator.predict_proba(X) |
96 | 102 |
97 # fasta input | 103 # fasta input |
98 elif input_type == 'seq_fasta': | 104 elif input_type == "seq_fasta": |
99 if not hasattr(estimator, 'data_batch_generator'): | 105 if not hasattr(estimator, "data_batch_generator"): |
100 raise ValueError( | 106 raise ValueError( |
101 "To do prediction on sequences in fasta input, " | 107 "To do prediction on sequences in fasta input, " |
102 "the estimator must be a `KerasGBatchClassifier`" | 108 "the estimator must be a `KerasGBatchClassifier`" |
103 "equipped with data_batch_generator!") | 109 "equipped with data_batch_generator!" |
104 pyfaidx = get_module('pyfaidx') | 110 ) |
111 pyfaidx = get_module("pyfaidx") | |
105 sequences = pyfaidx.Fasta(fasta_path) | 112 sequences = pyfaidx.Fasta(fasta_path) |
106 n_seqs = len(sequences.keys()) | 113 n_seqs = len(sequences.keys()) |
107 X = np.arange(n_seqs)[:, np.newaxis] | 114 X = np.arange(n_seqs)[:, np.newaxis] |
108 seq_length = estimator.data_batch_generator.seq_length | 115 seq_length = estimator.data_batch_generator.seq_length |
109 batch_size = getattr(estimator, 'batch_size', 32) | 116 batch_size = getattr(estimator, "batch_size", 32) |
110 steps = (n_seqs + batch_size - 1) // batch_size | 117 steps = (n_seqs + batch_size - 1) // batch_size |
111 | 118 |
112 seq_type = params['input_options']['seq_type'] | 119 seq_type = params["input_options"]["seq_type"] |
113 klass = try_get_attr( | 120 klass = try_get_attr("galaxy_ml.preprocessors", seq_type) |
114 'galaxy_ml.preprocessors', seq_type) | 121 |
115 | 122 pred_data_generator = klass(fasta_path, seq_length=seq_length) |
116 pred_data_generator = klass( | 123 |
117 fasta_path, seq_length=seq_length) | 124 if params["method"] == "predict": |
118 | 125 preds = estimator.predict(X, data_generator=pred_data_generator, steps=steps) |
119 if params['method'] == 'predict': | 126 else: |
120 preds = estimator.predict( | 127 preds = estimator.predict_proba(X, data_generator=pred_data_generator, steps=steps) |
121 X, data_generator=pred_data_generator, steps=steps) | |
122 else: | |
123 preds = estimator.predict_proba( | |
124 X, data_generator=pred_data_generator, steps=steps) | |
125 | 128 |
126 # vcf input | 129 # vcf input |
127 elif input_type == 'variant_effect': | 130 elif input_type == "variant_effect": |
128 klass = try_get_attr('galaxy_ml.preprocessors', | 131 klass = try_get_attr("galaxy_ml.preprocessors", "GenomicVariantBatchGenerator") |
129 'GenomicVariantBatchGenerator') | 132 |
130 | 133 options = params["input_options"] |
131 options = params['input_options'] | 134 options.pop("selected_input") |
132 options.pop('selected_input') | 135 if options["blacklist_regions"] == "none": |
133 if options['blacklist_regions'] == 'none': | 136 options["blacklist_regions"] = None |
134 options['blacklist_regions'] = None | 137 |
135 | 138 pred_data_generator = klass(ref_genome_path=ref_seq, vcf_path=vcf_path, **options) |
136 pred_data_generator = klass( | |
137 ref_genome_path=ref_seq, vcf_path=vcf_path, **options) | |
138 | 139 |
139 pred_data_generator.set_processing_attrs() | 140 pred_data_generator.set_processing_attrs() |
140 | 141 |
141 variants = pred_data_generator.variants | 142 variants = pred_data_generator.variants |
142 | 143 |
143 # predict 1600 sample at once then write to file | 144 # predict 1600 sample at once then write to file |
144 gen_flow = pred_data_generator.flow(batch_size=1600) | 145 gen_flow = pred_data_generator.flow(batch_size=1600) |
145 | 146 |
146 file_writer = open(outfile_predict, 'w') | 147 file_writer = open(outfile_predict, "w") |
147 header_row = '\t'.join(['chrom', 'pos', 'name', 'ref', | 148 header_row = "\t".join(["chrom", "pos", "name", "ref", "alt", "strand"]) |
148 'alt', 'strand']) | |
149 file_writer.write(header_row) | 149 file_writer.write(header_row) |
150 header_done = False | 150 header_done = False |
151 | 151 |
152 steps_done = 0 | 152 steps_done = 0 |
153 | 153 |
154 # TODO: multiple threading | 154 # TODO: multiple threading |
155 try: | 155 try: |
156 while steps_done < len(gen_flow): | 156 while steps_done < len(gen_flow): |
157 index_array = next(gen_flow.index_generator) | 157 index_array = next(gen_flow.index_generator) |
158 batch_X = gen_flow._get_batches_of_transformed_samples( | 158 batch_X = gen_flow._get_batches_of_transformed_samples(index_array) |
159 index_array) | 159 |
160 | 160 if params["method"] == "predict": |
161 if params['method'] == 'predict': | |
162 batch_preds = estimator.predict( | 161 batch_preds = estimator.predict( |
163 batch_X, | 162 batch_X, |
164 # The presence of `pred_data_generator` below is to | 163 # The presence of `pred_data_generator` below is to |
165 # override model carrying data_generator if there | 164 # override model carrying data_generator if there |
166 # is any. | 165 # is any. |
167 data_generator=pred_data_generator) | 166 data_generator=pred_data_generator, |
167 ) | |
168 else: | 168 else: |
169 batch_preds = estimator.predict_proba( | 169 batch_preds = estimator.predict_proba( |
170 batch_X, | 170 batch_X, |
171 # The presence of `pred_data_generator` below is to | 171 # The presence of `pred_data_generator` below is to |
172 # override model carrying data_generator if there | 172 # override model carrying data_generator if there |
173 # is any. | 173 # is any. |
174 data_generator=pred_data_generator) | 174 data_generator=pred_data_generator, |
175 ) | |
175 | 176 |
176 if batch_preds.ndim == 1: | 177 if batch_preds.ndim == 1: |
177 batch_preds = batch_preds[:, np.newaxis] | 178 batch_preds = batch_preds[:, np.newaxis] |
178 | 179 |
179 batch_meta = variants[index_array] | 180 batch_meta = variants[index_array] |
180 batch_out = np.column_stack([batch_meta, batch_preds]) | 181 batch_out = np.column_stack([batch_meta, batch_preds]) |
181 | 182 |
182 if not header_done: | 183 if not header_done: |
183 heads = np.arange(batch_preds.shape[-1]).astype(str) | 184 heads = np.arange(batch_preds.shape[-1]).astype(str) |
184 heads_str = '\t'.join(heads) | 185 heads_str = "\t".join(heads) |
185 file_writer.write("\t%s\n" % heads_str) | 186 file_writer.write("\t%s\n" % heads_str) |
186 header_done = True | 187 header_done = True |
187 | 188 |
188 for row in batch_out: | 189 for row in batch_out: |
189 row_str = '\t'.join(row) | 190 row_str = "\t".join(row) |
190 file_writer.write("%s\n" % row_str) | 191 file_writer.write("%s\n" % row_str) |
191 | 192 |
192 steps_done += 1 | 193 steps_done += 1 |
193 | 194 |
194 finally: | 195 finally: |
198 return 0 | 199 return 0 |
199 # end input | 200 # end input |
200 | 201 |
201 # output | 202 # output |
202 if len(preds.shape) == 1: | 203 if len(preds.shape) == 1: |
203 rval = pd.DataFrame(preds, columns=['Predicted']) | 204 rval = pd.DataFrame(preds, columns=["Predicted"]) |
204 else: | 205 else: |
205 rval = pd.DataFrame(preds) | 206 rval = pd.DataFrame(preds) |
206 | 207 |
207 rval.to_csv(outfile_predict, sep='\t', header=True, index=False) | 208 rval.to_csv(outfile_predict, sep="\t", header=True, index=False) |
208 | 209 |
209 | 210 |
210 if __name__ == '__main__': | 211 if __name__ == "__main__": |
211 aparser = argparse.ArgumentParser() | 212 aparser = argparse.ArgumentParser() |
212 aparser.add_argument("-i", "--inputs", dest="inputs", required=True) | 213 aparser.add_argument("-i", "--inputs", dest="inputs", required=True) |
213 aparser.add_argument("-e", "--infile_estimator", dest="infile_estimator") | 214 aparser.add_argument("-e", "--infile_estimator", dest="infile_estimator") |
214 aparser.add_argument("-w", "--infile_weights", dest="infile_weights") | 215 aparser.add_argument("-w", "--infile_weights", dest="infile_weights") |
215 aparser.add_argument("-X", "--infile1", dest="infile1") | 216 aparser.add_argument("-X", "--infile1", dest="infile1") |
217 aparser.add_argument("-f", "--fasta_path", dest="fasta_path") | 218 aparser.add_argument("-f", "--fasta_path", dest="fasta_path") |
218 aparser.add_argument("-r", "--ref_seq", dest="ref_seq") | 219 aparser.add_argument("-r", "--ref_seq", dest="ref_seq") |
219 aparser.add_argument("-v", "--vcf_path", dest="vcf_path") | 220 aparser.add_argument("-v", "--vcf_path", dest="vcf_path") |
220 args = aparser.parse_args() | 221 args = aparser.parse_args() |
221 | 222 |
222 main(args.inputs, args.infile_estimator, args.outfile_predict, | 223 main( |
223 infile_weights=args.infile_weights, infile1=args.infile1, | 224 args.inputs, |
224 fasta_path=args.fasta_path, ref_seq=args.ref_seq, | 225 args.infile_estimator, |
225 vcf_path=args.vcf_path) | 226 args.outfile_predict, |
227 infile_weights=args.infile_weights, | |
228 infile1=args.infile1, | |
229 fasta_path=args.fasta_path, | |
230 ref_seq=args.ref_seq, | |
231 vcf_path=args.vcf_path, | |
232 ) |