comparison search_model_validation.py @ 3:7a64b9f39a46 draft

"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit ba6a47bdf76bbf4cb276206ac1a8cbf61332fd16"
author bgruening
date Fri, 13 Sep 2019 12:15:10 -0400
parents c411ff569a26
children 7a9a9349eb42
comparison
equal deleted inserted replaced
2:c411ff569a26 3:7a64b9f39a46
211 warnings.simplefilter('ignore') 211 warnings.simplefilter('ignore')
212 212
213 with open(inputs, 'r') as param_handler: 213 with open(inputs, 'r') as param_handler:
214 params = json.load(param_handler) 214 params = json.load(param_handler)
215 215
216 # conflict param checker
217 if params['outer_split']['split_mode'] == 'nested_cv' \
218 and params['save'] != 'nope':
219 raise ValueError("Save best estimator is not possible for nested CV!")
220
221 if not (params['search_schemes']['options']['refit']) \
222 and params['save'] != 'nope':
223 raise ValueError("Save best estimator is not possible when refit "
224 "is False!")
225
216 params_builder = params['search_schemes']['search_params_builder'] 226 params_builder = params['search_schemes']['search_params_builder']
217 227
218 with open(infile_estimator, 'rb') as estimator_handler: 228 with open(infile_estimator, 'rb') as estimator_handler:
219 estimator = load_model(estimator_handler) 229 estimator = load_model(estimator_handler)
220 estimator_params = estimator.get_params() 230 estimator_params = estimator.get_params()
540 del main_est.fit_params 550 del main_est.fit_params
541 del main_est.model_class_ 551 del main_est.model_class_
542 del main_est.validation_data 552 del main_est.validation_data
543 if getattr(main_est, 'data_generator_', None): 553 if getattr(main_est, 'data_generator_', None):
544 del main_est.data_generator_ 554 del main_est.data_generator_
545 del main_est.data_batch_generator
546 555
547 with open(outfile_object, 'wb') as output_handler: 556 with open(outfile_object, 'wb') as output_handler:
548 pickle.dump(best_estimator_, output_handler, 557 pickle.dump(best_estimator_, output_handler,
549 pickle.HIGHEST_PROTOCOL) 558 pickle.HIGHEST_PROTOCOL)
550 559