diff utils.py @ 7:5072ac474cd5 draft

planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 2a058459e6daf0486871f93845f00fdb4a4eaca1
author bgruening
date Sat, 29 Sep 2018 07:29:02 -0400
parents e972a913e61a
children ed7b1654e841
line wrap: on
line diff
--- a/utils.py	Thu Aug 23 16:15:30 2018 -0400
+++ b/utils.py	Sat Sep 29 07:29:02 2018 -0400
@@ -2,28 +2,27 @@
 import os
 import pandas
 import re
-import cPickle as pickle
+import pickle
 import warnings
 import numpy as np
 import xgboost
 import scipy
 import sklearn
-import ast
 from asteval import Interpreter, make_symbol_table
 from sklearn import (cluster, decomposition, ensemble, feature_extraction, feature_selection,
-                    gaussian_process, kernel_approximation, linear_model, metrics,
+                    gaussian_process, kernel_approximation, metrics,
                     model_selection, naive_bayes, neighbors, pipeline, preprocessing,
                     svm, linear_model, tree, discriminant_analysis)
 
-N_JOBS = int( os.environ.get('GALAXY_SLOTS', 1) )
+N_JOBS = int(os.environ.get('GALAXY_SLOTS', 1))
 
-class SafePickler(object):
+
+class SafePickler(pickle.Unpickler):
     """
     Used to safely deserialize scikit-learn model objects serialized by cPickle.dump
     Usage:
         eg.: SafePickler.load(pickled_file_object)
     """
-    @classmethod
     def find_class(self, module, name):
 
         bad_names = ('and', 'as', 'assert', 'break', 'class', 'continue',
@@ -39,11 +38,11 @@
                     '__init__', 'func_globals', 'func_code', 'func_closure',
                     'im_class', 'im_func', 'im_self', 'gi_code', 'gi_frame',
                     '__asteval__', 'f_locals', '__mro__')
-        good_names = ('copy_reg._reconstructor', '__builtin__.object')
+        good_names = ['copy_reg._reconstructor', '__builtin__.object']
 
         if re.match(r'^[a-zA-Z_][a-zA-Z0-9_]*$', name):
             fullname = module + '.' + name
-            if  (fullname in good_names)\
+            if (fullname in good_names)\
                 or  (   (   module.startswith('sklearn.')
                             or module.startswith('xgboost.')
                             or module.startswith('skrebate.')
@@ -51,26 +50,25 @@
                             or module == 'numpy'
                         )
                         and (name not in bad_names)
-                    ) :
+                    ):
                 # TODO: replace with a whitelist checker
-                if fullname not in SK_NAMES + SKR_NAMES + XGB_NAMES + NUMPY_NAMES + good_names:
+                if fullname not in sk_whitelist['SK_NAMES'] + sk_whitelist['SKR_NAMES'] + sk_whitelist['XGB_NAMES'] + sk_whitelist['NUMPY_NAMES'] + good_names:
                     print("Warning: global %s is not in pickler whitelist yet and will loss support soon. Contact tool author or leave a message at github.com" % fullname)
                 mod = sys.modules[module]
                 return getattr(mod, name)
 
         raise pickle.UnpicklingError("global '%s' is forbidden" % fullname)
 
-    @classmethod
-    def load(self, file):
-        obj = pickle.Unpickler(file)
-        obj.find_global = self.find_class
-        return obj.load()
+
+def load_model(file):
+    return SafePickler(file).load()
+
 
 def read_columns(f, c=None, c_option='by_index_number', return_df=False, **args):
     data = pandas.read_csv(f, **args)
     if c_option == 'by_index_number':
         cols = list(map(lambda x: x - 1, c))
-        data = data.iloc[:,cols]
+        data = data.iloc[:, cols]
     if c_option == 'all_but_by_index_number':
         cols = list(map(lambda x: x - 1, c))
         data.drop(data.columns[cols], axis=1, inplace=True)
@@ -100,7 +98,7 @@
         if inputs['model_inputter']['input_mode'] == 'prefitted':
             model_file = inputs['model_inputter']['fitted_estimator']
             with open(model_file, 'rb') as model_handler:
-                fitted_estimator = SafePickler.load(model_handler)
+                fitted_estimator = load_model(model_handler)
             new_selector = selector(fitted_estimator, prefit=True, **options)
         else:
             estimator_json = inputs['model_inputter']["estimator_selector"]
@@ -108,14 +106,14 @@
             new_selector = selector(estimator, **options)
 
     elif inputs['selected_algorithm'] == 'RFE':
-        estimator=get_estimator(inputs["estimator_selector"])
+        estimator = get_estimator(inputs["estimator_selector"])
         new_selector = selector(estimator, **options)
 
     elif inputs['selected_algorithm'] == 'RFECV':
         options['scoring'] = get_scoring(options['scoring'])
         options['n_jobs'] = N_JOBS
