Mercurial > repos > bgruening > sklearn_nn_classifier
diff fitted_model_eval.py @ 27:22f0b9db4ea1 draft
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 9981e25b00de29ed881b2229a173a8c812ded9bb
author | bgruening |
---|---|
date | Wed, 09 Aug 2023 12:57:05 +0000 |
parents | 823ecc0bce45 |
children |
line wrap: on
line diff
--- a/fitted_model_eval.py Thu Aug 11 09:54:23 2022 +0000 +++ b/fitted_model_eval.py Wed Aug 09 12:57:05 2023 +0000 @@ -3,11 +3,11 @@ import warnings import pandas as pd -from galaxy_ml.utils import get_scoring, load_model, read_columns +from galaxy_ml.model_persist import load_model_from_h5 +from galaxy_ml.utils import clean_params, get_scoring, read_columns from scipy.io import mmread -from sklearn.metrics.scorer import _check_multimetric_scoring +from sklearn.metrics._scorer import _check_multimetric_scoring from sklearn.model_selection._validation import _score -from sklearn.pipeline import Pipeline def _get_X_y(params, infile1, infile2): @@ -75,7 +75,12 @@ loaded_df[df_key] = infile2 y = read_columns( - infile2, c=c, c_option=column_option, sep="\t", header=header, parse_dates=True + infile2, + c=c, + c_option=column_option, + sep="\t", + header=header, + parse_dates=True, ) if len(y.shape) == 2 and y.shape[1] == 1: y = y.ravel() @@ -83,14 +88,7 @@ return X, y -def main( - inputs, - infile_estimator, - outfile_eval, - infile_weights=None, - infile1=None, - infile2=None, -): +def main(inputs, infile_estimator, outfile_eval, infile1=None, infile2=None): """ Parameter --------- @@ -103,9 +101,6 @@ outfile_eval : str File path to save the evalulation results, tabular - infile_weights : str - File path to weights input - infile1 : str File path to dataset containing features @@ -120,40 +115,20 @@ X_test, y_test = _get_X_y(params, infile1, infile2) # load model - with open(infile_estimator, "rb") as est_handler: - estimator = load_model(est_handler) - - main_est = estimator - if isinstance(estimator, Pipeline): - main_est = estimator.steps[-1][-1] - if hasattr(main_est, "config") and hasattr(main_est, "load_weights"): - if not infile_weights or infile_weights == "None": - raise ValueError( - "The selected model skeleton asks for weights, " - "but no dataset for weights was provided!" - ) - main_est.load_weights(infile_weights) + estimator = load_model_from_h5(infile_estimator) + estimator = clean_params(estimator) # handle scorer, convert to scorer dict - # Check if scoring is specified scoring = params["scoring"] - if scoring is not None: - # get_scoring() expects secondary_scoring to be a comma separated string (not a list) - # Check if secondary_scoring is specified - secondary_scoring = scoring.get("secondary_scoring", None) - if secondary_scoring is not None: - # If secondary_scoring is specified, convert the list into comman separated string - scoring["secondary_scoring"] = ",".join(scoring["secondary_scoring"]) - scorer = get_scoring(scoring) - scorer, _ = _check_multimetric_scoring(estimator, scoring=scorer) + if not isinstance(scorer, (dict, list)): + scorer = [scoring["primary_scoring"]] + scorer = _check_multimetric_scoring(estimator, scoring=scorer) if hasattr(estimator, "evaluate"): - scores = estimator.evaluate( - X_test, y_test=y_test, scorer=scorer, is_multimetric=True - ) + scores = estimator.evaluate(X_test, y_test=y_test, scorer=scorer) else: - scores = _score(estimator, X_test, y_test, scorer, is_multimetric=True) + scores = _score(estimator, X_test, y_test, scorer) # handle output for name, score in scores.items(): @@ -167,7 +142,6 @@ aparser = argparse.ArgumentParser() aparser.add_argument("-i", "--inputs", dest="inputs", required=True) aparser.add_argument("-e", "--infile_estimator", dest="infile_estimator") - aparser.add_argument("-w", "--infile_weights", dest="infile_weights") aparser.add_argument("-X", "--infile1", dest="infile1") aparser.add_argument("-y", "--infile2", dest="infile2") aparser.add_argument("-O", "--outfile_eval", dest="outfile_eval") @@ -177,7 +151,6 @@ args.inputs, args.infile_estimator, args.outfile_eval, - infile_weights=args.infile_weights, infile1=args.infile1, infile2=args.infile2, )