diff utils.py @ 6:e972a913e61a draft

planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 8cf3d813ec755166ee0bd517b4ecbbd4f84d4df1
author bgruening
date Thu, 23 Aug 2018 16:15:30 -0400
parents 753ebd417b17
children 5072ac474cd5
line wrap: on
line diff
--- a/utils.py	Fri Aug 17 12:27:46 2018 -0400
+++ b/utils.py	Thu Aug 23 16:15:30 2018 -0400
@@ -2,7 +2,7 @@
 import os
 import pandas
 import re
-import pickle
+import cPickle as pickle
 import warnings
 import numpy as np
 import xgboost
@@ -10,10 +10,62 @@
 import sklearn
 import ast
 from asteval import Interpreter, make_symbol_table
-from sklearn import metrics, model_selection, ensemble, svm, linear_model, naive_bayes, tree, neighbors
+from sklearn import (cluster, decomposition, ensemble, feature_extraction, feature_selection,
+                    gaussian_process, kernel_approximation, linear_model, metrics,
+                    model_selection, naive_bayes, neighbors, pipeline, preprocessing,
+                    svm, linear_model, tree, discriminant_analysis)
 
 N_JOBS = int( os.environ.get('GALAXY_SLOTS', 1) )
 
+class SafePickler(object):
+    """
+    Used to safely deserialize scikit-learn model objects serialized by cPickle.dump
+    Usage:
+        eg.: SafePickler.load(pickled_file_object)
+    """
+    @classmethod
+    def find_class(self, module, name):
+
+        bad_names = ('and', 'as', 'assert', 'break', 'class', 'continue',
+                    'def', 'del', 'elif', 'else', 'except', 'exec',
+                    'finally', 'for', 'from', 'global', 'if', 'import',
+                    'in', 'is', 'lambda', 'not', 'or', 'pass', 'print',
+                    'raise', 'return', 'try', 'system', 'while', 'with',
+                    'True', 'False', 'None', 'eval', 'execfile', '__import__',
+                    '__package__', '__subclasses__', '__bases__', '__globals__',
+                    '__code__', '__closure__', '__func__', '__self__', '__module__',
+                    '__dict__', '__class__', '__call__', '__get__',
+                    '__getattribute__', '__subclasshook__', '__new__',
+                    '__init__', 'func_globals', 'func_code', 'func_closure',
+                    'im_class', 'im_func', 'im_self', 'gi_code', 'gi_frame',
+                    '__asteval__', 'f_locals', '__mro__')
+        good_names = ('copy_reg._reconstructor', '__builtin__.object')
+
+        if re.match(r'^[a-zA-Z_][a-zA-Z0-9_]*$', name):
+            fullname = module + '.' + name
+            if  (fullname in good_names)\
+                or  (   (   module.startswith('sklearn.')
+                            or module.startswith('xgboost.')
+                            or module.startswith('skrebate.')
+                            or module.startswith('numpy.')
+                            or module == 'numpy'
+                        )
+                        and (name not in bad_names)
+                    ) :
+                # TODO: replace with a whitelist checker
+                if fullname not in SK_NAMES + SKR_NAMES + XGB_NAMES + NUMPY_NAMES + good_names:
+                    print("Warning: global %s is not in pickler whitelist yet and will loss support soon. Contact tool author or leave a message at github.com" % fullname)
+                mod = sys.modules[module]
+                return getattr(mod, name)
+
+        raise pickle.UnpicklingError("global '%s' is forbidden" % fullname)
+
+    @classmethod
+    def load(self, file):
+        obj = pickle.Unpickler(file)
+        obj.find_global = self.find_class
+        return obj.load()
+
 def read_columns(f, c=None, c_option='by_index_number', return_df=False, **args):
     data = pandas.read_csv(f, **args)
     if c_option == 'by_index_number':
@@ -48,7 +100,7 @@
         if inputs['model_inputter']['input_mode'] == 'prefitted':
             model_file = inputs['model_inputter']['fitted_estimator']
             with open(model_file, 'rb') as model_handler:
-                fitted_estimator = pickle.load(model_handler)
+                fitted_estimator = SafePickler.load(model_handler)
             new_selector = selector(fitted_estimator, prefit=True, **options)
         else:
             estimator_json = inputs['model_inputter']["estimator_selector"]
@@ -132,9 +184,9 @@
 
         if load_scipy:
             scipy_distributions = scipy.stats.distributions.__dict__
-            for key in scipy_distributions.keys():
-                if isinstance(scipy_distributions[key], (scipy.stats.rv_continuous, scipy.stats.rv_discrete)):
-                    syms['scipy_stats_' + key] = scipy_distributions[key]
+            for k, v in scipy_distributions.items():
+                if isinstance(v, (scipy.stats.rv_continuous, scipy.stats.rv_discrete)):
+                    syms['scipy_stats_' + k] = v
 
         if load_numpy:
             from_numpy_random = ['beta', 'binomial', 'bytes', 'chisquare', 'choice', 'dirichlet', 'division',