Mercurial > repos > bgruening > sklearn_train_test_eval
comparison search_model_validation.py @ 11:caf7d2b71a48 draft
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit ea12f973df4b97a2691d9e4ce6bf6fae59d57717"
author | bgruening |
---|---|
date | Sat, 01 May 2021 01:47:26 +0000 |
parents | a9e0b963b7bb |
children | 2eb5c017958d |
comparison
equal
deleted
inserted
replaced
10:a9e0b963b7bb | 11:caf7d2b71a48 |
---|---|
9 import imblearn | 9 import imblearn |
10 import joblib | 10 import joblib |
11 import numpy as np | 11 import numpy as np |
12 import pandas as pd | 12 import pandas as pd |
13 import skrebate | 13 import skrebate |
14 from galaxy_ml.utils import ( | 14 from galaxy_ml.utils import (clean_params, get_cv, |
15 clean_params, | 15 get_main_estimator, get_module, get_scoring, |
16 get_cv, | 16 load_model, read_columns, SafeEval, try_get_attr) |
17 get_main_estimator, | |
18 get_module, | |
19 get_scoring, | |
20 load_model, | |
21 read_columns, | |
22 SafeEval, | |
23 try_get_attr | |
24 ) | |
25 from scipy.io import mmread | 17 from scipy.io import mmread |
26 from sklearn import ( | 18 from sklearn import (cluster, decomposition, feature_selection, |
27 cluster, | 19 kernel_approximation, model_selection, preprocessing) |
28 decomposition, | |
29 feature_selection, | |
30 kernel_approximation, | |
31 model_selection, | |
32 preprocessing, | |
33 ) | |
34 from sklearn.exceptions import FitFailedWarning | 20 from sklearn.exceptions import FitFailedWarning |
35 from sklearn.model_selection import _search, _validation | 21 from sklearn.model_selection import _search, _validation |
36 from sklearn.model_selection._validation import _score, cross_validate | 22 from sklearn.model_selection._validation import _score, cross_validate |
37 | |
38 | 23 |
39 _fit_and_score = try_get_attr("galaxy_ml.model_validations", "_fit_and_score") | 24 _fit_and_score = try_get_attr("galaxy_ml.model_validations", "_fit_and_score") |
40 setattr(_search, "_fit_and_score", _fit_and_score) | 25 setattr(_search, "_fit_and_score", _fit_and_score) |
41 setattr(_validation, "_fit_and_score", _fit_and_score) | 26 setattr(_validation, "_fit_and_score", _fit_and_score) |
42 | 27 |
55 if search_list == "": | 40 if search_list == "": |
56 continue | 41 continue |
57 | 42 |
58 param_name = p["sp_name"] | 43 param_name = p["sp_name"] |
59 if param_name.lower().endswith(NON_SEARCHABLE): | 44 if param_name.lower().endswith(NON_SEARCHABLE): |
60 print("Warning: `%s` is not eligible for search and was " "omitted!" % param_name) | 45 print( |
46 "Warning: `%s` is not eligible for search and was " | |
47 "omitted!" % param_name | |
48 ) | |
61 continue | 49 continue |
62 | 50 |
63 if not search_list.startswith(":"): | 51 if not search_list.startswith(":"): |
64 safe_eval = SafeEval(load_scipy=True, load_numpy=True) | 52 safe_eval = SafeEval(load_scipy=True, load_numpy=True) |
65 ev = safe_eval(search_list) | 53 ev = safe_eval(search_list) |
88 decomposition.FactorAnalysis(random_state=0), | 76 decomposition.FactorAnalysis(random_state=0), |
89 decomposition.FastICA(random_state=0), | 77 decomposition.FastICA(random_state=0), |
90 decomposition.IncrementalPCA(), | 78 decomposition.IncrementalPCA(), |
91 decomposition.KernelPCA(random_state=0, n_jobs=N_JOBS), | 79 decomposition.KernelPCA(random_state=0, n_jobs=N_JOBS), |
92 decomposition.LatentDirichletAllocation(random_state=0, n_jobs=N_JOBS), | 80 decomposition.LatentDirichletAllocation(random_state=0, n_jobs=N_JOBS), |
93 decomposition.MiniBatchDictionaryLearning(random_state=0, n_jobs=N_JOBS), | 81 decomposition.MiniBatchDictionaryLearning( |
82 random_state=0, n_jobs=N_JOBS | |
83 ), | |
94 decomposition.MiniBatchSparsePCA(random_state=0, n_jobs=N_JOBS), | 84 decomposition.MiniBatchSparsePCA(random_state=0, n_jobs=N_JOBS), |
95 decomposition.NMF(random_state=0), | 85 decomposition.NMF(random_state=0), |
96 decomposition.PCA(random_state=0), | 86 decomposition.PCA(random_state=0), |
97 decomposition.SparsePCA(random_state=0, n_jobs=N_JOBS), | 87 decomposition.SparsePCA(random_state=0, n_jobs=N_JOBS), |
98 decomposition.TruncatedSVD(random_state=0), | 88 decomposition.TruncatedSVD(random_state=0), |
105 skrebate.SURF(n_jobs=N_JOBS), | 95 skrebate.SURF(n_jobs=N_JOBS), |
106 skrebate.SURFstar(n_jobs=N_JOBS), | 96 skrebate.SURFstar(n_jobs=N_JOBS), |
107 skrebate.MultiSURF(n_jobs=N_JOBS), | 97 skrebate.MultiSURF(n_jobs=N_JOBS), |
108 skrebate.MultiSURFstar(n_jobs=N_JOBS), | 98 skrebate.MultiSURFstar(n_jobs=N_JOBS), |
109 imblearn.under_sampling.ClusterCentroids(random_state=0, n_jobs=N_JOBS), | 99 imblearn.under_sampling.ClusterCentroids(random_state=0, n_jobs=N_JOBS), |
110 imblearn.under_sampling.CondensedNearestNeighbour(random_state=0, n_jobs=N_JOBS), | 100 imblearn.under_sampling.CondensedNearestNeighbour( |
111 imblearn.under_sampling.EditedNearestNeighbours(random_state=0, n_jobs=N_JOBS), | 101 random_state=0, n_jobs=N_JOBS |
112 imblearn.under_sampling.RepeatedEditedNearestNeighbours(random_state=0, n_jobs=N_JOBS), | 102 ), |
103 imblearn.under_sampling.EditedNearestNeighbours( | |
104 random_state=0, n_jobs=N_JOBS | |
105 ), | |
106 imblearn.under_sampling.RepeatedEditedNearestNeighbours( | |
107 random_state=0, n_jobs=N_JOBS | |
108 ), | |
113 imblearn.under_sampling.AllKNN(random_state=0, n_jobs=N_JOBS), | 109 imblearn.under_sampling.AllKNN(random_state=0, n_jobs=N_JOBS), |
114 imblearn.under_sampling.InstanceHardnessThreshold(random_state=0, n_jobs=N_JOBS), | 110 imblearn.under_sampling.InstanceHardnessThreshold( |
111 random_state=0, n_jobs=N_JOBS | |
112 ), | |
115 imblearn.under_sampling.NearMiss(random_state=0, n_jobs=N_JOBS), | 113 imblearn.under_sampling.NearMiss(random_state=0, n_jobs=N_JOBS), |
116 imblearn.under_sampling.NeighbourhoodCleaningRule(random_state=0, n_jobs=N_JOBS), | 114 imblearn.under_sampling.NeighbourhoodCleaningRule( |
117 imblearn.under_sampling.OneSidedSelection(random_state=0, n_jobs=N_JOBS), | 115 random_state=0, n_jobs=N_JOBS |
116 ), | |
117 imblearn.under_sampling.OneSidedSelection( | |
118 random_state=0, n_jobs=N_JOBS | |
119 ), | |
118 imblearn.under_sampling.RandomUnderSampler(random_state=0), | 120 imblearn.under_sampling.RandomUnderSampler(random_state=0), |
119 imblearn.under_sampling.TomekLinks(random_state=0, n_jobs=N_JOBS), | 121 imblearn.under_sampling.TomekLinks(random_state=0, n_jobs=N_JOBS), |
120 imblearn.over_sampling.ADASYN(random_state=0, n_jobs=N_JOBS), | 122 imblearn.over_sampling.ADASYN(random_state=0, n_jobs=N_JOBS), |
121 imblearn.over_sampling.RandomOverSampler(random_state=0), | 123 imblearn.over_sampling.RandomOverSampler(random_state=0), |
122 imblearn.over_sampling.SMOTE(random_state=0, n_jobs=N_JOBS), | 124 imblearn.over_sampling.SMOTE(random_state=0, n_jobs=N_JOBS), |
123 imblearn.over_sampling.SVMSMOTE(random_state=0, n_jobs=N_JOBS), | 125 imblearn.over_sampling.SVMSMOTE(random_state=0, n_jobs=N_JOBS), |
124 imblearn.over_sampling.BorderlineSMOTE(random_state=0, n_jobs=N_JOBS), | 126 imblearn.over_sampling.BorderlineSMOTE(random_state=0, n_jobs=N_JOBS), |
125 imblearn.over_sampling.SMOTENC(categorical_features=[], random_state=0, n_jobs=N_JOBS), | 127 imblearn.over_sampling.SMOTENC( |
128 categorical_features=[], random_state=0, n_jobs=N_JOBS | |
129 ), | |
126 imblearn.combine.SMOTEENN(random_state=0), | 130 imblearn.combine.SMOTEENN(random_state=0), |
127 imblearn.combine.SMOTETomek(random_state=0), | 131 imblearn.combine.SMOTETomek(random_state=0), |
128 ) | 132 ) |
129 newlist = [] | 133 newlist = [] |
130 for obj in ev: | 134 for obj in ev: |
203 | 207 |
204 input_type = params["input_options"]["selected_input"] | 208 input_type = params["input_options"]["selected_input"] |
205 # tabular input | 209 # tabular input |
206 if input_type == "tabular": | 210 if input_type == "tabular": |
207 header = "infer" if params["input_options"]["header1"] else None | 211 header = "infer" if params["input_options"]["header1"] else None |
208 column_option = params["input_options"]["column_selector_options_1"]["selected_column_selector_option"] | 212 column_option = params["input_options"]["column_selector_options_1"][ |
213 "selected_column_selector_option" | |
214 ] | |
209 if column_option in [ | 215 if column_option in [ |
210 "by_index_number", | 216 "by_index_number", |
211 "all_but_by_index_number", | 217 "all_but_by_index_number", |
212 "by_header_name", | 218 "by_header_name", |
213 "all_but_by_header_name", | 219 "all_but_by_header_name", |
259 n_intervals = sum(1 for line in open(intervals)) | 265 n_intervals = sum(1 for line in open(intervals)) |
260 X = np.arange(n_intervals)[:, np.newaxis] | 266 X = np.arange(n_intervals)[:, np.newaxis] |
261 | 267 |
262 # Get target y | 268 # Get target y |
263 header = "infer" if params["input_options"]["header2"] else None | 269 header = "infer" if params["input_options"]["header2"] else None |
264 column_option = params["input_options"]["column_selector_options_2"]["selected_column_selector_option2"] | 270 column_option = params["input_options"]["column_selector_options_2"][ |
271 "selected_column_selector_option2" | |
272 ] | |
265 if column_option in [ | 273 if column_option in [ |
266 "by_index_number", | 274 "by_index_number", |
267 "all_but_by_index_number", | 275 "all_but_by_index_number", |
268 "by_header_name", | 276 "by_header_name", |
269 "all_but_by_header_name", | 277 "all_but_by_header_name", |
277 infile2 = loaded_df[df_key] | 285 infile2 = loaded_df[df_key] |
278 else: | 286 else: |
279 infile2 = pd.read_csv(infile2, sep="\t", header=header, parse_dates=True) | 287 infile2 = pd.read_csv(infile2, sep="\t", header=header, parse_dates=True) |
280 loaded_df[df_key] = infile2 | 288 loaded_df[df_key] = infile2 |
281 | 289 |
282 y = read_columns(infile2, c=c, c_option=column_option, sep="\t", header=header, parse_dates=True) | 290 y = read_columns( |
291 infile2, c=c, c_option=column_option, sep="\t", header=header, parse_dates=True | |
292 ) | |
283 if len(y.shape) == 2 and y.shape[1] == 1: | 293 if len(y.shape) == 2 and y.shape[1] == 1: |
284 y = y.ravel() | 294 y = y.ravel() |
285 if input_type == "refseq_and_interval": | 295 if input_type == "refseq_and_interval": |
286 estimator.set_params(data_batch_generator__features=y.ravel().tolist()) | 296 estimator.set_params(data_batch_generator__features=y.ravel().tolist()) |
287 y = None | 297 y = None |
376 if split_options["shuffle"] == "stratified": | 386 if split_options["shuffle"] == "stratified": |
377 split_options["labels"] = y | 387 split_options["labels"] = y |
378 X, X_test, y, y_test = train_test_split(X, y, **split_options) | 388 X, X_test, y, y_test = train_test_split(X, y, **split_options) |
379 elif split_options["shuffle"] == "group": | 389 elif split_options["shuffle"] == "group": |
380 if groups is None: | 390 if groups is None: |
381 raise ValueError("No group based CV option was choosen for " "group shuffle!") | 391 raise ValueError( |
392 "No group based CV option was choosen for " "group shuffle!" | |
393 ) | |
382 split_options["labels"] = groups | 394 split_options["labels"] = groups |
383 if y is None: | 395 if y is None: |
384 X, X_test, groups, _ = train_test_split(X, groups, **split_options) | 396 X, X_test, groups, _ = train_test_split(X, groups, **split_options) |
385 else: | 397 else: |
386 X, X_test, y, y_test, groups, _ = train_test_split(X, y, groups, **split_options) | 398 X, X_test, y, y_test, groups, _ = train_test_split( |
399 X, y, groups, **split_options | |
400 ) | |
387 else: | 401 else: |
388 if split_options["shuffle"] == "None": | 402 if split_options["shuffle"] == "None": |
389 split_options["shuffle"] = None | 403 split_options["shuffle"] = None |
390 X, X_test, y, y_test = train_test_split(X, y, **split_options) | 404 X, X_test, y, y_test = train_test_split(X, y, **split_options) |
391 | 405 |
409 | 423 |
410 best_estimator_ = getattr(searcher, "best_estimator_") | 424 best_estimator_ = getattr(searcher, "best_estimator_") |
411 | 425 |
412 # TODO Solve deep learning models in pipeline | 426 # TODO Solve deep learning models in pipeline |
413 if best_estimator_.__class__.__name__ == "KerasGBatchClassifier": | 427 if best_estimator_.__class__.__name__ == "KerasGBatchClassifier": |
414 test_score = best_estimator_.evaluate(X_test, scorer=scorer_, is_multimetric=is_multimetric) | 428 test_score = best_estimator_.evaluate( |
415 else: | 429 X_test, scorer=scorer_, is_multimetric=is_multimetric |
416 test_score = _score(best_estimator_, X_test, y_test, scorer_, is_multimetric=is_multimetric) | 430 ) |
431 else: | |
432 test_score = _score( | |
433 best_estimator_, X_test, y_test, scorer_, is_multimetric=is_multimetric | |
434 ) | |
417 | 435 |
418 if not is_multimetric: | 436 if not is_multimetric: |
419 test_score = {primary_scoring: test_score} | 437 test_score = {primary_scoring: test_score} |
420 for key, value in test_score.items(): | 438 for key, value in test_score.items(): |
421 test_score[key] = [value] | 439 test_score[key] = [value] |
485 | 503 |
486 with open(inputs, "r") as param_handler: | 504 with open(inputs, "r") as param_handler: |
487 params = json.load(param_handler) | 505 params = json.load(param_handler) |
488 | 506 |
489 # Override the refit parameter | 507 # Override the refit parameter |
490 params["search_schemes"]["options"]["refit"] = True if params["save"] != "nope" else False | 508 params["search_schemes"]["options"]["refit"] = ( |
509 True if params["save"] != "nope" else False | |
510 ) | |
491 | 511 |
492 with open(infile_estimator, "rb") as estimator_handler: | 512 with open(infile_estimator, "rb") as estimator_handler: |
493 estimator = load_model(estimator_handler) | 513 estimator = load_model(estimator_handler) |
494 | 514 |
495 optimizer = params["search_schemes"]["selected_search_scheme"] | 515 optimizer = params["search_schemes"]["selected_search_scheme"] |
497 | 517 |
498 # handle gridsearchcv options | 518 # handle gridsearchcv options |
499 options = params["search_schemes"]["options"] | 519 options = params["search_schemes"]["options"] |
500 | 520 |
501 if groups: | 521 if groups: |
502 header = "infer" if (options["cv_selector"]["groups_selector"]["header_g"]) else None | 522 header = ( |
503 column_option = options["cv_selector"]["groups_selector"]["column_selector_options_g"][ | 523 "infer" if (options["cv_selector"]["groups_selector"]["header_g"]) else None |
504 "selected_column_selector_option_g" | 524 ) |
505 ] | 525 column_option = options["cv_selector"]["groups_selector"][ |
526 "column_selector_options_g" | |
527 ]["selected_column_selector_option_g"] | |
506 if column_option in [ | 528 if column_option in [ |
507 "by_index_number", | 529 "by_index_number", |
508 "all_but_by_index_number", | 530 "all_but_by_index_number", |
509 "by_header_name", | 531 "by_header_name", |
510 "all_but_by_header_name", | 532 "all_but_by_header_name", |
511 ]: | 533 ]: |
512 c = options["cv_selector"]["groups_selector"]["column_selector_options_g"]["col_g"] | 534 c = options["cv_selector"]["groups_selector"]["column_selector_options_g"][ |
535 "col_g" | |
536 ] | |
513 else: | 537 else: |
514 c = None | 538 c = None |
515 | 539 |
516 df_key = groups + repr(header) | 540 df_key = groups + repr(header) |
517 | 541 |
535 # get_scoring() expects secondary_scoring to be a comma separated string (not a list) | 559 # get_scoring() expects secondary_scoring to be a comma separated string (not a list) |
536 # Check if secondary_scoring is specified | 560 # Check if secondary_scoring is specified |
537 secondary_scoring = options["scoring"].get("secondary_scoring", None) | 561 secondary_scoring = options["scoring"].get("secondary_scoring", None) |
538 if secondary_scoring is not None: | 562 if secondary_scoring is not None: |
539 # If secondary_scoring is specified, convert the list into comman separated string | 563 # If secondary_scoring is specified, convert the list into comman separated string |
540 options["scoring"]["secondary_scoring"] = ",".join(options["scoring"]["secondary_scoring"]) | 564 options["scoring"]["secondary_scoring"] = ",".join( |
565 options["scoring"]["secondary_scoring"] | |
566 ) | |
541 options["scoring"] = get_scoring(options["scoring"]) | 567 options["scoring"] = get_scoring(options["scoring"]) |
542 if options["error_score"]: | 568 if options["error_score"]: |
543 options["error_score"] = "raise" | 569 options["error_score"] = "raise" |
544 else: | 570 else: |
545 options["error_score"] = np.NaN | 571 options["error_score"] = np.nan |
546 if options["refit"] and isinstance(options["scoring"], dict): | 572 if options["refit"] and isinstance(options["scoring"], dict): |
547 options["refit"] = primary_scoring | 573 options["refit"] = primary_scoring |
548 if "pre_dispatch" in options and options["pre_dispatch"] == "": | 574 if "pre_dispatch" in options and options["pre_dispatch"] == "": |
549 options["pre_dispatch"] = None | 575 options["pre_dispatch"] = None |
550 | 576 |
586 | 612 |
587 if split_mode == "nested_cv": | 613 if split_mode == "nested_cv": |
588 # make sure refit is choosen | 614 # make sure refit is choosen |
589 # this could be True for sklearn models, but not the case for | 615 # this could be True for sklearn models, but not the case for |
590 # deep learning models | 616 # deep learning models |
591 if not options["refit"] and not all(hasattr(estimator, attr) for attr in ("config", "model_type")): | 617 if not options["refit"] and not all( |
618 hasattr(estimator, attr) for attr in ("config", "model_type") | |
619 ): | |
592 warnings.warn("Refit is change to `True` for nested validation!") | 620 warnings.warn("Refit is change to `True` for nested validation!") |
593 setattr(searcher, "refit", True) | 621 setattr(searcher, "refit", True) |
594 | 622 |
595 outer_cv, _ = get_cv(params["outer_split"]["cv_selector"]) | 623 outer_cv, _ = get_cv(params["outer_split"]["cv_selector"]) |
596 # nested CV, outer cv using cross_validate | 624 # nested CV, outer cv using cross_validate |
685 for warning in w: | 713 for warning in w: |
686 print(repr(warning.message)) | 714 print(repr(warning.message)) |
687 | 715 |
688 cv_results = pd.DataFrame(searcher.cv_results_) | 716 cv_results = pd.DataFrame(searcher.cv_results_) |
689 cv_results = cv_results[sorted(cv_results.columns)] | 717 cv_results = cv_results[sorted(cv_results.columns)] |
690 cv_results.to_csv(path_or_buf=outfile_result, sep="\t", header=True, index=False) | 718 cv_results.to_csv( |
719 path_or_buf=outfile_result, sep="\t", header=True, index=False | |
720 ) | |
691 | 721 |
692 memory.clear(warn=False) | 722 memory.clear(warn=False) |
693 | 723 |
694 # output best estimator, and weights if applicable | 724 # output best estimator, and weights if applicable |
695 if outfile_object: | 725 if outfile_object: |