comparison utils.py @ 21:212e7adfe65f draft

planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 2a058459e6daf0486871f93845f00fdb4a4eaca1
author bgruening
date Sat, 29 Sep 2018 07:39:16 -0400
parents 9b7d0655f70f
children e3bc646e63b2
comparison
equal deleted inserted replaced
20:9b7d0655f70f 21:212e7adfe65f
1 import sys 1 import sys
2 import os 2 import os
3 import pandas 3 import pandas
4 import re 4 import re
5 import cPickle as pickle 5 import pickle
6 import warnings 6 import warnings
7 import numpy as np 7 import numpy as np
8 import xgboost 8 import xgboost
9 import scipy 9 import scipy
10 import sklearn 10 import sklearn
11 import ast
12 from asteval import Interpreter, make_symbol_table 11 from asteval import Interpreter, make_symbol_table
13 from sklearn import (cluster, decomposition, ensemble, feature_extraction, feature_selection, 12 from sklearn import (cluster, decomposition, ensemble, feature_extraction, feature_selection,
14 gaussian_process, kernel_approximation, linear_model, metrics, 13 gaussian_process, kernel_approximation, metrics,
15 model_selection, naive_bayes, neighbors, pipeline, preprocessing, 14 model_selection, naive_bayes, neighbors, pipeline, preprocessing,
16 svm, linear_model, tree, discriminant_analysis) 15 svm, linear_model, tree, discriminant_analysis)
17 16
18 N_JOBS = int( os.environ.get('GALAXY_SLOTS', 1) ) 17 N_JOBS = int(os.environ.get('GALAXY_SLOTS', 1))
19 18
20 class SafePickler(object): 19
20 class SafePickler(pickle.Unpickler):
21 """ 21 """
22 Used to safely deserialize scikit-learn model objects serialized by cPickle.dump 22 Used to safely deserialize scikit-learn model objects serialized by cPickle.dump
23 Usage: 23 Usage:
24 eg.: SafePickler.load(pickled_file_object) 24 eg.: SafePickler.load(pickled_file_object)
25 """ 25 """
26 @classmethod
27 def find_class(self, module, name): 26 def find_class(self, module, name):
28 27
29 bad_names = ('and', 'as', 'assert', 'break', 'class', 'continue', 28 bad_names = ('and', 'as', 'assert', 'break', 'class', 'continue',
30 'def', 'del', 'elif', 'else', 'except', 'exec', 29 'def', 'del', 'elif', 'else', 'except', 'exec',
31 'finally', 'for', 'from', 'global', 'if', 'import', 30 'finally', 'for', 'from', 'global', 'if', 'import',
37 '__dict__', '__class__', '__call__', '__get__', 36 '__dict__', '__class__', '__call__', '__get__',
38 '__getattribute__', '__subclasshook__', '__new__', 37 '__getattribute__', '__subclasshook__', '__new__',
39 '__init__', 'func_globals', 'func_code', 'func_closure', 38 '__init__', 'func_globals', 'func_code', 'func_closure',
40 'im_class', 'im_func', 'im_self', 'gi_code', 'gi_frame', 39 'im_class', 'im_func', 'im_self', 'gi_code', 'gi_frame',
41 '__asteval__', 'f_locals', '__mro__') 40 '__asteval__', 'f_locals', '__mro__')
42 good_names = ('copy_reg._reconstructor', '__builtin__.object') 41 good_names = ['copy_reg._reconstructor', '__builtin__.object']
43 42
44 if re.match(r'^[a-zA-Z_][a-zA-Z0-9_]*$', name): 43 if re.match(r'^[a-zA-Z_][a-zA-Z0-9_]*$', name):
45 fullname = module + '.' + name 44 fullname = module + '.' + name
46 if (fullname in good_names)\ 45 if (fullname in good_names)\
47 or ( ( module.startswith('sklearn.') 46 or ( ( module.startswith('sklearn.')
48 or module.startswith('xgboost.') 47 or module.startswith('xgboost.')
49 or module.startswith('skrebate.') 48 or module.startswith('skrebate.')
50 or module.startswith('numpy.') 49 or module.startswith('numpy.')
51 or module == 'numpy' 50 or module == 'numpy'
52 ) 51 )
53 and (name not in bad_names) 52 and (name not in bad_names)
54 ) : 53 ):
55 # TODO: replace with a whitelist checker 54 # TODO: replace with a whitelist checker
56 if fullname not in SK_NAMES + SKR_NAMES + XGB_NAMES + NUMPY_NAMES + good_names: 55 if fullname not in sk_whitelist['SK_NAMES'] + sk_whitelist['SKR_NAMES'] + sk_whitelist['XGB_NAMES'] + sk_whitelist['NUMPY_NAMES'] + good_names:
57 print("Warning: global %s is not in pickler whitelist yet and will loss support soon. Contact tool author or leave a message at github.com" % fullname) 56 print("Warning: global %s is not in pickler whitelist yet and will loss support soon. Contact tool author or leave a message at github.com" % fullname)
58 mod = sys.modules[module] 57 mod = sys.modules[module]
59 return getattr(mod, name) 58 return getattr(mod, name)
60 59
61 raise pickle.UnpicklingError("global '%s' is forbidden" % fullname) 60 raise pickle.UnpicklingError("global '%s' is forbidden" % fullname)
62 61
63 @classmethod 62
64 def load(self, file): 63 def load_model(file):
65 obj = pickle.Unpickler(file) 64 return SafePickler(file).load()
66 obj.find_global = self.find_class 65
67 return obj.load()
68 66
69 def read_columns(f, c=None, c_option='by_index_number', return_df=False, **args): 67 def read_columns(f, c=None, c_option='by_index_number', return_df=False, **args):
70 data = pandas.read_csv(f, **args) 68 data = pandas.read_csv(f, **args)
71 if c_option == 'by_index_number': 69 if c_option == 'by_index_number':
72 cols = list(map(lambda x: x - 1, c)) 70 cols = list(map(lambda x: x - 1, c))
73 data = data.iloc[:,cols] 71 data = data.iloc[:, cols]
74 if c_option == 'all_but_by_index_number': 72 if c_option == 'all_but_by_index_number':
75 cols = list(map(lambda x: x - 1, c)) 73 cols = list(map(lambda x: x - 1, c))
76 data.drop(data.columns[cols], axis=1, inplace=True) 74 data.drop(data.columns[cols], axis=1, inplace=True)
77 if c_option == 'by_header_name': 75 if c_option == 'by_header_name':
78 cols = [e.strip() for e in c.split(',')] 76 cols = [e.strip() for e in c.split(',')]
98 if not options['threshold'] or options['threshold'] == 'None': 96 if not options['threshold'] or options['threshold'] == 'None':
99 options['threshold'] = None 97 options['threshold'] = None
100 if inputs['model_inputter']['input_mode'] == 'prefitted': 98 if inputs['model_inputter']['input_mode'] == 'prefitted':
101 model_file = inputs['model_inputter']['fitted_estimator'] 99 model_file = inputs['model_inputter']['fitted_estimator']
102 with open(model_file, 'rb') as model_handler: 100 with open(model_file, 'rb') as model_handler:
103 fitted_estimator = SafePickler.load(model_handler) 101 fitted_estimator = load_model(model_handler)
104 new_selector = selector(fitted_estimator, prefit=True, **options) 102 new_selector = selector(fitted_estimator, prefit=True, **options)
105 else: 103 else:
106 estimator_json = inputs['model_inputter']["estimator_selector"] 104 estimator_json = inputs['model_inputter']["estimator_selector"]
107 estimator = get_estimator(estimator_json) 105 estimator = get_estimator(estimator_json)
108 new_selector = selector(estimator, **options) 106 new_selector = selector(estimator, **options)
109 107
110 elif inputs['selected_algorithm'] == 'RFE': 108 elif inputs['selected_algorithm'] == 'RFE':
111 estimator=get_estimator(inputs["estimator_selector"]) 109 estimator = get_estimator(inputs["estimator_selector"])
112 new_selector = selector(estimator, **options) 110 new_selector = selector(estimator, **options)
113 111
114 elif inputs['selected_algorithm'] == 'RFECV': 112 elif inputs['selected_algorithm'] == 'RFECV':
115 options['scoring'] = get_scoring(options['scoring']) 113 options['scoring'] = get_scoring(options['scoring'])
116 options['n_jobs'] = N_JOBS 114 options['n_jobs'] = N_JOBS
117 options['cv'] = get_cv( options['cv'].strip() ) 115 options['cv'] = get_cv(options['cv'].strip())
118 estimator=get_estimator(inputs["estimator_selector"]) 116 estimator = get_estimator(inputs["estimator_selector"])
119 new_selector = selector(estimator, **options) 117 new_selector = selector(estimator, **options)
120 118
121 elif inputs['selected_algorithm'] == "VarianceThreshold": 119 elif inputs['selected_algorithm'] == "VarianceThreshold":
122 new_selector = selector(**options) 120 new_selector = selector(**options)
123 121
125 score_func = inputs["score_func"] 123 score_func = inputs["score_func"]
126 score_func = getattr(sklearn.feature_selection, score_func) 124 score_func = getattr(sklearn.feature_selection, score_func)
127 new_selector = selector(score_func, **options) 125 new_selector = selector(score_func, **options)
128 126
129 return new_selector 127 return new_selector
130 128
131 129
132 def get_X_y(params, file1, file2): 130 def get_X_y(params, file1, file2):
133 input_type = params["selected_tasks"]["selected_algorithms"]["input_options"]["selected_input"] 131 input_type = params["selected_tasks"]["selected_algorithms"]["input_options"]["selected_input"]
134 if input_type=="tabular": 132 if input_type == "tabular":
135 header = 'infer' if params["selected_tasks"]["selected_algorithms"]["input_options"]["header1"] else None 133 header = 'infer' if params["selected_tasks"]["selected_algorithms"]["input_options"]["header1"] else None
136 column_option = params["selected_tasks"]["selected_algorithms"]["input_options"]["column_selector_options_1"]["selected_column_selector_option"] 134 column_option = params["selected_tasks"]["selected_algorithms"]["input_options"]["column_selector_options_1"]["selected_column_selector_option"]
137 if column_option in ["by_index_number", "all_but_by_index_number", "by_header_name", "all_but_by_header_name"]: 135 if column_option in ["by_index_number", "all_but_by_index_number", "by_header_name", "all_but_by_header_name"]:
138 c = params["selected_tasks"]["selected_algorithms"]["input_options"]["column_selector_options_1"]["col1"] 136 c = params["selected_tasks"]["selected_algorithms"]["input_options"]["column_selector_options_1"]["col1"]
139 else: 137 else:
140 c = None 138 c = None
141 X = read_columns( 139 X = read_columns(
142 file1, 140 file1,
143 c = c, 141 c=c,
144 c_option = column_option, 142 c_option=column_option,
145 sep='\t', 143 sep='\t',
146 header=header, 144 header=header,
147 parse_dates=True 145 parse_dates=True
148 ) 146 )
149 else: 147 else:
155 c = params["selected_tasks"]["selected_algorithms"]["input_options"]["column_selector_options_2"]["col2"] 153 c = params["selected_tasks"]["selected_algorithms"]["input_options"]["column_selector_options_2"]["col2"]
156 else: 154 else:
157 c = None 155 c = None
158 y = read_columns( 156 y = read_columns(
159 file2, 157 file2,
160 c = c, 158 c=c,
161 c_option = column_option, 159 c_option=column_option,
162 sep='\t', 160 sep='\t',
163 header=header, 161 header=header,
164 parse_dates=True 162 parse_dates=True
165 ) 163 )
166 y=y.ravel() 164 y = y.ravel()
167 return X, y 165 return X, y
168 166
169 167
170 class SafeEval(Interpreter): 168 class SafeEval(Interpreter):
171 169
195 'multivariate_normal', 'negative_binomial', 'noncentral_chisquare', 'noncentral_f', 193 'multivariate_normal', 'negative_binomial', 'noncentral_chisquare', 'noncentral_f',
196 'normal', 'pareto', 'permutation', 'poisson', 'power', 'rand', 'randint', 194 'normal', 'pareto', 'permutation', 'poisson', 'power', 'rand', 'randint',
197 'randn', 'random', 'random_integers', 'random_sample', 'ranf', 'rayleigh', 195 'randn', 'random', 'random_integers', 'random_sample', 'ranf', 'rayleigh',
198 'sample', 'seed', 'set_state', 'shuffle', 'standard_cauchy', 'standard_exponential', 196 'sample', 'seed', 'set_state', 'shuffle', 'standard_cauchy', 'standard_exponential',
199 'standard_gamma', 'standard_normal', 'standard_t', 'triangular', 'uniform', 197 'standard_gamma', 'standard_normal', 'standard_t', 'triangular', 'uniform',
200 'vonmises', 'wald', 'weibull', 'zipf' ] 198 'vonmises', 'wald', 'weibull', 'zipf']
201 for f in from_numpy_random: 199 for f in from_numpy_random:
202 syms['np_random_' + f] = getattr(np.random, f) 200 syms['np_random_' + f] = getattr(np.random, f)
203 201
204 for key in unwanted: 202 for key in unwanted:
205 syms.pop(key, None) 203 syms.pop(key, None)
206 204
207 super(SafeEval, self).__init__( symtable=syms, use_numpy=False, minimal=False, 205 super(SafeEval, self).__init__(symtable=syms, use_numpy=False, minimal=False,
208 no_if=True, no_for=True, no_while=True, no_try=True, 206 no_if=True, no_for=True, no_while=True, no_try=True,
209 no_functiondef=True, no_ifexp=True, no_listcomp=False, 207 no_functiondef=True, no_ifexp=True, no_listcomp=False,
210 no_augassign=False, no_assert=True, no_delete=True, 208 no_augassign=False, no_assert=True, no_delete=True,
211 no_raise=True, no_print=True) 209 no_raise=True, no_print=True)
212 210
248 estimator_params = estimator_json['text_params'].strip() 246 estimator_params = estimator_json['text_params'].strip()
249 if estimator_params != "": 247 if estimator_params != "":
250 try: 248 try:
251 params = safe_eval('dict(' + estimator_params + ')') 249 params = safe_eval('dict(' + estimator_params + ')')
252 except ValueError: 250 except ValueError:
253 sys.exit("Unsupported parameter input: `%s`" %estimator_params) 251 sys.exit("Unsupported parameter input: `%s`" % estimator_params)
254 estimator.set_params(**params) 252 estimator.set_params(**params)
255 if 'n_jobs' in estimator.get_params(): 253 if 'n_jobs' in estimator.get_params():
256 estimator.set_params( n_jobs=N_JOBS ) 254 estimator.set_params(n_jobs=N_JOBS)
257 255
258 return estimator 256 return estimator
259 257
260 258
261 def get_cv(literal): 259 def get_cv(literal):
264 return None 262 return None
265 if literal.isdigit(): 263 if literal.isdigit():
266 return int(literal) 264 return int(literal)
267 m = re.match(r'^(?P<method>\w+)\((?P<args>.*)\)$', literal) 265 m = re.match(r'^(?P<method>\w+)\((?P<args>.*)\)$', literal)
268 if m: 266 if m:
269 my_class = getattr( model_selection, m.group('method') ) 267 my_class = getattr(model_selection, m.group('method'))
270 args = safe_eval( 'dict('+ m.group('args') + ')' ) 268 args = safe_eval('dict('+ m.group('args') + ')')
271 return my_class( **args ) 269 return my_class(**args)
272 sys.exit("Unsupported CV input: %s" %literal) 270 sys.exit("Unsupported CV input: %s" % literal)
273 271
274 272
275 def get_scoring(scoring_json): 273 def get_scoring(scoring_json):
276 def balanced_accuracy_score(y_true, y_pred): 274 def balanced_accuracy_score(y_true, y_pred):
277 C = metrics.confusion_matrix(y_true, y_pred) 275 C = metrics.confusion_matrix(y_true, y_pred)
291 my_scorers['balanced_accuracy'] = metrics.make_scorer(balanced_accuracy_score) 289 my_scorers['balanced_accuracy'] = metrics.make_scorer(balanced_accuracy_score)
292 290
293 if scoring_json['secondary_scoring'] != 'None'\ 291 if scoring_json['secondary_scoring'] != 'None'\
294 and scoring_json['secondary_scoring'] != scoring_json['primary_scoring']: 292 and scoring_json['secondary_scoring'] != scoring_json['primary_scoring']:
295 scoring = {} 293 scoring = {}
296 scoring['primary'] = my_scorers[ scoring_json['primary_scoring'] ] 294 scoring['primary'] = my_scorers[scoring_json['primary_scoring']]
297 for scorer in scoring_json['secondary_scoring'].split(','): 295 for scorer in scoring_json['secondary_scoring'].split(','):
298 if scorer != scoring_json['primary_scoring']: 296 if scorer != scoring_json['primary_scoring']:
299 scoring[scorer] = my_scorers[scorer] 297 scoring[scorer] = my_scorers[scorer]
300 return scoring 298 return scoring
301 299
302 return my_scorers[ scoring_json['primary_scoring'] ] 300 return my_scorers[scoring_json['primary_scoring']]
303