Mercurial > repos > bgruening > sklearn_train_test_eval
comparison train_test_eval.py @ 15:2eb5c017958d draft
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 9981e25b00de29ed881b2229a173a8c812ded9bb
author | bgruening |
---|---|
date | Wed, 09 Aug 2023 13:15:27 +0000 |
parents | caf7d2b71a48 |
children |
comparison
equal
deleted
inserted
replaced
14:4d1637cac794 | 15:2eb5c017958d |
---|---|
1 import argparse | 1 import argparse |
2 import json | 2 import json |
3 import os | 3 import os |
4 import pickle | |
5 import warnings | 4 import warnings |
6 from itertools import chain | 5 from itertools import chain |
7 | 6 |
8 import joblib | 7 import joblib |
9 import numpy as np | 8 import numpy as np |
10 import pandas as pd | 9 import pandas as pd |
10 from galaxy_ml.model_persist import dump_model_to_h5, load_model_from_h5 | |
11 from galaxy_ml.model_validations import train_test_split | 11 from galaxy_ml.model_validations import train_test_split |
12 from galaxy_ml.utils import (get_module, get_scoring, load_model, | 12 from galaxy_ml.utils import ( |
13 read_columns, SafeEval, try_get_attr) | 13 clean_params, |
14 get_module, | |
15 get_scoring, | |
16 read_columns, | |
17 SafeEval, | |
18 try_get_attr | |
19 ) | |
14 from scipy.io import mmread | 20 from scipy.io import mmread |
15 from sklearn import pipeline | 21 from sklearn import pipeline |
16 from sklearn.metrics.scorer import _check_multimetric_scoring | |
17 from sklearn.model_selection import _search, _validation | 22 from sklearn.model_selection import _search, _validation |
18 from sklearn.model_selection._validation import _score | 23 from sklearn.model_selection._validation import _score |
19 from sklearn.utils import indexable, safe_indexing | 24 from sklearn.utils import _safe_indexing, indexable |
20 | 25 |
21 _fit_and_score = try_get_attr("galaxy_ml.model_validations", "_fit_and_score") | 26 _fit_and_score = try_get_attr("galaxy_ml.model_validations", "_fit_and_score") |
22 setattr(_search, "_fit_and_score", _fit_and_score) | 27 setattr(_search, "_fit_and_score", _fit_and_score) |
23 setattr(_validation, "_fit_and_score", _fit_and_score) | 28 setattr(_validation, "_fit_and_score", _fit_and_score) |
24 | 29 |
91 index_arr = np.arange(n_samples) | 96 index_arr = np.arange(n_samples) |
92 test = index_arr[np.isin(groups, group_names)] | 97 test = index_arr[np.isin(groups, group_names)] |
93 train = index_arr[~np.isin(groups, group_names)] | 98 train = index_arr[~np.isin(groups, group_names)] |
94 rval = list( | 99 rval = list( |
95 chain.from_iterable( | 100 chain.from_iterable( |
96 (safe_indexing(a, train), safe_indexing(a, test)) for a in new_arrays | 101 (_safe_indexing(a, train), _safe_indexing(a, test)) for a in new_arrays |
97 ) | 102 ) |
98 ) | 103 ) |
99 else: | 104 else: |
100 rval = train_test_split(*new_arrays, **kwargs) | 105 rval = train_test_split(*new_arrays, **kwargs) |
101 | 106 |
162 | 167 |
163 with open(inputs, "r") as param_handler: | 168 with open(inputs, "r") as param_handler: |
164 params = json.load(param_handler) | 169 params = json.load(param_handler) |
165 | 170 |
166 # load estimator | 171 # load estimator |
167 with open(infile_estimator, "rb") as estimator_handler: | 172 estimator = load_model_from_h5(infile_estimator) |
168 estimator = load_model(estimator_handler) | 173 estimator = clean_params(estimator) |
169 | 174 |
170 # swap hyperparameter | 175 # swap hyperparameter |
171 swapping = params["experiment_schemes"]["hyperparams_swapping"] | 176 swapping = params["experiment_schemes"]["hyperparams_swapping"] |
172 swap_params = _eval_swap_params(swapping) | 177 swap_params = _eval_swap_params(swapping) |
173 estimator.set_params(**swap_params) | 178 estimator.set_params(**swap_params) |
346 secondary_scoring = scoring.get("secondary_scoring", None) | 351 secondary_scoring = scoring.get("secondary_scoring", None) |
347 if secondary_scoring is not None: | 352 if secondary_scoring is not None: |
348 # If secondary_scoring is specified, convert the list into comman separated string | 353 # If secondary_scoring is specified, convert the list into comman separated string |
349 scoring["secondary_scoring"] = ",".join(scoring["secondary_scoring"]) | 354 scoring["secondary_scoring"] = ",".join(scoring["secondary_scoring"]) |
350 scorer = get_scoring(scoring) | 355 scorer = get_scoring(scoring) |
351 scorer, _ = _check_multimetric_scoring(estimator, scoring=scorer) | |
352 | 356 |
353 # handle test (first) split | 357 # handle test (first) split |
354 test_split_options = params["experiment_schemes"]["test_split"]["split_algos"] | 358 test_split_options = params["experiment_schemes"]["test_split"]["split_algos"] |
355 | 359 |
356 if test_split_options["shuffle"] == "group": | 360 if test_split_options["shuffle"] == "group": |
410 if hasattr(estimator, "evaluate"): | 414 if hasattr(estimator, "evaluate"): |
411 scores = estimator.evaluate( | 415 scores = estimator.evaluate( |
412 X_test, y_test=y_test, scorer=scorer, is_multimetric=True | 416 X_test, y_test=y_test, scorer=scorer, is_multimetric=True |
413 ) | 417 ) |
414 else: | 418 else: |
415 scores = _score(estimator, X_test, y_test, scorer, is_multimetric=True) | 419 scores = _score(estimator, X_test, y_test, scorer) |
416 # handle output | 420 # handle output |
417 for name, score in scores.items(): | 421 for name, score in scores.items(): |
418 scores[name] = [score] | 422 scores[name] = [score] |
419 df = pd.DataFrame(scores) | 423 df = pd.DataFrame(scores) |
420 df = df[sorted(df.columns)] | 424 df = df[sorted(df.columns)] |
439 if getattr(main_est, "validation_data", None): | 443 if getattr(main_est, "validation_data", None): |
440 del main_est.validation_data | 444 del main_est.validation_data |
441 if getattr(main_est, "data_generator_", None): | 445 if getattr(main_est, "data_generator_", None): |
442 del main_est.data_generator_ | 446 del main_est.data_generator_ |
443 | 447 |
444 with open(outfile_object, "wb") as output_handler: | 448 dump_model_to_h5(estimator, outfile_object) |
445 pickle.dump(estimator, output_handler, pickle.HIGHEST_PROTOCOL) | |
446 | 449 |
447 | 450 |
448 if __name__ == "__main__": | 451 if __name__ == "__main__": |
449 aparser = argparse.ArgumentParser() | 452 aparser = argparse.ArgumentParser() |
450 aparser.add_argument("-i", "--inputs", dest="inputs", required=True) | 453 aparser.add_argument("-i", "--inputs", dest="inputs", required=True) |