Mercurial > repos > bgruening > sklearn_nn_classifier
diff utils.py @ 10:e9ba818e7877 draft
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
author | bgruening |
---|---|
date | Tue, 14 May 2019 18:09:43 -0400 |
parents | ed7b1654e841 |
children |
line wrap: on
line diff
--- a/utils.py Sun Dec 30 01:54:35 2018 -0500 +++ b/utils.py Tue May 14 18:09:43 2019 -0400 @@ -1,80 +1,134 @@ +import ast import json +import imblearn import numpy as np -import os import pandas import pickle import re import scipy import sklearn +import skrebate import sys import warnings import xgboost +from collections import Counter from asteval import Interpreter, make_symbol_table -from sklearn import (cluster, compose, decomposition, ensemble, feature_extraction, - feature_selection, gaussian_process, kernel_approximation, metrics, - model_selection, naive_bayes, neighbors, pipeline, preprocessing, - svm, linear_model, tree, discriminant_analysis) +from imblearn import under_sampling, over_sampling, combine +from imblearn.pipeline import Pipeline as imbPipeline +from mlxtend import regressor, classifier +from scipy.io import mmread +from sklearn import ( + cluster, compose, decomposition, ensemble, feature_extraction, + feature_selection, gaussian_process, kernel_approximation, metrics, + model_selection, naive_bayes, neighbors, pipeline, preprocessing, + svm, linear_model, tree, discriminant_analysis) + +try: + import iraps_classifier +except ImportError: + pass try: - import skrebate -except ModuleNotFoundError: + import model_validations +except ImportError: + pass + +try: + import feature_selectors +except ImportError: pass - -N_JOBS = int(os.environ.get('GALAXY_SLOTS', 1)) +try: + import preprocessors +except ImportError: + pass -try: - sk_whitelist -except NameError: - sk_whitelist = None +# handle pickle white list file +WL_FILE = __import__('os').path.join( + __import__('os').path.dirname(__file__), 'pk_whitelist.json') + +N_JOBS = int(__import__('os').environ.get('GALAXY_SLOTS', 1)) -class SafePickler(pickle.Unpickler): +class _SafePickler(pickle.Unpickler, object): """ - Used to safely deserialize scikit-learn model objects serialized by cPickle.dump + Used to safely deserialize scikit-learn model objects Usage: - eg.: SafePickler.load(pickled_file_object) + eg.: _SafePickler.load(pickled_file_object) """ - def find_class(self, module, name): + def __init__(self, file): + super(_SafePickler, self).__init__(file) + # load global white list + with open(WL_FILE, 'r') as f: + self.pk_whitelist = json.load(f) - # sk_whitelist could be read from tool - global sk_whitelist - if not sk_whitelist: - whitelist_file = os.path.join(os.path.dirname(__file__), 'sk_whitelist.json') - with open(whitelist_file, 'r') as f: - sk_whitelist = json.load(f) + self.bad_names = ( + 'and', 'as', 'assert', 'break', 'class', 'continue', + 'def', 'del', 'elif', 'else', 'except', 'exec', + 'finally', 'for', 'from', 'global', 'if', 'import', + 'in', 'is', 'lambda', 'not', 'or', 'pass', 'print', + 'raise', 'return', 'try', 'system', 'while', 'with', + 'True', 'False', 'None', 'eval', 'execfile', '__import__', + '__package__', '__subclasses__', '__bases__', '__globals__', + '__code__', '__closure__', '__func__', '__self__', '__module__', + '__dict__', '__class__', '__call__', '__get__', + '__getattribute__', '__subclasshook__', '__new__', + '__init__', 'func_globals', 'func_code', 'func_closure', + 'im_class', 'im_func', 'im_self', 'gi_code', 'gi_frame', + '__asteval__', 'f_locals', '__mro__') - bad_names = ('and', 'as', 'assert', 'break', 'class', 'continue', - 'def', 'del', 'elif', 'else', 'except', 'exec', - 'finally', 'for', 'from', 'global', 'if', 'import', - 'in', 'is', 'lambda', 'not', 'or', 'pass', 'print', - 'raise', 'return', 'try', 'system', 'while', 'with', - 'True', 'False', 'None', 'eval', 'execfile', '__import__', - '__package__', '__subclasses__', '__bases__', '__globals__', - '__code__', '__closure__', '__func__', '__self__', '__module__', - '__dict__', '__class__', '__call__', '__get__', - '__getattribute__', '__subclasshook__', '__new__', - '__init__', 'func_globals', 'func_code', 'func_closure', - 'im_class', 'im_func', 'im_self', 'gi_code', 'gi_frame', - '__asteval__', 'f_locals', '__mro__') - good_names = ['copy_reg._reconstructor', '__builtin__.object'] + # unclassified good globals + self.good_names = [ + 'copy_reg._reconstructor', '__builtin__.object', + '__builtin__.bytearray', 'builtins.object', + 'builtins.bytearray', 'keras.engine.sequential.Sequential', + 'keras.engine.sequential.Model'] + + # custom module in Galaxy-ML + self.custom_modules = [ + '__main__', 'keras_galaxy_models', 'feature_selectors', + 'preprocessors', 'iraps_classifier', 'model_validations'] + # override + def find_class(self, module, name): + # balack list first + if name in self.bad_names: + raise pickle.UnpicklingError("global '%s.%s' is forbidden" + % (module, name)) + + # custom module in Galaxy-ML + if module in self.custom_modules: + cutom_module = sys.modules.get(module, None) + if cutom_module: + return getattr(cutom_module, name) + else: + raise pickle.UnpicklingError("Module %s' is not imported" + % module) + + # For objects from outside libraries, it's necessary to verify + # both module and name. Currently only a blacklist checker + # is working. + # TODO: replace with a whitelist checker. + good_names = self.good_names + pk_whitelist = self.pk_whitelist if re.match(r'^[a-zA-Z_][a-zA-Z0-9_]*$', name): fullname = module + '.' + name if (fullname in good_names)\ - or ( ( module.startswith('sklearn.') - or module.startswith('xgboost.') - or module.startswith('skrebate.') - or module.startswith('imblearn') - or module.startswith('numpy.') - or module == 'numpy' - ) - and (name not in bad_names) - ): - # TODO: replace with a whitelist checker - if fullname not in sk_whitelist['SK_NAMES'] + sk_whitelist['SKR_NAMES'] + sk_whitelist['XGB_NAMES'] + sk_whitelist['NUMPY_NAMES'] + sk_whitelist['IMBLEARN_NAMES'] + good_names: - print("Warning: global %s is not in pickler whitelist yet and will loss support soon. Contact tool author or leave a message at github.com" % fullname) + or (module.startswith(('sklearn.', 'xgboost.', 'skrebate.', + 'imblearn.', 'mlxtend.', 'numpy.')) + or module == 'numpy'): + if fullname not in (pk_whitelist['SK_NAMES'] + + pk_whitelist['SKR_NAMES'] + + pk_whitelist['XGB_NAMES'] + + pk_whitelist['NUMPY_NAMES'] + + pk_whitelist['IMBLEARN_NAMES'] + + pk_whitelist['MLXTEND_NAMES'] + + good_names): + # raise pickle.UnpicklingError + print("Warning: global %s is not in pickler whitelist " + "yet and will loss support soon. Contact tool " + "author or leave a message at github.com" % fullname) mod = sys.modules[module] return getattr(mod, name) @@ -82,10 +136,15 @@ def load_model(file): - return SafePickler(file).load() + """Load pickled object with `_SafePicker` + """ + return _SafePickler(file).load() -def read_columns(f, c=None, c_option='by_index_number', return_df=False, **args): +def read_columns(f, c=None, c_option='by_index_number', + return_df=False, **args): + """Return array from a tabular dataset by various columns selection + """ data = pandas.read_csv(f, **args) if c_option == 'by_index_number': cols = list(map(lambda x: x - 1, c)) @@ -106,10 +165,21 @@ return y -## generate an instance for one of sklearn.feature_selection classes -def feature_selector(inputs): +def feature_selector(inputs, X=None, y=None): + """generate an instance of sklearn.feature_selection classes + + Parameters + ---------- + inputs : dict + From galaxy tool parameters. + X : array + Containing training features. + y : array or list + Target values. + """ selector = inputs['selected_algorithm'] - selector = getattr(sklearn.feature_selection, selector) + if selector != 'DyRFECV': + selector = getattr(sklearn.feature_selection, selector) options = inputs['options'] if inputs['selected_algorithm'] == 'SelectFromModel': @@ -128,27 +198,60 @@ else: estimator_json = inputs['model_inputter']['estimator_selector'] estimator = get_estimator(estimator_json) + check_feature_importances = try_get_attr( + 'feature_selectors', 'check_feature_importances') + estimator = check_feature_importances(estimator) new_selector = selector(estimator, **options) elif inputs['selected_algorithm'] == 'RFE': - estimator = get_estimator(inputs['estimator_selector']) step = options.get('step', None) if step and step >= 1.0: options['step'] = int(step) + estimator = get_estimator(inputs["estimator_selector"]) + check_feature_importances = try_get_attr( + 'feature_selectors', 'check_feature_importances') + estimator = check_feature_importances(estimator) new_selector = selector(estimator, **options) elif inputs['selected_algorithm'] == 'RFECV': options['scoring'] = get_scoring(options['scoring']) options['n_jobs'] = N_JOBS splitter, groups = get_cv(options.pop('cv_selector')) - # TODO support group cv splitters - options['cv'] = splitter + if groups is None: + options['cv'] = splitter + else: + options['cv'] = list(splitter.split(X, y, groups=groups)) step = options.get('step', None) if step and step >= 1.0: options['step'] = int(step) estimator = get_estimator(inputs['estimator_selector']) + check_feature_importances = try_get_attr( + 'feature_selectors', 'check_feature_importances') + estimator = check_feature_importances(estimator) new_selector = selector(estimator, **options) + elif inputs['selected_algorithm'] == 'DyRFECV': + options['scoring'] = get_scoring(options['scoring']) + options['n_jobs'] = N_JOBS + splitter, groups = get_cv(options.pop('cv_selector')) + if groups is None: + options['cv'] = splitter + else: + options['cv'] = list(splitter.split(X, y, groups=groups)) + step = options.get('step') + if not step or step == 'None': + step = None + else: + step = ast.literal_eval(step) + options['step'] = step + estimator = get_estimator(inputs["estimator_selector"]) + check_feature_importances = try_get_attr( + 'feature_selectors', 'check_feature_importances') + estimator = check_feature_importances(estimator) + DyRFECV = try_get_attr('feature_selectors', 'DyRFECV') + + new_selector = DyRFECV(estimator, **options) + elif inputs['selected_algorithm'] == 'VarianceThreshold': new_selector = selector(**options) @@ -161,12 +264,20 @@ def get_X_y(params, file1, file2): - input_type = params['selected_tasks']['selected_algorithms']['input_options']['selected_input'] + """Return machine learning inputs X, y from tabluar inputs + """ + input_type = (params['selected_tasks']['selected_algorithms'] + ['input_options']['selected_input']) if input_type == 'tabular': - header = 'infer' if params['selected_tasks']['selected_algorithms']['input_options']['header1'] else None - column_option = params['selected_tasks']['selected_algorithms']['input_options']['column_selector_options_1']['selected_column_selector_option'] - if column_option in ['by_index_number', 'all_but_by_index_number', 'by_header_name', 'all_but_by_header_name']: - c = params['selected_tasks']['selected_algorithms']['input_options']['column_selector_options_1']['col1'] + header = 'infer' if (params['selected_tasks']['selected_algorithms'] + ['input_options']['header1']) else None + column_option = (params['selected_tasks']['selected_algorithms'] + ['input_options']['column_selector_options_1'] + ['selected_column_selector_option']) + if column_option in ['by_index_number', 'all_but_by_index_number', + 'by_header_name', 'all_but_by_header_name']: + c = (params['selected_tasks']['selected_algorithms'] + ['input_options']['column_selector_options_1']['col1']) else: c = None X = read_columns( @@ -175,15 +286,19 @@ c_option=column_option, sep='\t', header=header, - parse_dates=True - ) + parse_dates=True).astype(float) else: X = mmread(file1) - header = 'infer' if params['selected_tasks']['selected_algorithms']['input_options']['header2'] else None - column_option = params['selected_tasks']['selected_algorithms']['input_options']['column_selector_options_2']['selected_column_selector_option2'] - if column_option in ['by_index_number', 'all_but_by_index_number', 'by_header_name', 'all_but_by_header_name']: - c = params['selected_tasks']['selected_algorithms']['input_options']['column_selector_options_2']['col2'] + header = 'infer' if (params['selected_tasks']['selected_algorithms'] + ['input_options']['header2']) else None + column_option = (params['selected_tasks']['selected_algorithms'] + ['input_options']['column_selector_options_2'] + ['selected_column_selector_option2']) + if column_option in ['by_index_number', 'all_but_by_index_number', + 'by_header_name', 'all_but_by_header_name']: + c = (params['selected_tasks']['selected_algorithms'] + ['input_options']['column_selector_options_2']['col2']) else: c = None y = read_columns( @@ -192,15 +307,17 @@ c_option=column_option, sep='\t', header=header, - parse_dates=True - ) + parse_dates=True) y = y.ravel() + return X, y class SafeEval(Interpreter): - - def __init__(self, load_scipy=False, load_numpy=False, load_estimators=False): + """Customized symbol table for safely literal eval + """ + def __init__(self, load_scipy=False, load_numpy=False, + load_estimators=False): # File opening and other unneeded functions could be dropped unwanted = ['open', 'type', 'dir', 'id', 'str', 'repr'] @@ -208,7 +325,8 @@ # Allowed symbol table. Add more if needed. new_syms = { 'np_arange': getattr(np, 'arange'), - 'ensemble_ExtraTreesClassifier': getattr(ensemble, 'ExtraTreesClassifier') + 'ensemble_ExtraTreesClassifier': + getattr(ensemble, 'ExtraTreesClassifier') } syms = make_symbol_table(use_numpy=False, **new_syms) @@ -216,80 +334,109 @@ if load_scipy: scipy_distributions = scipy.stats.distributions.__dict__ for k, v in scipy_distributions.items(): - if isinstance(v, (scipy.stats.rv_continuous, scipy.stats.rv_discrete)): + if isinstance(v, (scipy.stats.rv_continuous, + scipy.stats.rv_discrete)): syms['scipy_stats_' + k] = v if load_numpy: - from_numpy_random = ['beta', 'binomial', 'bytes', 'chisquare', 'choice', 'dirichlet', 'division', - 'exponential', 'f', 'gamma', 'geometric', 'gumbel', 'hypergeometric', - 'laplace', 'logistic', 'lognormal', 'logseries', 'mtrand', 'multinomial', - 'multivariate_normal', 'negative_binomial', 'noncentral_chisquare', 'noncentral_f', - 'normal', 'pareto', 'permutation', 'poisson', 'power', 'rand', 'randint', - 'randn', 'random', 'random_integers', 'random_sample', 'ranf', 'rayleigh', - 'sample', 'seed', 'set_state', 'shuffle', 'standard_cauchy', 'standard_exponential', - 'standard_gamma', 'standard_normal', 'standard_t', 'triangular', 'uniform', - 'vonmises', 'wald', 'weibull', 'zipf'] + from_numpy_random = [ + 'beta', 'binomial', 'bytes', 'chisquare', 'choice', + 'dirichlet', 'division', 'exponential', 'f', 'gamma', + 'geometric', 'gumbel', 'hypergeometric', 'laplace', + 'logistic', 'lognormal', 'logseries', 'mtrand', + 'multinomial', 'multivariate_normal', 'negative_binomial', + 'noncentral_chisquare', 'noncentral_f', 'normal', 'pareto', + 'permutation', 'poisson', 'power', 'rand', 'randint', + 'randn', 'random', 'random_integers', 'random_sample', + 'ranf', 'rayleigh', 'sample', 'seed', 'set_state', + 'shuffle', 'standard_cauchy', 'standard_exponential', + 'standard_gamma', 'standard_normal', 'standard_t', + 'triangular', 'uniform', 'vonmises', 'wald', 'weibull', 'zipf'] for f in from_numpy_random: syms['np_random_' + f] = getattr(np.random, f) if load_estimators: estimator_table = { - 'sklearn_svm' : getattr(sklearn, 'svm'), - 'sklearn_tree' : getattr(sklearn, 'tree'), - 'sklearn_ensemble' : getattr(sklearn, 'ensemble'), - 'sklearn_neighbors' : getattr(sklearn, 'neighbors'), - 'sklearn_naive_bayes' : getattr(sklearn, 'naive_bayes'), - 'sklearn_linear_model' : getattr(sklearn, 'linear_model'), - 'sklearn_cluster' : getattr(sklearn, 'cluster'), - 'sklearn_decomposition' : getattr(sklearn, 'decomposition'), - 'sklearn_preprocessing' : getattr(sklearn, 'preprocessing'), - 'sklearn_feature_selection' : getattr(sklearn, 'feature_selection'), - 'sklearn_kernel_approximation' : getattr(sklearn, 'kernel_approximation'), + 'sklearn_svm': getattr(sklearn, 'svm'), + 'sklearn_tree': getattr(sklearn, 'tree'), + 'sklearn_ensemble': getattr(sklearn, 'ensemble'), + 'sklearn_neighbors': getattr(sklearn, 'neighbors'), + 'sklearn_naive_bayes': getattr(sklearn, 'naive_bayes'), + 'sklearn_linear_model': getattr(sklearn, 'linear_model'), + 'sklearn_cluster': getattr(sklearn, 'cluster'), + 'sklearn_decomposition': getattr(sklearn, 'decomposition'), + 'sklearn_preprocessing': getattr(sklearn, 'preprocessing'), + 'sklearn_feature_selection': + getattr(sklearn, 'feature_selection'), + 'sklearn_kernel_approximation': + getattr(sklearn, 'kernel_approximation'), 'skrebate_ReliefF': getattr(skrebate, 'ReliefF'), 'skrebate_SURF': getattr(skrebate, 'SURF'), 'skrebate_SURFstar': getattr(skrebate, 'SURFstar'), 'skrebate_MultiSURF': getattr(skrebate, 'MultiSURF'), 'skrebate_MultiSURFstar': getattr(skrebate, 'MultiSURFstar'), 'skrebate_TuRF': getattr(skrebate, 'TuRF'), - 'xgboost_XGBClassifier' : getattr(xgboost, 'XGBClassifier'), - 'xgboost_XGBRegressor' : getattr(xgboost, 'XGBRegressor') + 'xgboost_XGBClassifier': getattr(xgboost, 'XGBClassifier'), + 'xgboost_XGBRegressor': getattr(xgboost, 'XGBRegressor'), + 'imblearn_over_sampling': getattr(imblearn, 'over_sampling'), + 'imblearn_combine': getattr(imblearn, 'combine') } syms.update(estimator_table) for key in unwanted: syms.pop(key, None) - super(SafeEval, self).__init__(symtable=syms, use_numpy=False, minimal=False, - no_if=True, no_for=True, no_while=True, no_try=True, - no_functiondef=True, no_ifexp=True, no_listcomp=False, - no_augassign=False, no_assert=True, no_delete=True, - no_raise=True, no_print=True) - + super(SafeEval, self).__init__( + symtable=syms, use_numpy=False, minimal=False, + no_if=True, no_for=True, no_while=True, no_try=True, + no_functiondef=True, no_ifexp=True, no_listcomp=False, + no_augassign=False, no_assert=True, no_delete=True, + no_raise=True, no_print=True) def get_estimator(estimator_json): - + """Return a sklearn or compatible estimator from Galaxy tool inputs + """ estimator_module = estimator_json['selected_module'] - if estimator_module == 'customer_estimator': + if estimator_module == 'custom_estimator': c_estimator = estimator_json['c_estimator'] with open(c_estimator, 'rb') as model_handler: new_model = load_model(model_handler) return new_model + if estimator_module == "binarize_target": + wrapped_estimator = estimator_json['wrapped_estimator'] + with open(wrapped_estimator, 'rb') as model_handler: + wrapped_estimator = load_model(model_handler) + options = {} + if estimator_json['z_score'] is not None: + options['z_score'] = estimator_json['z_score'] + if estimator_json['value'] is not None: + options['value'] = estimator_json['value'] + options['less_is_positive'] = estimator_json['less_is_positive'] + if estimator_json['clf_or_regr'] == 'BinarizeTargetClassifier': + klass = try_get_attr('iraps_classifier', + 'BinarizeTargetClassifier') + else: + klass = try_get_attr('iraps_classifier', + 'BinarizeTargetRegressor') + return klass(wrapped_estimator, **options) + estimator_cls = estimator_json['selected_estimator'] if estimator_module == 'xgboost': - cls = getattr(xgboost, estimator_cls) + klass = getattr(xgboost, estimator_cls) else: module = getattr(sklearn, estimator_module) - cls = getattr(module, estimator_cls) + klass = getattr(module, estimator_cls) - estimator = cls() + estimator = klass() estimator_params = estimator_json['text_params'].strip() if estimator_params != '': try: + safe_eval = SafeEval() params = safe_eval('dict(' + estimator_params + ')') except ValueError: sys.exit("Unsupported parameter input: `%s`" % estimator_params) @@ -301,9 +448,13 @@ def get_cv(cv_json): - """ - cv_json: - e.g.: + """ Return CV splitter from Galaxy tool inputs + + Parameters + ---------- + cv_json : dict + From Galaxy tool inputs. + e.g.: { 'selected_cv': 'StratifiedKFold', 'n_splits': 3, @@ -315,15 +466,25 @@ if cv == 'default': return cv_json['n_splits'], None - groups = cv_json.pop('groups', None) - if groups: - groups = groups.strip() - if groups != '': - if groups.startswith('__ob__'): - groups = groups[6:] - if groups.endswith('__cb__'): - groups = groups[:-6] - groups = [int(x.strip()) for x in groups.split(',')] + groups = cv_json.pop('groups_selector', None) + if groups is not None: + infile_g = groups['infile_g'] + header = 'infer' if groups['header_g'] else None + column_option = (groups['column_selector_options_g'] + ['selected_column_selector_option_g']) + if column_option in ['by_index_number', 'all_but_by_index_number', + 'by_header_name', 'all_but_by_header_name']: + c = groups['column_selector_options_g']['col_g'] + else: + c = None + groups = read_columns( + infile_g, + c=c, + c_option=column_option, + sep='\t', + header=header, + parse_dates=True) + groups = groups.ravel() for k, v in cv_json.items(): if v == '': @@ -341,7 +502,12 @@ if test_size and test_size > 1.0: cv_json['test_size'] = int(test_size) - cv_class = getattr(model_selection, cv) + if cv == 'OrderedKFold': + cv_class = try_get_attr('model_validations', 'OrderedKFold') + elif cv == 'RepeatedOrderedKFold': + cv_class = try_get_attr('model_validations', 'RepeatedOrderedKFold') + else: + cv_class = getattr(model_selection, cv) splitter = cv_class(**cv_json) return splitter, groups @@ -349,6 +515,9 @@ # needed when sklearn < v0.20 def balanced_accuracy_score(y_true, y_pred): + """Compute balanced accuracy score, which is now available in + scikit-learn from v0.20.0. + """ C = metrics.confusion_matrix(y_true, y_pred) with np.errstate(divide='ignore', invalid='ignore'): per_class = np.diag(C) / C.sum(axis=1) @@ -360,21 +529,71 @@ def get_scoring(scoring_json): - + """Return single sklearn scorer class + or multiple scoers in dictionary + """ if scoring_json['primary_scoring'] == 'default': return None my_scorers = metrics.SCORERS + my_scorers['binarize_auc_scorer'] =\ + try_get_attr('iraps_classifier', 'binarize_auc_scorer') + my_scorers['binarize_average_precision_scorer'] =\ + try_get_attr('iraps_classifier', 'binarize_average_precision_scorer') if 'balanced_accuracy' not in my_scorers: - my_scorers['balanced_accuracy'] = metrics.make_scorer(balanced_accuracy_score) + my_scorers['balanced_accuracy'] =\ + metrics.make_scorer(balanced_accuracy_score) if scoring_json['secondary_scoring'] != 'None'\ - and scoring_json['secondary_scoring'] != scoring_json['primary_scoring']: - scoring = {} - scoring['primary'] = my_scorers[scoring_json['primary_scoring']] + and scoring_json['secondary_scoring'] !=\ + scoring_json['primary_scoring']: + return_scoring = {} + primary_scoring = scoring_json['primary_scoring'] + return_scoring[primary_scoring] = my_scorers[primary_scoring] for scorer in scoring_json['secondary_scoring'].split(','): if scorer != scoring_json['primary_scoring']: - scoring[scorer] = my_scorers[scorer] - return scoring + return_scoring[scorer] = my_scorers[scorer] + return return_scoring return my_scorers[scoring_json['primary_scoring']] + + +def get_search_params(estimator): + """Format the output of `estimator.get_params()` + """ + params = estimator.get_params() + results = [] + for k, v in params.items(): + # params below won't be shown for search in the searchcv tool + keywords = ('n_jobs', 'pre_dispatch', 'memory', 'steps', + 'nthread', 'verbose') + if k.endswith(keywords): + results.append(['*', k, k+": "+repr(v)]) + else: + results.append(['@', k, k+": "+repr(v)]) + results.append( + ["", "Note:", + "@, params eligible for search in searchcv tool."]) + + return results + + +def try_get_attr(module, name): + """try to get attribute from a custom module + + Parameters + ---------- + module : str + Module name + name : str + Attribute (class/function) name. + + Returns + ------- + class or function + """ + mod = sys.modules.get(module, None) + if mod: + return getattr(mod, name) + else: + raise Exception("No module named %s." % module)