diff search_model_validation.py @ 15:6eb4e7fb0f91 draft

planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 9981e25b00de29ed881b2229a173a8c812ded9bb
author bgruening
date Wed, 09 Aug 2023 13:23:40 +0000
parents 0c933465d70e
children
line wrap: on
line diff
--- a/search_model_validation.py	Thu Aug 11 09:49:51 2022 +0000
+++ b/search_model_validation.py	Wed Aug 09 13:23:40 2023 +0000
@@ -1,35 +1,55 @@
 import argparse
-import collections
 import json
 import os
-import pickle
 import sys
 import warnings
+from distutils.version import LooseVersion as Version
 
 import imblearn
 import joblib
 import numpy as np
 import pandas as pd
 import skrebate
-from galaxy_ml.utils import (clean_params, get_cv,
-                             get_main_estimator, get_module, get_scoring,
-                             load_model, read_columns, SafeEval, try_get_attr)
+from galaxy_ml import __version__ as galaxy_ml_version
+from galaxy_ml.binarize_target import IRAPSClassifier
+from galaxy_ml.model_persist import dump_model_to_h5, load_model_from_h5
+from galaxy_ml.utils import (
+    clean_params,
+    get_cv,
+    get_main_estimator,
+    get_module,
+    get_scoring,
+    read_columns,
+    SafeEval,
+    try_get_attr
+)
 from scipy.io import mmread
-from sklearn import (cluster, decomposition, feature_selection,
-                     kernel_approximation, model_selection, preprocessing)
+from sklearn import (
+    cluster,
+    decomposition,
+    feature_selection,
+    kernel_approximation,
+    model_selection,
+    preprocessing,
+)
 from sklearn.exceptions import FitFailedWarning
 from sklearn.model_selection import _search, _validation
 from sklearn.model_selection._validation import _score, cross_validate
-
-_fit_and_score = try_get_attr("galaxy_ml.model_validations", "_fit_and_score")
-setattr(_search, "_fit_and_score", _fit_and_score)
-setattr(_validation, "_fit_and_score", _fit_and_score)
+from sklearn.preprocessing import LabelEncoder
+from skopt import BayesSearchCV
 
 N_JOBS = int(os.environ.get("GALAXY_SLOTS", 1))
 # handle  disk cache
 CACHE_DIR = os.path.join(os.getcwd(), "cached")
-del os
-NON_SEARCHABLE = ("n_jobs", "pre_dispatch", "memory", "_path", "nthread", "callbacks")
+NON_SEARCHABLE = (
+    "n_jobs",
+    "pre_dispatch",
+    "memory",
+    "_path",
+    "_dir",
+    "nthread",
+    "callbacks",
+)
 
 
 def _eval_search_params(params_builder):
@@ -100,33 +120,29 @@
                 imblearn.under_sampling.CondensedNearestNeighbour(
                     random_state=0, n_jobs=N_JOBS
                 ),
-                imblearn.under_sampling.EditedNearestNeighbours(
-                    random_state=0, n_jobs=N_JOBS
-                ),
-                imblearn.under_sampling.RepeatedEditedNearestNeighbours(
-                    random_state=0, n_jobs=N_JOBS
-                ),
-                imblearn.under_sampling.AllKNN(random_state=0, n_jobs=N_JOBS),
+                imblearn.under_sampling.EditedNearestNeighbours(n_jobs=N_JOBS),
+                imblearn.under_sampling.RepeatedEditedNearestNeighbours(n_jobs=N_JOBS),
+                imblearn.under_sampling.AllKNN(n_jobs=N_JOBS),
                 imblearn.under_sampling.InstanceHardnessThreshold(
                     random_state=0, n_jobs=N_JOBS
                 ),
-                imblearn.under_sampling.NearMiss(random_state=0, n_jobs=N_JOBS),
-                imblearn.under_sampling.NeighbourhoodCleaningRule(
-                    random_state=0, n_jobs=N_JOBS
-                ),
+                imblearn.under_sampling.NearMiss(n_jobs=N_JOBS),
+                imblearn.under_sampling.NeighbourhoodCleaningRule(n_jobs=N_JOBS),
                 imblearn.under_sampling.OneSidedSelection(
                     random_state=0, n_jobs=N_JOBS
                 ),
                 imblearn.under_sampling.RandomUnderSampler(random_state=0),
