comparison model_prediction.py @ 0:f96efab83b65 draft

"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit ba6a47bdf76bbf4cb276206ac1a8cbf61332fd16"
author bgruening
date Fri, 13 Sep 2019 12:23:39 -0400
parents
children 6b94d76a1397
comparison
equal deleted inserted replaced
-1:000000000000 0:f96efab83b65
1 import argparse
2 import json
3 import numpy as np
4 import pandas as pd
5 import tabix
6 import warnings
7
8 from scipy.io import mmread
9 from sklearn.pipeline import Pipeline
10
11 from galaxy_ml.externals.selene_sdk.sequences import Genome
12 from galaxy_ml.utils import (load_model, read_columns,
13 get_module, try_get_attr)
14
15
16 N_JOBS = int(__import__('os').environ.get('GALAXY_SLOTS', 1))
17
18
19 def main(inputs, infile_estimator, outfile_predict,
20 infile_weights=None, infile1=None,
21 fasta_path=None, ref_seq=None,
22 vcf_path=None):
23 """
24 Parameter
25 ---------
26 inputs : str
27 File path to galaxy tool parameter
28
29 infile_estimator : strgit
30 File path to trained estimator input
31
32 outfile_predict : str
33 File path to save the prediction results, tabular
34
35 infile_weights : str
36 File path to weights input
37
38 infile1 : str
39 File path to dataset containing features
40
41 fasta_path : str
42 File path to dataset containing fasta file
43
44 ref_seq : str
45 File path to dataset containing the reference genome sequence.
46
47 vcf_path : str
48 File path to dataset containing variants info.
49 """
50 warnings.filterwarnings('ignore')
51
52 with open(inputs, 'r') as param_handler:
53 params = json.load(param_handler)
54
55 # load model
56 with open(infile_estimator, 'rb') as est_handler:
57 estimator = load_model(est_handler)
58
59 main_est = estimator
60 if isinstance(estimator, Pipeline):
61 main_est = estimator.steps[-1][-1]
62 if hasattr(main_est, 'config') and hasattr(main_est, 'load_weights'):
63 if not infile_weights or infile_weights == 'None':
64 raise ValueError("The selected model skeleton asks for weights, "
65 "but dataset for weights wan not selected!")
66 main_est.load_weights(infile_weights)
67
68 # handle data input
69 input_type = params['input_options']['selected_input']
70 # tabular input
71 if input_type == 'tabular':
72 header = 'infer' if params['input_options']['header1'] else None
73 column_option = (params['input_options']
74 ['column_selector_options_1']
75 ['selected_column_selector_option'])
76 if column_option in ['by_index_number', 'all_but_by_index_number',
77 'by_header_name', 'all_but_by_header_name']:
78 c = params['input_options']['column_selector_options_1']['col1']
79 else:
80 c = None
81
82 df = pd.read_csv(infile1, sep='\t', header=header, parse_dates=True)
83
84 X = read_columns(df, c=c, c_option=column_option).astype(float)
85
86 if params['method'] == 'predict':
87 preds = estimator.predict(X)
88 else:
89 preds = estimator.predict_proba(X)
90
91 # sparse input
92 elif input_type == 'sparse':
93 X = mmread(open(infile1, 'r'))
94 if params['method'] == 'predict':
95 preds = estimator.predict(X)
96 else:
97 preds = estimator.predict_proba(X)
98
99 # fasta input
100 elif input_type == 'seq_fasta':
101 if not hasattr(estimator, 'data_batch_generator'):
102 raise ValueError(
103 "To do prediction on sequences in fasta input, "
104 "the estimator must be a `KerasGBatchClassifier`"
105 "equipped with data_batch_generator!")
106 pyfaidx = get_module('pyfaidx')
107 sequences = pyfaidx.Fasta(fasta_path)
108 n_seqs = len(sequences.keys())
109 X = np.arange(n_seqs)[:, np.newaxis]
110 seq_length = estimator.data_batch_generator.seq_length
111 batch_size = getattr(estimator, 'batch_size', 32)
112 steps = (n_seqs + batch_size - 1) // batch_size
113
114 seq_type = params['input_options']['seq_type']
115 klass = try_get_attr(
116 'galaxy_ml.preprocessors', seq_type)
117
118 pred_data_generator = klass(
119 fasta_path, seq_length=seq_length)
120
121 if params['method'] == 'predict':
122 preds = estimator.predict(
123 X, data_generator=pred_data_generator, steps=steps)
124 else:
125 preds = estimator.predict_proba(
126 X, data_generator=pred_data_generator, steps=steps)
127
128 # vcf input
129 elif input_type == 'variant_effect':
130 klass = try_get_attr('galaxy_ml.preprocessors',
131 'GenomicVariantBatchGenerator')
132
133 options = params['input_options']
134 options.pop('selected_input')
135 if options['blacklist_regions'] == 'none':
136 options['blacklist_regions'] = None
137
138 pred_data_generator = klass(
139 ref_genome_path=ref_seq, vcf_path=vcf_path, **options)
140
141 pred_data_generator.fit()
142
143 variants = pred_data_generator.variants
144 # TODO : remove the following block after galaxy-ml v0.7.13
145 blacklist_tabix = getattr(pred_data_generator.reference_genome_,
146 '_blacklist_tabix', None)
147 clean_variants = []
148 if blacklist_tabix:
149 start_radius = pred_data_generator.start_radius_
150 end_radius = pred_data_generator.end_radius_
151
152 for chrom, pos, name, ref, alt, strand in variants:
153 center = pos + len(ref) // 2
154 start = center - start_radius
155 end = center + end_radius
156
157 if isinstance(pred_data_generator.reference_genome_, Genome):
158 if "chr" not in chrom:
159 chrom = "chr" + chrom
160 if "MT" in chrom:
161 chrom = chrom[:-1]
162 try:
163 rows = blacklist_tabix.query(chrom, start, end)
164 found = 0
165 for row in rows:
166 found = 1
167 break
168 if found:
169 continue
170 except tabix.TabixError:
171 pass
172
173 clean_variants.append((chrom, pos, name, ref, alt, strand))
174 else:
175 clean_variants = variants
176
177 setattr(pred_data_generator, 'variants', clean_variants)
178
179 variants = np.array(clean_variants)
180 # predict 1600 sample at once then write to file
181 gen_flow = pred_data_generator.flow(batch_size=1600)
182
183 file_writer = open(outfile_predict, 'w')
184 header_row = '\t'.join(['chrom', 'pos', 'name', 'ref',
185 'alt', 'strand'])
186 file_writer.write(header_row)
187 header_done = False
188
189 steps_done = 0
190
191 # TODO: multiple threading
192 try:
193 while steps_done < len(gen_flow):
194 index_array = next(gen_flow.index_generator)
195 batch_X = gen_flow._get_batches_of_transformed_samples(
196 index_array)
197
198 if params['method'] == 'predict':
199 batch_preds = estimator.predict(
200 batch_X,
201 # The presence of `pred_data_generator` below is to
202 # override model carrying data_generator if there
203 # is any.
204 data_generator=pred_data_generator)
205 else:
206 batch_preds = estimator.predict_proba(
207 batch_X,
208 # The presence of `pred_data_generator` below is to
209 # override model carrying data_generator if there
210 # is any.
211 data_generator=pred_data_generator)
212
213 if batch_preds.ndim == 1:
214 batch_preds = batch_preds[:, np.newaxis]
215
216 batch_meta = variants[index_array]
217 batch_out = np.column_stack([batch_meta, batch_preds])
218
219 if not header_done:
220 heads = np.arange(batch_preds.shape[-1]).astype(str)
221 heads_str = '\t'.join(heads)
222 file_writer.write("\t%s\n" % heads_str)
223 header_done = True
224
225 for row in batch_out:
226 row_str = '\t'.join(row)
227 file_writer.write("%s\n" % row_str)
228
229 steps_done += 1
230
231 finally:
232 file_writer.close()
233 # TODO: make api `pred_data_generator.close()`
234 pred_data_generator.close()
235 return 0
236 # end input
237
238 # output
239 if len(preds.shape) == 1:
240 rval = pd.DataFrame(preds, columns=['Predicted'])
241 else:
242 rval = pd.DataFrame(preds)
243
244 rval.to_csv(outfile_predict, sep='\t', header=True, index=False)
245
246
247 if __name__ == '__main__':
248 aparser = argparse.ArgumentParser()
249 aparser.add_argument("-i", "--inputs", dest="inputs", required=True)
250 aparser.add_argument("-e", "--infile_estimator", dest="infile_estimator")
251 aparser.add_argument("-w", "--infile_weights", dest="infile_weights")
252 aparser.add_argument("-X", "--infile1", dest="infile1")
253 aparser.add_argument("-O", "--outfile_predict", dest="outfile_predict")
254 aparser.add_argument("-f", "--fasta_path", dest="fasta_path")
255 aparser.add_argument("-r", "--ref_seq", dest="ref_seq")
256 aparser.add_argument("-v", "--vcf_path", dest="vcf_path")
257 args = aparser.parse_args()
258
259 main(args.inputs, args.infile_estimator, args.outfile_predict,
260 infile_weights=args.infile_weights, infile1=args.infile1,
261 fasta_path=args.fasta_path, ref_seq=args.ref_seq,
262 vcf_path=args.vcf_path)