comparison utils.py @ 13:badd86b9ce24 draft

planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 8cf3d813ec755166ee0bd517b4ecbbd4f84d4df1
author bgruening
date Thu, 23 Aug 2018 16:21:01 -0400
parents 2c1851992069
children e244d6f2df1a
comparison
equal deleted inserted replaced
12:2c1851992069 13:badd86b9ce24
1 import sys 1 import sys
2 import os 2 import os
3 import pandas 3 import pandas
4 import re 4 import re
5 import pickle 5 import cPickle as pickle
6 import warnings 6 import warnings
7 import numpy as np 7 import numpy as np
8 import xgboost 8 import xgboost
9 import scipy 9 import scipy
10 import sklearn 10 import sklearn
11 import ast 11 import ast
12 from asteval import Interpreter, make_symbol_table 12 from asteval import Interpreter, make_symbol_table
13 from sklearn import metrics, model_selection, ensemble, svm, linear_model, naive_bayes, tree, neighbors 13 from sklearn import (cluster, decomposition, ensemble, feature_extraction, feature_selection,
14 gaussian_process, kernel_approximation, linear_model, metrics,
15 model_selection, naive_bayes, neighbors, pipeline, preprocessing,
16 svm, linear_model, tree, discriminant_analysis)
14 17
15 N_JOBS = int( os.environ.get('GALAXY_SLOTS', 1) ) 18 N_JOBS = int( os.environ.get('GALAXY_SLOTS', 1) )
19
20 class SafePickler(object):
21 """
22 Used to safely deserialize scikit-learn model objects serialized by cPickle.dump
23 Usage:
24 eg.: SafePickler.load(pickled_file_object)
25 """
26 @classmethod
27 def find_class(self, module, name):
28
29 bad_names = ('and', 'as', 'assert', 'break', 'class', 'continue',
30 'def', 'del', 'elif', 'else', 'except', 'exec',
31 'finally', 'for', 'from', 'global', 'if', 'import',
32 'in', 'is', 'lambda', 'not', 'or', 'pass', 'print',
33 'raise', 'return', 'try', 'system', 'while', 'with',
34 'True', 'False', 'None', 'eval', 'execfile', '__import__',
35 '__package__', '__subclasses__', '__bases__', '__globals__',
36 '__code__', '__closure__', '__func__', '__self__', '__module__',
37 '__dict__', '__class__', '__call__', '__get__',
38 '__getattribute__', '__subclasshook__', '__new__',
39 '__init__', 'func_globals', 'func_code', 'func_closure',
40 'im_class', 'im_func', 'im_self', 'gi_code', 'gi_frame',
41 '__asteval__', 'f_locals', '__mro__')
42 good_names = ('copy_reg._reconstructor', '__builtin__.object')
43
44 if re.match(r'^[a-zA-Z_][a-zA-Z0-9_]*$', name):
45 fullname = module + '.' + name
46 if (fullname in good_names)\
47 or ( ( module.startswith('sklearn.')
48 or module.startswith('xgboost.')
49 or module.startswith('skrebate.')
50 or module.startswith('numpy.')
51 or module == 'numpy'
52 )
53 and (name not in bad_names)
54 ) :
55 # TODO: replace with a whitelist checker
56 if fullname not in SK_NAMES + SKR_NAMES + XGB_NAMES + NUMPY_NAMES + good_names:
57 print("Warning: global %s is not in pickler whitelist yet and will loss support soon. Contact tool author or leave a message at github.com" % fullname)
58 mod = sys.modules[module]
59 return getattr(mod, name)
60
61 raise pickle.UnpicklingError("global '%s' is forbidden" % fullname)
62
63 @classmethod
64 def load(self, file):
65 obj = pickle.Unpickler(file)
66 obj.find_global = self.find_class
67 return obj.load()
16 68
17 def read_columns(f, c=None, c_option='by_index_number', return_df=False, **args): 69 def read_columns(f, c=None, c_option='by_index_number', return_df=False, **args):
18 data = pandas.read_csv(f, **args) 70 data = pandas.read_csv(f, **args)
19 if c_option == 'by_index_number': 71 if c_option == 'by_index_number':
20 cols = list(map(lambda x: x - 1, c)) 72 cols = list(map(lambda x: x - 1, c))
46 if not options['threshold'] or options['threshold'] == 'None': 98 if not options['threshold'] or options['threshold'] == 'None':
47 options['threshold'] = None 99 options['threshold'] = None
48 if inputs['model_inputter']['input_mode'] == 'prefitted': 100 if inputs['model_inputter']['input_mode'] == 'prefitted':
49 model_file = inputs['model_inputter']['fitted_estimator'] 101 model_file = inputs['model_inputter']['fitted_estimator']
50 with open(model_file, 'rb') as model_handler: 102 with open(model_file, 'rb') as model_handler:
51 fitted_estimator = pickle.load(model_handler) 103 fitted_estimator = SafePickler.load(model_handler)
52 new_selector = selector(fitted_estimator, prefit=True, **options) 104 new_selector = selector(fitted_estimator, prefit=True, **options)
53 else: 105 else:
54 estimator_json = inputs['model_inputter']["estimator_selector"] 106 estimator_json = inputs['model_inputter']["estimator_selector"]
55 estimator = get_estimator(estimator_json) 107 estimator = get_estimator(estimator_json)
56 new_selector = selector(estimator, **options) 108 new_selector = selector(estimator, **options)
130 182
131 syms = make_symbol_table(use_numpy=False, **new_syms) 183 syms = make_symbol_table(use_numpy=False, **new_syms)
132 184
133 if load_scipy: 185 if load_scipy:
134 scipy_distributions = scipy.stats.distributions.__dict__ 186 scipy_distributions = scipy.stats.distributions.__dict__
135 for key in scipy_distributions.keys(): 187 for k, v in scipy_distributions.items():
136 if isinstance(scipy_distributions[key], (scipy.stats.rv_continuous, scipy.stats.rv_discrete)): 188 if isinstance(v, (scipy.stats.rv_continuous, scipy.stats.rv_discrete)):
137 syms['scipy_stats_' + key] = scipy_distributions[key] 189 syms['scipy_stats_' + k] = v
138 190
139 if load_numpy: 191 if load_numpy:
140 from_numpy_random = ['beta', 'binomial', 'bytes', 'chisquare', 'choice', 'dirichlet', 'division', 192 from_numpy_random = ['beta', 'binomial', 'bytes', 'chisquare', 'choice', 'dirichlet', 'division',
141 'exponential', 'f', 'gamma', 'geometric', 'gumbel', 'hypergeometric', 193 'exponential', 'f', 'gamma', 'geometric', 'gumbel', 'hypergeometric',
142 'laplace', 'logistic', 'lognormal', 'logseries', 'mtrand', 'multinomial', 194 'laplace', 'logistic', 'lognormal', 'logseries', 'mtrand', 'multinomial',