Mercurial > repos > bgruening > sklearn_mlxtend_association_rules
comparison stacking_ensembles.py @ 0:af2624d5ab32 draft
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit ea12f973df4b97a2691d9e4ce6bf6fae59d57717"
author | bgruening |
---|---|
date | Sat, 01 May 2021 01:24:32 +0000 |
parents | |
children | 9349ed2749c6 |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:af2624d5ab32 |
---|---|
1 import argparse | |
2 import ast | |
3 import json | |
4 import pickle | |
5 import sys | |
6 import warnings | |
7 | |
8 import mlxtend.classifier | |
9 import mlxtend.regressor | |
10 import pandas as pd | |
11 from galaxy_ml.utils import (get_cv, get_estimator, get_search_params, | |
12 load_model) | |
13 | |
14 warnings.filterwarnings("ignore") | |
15 | |
16 N_JOBS = int(__import__("os").environ.get("GALAXY_SLOTS", 1)) | |
17 | |
18 | |
19 def main(inputs_path, output_obj, base_paths=None, meta_path=None, outfile_params=None): | |
20 """ | |
21 Parameter | |
22 --------- | |
23 inputs_path : str | |
24 File path for Galaxy parameters | |
25 | |
26 output_obj : str | |
27 File path for ensemble estimator ouput | |
28 | |
29 base_paths : str | |
30 File path or paths concatenated by comma. | |
31 | |
32 meta_path : str | |
33 File path | |
34 | |
35 outfile_params : str | |
36 File path for params output | |
37 """ | |
38 with open(inputs_path, "r") as param_handler: | |
39 params = json.load(param_handler) | |
40 | |
41 estimator_type = params["algo_selection"]["estimator_type"] | |
42 # get base estimators | |
43 base_estimators = [] | |
44 for idx, base_file in enumerate(base_paths.split(",")): | |
45 if base_file and base_file != "None": | |
46 with open(base_file, "rb") as handler: | |
47 model = load_model(handler) | |
48 else: | |
49 estimator_json = params["base_est_builder"][idx]["estimator_selector"] | |
50 model = get_estimator(estimator_json) | |
51 | |
52 if estimator_type.startswith("sklearn"): | |
53 named = model.__class__.__name__.lower() | |
54 named = "base_%d_%s" % (idx, named) | |
55 base_estimators.append((named, model)) | |
56 else: | |
57 base_estimators.append(model) | |
58 | |
59 # get meta estimator, if applicable | |
60 if estimator_type.startswith("mlxtend"): | |
61 if meta_path: | |
62 with open(meta_path, "rb") as f: | |
63 meta_estimator = load_model(f) | |
64 else: | |
65 estimator_json = params["algo_selection"]["meta_estimator"][ | |
66 "estimator_selector" | |
67 ] | |
68 meta_estimator = get_estimator(estimator_json) | |
69 | |
70 options = params["algo_selection"]["options"] | |
71 | |
72 cv_selector = options.pop("cv_selector", None) | |
73 if cv_selector: | |
74 splitter, _groups = get_cv(cv_selector) | |
75 options["cv"] = splitter | |
76 # set n_jobs | |
77 options["n_jobs"] = N_JOBS | |
78 | |
79 weights = options.pop("weights", None) | |
80 if weights: | |
81 weights = ast.literal_eval(weights) | |
82 if weights: | |
83 options["weights"] = weights | |
84 | |
85 mod_and_name = estimator_type.split("_") | |
86 mod = sys.modules[mod_and_name[0]] | |
87 klass = getattr(mod, mod_and_name[1]) | |
88 | |
89 if estimator_type.startswith("sklearn"): | |
90 options["n_jobs"] = N_JOBS | |
91 ensemble_estimator = klass(base_estimators, **options) | |
92 | |
93 elif mod == mlxtend.classifier: | |
94 ensemble_estimator = klass( | |
95 classifiers=base_estimators, meta_classifier=meta_estimator, **options | |
96 ) | |
97 | |
98 else: | |
99 ensemble_estimator = klass( | |
100 regressors=base_estimators, meta_regressor=meta_estimator, **options | |
101 ) | |
102 | |
103 print(ensemble_estimator) | |
104 for base_est in base_estimators: | |
105 print(base_est) | |
106 | |
107 with open(output_obj, "wb") as out_handler: | |
108 pickle.dump(ensemble_estimator, out_handler, pickle.HIGHEST_PROTOCOL) | |
109 | |
110 if params["get_params"] and outfile_params: | |
111 results = get_search_params(ensemble_estimator) | |
112 df = pd.DataFrame(results, columns=["", "Parameter", "Value"]) | |
113 df.to_csv(outfile_params, sep="\t", index=False) | |
114 | |
115 | |
116 if __name__ == "__main__": | |
117 aparser = argparse.ArgumentParser() | |
118 aparser.add_argument("-b", "--bases", dest="bases") | |
119 aparser.add_argument("-m", "--meta", dest="meta") | |
120 aparser.add_argument("-i", "--inputs", dest="inputs") | |
121 aparser.add_argument("-o", "--outfile", dest="outfile") | |
122 aparser.add_argument("-p", "--outfile_params", dest="outfile_params") | |
123 args = aparser.parse_args() | |
124 | |
125 main( | |
126 args.inputs, | |
127 args.outfile, | |
128 base_paths=args.bases, | |
129 meta_path=args.meta, | |
130 outfile_params=args.outfile_params, | |
131 ) |