comparison stacking_ensembles.py @ 15:6eb4e7fb0f91 draft

planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 9981e25b00de29ed881b2229a173a8c812ded9bb
author bgruening
date Wed, 09 Aug 2023 13:23:40 +0000
parents 0c933465d70e
children
comparison
equal deleted inserted replaced
14:8a794e6d3388 15:6eb4e7fb0f91
1 import argparse 1 import argparse
2 import ast 2 import ast
3 import json 3 import json
4 import pickle
5 import sys 4 import sys
6 import warnings 5 import warnings
6 from distutils.version import LooseVersion as Version
7 7
8 import mlxtend.classifier 8 import mlxtend.classifier
9 import mlxtend.regressor 9 import mlxtend.regressor
10 import pandas as pd 10 from galaxy_ml import __version__ as galaxy_ml_version
11 from galaxy_ml.utils import (get_cv, get_estimator, get_search_params, 11 from galaxy_ml.model_persist import dump_model_to_h5, load_model_from_h5
12 load_model) 12 from galaxy_ml.utils import get_cv, get_estimator
13 13
14 warnings.filterwarnings("ignore") 14 warnings.filterwarnings("ignore")
15 15
16 N_JOBS = int(__import__("os").environ.get("GALAXY_SLOTS", 1)) 16 N_JOBS = int(__import__("os").environ.get("GALAXY_SLOTS", 1))
17 17
18 18
19 def main(inputs_path, output_obj, base_paths=None, meta_path=None, outfile_params=None): 19 def main(inputs_path, output_obj, base_paths=None, meta_path=None):
20 """ 20 """
21 Parameter 21 Parameter
22 --------- 22 ---------
23 inputs_path : str 23 inputs_path : str
24 File path for Galaxy parameters 24 File path for Galaxy parameters
29 base_paths : str 29 base_paths : str
30 File path or paths concatenated by comma. 30 File path or paths concatenated by comma.
31 31
32 meta_path : str 32 meta_path : str
33 File path 33 File path
34
35 outfile_params : str
36 File path for params output
37 """ 34 """
38 with open(inputs_path, "r") as param_handler: 35 with open(inputs_path, "r") as param_handler:
39 params = json.load(param_handler) 36 params = json.load(param_handler)
40 37
41 estimator_type = params["algo_selection"]["estimator_type"] 38 estimator_type = params["algo_selection"]["estimator_type"]
42 # get base estimators 39 # get base estimators
43 base_estimators = [] 40 base_estimators = []
44 for idx, base_file in enumerate(base_paths.split(",")): 41 for idx, base_file in enumerate(base_paths.split(",")):
45 if base_file and base_file != "None": 42 if base_file and base_file != "None":
46 with open(base_file, "rb") as handler: 43 model = load_model_from_h5(base_file)
47 model = load_model(handler)
48 else: 44 else:
49 estimator_json = params["base_est_builder"][idx]["estimator_selector"] 45 estimator_json = params["base_est_builder"][idx]["estimator_selector"]
50 model = get_estimator(estimator_json) 46 model = get_estimator(estimator_json)
51 47
52 if estimator_type.startswith("sklearn"): 48 if estimator_type.startswith("sklearn"):
57 base_estimators.append(model) 53 base_estimators.append(model)
58 54
59 # get meta estimator, if applicable 55 # get meta estimator, if applicable
60 if estimator_type.startswith("mlxtend"): 56 if estimator_type.startswith("mlxtend"):
61 if meta_path: 57 if meta_path:
62 with open(meta_path, "rb") as f: 58 meta_estimator = load_model_from_h5(meta_path)
63 meta_estimator = load_model(f)
64 else: 59 else:
65 estimator_json = params["algo_selection"]["meta_estimator"][ 60 estimator_json = params["algo_selection"]["meta_estimator"][
66 "estimator_selector" 61 "estimator_selector"
67 ] 62 ]
68 meta_estimator = get_estimator(estimator_json) 63 meta_estimator = get_estimator(estimator_json)
69 64
70 options = params["algo_selection"]["options"] 65 options = params["algo_selection"]["options"]
71 66
72 cv_selector = options.pop("cv_selector", None) 67 cv_selector = options.pop("cv_selector", None)
73 if cv_selector: 68 if cv_selector:
74 splitter, _groups = get_cv(cv_selector) 69 if Version(galaxy_ml_version) < Version("0.8.3"):
70 cv_selector.pop("n_stratification_bins", None)
71 splitter, groups = get_cv(cv_selector)
75 options["cv"] = splitter 72 options["cv"] = splitter
76 # set n_jobs 73 # set n_jobs
77 options["n_jobs"] = N_JOBS 74 options["n_jobs"] = N_JOBS
78 75
79 weights = options.pop("weights", None) 76 weights = options.pop("weights", None)
102 99
103 print(ensemble_estimator) 100 print(ensemble_estimator)
104 for base_est in base_estimators: 101 for base_est in base_estimators:
105 print(base_est) 102 print(base_est)
106 103
107 with open(output_obj, "wb") as out_handler: 104 dump_model_to_h5(ensemble_estimator, output_obj)
108 pickle.dump(ensemble_estimator, out_handler, pickle.HIGHEST_PROTOCOL)
109
110 if params["get_params"] and outfile_params:
111 results = get_search_params(ensemble_estimator)
112 df = pd.DataFrame(results, columns=["", "Parameter", "Value"])
113 df.to_csv(outfile_params, sep="\t", index=False)
114 105
115 106
116 if __name__ == "__main__": 107 if __name__ == "__main__":
117 aparser = argparse.ArgumentParser() 108 aparser = argparse.ArgumentParser()
118 aparser.add_argument("-b", "--bases", dest="bases") 109 aparser.add_argument("-b", "--bases", dest="bases")
119 aparser.add_argument("-m", "--meta", dest="meta") 110 aparser.add_argument("-m", "--meta", dest="meta")
120 aparser.add_argument("-i", "--inputs", dest="inputs") 111 aparser.add_argument("-i", "--inputs", dest="inputs")
121 aparser.add_argument("-o", "--outfile", dest="outfile") 112 aparser.add_argument("-o", "--outfile", dest="outfile")
122 aparser.add_argument("-p", "--outfile_params", dest="outfile_params")
123 args = aparser.parse_args() 113 args = aparser.parse_args()
124 114
125 main( 115 main(args.inputs, args.outfile, base_paths=args.bases, meta_path=args.meta)
126 args.inputs,
127 args.outfile,
128 base_paths=args.bases,
129 meta_path=args.meta,
130 outfile_params=args.outfile_params,
131 )