-        options['cv'] = get_cv( options['cv'].strip() )
-        estimator=get_estimator(inputs["estimator_selector"])
+        options['cv'] = get_cv(options['cv'].strip())
+        estimator = get_estimator(inputs["estimator_selector"])
         new_selector = selector(estimator, **options)
 
     elif inputs['selected_algorithm'] == "VarianceThreshold":
@@ -127,11 +125,11 @@
         new_selector = selector(score_func, **options)
 
     return new_selector
- 
+
 
 def get_X_y(params, file1, file2):
     input_type = params["selected_tasks"]["selected_algorithms"]["input_options"]["selected_input"]
-    if input_type=="tabular":
+    if input_type == "tabular":
         header = 'infer' if params["selected_tasks"]["selected_algorithms"]["input_options"]["header1"] else None
         column_option = params["selected_tasks"]["selected_algorithms"]["input_options"]["column_selector_options_1"]["selected_column_selector_option"]
         if column_option in ["by_index_number", "all_but_by_index_number", "by_header_name", "all_but_by_header_name"]:
@@ -140,8 +138,8 @@
             c = None
         X = read_columns(
             file1,
-            c = c,
-            c_option = column_option,
+            c=c,
+            c_option=column_option,
             sep='\t',
             header=header,
             parse_dates=True
@@ -157,13 +155,13 @@
         c = None
     y = read_columns(
         file2,
-        c = c,
-        c_option = column_option,
+        c=c,
+        c_option=column_option,
         sep='\t',
         header=header,
         parse_dates=True
     )
-    y=y.ravel()
+    y = y.ravel()
     return X, y
 
 
@@ -197,14 +195,14 @@
                                 'randn', 'random', 'random_integers', 'random_sample', 'ranf', 'rayleigh',
                                 'sample', 'seed', 'set_state', 'shuffle', 'standard_cauchy', 'standard_exponential',
                                 'standard_gamma', 'standard_normal', 'standard_t', 'triangular', 'uniform',
-                                'vonmises', 'wald', 'weibull', 'zipf' ]
+                                'vonmises', 'wald', 'weibull', 'zipf']
             for f in from_numpy_random:
                 syms['np_random_' + f] = getattr(np.random, f)
 
         for key in unwanted:
             syms.pop(key, None)
 
-        super(SafeEval, self).__init__( symtable=syms, use_numpy=False, minimal=False,
+        super(SafeEval, self).__init__(symtable=syms, use_numpy=False, minimal=False,
                                         no_if=True, no_for=True, no_while=True, no_try=True,
                                         no_functiondef=True, no_ifexp=True, no_listcomp=False,
                                         no_augassign=False, no_assert=True, no_delete=True,
@@ -250,10 +248,10 @@
         try:
             params = safe_eval('dict(' + estimator_params + ')')
         except ValueError:
-            sys.exit("Unsupported parameter input: `%s`" %estimator_params)
+            sys.exit("Unsupported parameter input: `%s`" % estimator_params)
         estimator.set_params(**params)
     if 'n_jobs' in estimator.get_params():
-        estimator.set_params( n_jobs=N_JOBS )
+        estimator.set_params(n_jobs=N_JOBS)
 
     return estimator
 
@@ -266,10 +264,10 @@
         return int(literal)
     m = re.match(r'^(?P<method>\w+)\((?P<args>.*)\)$', literal)
     if m:
-        my_class = getattr( model_selection, m.group('method') )
-        args = safe_eval( 'dict('+ m.group('args') + ')' )
-        return my_class( **args )
-    sys.exit("Unsupported CV input: %s" %literal)
+        my_class = getattr(model_selection, m.group('method'))
+        args = safe_eval('dict('+ m.group('args') + ')')
+        return my_class(**args)
+    sys.exit("Unsupported CV input: %s" % literal)
 
 
 def get_scoring(scoring_json):
@@ -293,11 +291,10 @@
     if scoring_json['secondary_scoring'] != 'None'\
             and scoring_json['secondary_scoring'] != scoring_json['primary_scoring']:
         scoring = {}
-        scoring['primary'] = my_scorers[ scoring_json['primary_scoring'] ]
+        scoring['primary'] = my_scorers[scoring_json['primary_scoring']]
         for scorer in scoring_json['secondary_scoring'].split(','):
             if scorer != scoring_json['primary_scoring']:
                 scoring[scorer] = my_scorers[scorer]
         return scoring
 
-    return my_scorers[ scoring_json['primary_scoring'] ]
-
+    return my_scorers[scoring_json['primary_scoring']]