Mercurial > repos > bgruening > sklearn_train_test_split
comparison stacking_ensembles.py @ 6:13b9ac5d277c draft
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 208a8d348e7c7a182cfbe1b6f17868146428a7e2"
author | bgruening |
---|---|
date | Tue, 13 Apr 2021 22:24:07 +0000 |
parents | 0985b0dd6f1a |
children | 3312fb686ffb |
comparison
equal
deleted
inserted
replaced
5:ce2fd1edbc6e | 6:13b9ac5d277c |
---|---|
1 import argparse | 1 import argparse |
2 import ast | 2 import ast |
3 import json | 3 import json |
4 import mlxtend.regressor | |
5 import mlxtend.classifier | |
6 import pandas as pd | |
7 import pickle | 4 import pickle |
8 import sklearn | |
9 import sys | 5 import sys |
10 import warnings | 6 import warnings |
11 from sklearn import ensemble | |
12 | 7 |
13 from galaxy_ml.utils import (load_model, get_cv, get_estimator, | 8 import mlxtend.classifier |
14 get_search_params) | 9 import mlxtend.regressor |
10 import pandas as pd | |
11 from galaxy_ml.utils import get_cv, get_estimator, get_search_params, load_model | |
15 | 12 |
16 | 13 |
17 warnings.filterwarnings('ignore') | 14 warnings.filterwarnings("ignore") |
18 | 15 |
19 N_JOBS = int(__import__('os').environ.get('GALAXY_SLOTS', 1)) | 16 N_JOBS = int(__import__("os").environ.get("GALAXY_SLOTS", 1)) |
20 | 17 |
21 | 18 |
22 def main(inputs_path, output_obj, base_paths=None, meta_path=None, | 19 def main(inputs_path, output_obj, base_paths=None, meta_path=None, outfile_params=None): |
23 outfile_params=None): | |
24 """ | 20 """ |
25 Parameter | 21 Parameter |
26 --------- | 22 --------- |
27 inputs_path : str | 23 inputs_path : str |
28 File path for Galaxy parameters | 24 File path for Galaxy parameters |
37 File path | 33 File path |
38 | 34 |
39 outfile_params : str | 35 outfile_params : str |
40 File path for params output | 36 File path for params output |
41 """ | 37 """ |
42 with open(inputs_path, 'r') as param_handler: | 38 with open(inputs_path, "r") as param_handler: |
43 params = json.load(param_handler) | 39 params = json.load(param_handler) |
44 | 40 |
45 estimator_type = params['algo_selection']['estimator_type'] | 41 estimator_type = params["algo_selection"]["estimator_type"] |
46 # get base estimators | 42 # get base estimators |
47 base_estimators = [] | 43 base_estimators = [] |
48 for idx, base_file in enumerate(base_paths.split(',')): | 44 for idx, base_file in enumerate(base_paths.split(",")): |
49 if base_file and base_file != 'None': | 45 if base_file and base_file != "None": |
50 with open(base_file, 'rb') as handler: | 46 with open(base_file, "rb") as handler: |
51 model = load_model(handler) | 47 model = load_model(handler) |
52 else: | 48 else: |
53 estimator_json = (params['base_est_builder'][idx] | 49 estimator_json = params["base_est_builder"][idx]["estimator_selector"] |
54 ['estimator_selector']) | |
55 model = get_estimator(estimator_json) | 50 model = get_estimator(estimator_json) |
56 | 51 |
57 if estimator_type.startswith('sklearn'): | 52 if estimator_type.startswith("sklearn"): |
58 named = model.__class__.__name__.lower() | 53 named = model.__class__.__name__.lower() |
59 named = 'base_%d_%s' % (idx, named) | 54 named = "base_%d_%s" % (idx, named) |
60 base_estimators.append((named, model)) | 55 base_estimators.append((named, model)) |
61 else: | 56 else: |
62 base_estimators.append(model) | 57 base_estimators.append(model) |
63 | 58 |
64 # get meta estimator, if applicable | 59 # get meta estimator, if applicable |
65 if estimator_type.startswith('mlxtend'): | 60 if estimator_type.startswith("mlxtend"): |
66 if meta_path: | 61 if meta_path: |
67 with open(meta_path, 'rb') as f: | 62 with open(meta_path, "rb") as f: |
68 meta_estimator = load_model(f) | 63 meta_estimator = load_model(f) |
69 else: | 64 else: |
70 estimator_json = (params['algo_selection'] | 65 estimator_json = params["algo_selection"]["meta_estimator"]["estimator_selector"] |
71 ['meta_estimator']['estimator_selector']) | |
72 meta_estimator = get_estimator(estimator_json) | 66 meta_estimator = get_estimator(estimator_json) |
73 | 67 |
74 options = params['algo_selection']['options'] | 68 options = params["algo_selection"]["options"] |
75 | 69 |
76 cv_selector = options.pop('cv_selector', None) | 70 cv_selector = options.pop("cv_selector", None) |
77 if cv_selector: | 71 if cv_selector: |
78 splitter, groups = get_cv(cv_selector) | 72 splitter, _groups = get_cv(cv_selector) |
79 options['cv'] = splitter | 73 options["cv"] = splitter |
80 # set n_jobs | 74 # set n_jobs |
81 options['n_jobs'] = N_JOBS | 75 options["n_jobs"] = N_JOBS |
82 | 76 |
83 weights = options.pop('weights', None) | 77 weights = options.pop("weights", None) |
84 if weights: | 78 if weights: |
85 weights = ast.literal_eval(weights) | 79 weights = ast.literal_eval(weights) |
86 if weights: | 80 if weights: |
87 options['weights'] = weights | 81 options["weights"] = weights |
88 | 82 |
89 mod_and_name = estimator_type.split('_') | 83 mod_and_name = estimator_type.split("_") |
90 mod = sys.modules[mod_and_name[0]] | 84 mod = sys.modules[mod_and_name[0]] |
91 klass = getattr(mod, mod_and_name[1]) | 85 klass = getattr(mod, mod_and_name[1]) |
92 | 86 |
93 if estimator_type.startswith('sklearn'): | 87 if estimator_type.startswith("sklearn"): |
94 options['n_jobs'] = N_JOBS | 88 options["n_jobs"] = N_JOBS |
95 ensemble_estimator = klass(base_estimators, **options) | 89 ensemble_estimator = klass(base_estimators, **options) |
96 | 90 |
97 elif mod == mlxtend.classifier: | 91 elif mod == mlxtend.classifier: |
98 ensemble_estimator = klass( | 92 ensemble_estimator = klass(classifiers=base_estimators, meta_classifier=meta_estimator, **options) |
99 classifiers=base_estimators, | |
100 meta_classifier=meta_estimator, | |
101 **options) | |
102 | 93 |
103 else: | 94 else: |
104 ensemble_estimator = klass( | 95 ensemble_estimator = klass(regressors=base_estimators, meta_regressor=meta_estimator, **options) |
105 regressors=base_estimators, | |
106 meta_regressor=meta_estimator, | |
107 **options) | |
108 | 96 |
109 print(ensemble_estimator) | 97 print(ensemble_estimator) |
110 for base_est in base_estimators: | 98 for base_est in base_estimators: |
111 print(base_est) | 99 print(base_est) |
112 | 100 |
113 with open(output_obj, 'wb') as out_handler: | 101 with open(output_obj, "wb") as out_handler: |
114 pickle.dump(ensemble_estimator, out_handler, pickle.HIGHEST_PROTOCOL) | 102 pickle.dump(ensemble_estimator, out_handler, pickle.HIGHEST_PROTOCOL) |
115 | 103 |
116 if params['get_params'] and outfile_params: | 104 if params["get_params"] and outfile_params: |
117 results = get_search_params(ensemble_estimator) | 105 results = get_search_params(ensemble_estimator) |
118 df = pd.DataFrame(results, columns=['', 'Parameter', 'Value']) | 106 df = pd.DataFrame(results, columns=["", "Parameter", "Value"]) |
119 df.to_csv(outfile_params, sep='\t', index=False) | 107 df.to_csv(outfile_params, sep="\t", index=False) |
120 | 108 |
121 | 109 |
122 if __name__ == '__main__': | 110 if __name__ == "__main__": |
123 aparser = argparse.ArgumentParser() | 111 aparser = argparse.ArgumentParser() |
124 aparser.add_argument("-b", "--bases", dest="bases") | 112 aparser.add_argument("-b", "--bases", dest="bases") |
125 aparser.add_argument("-m", "--meta", dest="meta") | 113 aparser.add_argument("-m", "--meta", dest="meta") |
126 aparser.add_argument("-i", "--inputs", dest="inputs") | 114 aparser.add_argument("-i", "--inputs", dest="inputs") |
127 aparser.add_argument("-o", "--outfile", dest="outfile") | 115 aparser.add_argument("-o", "--outfile", dest="outfile") |
128 aparser.add_argument("-p", "--outfile_params", dest="outfile_params") | 116 aparser.add_argument("-p", "--outfile_params", dest="outfile_params") |
129 args = aparser.parse_args() | 117 args = aparser.parse_args() |
130 | 118 |
131 main(args.inputs, args.outfile, base_paths=args.bases, | 119 main( |
132 meta_path=args.meta, outfile_params=args.outfile_params) | 120 args.inputs, |
121 args.outfile, | |
122 base_paths=args.bases, | |
123 meta_path=args.meta, | |
124 outfile_params=args.outfile_params, | |
125 ) |