Mercurial > repos > bgruening > sklearn_estimator_attributes
comparison model_prediction.py @ 3:7a64b9f39a46 draft
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit ba6a47bdf76bbf4cb276206ac1a8cbf61332fd16"
author | bgruening |
---|---|
date | Fri, 13 Sep 2019 12:15:10 -0400 |
parents | c411ff569a26 |
children | 7a9a9349eb42 |
comparison
equal
deleted
inserted
replaced
2:c411ff569a26 | 3:7a64b9f39a46 |
---|---|
1 import argparse | 1 import argparse |
2 import json | 2 import json |
3 import numpy as np | 3 import numpy as np |
4 import pandas as pd | 4 import pandas as pd |
5 import tabix | |
5 import warnings | 6 import warnings |
6 | 7 |
7 from scipy.io import mmread | 8 from scipy.io import mmread |
8 from sklearn.pipeline import Pipeline | 9 from sklearn.pipeline import Pipeline |
9 | 10 |
11 from galaxy_ml.externals.selene_sdk.sequences import Genome | |
10 from galaxy_ml.utils import (load_model, read_columns, | 12 from galaxy_ml.utils import (load_model, read_columns, |
11 get_module, try_get_attr) | 13 get_module, try_get_attr) |
12 | 14 |
13 | 15 |
14 N_JOBS = int(__import__('os').environ.get('GALAXY_SLOTS', 1)) | 16 N_JOBS = int(__import__('os').environ.get('GALAXY_SLOTS', 1)) |
136 pred_data_generator = klass( | 138 pred_data_generator = klass( |
137 ref_genome_path=ref_seq, vcf_path=vcf_path, **options) | 139 ref_genome_path=ref_seq, vcf_path=vcf_path, **options) |
138 | 140 |
139 pred_data_generator.fit() | 141 pred_data_generator.fit() |
140 | 142 |
141 preds = estimator.model_.predict_generator( | 143 variants = pred_data_generator.variants |
142 pred_data_generator.flow(batch_size=32), | 144 # TODO : remove the following block after galaxy-ml v0.7.13 |
143 workers=N_JOBS, | 145 blacklist_tabix = getattr(pred_data_generator.reference_genome_, |
144 use_multiprocessing=True) | 146 '_blacklist_tabix', None) |
145 | 147 clean_variants = [] |
146 if preds.min() < 0. or preds.max() > 1.: | 148 if blacklist_tabix: |
147 warnings.warn('Network returning invalid probability values. ' | 149 start_radius = pred_data_generator.start_radius_ |
148 'The last layer might not normalize predictions ' | 150 end_radius = pred_data_generator.end_radius_ |
149 'into probabilities ' | 151 |
150 '(like softmax or sigmoid would).') | 152 for chrom, pos, name, ref, alt, strand in variants: |
151 | 153 center = pos + len(ref) // 2 |
152 if params['method'] == 'predict_proba' and preds.shape[1] == 1: | 154 start = center - start_radius |
153 # first column is probability of class 0 and second is of class 1 | 155 end = center + end_radius |
154 preds = np.hstack([1 - preds, preds]) | 156 |
155 | 157 if isinstance(pred_data_generator.reference_genome_, Genome): |
156 elif params['method'] == 'predict': | 158 if "chr" not in chrom: |
157 if preds.shape[-1] > 1: | 159 chrom = "chr" + chrom |
158 # if the last activation is `softmax`, the sum of all | 160 if "MT" in chrom: |
159 # probibilities will 1, the classification is considered as | 161 chrom = chrom[:-1] |
160 # multi-class problem, otherwise, we take it as multi-label. | 162 try: |
161 act = getattr(estimator.model_.layers[-1], 'activation', None) | 163 rows = blacklist_tabix.query(chrom, start, end) |
162 if act and act.__name__ == 'softmax': | 164 found = 0 |
163 classes = preds.argmax(axis=-1) | 165 for row in rows: |
166 found = 1 | |
167 break | |
168 if found: | |
169 continue | |
170 except tabix.TabixError: | |
171 pass | |
172 | |
173 clean_variants.append((chrom, pos, name, ref, alt, strand)) | |
174 else: | |
175 clean_variants = variants | |
176 | |
177 setattr(pred_data_generator, 'variants', clean_variants) | |
178 | |
179 variants = np.array(clean_variants) | |
180 # predict 1600 sample at once then write to file | |
181 gen_flow = pred_data_generator.flow(batch_size=1600) | |
182 | |
183 file_writer = open(outfile_predict, 'w') | |
184 header_row = '\t'.join(['chrom', 'pos', 'name', 'ref', | |
185 'alt', 'strand']) | |
186 file_writer.write(header_row) | |
187 header_done = False | |
188 | |
189 steps_done = 0 | |
190 | |
191 # TODO: multiple threading | |
192 try: | |
193 while steps_done < len(gen_flow): | |
194 index_array = next(gen_flow.index_generator) | |
195 batch_X = gen_flow._get_batches_of_transformed_samples( | |
196 index_array) | |
197 | |
198 if params['method'] == 'predict': | |
199 batch_preds = estimator.predict( | |
200 batch_X, | |
201 # The presence of `pred_data_generator` below is to | |
202 # override model carrying data_generator if there | |
203 # is any. | |
204 data_generator=pred_data_generator) | |
164 else: | 205 else: |
165 preds = (preds > 0.5).astype('int32') | 206 batch_preds = estimator.predict_proba( |
166 else: | 207 batch_X, |
167 classes = (preds > 0.5).astype('int32') | 208 # The presence of `pred_data_generator` below is to |
168 | 209 # override model carrying data_generator if there |
169 preds = estimator.classes_[classes] | 210 # is any. |
211 data_generator=pred_data_generator) | |
212 | |
213 if batch_preds.ndim == 1: | |
214 batch_preds = batch_preds[:, np.newaxis] | |
215 | |
216 batch_meta = variants[index_array] | |
217 batch_out = np.column_stack([batch_meta, batch_preds]) | |
218 | |
219 if not header_done: | |
220 heads = np.arange(batch_preds.shape[-1]).astype(str) | |
221 heads_str = '\t'.join(heads) | |
222 file_writer.write("\t%s\n" % heads_str) | |
223 header_done = True | |
224 | |
225 for row in batch_out: | |
226 row_str = '\t'.join(row) | |
227 file_writer.write("%s\n" % row_str) | |
228 | |
229 steps_done += 1 | |
230 | |
231 finally: | |
232 file_writer.close() | |
233 # TODO: make api `pred_data_generator.close()` | |
234 pred_data_generator.close() | |
235 return 0 | |
170 # end input | 236 # end input |
171 | 237 |
172 # output | 238 # output |
173 if input_type == 'variant_effect': # TODO: save in batchs | 239 if len(preds.shape) == 1: |
174 rval = pd.DataFrame(preds) | |
175 meta = pd.DataFrame( | |
176 pred_data_generator.variants, | |
177 columns=['chrom', 'pos', 'name', 'ref', 'alt', 'strand']) | |
178 | |
179 rval = pd.concat([meta, rval], axis=1) | |
180 | |
181 elif len(preds.shape) == 1: | |
182 rval = pd.DataFrame(preds, columns=['Predicted']) | 240 rval = pd.DataFrame(preds, columns=['Predicted']) |
183 else: | 241 else: |
184 rval = pd.DataFrame(preds) | 242 rval = pd.DataFrame(preds) |
185 | 243 |
186 rval.to_csv(outfile_predict, sep='\t', | 244 rval.to_csv(outfile_predict, sep='\t', header=True, index=False) |
187 header=True, index=False) | |
188 | 245 |
189 | 246 |
190 if __name__ == '__main__': | 247 if __name__ == '__main__': |
191 aparser = argparse.ArgumentParser() | 248 aparser = argparse.ArgumentParser() |
192 aparser.add_argument("-i", "--inputs", dest="inputs", required=True) | 249 aparser.add_argument("-i", "--inputs", dest="inputs", required=True) |