Mercurial > repos > bgruening > sklearn_svm_classifier
comparison utils.py @ 5:1c5989b930e3 draft
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 2a058459e6daf0486871f93845f00fdb4a4eaca1
author | bgruening |
---|---|
date | Sat, 29 Sep 2018 07:26:04 -0400 |
parents | 41d0edb7d1fc |
children | 372582a7a34d |
comparison
equal
deleted
inserted
replaced
4:41d0edb7d1fc | 5:1c5989b930e3 |
---|---|
1 import sys | 1 import sys |
2 import os | 2 import os |
3 import pandas | 3 import pandas |
4 import re | 4 import re |
5 import cPickle as pickle | 5 import pickle |
6 import warnings | 6 import warnings |
7 import numpy as np | 7 import numpy as np |
8 import xgboost | 8 import xgboost |
9 import scipy | 9 import scipy |
10 import sklearn | 10 import sklearn |
11 import ast | |
12 from asteval import Interpreter, make_symbol_table | 11 from asteval import Interpreter, make_symbol_table |
13 from sklearn import (cluster, decomposition, ensemble, feature_extraction, feature_selection, | 12 from sklearn import (cluster, decomposition, ensemble, feature_extraction, feature_selection, |
14 gaussian_process, kernel_approximation, linear_model, metrics, | 13 gaussian_process, kernel_approximation, metrics, |
15 model_selection, naive_bayes, neighbors, pipeline, preprocessing, | 14 model_selection, naive_bayes, neighbors, pipeline, preprocessing, |
16 svm, linear_model, tree, discriminant_analysis) | 15 svm, linear_model, tree, discriminant_analysis) |
17 | 16 |
18 N_JOBS = int( os.environ.get('GALAXY_SLOTS', 1) ) | 17 N_JOBS = int(os.environ.get('GALAXY_SLOTS', 1)) |
19 | 18 |
20 class SafePickler(object): | 19 |
20 class SafePickler(pickle.Unpickler): | |
21 """ | 21 """ |
22 Used to safely deserialize scikit-learn model objects serialized by cPickle.dump | 22 Used to safely deserialize scikit-learn model objects serialized by cPickle.dump |
23 Usage: | 23 Usage: |
24 eg.: SafePickler.load(pickled_file_object) | 24 eg.: SafePickler.load(pickled_file_object) |
25 """ | 25 """ |
26 @classmethod | |
27 def find_class(self, module, name): | 26 def find_class(self, module, name): |
28 | 27 |
29 bad_names = ('and', 'as', 'assert', 'break', 'class', 'continue', | 28 bad_names = ('and', 'as', 'assert', 'break', 'class', 'continue', |
30 'def', 'del', 'elif', 'else', 'except', 'exec', | 29 'def', 'del', 'elif', 'else', 'except', 'exec', |
31 'finally', 'for', 'from', 'global', 'if', 'import', | 30 'finally', 'for', 'from', 'global', 'if', 'import', |
37 '__dict__', '__class__', '__call__', '__get__', | 36 '__dict__', '__class__', '__call__', '__get__', |
38 '__getattribute__', '__subclasshook__', '__new__', | 37 '__getattribute__', '__subclasshook__', '__new__', |
39 '__init__', 'func_globals', 'func_code', 'func_closure', | 38 '__init__', 'func_globals', 'func_code', 'func_closure', |
40 'im_class', 'im_func', 'im_self', 'gi_code', 'gi_frame', | 39 'im_class', 'im_func', 'im_self', 'gi_code', 'gi_frame', |
41 '__asteval__', 'f_locals', '__mro__') | 40 '__asteval__', 'f_locals', '__mro__') |
42 good_names = ('copy_reg._reconstructor', '__builtin__.object') | 41 good_names = ['copy_reg._reconstructor', '__builtin__.object'] |
43 | 42 |
44 if re.match(r'^[a-zA-Z_][a-zA-Z0-9_]*$', name): | 43 if re.match(r'^[a-zA-Z_][a-zA-Z0-9_]*$', name): |
45 fullname = module + '.' + name | 44 fullname = module + '.' + name |
46 if (fullname in good_names)\ | 45 if (fullname in good_names)\ |
47 or ( ( module.startswith('sklearn.') | 46 or ( ( module.startswith('sklearn.') |
48 or module.startswith('xgboost.') | 47 or module.startswith('xgboost.') |
49 or module.startswith('skrebate.') | 48 or module.startswith('skrebate.') |
50 or module.startswith('numpy.') | 49 or module.startswith('numpy.') |
51 or module == 'numpy' | 50 or module == 'numpy' |
52 ) | 51 ) |
53 and (name not in bad_names) | 52 and (name not in bad_names) |
54 ) : | 53 ): |
55 # TODO: replace with a whitelist checker | 54 # TODO: replace with a whitelist checker |
56 if fullname not in SK_NAMES + SKR_NAMES + XGB_NAMES + NUMPY_NAMES + good_names: | 55 if fullname not in sk_whitelist['SK_NAMES'] + sk_whitelist['SKR_NAMES'] + sk_whitelist['XGB_NAMES'] + sk_whitelist['NUMPY_NAMES'] + good_names: |
57 print("Warning: global %s is not in pickler whitelist yet and will loss support soon. Contact tool author or leave a message at github.com" % fullname) | 56 print("Warning: global %s is not in pickler whitelist yet and will loss support soon. Contact tool author or leave a message at github.com" % fullname) |
58 mod = sys.modules[module] | 57 mod = sys.modules[module] |
59 return getattr(mod, name) | 58 return getattr(mod, name) |
60 | 59 |
61 raise pickle.UnpicklingError("global '%s' is forbidden" % fullname) | 60 raise pickle.UnpicklingError("global '%s' is forbidden" % fullname) |
62 | 61 |
63 @classmethod | 62 |
64 def load(self, file): | 63 def load_model(file): |
65 obj = pickle.Unpickler(file) | 64 return SafePickler(file).load() |
66 obj.find_global = self.find_class | 65 |
67 return obj.load() | |
68 | 66 |
69 def read_columns(f, c=None, c_option='by_index_number', return_df=False, **args): | 67 def read_columns(f, c=None, c_option='by_index_number', return_df=False, **args): |
70 data = pandas.read_csv(f, **args) | 68 data = pandas.read_csv(f, **args) |
71 if c_option == 'by_index_number': | 69 if c_option == 'by_index_number': |
72 cols = list(map(lambda x: x - 1, c)) | 70 cols = list(map(lambda x: x - 1, c)) |
73 data = data.iloc[:,cols] | 71 data = data.iloc[:, cols] |
74 if c_option == 'all_but_by_index_number': | 72 if c_option == 'all_but_by_index_number': |
75 cols = list(map(lambda x: x - 1, c)) | 73 cols = list(map(lambda x: x - 1, c)) |
76 data.drop(data.columns[cols], axis=1, inplace=True) | 74 data.drop(data.columns[cols], axis=1, inplace=True) |
77 if c_option == 'by_header_name': | 75 if c_option == 'by_header_name': |
78 cols = [e.strip() for e in c.split(',')] | 76 cols = [e.strip() for e in c.split(',')] |
98 if not options['threshold'] or options['threshold'] == 'None': | 96 if not options['threshold'] or options['threshold'] == 'None': |
99 options['threshold'] = None | 97 options['threshold'] = None |
100 if inputs['model_inputter']['input_mode'] == 'prefitted': | 98 if inputs['model_inputter']['input_mode'] == 'prefitted': |
101 model_file = inputs['model_inputter']['fitted_estimator'] | 99 model_file = inputs['model_inputter']['fitted_estimator'] |
102 with open(model_file, 'rb') as model_handler: | 100 with open(model_file, 'rb') as model_handler: |
103 fitted_estimator = SafePickler.load(model_handler) | 101 fitted_estimator = load_model(model_handler) |
104 new_selector = selector(fitted_estimator, prefit=True, **options) | 102 new_selector = selector(fitted_estimator, prefit=True, **options) |
105 else: | 103 else: |
106 estimator_json = inputs['model_inputter']["estimator_selector"] | 104 estimator_json = inputs['model_inputter']["estimator_selector"] |
107 estimator = get_estimator(estimator_json) | 105 estimator = get_estimator(estimator_json) |
108 new_selector = selector(estimator, **options) | 106 new_selector = selector(estimator, **options) |
109 | 107 |
110 elif inputs['selected_algorithm'] == 'RFE': | 108 elif inputs['selected_algorithm'] == 'RFE': |
111 estimator=get_estimator(inputs["estimator_selector"]) | 109 estimator = get_estimator(inputs["estimator_selector"]) |
112 new_selector = selector(estimator, **options) | 110 new_selector = selector(estimator, **options) |
113 | 111 |
114 elif inputs['selected_algorithm'] == 'RFECV': | 112 elif inputs['selected_algorithm'] == 'RFECV': |
115 options['scoring'] = get_scoring(options['scoring']) | 113 options['scoring'] = get_scoring(options['scoring']) |
116 options['n_jobs'] = N_JOBS | 114 options['n_jobs'] = N_JOBS |
117 options['cv'] = get_cv( options['cv'].strip() ) | 115 options['cv'] = get_cv(options['cv'].strip()) |
118 estimator=get_estimator(inputs["estimator_selector"]) | 116 estimator = get_estimator(inputs["estimator_selector"]) |
119 new_selector = selector(estimator, **options) | 117 new_selector = selector(estimator, **options) |
120 | 118 |
121 elif inputs['selected_algorithm'] == "VarianceThreshold": | 119 elif inputs['selected_algorithm'] == "VarianceThreshold": |
122 new_selector = selector(**options) | 120 new_selector = selector(**options) |
123 | 121 |
125 score_func = inputs["score_func"] | 123 score_func = inputs["score_func"] |
126 score_func = getattr(sklearn.feature_selection, score_func) | 124 score_func = getattr(sklearn.feature_selection, score_func) |
127 new_selector = selector(score_func, **options) | 125 new_selector = selector(score_func, **options) |
128 | 126 |
129 return new_selector | 127 return new_selector |
130 | 128 |
131 | 129 |
132 def get_X_y(params, file1, file2): | 130 def get_X_y(params, file1, file2): |
133 input_type = params["selected_tasks"]["selected_algorithms"]["input_options"]["selected_input"] | 131 input_type = params["selected_tasks"]["selected_algorithms"]["input_options"]["selected_input"] |
134 if input_type=="tabular": | 132 if input_type == "tabular": |
135 header = 'infer' if params["selected_tasks"]["selected_algorithms"]["input_options"]["header1"] else None | 133 header = 'infer' if params["selected_tasks"]["selected_algorithms"]["input_options"]["header1"] else None |
136 column_option = params["selected_tasks"]["selected_algorithms"]["input_options"]["column_selector_options_1"]["selected_column_selector_option"] | 134 column_option = params["selected_tasks"]["selected_algorithms"]["input_options"]["column_selector_options_1"]["selected_column_selector_option"] |
137 if column_option in ["by_index_number", "all_but_by_index_number", "by_header_name", "all_but_by_header_name"]: | 135 if column_option in ["by_index_number", "all_but_by_index_number", "by_header_name", "all_but_by_header_name"]: |
138 c = params["selected_tasks"]["selected_algorithms"]["input_options"]["column_selector_options_1"]["col1"] | 136 c = params["selected_tasks"]["selected_algorithms"]["input_options"]["column_selector_options_1"]["col1"] |
139 else: | 137 else: |
140 c = None | 138 c = None |
141 X = read_columns( | 139 X = read_columns( |
142 file1, | 140 file1, |
143 c = c, | 141 c=c, |
144 c_option = column_option, | 142 c_option=column_option, |
145 sep='\t', | 143 sep='\t', |
146 header=header, | 144 header=header, |
147 parse_dates=True | 145 parse_dates=True |
148 ) | 146 ) |
149 else: | 147 else: |
155 c = params["selected_tasks"]["selected_algorithms"]["input_options"]["column_selector_options_2"]["col2"] | 153 c = params["selected_tasks"]["selected_algorithms"]["input_options"]["column_selector_options_2"]["col2"] |
156 else: | 154 else: |
157 c = None | 155 c = None |
158 y = read_columns( | 156 y = read_columns( |
159 file2, | 157 file2, |
160 c = c, | 158 c=c, |
161 c_option = column_option, | 159 c_option=column_option, |
162 sep='\t', | 160 sep='\t', |
163 header=header, | 161 header=header, |
164 parse_dates=True | 162 parse_dates=True |
165 ) | 163 ) |
166 y=y.ravel() | 164 y = y.ravel() |
167 return X, y | 165 return X, y |
168 | 166 |
169 | 167 |
170 class SafeEval(Interpreter): | 168 class SafeEval(Interpreter): |
171 | 169 |
195 'multivariate_normal', 'negative_binomial', 'noncentral_chisquare', 'noncentral_f', | 193 'multivariate_normal', 'negative_binomial', 'noncentral_chisquare', 'noncentral_f', |
196 'normal', 'pareto', 'permutation', 'poisson', 'power', 'rand', 'randint', | 194 'normal', 'pareto', 'permutation', 'poisson', 'power', 'rand', 'randint', |
197 'randn', 'random', 'random_integers', 'random_sample', 'ranf', 'rayleigh', | 195 'randn', 'random', 'random_integers', 'random_sample', 'ranf', 'rayleigh', |
198 'sample', 'seed', 'set_state', 'shuffle', 'standard_cauchy', 'standard_exponential', | 196 'sample', 'seed', 'set_state', 'shuffle', 'standard_cauchy', 'standard_exponential', |
199 'standard_gamma', 'standard_normal', 'standard_t', 'triangular', 'uniform', | 197 'standard_gamma', 'standard_normal', 'standard_t', 'triangular', 'uniform', |
200 'vonmises', 'wald', 'weibull', 'zipf' ] | 198 'vonmises', 'wald', 'weibull', 'zipf'] |
201 for f in from_numpy_random: | 199 for f in from_numpy_random: |
202 syms['np_random_' + f] = getattr(np.random, f) | 200 syms['np_random_' + f] = getattr(np.random, f) |
203 | 201 |
204 for key in unwanted: | 202 for key in unwanted: |
205 syms.pop(key, None) | 203 syms.pop(key, None) |
206 | 204 |
207 super(SafeEval, self).__init__( symtable=syms, use_numpy=False, minimal=False, | 205 super(SafeEval, self).__init__(symtable=syms, use_numpy=False, minimal=False, |
208 no_if=True, no_for=True, no_while=True, no_try=True, | 206 no_if=True, no_for=True, no_while=True, no_try=True, |
209 no_functiondef=True, no_ifexp=True, no_listcomp=False, | 207 no_functiondef=True, no_ifexp=True, no_listcomp=False, |
210 no_augassign=False, no_assert=True, no_delete=True, | 208 no_augassign=False, no_assert=True, no_delete=True, |
211 no_raise=True, no_print=True) | 209 no_raise=True, no_print=True) |
212 | 210 |
248 estimator_params = estimator_json['text_params'].strip() | 246 estimator_params = estimator_json['text_params'].strip() |
249 if estimator_params != "": | 247 if estimator_params != "": |
250 try: | 248 try: |
251 params = safe_eval('dict(' + estimator_params + ')') | 249 params = safe_eval('dict(' + estimator_params + ')') |
252 except ValueError: | 250 except ValueError: |
253 sys.exit("Unsupported parameter input: `%s`" %estimator_params) | 251 sys.exit("Unsupported parameter input: `%s`" % estimator_params) |
254 estimator.set_params(**params) | 252 estimator.set_params(**params) |
255 if 'n_jobs' in estimator.get_params(): | 253 if 'n_jobs' in estimator.get_params(): |
256 estimator.set_params( n_jobs=N_JOBS ) | 254 estimator.set_params(n_jobs=N_JOBS) |
257 | 255 |
258 return estimator | 256 return estimator |
259 | 257 |
260 | 258 |
261 def get_cv(literal): | 259 def get_cv(literal): |
264 return None | 262 return None |
265 if literal.isdigit(): | 263 if literal.isdigit(): |
266 return int(literal) | 264 return int(literal) |
267 m = re.match(r'^(?P<method>\w+)\((?P<args>.*)\)$', literal) | 265 m = re.match(r'^(?P<method>\w+)\((?P<args>.*)\)$', literal) |
268 if m: | 266 if m: |
269 my_class = getattr( model_selection, m.group('method') ) | 267 my_class = getattr(model_selection, m.group('method')) |
270 args = safe_eval( 'dict('+ m.group('args') + ')' ) | 268 args = safe_eval('dict('+ m.group('args') + ')') |
271 return my_class( **args ) | 269 return my_class(**args) |
272 sys.exit("Unsupported CV input: %s" %literal) | 270 sys.exit("Unsupported CV input: %s" % literal) |
273 | 271 |
274 | 272 |
275 def get_scoring(scoring_json): | 273 def get_scoring(scoring_json): |
276 def balanced_accuracy_score(y_true, y_pred): | 274 def balanced_accuracy_score(y_true, y_pred): |
277 C = metrics.confusion_matrix(y_true, y_pred) | 275 C = metrics.confusion_matrix(y_true, y_pred) |
291 my_scorers['balanced_accuracy'] = metrics.make_scorer(balanced_accuracy_score) | 289 my_scorers['balanced_accuracy'] = metrics.make_scorer(balanced_accuracy_score) |
292 | 290 |
293 if scoring_json['secondary_scoring'] != 'None'\ | 291 if scoring_json['secondary_scoring'] != 'None'\ |
294 and scoring_json['secondary_scoring'] != scoring_json['primary_scoring']: | 292 and scoring_json['secondary_scoring'] != scoring_json['primary_scoring']: |
295 scoring = {} | 293 scoring = {} |
296 scoring['primary'] = my_scorers[ scoring_json['primary_scoring'] ] | 294 scoring['primary'] = my_scorers[scoring_json['primary_scoring']] |
297 for scorer in scoring_json['secondary_scoring'].split(','): | 295 for scorer in scoring_json['secondary_scoring'].split(','): |
298 if scorer != scoring_json['primary_scoring']: | 296 if scorer != scoring_json['primary_scoring']: |
299 scoring[scorer] = my_scorers[scorer] | 297 scoring[scorer] = my_scorers[scorer] |
300 return scoring | 298 return scoring |
301 | 299 |
302 return my_scorers[ scoring_json['primary_scoring'] ] | 300 return my_scorers[scoring_json['primary_scoring']] |
303 |