-                imblearn.under_sampling.TomekLinks(random_state=0, n_jobs=N_JOBS),
+                imblearn.under_sampling.TomekLinks(n_jobs=N_JOBS),
                 imblearn.over_sampling.ADASYN(random_state=0, n_jobs=N_JOBS),
+                imblearn.over_sampling.BorderlineSMOTE(random_state=0, n_jobs=N_JOBS),
+                imblearn.over_sampling.KMeansSMOTE(random_state=0, n_jobs=N_JOBS),
                 imblearn.over_sampling.RandomOverSampler(random_state=0),
                 imblearn.over_sampling.SMOTE(random_state=0, n_jobs=N_JOBS),
-                imblearn.over_sampling.SVMSMOTE(random_state=0, n_jobs=N_JOBS),
-                imblearn.over_sampling.BorderlineSMOTE(random_state=0, n_jobs=N_JOBS),
+                imblearn.over_sampling.SMOTEN(random_state=0, n_jobs=N_JOBS),
                 imblearn.over_sampling.SMOTENC(
                     categorical_features=[], random_state=0, n_jobs=N_JOBS
                 ),
+                imblearn.over_sampling.SVMSMOTE(random_state=0, n_jobs=N_JOBS),
                 imblearn.combine.SMOTEENN(random_state=0),
                 imblearn.combine.SMOTETomek(random_state=0),
             )
@@ -288,7 +304,12 @@
         loaded_df[df_key] = infile2
 
     y = read_columns(
-        infile2, c=c, c_option=column_option, sep="\t", header=header, parse_dates=True
+        infile2,
+        c=c,
+        c_option=column_option,
+        sep="\t",
+        header=header,
+        parse_dates=True,
     )
     if len(y.shape) == 2 and y.shape[1] == 1:
         y = y.ravel()
@@ -416,24 +437,19 @@
                 print(repr(warning.message))
 
     scorer_ = searcher.scorer_
-    if isinstance(scorer_, collections.Mapping):
-        is_multimetric = True
-    else:
-        is_multimetric = False
 
     best_estimator_ = getattr(searcher, "best_estimator_")
 
     # TODO Solve deep learning models in pipeline
     if best_estimator_.__class__.__name__ == "KerasGBatchClassifier":
         test_score = best_estimator_.evaluate(
-            X_test, scorer=scorer_, is_multimetric=is_multimetric
+            X_test,
+            scorer=scorer_,
         )
     else:
-        test_score = _score(
-            best_estimator_, X_test, y_test, scorer_, is_multimetric=is_multimetric
-        )
+        test_score = _score(best_estimator_, X_test, y_test, scorer_)
 
-    if not is_multimetric:
+    if not isinstance(scorer_, dict):
         test_score = {primary_scoring: test_score}
     for key, value in test_score.items():
         test_score[key] = [value]
@@ -443,6 +459,34 @@
     return searcher
 
 
