comparison search_model_validation.py @ 26:dde0f1654d18 draft

planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 60f0fbc0eafd7c11bc60fb6c77f2937782efd8a9-dirty
author bgruening
date Fri, 09 Aug 2019 07:23:21 -0400
parents e94395c672bd
children 47d4baa183b2
comparison
equal deleted inserted replaced
25:360330a867f0 26:dde0f1654d18
1 import argparse 1 import argparse
2 import collections 2 import collections
3 import imblearn 3 import imblearn
4 import joblib
4 import json 5 import json
5 import numpy as np 6 import numpy as np
6 import pandas 7 import pandas as pd
7 import pickle 8 import pickle
8 import skrebate 9 import skrebate
9 import sklearn 10 import sklearn
10 import sys 11 import sys
11 import xgboost 12 import xgboost
12 import warnings 13 import warnings
13 import iraps_classifier
14 import model_validations
15 import preprocessors
16 import feature_selectors
17 from imblearn import under_sampling, over_sampling, combine 14 from imblearn import under_sampling, over_sampling, combine
18 from scipy.io import mmread 15 from scipy.io import mmread
19 from mlxtend import classifier, regressor 16 from mlxtend import classifier, regressor
17 from sklearn.base import clone
20 from sklearn import (cluster, compose, decomposition, ensemble, 18 from sklearn import (cluster, compose, decomposition, ensemble,
21 feature_extraction, feature_selection, 19 feature_extraction, feature_selection,
22 gaussian_process, kernel_approximation, metrics, 20 gaussian_process, kernel_approximation, metrics,
23 model_selection, naive_bayes, neighbors, 21 model_selection, naive_bayes, neighbors,
24 pipeline, preprocessing, svm, linear_model, 22 pipeline, preprocessing, svm, linear_model,
25 tree, discriminant_analysis) 23 tree, discriminant_analysis)
26 from sklearn.exceptions import FitFailedWarning 24 from sklearn.exceptions import FitFailedWarning
27 from sklearn.externals import joblib 25 from sklearn.model_selection._validation import _score, cross_validate
28 from sklearn.model_selection._validation import _score 26 from sklearn.model_selection import _search, _validation
29 27
30 from utils import (SafeEval, get_cv, get_scoring, get_X_y, 28 from galaxy_ml.utils import (SafeEval, get_cv, get_scoring, load_model,
31 load_model, read_columns) 29 read_columns, try_get_attr, get_module)
32 from model_validations import train_test_split 30
33 31
32 _fit_and_score = try_get_attr('galaxy_ml.model_validations', '_fit_and_score')
33 setattr(_search, '_fit_and_score', _fit_and_score)
34 setattr(_validation, '_fit_and_score', _fit_and_score)
34 35
35 N_JOBS = int(__import__('os').environ.get('GALAXY_SLOTS', 1)) 36 N_JOBS = int(__import__('os').environ.get('GALAXY_SLOTS', 1))
36 CACHE_DIR = './cached' 37 CACHE_DIR = './cached'
37 NON_SEARCHABLE = ('n_jobs', 'pre_dispatch', 'memory', 'steps', 38 NON_SEARCHABLE = ('n_jobs', 'pre_dispatch', 'memory', '_path',
38 'nthread', 'verbose') 39 'nthread', 'callbacks')
40 ALLOWED_CALLBACKS = ('EarlyStopping', 'TerminateOnNaN', 'ReduceLROnPlateau',
41 'CSVLogger', 'None')
39 42
40 43
41 def _eval_search_params(params_builder): 44 def _eval_search_params(params_builder):
42 search_params = {} 45 search_params = {}
43 46
60 # Have `:` before search list, asks for estimator evaluatio 63 # Have `:` before search list, asks for estimator evaluatio
61 safe_eval_es = SafeEval(load_estimators=True) 64 safe_eval_es = SafeEval(load_estimators=True)
62 search_list = search_list[1:].strip() 65 search_list = search_list[1:].strip()
63 # TODO maybe add regular express check 66 # TODO maybe add regular express check
64 ev = safe_eval_es(search_list) 67 ev = safe_eval_es(search_list)
65 preprocessors = ( 68 preprocessings = (
66 preprocessing.StandardScaler(), preprocessing.Binarizer(), 69 preprocessing.StandardScaler(), preprocessing.Binarizer(),
67 preprocessing.Imputer(), preprocessing.MaxAbsScaler(), 70 preprocessing.MaxAbsScaler(),
68 preprocessing.Normalizer(), preprocessing.MinMaxScaler(), 71 preprocessing.Normalizer(), preprocessing.MinMaxScaler(),
69 preprocessing.PolynomialFeatures(), 72 preprocessing.PolynomialFeatures(),
70 preprocessing.RobustScaler(), feature_selection.SelectKBest(), 73 preprocessing.RobustScaler(), feature_selection.SelectKBest(),
71 feature_selection.GenericUnivariateSelect(), 74 feature_selection.GenericUnivariateSelect(),
72 feature_selection.SelectPercentile(), 75 feature_selection.SelectPercentile(),
131 newlist = [] 134 newlist = []
132 for obj in ev: 135 for obj in ev:
133 if obj is None: 136 if obj is None:
134 newlist.append(None) 137 newlist.append(None)
135 elif obj == 'all_0': 138 elif obj == 'all_0':
136 newlist.extend(preprocessors[0:36]) 139 newlist.extend(preprocessings[0:35])
137 elif obj == 'sk_prep_all': # no KernalCenter() 140 elif obj == 'sk_prep_all': # no KernalCenter()
138 newlist.extend(preprocessors[0:8]) 141 newlist.extend(preprocessings[0:7])
139 elif obj == 'fs_all': 142 elif obj == 'fs_all':
140 newlist.extend(preprocessors[8:15]) 143 newlist.extend(preprocessings[7:14])
141 elif obj == 'decomp_all': 144 elif obj == 'decomp_all':
142 newlist.extend(preprocessors[15:26]) 145 newlist.extend(preprocessings[14:25])
143 elif obj == 'k_appr_all': 146 elif obj == 'k_appr_all':
144 newlist.extend(preprocessors[26:30]) 147 newlist.extend(preprocessings[25:29])
145 elif obj == 'reb_all': 148 elif obj == 'reb_all':
146 newlist.extend(preprocessors[31:36]) 149 newlist.extend(preprocessings[30:35])
147 elif obj == 'imb_all': 150 elif obj == 'imb_all':
148 newlist.extend(preprocessors[36:55]) 151 newlist.extend(preprocessings[35:54])
149 elif type(obj) is int and -1 < obj < len(preprocessors): 152 elif type(obj) is int and -1 < obj < len(preprocessings):
150 newlist.append(preprocessors[obj]) 153 newlist.append(preprocessings[obj])
151 elif hasattr(obj, 'get_params'): # user uploaded object 154 elif hasattr(obj, 'get_params'): # user uploaded object
152 if 'n_jobs' in obj.get_params(): 155 if 'n_jobs' in obj.get_params():
153 newlist.append(obj.set_params(n_jobs=N_JOBS)) 156 newlist.append(obj.set_params(n_jobs=N_JOBS))
154 else: 157 else:
155 newlist.append(obj) 158 newlist.append(obj)
160 163
161 return search_params 164 return search_params
162 165
163 166
164 def main(inputs, infile_estimator, infile1, infile2, 167 def main(inputs, infile_estimator, infile1, infile2,
165 outfile_result, outfile_object=None, groups=None): 168 outfile_result, outfile_object=None,
169 outfile_weights=None, groups=None,
170 ref_seq=None, intervals=None, targets=None,
171 fasta_path=None):
166 """ 172 """
167 Parameter 173 Parameter
168 --------- 174 ---------
169 inputs : str 175 inputs : str
170 File path to galaxy tool parameter 176 File path to galaxy tool parameter
182 File path to save the results, either cv_results or test result 188 File path to save the results, either cv_results or test result
183 189
184 outfile_object : str, optional 190 outfile_object : str, optional
185 File path to save searchCV object 191 File path to save searchCV object
186 192
193 outfile_weights : str, optional
194 File path to save model weights
195
187 groups : str 196 groups : str
188 File path to dataset containing groups labels 197 File path to dataset containing groups labels
198
199 ref_seq : str
200 File path to dataset containing genome sequence file
201
202 intervals : str
203 File path to dataset containing interval file
204
205 targets : str
206 File path to dataset compressed target bed file
207
208 fasta_path : str
209 File path to dataset containing fasta file
189 """ 210 """
190
191 warnings.simplefilter('ignore') 211 warnings.simplefilter('ignore')
192 212
193 with open(inputs, 'r') as param_handler: 213 with open(inputs, 'r') as param_handler:
194 params = json.load(param_handler) 214 params = json.load(param_handler)
195 if groups:
196 (params['search_schemes']['options']['cv_selector']
197 ['groups_selector']['infile_g']) = groups
198 215
199 params_builder = params['search_schemes']['search_params_builder'] 216 params_builder = params['search_schemes']['search_params_builder']
200 217
218 with open(infile_estimator, 'rb') as estimator_handler:
219 estimator = load_model(estimator_handler)
220 estimator_params = estimator.get_params()
221
222 # store read dataframe object
223 loaded_df = {}
224
201 input_type = params['input_options']['selected_input'] 225 input_type = params['input_options']['selected_input']
226 # tabular input
202 if input_type == 'tabular': 227 if input_type == 'tabular':
203 header = 'infer' if params['input_options']['header1'] else None 228 header = 'infer' if params['input_options']['header1'] else None
204 column_option = (params['input_options']['column_selector_options_1'] 229 column_option = (params['input_options']['column_selector_options_1']
205 ['selected_column_selector_option']) 230 ['selected_column_selector_option'])
206 if column_option in ['by_index_number', 'all_but_by_index_number', 231 if column_option in ['by_index_number', 'all_but_by_index_number',
207 'by_header_name', 'all_but_by_header_name']: 232 'by_header_name', 'all_but_by_header_name']:
208 c = params['input_options']['column_selector_options_1']['col1'] 233 c = params['input_options']['column_selector_options_1']['col1']
209 else: 234 else:
210 c = None 235 c = None
211 X = read_columns( 236
212 infile1, 237 df_key = infile1 + repr(header)
213 c=c, 238 df = pd.read_csv(infile1, sep='\t', header=header,
214 c_option=column_option, 239 parse_dates=True)
215 sep='\t', 240 loaded_df[df_key] = df
216 header=header, 241
217 parse_dates=True).astype(float) 242 X = read_columns(df, c=c, c_option=column_option).astype(float)
218 else: 243 # sparse input
244 elif input_type == 'sparse':
219 X = mmread(open(infile1, 'r')) 245 X = mmread(open(infile1, 'r'))
220 246
247 # fasta_file input
248 elif input_type == 'seq_fasta':
249 pyfaidx = get_module('pyfaidx')
250 sequences = pyfaidx.Fasta(fasta_path)
251 n_seqs = len(sequences.keys())
252 X = np.arange(n_seqs)[:, np.newaxis]
253 for param in estimator_params.keys():
254 if param.endswith('fasta_path'):
255 estimator.set_params(
256 **{param: fasta_path})
257 break
258 else:
259 raise ValueError(
260 "The selected estimator doesn't support "
261 "fasta file input! Please consider using "
262 "KerasGBatchClassifier with "
263 "FastaDNABatchGenerator/FastaProteinBatchGenerator "
264 "or having GenomeOneHotEncoder/ProteinOneHotEncoder "
265 "in pipeline!")
266
267 elif input_type == 'refseq_and_interval':
268 path_params = {
269 'data_batch_generator__ref_genome_path': ref_seq,
270 'data_batch_generator__intervals_path': intervals,
271 'data_batch_generator__target_path': targets
272 }
273 estimator.set_params(**path_params)
274 n_intervals = sum(1 for line in open(intervals))
275 X = np.arange(n_intervals)[:, np.newaxis]
276
277 # Get target y
221 header = 'infer' if params['input_options']['header2'] else None 278 header = 'infer' if params['input_options']['header2'] else None
222 column_option = (params['input_options']['column_selector_options_2'] 279 column_option = (params['input_options']['column_selector_options_2']
223 ['selected_column_selector_option2']) 280 ['selected_column_selector_option2'])
224 if column_option in ['by_index_number', 'all_but_by_index_number', 281 if column_option in ['by_index_number', 'all_but_by_index_number',
225 'by_header_name', 'all_but_by_header_name']: 282 'by_header_name', 'all_but_by_header_name']:
226 c = params['input_options']['column_selector_options_2']['col2'] 283 c = params['input_options']['column_selector_options_2']['col2']
227 else: 284 else:
228 c = None 285 c = None
286
287 df_key = infile2 + repr(header)
288 if df_key in loaded_df:
289 infile2 = loaded_df[df_key]
290 else:
291 infile2 = pd.read_csv(infile2, sep='\t',
292 header=header, parse_dates=True)
293 loaded_df[df_key] = infile2
294
229 y = read_columns( 295 y = read_columns(
230 infile2, 296 infile2,
231 c=c, 297 c=c,
232 c_option=column_option, 298 c_option=column_option,
233 sep='\t', 299 sep='\t',
234 header=header, 300 header=header,
235 parse_dates=True) 301 parse_dates=True)
236 y = y.ravel() 302 if len(y.shape) == 2 and y.shape[1] == 1:
303 y = y.ravel()
304 if input_type == 'refseq_and_interval':
305 estimator.set_params(
306 data_batch_generator__features=y.ravel().tolist())
307 y = None
308 # end y
237 309
238 optimizer = params['search_schemes']['selected_search_scheme'] 310 optimizer = params['search_schemes']['selected_search_scheme']
239 optimizer = getattr(model_selection, optimizer) 311 optimizer = getattr(model_selection, optimizer)
240 312
313 # handle gridsearchcv options
241 options = params['search_schemes']['options'] 314 options = params['search_schemes']['options']
315
316 if groups:
317 header = 'infer' if (options['cv_selector']['groups_selector']
318 ['header_g']) else None
319 column_option = (options['cv_selector']['groups_selector']
320 ['column_selector_options_g']
321 ['selected_column_selector_option_g'])
322 if column_option in ['by_index_number', 'all_but_by_index_number',
323 'by_header_name', 'all_but_by_header_name']:
324 c = (options['cv_selector']['groups_selector']
325 ['column_selector_options_g']['col_g'])
326 else:
327 c = None
328
329 df_key = groups + repr(header)
330 if df_key in loaded_df:
331 groups = loaded_df[df_key]
332
333 groups = read_columns(
334 groups,
335 c=c,
336 c_option=column_option,
337 sep='\t',
338 header=header,
339 parse_dates=True)
340 groups = groups.ravel()
341 options['cv_selector']['groups_selector'] = groups
242 342
243 splitter, groups = get_cv(options.pop('cv_selector')) 343 splitter, groups = get_cv(options.pop('cv_selector'))
244 options['cv'] = splitter 344 options['cv'] = splitter
245 options['n_jobs'] = N_JOBS 345 options['n_jobs'] = N_JOBS
246 primary_scoring = options['scoring']['primary_scoring'] 346 primary_scoring = options['scoring']['primary_scoring']
252 if options['refit'] and isinstance(options['scoring'], dict): 352 if options['refit'] and isinstance(options['scoring'], dict):
253 options['refit'] = primary_scoring 353 options['refit'] = primary_scoring
254 if 'pre_dispatch' in options and options['pre_dispatch'] == '': 354 if 'pre_dispatch' in options and options['pre_dispatch'] == '':
255 options['pre_dispatch'] = None 355 options['pre_dispatch'] = None
256 356
257 with open(infile_estimator, 'rb') as estimator_handler: 357 # del loaded_df
258 estimator = load_model(estimator_handler) 358 del loaded_df
259 359
360 # handle memory
260 memory = joblib.Memory(location=CACHE_DIR, verbose=0) 361 memory = joblib.Memory(location=CACHE_DIR, verbose=0)
261 # cache iraps_core fits could increase search speed significantly 362 # cache iraps_core fits could increase search speed significantly
262 if estimator.__class__.__name__ == 'IRAPSClassifier': 363 if estimator.__class__.__name__ == 'IRAPSClassifier':
263 estimator.set_params(memory=memory) 364 estimator.set_params(memory=memory)
264 else: 365 else:
265 for p, v in estimator.get_params().items(): 366 # For iraps buried in pipeline
367 for p, v in estimator_params.items():
266 if p.endswith('memory'): 368 if p.endswith('memory'):
369 # for case of `__irapsclassifier__memory`
267 if len(p) > 8 and p[:-8].endswith('irapsclassifier'): 370 if len(p) > 8 and p[:-8].endswith('irapsclassifier'):
268 # cache iraps_core fits could increase search 371 # cache iraps_core fits could increase search
269 # speed significantly 372 # speed significantly
270 new_params = {p: memory} 373 new_params = {p: memory}
271 estimator.set_params(**new_params) 374 estimator.set_params(**new_params)
375 # security reason, we don't want memory being
376 # modified unexpectedly
272 elif v: 377 elif v:
273 new_params = {p, None} 378 new_params = {p, None}
274 estimator.set_params(**new_params) 379 estimator.set_params(**new_params)
380 # For now, 1 CPU is suggested for iprasclassifier
275 elif p.endswith('n_jobs'): 381 elif p.endswith('n_jobs'):
276 new_params = {p: 1} 382 new_params = {p: 1}
277 estimator.set_params(**new_params) 383 estimator.set_params(**new_params)
384 # for security reason, types of callbacks are limited
385 elif p.endswith('callbacks'):
386 for cb in v:
387 cb_type = cb['callback_selection']['callback_type']
388 if cb_type not in ALLOWED_CALLBACKS:
389 raise ValueError(
390 "Prohibited callback type: %s!" % cb_type)
278 391
279 param_grid = _eval_search_params(params_builder) 392 param_grid = _eval_search_params(params_builder)
280 searcher = optimizer(estimator, param_grid, **options) 393 searcher = optimizer(estimator, param_grid, **options)
281 394
282 # do train_test_split 395 # do nested split
283 do_train_test_split = params['train_test_split'].pop('do_split') 396 split_mode = params['outer_split'].pop('split_mode')
284 if do_train_test_split == 'yes': 397 # nested CV, outer cv using cross_validate
285 # make sure refit is choosen 398 if split_mode == 'nested_cv':
286 if not options['refit']: 399 outer_cv, _ = get_cv(params['outer_split']['cv_selector'])
287 raise ValueError("Refit must be `True` for shuffle splitting!") 400
288 split_options = params['train_test_split'] 401 if options['error_score'] == 'raise':
289 402 rval = cross_validate(
290 # splits 403 searcher, X, y, scoring=options['scoring'],
291 if split_options['shuffle'] == 'stratified': 404 cv=outer_cv, n_jobs=N_JOBS, verbose=0,
292 split_options['labels'] = y 405 error_score=options['error_score'])
293 X, X_test, y, y_test = train_test_split(X, y, **split_options) 406 else:
294 elif split_options['shuffle'] == 'group': 407 warnings.simplefilter('always', FitFailedWarning)
295 if not groups: 408 with warnings.catch_warnings(record=True) as w:
296 raise ValueError("No group based CV option was " 409 try:
297 "choosen for group shuffle!") 410 rval = cross_validate(
298 split_options['labels'] = groups 411 searcher, X, y,
299 X, X_test, y, y_test, groups, _ =\ 412 scoring=options['scoring'],
300 train_test_split(X, y, **split_options) 413 cv=outer_cv, n_jobs=N_JOBS,
301 else: 414 verbose=0,
302 if split_options['shuffle'] == 'None': 415 error_score=options['error_score'])
303 split_options['shuffle'] = None 416 except ValueError:
304 X, X_test, y, y_test =\ 417 pass
305 train_test_split(X, y, **split_options) 418 for warning in w:
306 # end train_test_split 419 print(repr(warning.message))
307 420
308 if options['error_score'] == 'raise': 421 keys = list(rval.keys())
309 searcher.fit(X, y, groups=groups) 422 for k in keys:
423 if k.startswith('test'):
424 rval['mean_' + k] = np.mean(rval[k])
425 rval['std_' + k] = np.std(rval[k])
426 if k.endswith('time'):
427 rval.pop(k)
428 rval = pd.DataFrame(rval)
429 rval = rval[sorted(rval.columns)]
430 rval.to_csv(path_or_buf=outfile_result, sep='\t',
431 header=True, index=False)
310 else: 432 else:
311 warnings.simplefilter('always', FitFailedWarning) 433 if split_mode == 'train_test_split':
312 with warnings.catch_warnings(record=True) as w: 434 train_test_split = try_get_attr(
313 try: 435 'galaxy_ml.model_validations', 'train_test_split')
314 searcher.fit(X, y, groups=groups) 436 # make sure refit is choosen
315 except ValueError: 437 # this could be True for sklearn models, but not the case for
316 pass 438 # deep learning models
317 for warning in w: 439 if not options['refit'] and \
318 print(repr(warning.message)) 440 not all(hasattr(estimator, attr)
319 441 for attr in ('config', 'model_type')):
320 if do_train_test_split == 'no': 442 warnings.warn("Refit is change to `True` for nested "
321 # save results 443 "validation!")
322 cv_results = pandas.DataFrame(searcher.cv_results_) 444 setattr(searcher, 'refit', True)
323 cv_results = cv_results[sorted(cv_results.columns)] 445 split_options = params['outer_split']
324 cv_results.to_csv(path_or_buf=outfile_result, sep='\t', 446
325 header=True, index=False) 447 # splits
326 448 if split_options['shuffle'] == 'stratified':
327 # output test result using best_estimator_ 449 split_options['labels'] = y
328 else: 450 X, X_test, y, y_test = train_test_split(X, y, **split_options)
329 best_estimator_ = searcher.best_estimator_ 451 elif split_options['shuffle'] == 'group':
330 if isinstance(options['scoring'], collections.Mapping): 452 if groups is None:
331 is_multimetric = True 453 raise ValueError("No group based CV option was "
332 else: 454 "choosen for group shuffle!")
333 is_multimetric = False 455 split_options['labels'] = groups
334 456 if y is None:
335 test_score = _score(best_estimator_, X_test, 457 X, X_test, groups, _ =\
336 y_test, options['scoring'], 458 train_test_split(X, groups, **split_options)
337 is_multimetric=is_multimetric) 459 else:
338 if not is_multimetric: 460 X, X_test, y, y_test, groups, _ =\
339 test_score = {primary_scoring: test_score} 461 train_test_split(X, y, groups, **split_options)
340 for key, value in test_score.items(): 462 else:
341 test_score[key] = [value] 463 if split_options['shuffle'] == 'None':
342 result_df = pandas.DataFrame(test_score) 464 split_options['shuffle'] = None
343 result_df.to_csv(path_or_buf=outfile_result, sep='\t', 465 X, X_test, y, y_test =\
344 header=True, index=False) 466 train_test_split(X, y, **split_options)
467 # end train_test_split
468
469 # shared by both train_test_split and non-split
470 if options['error_score'] == 'raise':
471 searcher.fit(X, y, groups=groups)
472 else:
473 warnings.simplefilter('always', FitFailedWarning)
474 with warnings.catch_warnings(record=True) as w:
475 try:
476 searcher.fit(X, y, groups=groups)
477 except ValueError:
478 pass
479 for warning in w:
480 print(repr(warning.message))
481
482 # no outer split
483 if split_mode == 'no':
484 # save results
485 cv_results = pd.DataFrame(searcher.cv_results_)
486 cv_results = cv_results[sorted(cv_results.columns)]
487 cv_results.to_csv(path_or_buf=outfile_result, sep='\t',
488 header=True, index=False)
489
490 # train_test_split, output test result using best_estimator_
491 # or rebuild the trained estimator using weights if applicable.
492 else:
493 scorer_ = searcher.scorer_
494 if isinstance(scorer_, collections.Mapping):
495 is_multimetric = True
496 else:
497 is_multimetric = False
498
499 best_estimator_ = getattr(searcher, 'best_estimator_', None)
500 if not best_estimator_:
501 raise ValueError("GridSearchCV object has no "
502 "`best_estimator_` when `refit`=False!")
503
504 if best_estimator_.__class__.__name__ == 'KerasGBatchClassifier' \
505 and hasattr(estimator.data_batch_generator, 'target_path'):
506 test_score = best_estimator_.evaluate(
507 X_test, scorer=scorer_, is_multimetric=is_multimetric)
508 else:
509 test_score = _score(best_estimator_, X_test,
510 y_test, scorer_,
511 is_multimetric=is_multimetric)
512
513 if not is_multimetric:
514 test_score = {primary_scoring: test_score}
515 for key, value in test_score.items():
516 test_score[key] = [value]
517 result_df = pd.DataFrame(test_score)
518 result_df.to_csv(path_or_buf=outfile_result, sep='\t',
519 header=True, index=False)
345 520
346 memory.clear(warn=False) 521 memory.clear(warn=False)
347 522
348 if outfile_object: 523 if outfile_object:
524 best_estimator_ = getattr(searcher, 'best_estimator_', None)
525 if not best_estimator_:
526 warnings.warn("GridSearchCV object has no attribute "
527 "'best_estimator_', because either it's "
528 "nested gridsearch or `refit` is False!")
529 return
530
531 main_est = best_estimator_
532 if isinstance(best_estimator_, pipeline.Pipeline):
533 main_est = best_estimator_.steps[-1][-1]
534
535 if hasattr(main_est, 'model_') \
536 and hasattr(main_est, 'save_weights'):
537 if outfile_weights:
538 main_est.save_weights(outfile_weights)
539 del main_est.model_
540 del main_est.fit_params
541 del main_est.model_class_
542 del main_est.validation_data
543 if getattr(main_est, 'data_generator_', None):
544 del main_est.data_generator_
545 del main_est.data_batch_generator
546
349 with open(outfile_object, 'wb') as output_handler: 547 with open(outfile_object, 'wb') as output_handler:
350 pickle.dump(searcher, output_handler, pickle.HIGHEST_PROTOCOL) 548 pickle.dump(best_estimator_, output_handler,
549 pickle.HIGHEST_PROTOCOL)
351 550
352 551
353 if __name__ == '__main__': 552 if __name__ == '__main__':
354 aparser = argparse.ArgumentParser() 553 aparser = argparse.ArgumentParser()
355 aparser.add_argument("-i", "--inputs", dest="inputs", required=True) 554 aparser.add_argument("-i", "--inputs", dest="inputs", required=True)
356 aparser.add_argument("-e", "--estimator", dest="infile_estimator") 555 aparser.add_argument("-e", "--estimator", dest="infile_estimator")
357 aparser.add_argument("-X", "--infile1", dest="infile1") 556 aparser.add_argument("-X", "--infile1", dest="infile1")
358 aparser.add_argument("-y", "--infile2", dest="infile2") 557 aparser.add_argument("-y", "--infile2", dest="infile2")
359 aparser.add_argument("-r", "--outfile_result", dest="outfile_result") 558 aparser.add_argument("-O", "--outfile_result", dest="outfile_result")
360 aparser.add_argument("-o", "--outfile_object", dest="outfile_object") 559 aparser.add_argument("-o", "--outfile_object", dest="outfile_object")
560 aparser.add_argument("-w", "--outfile_weights", dest="outfile_weights")
361 aparser.add_argument("-g", "--groups", dest="groups") 561 aparser.add_argument("-g", "--groups", dest="groups")
562 aparser.add_argument("-r", "--ref_seq", dest="ref_seq")
563 aparser.add_argument("-b", "--intervals", dest="intervals")
564 aparser.add_argument("-t", "--targets", dest="targets")
565 aparser.add_argument("-f", "--fasta_path", dest="fasta_path")
362 args = aparser.parse_args() 566 args = aparser.parse_args()
363 567
364 main(args.inputs, args.infile_estimator, args.infile1, args.infile2, 568 main(args.inputs, args.infile_estimator, args.infile1, args.infile2,
365 args.outfile_result, outfile_object=args.outfile_object, 569 args.outfile_result, outfile_object=args.outfile_object,
366 groups=args.groups) 570 outfile_weights=args.outfile_weights, groups=args.groups,
571 ref_seq=args.ref_seq, intervals=args.intervals,
572 targets=args.targets, fasta_path=args.fasta_path)