comparison search_model_validation.py @ 3:24c1cc2dd4a4 draft

"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit ea12f973df4b97a2691d9e4ce6bf6fae59d57717"
author bgruening
date Sat, 01 May 2021 01:14:08 +0000
parents e36ab18cbaca
children c16818ce0424
comparison
equal deleted inserted replaced
2:e36ab18cbaca 3:24c1cc2dd4a4
9 import imblearn 9 import imblearn
10 import joblib 10 import joblib
11 import numpy as np 11 import numpy as np
12 import pandas as pd 12 import pandas as pd
13 import skrebate 13 import skrebate
14 from galaxy_ml.utils import ( 14 from galaxy_ml.utils import (clean_params, get_cv,
15 clean_params, 15 get_main_estimator, get_module, get_scoring,
16 get_cv, 16 load_model, read_columns, SafeEval, try_get_attr)
17 get_main_estimator,
18 get_module,
19 get_scoring,
20 load_model,
21 read_columns,
22 SafeEval,
23 try_get_attr
24 )
25 from scipy.io import mmread 17 from scipy.io import mmread
26 from sklearn import ( 18 from sklearn import (cluster, decomposition, feature_selection,
27 cluster, 19 kernel_approximation, model_selection, preprocessing)
28 decomposition,
29 feature_selection,
30 kernel_approximation,
31 model_selection,
32 preprocessing,
33 )
34 from sklearn.exceptions import FitFailedWarning 20 from sklearn.exceptions import FitFailedWarning
35 from sklearn.model_selection import _search, _validation 21 from sklearn.model_selection import _search, _validation
36 from sklearn.model_selection._validation import _score, cross_validate 22 from sklearn.model_selection._validation import _score, cross_validate
37
38 23
39 _fit_and_score = try_get_attr("galaxy_ml.model_validations", "_fit_and_score") 24 _fit_and_score = try_get_attr("galaxy_ml.model_validations", "_fit_and_score")
40 setattr(_search, "_fit_and_score", _fit_and_score) 25 setattr(_search, "_fit_and_score", _fit_and_score)
41 setattr(_validation, "_fit_and_score", _fit_and_score) 26 setattr(_validation, "_fit_and_score", _fit_and_score)
42 27
55 if search_list == "": 40 if search_list == "":
56 continue 41 continue
57 42
58 param_name = p["sp_name"] 43 param_name = p["sp_name"]
59 if param_name.lower().endswith(NON_SEARCHABLE): 44 if param_name.lower().endswith(NON_SEARCHABLE):
60 print("Warning: `%s` is not eligible for search and was " "omitted!" % param_name) 45 print(
46 "Warning: `%s` is not eligible for search and was "
47 "omitted!" % param_name
48 )
61 continue 49 continue
62 50
63 if not search_list.startswith(":"): 51 if not search_list.startswith(":"):
64 safe_eval = SafeEval(load_scipy=True, load_numpy=True) 52 safe_eval = SafeEval(load_scipy=True, load_numpy=True)
65 ev = safe_eval(search_list) 53 ev = safe_eval(search_list)
88 decomposition.FactorAnalysis(random_state=0), 76 decomposition.FactorAnalysis(random_state=0),
89 decomposition.FastICA(random_state=0), 77 decomposition.FastICA(random_state=0),
90 decomposition.IncrementalPCA(), 78 decomposition.IncrementalPCA(),
91 decomposition.KernelPCA(random_state=0, n_jobs=N_JOBS), 79 decomposition.KernelPCA(random_state=0, n_jobs=N_JOBS),
92 decomposition.LatentDirichletAllocation(random_state=0, n_jobs=N_JOBS), 80 decomposition.LatentDirichletAllocation(random_state=0, n_jobs=N_JOBS),
93 decomposition.MiniBatchDictionaryLearning(random_state=0, n_jobs=N_JOBS), 81 decomposition.MiniBatchDictionaryLearning(
82 random_state=0, n_jobs=N_JOBS
83 ),
94 decomposition.MiniBatchSparsePCA(random_state=0, n_jobs=N_JOBS), 84 decomposition.MiniBatchSparsePCA(random_state=0, n_jobs=N_JOBS),
95 decomposition.NMF(random_state=0), 85 decomposition.NMF(random_state=0),
96 decomposition.PCA(random_state=0), 86 decomposition.PCA(random_state=0),
97 decomposition.SparsePCA(random_state=0, n_jobs=N_JOBS), 87 decomposition.SparsePCA(random_state=0, n_jobs=N_JOBS),
98 decomposition.TruncatedSVD(random_state=0), 88 decomposition.TruncatedSVD(random_state=0),
105 skrebate.SURF(n_jobs=N_JOBS), 95 skrebate.SURF(n_jobs=N_JOBS),
106 skrebate.SURFstar(n_jobs=N_JOBS), 96 skrebate.SURFstar(n_jobs=N_JOBS),
107 skrebate.MultiSURF(n_jobs=N_JOBS), 97 skrebate.MultiSURF(n_jobs=N_JOBS),
108 skrebate.MultiSURFstar(n_jobs=N_JOBS), 98 skrebate.MultiSURFstar(n_jobs=N_JOBS),
109 imblearn.under_sampling.ClusterCentroids(random_state=0, n_jobs=N_JOBS), 99 imblearn.under_sampling.ClusterCentroids(random_state=0, n_jobs=N_JOBS),
110 imblearn.under_sampling.CondensedNearestNeighbour(random_state=0, n_jobs=N_JOBS), 100 imblearn.under_sampling.CondensedNearestNeighbour(
111 imblearn.under_sampling.EditedNearestNeighbours(random_state=0, n_jobs=N_JOBS), 101 random_state=0, n_jobs=N_JOBS
112 imblearn.under_sampling.RepeatedEditedNearestNeighbours(random_state=0, n_jobs=N_JOBS), 102 ),
103 imblearn.under_sampling.EditedNearestNeighbours(
104 random_state=0, n_jobs=N_JOBS
105 ),
106 imblearn.under_sampling.RepeatedEditedNearestNeighbours(
107 random_state=0, n_jobs=N_JOBS
108 ),
113 imblearn.under_sampling.AllKNN(random_state=0, n_jobs=N_JOBS), 109 imblearn.under_sampling.AllKNN(random_state=0, n_jobs=N_JOBS),
114 imblearn.under_sampling.InstanceHardnessThreshold(random_state=0, n_jobs=N_JOBS), 110 imblearn.under_sampling.InstanceHardnessThreshold(
111 random_state=0, n_jobs=N_JOBS
112 ),
115 imblearn.under_sampling.NearMiss(random_state=0, n_jobs=N_JOBS), 113 imblearn.under_sampling.NearMiss(random_state=0, n_jobs=N_JOBS),
116 imblearn.under_sampling.NeighbourhoodCleaningRule(random_state=0, n_jobs=N_JOBS), 114 imblearn.under_sampling.NeighbourhoodCleaningRule(
117 imblearn.under_sampling.OneSidedSelection(random_state=0, n_jobs=N_JOBS), 115 random_state=0, n_jobs=N_JOBS
116 ),
117 imblearn.under_sampling.OneSidedSelection(
118 random_state=0, n_jobs=N_JOBS
119 ),
118 imblearn.under_sampling.RandomUnderSampler(random_state=0), 120 imblearn.under_sampling.RandomUnderSampler(random_state=0),
119 imblearn.under_sampling.TomekLinks(random_state=0, n_jobs=N_JOBS), 121 imblearn.under_sampling.TomekLinks(random_state=0, n_jobs=N_JOBS),
120 imblearn.over_sampling.ADASYN(random_state=0, n_jobs=N_JOBS), 122 imblearn.over_sampling.ADASYN(random_state=0, n_jobs=N_JOBS),
121 imblearn.over_sampling.RandomOverSampler(random_state=0), 123 imblearn.over_sampling.RandomOverSampler(random_state=0),
122 imblearn.over_sampling.SMOTE(random_state=0, n_jobs=N_JOBS), 124 imblearn.over_sampling.SMOTE(random_state=0, n_jobs=N_JOBS),
123 imblearn.over_sampling.SVMSMOTE(random_state=0, n_jobs=N_JOBS), 125 imblearn.over_sampling.SVMSMOTE(random_state=0, n_jobs=N_JOBS),
124 imblearn.over_sampling.BorderlineSMOTE(random_state=0, n_jobs=N_JOBS), 126 imblearn.over_sampling.BorderlineSMOTE(random_state=0, n_jobs=N_JOBS),
125 imblearn.over_sampling.SMOTENC(categorical_features=[], random_state=0, n_jobs=N_JOBS), 127 imblearn.over_sampling.SMOTENC(
128 categorical_features=[], random_state=0, n_jobs=N_JOBS
129 ),
126 imblearn.combine.SMOTEENN(random_state=0), 130 imblearn.combine.SMOTEENN(random_state=0),
127 imblearn.combine.SMOTETomek(random_state=0), 131 imblearn.combine.SMOTETomek(random_state=0),
128 ) 132 )
129 newlist = [] 133 newlist = []
130 for obj in ev: 134 for obj in ev:
203 207
204 input_type = params["input_options"]["selected_input"] 208 input_type = params["input_options"]["selected_input"]
205 # tabular input 209 # tabular input
206 if input_type == "tabular": 210 if input_type == "tabular":
207 header = "infer" if params["input_options"]["header1"] else None 211 header = "infer" if params["input_options"]["header1"] else None
208 column_option = params["input_options"]["column_selector_options_1"]["selected_column_selector_option"] 212 column_option = params["input_options"]["column_selector_options_1"][
213 "selected_column_selector_option"
214 ]
209 if column_option in [ 215 if column_option in [
210 "by_index_number", 216 "by_index_number",
211 "all_but_by_index_number", 217 "all_but_by_index_number",
212 "by_header_name", 218 "by_header_name",
213 "all_but_by_header_name", 219 "all_but_by_header_name",
259 n_intervals = sum(1 for line in open(intervals)) 265 n_intervals = sum(1 for line in open(intervals))
260 X = np.arange(n_intervals)[:, np.newaxis] 266 X = np.arange(n_intervals)[:, np.newaxis]
261 267
262 # Get target y 268 # Get target y
263 header = "infer" if params["input_options"]["header2"] else None 269 header = "infer" if params["input_options"]["header2"] else None
264 column_option = params["input_options"]["column_selector_options_2"]["selected_column_selector_option2"] 270 column_option = params["input_options"]["column_selector_options_2"][
271 "selected_column_selector_option2"
272 ]
265 if column_option in [ 273 if column_option in [
266 "by_index_number", 274 "by_index_number",
267 "all_but_by_index_number", 275 "all_but_by_index_number",
268 "by_header_name", 276 "by_header_name",
269 "all_but_by_header_name", 277 "all_but_by_header_name",
277 infile2 = loaded_df[df_key] 285 infile2 = loaded_df[df_key]
278 else: 286 else:
279 infile2 = pd.read_csv(infile2, sep="\t", header=header, parse_dates=True) 287 infile2 = pd.read_csv(infile2, sep="\t", header=header, parse_dates=True)
280 loaded_df[df_key] = infile2 288 loaded_df[df_key] = infile2
281 289
282 y = read_columns(infile2, c=c, c_option=column_option, sep="\t", header=header, parse_dates=True) 290 y = read_columns(
291 infile2, c=c, c_option=column_option, sep="\t", header=header, parse_dates=True
292 )
283 if len(y.shape) == 2 and y.shape[1] == 1: 293 if len(y.shape) == 2 and y.shape[1] == 1:
284 y = y.ravel() 294 y = y.ravel()
285 if input_type == "refseq_and_interval": 295 if input_type == "refseq_and_interval":
286 estimator.set_params(data_batch_generator__features=y.ravel().tolist()) 296 estimator.set_params(data_batch_generator__features=y.ravel().tolist())
287 y = None 297 y = None
376 if split_options["shuffle"] == "stratified": 386 if split_options["shuffle"] == "stratified":
377 split_options["labels"] = y 387 split_options["labels"] = y
378 X, X_test, y, y_test = train_test_split(X, y, **split_options) 388 X, X_test, y, y_test = train_test_split(X, y, **split_options)
379 elif split_options["shuffle"] == "group": 389 elif split_options["shuffle"] == "group":
380 if groups is None: 390 if groups is None:
381 raise ValueError("No group based CV option was choosen for " "group shuffle!") 391 raise ValueError(
392 "No group based CV option was choosen for " "group shuffle!"
393 )
382 split_options["labels"] = groups 394 split_options["labels"] = groups
383 if y is None: 395 if y is None:
384 X, X_test, groups, _ = train_test_split(X, groups, **split_options) 396 X, X_test, groups, _ = train_test_split(X, groups, **split_options)
385 else: 397 else:
386 X, X_test, y, y_test, groups, _ = train_test_split(X, y, groups, **split_options) 398 X, X_test, y, y_test, groups, _ = train_test_split(
399 X, y, groups, **split_options
400 )
387 else: 401 else:
388 if split_options["shuffle"] == "None": 402 if split_options["shuffle"] == "None":
389 split_options["shuffle"] = None 403 split_options["shuffle"] = None
390 X, X_test, y, y_test = train_test_split(X, y, **split_options) 404 X, X_test, y, y_test = train_test_split(X, y, **split_options)
391 405
409 423
410 best_estimator_ = getattr(searcher, "best_estimator_") 424 best_estimator_ = getattr(searcher, "best_estimator_")
411 425
412 # TODO Solve deep learning models in pipeline 426 # TODO Solve deep learning models in pipeline
413 if best_estimator_.__class__.__name__ == "KerasGBatchClassifier": 427 if best_estimator_.__class__.__name__ == "KerasGBatchClassifier":
414 test_score = best_estimator_.evaluate(X_test, scorer=scorer_, is_multimetric=is_multimetric) 428 test_score = best_estimator_.evaluate(
415 else: 429 X_test, scorer=scorer_, is_multimetric=is_multimetric
416 test_score = _score(best_estimator_, X_test, y_test, scorer_, is_multimetric=is_multimetric) 430 )
431 else:
432 test_score = _score(
433 best_estimator_, X_test, y_test, scorer_, is_multimetric=is_multimetric
434 )
417 435
418 if not is_multimetric: 436 if not is_multimetric:
419 test_score = {primary_scoring: test_score} 437 test_score = {primary_scoring: test_score}
420 for key, value in test_score.items(): 438 for key, value in test_score.items():
421 test_score[key] = [value] 439 test_score[key] = [value]
485 503
486 with open(inputs, "r") as param_handler: 504 with open(inputs, "r") as param_handler:
487 params = json.load(param_handler) 505 params = json.load(param_handler)
488 506
489 # Override the refit parameter 507 # Override the refit parameter
490 params["search_schemes"]["options"]["refit"] = True if params["save"] != "nope" else False 508 params["search_schemes"]["options"]["refit"] = (
509 True if params["save"] != "nope" else False
510 )
491 511
492 with open(infile_estimator, "rb") as estimator_handler: 512 with open(infile_estimator, "rb") as estimator_handler:
493 estimator = load_model(estimator_handler) 513 estimator = load_model(estimator_handler)
494 514
495 optimizer = params["search_schemes"]["selected_search_scheme"] 515 optimizer = params["search_schemes"]["selected_search_scheme"]
497 517
498 # handle gridsearchcv options 518 # handle gridsearchcv options
499 options = params["search_schemes"]["options"] 519 options = params["search_schemes"]["options"]
500 520
501 if groups: 521 if groups:
502 header = "infer" if (options["cv_selector"]["groups_selector"]["header_g"]) else None 522 header = (
503 column_option = options["cv_selector"]["groups_selector"]["column_selector_options_g"][ 523 "infer" if (options["cv_selector"]["groups_selector"]["header_g"]) else None
504 "selected_column_selector_option_g" 524 )
505 ] 525 column_option = options["cv_selector"]["groups_selector"][
526 "column_selector_options_g"
527 ]["selected_column_selector_option_g"]
506 if column_option in [ 528 if column_option in [
507 "by_index_number", 529 "by_index_number",
508 "all_but_by_index_number", 530 "all_but_by_index_number",
509 "by_header_name", 531 "by_header_name",
510 "all_but_by_header_name", 532 "all_but_by_header_name",
511 ]: 533 ]:
512 c = options["cv_selector"]["groups_selector"]["column_selector_options_g"]["col_g"] 534 c = options["cv_selector"]["groups_selector"]["column_selector_options_g"][
535 "col_g"
536 ]
513 else: 537 else:
514 c = None 538 c = None
515 539
516 df_key = groups + repr(header) 540 df_key = groups + repr(header)
517 541
535 # get_scoring() expects secondary_scoring to be a comma separated string (not a list) 559 # get_scoring() expects secondary_scoring to be a comma separated string (not a list)
536 # Check if secondary_scoring is specified 560 # Check if secondary_scoring is specified
537 secondary_scoring = options["scoring"].get("secondary_scoring", None) 561 secondary_scoring = options["scoring"].get("secondary_scoring", None)
538 if secondary_scoring is not None: 562 if secondary_scoring is not None:
539 # If secondary_scoring is specified, convert the list into comman separated string 563 # If secondary_scoring is specified, convert the list into comman separated string
540 options["scoring"]["secondary_scoring"] = ",".join(options["scoring"]["secondary_scoring"]) 564 options["scoring"]["secondary_scoring"] = ",".join(
565 options["scoring"]["secondary_scoring"]
566 )
541 options["scoring"] = get_scoring(options["scoring"]) 567 options["scoring"] = get_scoring(options["scoring"])
542 if options["error_score"]: 568 if options["error_score"]:
543 options["error_score"] = "raise" 569 options["error_score"] = "raise"
544 else: 570 else:
545 options["error_score"] = np.NaN 571 options["error_score"] = np.nan
546 if options["refit"] and isinstance(options["scoring"], dict): 572 if options["refit"] and isinstance(options["scoring"], dict):
547 options["refit"] = primary_scoring 573 options["refit"] = primary_scoring
548 if "pre_dispatch" in options and options["pre_dispatch"] == "": 574 if "pre_dispatch" in options and options["pre_dispatch"] == "":
549 options["pre_dispatch"] = None 575 options["pre_dispatch"] = None
550 576
586 612
587 if split_mode == "nested_cv": 613 if split_mode == "nested_cv":
588 # make sure refit is choosen 614 # make sure refit is choosen
589 # this could be True for sklearn models, but not the case for 615 # this could be True for sklearn models, but not the case for
590 # deep learning models 616 # deep learning models
591 if not options["refit"] and not all(hasattr(estimator, attr) for attr in ("config", "model_type")): 617 if not options["refit"] and not all(
618 hasattr(estimator, attr) for attr in ("config", "model_type")
619 ):
592 warnings.warn("Refit is change to `True` for nested validation!") 620 warnings.warn("Refit is change to `True` for nested validation!")
593 setattr(searcher, "refit", True) 621 setattr(searcher, "refit", True)
594 622
595 outer_cv, _ = get_cv(params["outer_split"]["cv_selector"]) 623 outer_cv, _ = get_cv(params["outer_split"]["cv_selector"])
596 # nested CV, outer cv using cross_validate 624 # nested CV, outer cv using cross_validate
685 for warning in w: 713 for warning in w:
686 print(repr(warning.message)) 714 print(repr(warning.message))
687 715
688 cv_results = pd.DataFrame(searcher.cv_results_) 716 cv_results = pd.DataFrame(searcher.cv_results_)
689 cv_results = cv_results[sorted(cv_results.columns)] 717 cv_results = cv_results[sorted(cv_results.columns)]
690 cv_results.to_csv(path_or_buf=outfile_result, sep="\t", header=True, index=False) 718 cv_results.to_csv(
719 path_or_buf=outfile_result, sep="\t", header=True, index=False
720 )
691 721
692 memory.clear(warn=False) 722 memory.clear(warn=False)
693 723
694 # output best estimator, and weights if applicable 724 # output best estimator, and weights if applicable
695 if outfile_object: 725 if outfile_object: