comparison model_prediction.py @ 6:13b9ac5d277c draft

"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 208a8d348e7c7a182cfbe1b6f17868146428a7e2"
author bgruening
date Tue, 13 Apr 2021 22:24:07 +0000
parents 5a092779412e
children 3312fb686ffb
comparison
equal deleted inserted replaced
5:ce2fd1edbc6e 6:13b9ac5d277c
1 import argparse 1 import argparse
2 import json 2 import json
3 import warnings
4
3 import numpy as np 5 import numpy as np
4 import pandas as pd 6 import pandas as pd
5 import warnings 7 from galaxy_ml.utils import get_module, load_model, read_columns, try_get_attr
6
7 from scipy.io import mmread 8 from scipy.io import mmread
8 from sklearn.pipeline import Pipeline 9 from sklearn.pipeline import Pipeline
9 10
10 from galaxy_ml.utils import (load_model, read_columns, 11 N_JOBS = int(__import__("os").environ.get("GALAXY_SLOTS", 1))
11 get_module, try_get_attr) 12
12 13
13 14 def main(
14 N_JOBS = int(__import__('os').environ.get('GALAXY_SLOTS', 1)) 15 inputs,
15 16 infile_estimator,
16 17 outfile_predict,
17 def main(inputs, infile_estimator, outfile_predict, 18 infile_weights=None,
18 infile_weights=None, infile1=None, 19 infile1=None,
19 fasta_path=None, ref_seq=None, 20 fasta_path=None,
20 vcf_path=None): 21 ref_seq=None,
22 vcf_path=None,
23 ):
21 """ 24 """
22 Parameter 25 Parameter
23 --------- 26 ---------
24 inputs : str 27 inputs : str
25 File path to galaxy tool parameter 28 File path to galaxy tool parameter
43 File path to dataset containing the reference genome sequence. 46 File path to dataset containing the reference genome sequence.
44 47
45 vcf_path : str 48 vcf_path : str
46 File path to dataset containing variants info. 49 File path to dataset containing variants info.
47 """ 50 """
48 warnings.filterwarnings('ignore') 51 warnings.filterwarnings("ignore")
49 52
50 with open(inputs, 'r') as param_handler: 53 with open(inputs, "r") as param_handler:
51 params = json.load(param_handler) 54 params = json.load(param_handler)
52 55
53 # load model 56 # load model
54 with open(infile_estimator, 'rb') as est_handler: 57 with open(infile_estimator, "rb") as est_handler:
55 estimator = load_model(est_handler) 58 estimator = load_model(est_handler)
56 59
57 main_est = estimator 60 main_est = estimator
58 if isinstance(estimator, Pipeline): 61 if isinstance(estimator, Pipeline):
59 main_est = estimator.steps[-1][-1] 62 main_est = estimator.steps[-1][-1]
60 if hasattr(main_est, 'config') and hasattr(main_est, 'load_weights'): 63 if hasattr(main_est, "config") and hasattr(main_est, "load_weights"):
61 if not infile_weights or infile_weights == 'None': 64 if not infile_weights or infile_weights == "None":
62 raise ValueError("The selected model skeleton asks for weights, " 65 raise ValueError(
63 "but dataset for weights wan not selected!") 66 "The selected model skeleton asks for weights, " "but dataset for weights wan not selected!"
67 )
64 main_est.load_weights(infile_weights) 68 main_est.load_weights(infile_weights)
65 69
66 # handle data input 70 # handle data input
67 input_type = params['input_options']['selected_input'] 71 input_type = params["input_options"]["selected_input"]
68 # tabular input 72 # tabular input
69 if input_type == 'tabular': 73 if input_type == "tabular":
70 header = 'infer' if params['input_options']['header1'] else None 74 header = "infer" if params["input_options"]["header1"] else None
71 column_option = (params['input_options'] 75 column_option = params["input_options"]["column_selector_options_1"]["selected_column_selector_option"]
72 ['column_selector_options_1'] 76 if column_option in [
73 ['selected_column_selector_option']) 77 "by_index_number",
74 if column_option in ['by_index_number', 'all_but_by_index_number', 78 "all_but_by_index_number",
75 'by_header_name', 'all_but_by_header_name']: 79 "by_header_name",
76 c = params['input_options']['column_selector_options_1']['col1'] 80 "all_but_by_header_name",
81 ]:
82 c = params["input_options"]["column_selector_options_1"]["col1"]
77 else: 83 else:
78 c = None 84 c = None
79 85
80 df = pd.read_csv(infile1, sep='\t', header=header, parse_dates=True) 86 df = pd.read_csv(infile1, sep="\t", header=header, parse_dates=True)
81 87
82 X = read_columns(df, c=c, c_option=column_option).astype(float) 88 X = read_columns(df, c=c, c_option=column_option).astype(float)
83 89
84 if params['method'] == 'predict': 90 if params["method"] == "predict":
85 preds = estimator.predict(X) 91 preds = estimator.predict(X)
86 else: 92 else:
87 preds = estimator.predict_proba(X) 93 preds = estimator.predict_proba(X)
88 94
89 # sparse input 95 # sparse input
90 elif input_type == 'sparse': 96 elif input_type == "sparse":
91 X = mmread(open(infile1, 'r')) 97 X = mmread(open(infile1, "r"))
92 if params['method'] == 'predict': 98 if params["method"] == "predict":
93 preds = estimator.predict(X) 99 preds = estimator.predict(X)
94 else: 100 else:
95 preds = estimator.predict_proba(X) 101 preds = estimator.predict_proba(X)
96 102
97 # fasta input 103 # fasta input
98 elif input_type == 'seq_fasta': 104 elif input_type == "seq_fasta":
99 if not hasattr(estimator, 'data_batch_generator'): 105 if not hasattr(estimator, "data_batch_generator"):
100 raise ValueError( 106 raise ValueError(
101 "To do prediction on sequences in fasta input, " 107 "To do prediction on sequences in fasta input, "
102 "the estimator must be a `KerasGBatchClassifier`" 108 "the estimator must be a `KerasGBatchClassifier`"
103 "equipped with data_batch_generator!") 109 "equipped with data_batch_generator!"
104 pyfaidx = get_module('pyfaidx') 110 )
111 pyfaidx = get_module("pyfaidx")
105 sequences = pyfaidx.Fasta(fasta_path) 112 sequences = pyfaidx.Fasta(fasta_path)
106 n_seqs = len(sequences.keys()) 113 n_seqs = len(sequences.keys())
107 X = np.arange(n_seqs)[:, np.newaxis] 114 X = np.arange(n_seqs)[:, np.newaxis]
108 seq_length = estimator.data_batch_generator.seq_length 115 seq_length = estimator.data_batch_generator.seq_length
109 batch_size = getattr(estimator, 'batch_size', 32) 116 batch_size = getattr(estimator, "batch_size", 32)
110 steps = (n_seqs + batch_size - 1) // batch_size 117 steps = (n_seqs + batch_size - 1) // batch_size
111 118
112 seq_type = params['input_options']['seq_type'] 119 seq_type = params["input_options"]["seq_type"]
113 klass = try_get_attr( 120 klass = try_get_attr("galaxy_ml.preprocessors", seq_type)
114 'galaxy_ml.preprocessors', seq_type) 121
115 122 pred_data_generator = klass(fasta_path, seq_length=seq_length)
116 pred_data_generator = klass( 123
117 fasta_path, seq_length=seq_length) 124 if params["method"] == "predict":
118 125 preds = estimator.predict(X, data_generator=pred_data_generator, steps=steps)
119 if params['method'] == 'predict': 126 else:
120 preds = estimator.predict( 127 preds = estimator.predict_proba(X, data_generator=pred_data_generator, steps=steps)
121 X, data_generator=pred_data_generator, steps=steps)
122 else:
123 preds = estimator.predict_proba(
124 X, data_generator=pred_data_generator, steps=steps)
125 128
126 # vcf input 129 # vcf input
127 elif input_type == 'variant_effect': 130 elif input_type == "variant_effect":
128 klass = try_get_attr('galaxy_ml.preprocessors', 131 klass = try_get_attr("galaxy_ml.preprocessors", "GenomicVariantBatchGenerator")
129 'GenomicVariantBatchGenerator') 132
130 133 options = params["input_options"]
131 options = params['input_options'] 134 options.pop("selected_input")
132 options.pop('selected_input') 135 if options["blacklist_regions"] == "none":
133 if options['blacklist_regions'] == 'none': 136 options["blacklist_regions"] = None
134 options['blacklist_regions'] = None 137
135 138 pred_data_generator = klass(ref_genome_path=ref_seq, vcf_path=vcf_path, **options)
136 pred_data_generator = klass(
137 ref_genome_path=ref_seq, vcf_path=vcf_path, **options)
138 139
139 pred_data_generator.set_processing_attrs() 140 pred_data_generator.set_processing_attrs()
140 141
141 variants = pred_data_generator.variants 142 variants = pred_data_generator.variants
142 143
143 # predict 1600 sample at once then write to file 144 # predict 1600 sample at once then write to file
144 gen_flow = pred_data_generator.flow(batch_size=1600) 145 gen_flow = pred_data_generator.flow(batch_size=1600)
145 146
146 file_writer = open(outfile_predict, 'w') 147 file_writer = open(outfile_predict, "w")
147 header_row = '\t'.join(['chrom', 'pos', 'name', 'ref', 148 header_row = "\t".join(["chrom", "pos", "name", "ref", "alt", "strand"])
148 'alt', 'strand'])
149 file_writer.write(header_row) 149 file_writer.write(header_row)
150 header_done = False 150 header_done = False
151 151
152 steps_done = 0 152 steps_done = 0
153 153
154 # TODO: multiple threading 154 # TODO: multiple threading
155 try: 155 try:
156 while steps_done < len(gen_flow): 156 while steps_done < len(gen_flow):
157 index_array = next(gen_flow.index_generator) 157 index_array = next(gen_flow.index_generator)
158 batch_X = gen_flow._get_batches_of_transformed_samples( 158 batch_X = gen_flow._get_batches_of_transformed_samples(index_array)
159 index_array) 159
160 160 if params["method"] == "predict":
161 if params['method'] == 'predict':
162 batch_preds = estimator.predict( 161 batch_preds = estimator.predict(
163 batch_X, 162 batch_X,
164 # The presence of `pred_data_generator` below is to 163 # The presence of `pred_data_generator` below is to
165 # override model carrying data_generator if there 164 # override model carrying data_generator if there
166 # is any. 165 # is any.
167 data_generator=pred_data_generator) 166 data_generator=pred_data_generator,
167 )
168 else: 168 else:
169 batch_preds = estimator.predict_proba( 169 batch_preds = estimator.predict_proba(
170 batch_X, 170 batch_X,
171 # The presence of `pred_data_generator` below is to 171 # The presence of `pred_data_generator` below is to
172 # override model carrying data_generator if there 172 # override model carrying data_generator if there
173 # is any. 173 # is any.
174 data_generator=pred_data_generator) 174 data_generator=pred_data_generator,
175 )
175 176
176 if batch_preds.ndim == 1: 177 if batch_preds.ndim == 1:
177 batch_preds = batch_preds[:, np.newaxis] 178 batch_preds = batch_preds[:, np.newaxis]
178 179
179 batch_meta = variants[index_array] 180 batch_meta = variants[index_array]
180 batch_out = np.column_stack([batch_meta, batch_preds]) 181 batch_out = np.column_stack([batch_meta, batch_preds])
181 182
182 if not header_done: 183 if not header_done:
183 heads = np.arange(batch_preds.shape[-1]).astype(str) 184 heads = np.arange(batch_preds.shape[-1]).astype(str)
184 heads_str = '\t'.join(heads) 185 heads_str = "\t".join(heads)
185 file_writer.write("\t%s\n" % heads_str) 186 file_writer.write("\t%s\n" % heads_str)
186 header_done = True 187 header_done = True
187 188
188 for row in batch_out: 189 for row in batch_out:
189 row_str = '\t'.join(row) 190 row_str = "\t".join(row)
190 file_writer.write("%s\n" % row_str) 191 file_writer.write("%s\n" % row_str)
191 192
192 steps_done += 1 193 steps_done += 1
193 194
194 finally: 195 finally:
198 return 0 199 return 0
199 # end input 200 # end input
200 201
201 # output 202 # output
202 if len(preds.shape) == 1: 203 if len(preds.shape) == 1:
203 rval = pd.DataFrame(preds, columns=['Predicted']) 204 rval = pd.DataFrame(preds, columns=["Predicted"])
204 else: 205 else:
205 rval = pd.DataFrame(preds) 206 rval = pd.DataFrame(preds)
206 207
207 rval.to_csv(outfile_predict, sep='\t', header=True, index=False) 208 rval.to_csv(outfile_predict, sep="\t", header=True, index=False)
208 209
209 210
210 if __name__ == '__main__': 211 if __name__ == "__main__":
211 aparser = argparse.ArgumentParser() 212 aparser = argparse.ArgumentParser()
212 aparser.add_argument("-i", "--inputs", dest="inputs", required=True) 213 aparser.add_argument("-i", "--inputs", dest="inputs", required=True)
213 aparser.add_argument("-e", "--infile_estimator", dest="infile_estimator") 214 aparser.add_argument("-e", "--infile_estimator", dest="infile_estimator")
214 aparser.add_argument("-w", "--infile_weights", dest="infile_weights") 215 aparser.add_argument("-w", "--infile_weights", dest="infile_weights")
215 aparser.add_argument("-X", "--infile1", dest="infile1") 216 aparser.add_argument("-X", "--infile1", dest="infile1")
217 aparser.add_argument("-f", "--fasta_path", dest="fasta_path") 218 aparser.add_argument("-f", "--fasta_path", dest="fasta_path")
218 aparser.add_argument("-r", "--ref_seq", dest="ref_seq") 219 aparser.add_argument("-r", "--ref_seq", dest="ref_seq")
219 aparser.add_argument("-v", "--vcf_path", dest="vcf_path") 220 aparser.add_argument("-v", "--vcf_path", dest="vcf_path")
220 args = aparser.parse_args() 221 args = aparser.parse_args()
221 222
222 main(args.inputs, args.infile_estimator, args.outfile_predict, 223 main(
223 infile_weights=args.infile_weights, infile1=args.infile1, 224 args.inputs,
224 fasta_path=args.fasta_path, ref_seq=args.ref_seq, 225 args.infile_estimator,
225 vcf_path=args.vcf_path) 226 args.outfile_predict,
227 infile_weights=args.infile_weights,
228 infile1=args.infile1,
229 fasta_path=args.fasta_path,
230 ref_seq=args.ref_seq,
231 vcf_path=args.vcf_path,
232 )