comparison keras_train_and_eval.py @ 29:93f3b307485f draft

"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit e2a5eade6d0e5ddf3a47630381a0ad90d80e8a04"
author bgruening
date Tue, 13 Apr 2021 18:21:34 +0000
parents 41b109e70a7f
children 1d20e0dce176
comparison
equal deleted inserted replaced
28:6d21b03e00a1 29:93f3b307485f
8 import warnings 8 import warnings
9 from itertools import chain 9 from itertools import chain
10 from scipy.io import mmread 10 from scipy.io import mmread
11 from sklearn.pipeline import Pipeline 11 from sklearn.pipeline import Pipeline
12 from sklearn.metrics.scorer import _check_multimetric_scoring 12 from sklearn.metrics.scorer import _check_multimetric_scoring
13 from sklearn import model_selection
14 from sklearn.model_selection._validation import _score 13 from sklearn.model_selection._validation import _score
15 from sklearn.model_selection import _search, _validation 14 from sklearn.model_selection import _search, _validation
16 from sklearn.utils import indexable, safe_indexing 15 from sklearn.utils import indexable, safe_indexing
17 16
18 from galaxy_ml.externals.selene_sdk.utils import compute_score 17 from galaxy_ml.externals.selene_sdk.utils import compute_score
19 from galaxy_ml.model_validations import train_test_split 18 from galaxy_ml.model_validations import train_test_split
20 from galaxy_ml.keras_galaxy_models import _predict_generator 19 from galaxy_ml.keras_galaxy_models import _predict_generator
21 from galaxy_ml.utils import (SafeEval, get_scoring, load_model, 20 from galaxy_ml.utils import (
22 read_columns, try_get_attr, get_module, 21 SafeEval,
23 clean_params, get_main_estimator) 22 get_scoring,
24 23 load_model,
25 24 read_columns,
26 _fit_and_score = try_get_attr('galaxy_ml.model_validations', '_fit_and_score') 25 try_get_attr,
27 setattr(_search, '_fit_and_score', _fit_and_score) 26 get_module,
28 setattr(_validation, '_fit_and_score', _fit_and_score) 27 clean_params,
29 28 get_main_estimator,
30 N_JOBS = int(os.environ.get('GALAXY_SLOTS', 1)) 29 )
31 CACHE_DIR = os.path.join(os.getcwd(), 'cached') 30
31
32 _fit_and_score = try_get_attr("galaxy_ml.model_validations", "_fit_and_score")
33 setattr(_search, "_fit_and_score", _fit_and_score)
34 setattr(_validation, "_fit_and_score", _fit_and_score)
35
36 N_JOBS = int(os.environ.get("GALAXY_SLOTS", 1))
37 CACHE_DIR = os.path.join(os.getcwd(), "cached")
32 del os 38 del os
33 NON_SEARCHABLE = ('n_jobs', 'pre_dispatch', 'memory', '_path', 39 NON_SEARCHABLE = ("n_jobs", "pre_dispatch", "memory", "_path", "nthread", "callbacks")
34 'nthread', 'callbacks') 40 ALLOWED_CALLBACKS = (
35 ALLOWED_CALLBACKS = ('EarlyStopping', 'TerminateOnNaN', 'ReduceLROnPlateau', 41 "EarlyStopping",
36 'CSVLogger', 'None') 42 "TerminateOnNaN",
43 "ReduceLROnPlateau",
44 "CSVLogger",
45 "None",
46 )
37 47
38 48
39 def _eval_swap_params(params_builder): 49 def _eval_swap_params(params_builder):
40 swap_params = {} 50 swap_params = {}
41 51
42 for p in params_builder['param_set']: 52 for p in params_builder["param_set"]:
43 swap_value = p['sp_value'].strip() 53 swap_value = p["sp_value"].strip()
44 if swap_value == '': 54 if swap_value == "":
45 continue 55 continue
46 56
47 param_name = p['sp_name'] 57 param_name = p["sp_name"]
48 if param_name.lower().endswith(NON_SEARCHABLE): 58 if param_name.lower().endswith(NON_SEARCHABLE):
49 warnings.warn("Warning: `%s` is not eligible for search and was " 59 warnings.warn("Warning: `%s` is not eligible for search and was " "omitted!" % param_name)
50 "omitted!" % param_name)
51 continue 60 continue
52 61
53 if not swap_value.startswith(':'): 62 if not swap_value.startswith(":"):
54 safe_eval = SafeEval(load_scipy=True, load_numpy=True) 63 safe_eval = SafeEval(load_scipy=True, load_numpy=True)
55 ev = safe_eval(swap_value) 64 ev = safe_eval(swap_value)
56 else: 65 else:
57 # Have `:` before search list, asks for estimator evaluatio 66 # Have `:` before search list, asks for estimator evaluatio
58 safe_eval_es = SafeEval(load_estimators=True) 67 safe_eval_es = SafeEval(load_estimators=True)
75 if arr is None: 84 if arr is None:
76 nones.append(idx) 85 nones.append(idx)
77 else: 86 else:
78 new_arrays.append(arr) 87 new_arrays.append(arr)
79 88
80 if kwargs['shuffle'] == 'None': 89 if kwargs["shuffle"] == "None":
81 kwargs['shuffle'] = None 90 kwargs["shuffle"] = None
82 91
83 group_names = kwargs.pop('group_names', None) 92 group_names = kwargs.pop("group_names", None)
84 93
85 if group_names is not None and group_names.strip(): 94 if group_names is not None and group_names.strip():
86 group_names = [name.strip() for name in 95 group_names = [name.strip() for name in group_names.split(",")]
87 group_names.split(',')]
88 new_arrays = indexable(*new_arrays) 96 new_arrays = indexable(*new_arrays)
89 groups = kwargs['labels'] 97 groups = kwargs["labels"]
90 n_samples = new_arrays[0].shape[0] 98 n_samples = new_arrays[0].shape[0]
91 index_arr = np.arange(n_samples) 99 index_arr = np.arange(n_samples)
92 test = index_arr[np.isin(groups, group_names)] 100 test = index_arr[np.isin(groups, group_names)]
93 train = index_arr[~np.isin(groups, group_names)] 101 train = index_arr[~np.isin(groups, group_names)]
94 rval = list(chain.from_iterable( 102 rval = list(chain.from_iterable((safe_indexing(a, train), safe_indexing(a, test)) for a in new_arrays))
95 (safe_indexing(a, train),
96 safe_indexing(a, test)) for a in new_arrays))
97 else: 103 else:
98 rval = train_test_split(*new_arrays, **kwargs) 104 rval = train_test_split(*new_arrays, **kwargs)
99 105
100 for pos in nones: 106 for pos in nones:
101 rval[pos * 2: 2] = [None, None] 107 rval[pos * 2 : 2] = [None, None]
102 108
103 return rval 109 return rval
104 110
105 111
106 def _evaluate(y_true, pred_probas, scorer, is_multimetric=True): 112 def _evaluate(y_true, pred_probas, scorer, is_multimetric=True):
107 """ output scores based on input scorer 113 """output scores based on input scorer
108 114
109 Parameters 115 Parameters
110 ---------- 116 ----------
111 y_true : array 117 y_true : array
112 True label or target values 118 True label or target values
116 dict of `sklearn.metrics.scorer.SCORER` 122 dict of `sklearn.metrics.scorer.SCORER`
117 is_multimetric : bool, default is True 123 is_multimetric : bool, default is True
118 """ 124 """
119 if y_true.ndim == 1 or y_true.shape[-1] == 1: 125 if y_true.ndim == 1 or y_true.shape[-1] == 1:
120 pred_probas = pred_probas.ravel() 126 pred_probas = pred_probas.ravel()
121 pred_labels = (pred_probas > 0.5).astype('int32') 127 pred_labels = (pred_probas > 0.5).astype("int32")
122 targets = y_true.ravel().astype('int32') 128 targets = y_true.ravel().astype("int32")
123 if not is_multimetric: 129 if not is_multimetric:
124 preds = pred_labels if scorer.__class__.__name__ == \ 130 preds = pred_labels if scorer.__class__.__name__ == "_PredictScorer" else pred_probas
125 '_PredictScorer' else pred_probas
126 score = scorer._score_func(targets, preds, **scorer._kwargs) 131 score = scorer._score_func(targets, preds, **scorer._kwargs)
127 132
128 return score 133 return score
129 else: 134 else:
130 scores = {} 135 scores = {}
131 for name, one_scorer in scorer.items(): 136 for name, one_scorer in scorer.items():
132 preds = pred_labels if one_scorer.__class__.__name__\ 137 preds = pred_labels if one_scorer.__class__.__name__ == "_PredictScorer" else pred_probas
133 == '_PredictScorer' else pred_probas 138 score = one_scorer._score_func(targets, preds, **one_scorer._kwargs)
134 score = one_scorer._score_func(targets, preds,
135 **one_scorer._kwargs)
136 scores[name] = score 139 scores[name] = score
137 140
138 # TODO: multi-class metrics 141 # TODO: multi-class metrics
139 # multi-label 142 # multi-label
140 else: 143 else:
141 pred_labels = (pred_probas > 0.5).astype('int32') 144 pred_labels = (pred_probas > 0.5).astype("int32")
142 targets = y_true.astype('int32') 145 targets = y_true.astype("int32")
143 if not is_multimetric: 146 if not is_multimetric:
144 preds = pred_labels if scorer.__class__.__name__ == \ 147 preds = pred_labels if scorer.__class__.__name__ == "_PredictScorer" else pred_probas
145 '_PredictScorer' else pred_probas 148 score, _ = compute_score(preds, targets, scorer._score_func)
146 score, _ = compute_score(preds, targets,
147 scorer._score_func)
148 return score 149 return score
149 else: 150 else:
150 scores = {} 151 scores = {}
151 for name, one_scorer in scorer.items(): 152 for name, one_scorer in scorer.items():
152 preds = pred_labels if one_scorer.__class__.__name__\ 153 preds = pred_labels if one_scorer.__class__.__name__ == "_PredictScorer" else pred_probas
153 == '_PredictScorer' else pred_probas 154 score, _ = compute_score(preds, targets, one_scorer._score_func)
154 score, _ = compute_score(preds, targets,
155 one_scorer._score_func)
156 scores[name] = score 155 scores[name] = score
157 156
158 return scores 157 return scores
159 158
160 159
161 def main(inputs, infile_estimator, infile1, infile2, 160 def main(
162 outfile_result, outfile_object=None, 161 inputs,
163 outfile_weights=None, outfile_y_true=None, 162 infile_estimator,
164 outfile_y_preds=None, groups=None, 163 infile1,
165 ref_seq=None, intervals=None, targets=None, 164 infile2,
166 fasta_path=None): 165 outfile_result,
166 outfile_object=None,
167 outfile_weights=None,
168 outfile_y_true=None,
169 outfile_y_preds=None,
170 groups=None,
171 ref_seq=None,
172 intervals=None,
173 targets=None,
174 fasta_path=None,
175 ):
167 """ 176 """
168 Parameter 177 Parameter
169 --------- 178 ---------
170 inputs : str 179 inputs : str
171 File path to galaxy tool parameter 180 File path to galaxy tool parameter
207 File path to dataset compressed target bed file 216 File path to dataset compressed target bed file
208 217
209 fasta_path : str 218 fasta_path : str
210 File path to dataset containing fasta file 219 File path to dataset containing fasta file
211 """ 220 """
212 warnings.simplefilter('ignore') 221 warnings.simplefilter("ignore")
213 222
214 with open(inputs, 'r') as param_handler: 223 with open(inputs, "r") as param_handler:
215 params = json.load(param_handler) 224 params = json.load(param_handler)
216 225
217 # load estimator 226 # load estimator
218 with open(infile_estimator, 'rb') as estimator_handler: 227 with open(infile_estimator, "rb") as estimator_handler:
219 estimator = load_model(estimator_handler) 228 estimator = load_model(estimator_handler)
220 229
221 estimator = clean_params(estimator) 230 estimator = clean_params(estimator)
222 231
223 # swap hyperparameter 232 # swap hyperparameter
224 swapping = params['experiment_schemes']['hyperparams_swapping'] 233 swapping = params["experiment_schemes"]["hyperparams_swapping"]
225 swap_params = _eval_swap_params(swapping) 234 swap_params = _eval_swap_params(swapping)
226 estimator.set_params(**swap_params) 235 estimator.set_params(**swap_params)
227 236
228 estimator_params = estimator.get_params() 237 estimator_params = estimator.get_params()
229 238
230 # store read dataframe object 239 # store read dataframe object
231 loaded_df = {} 240 loaded_df = {}
232 241
233 input_type = params['input_options']['selected_input'] 242 input_type = params["input_options"]["selected_input"]
234 # tabular input 243 # tabular input
235 if input_type == 'tabular': 244 if input_type == "tabular":
236 header = 'infer' if params['input_options']['header1'] else None 245 header = "infer" if params["input_options"]["header1"] else None
237 column_option = (params['input_options']['column_selector_options_1'] 246 column_option = params["input_options"]["column_selector_options_1"]["selected_column_selector_option"]
238 ['selected_column_selector_option']) 247 if column_option in [
239 if column_option in ['by_index_number', 'all_but_by_index_number', 248 "by_index_number",
240 'by_header_name', 'all_but_by_header_name']: 249 "all_but_by_index_number",
241 c = params['input_options']['column_selector_options_1']['col1'] 250 "by_header_name",
251 "all_but_by_header_name",
252 ]:
253 c = params["input_options"]["column_selector_options_1"]["col1"]
242 else: 254 else:
243 c = None 255 c = None
244 256
245 df_key = infile1 + repr(header) 257 df_key = infile1 + repr(header)
246 df = pd.read_csv(infile1, sep='\t', header=header, 258 df = pd.read_csv(infile1, sep="\t", header=header, parse_dates=True)
247 parse_dates=True)
248 loaded_df[df_key] = df 259 loaded_df[df_key] = df
249 260
250 X = read_columns(df, c=c, c_option=column_option).astype(float) 261 X = read_columns(df, c=c, c_option=column_option).astype(float)
251 # sparse input 262 # sparse input
252 elif input_type == 'sparse': 263 elif input_type == "sparse":
253 X = mmread(open(infile1, 'r')) 264 X = mmread(open(infile1, "r"))
254 265
255 # fasta_file input 266 # fasta_file input
256 elif input_type == 'seq_fasta': 267 elif input_type == "seq_fasta":
257 pyfaidx = get_module('pyfaidx') 268 pyfaidx = get_module("pyfaidx")
258 sequences = pyfaidx.Fasta(fasta_path) 269 sequences = pyfaidx.Fasta(fasta_path)
259 n_seqs = len(sequences.keys()) 270 n_seqs = len(sequences.keys())
260 X = np.arange(n_seqs)[:, np.newaxis] 271 X = np.arange(n_seqs)[:, np.newaxis]
261 for param in estimator_params.keys(): 272 for param in estimator_params.keys():
262 if param.endswith('fasta_path'): 273 if param.endswith("fasta_path"):
263 estimator.set_params( 274 estimator.set_params(**{param: fasta_path})
264 **{param: fasta_path})
265 break 275 break
266 else: 276 else:
267 raise ValueError( 277 raise ValueError(
268 "The selected estimator doesn't support " 278 "The selected estimator doesn't support "
269 "fasta file input! Please consider using " 279 "fasta file input! Please consider using "
270 "KerasGBatchClassifier with " 280 "KerasGBatchClassifier with "
271 "FastaDNABatchGenerator/FastaProteinBatchGenerator " 281 "FastaDNABatchGenerator/FastaProteinBatchGenerator "
272 "or having GenomeOneHotEncoder/ProteinOneHotEncoder " 282 "or having GenomeOneHotEncoder/ProteinOneHotEncoder "
273 "in pipeline!") 283 "in pipeline!"
274 284 )
275 elif input_type == 'refseq_and_interval': 285
286 elif input_type == "refseq_and_interval":
276 path_params = { 287 path_params = {
277 'data_batch_generator__ref_genome_path': ref_seq, 288 "data_batch_generator__ref_genome_path": ref_seq,
278 'data_batch_generator__intervals_path': intervals, 289 "data_batch_generator__intervals_path": intervals,
279 'data_batch_generator__target_path': targets 290 "data_batch_generator__target_path": targets,
280 } 291 }
281 estimator.set_params(**path_params) 292 estimator.set_params(**path_params)
282 n_intervals = sum(1 for line in open(intervals)) 293 n_intervals = sum(1 for line in open(intervals))
283 X = np.arange(n_intervals)[:, np.newaxis] 294 X = np.arange(n_intervals)[:, np.newaxis]
284 295
285 # Get target y 296 # Get target y
286 header = 'infer' if params['input_options']['header2'] else None 297 header = "infer" if params["input_options"]["header2"] else None
287 column_option = (params['input_options']['column_selector_options_2'] 298 column_option = params["input_options"]["column_selector_options_2"]["selected_column_selector_option2"]
288 ['selected_column_selector_option2']) 299 if column_option in [
289 if column_option in ['by_index_number', 'all_but_by_index_number', 300 "by_index_number",
290 'by_header_name', 'all_but_by_header_name']: 301 "all_but_by_index_number",
291 c = params['input_options']['column_selector_options_2']['col2'] 302 "by_header_name",
303 "all_but_by_header_name",
304 ]:
305 c = params["input_options"]["column_selector_options_2"]["col2"]
292 else: 306 else:
293 c = None 307 c = None
294 308
295 df_key = infile2 + repr(header) 309 df_key = infile2 + repr(header)
296 if df_key in loaded_df: 310 if df_key in loaded_df:
297 infile2 = loaded_df[df_key] 311 infile2 = loaded_df[df_key]
298 else: 312 else:
299 infile2 = pd.read_csv(infile2, sep='\t', 313 infile2 = pd.read_csv(infile2, sep="\t", header=header, parse_dates=True)
300 header=header, parse_dates=True)
301 loaded_df[df_key] = infile2 314 loaded_df[df_key] = infile2
302 315
303 y = read_columns( 316 y = read_columns(infile2,
304 infile2, 317 c=c,
305 c=c, 318 c_option=column_option,
306 c_option=column_option, 319 sep='\t',
307 sep='\t', 320 header=header,
308 header=header, 321 parse_dates=True)
309 parse_dates=True)
310 if len(y.shape) == 2 and y.shape[1] == 1: 322 if len(y.shape) == 2 and y.shape[1] == 1:
311 y = y.ravel() 323 y = y.ravel()
312 if input_type == 'refseq_and_interval': 324 if input_type == "refseq_and_interval":
313 estimator.set_params( 325 estimator.set_params(data_batch_generator__features=y.ravel().tolist())
314 data_batch_generator__features=y.ravel().tolist())
315 y = None 326 y = None
316 # end y 327 # end y
317 328
318 # load groups 329 # load groups
319 if groups: 330 if groups:
320 groups_selector = (params['experiment_schemes']['test_split'] 331 groups_selector = (params["experiment_schemes"]["test_split"]["split_algos"]).pop("groups_selector")
321 ['split_algos']).pop('groups_selector') 332
322 333 header = "infer" if groups_selector["header_g"] else None
323 header = 'infer' if groups_selector['header_g'] else None 334 column_option = groups_selector["column_selector_options_g"]["selected_column_selector_option_g"]
324 column_option = \ 335 if column_option in [
325 (groups_selector['column_selector_options_g'] 336 "by_index_number",
326 ['selected_column_selector_option_g']) 337 "all_but_by_index_number",
327 if column_option in ['by_index_number', 'all_but_by_index_number', 338 "by_header_name",
328 'by_header_name', 'all_but_by_header_name']: 339 "all_but_by_header_name",
329 c = groups_selector['column_selector_options_g']['col_g'] 340 ]:
341 c = groups_selector["column_selector_options_g"]["col_g"]
330 else: 342 else:
331 c = None 343 c = None
332 344
333 df_key = groups + repr(header) 345 df_key = groups + repr(header)
334 if df_key in loaded_df: 346 if df_key in loaded_df:
335 groups = loaded_df[df_key] 347 groups = loaded_df[df_key]
336 348
337 groups = read_columns( 349 groups = read_columns(groups,
338 groups, 350 c=c,
339 c=c, 351 c_option=column_option,
340 c_option=column_option, 352 sep='\t',
341 sep='\t', 353 header=header,
342 header=header, 354 parse_dates=True)
343 parse_dates=True)
344 groups = groups.ravel() 355 groups = groups.ravel()
345 356
346 # del loaded_df 357 # del loaded_df
347 del loaded_df 358 del loaded_df
348 359
349 # cache iraps_core fits could increase search speed significantly 360 # cache iraps_core fits could increase search speed significantly
350 memory = joblib.Memory(location=CACHE_DIR, verbose=0) 361 memory = joblib.Memory(location=CACHE_DIR, verbose=0)
351 main_est = get_main_estimator(estimator) 362 main_est = get_main_estimator(estimator)
352 if main_est.__class__.__name__ == 'IRAPSClassifier': 363 if main_est.__class__.__name__ == "IRAPSClassifier":
353 main_est.set_params(memory=memory) 364 main_est.set_params(memory=memory)
354 365
355 # handle scorer, convert to scorer dict 366 # handle scorer, convert to scorer dict
356 scoring = params['experiment_schemes']['metrics']['scoring'] 367 scoring = params['experiment_schemes']['metrics']['scoring']
368 if scoring is not None:
369 # get_scoring() expects secondary_scoring to be a comma separated string (not a list)
370 # Check if secondary_scoring is specified
371 secondary_scoring = scoring.get("secondary_scoring", None)
372 if secondary_scoring is not None:
373 # If secondary_scoring is specified, convert the list into comman separated string
374 scoring["secondary_scoring"] = ",".join(scoring["secondary_scoring"])
375
357 scorer = get_scoring(scoring) 376 scorer = get_scoring(scoring)
358 scorer, _ = _check_multimetric_scoring(estimator, scoring=scorer) 377 scorer, _ = _check_multimetric_scoring(estimator, scoring=scorer)
359 378
360 # handle test (first) split 379 # handle test (first) split
361 test_split_options = (params['experiment_schemes'] 380 test_split_options = params["experiment_schemes"]["test_split"]["split_algos"]
362 ['test_split']['split_algos']) 381
363 382 if test_split_options["shuffle"] == "group":
364 if test_split_options['shuffle'] == 'group': 383 test_split_options["labels"] = groups
365 test_split_options['labels'] = groups 384 if test_split_options["shuffle"] == "stratified":
366 if test_split_options['shuffle'] == 'stratified':
367 if y is not None: 385 if y is not None:
368 test_split_options['labels'] = y 386 test_split_options["labels"] = y
369 else: 387 else:
370 raise ValueError("Stratified shuffle split is not " 388 raise ValueError("Stratified shuffle split is not " "applicable on empty target values!")
371 "applicable on empty target values!") 389
372 390 (
373 X_train, X_test, y_train, y_test, groups_train, groups_test = \ 391 X_train,
374 train_test_split_none(X, y, groups, **test_split_options) 392 X_test,
375 393 y_train,
376 exp_scheme = params['experiment_schemes']['selected_exp_scheme'] 394 y_test,
395 groups_train,
396 _groups_test,
397 ) = train_test_split_none(X, y, groups, **test_split_options)
398
399 exp_scheme = params["experiment_schemes"]["selected_exp_scheme"]
377 400
378 # handle validation (second) split 401 # handle validation (second) split
379 if exp_scheme == 'train_val_test': 402 if exp_scheme == "train_val_test":
380 val_split_options = (params['experiment_schemes'] 403 val_split_options = params["experiment_schemes"]["val_split"]["split_algos"]
381 ['val_split']['split_algos']) 404
382 405 if val_split_options["shuffle"] == "group":
383 if val_split_options['shuffle'] == 'group': 406 val_split_options["labels"] = groups_train
384 val_split_options['labels'] = groups_train 407 if val_split_options["shuffle"] == "stratified":
385 if val_split_options['shuffle'] == 'stratified':
386 if y_train is not None: 408 if y_train is not None:
387 val_split_options['labels'] = y_train 409 val_split_options["labels"] = y_train
388 else: 410 else:
389 raise ValueError("Stratified shuffle split is not " 411 raise ValueError("Stratified shuffle split is not " "applicable on empty target values!")
390 "applicable on empty target values!") 412
391 413 (
392 X_train, X_val, y_train, y_val, groups_train, groups_val = \ 414 X_train,
393 train_test_split_none(X_train, y_train, groups_train, 415 X_val,
394 **val_split_options) 416 y_train,
417 y_val,
418 groups_train,
419 _groups_val,
420 ) = train_test_split_none(X_train, y_train, groups_train, **val_split_options)
395 421
396 # train and eval 422 # train and eval
397 if hasattr(estimator, 'validation_data'): 423 if hasattr(estimator, "validation_data"):
398 if exp_scheme == 'train_val_test': 424 if exp_scheme == "train_val_test":
399 estimator.fit(X_train, y_train, 425 estimator.fit(X_train, y_train, validation_data=(X_val, y_val))
400 validation_data=(X_val, y_val)) 426 else:
401 else: 427 estimator.fit(X_train, y_train, validation_data=(X_test, y_test))
402 estimator.fit(X_train, y_train,
403 validation_data=(X_test, y_test))
404 else: 428 else:
405 estimator.fit(X_train, y_train) 429 estimator.fit(X_train, y_train)
406 430
407 if hasattr(estimator, 'evaluate'): 431 if hasattr(estimator, "evaluate"):
408 steps = estimator.prediction_steps 432 steps = estimator.prediction_steps
409 batch_size = estimator.batch_size 433 batch_size = estimator.batch_size
410 generator = estimator.data_generator_.flow(X_test, y=y_test, 434 generator = estimator.data_generator_.flow(X_test, y=y_test, batch_size=batch_size)
411 batch_size=batch_size) 435 predictions, y_true = _predict_generator(estimator.model_, generator, steps=steps)
412 predictions, y_true = _predict_generator(estimator.model_, generator,
413 steps=steps)
414 scores = _evaluate(y_true, predictions, scorer, is_multimetric=True) 436 scores = _evaluate(y_true, predictions, scorer, is_multimetric=True)
415 437
416 else: 438 else:
417 if hasattr(estimator, 'predict_proba'): 439 if hasattr(estimator, "predict_proba"):
418 predictions = estimator.predict_proba(X_test) 440 predictions = estimator.predict_proba(X_test)
419 else: 441 else:
420 predictions = estimator.predict(X_test) 442 predictions = estimator.predict(X_test)
421 443
422 y_true = y_test 444 y_true = y_test
423 scores = _score(estimator, X_test, y_test, scorer, 445 scores = _score(estimator, X_test, y_test, scorer, is_multimetric=True)
424 is_multimetric=True)
425 if outfile_y_true: 446 if outfile_y_true:
426 try: 447 try:
427 pd.DataFrame(y_true).to_csv(outfile_y_true, sep='\t', 448 pd.DataFrame(y_true).to_csv(outfile_y_true, sep="\t", index=False)
428 index=False)
429 pd.DataFrame(predictions).astype(np.float32).to_csv( 449 pd.DataFrame(predictions).astype(np.float32).to_csv(
430 outfile_y_preds, sep='\t', index=False, 450 outfile_y_preds,
431 float_format='%g', chunksize=10000) 451 sep="\t",
452 index=False,
453 float_format="%g",
454 chunksize=10000,
455 )
432 except Exception as e: 456 except Exception as e:
433 print("Error in saving predictions: %s" % e) 457 print("Error in saving predictions: %s" % e)
434 458
435 # handle output 459 # handle output
436 for name, score in scores.items(): 460 for name, score in scores.items():
437 scores[name] = [score] 461 scores[name] = [score]
438 df = pd.DataFrame(scores) 462 df = pd.DataFrame(scores)
439 df = df[sorted(df.columns)] 463 df = df[sorted(df.columns)]
440 df.to_csv(path_or_buf=outfile_result, sep='\t', 464 df.to_csv(path_or_buf=outfile_result, sep="\t", header=True, index=False)
441 header=True, index=False)
442 465
443 memory.clear(warn=False) 466 memory.clear(warn=False)
444 467
445 if outfile_object: 468 if outfile_object:
446 main_est = estimator 469 main_est = estimator
447 if isinstance(estimator, Pipeline): 470 if isinstance(estimator, Pipeline):
448 main_est = estimator.steps[-1][-1] 471 main_est = estimator.steps[-1][-1]
449 472
450 if hasattr(main_est, 'model_') \ 473 if hasattr(main_est, "model_") and hasattr(main_est, "save_weights"):
451 and hasattr(main_est, 'save_weights'):
452 if outfile_weights: 474 if outfile_weights:
453 main_est.save_weights(outfile_weights) 475 main_est.save_weights(outfile_weights)
454 del main_est.model_ 476 del main_est.model_
455 del main_est.fit_params 477 del main_est.fit_params
456 del main_est.model_class_ 478 del main_est.model_class_
457 del main_est.validation_data 479 if getattr(main_est, "validation_data", None):
458 if getattr(main_est, 'data_generator_', None): 480 del main_est.validation_data
481 if getattr(main_est, "data_generator_", None):
459 del main_est.data_generator_ 482 del main_est.data_generator_
460 483
461 with open(outfile_object, 'wb') as output_handler: 484 with open(outfile_object, "wb") as output_handler:
462 pickle.dump(estimator, output_handler, 485 pickle.dump(estimator, output_handler, pickle.HIGHEST_PROTOCOL)
463 pickle.HIGHEST_PROTOCOL) 486
464 487
465 488 if __name__ == "__main__":
466 if __name__ == '__main__':
467 aparser = argparse.ArgumentParser() 489 aparser = argparse.ArgumentParser()
468 aparser.add_argument("-i", "--inputs", dest="inputs", required=True) 490 aparser.add_argument("-i", "--inputs", dest="inputs", required=True)
469 aparser.add_argument("-e", "--estimator", dest="infile_estimator") 491 aparser.add_argument("-e", "--estimator", dest="infile_estimator")
470 aparser.add_argument("-X", "--infile1", dest="infile1") 492 aparser.add_argument("-X", "--infile1", dest="infile1")
471 aparser.add_argument("-y", "--infile2", dest="infile2") 493 aparser.add_argument("-y", "--infile2", dest="infile2")
479 aparser.add_argument("-b", "--intervals", dest="intervals") 501 aparser.add_argument("-b", "--intervals", dest="intervals")
480 aparser.add_argument("-t", "--targets", dest="targets") 502 aparser.add_argument("-t", "--targets", dest="targets")
481 aparser.add_argument("-f", "--fasta_path", dest="fasta_path") 503 aparser.add_argument("-f", "--fasta_path", dest="fasta_path")
482 args = aparser.parse_args() 504 args = aparser.parse_args()
483 505
484 main(args.inputs, args.infile_estimator, args.infile1, args.infile2, 506 main(
485 args.outfile_result, outfile_object=args.outfile_object, 507 args.inputs,
486 outfile_weights=args.outfile_weights, 508 args.infile_estimator,
487 outfile_y_true=args.outfile_y_true, 509 args.infile1,
488 outfile_y_preds=args.outfile_y_preds, 510 args.infile2,
489 groups=args.groups, 511 args.outfile_result,
490 ref_seq=args.ref_seq, intervals=args.intervals, 512 outfile_object=args.outfile_object,
491 targets=args.targets, fasta_path=args.fasta_path) 513 outfile_weights=args.outfile_weights,
514 outfile_y_true=args.outfile_y_true,
515 outfile_y_preds=args.outfile_y_preds,
516 groups=args.groups,
517 ref_seq=args.ref_seq,
518 intervals=args.intervals,
519 targets=args.targets,
520 fasta_path=args.fasta_path,
521 )