diff train_test_eval.py @ 15:2eb5c017958d draft

planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 9981e25b00de29ed881b2229a173a8c812ded9bb
author bgruening
date Wed, 09 Aug 2023 13:15:27 +0000
parents caf7d2b71a48
children
line wrap: on
line diff
--- a/train_test_eval.py	Thu Aug 11 09:13:40 2022 +0000
+++ b/train_test_eval.py	Wed Aug 09 13:15:27 2023 +0000
@@ -1,22 +1,27 @@
 import argparse
 import json
 import os
-import pickle
 import warnings
 from itertools import chain
 
 import joblib
 import numpy as np
 import pandas as pd
+from galaxy_ml.model_persist import dump_model_to_h5, load_model_from_h5
 from galaxy_ml.model_validations import train_test_split
-from galaxy_ml.utils import (get_module, get_scoring, load_model,
-                             read_columns, SafeEval, try_get_attr)
+from galaxy_ml.utils import (
+    clean_params,
+    get_module,
+    get_scoring,
+    read_columns,
+    SafeEval,
+    try_get_attr
+)
 from scipy.io import mmread
 from sklearn import pipeline
-from sklearn.metrics.scorer import _check_multimetric_scoring
 from sklearn.model_selection import _search, _validation
 from sklearn.model_selection._validation import _score
-from sklearn.utils import indexable, safe_indexing
+from sklearn.utils import _safe_indexing, indexable
 
 _fit_and_score = try_get_attr("galaxy_ml.model_validations", "_fit_and_score")
 setattr(_search, "_fit_and_score", _fit_and_score)
@@ -93,7 +98,7 @@
         train = index_arr[~np.isin(groups, group_names)]
         rval = list(
             chain.from_iterable(
-                (safe_indexing(a, train), safe_indexing(a, test)) for a in new_arrays
+                (_safe_indexing(a, train), _safe_indexing(a, test)) for a in new_arrays
             )
         )
     else:
@@ -164,8 +169,8 @@
         params = json.load(param_handler)
 
     #  load estimator
-    with open(infile_estimator, "rb") as estimator_handler:
-        estimator = load_model(estimator_handler)
+    estimator = load_model_from_h5(infile_estimator)
+    estimator = clean_params(estimator)
 
     # swap hyperparameter
     swapping = params["experiment_schemes"]["hyperparams_swapping"]
@@ -348,7 +353,6 @@
             # If secondary_scoring is specified, convert the list into comman separated string
             scoring["secondary_scoring"] = ",".join(scoring["secondary_scoring"])
     scorer = get_scoring(scoring)
-    scorer, _ = _check_multimetric_scoring(estimator, scoring=scorer)
 
     # handle test (first) split
     test_split_options = params["experiment_schemes"]["test_split"]["split_algos"]
@@ -412,7 +416,7 @@
             X_test, y_test=y_test, scorer=scorer, is_multimetric=True
         )
     else:
-        scores = _score(estimator, X_test, y_test, scorer, is_multimetric=True)
+        scores = _score(estimator, X_test, y_test, scorer)
     # handle output
     for name, score in scores.items():
         scores[name] = [score]
@@ -441,8 +445,7 @@
             if getattr(main_est, "data_generator_", None):
                 del main_est.data_generator_
 
-        with open(outfile_object, "wb") as output_handler:
-            pickle.dump(estimator, output_handler, pickle.HIGHEST_PROTOCOL)
+        dump_model_to_h5(estimator, outfile_object)
 
 
 if __name__ == "__main__":