Mercurial > repos > bgruening > sklearn_discriminant_classifier
comparison utils.py @ 20:f051d64eb12e draft
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 8cf3d813ec755166ee0bd517b4ecbbd4f84d4df1
author | bgruening |
---|---|
date | Thu, 23 Aug 2018 16:19:35 -0400 |
parents | 98b632c407ae |
children | 56ddc98c484e |
comparison
equal
deleted
inserted
replaced
19:98b632c407ae | 20:f051d64eb12e |
---|---|
1 import sys | 1 import sys |
2 import os | 2 import os |
3 import pandas | 3 import pandas |
4 import re | 4 import re |
5 import pickle | 5 import cPickle as pickle |
6 import warnings | 6 import warnings |
7 import numpy as np | 7 import numpy as np |
8 import xgboost | 8 import xgboost |
9 import scipy | 9 import scipy |
10 import sklearn | 10 import sklearn |
11 import ast | 11 import ast |
12 from asteval import Interpreter, make_symbol_table | 12 from asteval import Interpreter, make_symbol_table |
13 from sklearn import metrics, model_selection, ensemble, svm, linear_model, naive_bayes, tree, neighbors | 13 from sklearn import (cluster, decomposition, ensemble, feature_extraction, feature_selection, |
14 gaussian_process, kernel_approximation, linear_model, metrics, | |
15 model_selection, naive_bayes, neighbors, pipeline, preprocessing, | |
16 svm, linear_model, tree, discriminant_analysis) | |
14 | 17 |
15 N_JOBS = int( os.environ.get('GALAXY_SLOTS', 1) ) | 18 N_JOBS = int( os.environ.get('GALAXY_SLOTS', 1) ) |
19 | |
20 class SafePickler(object): | |
21 """ | |
22 Used to safely deserialize scikit-learn model objects serialized by cPickle.dump | |
23 Usage: | |
24 eg.: SafePickler.load(pickled_file_object) | |
25 """ | |
26 @classmethod | |
27 def find_class(self, module, name): | |
28 | |
29 bad_names = ('and', 'as', 'assert', 'break', 'class', 'continue', | |
30 'def', 'del', 'elif', 'else', 'except', 'exec', | |
31 'finally', 'for', 'from', 'global', 'if', 'import', | |
32 'in', 'is', 'lambda', 'not', 'or', 'pass', 'print', | |
33 'raise', 'return', 'try', 'system', 'while', 'with', | |
34 'True', 'False', 'None', 'eval', 'execfile', '__import__', | |
35 '__package__', '__subclasses__', '__bases__', '__globals__', | |
36 '__code__', '__closure__', '__func__', '__self__', '__module__', | |
37 '__dict__', '__class__', '__call__', '__get__', | |
38 '__getattribute__', '__subclasshook__', '__new__', | |
39 '__init__', 'func_globals', 'func_code', 'func_closure', | |
40 'im_class', 'im_func', 'im_self', 'gi_code', 'gi_frame', | |
41 '__asteval__', 'f_locals', '__mro__') | |
42 good_names = ('copy_reg._reconstructor', '__builtin__.object') | |
43 | |
44 if re.match(r'^[a-zA-Z_][a-zA-Z0-9_]*$', name): | |
45 fullname = module + '.' + name | |
46 if (fullname in good_names)\ | |
47 or ( ( module.startswith('sklearn.') | |
48 or module.startswith('xgboost.') | |
49 or module.startswith('skrebate.') | |
50 or module.startswith('numpy.') | |
51 or module == 'numpy' | |
52 ) | |
53 and (name not in bad_names) | |
54 ) : | |
55 # TODO: replace with a whitelist checker | |
56 if fullname not in SK_NAMES + SKR_NAMES + XGB_NAMES + NUMPY_NAMES + good_names: | |
57 print("Warning: global %s is not in pickler whitelist yet and will loss support soon. Contact tool author or leave a message at github.com" % fullname) | |
58 mod = sys.modules[module] | |
59 return getattr(mod, name) | |
60 | |
61 raise pickle.UnpicklingError("global '%s' is forbidden" % fullname) | |
62 | |
63 @classmethod | |
64 def load(self, file): | |
65 obj = pickle.Unpickler(file) | |
66 obj.find_global = self.find_class | |
67 return obj.load() | |
16 | 68 |
17 def read_columns(f, c=None, c_option='by_index_number', return_df=False, **args): | 69 def read_columns(f, c=None, c_option='by_index_number', return_df=False, **args): |
18 data = pandas.read_csv(f, **args) | 70 data = pandas.read_csv(f, **args) |
19 if c_option == 'by_index_number': | 71 if c_option == 'by_index_number': |
20 cols = list(map(lambda x: x - 1, c)) | 72 cols = list(map(lambda x: x - 1, c)) |
46 if not options['threshold'] or options['threshold'] == 'None': | 98 if not options['threshold'] or options['threshold'] == 'None': |
47 options['threshold'] = None | 99 options['threshold'] = None |
48 if inputs['model_inputter']['input_mode'] == 'prefitted': | 100 if inputs['model_inputter']['input_mode'] == 'prefitted': |
49 model_file = inputs['model_inputter']['fitted_estimator'] | 101 model_file = inputs['model_inputter']['fitted_estimator'] |
50 with open(model_file, 'rb') as model_handler: | 102 with open(model_file, 'rb') as model_handler: |
51 fitted_estimator = pickle.load(model_handler) | 103 fitted_estimator = SafePickler.load(model_handler) |
52 new_selector = selector(fitted_estimator, prefit=True, **options) | 104 new_selector = selector(fitted_estimator, prefit=True, **options) |
53 else: | 105 else: |
54 estimator_json = inputs['model_inputter']["estimator_selector"] | 106 estimator_json = inputs['model_inputter']["estimator_selector"] |
55 estimator = get_estimator(estimator_json) | 107 estimator = get_estimator(estimator_json) |
56 new_selector = selector(estimator, **options) | 108 new_selector = selector(estimator, **options) |
130 | 182 |
131 syms = make_symbol_table(use_numpy=False, **new_syms) | 183 syms = make_symbol_table(use_numpy=False, **new_syms) |
132 | 184 |
133 if load_scipy: | 185 if load_scipy: |
134 scipy_distributions = scipy.stats.distributions.__dict__ | 186 scipy_distributions = scipy.stats.distributions.__dict__ |
135 for key in scipy_distributions.keys(): | 187 for k, v in scipy_distributions.items(): |
136 if isinstance(scipy_distributions[key], (scipy.stats.rv_continuous, scipy.stats.rv_discrete)): | 188 if isinstance(v, (scipy.stats.rv_continuous, scipy.stats.rv_discrete)): |
137 syms['scipy_stats_' + key] = scipy_distributions[key] | 189 syms['scipy_stats_' + k] = v |
138 | 190 |
139 if load_numpy: | 191 if load_numpy: |
140 from_numpy_random = ['beta', 'binomial', 'bytes', 'chisquare', 'choice', 'dirichlet', 'division', | 192 from_numpy_random = ['beta', 'binomial', 'bytes', 'chisquare', 'choice', 'dirichlet', 'division', |
141 'exponential', 'f', 'gamma', 'geometric', 'gumbel', 'hypergeometric', | 193 'exponential', 'f', 'gamma', 'geometric', 'gumbel', 'hypergeometric', |
142 'laplace', 'logistic', 'lognormal', 'logseries', 'mtrand', 'multinomial', | 194 'laplace', 'logistic', 'lognormal', 'logseries', 'mtrand', 'multinomial', |