Mercurial > repos > bgruening > sklearn_svm_classifier
diff stacking_ensembles.py @ 25:b878e4cdd63a draft
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 9981e25b00de29ed881b2229a173a8c812ded9bb
author | bgruening |
---|---|
date | Wed, 09 Aug 2023 12:24:57 +0000 |
parents | 14fa42b095c4 |
children |
line wrap: on
line diff
--- a/stacking_ensembles.py Thu Aug 11 09:40:47 2022 +0000 +++ b/stacking_ensembles.py Wed Aug 09 12:24:57 2023 +0000 @@ -1,22 +1,22 @@ import argparse import ast import json -import pickle import sys import warnings +from distutils.version import LooseVersion as Version import mlxtend.classifier import mlxtend.regressor -import pandas as pd -from galaxy_ml.utils import (get_cv, get_estimator, get_search_params, - load_model) +from galaxy_ml import __version__ as galaxy_ml_version +from galaxy_ml.model_persist import dump_model_to_h5, load_model_from_h5 +from galaxy_ml.utils import get_cv, get_estimator warnings.filterwarnings("ignore") N_JOBS = int(__import__("os").environ.get("GALAXY_SLOTS", 1)) -def main(inputs_path, output_obj, base_paths=None, meta_path=None, outfile_params=None): +def main(inputs_path, output_obj, base_paths=None, meta_path=None): """ Parameter --------- @@ -31,9 +31,6 @@ meta_path : str File path - - outfile_params : str - File path for params output """ with open(inputs_path, "r") as param_handler: params = json.load(param_handler) @@ -43,8 +40,7 @@ base_estimators = [] for idx, base_file in enumerate(base_paths.split(",")): if base_file and base_file != "None": - with open(base_file, "rb") as handler: - model = load_model(handler) + model = load_model_from_h5(base_file) else: estimator_json = params["base_est_builder"][idx]["estimator_selector"] model = get_estimator(estimator_json) @@ -59,8 +55,7 @@ # get meta estimator, if applicable if estimator_type.startswith("mlxtend"): if meta_path: - with open(meta_path, "rb") as f: - meta_estimator = load_model(f) + meta_estimator = load_model_from_h5(meta_path) else: estimator_json = params["algo_selection"]["meta_estimator"][ "estimator_selector" @@ -71,7 +66,9 @@ cv_selector = options.pop("cv_selector", None) if cv_selector: - splitter, _groups = get_cv(cv_selector) + if Version(galaxy_ml_version) < Version("0.8.3"): + cv_selector.pop("n_stratification_bins", None) + splitter, groups = get_cv(cv_selector) options["cv"] = splitter # set n_jobs options["n_jobs"] = N_JOBS @@ -104,13 +101,7 @@ for base_est in base_estimators: print(base_est) - with open(output_obj, "wb") as out_handler: - pickle.dump(ensemble_estimator, out_handler, pickle.HIGHEST_PROTOCOL) - - if params["get_params"] and outfile_params: - results = get_search_params(ensemble_estimator) - df = pd.DataFrame(results, columns=["", "Parameter", "Value"]) - df.to_csv(outfile_params, sep="\t", index=False) + dump_model_to_h5(ensemble_estimator, output_obj) if __name__ == "__main__": @@ -119,13 +110,6 @@ aparser.add_argument("-m", "--meta", dest="meta") aparser.add_argument("-i", "--inputs", dest="inputs") aparser.add_argument("-o", "--outfile", dest="outfile") - aparser.add_argument("-p", "--outfile_params", dest="outfile_params") args = aparser.parse_args() - main( - args.inputs, - args.outfile, - base_paths=args.bases, - meta_path=args.meta, - outfile_params=args.outfile_params, - ) + main(args.inputs, args.outfile, base_paths=args.bases, meta_path=args.meta)