comparison model_prediction.py @ 9:4471d2b2de79 draft

"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit e2a5eade6d0e5ddf3a47630381a0ad90d80e8a04"
author bgruening
date Tue, 13 Apr 2021 18:25:32 +0000
parents c3813c64d678
children 64bbfa592868
comparison
equal deleted inserted replaced
8:2a853a2d2e72 9:4471d2b2de79
1 import argparse 1 import argparse
2 import json 2 import json
3 import warnings
4
3 import numpy as np 5 import numpy as np
4 import pandas as pd 6 import pandas as pd
5 import warnings
6
7 from scipy.io import mmread 7 from scipy.io import mmread
8 from sklearn.pipeline import Pipeline 8 from sklearn.pipeline import Pipeline
9 9
10 from galaxy_ml.utils import (load_model, read_columns, 10 from galaxy_ml.utils import (get_module, load_model,
11 get_module, try_get_attr) 11 read_columns, try_get_attr)
12 12
13 13
14 N_JOBS = int(__import__('os').environ.get('GALAXY_SLOTS', 1)) 14 N_JOBS = int(__import__("os").environ.get("GALAXY_SLOTS", 1))
15 15
16 16
17 def main(inputs, infile_estimator, outfile_predict, 17 def main(
18 infile_weights=None, infile1=None, 18 inputs,
19 fasta_path=None, ref_seq=None, 19 infile_estimator,
20 vcf_path=None): 20 outfile_predict,
21 infile_weights=None,
22 infile1=None,
23 fasta_path=None,
24 ref_seq=None,
25 vcf_path=None,
26 ):
21 """ 27 """
22 Parameter 28 Parameter
23 --------- 29 ---------
24 inputs : str 30 inputs : str
25 File path to galaxy tool parameter 31 File path to galaxy tool parameter
43 File path to dataset containing the reference genome sequence. 49 File path to dataset containing the reference genome sequence.
44 50
45 vcf_path : str 51 vcf_path : str
46 File path to dataset containing variants info. 52 File path to dataset containing variants info.
47 """ 53 """
48 warnings.filterwarnings('ignore') 54 warnings.filterwarnings("ignore")
49 55
50 with open(inputs, 'r') as param_handler: 56 with open(inputs, "r") as param_handler:
51 params = json.load(param_handler) 57 params = json.load(param_handler)
52 58
53 # load model 59 # load model
54 with open(infile_estimator, 'rb') as est_handler: 60 with open(infile_estimator, "rb") as est_handler:
55 estimator = load_model(est_handler) 61 estimator = load_model(est_handler)
56 62
57 main_est = estimator 63 main_est = estimator
58 if isinstance(estimator, Pipeline): 64 if isinstance(estimator, Pipeline):
59 main_est = estimator.steps[-1][-1] 65 main_est = estimator.steps[-1][-1]
60 if hasattr(main_est, 'config') and hasattr(main_est, 'load_weights'): 66 if hasattr(main_est, "config") and hasattr(main_est, "load_weights"):
61 if not infile_weights or infile_weights == 'None': 67 if not infile_weights or infile_weights == "None":
62 raise ValueError("The selected model skeleton asks for weights, " 68 raise ValueError(
63 "but dataset for weights wan not selected!") 69 "The selected model skeleton asks for weights, " "but dataset for weights wan not selected!"
70 )
64 main_est.load_weights(infile_weights) 71 main_est.load_weights(infile_weights)
65 72
66 # handle data input 73 # handle data input
67 input_type = params['input_options']['selected_input'] 74 input_type = params["input_options"]["selected_input"]
68 # tabular input 75 # tabular input
69 if input_type == 'tabular': 76 if input_type == "tabular":
70 header = 'infer' if params['input_options']['header1'] else None 77 header = "infer" if params["input_options"]["header1"] else None
71 column_option = (params['input_options'] 78 column_option = params["input_options"]["column_selector_options_1"]["selected_column_selector_option"]
72 ['column_selector_options_1'] 79 if column_option in [
73 ['selected_column_selector_option']) 80 "by_index_number",
74 if column_option in ['by_index_number', 'all_but_by_index_number', 81 "all_but_by_index_number",
75 'by_header_name', 'all_but_by_header_name']: 82 "by_header_name",
76 c = params['input_options']['column_selector_options_1']['col1'] 83 "all_but_by_header_name",
84 ]:
85 c = params["input_options"]["column_selector_options_1"]["col1"]
77 else: 86 else:
78 c = None 87 c = None
79 88
80 df = pd.read_csv(infile1, sep='\t', header=header, parse_dates=True) 89 df = pd.read_csv(infile1, sep="\t", header=header, parse_dates=True)
81 90
82 X = read_columns(df, c=c, c_option=column_option).astype(float) 91 X = read_columns(df, c=c, c_option=column_option).astype(float)
83 92
84 if params['method'] == 'predict': 93 if params["method"] == "predict":
85 preds = estimator.predict(X) 94 preds = estimator.predict(X)
86 else: 95 else:
87 preds = estimator.predict_proba(X) 96 preds = estimator.predict_proba(X)
88 97
89 # sparse input 98 # sparse input
90 elif input_type == 'sparse': 99 elif input_type == "sparse":
91 X = mmread(open(infile1, 'r')) 100 X = mmread(open(infile1, "r"))
92 if params['method'] == 'predict': 101 if params["method"] == "predict":
93 preds = estimator.predict(X) 102 preds = estimator.predict(X)
94 else: 103 else:
95 preds = estimator.predict_proba(X) 104 preds = estimator.predict_proba(X)
96 105
97 # fasta input 106 # fasta input
98 elif input_type == 'seq_fasta': 107 elif input_type == "seq_fasta":
99 if not hasattr(estimator, 'data_batch_generator'): 108 if not hasattr(estimator, "data_batch_generator"):
100 raise ValueError( 109 raise ValueError(
101 "To do prediction on sequences in fasta input, " 110 "To do prediction on sequences in fasta input, "
102 "the estimator must be a `KerasGBatchClassifier`" 111 "the estimator must be a `KerasGBatchClassifier`"
103 "equipped with data_batch_generator!") 112 "equipped with data_batch_generator!"
104 pyfaidx = get_module('pyfaidx') 113 )
114 pyfaidx = get_module("pyfaidx")
105 sequences = pyfaidx.Fasta(fasta_path) 115 sequences = pyfaidx.Fasta(fasta_path)
106 n_seqs = len(sequences.keys()) 116 n_seqs = len(sequences.keys())
107 X = np.arange(n_seqs)[:, np.newaxis] 117 X = np.arange(n_seqs)[:, np.newaxis]
108 seq_length = estimator.data_batch_generator.seq_length 118 seq_length = estimator.data_batch_generator.seq_length
109 batch_size = getattr(estimator, 'batch_size', 32) 119 batch_size = getattr(estimator, "batch_size", 32)
110 steps = (n_seqs + batch_size - 1) // batch_size 120 steps = (n_seqs + batch_size - 1) // batch_size
111 121
112 seq_type = params['input_options']['seq_type'] 122 seq_type = params["input_options"]["seq_type"]
113 klass = try_get_attr( 123 klass = try_get_attr("galaxy_ml.preprocessors", seq_type)
114 'galaxy_ml.preprocessors', seq_type) 124
115 125 pred_data_generator = klass(fasta_path, seq_length=seq_length)
116 pred_data_generator = klass( 126
117 fasta_path, seq_length=seq_length) 127 if params["method"] == "predict":
118 128 preds = estimator.predict(X, data_generator=pred_data_generator, steps=steps)
119 if params['method'] == 'predict': 129 else:
120 preds = estimator.predict( 130 preds = estimator.predict_proba(X, data_generator=pred_data_generator, steps=steps)
121 X, data_generator=pred_data_generator, steps=steps)
122 else:
123 preds = estimator.predict_proba(
124 X, data_generator=pred_data_generator, steps=steps)
125 131
126 # vcf input 132 # vcf input
127 elif input_type == 'variant_effect': 133 elif input_type == "variant_effect":
128 klass = try_get_attr('galaxy_ml.preprocessors', 134 klass = try_get_attr("galaxy_ml.preprocessors", "GenomicVariantBatchGenerator")
129 'GenomicVariantBatchGenerator') 135
130 136 options = params["input_options"]
131 options = params['input_options'] 137 options.pop("selected_input")
132 options.pop('selected_input') 138 if options["blacklist_regions"] == "none":
133 if options['blacklist_regions'] == 'none': 139 options["blacklist_regions"] = None
134 options['blacklist_regions'] = None 140
135 141 pred_data_generator = klass(ref_genome_path=ref_seq, vcf_path=vcf_path, **options)
136 pred_data_generator = klass(
137 ref_genome_path=ref_seq, vcf_path=vcf_path, **options)
138 142
139 pred_data_generator.set_processing_attrs() 143 pred_data_generator.set_processing_attrs()
140 144
141 variants = pred_data_generator.variants 145 variants = pred_data_generator.variants
142 146
143 # predict 1600 sample at once then write to file 147 # predict 1600 sample at once then write to file
144 gen_flow = pred_data_generator.flow(batch_size=1600) 148 gen_flow = pred_data_generator.flow(batch_size=1600)
145 149
146 file_writer = open(outfile_predict, 'w') 150 file_writer = open(outfile_predict, "w")
147 header_row = '\t'.join(['chrom', 'pos', 'name', 'ref', 151 header_row = "\t".join(["chrom", "pos", "name", "ref", "alt", "strand"])
148 'alt', 'strand'])
149 file_writer.write(header_row) 152 file_writer.write(header_row)
150 header_done = False 153 header_done = False
151 154
152 steps_done = 0 155 steps_done = 0
153 156
154 # TODO: multiple threading 157 # TODO: multiple threading
155 try: 158 try:
156 while steps_done < len(gen_flow): 159 while steps_done < len(gen_flow):
157 index_array = next(gen_flow.index_generator) 160 index_array = next(gen_flow.index_generator)
158 batch_X = gen_flow._get_batches_of_transformed_samples( 161 batch_X = gen_flow._get_batches_of_transformed_samples(index_array)
159 index_array) 162
160 163 if params["method"] == "predict":
161 if params['method'] == 'predict':
162 batch_preds = estimator.predict( 164 batch_preds = estimator.predict(
163 batch_X, 165 batch_X,
164 # The presence of `pred_data_generator` below is to 166 # The presence of `pred_data_generator` below is to
165 # override model carrying data_generator if there 167 # override model carrying data_generator if there
166 # is any. 168 # is any.
167 data_generator=pred_data_generator) 169 data_generator=pred_data_generator,
170 )
168 else: 171 else:
169 batch_preds = estimator.predict_proba( 172 batch_preds = estimator.predict_proba(
170 batch_X, 173 batch_X,
171 # The presence of `pred_data_generator` below is to 174 # The presence of `pred_data_generator` below is to
172 # override model carrying data_generator if there 175 # override model carrying data_generator if there
173 # is any. 176 # is any.
174 data_generator=pred_data_generator) 177 data_generator=pred_data_generator,
178 )
175 179
176 if batch_preds.ndim == 1: 180 if batch_preds.ndim == 1:
177 batch_preds = batch_preds[:, np.newaxis] 181 batch_preds = batch_preds[:, np.newaxis]
178 182
179 batch_meta = variants[index_array] 183 batch_meta = variants[index_array]
180 batch_out = np.column_stack([batch_meta, batch_preds]) 184 batch_out = np.column_stack([batch_meta, batch_preds])
181 185
182 if not header_done: 186 if not header_done:
183 heads = np.arange(batch_preds.shape[-1]).astype(str) 187 heads = np.arange(batch_preds.shape[-1]).astype(str)
184 heads_str = '\t'.join(heads) 188 heads_str = "\t".join(heads)
185 file_writer.write("\t%s\n" % heads_str) 189 file_writer.write("\t%s\n" % heads_str)
186 header_done = True 190 header_done = True
187 191
188 for row in batch_out: 192 for row in batch_out:
189 row_str = '\t'.join(row) 193 row_str = "\t".join(row)
190 file_writer.write("%s\n" % row_str) 194 file_writer.write("%s\n" % row_str)
191 195
192 steps_done += 1 196 steps_done += 1
193 197
194 finally: 198 finally:
198 return 0 202 return 0
199 # end input 203 # end input
200 204
201 # output 205 # output
202 if len(preds.shape) == 1: 206 if len(preds.shape) == 1:
203 rval = pd.DataFrame(preds, columns=['Predicted']) 207 rval = pd.DataFrame(preds, columns=["Predicted"])
204 else: 208 else:
205 rval = pd.DataFrame(preds) 209 rval = pd.DataFrame(preds)
206 210
207 rval.to_csv(outfile_predict, sep='\t', header=True, index=False) 211 rval.to_csv(outfile_predict, sep="\t", header=True, index=False)
208 212
209 213
210 if __name__ == '__main__': 214 if __name__ == "__main__":
211 aparser = argparse.ArgumentParser() 215 aparser = argparse.ArgumentParser()
212 aparser.add_argument("-i", "--inputs", dest="inputs", required=True) 216 aparser.add_argument("-i", "--inputs", dest="inputs", required=True)
213 aparser.add_argument("-e", "--infile_estimator", dest="infile_estimator") 217 aparser.add_argument("-e", "--infile_estimator", dest="infile_estimator")
214 aparser.add_argument("-w", "--infile_weights", dest="infile_weights") 218 aparser.add_argument("-w", "--infile_weights", dest="infile_weights")
215 aparser.add_argument("-X", "--infile1", dest="infile1") 219 aparser.add_argument("-X", "--infile1", dest="infile1")
217 aparser.add_argument("-f", "--fasta_path", dest="fasta_path") 221 aparser.add_argument("-f", "--fasta_path", dest="fasta_path")
218 aparser.add_argument("-r", "--ref_seq", dest="ref_seq") 222 aparser.add_argument("-r", "--ref_seq", dest="ref_seq")
219 aparser.add_argument("-v", "--vcf_path", dest="vcf_path") 223 aparser.add_argument("-v", "--vcf_path", dest="vcf_path")
220 args = aparser.parse_args() 224 args = aparser.parse_args()
221 225
222 main(args.inputs, args.infile_estimator, args.outfile_predict, 226 main(
223 infile_weights=args.infile_weights, infile1=args.infile1, 227 args.inputs,
224 fasta_path=args.fasta_path, ref_seq=args.ref_seq, 228 args.infile_estimator,
225 vcf_path=args.vcf_path) 229 args.outfile_predict,
230 infile_weights=args.infile_weights,
231 infile1=args.infile1,
232 fasta_path=args.fasta_path,
233 ref_seq=args.ref_seq,
234 vcf_path=args.vcf_path,
235 )