+def _set_memory(estimator, memory):
+    """set memeory cache
+
+    Parameters
+    ----------
+    estimator : python object
+    memory : joblib.Memory object
+
+    Returns
+    -------
+    estimator : estimator object after setting new attributes
+    """
+    if isinstance(estimator, IRAPSClassifier):
+        estimator.set_params(memory=memory)
+        return estimator
+
+    estimator_params = estimator.get_params()
+
+    new_params = {}
+    for k in estimator_params.keys():
+        if k.endswith("irapsclassifier__memory"):
+            new_params[k] = memory
+
+    estimator.set_params(**new_params)
+
+    return estimator
+
+
 def main(
     inputs,
     infile_estimator,
@@ -450,7 +494,6 @@
     infile2,
     outfile_result,
     outfile_object=None,
-    outfile_weights=None,
     groups=None,
     ref_seq=None,
     intervals=None,
@@ -461,10 +504,10 @@
     Parameter
     ---------
     inputs : str
-        File path to galaxy tool parameter
+        File path to galaxy tool parameter.
 
     infile_estimator : str
-        File path to estimator
+        File path to estimator.
 
     infile1 : str
         File path to dataset containing features
@@ -478,9 +521,6 @@
     outfile_object : str, optional
         File path to save searchCV object
 
-    outfile_weights : str, optional
-        File path to save model weights
-
     groups : str
         File path to dataset containing groups labels
 
@@ -505,18 +545,38 @@
         params = json.load(param_handler)
 
     # Override the refit parameter
-    params["search_schemes"]["options"]["refit"] = (
-        True if params["save"] != "nope" else False
+    params["options"]["refit"] = (
+        True
+        if (
+            params["save"] != "nope"
+            or params["outer_split"]["split_mode"] == "nested_cv"
+        )
+        else False
     )
 
-    with open(infile_estimator, "rb") as estimator_handler:
-        estimator = load_model(estimator_handler)
+    estimator = load_model_from_h5(infile_estimator)
+
+    estimator = clean_params(estimator)
+
+    if estimator.__class__.__name__ == "KerasGBatchClassifier":
+        _fit_and_score = try_get_attr(
+            "galaxy_ml.model_validations",
+            "_fit_and_score",
+        )
 
-    optimizer = params["search_schemes"]["selected_search_scheme"]
-    optimizer = getattr(model_selection, optimizer)
+        setattr(_search, "_fit_and_score", _fit_and_score)
+        setattr(_validation, "_fit_and_score", _fit_and_score)
+
+    search_algos_and_options = params["search_algos"]
+    optimizer = search_algos_and_options.pop("selected_search_algo")
+    if optimizer == "skopt.BayesSearchCV":
+        optimizer = BayesSearchCV
+    else:
+        optimizer = getattr(model_selection, optimizer)
 
     # handle gridsearchcv options
-    options = params["search_schemes"]["options"]
+    options = params["options"]
+    options.update(search_algos_and_options)
 
     if groups:
         header = (
@@ -553,38 +613,36 @@
         groups = groups.ravel()
         options["cv_selector"]["groups_selector"] = groups
 
-    splitter, groups = get_cv(options.pop("cv_selector"))
+    cv_selector = options.pop("cv_selector")
+    if Version(galaxy_ml_version) < Version("0.8.3"):
+        cv_selector.pop("n_stratification_bins", None)
+    splitter, groups = get_cv(cv_selector)
     options["cv"] = splitter
     primary_scoring = options["scoring"]["primary_scoring"]
-    # get_scoring() expects secondary_scoring to be a comma separated string (not a list)
-    # Check if secondary_scoring is specified
-    secondary_scoring = options["scoring"].get("secondary_scoring", None)
-    if secondary_scoring is not None:
-        # If secondary_scoring is specified, convert the list into comman separated string
-        options["scoring"]["secondary_scoring"] = ",".join(
-            options["scoring"]["secondary_scoring"]
+    options["scoring"] = get_scoring(options["scoring"])
+    # TODO make BayesSearchCV support multiple scoring
+    if optimizer == "skopt.BayesSearchCV" and isinstance(options["scoring"], dict):
+        options["scoring"] = options["scoring"][primary_scoring]
+        warnings.warn(
+            "BayesSearchCV doesn't support multiple "
+            "scorings! Primary scoring is used."
         )
-    options["scoring"] = get_scoring(options["scoring"])
     if options["error_score"]:
         options["error_score"] = "raise"
     else:
-        options["error_score"] = np.nan
+        options["error_score"] = np.NaN
     if options["refit"] and isinstance(options["scoring"], dict):
         options["refit"] = primary_scoring
     if "pre_dispatch" in options and options["pre_dispatch"] == "":
         options["pre_dispatch"] = None
 
-    params_builder = params["search_schemes"]["search_params_builder"]
+    params_builder = params["search_params_builder"]
     param_grid = _eval_search_params(params_builder)
 
-    estimator = clean_params(estimator)
-
     # save the SearchCV object without fit
     if params["save"] == "save_no_fit":
         searcher = optimizer(estimator, param_grid, **options)
-        print(searcher)
-        with open(outfile_object, "wb") as output_handler:
-            pickle.dump(searcher, output_handler, pickle.HIGHEST_PROTOCOL)
+        dump_model_to_h5(searcher, outfile_object)
         return 0
 
     # read inputs and loads new attributes, like paths
@@ -600,37 +658,36 @@
         fasta_path=fasta_path,
     )
 
+    label_encoder = LabelEncoder()
+    if get_main_estimator(estimator).__class__.__name__ == "XGBClassifier":
+        y = label_encoder.fit_transform(y)
+
     # cache iraps_core fits could increase search speed significantly
     memory = joblib.Memory(location=CACHE_DIR, verbose=0)
-    main_est = get_main_estimator(estimator)
-    if main_est.__class__.__name__ == "IRAPSClassifier":
-        main_est.set_params(memory=memory)
+    estimator = _set_memory(estimator, memory)
 
     searcher = optimizer(estimator, param_grid, **options)
 
     split_mode = params["outer_split"].pop("split_mode")
 
+    # Nested CV
     if split_mode == "nested_cv":
-        # make sure refit is choosen
-        # this could be True for sklearn models, but not the case for
-        # deep learning models
-        if not options["refit"] and not all(
-            hasattr(estimator, attr) for attr in ("config", "model_type")
-        ):
-            warnings.warn("Refit is change to `True` for nested validation!")
-            setattr(searcher, "refit", True)
-
-        outer_cv, _ = get_cv(params["outer_split"]["cv_selector"])
+        cv_selector = params["outer_split"]["cv_selector"]
+        if Version(galaxy_ml_version) < Version("0.8.3"):
+            cv_selector.pop("n_stratification_bins", None)
+        outer_cv, _ = get_cv(cv_selector)
         # nested CV, outer cv using cross_validate
         if options["error_score"] == "raise":
             rval = cross_validate(
                 searcher,
                 X,
                 y,
+                groups=groups,
                 scoring=options["scoring"],
                 cv=outer_cv,
                 n_jobs=N_JOBS,
                 verbose=options["verbose"],
+                fit_params={"groups": groups},
                 return_estimator=(params["save"] == "save_estimator"),
                 error_score=options["error_score"],
                 return_train_score=True,
@@ -643,10 +700,12 @@
                         searcher,
                         X,
                         y,
+                        groups=groups,
                         scoring=options["scoring"],
                         cv=outer_cv,
                         n_jobs=N_JOBS,
                         verbose=options["verbose"],
+                        fit_params={"groups": groups},
                         return_estimator=(params["save"] == "save_estimator"),
                         error_score=options["error_score"],
                         return_train_score=True,
@@ -676,8 +735,6 @@
                     cv_results_.to_csv(target_path, sep="\t", header=True, index=False)
             except Exception as e:
                 print(e)
-            finally:
-                del os
 
         keys = list(rval.keys())
         for k in keys:
@@ -689,6 +746,9 @@
         rval = pd.DataFrame(rval)
         rval = rval[sorted(rval.columns)]
         rval.to_csv(path_or_buf=outfile_result, sep="\t", header=True, index=False)
+
+        return 0
+
         # deprecate train test split mode
         """searcher = _do_train_test_split_val(
             searcher, X, y, params,
@@ -696,7 +756,6 @@
             error_score=options['error_score'],
             groups=groups,
             outfile=outfile_result)"""
-        return 0
 
     # no outer split
     else:
@@ -732,24 +791,7 @@
             )
             return
 
-        # clean prams
-        best_estimator_ = clean_params(best_estimator_)
-
-        main_est = get_main_estimator(best_estimator_)
-
-        if hasattr(main_est, "model_") and hasattr(main_est, "save_weights"):
-            if outfile_weights:
-                main_est.save_weights(outfile_weights)
-            del main_est.model_
-            del main_est.fit_params
-            del main_est.model_class_
-            del main_est.validation_data
-            if getattr(main_est, "data_generator_", None):
-                del main_est.data_generator_
-
-        with open(outfile_object, "wb") as output_handler:
-            print("Best estimator is saved: %s " % repr(best_estimator_))
-            pickle.dump(best_estimator_, output_handler, pickle.HIGHEST_PROTOCOL)
+        dump_model_to_h5(best_estimator_, outfile_object)
 
 
 if __name__ == "__main__":
@@ -760,7 +802,6 @@
     aparser.add_argument("-y", "--infile2", dest="infile2")
     aparser.add_argument("-O", "--outfile_result", dest="outfile_result")
     aparser.add_argument("-o", "--outfile_object", dest="outfile_object")
-    aparser.add_argument("-w", "--outfile_weights", dest="outfile_weights")
     aparser.add_argument("-g", "--groups", dest="groups")
     aparser.add_argument("-r", "--ref_seq", dest="ref_seq")
     aparser.add_argument("-b", "--intervals", dest="intervals")
@@ -768,17 +809,4 @@
     aparser.add_argument("-f", "--fasta_path", dest="fasta_path")
     args = aparser.parse_args()
 
-    main(
-        args.inputs,
-        args.infile_estimator,
-        args.infile1,
-        args.infile2,
-        args.outfile_result,
-        outfile_object=args.outfile_object,
-        outfile_weights=args.outfile_weights,
-        groups=args.groups,
-        ref_seq=args.ref_seq,
-        intervals=args.intervals,
-        targets=args.targets,
-        fasta_path=args.fasta_path,
-    )
+    main(**vars(args))