Mercurial > repos > bgruening > sklearn_nn_classifier
diff utils.py @ 7:5072ac474cd5 draft
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 2a058459e6daf0486871f93845f00fdb4a4eaca1
author | bgruening |
---|---|
date | Sat, 29 Sep 2018 07:29:02 -0400 |
parents | e972a913e61a |
children | ed7b1654e841 |
line wrap: on
line diff
--- a/utils.py Thu Aug 23 16:15:30 2018 -0400 +++ b/utils.py Sat Sep 29 07:29:02 2018 -0400 @@ -2,28 +2,27 @@ import os import pandas import re -import cPickle as pickle +import pickle import warnings import numpy as np import xgboost import scipy import sklearn -import ast from asteval import Interpreter, make_symbol_table from sklearn import (cluster, decomposition, ensemble, feature_extraction, feature_selection, - gaussian_process, kernel_approximation, linear_model, metrics, + gaussian_process, kernel_approximation, metrics, model_selection, naive_bayes, neighbors, pipeline, preprocessing, svm, linear_model, tree, discriminant_analysis) -N_JOBS = int( os.environ.get('GALAXY_SLOTS', 1) ) +N_JOBS = int(os.environ.get('GALAXY_SLOTS', 1)) -class SafePickler(object): + +class SafePickler(pickle.Unpickler): """ Used to safely deserialize scikit-learn model objects serialized by cPickle.dump Usage: eg.: SafePickler.load(pickled_file_object) """ - @classmethod def find_class(self, module, name): bad_names = ('and', 'as', 'assert', 'break', 'class', 'continue', @@ -39,11 +38,11 @@ '__init__', 'func_globals', 'func_code', 'func_closure', 'im_class', 'im_func', 'im_self', 'gi_code', 'gi_frame', '__asteval__', 'f_locals', '__mro__') - good_names = ('copy_reg._reconstructor', '__builtin__.object') + good_names = ['copy_reg._reconstructor', '__builtin__.object'] if re.match(r'^[a-zA-Z_][a-zA-Z0-9_]*$', name): fullname = module + '.' + name - if (fullname in good_names)\ + if (fullname in good_names)\ or ( ( module.startswith('sklearn.') or module.startswith('xgboost.') or module.startswith('skrebate.') @@ -51,26 +50,25 @@ or module == 'numpy' ) and (name not in bad_names) - ) : + ): # TODO: replace with a whitelist checker - if fullname not in SK_NAMES + SKR_NAMES + XGB_NAMES + NUMPY_NAMES + good_names: + if fullname not in sk_whitelist['SK_NAMES'] + sk_whitelist['SKR_NAMES'] + sk_whitelist['XGB_NAMES'] + sk_whitelist['NUMPY_NAMES'] + good_names: print("Warning: global %s is not in pickler whitelist yet and will loss support soon. Contact tool author or leave a message at github.com" % fullname) mod = sys.modules[module] return getattr(mod, name) raise pickle.UnpicklingError("global '%s' is forbidden" % fullname) - @classmethod - def load(self, file): - obj = pickle.Unpickler(file) - obj.find_global = self.find_class - return obj.load() + +def load_model(file): + return SafePickler(file).load() + def read_columns(f, c=None, c_option='by_index_number', return_df=False, **args): data = pandas.read_csv(f, **args) if c_option == 'by_index_number': cols = list(map(lambda x: x - 1, c)) - data = data.iloc[:,cols] + data = data.iloc[:, cols] if c_option == 'all_but_by_index_number': cols = list(map(lambda x: x - 1, c)) data.drop(data.columns[cols], axis=1, inplace=True) @@ -100,7 +98,7 @@ if inputs['model_inputter']['input_mode'] == 'prefitted': model_file = inputs['model_inputter']['fitted_estimator'] with open(model_file, 'rb') as model_handler: - fitted_estimator = SafePickler.load(model_handler) + fitted_estimator = load_model(model_handler) new_selector = selector(fitted_estimator, prefit=True, **options) else: estimator_json = inputs['model_inputter']["estimator_selector"] @@ -108,14 +106,14 @@ new_selector = selector(estimator, **options) elif inputs['selected_algorithm'] == 'RFE': - estimator=get_estimator(inputs["estimator_selector"]) + estimator = get_estimator(inputs["estimator_selector"]) new_selector = selector(estimator, **options) elif inputs['selected_algorithm'] == 'RFECV': options['scoring'] = get_scoring(options['scoring']) options['n_jobs'] = N_JOBS - options['cv'] = get_cv( options['cv'].strip() ) - estimator=get_estimator(inputs["estimator_selector"]) + options['cv'] = get_cv(options['cv'].strip()) + estimator = get_estimator(inputs["estimator_selector"]) new_selector = selector(estimator, **options) elif inputs['selected_algorithm'] == "VarianceThreshold": @@ -127,11 +125,11 @@ new_selector = selector(score_func, **options) return new_selector - + def get_X_y(params, file1, file2): input_type = params["selected_tasks"]["selected_algorithms"]["input_options"]["selected_input"] - if input_type=="tabular": + if input_type == "tabular": header = 'infer' if params["selected_tasks"]["selected_algorithms"]["input_options"]["header1"] else None column_option = params["selected_tasks"]["selected_algorithms"]["input_options"]["column_selector_options_1"]["selected_column_selector_option"] if column_option in ["by_index_number", "all_but_by_index_number", "by_header_name", "all_but_by_header_name"]: @@ -140,8 +138,8 @@ c = None X = read_columns( file1, - c = c, - c_option = column_option, + c=c, + c_option=column_option, sep='\t', header=header, parse_dates=True @@ -157,13 +155,13 @@ c = None y = read_columns( file2, - c = c, - c_option = column_option, + c=c, + c_option=column_option, sep='\t', header=header, parse_dates=True ) - y=y.ravel() + y = y.ravel() return X, y @@ -197,14 +195,14 @@ 'randn', 'random', 'random_integers', 'random_sample', 'ranf', 'rayleigh', 'sample', 'seed', 'set_state', 'shuffle', 'standard_cauchy', 'standard_exponential', 'standard_gamma', 'standard_normal', 'standard_t', 'triangular', 'uniform', - 'vonmises', 'wald', 'weibull', 'zipf' ] + 'vonmises', 'wald', 'weibull', 'zipf'] for f in from_numpy_random: syms['np_random_' + f] = getattr(np.random, f) for key in unwanted: syms.pop(key, None) - super(SafeEval, self).__init__( symtable=syms, use_numpy=False, minimal=False, + super(SafeEval, self).__init__(symtable=syms, use_numpy=False, minimal=False, no_if=True, no_for=True, no_while=True, no_try=True, no_functiondef=True, no_ifexp=True, no_listcomp=False, no_augassign=False, no_assert=True, no_delete=True, @@ -250,10 +248,10 @@ try: params = safe_eval('dict(' + estimator_params + ')') except ValueError: - sys.exit("Unsupported parameter input: `%s`" %estimator_params) + sys.exit("Unsupported parameter input: `%s`" % estimator_params) estimator.set_params(**params) if 'n_jobs' in estimator.get_params(): - estimator.set_params( n_jobs=N_JOBS ) + estimator.set_params(n_jobs=N_JOBS) return estimator @@ -266,10 +264,10 @@ return int(literal) m = re.match(r'^(?P<method>\w+)\((?P<args>.*)\)$', literal) if m: - my_class = getattr( model_selection, m.group('method') ) - args = safe_eval( 'dict('+ m.group('args') + ')' ) - return my_class( **args ) - sys.exit("Unsupported CV input: %s" %literal) + my_class = getattr(model_selection, m.group('method')) + args = safe_eval('dict('+ m.group('args') + ')') + return my_class(**args) + sys.exit("Unsupported CV input: %s" % literal) def get_scoring(scoring_json): @@ -293,11 +291,10 @@ if scoring_json['secondary_scoring'] != 'None'\ and scoring_json['secondary_scoring'] != scoring_json['primary_scoring']: scoring = {} - scoring['primary'] = my_scorers[ scoring_json['primary_scoring'] ] + scoring['primary'] = my_scorers[scoring_json['primary_scoring']] for scorer in scoring_json['secondary_scoring'].split(','): if scorer != scoring_json['primary_scoring']: scoring[scorer] = my_scorers[scorer] return scoring - return my_scorers[ scoring_json['primary_scoring'] ] - + return my_scorers[scoring_json['primary_scoring']]