Mercurial > repos > bgruening > sklearn_discriminant_classifier
comparison model_prediction.py @ 26:9bb505eafac9 draft
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 60f0fbc0eafd7c11bc60fb6c77f2937782efd8a9-dirty
author | bgruening |
---|---|
date | Fri, 09 Aug 2019 07:06:17 -0400 |
parents | |
children | 8e49f26b14d3 |
comparison
equal
deleted
inserted
replaced
25:3e2921875c58 | 26:9bb505eafac9 |
---|---|
1 import argparse | |
2 import json | |
3 import numpy as np | |
4 import pandas as pd | |
5 import warnings | |
6 | |
7 from scipy.io import mmread | |
8 from sklearn.pipeline import Pipeline | |
9 | |
10 from galaxy_ml.utils import (load_model, read_columns, | |
11 get_module, try_get_attr) | |
12 | |
13 | |
14 N_JOBS = int(__import__('os').environ.get('GALAXY_SLOTS', 1)) | |
15 | |
16 | |
17 def main(inputs, infile_estimator, outfile_predict, | |
18 infile_weights=None, infile1=None, | |
19 fasta_path=None, ref_seq=None, | |
20 vcf_path=None): | |
21 """ | |
22 Parameter | |
23 --------- | |
24 inputs : str | |
25 File path to galaxy tool parameter | |
26 | |
27 infile_estimator : strgit | |
28 File path to trained estimator input | |
29 | |
30 outfile_predict : str | |
31 File path to save the prediction results, tabular | |
32 | |
33 infile_weights : str | |
34 File path to weights input | |
35 | |
36 infile1 : str | |
37 File path to dataset containing features | |
38 | |
39 fasta_path : str | |
40 File path to dataset containing fasta file | |
41 | |
42 ref_seq : str | |
43 File path to dataset containing the reference genome sequence. | |
44 | |
45 vcf_path : str | |
46 File path to dataset containing variants info. | |
47 """ | |
48 warnings.filterwarnings('ignore') | |
49 | |
50 with open(inputs, 'r') as param_handler: | |
51 params = json.load(param_handler) | |
52 | |
53 # load model | |
54 with open(infile_estimator, 'rb') as est_handler: | |
55 estimator = load_model(est_handler) | |
56 | |
57 main_est = estimator | |
58 if isinstance(estimator, Pipeline): | |
59 main_est = estimator.steps[-1][-1] | |
60 if hasattr(main_est, 'config') and hasattr(main_est, 'load_weights'): | |
61 if not infile_weights or infile_weights == 'None': | |
62 raise ValueError("The selected model skeleton asks for weights, " | |
63 "but dataset for weights wan not selected!") | |
64 main_est.load_weights(infile_weights) | |
65 | |
66 # handle data input | |
67 input_type = params['input_options']['selected_input'] | |
68 # tabular input | |
69 if input_type == 'tabular': | |
70 header = 'infer' if params['input_options']['header1'] else None | |
71 column_option = (params['input_options'] | |
72 ['column_selector_options_1'] | |
73 ['selected_column_selector_option']) | |
74 if column_option in ['by_index_number', 'all_but_by_index_number', | |
75 'by_header_name', 'all_but_by_header_name']: | |
76 c = params['input_options']['column_selector_options_1']['col1'] | |
77 else: | |
78 c = None | |
79 | |
80 df = pd.read_csv(infile1, sep='\t', header=header, parse_dates=True) | |
81 | |
82 X = read_columns(df, c=c, c_option=column_option).astype(float) | |
83 | |
84 if params['method'] == 'predict': | |
85 preds = estimator.predict(X) | |
86 else: | |
87 preds = estimator.predict_proba(X) | |
88 | |
89 # sparse input | |
90 elif input_type == 'sparse': | |
91 X = mmread(open(infile1, 'r')) | |
92 if params['method'] == 'predict': | |
93 preds = estimator.predict(X) | |
94 else: | |
95 preds = estimator.predict_proba(X) | |
96 | |
97 # fasta input | |
98 elif input_type == 'seq_fasta': | |
99 if not hasattr(estimator, 'data_batch_generator'): | |
100 raise ValueError( | |
101 "To do prediction on sequences in fasta input, " | |
102 "the estimator must be a `KerasGBatchClassifier`" | |
103 "equipped with data_batch_generator!") | |
104 pyfaidx = get_module('pyfaidx') | |
105 sequences = pyfaidx.Fasta(fasta_path) | |
106 n_seqs = len(sequences.keys()) | |
107 X = np.arange(n_seqs)[:, np.newaxis] | |
108 seq_length = estimator.data_batch_generator.seq_length | |
109 batch_size = getattr(estimator, 'batch_size', 32) | |
110 steps = (n_seqs + batch_size - 1) // batch_size | |
111 | |
112 seq_type = params['input_options']['seq_type'] | |
113 klass = try_get_attr( | |
114 'galaxy_ml.preprocessors', seq_type) | |
115 | |
116 pred_data_generator = klass( | |
117 fasta_path, seq_length=seq_length) | |
118 | |
119 if params['method'] == 'predict': | |
120 preds = estimator.predict( | |
121 X, data_generator=pred_data_generator, steps=steps) | |
122 else: | |
123 preds = estimator.predict_proba( | |
124 X, data_generator=pred_data_generator, steps=steps) | |
125 | |
126 # vcf input | |
127 elif input_type == 'variant_effect': | |
128 klass = try_get_attr('galaxy_ml.preprocessors', | |
129 'GenomicVariantBatchGenerator') | |
130 | |
131 options = params['input_options'] | |
132 options.pop('selected_input') | |
133 if options['blacklist_regions'] == 'none': | |
134 options['blacklist_regions'] = None | |
135 | |
136 pred_data_generator = klass( | |
137 ref_genome_path=ref_seq, vcf_path=vcf_path, **options) | |
138 | |
139 pred_data_generator.fit() | |
140 | |
141 preds = estimator.model_.predict_generator( | |
142 pred_data_generator.flow(batch_size=32), | |
143 workers=N_JOBS, | |
144 use_multiprocessing=True) | |
145 | |
146 if preds.min() < 0. or preds.max() > 1.: | |
147 warnings.warn('Network returning invalid probability values. ' | |
148 'The last layer might not normalize predictions ' | |
149 'into probabilities ' | |
150 '(like softmax or sigmoid would).') | |
151 | |
152 if params['method'] == 'predict_proba' and preds.shape[1] == 1: | |
153 # first column is probability of class 0 and second is of class 1 | |
154 preds = np.hstack([1 - preds, preds]) | |
155 | |
156 elif params['method'] == 'predict': | |
157 if preds.shape[-1] > 1: | |
158 # if the last activation is `softmax`, the sum of all | |
159 # probibilities will 1, the classification is considered as | |
160 # multi-class problem, otherwise, we take it as multi-label. | |
161 act = getattr(estimator.model_.layers[-1], 'activation', None) | |
162 if act and act.__name__ == 'softmax': | |
163 classes = preds.argmax(axis=-1) | |
164 else: | |
165 preds = (preds > 0.5).astype('int32') | |
166 else: | |
167 classes = (preds > 0.5).astype('int32') | |
168 | |
169 preds = estimator.classes_[classes] | |
170 # end input | |
171 | |
172 # output | |
173 if input_type == 'variant_effect': # TODO: save in batchs | |
174 rval = pd.DataFrame(preds) | |
175 meta = pd.DataFrame( | |
176 pred_data_generator.variants, | |
177 columns=['chrom', 'pos', 'name', 'ref', 'alt', 'strand']) | |
178 | |
179 rval = pd.concat([meta, rval], axis=1) | |
180 | |
181 elif len(preds.shape) == 1: | |
182 rval = pd.DataFrame(preds, columns=['Predicted']) | |
183 else: | |
184 rval = pd.DataFrame(preds) | |
185 | |
186 rval.to_csv(outfile_predict, sep='\t', | |
187 header=True, index=False) | |
188 | |
189 | |
190 if __name__ == '__main__': | |
191 aparser = argparse.ArgumentParser() | |
192 aparser.add_argument("-i", "--inputs", dest="inputs", required=True) | |
193 aparser.add_argument("-e", "--infile_estimator", dest="infile_estimator") | |
194 aparser.add_argument("-w", "--infile_weights", dest="infile_weights") | |
195 aparser.add_argument("-X", "--infile1", dest="infile1") | |
196 aparser.add_argument("-O", "--outfile_predict", dest="outfile_predict") | |
197 aparser.add_argument("-f", "--fasta_path", dest="fasta_path") | |
198 aparser.add_argument("-r", "--ref_seq", dest="ref_seq") | |
199 aparser.add_argument("-v", "--vcf_path", dest="vcf_path") | |
200 args = aparser.parse_args() | |
201 | |
202 main(args.inputs, args.infile_estimator, args.outfile_predict, | |
203 infile_weights=args.infile_weights, infile1=args.infile1, | |
204 fasta_path=args.fasta_path, ref_seq=args.ref_seq, | |
205 vcf_path=args.vcf_path) |