comparison stacking_ensembles.py @ 0:f96efab83b65 draft

"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit ba6a47bdf76bbf4cb276206ac1a8cbf61332fd16"
author bgruening
date Fri, 13 Sep 2019 12:23:39 -0400
parents
children 09efff9a5765
comparison
equal deleted inserted replaced
-1:000000000000 0:f96efab83b65
1 import argparse
2 import ast
3 import json
4 import mlxtend.regressor
5 import mlxtend.classifier
6 import pandas as pd
7 import pickle
8 import sklearn
9 import sys
10 import warnings
11 from sklearn import ensemble
12
13 from galaxy_ml.utils import (load_model, get_cv, get_estimator,
14 get_search_params)
15
16
17 warnings.filterwarnings('ignore')
18
19 N_JOBS = int(__import__('os').environ.get('GALAXY_SLOTS', 1))
20
21
22 def main(inputs_path, output_obj, base_paths=None, meta_path=None,
23 outfile_params=None):
24 """
25 Parameter
26 ---------
27 inputs_path : str
28 File path for Galaxy parameters
29
30 output_obj : str
31 File path for ensemble estimator ouput
32
33 base_paths : str
34 File path or paths concatenated by comma.
35
36 meta_path : str
37 File path
38
39 outfile_params : str
40 File path for params output
41 """
42 with open(inputs_path, 'r') as param_handler:
43 params = json.load(param_handler)
44
45 estimator_type = params['algo_selection']['estimator_type']
46 # get base estimators
47 base_estimators = []
48 for idx, base_file in enumerate(base_paths.split(',')):
49 if base_file and base_file != 'None':
50 with open(base_file, 'rb') as handler:
51 model = load_model(handler)
52 else:
53 estimator_json = (params['base_est_builder'][idx]
54 ['estimator_selector'])
55 model = get_estimator(estimator_json)
56
57 if estimator_type.startswith('sklearn'):
58 named = model.__class__.__name__.lower()
59 named = 'base_%d_%s' % (idx, named)
60 base_estimators.append((named, model))
61 else:
62 base_estimators.append(model)
63
64 # get meta estimator, if applicable
65 if estimator_type.startswith('mlxtend'):
66 if meta_path:
67 with open(meta_path, 'rb') as f:
68 meta_estimator = load_model(f)
69 else:
70 estimator_json = (params['algo_selection']
71 ['meta_estimator']['estimator_selector'])
72 meta_estimator = get_estimator(estimator_json)
73
74 options = params['algo_selection']['options']
75
76 cv_selector = options.pop('cv_selector', None)
77 if cv_selector:
78 splitter, groups = get_cv(cv_selector)
79 options['cv'] = splitter
80 # set n_jobs
81 options['n_jobs'] = N_JOBS
82
83 weights = options.pop('weights', None)
84 if weights:
85 options['weights'] = ast.literal_eval(weights)
86
87 mod_and_name = estimator_type.split('_')
88 mod = sys.modules[mod_and_name[0]]
89 klass = getattr(mod, mod_and_name[1])
90
91 if estimator_type.startswith('sklearn'):
92 options['n_jobs'] = N_JOBS
93 ensemble_estimator = klass(base_estimators, **options)
94
95 elif mod == mlxtend.classifier:
96 ensemble_estimator = klass(
97 classifiers=base_estimators,
98 meta_classifier=meta_estimator,
99 **options)
100
101 else:
102 ensemble_estimator = klass(
103 regressors=base_estimators,
104 meta_regressor=meta_estimator,
105 **options)
106
107 print(ensemble_estimator)
108 for base_est in base_estimators:
109 print(base_est)
110
111 with open(output_obj, 'wb') as out_handler:
112 pickle.dump(ensemble_estimator, out_handler, pickle.HIGHEST_PROTOCOL)
113
114 if params['get_params'] and outfile_params:
115 results = get_search_params(ensemble_estimator)
116 df = pd.DataFrame(results, columns=['', 'Parameter', 'Value'])
117 df.to_csv(outfile_params, sep='\t', index=False)
118
119
120 if __name__ == '__main__':
121 aparser = argparse.ArgumentParser()
122 aparser.add_argument("-b", "--bases", dest="bases")
123 aparser.add_argument("-m", "--meta", dest="meta")
124 aparser.add_argument("-i", "--inputs", dest="inputs")
125 aparser.add_argument("-o", "--outfile", dest="outfile")
126 aparser.add_argument("-p", "--outfile_params", dest="outfile_params")
127 args = aparser.parse_args()
128
129 main(args.inputs, args.outfile, base_paths=args.bases,
130 meta_path=args.meta, outfile_params=args.outfile_params)