comparison model_prediction.py @ 1:0fd7d8e90e2a draft

"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit ba6a47bdf76bbf4cb276206ac1a8cbf61332fd16"
author bgruening
date Fri, 13 Sep 2019 12:19:45 -0400
parents 1046cf73236b
children c3813c64d678
comparison
equal deleted inserted replaced
0:1046cf73236b 1:0fd7d8e90e2a
1 import argparse 1 import argparse
2 import json 2 import json
3 import numpy as np 3 import numpy as np
4 import pandas as pd 4 import pandas as pd
5 import tabix
5 import warnings 6 import warnings
6 7
7 from scipy.io import mmread 8 from scipy.io import mmread
8 from sklearn.pipeline import Pipeline 9 from sklearn.pipeline import Pipeline
9 10
11 from galaxy_ml.externals.selene_sdk.sequences import Genome
10 from galaxy_ml.utils import (load_model, read_columns, 12 from galaxy_ml.utils import (load_model, read_columns,
11 get_module, try_get_attr) 13 get_module, try_get_attr)
12 14
13 15
14 N_JOBS = int(__import__('os').environ.get('GALAXY_SLOTS', 1)) 16 N_JOBS = int(__import__('os').environ.get('GALAXY_SLOTS', 1))
136 pred_data_generator = klass( 138 pred_data_generator = klass(
137 ref_genome_path=ref_seq, vcf_path=vcf_path, **options) 139 ref_genome_path=ref_seq, vcf_path=vcf_path, **options)
138 140
139 pred_data_generator.fit() 141 pred_data_generator.fit()
140 142
141 preds = estimator.model_.predict_generator( 143 variants = pred_data_generator.variants
142 pred_data_generator.flow(batch_size=32), 144 # TODO : remove the following block after galaxy-ml v0.7.13
143 workers=N_JOBS, 145 blacklist_tabix = getattr(pred_data_generator.reference_genome_,
144 use_multiprocessing=True) 146 '_blacklist_tabix', None)
145 147 clean_variants = []
146 if preds.min() < 0. or preds.max() > 1.: 148 if blacklist_tabix:
147 warnings.warn('Network returning invalid probability values. ' 149 start_radius = pred_data_generator.start_radius_
148 'The last layer might not normalize predictions ' 150 end_radius = pred_data_generator.end_radius_
149 'into probabilities ' 151
150 '(like softmax or sigmoid would).') 152 for chrom, pos, name, ref, alt, strand in variants:
151 153 center = pos + len(ref) // 2
152 if params['method'] == 'predict_proba' and preds.shape[1] == 1: 154 start = center - start_radius
153 # first column is probability of class 0 and second is of class 1 155 end = center + end_radius
154 preds = np.hstack([1 - preds, preds]) 156
155 157 if isinstance(pred_data_generator.reference_genome_, Genome):
156 elif params['method'] == 'predict': 158 if "chr" not in chrom:
157 if preds.shape[-1] > 1: 159 chrom = "chr" + chrom
158 # if the last activation is `softmax`, the sum of all 160 if "MT" in chrom:
159 # probibilities will 1, the classification is considered as 161 chrom = chrom[:-1]
160 # multi-class problem, otherwise, we take it as multi-label. 162 try:
161 act = getattr(estimator.model_.layers[-1], 'activation', None) 163 rows = blacklist_tabix.query(chrom, start, end)
162 if act and act.__name__ == 'softmax': 164 found = 0
163 classes = preds.argmax(axis=-1) 165 for row in rows:
166 found = 1
167 break
168 if found:
169 continue
170 except tabix.TabixError:
171 pass
172
173 clean_variants.append((chrom, pos, name, ref, alt, strand))
174 else:
175 clean_variants = variants
176
177 setattr(pred_data_generator, 'variants', clean_variants)
178
179 variants = np.array(clean_variants)
180 # predict 1600 sample at once then write to file
181 gen_flow = pred_data_generator.flow(batch_size=1600)
182
183 file_writer = open(outfile_predict, 'w')
184 header_row = '\t'.join(['chrom', 'pos', 'name', 'ref',
185 'alt', 'strand'])
186 file_writer.write(header_row)
187 header_done = False
188
189 steps_done = 0
190
191 # TODO: multiple threading
192 try:
193 while steps_done < len(gen_flow):
194 index_array = next(gen_flow.index_generator)
195 batch_X = gen_flow._get_batches_of_transformed_samples(
196 index_array)
197
198 if params['method'] == 'predict':
199 batch_preds = estimator.predict(
200 batch_X,
201 # The presence of `pred_data_generator` below is to
202 # override model carrying data_generator if there
203 # is any.
204 data_generator=pred_data_generator)
164 else: 205 else:
165 preds = (preds > 0.5).astype('int32') 206 batch_preds = estimator.predict_proba(
166 else: 207 batch_X,
167 classes = (preds > 0.5).astype('int32') 208 # The presence of `pred_data_generator` below is to
168 209 # override model carrying data_generator if there
169 preds = estimator.classes_[classes] 210 # is any.
211 data_generator=pred_data_generator)
212
213 if batch_preds.ndim == 1:
214 batch_preds = batch_preds[:, np.newaxis]
215
216 batch_meta = variants[index_array]
217 batch_out = np.column_stack([batch_meta, batch_preds])
218
219 if not header_done:
220 heads = np.arange(batch_preds.shape[-1]).astype(str)
221 heads_str = '\t'.join(heads)
222 file_writer.write("\t%s\n" % heads_str)
223 header_done = True
224
225 for row in batch_out:
226 row_str = '\t'.join(row)
227 file_writer.write("%s\n" % row_str)
228
229 steps_done += 1
230
231 finally:
232 file_writer.close()
233 # TODO: make api `pred_data_generator.close()`
234 pred_data_generator.close()
235 return 0
170 # end input 236 # end input
171 237
172 # output 238 # output
173 if input_type == 'variant_effect': # TODO: save in batchs 239 if len(preds.shape) == 1:
174 rval = pd.DataFrame(preds)
175 meta = pd.DataFrame(
176 pred_data_generator.variants,
177 columns=['chrom', 'pos', 'name', 'ref', 'alt', 'strand'])
178
179 rval = pd.concat([meta, rval], axis=1)
180
181 elif len(preds.shape) == 1:
182 rval = pd.DataFrame(preds, columns=['Predicted']) 240 rval = pd.DataFrame(preds, columns=['Predicted'])
183 else: 241 else:
184 rval = pd.DataFrame(preds) 242 rval = pd.DataFrame(preds)
185 243
186 rval.to_csv(outfile_predict, sep='\t', 244 rval.to_csv(outfile_predict, sep='\t', header=True, index=False)
187 header=True, index=False)
188 245
189 246
190 if __name__ == '__main__': 247 if __name__ == '__main__':
191 aparser = argparse.ArgumentParser() 248 aparser = argparse.ArgumentParser()
192 aparser.add_argument("-i", "--inputs", dest="inputs", required=True) 249 aparser.add_argument("-i", "--inputs", dest="inputs", required=True)