Mercurial > repos > bgruening > sklearn_data_preprocess
diff utils.py @ 20:2bda387c73e4 draft
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 8cf3d813ec755166ee0bd517b4ecbbd4f84d4df1
author | bgruening |
---|---|
date | Thu, 23 Aug 2018 16:16:54 -0400 |
parents | f196d4715cfb |
children | f156acc7239b |
line wrap: on
line diff
--- a/utils.py Fri Aug 17 12:28:58 2018 -0400 +++ b/utils.py Thu Aug 23 16:16:54 2018 -0400 @@ -2,7 +2,7 @@ import os import pandas import re -import pickle +import cPickle as pickle import warnings import numpy as np import xgboost @@ -10,10 +10,62 @@ import sklearn import ast from asteval import Interpreter, make_symbol_table -from sklearn import metrics, model_selection, ensemble, svm, linear_model, naive_bayes, tree, neighbors +from sklearn import (cluster, decomposition, ensemble, feature_extraction, feature_selection, + gaussian_process, kernel_approximation, linear_model, metrics, + model_selection, naive_bayes, neighbors, pipeline, preprocessing, + svm, linear_model, tree, discriminant_analysis) N_JOBS = int( os.environ.get('GALAXY_SLOTS', 1) ) +class SafePickler(object): + """ + Used to safely deserialize scikit-learn model objects serialized by cPickle.dump + Usage: + eg.: SafePickler.load(pickled_file_object) + """ + @classmethod + def find_class(self, module, name): + + bad_names = ('and', 'as', 'assert', 'break', 'class', 'continue', + 'def', 'del', 'elif', 'else', 'except', 'exec', + 'finally', 'for', 'from', 'global', 'if', 'import', + 'in', 'is', 'lambda', 'not', 'or', 'pass', 'print', + 'raise', 'return', 'try', 'system', 'while', 'with', + 'True', 'False', 'None', 'eval', 'execfile', '__import__', + '__package__', '__subclasses__', '__bases__', '__globals__', + '__code__', '__closure__', '__func__', '__self__', '__module__', + '__dict__', '__class__', '__call__', '__get__', + '__getattribute__', '__subclasshook__', '__new__', + '__init__', 'func_globals', 'func_code', 'func_closure', + 'im_class', 'im_func', 'im_self', 'gi_code', 'gi_frame', + '__asteval__', 'f_locals', '__mro__') + good_names = ('copy_reg._reconstructor', '__builtin__.object') + + if re.match(r'^[a-zA-Z_][a-zA-Z0-9_]*$', name): + fullname = module + '.' + name + if (fullname in good_names)\ + or ( ( module.startswith('sklearn.') + or module.startswith('xgboost.') + or module.startswith('skrebate.') + or module.startswith('numpy.') + or module == 'numpy' + ) + and (name not in bad_names) + ) : + # TODO: replace with a whitelist checker + if fullname not in SK_NAMES + SKR_NAMES + XGB_NAMES + NUMPY_NAMES + good_names: + print("Warning: global %s is not in pickler whitelist yet and will loss support soon. Contact tool author or leave a message at github.com" % fullname) + mod = sys.modules[module] + return getattr(mod, name) + + raise pickle.UnpicklingError("global '%s' is forbidden" % fullname) + + @classmethod + def load(self, file): + obj = pickle.Unpickler(file) + obj.find_global = self.find_class + return obj.load() + def read_columns(f, c=None, c_option='by_index_number', return_df=False, **args): data = pandas.read_csv(f, **args) if c_option == 'by_index_number': @@ -48,7 +100,7 @@ if inputs['model_inputter']['input_mode'] == 'prefitted': model_file = inputs['model_inputter']['fitted_estimator'] with open(model_file, 'rb') as model_handler: - fitted_estimator = pickle.load(model_handler) + fitted_estimator = SafePickler.load(model_handler) new_selector = selector(fitted_estimator, prefit=True, **options) else: estimator_json = inputs['model_inputter']["estimator_selector"] @@ -132,9 +184,9 @@ if load_scipy: scipy_distributions = scipy.stats.distributions.__dict__ - for key in scipy_distributions.keys(): - if isinstance(scipy_distributions[key], (scipy.stats.rv_continuous, scipy.stats.rv_discrete)): - syms['scipy_stats_' + key] = scipy_distributions[key] + for k, v in scipy_distributions.items(): + if isinstance(v, (scipy.stats.rv_continuous, scipy.stats.rv_discrete)): + syms['scipy_stats_' + k] = v if load_numpy: from_numpy_random = ['beta', 'binomial', 'bytes', 'chisquare', 'choice', 'dirichlet', 'division',