Mercurial > repos > bgruening > stacking_ensemble_models
comparison utils.py @ 0:8e93241d5d28 draft default tip
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
author | bgruening |
---|---|
date | Tue, 14 May 2019 18:04:46 -0400 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:8e93241d5d28 |
---|---|
1 import ast | |
2 import json | |
3 import imblearn | |
4 import numpy as np | |
5 import pandas | |
6 import pickle | |
7 import re | |
8 import scipy | |
9 import sklearn | |
10 import skrebate | |
11 import sys | |
12 import warnings | |
13 import xgboost | |
14 | |
15 from collections import Counter | |
16 from asteval import Interpreter, make_symbol_table | |
17 from imblearn import under_sampling, over_sampling, combine | |
18 from imblearn.pipeline import Pipeline as imbPipeline | |
19 from mlxtend import regressor, classifier | |
20 from scipy.io import mmread | |
21 from sklearn import ( | |
22 cluster, compose, decomposition, ensemble, feature_extraction, | |
23 feature_selection, gaussian_process, kernel_approximation, metrics, | |
24 model_selection, naive_bayes, neighbors, pipeline, preprocessing, | |
25 svm, linear_model, tree, discriminant_analysis) | |
26 | |
27 try: | |
28 import iraps_classifier | |
29 except ImportError: | |
30 pass | |
31 | |
32 try: | |
33 import model_validations | |
34 except ImportError: | |
35 pass | |
36 | |
37 try: | |
38 import feature_selectors | |
39 except ImportError: | |
40 pass | |
41 | |
42 try: | |
43 import preprocessors | |
44 except ImportError: | |
45 pass | |
46 | |
47 # handle pickle white list file | |
48 WL_FILE = __import__('os').path.join( | |
49 __import__('os').path.dirname(__file__), 'pk_whitelist.json') | |
50 | |
51 N_JOBS = int(__import__('os').environ.get('GALAXY_SLOTS', 1)) | |
52 | |
53 | |
54 class _SafePickler(pickle.Unpickler, object): | |
55 """ | |
56 Used to safely deserialize scikit-learn model objects | |
57 Usage: | |
58 eg.: _SafePickler.load(pickled_file_object) | |
59 """ | |
60 def __init__(self, file): | |
61 super(_SafePickler, self).__init__(file) | |
62 # load global white list | |
63 with open(WL_FILE, 'r') as f: | |
64 self.pk_whitelist = json.load(f) | |
65 | |
66 self.bad_names = ( | |
67 'and', 'as', 'assert', 'break', 'class', 'continue', | |
68 'def', 'del', 'elif', 'else', 'except', 'exec', | |
69 'finally', 'for', 'from', 'global', 'if', 'import', | |
70 'in', 'is', 'lambda', 'not', 'or', 'pass', 'print', | |
71 'raise', 'return', 'try', 'system', 'while', 'with', | |
72 'True', 'False', 'None', 'eval', 'execfile', '__import__', | |
73 '__package__', '__subclasses__', '__bases__', '__globals__', | |
74 '__code__', '__closure__', '__func__', '__self__', '__module__', | |
75 '__dict__', '__class__', '__call__', '__get__', | |
76 '__getattribute__', '__subclasshook__', '__new__', | |
77 '__init__', 'func_globals', 'func_code', 'func_closure', | |
78 'im_class', 'im_func', 'im_self', 'gi_code', 'gi_frame', | |
79 '__asteval__', 'f_locals', '__mro__') | |
80 | |
81 # unclassified good globals | |
82 self.good_names = [ | |
83 'copy_reg._reconstructor', '__builtin__.object', | |
84 '__builtin__.bytearray', 'builtins.object', | |
85 'builtins.bytearray', 'keras.engine.sequential.Sequential', | |
86 'keras.engine.sequential.Model'] | |
87 | |
88 # custom module in Galaxy-ML | |
89 self.custom_modules = [ | |
90 '__main__', 'keras_galaxy_models', 'feature_selectors', | |
91 'preprocessors', 'iraps_classifier', 'model_validations'] | |
92 | |
93 # override | |
94 def find_class(self, module, name): | |
95 # balack list first | |
96 if name in self.bad_names: | |
97 raise pickle.UnpicklingError("global '%s.%s' is forbidden" | |
98 % (module, name)) | |
99 | |
100 # custom module in Galaxy-ML | |
101 if module in self.custom_modules: | |
102 cutom_module = sys.modules.get(module, None) | |
103 if cutom_module: | |
104 return getattr(cutom_module, name) | |
105 else: | |
106 raise pickle.UnpicklingError("Module %s' is not imported" | |
107 % module) | |
108 | |
109 # For objects from outside libraries, it's necessary to verify | |
110 # both module and name. Currently only a blacklist checker | |
111 # is working. | |
112 # TODO: replace with a whitelist checker. | |
113 good_names = self.good_names | |
114 pk_whitelist = self.pk_whitelist | |
115 if re.match(r'^[a-zA-Z_][a-zA-Z0-9_]*$', name): | |
116 fullname = module + '.' + name | |
117 if (fullname in good_names)\ | |
118 or (module.startswith(('sklearn.', 'xgboost.', 'skrebate.', | |
119 'imblearn.', 'mlxtend.', 'numpy.')) | |
120 or module == 'numpy'): | |
121 if fullname not in (pk_whitelist['SK_NAMES'] + | |
122 pk_whitelist['SKR_NAMES'] + | |
123 pk_whitelist['XGB_NAMES'] + | |
124 pk_whitelist['NUMPY_NAMES'] + | |
125 pk_whitelist['IMBLEARN_NAMES'] + | |
126 pk_whitelist['MLXTEND_NAMES'] + | |
127 good_names): | |
128 # raise pickle.UnpicklingError | |
129 print("Warning: global %s is not in pickler whitelist " | |
130 "yet and will loss support soon. Contact tool " | |
131 "author or leave a message at github.com" % fullname) | |
132 mod = sys.modules[module] | |
133 return getattr(mod, name) | |
134 | |
135 raise pickle.UnpicklingError("global '%s' is forbidden" % fullname) | |
136 | |
137 | |
138 def load_model(file): | |
139 """Load pickled object with `_SafePicker` | |
140 """ | |
141 return _SafePickler(file).load() | |
142 | |
143 | |
144 def read_columns(f, c=None, c_option='by_index_number', | |
145 return_df=False, **args): | |
146 """Return array from a tabular dataset by various columns selection | |
147 """ | |
148 data = pandas.read_csv(f, **args) | |
149 if c_option == 'by_index_number': | |
150 cols = list(map(lambda x: x - 1, c)) | |
151 data = data.iloc[:, cols] | |
152 if c_option == 'all_but_by_index_number': | |
153 cols = list(map(lambda x: x - 1, c)) | |
154 data.drop(data.columns[cols], axis=1, inplace=True) | |
155 if c_option == 'by_header_name': | |
156 cols = [e.strip() for e in c.split(',')] | |
157 data = data[cols] | |
158 if c_option == 'all_but_by_header_name': | |
159 cols = [e.strip() for e in c.split(',')] | |
160 data.drop(cols, axis=1, inplace=True) | |
161 y = data.values | |
162 if return_df: | |
163 return y, data | |
164 else: | |
165 return y | |
166 | |
167 | |
168 def feature_selector(inputs, X=None, y=None): | |
169 """generate an instance of sklearn.feature_selection classes | |
170 | |
171 Parameters | |
172 ---------- | |
173 inputs : dict | |
174 From galaxy tool parameters. | |
175 X : array | |
176 Containing training features. | |
177 y : array or list | |
178 Target values. | |
179 """ | |
180 selector = inputs['selected_algorithm'] | |
181 if selector != 'DyRFECV': | |
182 selector = getattr(sklearn.feature_selection, selector) | |
183 options = inputs['options'] | |
184 | |
185 if inputs['selected_algorithm'] == 'SelectFromModel': | |
186 if not options['threshold'] or options['threshold'] == 'None': | |
187 options['threshold'] = None | |
188 else: | |
189 try: | |
190 options['threshold'] = float(options['threshold']) | |
191 except ValueError: | |
192 pass | |
193 if inputs['model_inputter']['input_mode'] == 'prefitted': | |
194 model_file = inputs['model_inputter']['fitted_estimator'] | |
195 with open(model_file, 'rb') as model_handler: | |
196 fitted_estimator = load_model(model_handler) | |
197 new_selector = selector(fitted_estimator, prefit=True, **options) | |
198 else: | |
199 estimator_json = inputs['model_inputter']['estimator_selector'] | |
200 estimator = get_estimator(estimator_json) | |
201 check_feature_importances = try_get_attr( | |
202 'feature_selectors', 'check_feature_importances') | |
203 estimator = check_feature_importances(estimator) | |
204 new_selector = selector(estimator, **options) | |
205 | |
206 elif inputs['selected_algorithm'] == 'RFE': | |
207 step = options.get('step', None) | |
208 if step and step >= 1.0: | |
209 options['step'] = int(step) | |
210 estimator = get_estimator(inputs["estimator_selector"]) | |
211 check_feature_importances = try_get_attr( | |
212 'feature_selectors', 'check_feature_importances') | |
213 estimator = check_feature_importances(estimator) | |
214 new_selector = selector(estimator, **options) | |
215 | |
216 elif inputs['selected_algorithm'] == 'RFECV': | |
217 options['scoring'] = get_scoring(options['scoring']) | |
218 options['n_jobs'] = N_JOBS | |
219 splitter, groups = get_cv(options.pop('cv_selector')) | |
220 if groups is None: | |
221 options['cv'] = splitter | |
222 else: | |
223 options['cv'] = list(splitter.split(X, y, groups=groups)) | |
224 step = options.get('step', None) | |
225 if step and step >= 1.0: | |
226 options['step'] = int(step) | |
227 estimator = get_estimator(inputs['estimator_selector']) | |
228 check_feature_importances = try_get_attr( | |
229 'feature_selectors', 'check_feature_importances') | |
230 estimator = check_feature_importances(estimator) | |
231 new_selector = selector(estimator, **options) | |
232 | |
233 elif inputs['selected_algorithm'] == 'DyRFECV': | |
234 options['scoring'] = get_scoring(options['scoring']) | |
235 options['n_jobs'] = N_JOBS | |
236 splitter, groups = get_cv(options.pop('cv_selector')) | |
237 if groups is None: | |
238 options['cv'] = splitter | |
239 else: | |
240 options['cv'] = list(splitter.split(X, y, groups=groups)) | |
241 step = options.get('step') | |
242 if not step or step == 'None': | |
243 step = None | |
244 else: | |
245 step = ast.literal_eval(step) | |
246 options['step'] = step | |
247 estimator = get_estimator(inputs["estimator_selector"]) | |
248 check_feature_importances = try_get_attr( | |
249 'feature_selectors', 'check_feature_importances') | |
250 estimator = check_feature_importances(estimator) | |
251 DyRFECV = try_get_attr('feature_selectors', 'DyRFECV') | |
252 | |
253 new_selector = DyRFECV(estimator, **options) | |
254 | |
255 elif inputs['selected_algorithm'] == 'VarianceThreshold': | |
256 new_selector = selector(**options) | |
257 | |
258 else: | |
259 score_func = inputs['score_func'] | |
260 score_func = getattr(sklearn.feature_selection, score_func) | |
261 new_selector = selector(score_func, **options) | |
262 | |
263 return new_selector | |
264 | |
265 | |
266 def get_X_y(params, file1, file2): | |
267 """Return machine learning inputs X, y from tabluar inputs | |
268 """ | |
269 input_type = (params['selected_tasks']['selected_algorithms'] | |
270 ['input_options']['selected_input']) | |
271 if input_type == 'tabular': | |
272 header = 'infer' if (params['selected_tasks']['selected_algorithms'] | |
273 ['input_options']['header1']) else None | |
274 column_option = (params['selected_tasks']['selected_algorithms'] | |
275 ['input_options']['column_selector_options_1'] | |
276 ['selected_column_selector_option']) | |
277 if column_option in ['by_index_number', 'all_but_by_index_number', | |
278 'by_header_name', 'all_but_by_header_name']: | |
279 c = (params['selected_tasks']['selected_algorithms'] | |
280 ['input_options']['column_selector_options_1']['col1']) | |
281 else: | |
282 c = None | |
283 X = read_columns( | |
284 file1, | |
285 c=c, | |
286 c_option=column_option, | |
287 sep='\t', | |
288 header=header, | |
289 parse_dates=True).astype(float) | |
290 else: | |
291 X = mmread(file1) | |
292 | |
293 header = 'infer' if (params['selected_tasks']['selected_algorithms'] | |
294 ['input_options']['header2']) else None | |
295 column_option = (params['selected_tasks']['selected_algorithms'] | |
296 ['input_options']['column_selector_options_2'] | |
297 ['selected_column_selector_option2']) | |
298 if column_option in ['by_index_number', 'all_but_by_index_number', | |
299 'by_header_name', 'all_but_by_header_name']: | |
300 c = (params['selected_tasks']['selected_algorithms'] | |
301 ['input_options']['column_selector_options_2']['col2']) | |
302 else: | |
303 c = None | |
304 y = read_columns( | |
305 file2, | |
306 c=c, | |
307 c_option=column_option, | |
308 sep='\t', | |
309 header=header, | |
310 parse_dates=True) | |
311 y = y.ravel() | |
312 | |
313 return X, y | |
314 | |
315 | |
316 class SafeEval(Interpreter): | |
317 """Customized symbol table for safely literal eval | |
318 """ | |
319 def __init__(self, load_scipy=False, load_numpy=False, | |
320 load_estimators=False): | |
321 | |
322 # File opening and other unneeded functions could be dropped | |
323 unwanted = ['open', 'type', 'dir', 'id', 'str', 'repr'] | |
324 | |
325 # Allowed symbol table. Add more if needed. | |
326 new_syms = { | |
327 'np_arange': getattr(np, 'arange'), | |
328 'ensemble_ExtraTreesClassifier': | |
329 getattr(ensemble, 'ExtraTreesClassifier') | |
330 } | |
331 | |
332 syms = make_symbol_table(use_numpy=False, **new_syms) | |
333 | |
334 if load_scipy: | |
335 scipy_distributions = scipy.stats.distributions.__dict__ | |
336 for k, v in scipy_distributions.items(): | |
337 if isinstance(v, (scipy.stats.rv_continuous, | |
338 scipy.stats.rv_discrete)): | |
339 syms['scipy_stats_' + k] = v | |
340 | |
341 if load_numpy: | |
342 from_numpy_random = [ | |
343 'beta', 'binomial', 'bytes', 'chisquare', 'choice', | |
344 'dirichlet', 'division', 'exponential', 'f', 'gamma', | |
345 'geometric', 'gumbel', 'hypergeometric', 'laplace', | |
346 'logistic', 'lognormal', 'logseries', 'mtrand', | |
347 'multinomial', 'multivariate_normal', 'negative_binomial', | |
348 'noncentral_chisquare', 'noncentral_f', 'normal', 'pareto', | |
349 'permutation', 'poisson', 'power', 'rand', 'randint', | |
350 'randn', 'random', 'random_integers', 'random_sample', | |
351 'ranf', 'rayleigh', 'sample', 'seed', 'set_state', | |
352 'shuffle', 'standard_cauchy', 'standard_exponential', | |
353 'standard_gamma', 'standard_normal', 'standard_t', | |
354 'triangular', 'uniform', 'vonmises', 'wald', 'weibull', 'zipf'] | |
355 for f in from_numpy_random: | |
356 syms['np_random_' + f] = getattr(np.random, f) | |
357 | |
358 if load_estimators: | |
359 estimator_table = { | |
360 'sklearn_svm': getattr(sklearn, 'svm'), | |
361 'sklearn_tree': getattr(sklearn, 'tree'), | |
362 'sklearn_ensemble': getattr(sklearn, 'ensemble'), | |
363 'sklearn_neighbors': getattr(sklearn, 'neighbors'), | |
364 'sklearn_naive_bayes': getattr(sklearn, 'naive_bayes'), | |
365 'sklearn_linear_model': getattr(sklearn, 'linear_model'), | |
366 'sklearn_cluster': getattr(sklearn, 'cluster'), | |
367 'sklearn_decomposition': getattr(sklearn, 'decomposition'), | |
368 'sklearn_preprocessing': getattr(sklearn, 'preprocessing'), | |
369 'sklearn_feature_selection': | |
370 getattr(sklearn, 'feature_selection'), | |
371 'sklearn_kernel_approximation': | |
372 getattr(sklearn, 'kernel_approximation'), | |
373 'skrebate_ReliefF': getattr(skrebate, 'ReliefF'), | |
374 'skrebate_SURF': getattr(skrebate, 'SURF'), | |
375 'skrebate_SURFstar': getattr(skrebate, 'SURFstar'), | |
376 'skrebate_MultiSURF': getattr(skrebate, 'MultiSURF'), | |
377 'skrebate_MultiSURFstar': getattr(skrebate, 'MultiSURFstar'), | |
378 'skrebate_TuRF': getattr(skrebate, 'TuRF'), | |
379 'xgboost_XGBClassifier': getattr(xgboost, 'XGBClassifier'), | |
380 'xgboost_XGBRegressor': getattr(xgboost, 'XGBRegressor'), | |
381 'imblearn_over_sampling': getattr(imblearn, 'over_sampling'), | |
382 'imblearn_combine': getattr(imblearn, 'combine') | |
383 } | |
384 syms.update(estimator_table) | |
385 | |
386 for key in unwanted: | |
387 syms.pop(key, None) | |
388 | |
389 super(SafeEval, self).__init__( | |
390 symtable=syms, use_numpy=False, minimal=False, | |
391 no_if=True, no_for=True, no_while=True, no_try=True, | |
392 no_functiondef=True, no_ifexp=True, no_listcomp=False, | |
393 no_augassign=False, no_assert=True, no_delete=True, | |
394 no_raise=True, no_print=True) | |
395 | |
396 | |
397 def get_estimator(estimator_json): | |
398 """Return a sklearn or compatible estimator from Galaxy tool inputs | |
399 """ | |
400 estimator_module = estimator_json['selected_module'] | |
401 | |
402 if estimator_module == 'custom_estimator': | |
403 c_estimator = estimator_json['c_estimator'] | |
404 with open(c_estimator, 'rb') as model_handler: | |
405 new_model = load_model(model_handler) | |
406 return new_model | |
407 | |
408 if estimator_module == "binarize_target": | |
409 wrapped_estimator = estimator_json['wrapped_estimator'] | |
410 with open(wrapped_estimator, 'rb') as model_handler: | |
411 wrapped_estimator = load_model(model_handler) | |
412 options = {} | |
413 if estimator_json['z_score'] is not None: | |
414 options['z_score'] = estimator_json['z_score'] | |
415 if estimator_json['value'] is not None: | |
416 options['value'] = estimator_json['value'] | |
417 options['less_is_positive'] = estimator_json['less_is_positive'] | |
418 if estimator_json['clf_or_regr'] == 'BinarizeTargetClassifier': | |
419 klass = try_get_attr('iraps_classifier', | |
420 'BinarizeTargetClassifier') | |
421 else: | |
422 klass = try_get_attr('iraps_classifier', | |
423 'BinarizeTargetRegressor') | |
424 return klass(wrapped_estimator, **options) | |
425 | |
426 estimator_cls = estimator_json['selected_estimator'] | |
427 | |
428 if estimator_module == 'xgboost': | |
429 klass = getattr(xgboost, estimator_cls) | |
430 else: | |
431 module = getattr(sklearn, estimator_module) | |
432 klass = getattr(module, estimator_cls) | |
433 | |
434 estimator = klass() | |
435 | |
436 estimator_params = estimator_json['text_params'].strip() | |
437 if estimator_params != '': | |
438 try: | |
439 safe_eval = SafeEval() | |
440 params = safe_eval('dict(' + estimator_params + ')') | |
441 except ValueError: | |
442 sys.exit("Unsupported parameter input: `%s`" % estimator_params) | |
443 estimator.set_params(**params) | |
444 if 'n_jobs' in estimator.get_params(): | |
445 estimator.set_params(n_jobs=N_JOBS) | |
446 | |
447 return estimator | |
448 | |
449 | |
450 def get_cv(cv_json): | |
451 """ Return CV splitter from Galaxy tool inputs | |
452 | |
453 Parameters | |
454 ---------- | |
455 cv_json : dict | |
456 From Galaxy tool inputs. | |
457 e.g.: | |
458 { | |
459 'selected_cv': 'StratifiedKFold', | |
460 'n_splits': 3, | |
461 'shuffle': True, | |
462 'random_state': 0 | |
463 } | |
464 """ | |
465 cv = cv_json.pop('selected_cv') | |
466 if cv == 'default': | |
467 return cv_json['n_splits'], None | |
468 | |
469 groups = cv_json.pop('groups_selector', None) | |
470 if groups is not None: | |
471 infile_g = groups['infile_g'] | |
472 header = 'infer' if groups['header_g'] else None | |
473 column_option = (groups['column_selector_options_g'] | |
474 ['selected_column_selector_option_g']) | |
475 if column_option in ['by_index_number', 'all_but_by_index_number', | |
476 'by_header_name', 'all_but_by_header_name']: | |
477 c = groups['column_selector_options_g']['col_g'] | |
478 else: | |
479 c = None | |
480 groups = read_columns( | |
481 infile_g, | |
482 c=c, | |
483 c_option=column_option, | |
484 sep='\t', | |
485 header=header, | |
486 parse_dates=True) | |
487 groups = groups.ravel() | |
488 | |
489 for k, v in cv_json.items(): | |
490 if v == '': | |
491 cv_json[k] = None | |
492 | |
493 test_fold = cv_json.get('test_fold', None) | |
494 if test_fold: | |
495 if test_fold.startswith('__ob__'): | |
496 test_fold = test_fold[6:] | |
497 if test_fold.endswith('__cb__'): | |
498 test_fold = test_fold[:-6] | |
499 cv_json['test_fold'] = [int(x.strip()) for x in test_fold.split(',')] | |
500 | |
501 test_size = cv_json.get('test_size', None) | |
502 if test_size and test_size > 1.0: | |
503 cv_json['test_size'] = int(test_size) | |
504 | |
505 if cv == 'OrderedKFold': | |
506 cv_class = try_get_attr('model_validations', 'OrderedKFold') | |
507 elif cv == 'RepeatedOrderedKFold': | |
508 cv_class = try_get_attr('model_validations', 'RepeatedOrderedKFold') | |
509 else: | |
510 cv_class = getattr(model_selection, cv) | |
511 splitter = cv_class(**cv_json) | |
512 | |
513 return splitter, groups | |
514 | |
515 | |
516 # needed when sklearn < v0.20 | |
517 def balanced_accuracy_score(y_true, y_pred): | |
518 """Compute balanced accuracy score, which is now available in | |
519 scikit-learn from v0.20.0. | |
520 """ | |
521 C = metrics.confusion_matrix(y_true, y_pred) | |
522 with np.errstate(divide='ignore', invalid='ignore'): | |
523 per_class = np.diag(C) / C.sum(axis=1) | |
524 if np.any(np.isnan(per_class)): | |
525 warnings.warn('y_pred contains classes not in y_true') | |
526 per_class = per_class[~np.isnan(per_class)] | |
527 score = np.mean(per_class) | |
528 return score | |
529 | |
530 | |
531 def get_scoring(scoring_json): | |
532 """Return single sklearn scorer class | |
533 or multiple scoers in dictionary | |
534 """ | |
535 if scoring_json['primary_scoring'] == 'default': | |
536 return None | |
537 | |
538 my_scorers = metrics.SCORERS | |
539 my_scorers['binarize_auc_scorer'] =\ | |
540 try_get_attr('iraps_classifier', 'binarize_auc_scorer') | |
541 my_scorers['binarize_average_precision_scorer'] =\ | |
542 try_get_attr('iraps_classifier', 'binarize_average_precision_scorer') | |
543 if 'balanced_accuracy' not in my_scorers: | |
544 my_scorers['balanced_accuracy'] =\ | |
545 metrics.make_scorer(balanced_accuracy_score) | |
546 | |
547 if scoring_json['secondary_scoring'] != 'None'\ | |
548 and scoring_json['secondary_scoring'] !=\ | |
549 scoring_json['primary_scoring']: | |
550 return_scoring = {} | |
551 primary_scoring = scoring_json['primary_scoring'] | |
552 return_scoring[primary_scoring] = my_scorers[primary_scoring] | |
553 for scorer in scoring_json['secondary_scoring'].split(','): | |
554 if scorer != scoring_json['primary_scoring']: | |
555 return_scoring[scorer] = my_scorers[scorer] | |
556 return return_scoring | |
557 | |
558 return my_scorers[scoring_json['primary_scoring']] | |
559 | |
560 | |
561 def get_search_params(estimator): | |
562 """Format the output of `estimator.get_params()` | |
563 """ | |
564 params = estimator.get_params() | |
565 results = [] | |
566 for k, v in params.items(): | |
567 # params below won't be shown for search in the searchcv tool | |
568 keywords = ('n_jobs', 'pre_dispatch', 'memory', 'steps', | |
569 'nthread', 'verbose') | |
570 if k.endswith(keywords): | |
571 results.append(['*', k, k+": "+repr(v)]) | |
572 else: | |
573 results.append(['@', k, k+": "+repr(v)]) | |
574 results.append( | |
575 ["", "Note:", | |
576 "@, params eligible for search in searchcv tool."]) | |
577 | |
578 return results | |
579 | |
580 | |
581 def try_get_attr(module, name): | |
582 """try to get attribute from a custom module | |
583 | |
584 Parameters | |
585 ---------- | |
586 module : str | |
587 Module name | |
588 name : str | |
589 Attribute (class/function) name. | |
590 | |
591 Returns | |
592 ------- | |
593 class or function | |
594 """ | |
595 mod = sys.modules.get(module, None) | |
596 if mod: | |
597 return getattr(mod, name) | |
598 else: | |
599 raise Exception("No module named %s." % module) |