Mercurial > repos > bgruening > sklearn_discriminant_classifier
diff train_test_eval.py @ 41:d769d83ec796 draft
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 9981e25b00de29ed881b2229a173a8c812ded9bb
author | bgruening |
---|---|
date | Wed, 09 Aug 2023 13:14:12 +0000 |
parents | e76f6dfea5c9 |
children |
line wrap: on
line diff
--- a/train_test_eval.py Thu Aug 11 08:53:29 2022 +0000 +++ b/train_test_eval.py Wed Aug 09 13:14:12 2023 +0000 @@ -1,22 +1,27 @@ import argparse import json import os -import pickle import warnings from itertools import chain import joblib import numpy as np import pandas as pd +from galaxy_ml.model_persist import dump_model_to_h5, load_model_from_h5 from galaxy_ml.model_validations import train_test_split -from galaxy_ml.utils import (get_module, get_scoring, load_model, - read_columns, SafeEval, try_get_attr) +from galaxy_ml.utils import ( + clean_params, + get_module, + get_scoring, + read_columns, + SafeEval, + try_get_attr +) from scipy.io import mmread from sklearn import pipeline -from sklearn.metrics.scorer import _check_multimetric_scoring from sklearn.model_selection import _search, _validation from sklearn.model_selection._validation import _score -from sklearn.utils import indexable, safe_indexing +from sklearn.utils import _safe_indexing, indexable _fit_and_score = try_get_attr("galaxy_ml.model_validations", "_fit_and_score") setattr(_search, "_fit_and_score", _fit_and_score) @@ -93,7 +98,7 @@ train = index_arr[~np.isin(groups, group_names)] rval = list( chain.from_iterable( - (safe_indexing(a, train), safe_indexing(a, test)) for a in new_arrays + (_safe_indexing(a, train), _safe_indexing(a, test)) for a in new_arrays ) ) else: @@ -164,8 +169,8 @@ params = json.load(param_handler) # load estimator - with open(infile_estimator, "rb") as estimator_handler: - estimator = load_model(estimator_handler) + estimator = load_model_from_h5(infile_estimator) + estimator = clean_params(estimator) # swap hyperparameter swapping = params["experiment_schemes"]["hyperparams_swapping"] @@ -348,7 +353,6 @@ # If secondary_scoring is specified, convert the list into comman separated string scoring["secondary_scoring"] = ",".join(scoring["secondary_scoring"]) scorer = get_scoring(scoring) - scorer, _ = _check_multimetric_scoring(estimator, scoring=scorer) # handle test (first) split test_split_options = params["experiment_schemes"]["test_split"]["split_algos"] @@ -412,7 +416,7 @@ X_test, y_test=y_test, scorer=scorer, is_multimetric=True ) else: - scores = _score(estimator, X_test, y_test, scorer, is_multimetric=True) + scores = _score(estimator, X_test, y_test, scorer) # handle output for name, score in scores.items(): scores[name] = [score] @@ -441,8 +445,7 @@ if getattr(main_est, "data_generator_", None): del main_est.data_generator_ - with open(outfile_object, "wb") as output_handler: - pickle.dump(estimator, output_handler, pickle.HIGHEST_PROTOCOL) + dump_model_to_h5(estimator, outfile_object) if __name__ == "__main__":