comparison search_model_validation.py @ 6:13b9ac5d277c draft

"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 208a8d348e7c7a182cfbe1b6f17868146428a7e2"
author bgruening
date Tue, 13 Apr 2021 22:24:07 +0000
parents 5a092779412e
children 3312fb686ffb
comparison
equal deleted inserted replaced
5:ce2fd1edbc6e 6:13b9ac5d277c
1 import argparse 1 import argparse
2 import collections 2 import collections
3 import json
4 import os
5 import pickle
6 import sys
7 import warnings
8
3 import imblearn 9 import imblearn
4 import joblib 10 import joblib
5 import json
6 import numpy as np 11 import numpy as np
7 import os
8 import pandas as pd 12 import pandas as pd
9 import pickle
10 import skrebate 13 import skrebate
11 import sys 14 from galaxy_ml.utils import (
12 import warnings 15 clean_params,
16 get_cv,
17 get_main_estimator,
18 get_module,
19 get_scoring,
20 load_model,
21 read_columns,
22 SafeEval,
23 try_get_attr
24 )
13 from scipy.io import mmread 25 from scipy.io import mmread
14 from sklearn import (cluster, decomposition, feature_selection, 26 from sklearn import (
15 kernel_approximation, model_selection, preprocessing) 27 cluster,
28 decomposition,
29 feature_selection,
30 kernel_approximation,
31 model_selection,
32 preprocessing,
33 )
16 from sklearn.exceptions import FitFailedWarning 34 from sklearn.exceptions import FitFailedWarning
35 from sklearn.model_selection import _search, _validation
17 from sklearn.model_selection._validation import _score, cross_validate 36 from sklearn.model_selection._validation import _score, cross_validate
18 from sklearn.model_selection import _search, _validation 37
19 from sklearn.pipeline import Pipeline 38
20 39 _fit_and_score = try_get_attr("galaxy_ml.model_validations", "_fit_and_score")
21 from galaxy_ml.utils import (SafeEval, get_cv, get_scoring, load_model, 40 setattr(_search, "_fit_and_score", _fit_and_score)
22 read_columns, try_get_attr, get_module, 41 setattr(_validation, "_fit_and_score", _fit_and_score)
23 clean_params, get_main_estimator) 42
24 43 N_JOBS = int(os.environ.get("GALAXY_SLOTS", 1))
25
26 _fit_and_score = try_get_attr('galaxy_ml.model_validations', '_fit_and_score')
27 setattr(_search, '_fit_and_score', _fit_and_score)
28 setattr(_validation, '_fit_and_score', _fit_and_score)
29
30 N_JOBS = int(os.environ.get('GALAXY_SLOTS', 1))
31 # handle disk cache 44 # handle disk cache
32 CACHE_DIR = os.path.join(os.getcwd(), 'cached') 45 CACHE_DIR = os.path.join(os.getcwd(), "cached")
33 del os 46 del os
34 NON_SEARCHABLE = ('n_jobs', 'pre_dispatch', 'memory', '_path', 47 NON_SEARCHABLE = ("n_jobs", "pre_dispatch", "memory", "_path", "nthread", "callbacks")
35 'nthread', 'callbacks')
36 48
37 49
38 def _eval_search_params(params_builder): 50 def _eval_search_params(params_builder):
39 search_params = {} 51 search_params = {}
40 52
41 for p in params_builder['param_set']: 53 for p in params_builder["param_set"]:
42 search_list = p['sp_list'].strip() 54 search_list = p["sp_list"].strip()
43 if search_list == '': 55 if search_list == "":
44 continue 56 continue
45 57
46 param_name = p['sp_name'] 58 param_name = p["sp_name"]
47 if param_name.lower().endswith(NON_SEARCHABLE): 59 if param_name.lower().endswith(NON_SEARCHABLE):
48 print("Warning: `%s` is not eligible for search and was " 60 print("Warning: `%s` is not eligible for search and was " "omitted!" % param_name)
49 "omitted!" % param_name)
50 continue 61 continue
51 62
52 if not search_list.startswith(':'): 63 if not search_list.startswith(":"):
53 safe_eval = SafeEval(load_scipy=True, load_numpy=True) 64 safe_eval = SafeEval(load_scipy=True, load_numpy=True)
54 ev = safe_eval(search_list) 65 ev = safe_eval(search_list)
55 search_params[param_name] = ev 66 search_params[param_name] = ev
56 else: 67 else:
57 # Have `:` before search list, asks for estimator evaluatio 68 # Have `:` before search list, asks for estimator evaluatio
58 safe_eval_es = SafeEval(load_estimators=True) 69 safe_eval_es = SafeEval(load_estimators=True)
59 search_list = search_list[1:].strip() 70 search_list = search_list[1:].strip()
60 # TODO maybe add regular express check 71 # TODO maybe add regular express check
61 ev = safe_eval_es(search_list) 72 ev = safe_eval_es(search_list)
62 preprocessings = ( 73 preprocessings = (
63 preprocessing.StandardScaler(), preprocessing.Binarizer(), 74 preprocessing.StandardScaler(),
75 preprocessing.Binarizer(),
64 preprocessing.MaxAbsScaler(), 76 preprocessing.MaxAbsScaler(),
65 preprocessing.Normalizer(), preprocessing.MinMaxScaler(), 77 preprocessing.Normalizer(),
78 preprocessing.MinMaxScaler(),
66 preprocessing.PolynomialFeatures(), 79 preprocessing.PolynomialFeatures(),
67 preprocessing.RobustScaler(), feature_selection.SelectKBest(), 80 preprocessing.RobustScaler(),
81 feature_selection.SelectKBest(),
68 feature_selection.GenericUnivariateSelect(), 82 feature_selection.GenericUnivariateSelect(),
69 feature_selection.SelectPercentile(), 83 feature_selection.SelectPercentile(),
70 feature_selection.SelectFpr(), feature_selection.SelectFdr(), 84 feature_selection.SelectFpr(),
85 feature_selection.SelectFdr(),
71 feature_selection.SelectFwe(), 86 feature_selection.SelectFwe(),
72 feature_selection.VarianceThreshold(), 87 feature_selection.VarianceThreshold(),
73 decomposition.FactorAnalysis(random_state=0), 88 decomposition.FactorAnalysis(random_state=0),
74 decomposition.FastICA(random_state=0), 89 decomposition.FastICA(random_state=0),
75 decomposition.IncrementalPCA(), 90 decomposition.IncrementalPCA(),
76 decomposition.KernelPCA(random_state=0, n_jobs=N_JOBS), 91 decomposition.KernelPCA(random_state=0, n_jobs=N_JOBS),
77 decomposition.LatentDirichletAllocation( 92 decomposition.LatentDirichletAllocation(random_state=0, n_jobs=N_JOBS),
78 random_state=0, n_jobs=N_JOBS), 93 decomposition.MiniBatchDictionaryLearning(random_state=0, n_jobs=N_JOBS),
79 decomposition.MiniBatchDictionaryLearning( 94 decomposition.MiniBatchSparsePCA(random_state=0, n_jobs=N_JOBS),
80 random_state=0, n_jobs=N_JOBS),
81 decomposition.MiniBatchSparsePCA(
82 random_state=0, n_jobs=N_JOBS),
83 decomposition.NMF(random_state=0), 95 decomposition.NMF(random_state=0),
84 decomposition.PCA(random_state=0), 96 decomposition.PCA(random_state=0),
85 decomposition.SparsePCA(random_state=0, n_jobs=N_JOBS), 97 decomposition.SparsePCA(random_state=0, n_jobs=N_JOBS),
86 decomposition.TruncatedSVD(random_state=0), 98 decomposition.TruncatedSVD(random_state=0),
87 kernel_approximation.Nystroem(random_state=0), 99 kernel_approximation.Nystroem(random_state=0),
92 skrebate.ReliefF(n_jobs=N_JOBS), 104 skrebate.ReliefF(n_jobs=N_JOBS),
93 skrebate.SURF(n_jobs=N_JOBS), 105 skrebate.SURF(n_jobs=N_JOBS),
94 skrebate.SURFstar(n_jobs=N_JOBS), 106 skrebate.SURFstar(n_jobs=N_JOBS),
95 skrebate.MultiSURF(n_jobs=N_JOBS), 107 skrebate.MultiSURF(n_jobs=N_JOBS),
96 skrebate.MultiSURFstar(n_jobs=N_JOBS), 108 skrebate.MultiSURFstar(n_jobs=N_JOBS),
97 imblearn.under_sampling.ClusterCentroids( 109 imblearn.under_sampling.ClusterCentroids(random_state=0, n_jobs=N_JOBS),
98 random_state=0, n_jobs=N_JOBS), 110 imblearn.under_sampling.CondensedNearestNeighbour(random_state=0, n_jobs=N_JOBS),
99 imblearn.under_sampling.CondensedNearestNeighbour( 111 imblearn.under_sampling.EditedNearestNeighbours(random_state=0, n_jobs=N_JOBS),
100 random_state=0, n_jobs=N_JOBS), 112 imblearn.under_sampling.RepeatedEditedNearestNeighbours(random_state=0, n_jobs=N_JOBS),
101 imblearn.under_sampling.EditedNearestNeighbours(
102 random_state=0, n_jobs=N_JOBS),
103 imblearn.under_sampling.RepeatedEditedNearestNeighbours(
104 random_state=0, n_jobs=N_JOBS),
105 imblearn.under_sampling.AllKNN(random_state=0, n_jobs=N_JOBS), 113 imblearn.under_sampling.AllKNN(random_state=0, n_jobs=N_JOBS),
106 imblearn.under_sampling.InstanceHardnessThreshold( 114 imblearn.under_sampling.InstanceHardnessThreshold(random_state=0, n_jobs=N_JOBS),
107 random_state=0, n_jobs=N_JOBS), 115 imblearn.under_sampling.NearMiss(random_state=0, n_jobs=N_JOBS),
108 imblearn.under_sampling.NearMiss( 116 imblearn.under_sampling.NeighbourhoodCleaningRule(random_state=0, n_jobs=N_JOBS),
109 random_state=0, n_jobs=N_JOBS), 117 imblearn.under_sampling.OneSidedSelection(random_state=0, n_jobs=N_JOBS),
110 imblearn.under_sampling.NeighbourhoodCleaningRule( 118 imblearn.under_sampling.RandomUnderSampler(random_state=0),
111 random_state=0, n_jobs=N_JOBS), 119 imblearn.under_sampling.TomekLinks(random_state=0, n_jobs=N_JOBS),
112 imblearn.under_sampling.OneSidedSelection(
113 random_state=0, n_jobs=N_JOBS),
114 imblearn.under_sampling.RandomUnderSampler(
115 random_state=0),
116 imblearn.under_sampling.TomekLinks(
117 random_state=0, n_jobs=N_JOBS),
118 imblearn.over_sampling.ADASYN(random_state=0, n_jobs=N_JOBS), 120 imblearn.over_sampling.ADASYN(random_state=0, n_jobs=N_JOBS),
119 imblearn.over_sampling.RandomOverSampler(random_state=0), 121 imblearn.over_sampling.RandomOverSampler(random_state=0),
120 imblearn.over_sampling.SMOTE(random_state=0, n_jobs=N_JOBS), 122 imblearn.over_sampling.SMOTE(random_state=0, n_jobs=N_JOBS),
121 imblearn.over_sampling.SVMSMOTE(random_state=0, n_jobs=N_JOBS), 123 imblearn.over_sampling.SVMSMOTE(random_state=0, n_jobs=N_JOBS),
122 imblearn.over_sampling.BorderlineSMOTE( 124 imblearn.over_sampling.BorderlineSMOTE(random_state=0, n_jobs=N_JOBS),
123 random_state=0, n_jobs=N_JOBS), 125 imblearn.over_sampling.SMOTENC(categorical_features=[], random_state=0, n_jobs=N_JOBS),
124 imblearn.over_sampling.SMOTENC(
125 categorical_features=[], random_state=0, n_jobs=N_JOBS),
126 imblearn.combine.SMOTEENN(random_state=0), 126 imblearn.combine.SMOTEENN(random_state=0),
127 imblearn.combine.SMOTETomek(random_state=0)) 127 imblearn.combine.SMOTETomek(random_state=0),
128 )
128 newlist = [] 129 newlist = []
129 for obj in ev: 130 for obj in ev:
130 if obj is None: 131 if obj is None:
131 newlist.append(None) 132 newlist.append(None)
132 elif obj == 'all_0': 133 elif obj == "all_0":
133 newlist.extend(preprocessings[0:35]) 134 newlist.extend(preprocessings[0:35])
134 elif obj == 'sk_prep_all': # no KernalCenter() 135 elif obj == "sk_prep_all": # no KernalCenter()
135 newlist.extend(preprocessings[0:7]) 136 newlist.extend(preprocessings[0:7])
136 elif obj == 'fs_all': 137 elif obj == "fs_all":
137 newlist.extend(preprocessings[7:14]) 138 newlist.extend(preprocessings[7:14])
138 elif obj == 'decomp_all': 139 elif obj == "decomp_all":
139 newlist.extend(preprocessings[14:25]) 140 newlist.extend(preprocessings[14:25])
140 elif obj == 'k_appr_all': 141 elif obj == "k_appr_all":
141 newlist.extend(preprocessings[25:29]) 142 newlist.extend(preprocessings[25:29])
142 elif obj == 'reb_all': 143 elif obj == "reb_all":
143 newlist.extend(preprocessings[30:35]) 144 newlist.extend(preprocessings[30:35])
144 elif obj == 'imb_all': 145 elif obj == "imb_all":
145 newlist.extend(preprocessings[35:54]) 146 newlist.extend(preprocessings[35:54])
146 elif type(obj) is int and -1 < obj < len(preprocessings): 147 elif type(obj) is int and -1 < obj < len(preprocessings):
147 newlist.append(preprocessings[obj]) 148 newlist.append(preprocessings[obj])
148 elif hasattr(obj, 'get_params'): # user uploaded object 149 elif hasattr(obj, "get_params"): # user uploaded object
149 if 'n_jobs' in obj.get_params(): 150 if "n_jobs" in obj.get_params():
150 newlist.append(obj.set_params(n_jobs=N_JOBS)) 151 newlist.append(obj.set_params(n_jobs=N_JOBS))
151 else: 152 else:
152 newlist.append(obj) 153 newlist.append(obj)
153 else: 154 else:
154 sys.exit("Unsupported estimator type: %r" % (obj)) 155 sys.exit("Unsupported estimator type: %r" % (obj))
156 search_params[param_name] = newlist 157 search_params[param_name] = newlist
157 158
158 return search_params 159 return search_params
159 160
160 161
161 def _handle_X_y(estimator, params, infile1, infile2, loaded_df={}, 162 def _handle_X_y(
162 ref_seq=None, intervals=None, targets=None, 163 estimator,
163 fasta_path=None): 164 params,
165 infile1,
166 infile2,
167 loaded_df={},
168 ref_seq=None,
169 intervals=None,
170 targets=None,
171 fasta_path=None,
172 ):
164 """read inputs 173 """read inputs
165 174
166 Params 175 Params
167 ------- 176 -------
168 estimator : estimator object 177 estimator : estimator object
190 X : numpy array 199 X : numpy array
191 y : numpy array 200 y : numpy array
192 """ 201 """
193 estimator_params = estimator.get_params() 202 estimator_params = estimator.get_params()
194 203
195 input_type = params['input_options']['selected_input'] 204 input_type = params["input_options"]["selected_input"]
196 # tabular input 205 # tabular input
197 if input_type == 'tabular': 206 if input_type == "tabular":
198 header = 'infer' if params['input_options']['header1'] else None 207 header = "infer" if params["input_options"]["header1"] else None
199 column_option = (params['input_options']['column_selector_options_1'] 208 column_option = params["input_options"]["column_selector_options_1"]["selected_column_selector_option"]
200 ['selected_column_selector_option']) 209 if column_option in [
201 if column_option in ['by_index_number', 'all_but_by_index_number', 210 "by_index_number",
202 'by_header_name', 'all_but_by_header_name']: 211 "all_but_by_index_number",
203 c = params['input_options']['column_selector_options_1']['col1'] 212 "by_header_name",
213 "all_but_by_header_name",
214 ]:
215 c = params["input_options"]["column_selector_options_1"]["col1"]
204 else: 216 else:
205 c = None 217 c = None
206 218
207 df_key = infile1 + repr(header) 219 df_key = infile1 + repr(header)
208 220
209 if df_key in loaded_df: 221 if df_key in loaded_df:
210 infile1 = loaded_df[df_key] 222 infile1 = loaded_df[df_key]
211 223
212 df = pd.read_csv(infile1, sep='\t', header=header, 224 df = pd.read_csv(infile1, sep="\t", header=header, parse_dates=True)
213 parse_dates=True)
214 loaded_df[df_key] = df 225 loaded_df[df_key] = df
215 226
216 X = read_columns(df, c=c, c_option=column_option).astype(float) 227 X = read_columns(df, c=c, c_option=column_option).astype(float)
217 # sparse input 228 # sparse input
218 elif input_type == 'sparse': 229 elif input_type == "sparse":
219 X = mmread(open(infile1, 'r')) 230 X = mmread(open(infile1, "r"))
220 231
221 # fasta_file input 232 # fasta_file input
222 elif input_type == 'seq_fasta': 233 elif input_type == "seq_fasta":
223 pyfaidx = get_module('pyfaidx') 234 pyfaidx = get_module("pyfaidx")
224 sequences = pyfaidx.Fasta(fasta_path) 235 sequences = pyfaidx.Fasta(fasta_path)
225 n_seqs = len(sequences.keys()) 236 n_seqs = len(sequences.keys())
226 X = np.arange(n_seqs)[:, np.newaxis] 237 X = np.arange(n_seqs)[:, np.newaxis]
227 for param in estimator_params.keys(): 238 for param in estimator_params.keys():
228 if param.endswith('fasta_path'): 239 if param.endswith("fasta_path"):
229 estimator.set_params( 240 estimator.set_params(**{param: fasta_path})
230 **{param: fasta_path})
231 break 241 break
232 else: 242 else:
233 raise ValueError( 243 raise ValueError(
234 "The selected estimator doesn't support " 244 "The selected estimator doesn't support "
235 "fasta file input! Please consider using " 245 "fasta file input! Please consider using "
236 "KerasGBatchClassifier with " 246 "KerasGBatchClassifier with "
237 "FastaDNABatchGenerator/FastaProteinBatchGenerator " 247 "FastaDNABatchGenerator/FastaProteinBatchGenerator "
238 "or having GenomeOneHotEncoder/ProteinOneHotEncoder " 248 "or having GenomeOneHotEncoder/ProteinOneHotEncoder "
239 "in pipeline!") 249 "in pipeline!"
240 250 )
241 elif input_type == 'refseq_and_interval': 251
252 elif input_type == "refseq_and_interval":
242 path_params = { 253 path_params = {
243 'data_batch_generator__ref_genome_path': ref_seq, 254 "data_batch_generator__ref_genome_path": ref_seq,
244 'data_batch_generator__intervals_path': intervals, 255 "data_batch_generator__intervals_path": intervals,
245 'data_batch_generator__target_path': targets 256 "data_batch_generator__target_path": targets,
246 } 257 }
247 estimator.set_params(**path_params) 258 estimator.set_params(**path_params)
248 n_intervals = sum(1 for line in open(intervals)) 259 n_intervals = sum(1 for line in open(intervals))
249 X = np.arange(n_intervals)[:, np.newaxis] 260 X = np.arange(n_intervals)[:, np.newaxis]
250 261
251 # Get target y 262 # Get target y
252 header = 'infer' if params['input_options']['header2'] else None 263 header = "infer" if params["input_options"]["header2"] else None
253 column_option = (params['input_options']['column_selector_options_2'] 264 column_option = params["input_options"]["column_selector_options_2"]["selected_column_selector_option2"]
254 ['selected_column_selector_option2']) 265 if column_option in [
255 if column_option in ['by_index_number', 'all_but_by_index_number', 266 "by_index_number",
256 'by_header_name', 'all_but_by_header_name']: 267 "all_but_by_index_number",
257 c = params['input_options']['column_selector_options_2']['col2'] 268 "by_header_name",
269 "all_but_by_header_name",
270 ]:
271 c = params["input_options"]["column_selector_options_2"]["col2"]
258 else: 272 else:
259 c = None 273 c = None
260 274
261 df_key = infile2 + repr(header) 275 df_key = infile2 + repr(header)
262 if df_key in loaded_df: 276 if df_key in loaded_df:
263 infile2 = loaded_df[df_key] 277 infile2 = loaded_df[df_key]
264 else: 278 else:
265 infile2 = pd.read_csv(infile2, sep='\t', 279 infile2 = pd.read_csv(infile2, sep="\t", header=header, parse_dates=True)
266 header=header, parse_dates=True)
267 loaded_df[df_key] = infile2 280 loaded_df[df_key] = infile2
268 281
269 y = read_columns( 282 y = read_columns(infile2, c=c, c_option=column_option, sep="\t", header=header, parse_dates=True)
270 infile2,
271 c=c,
272 c_option=column_option,
273 sep='\t',
274 header=header,
275 parse_dates=True)
276 if len(y.shape) == 2 and y.shape[1] == 1: 283 if len(y.shape) == 2 and y.shape[1] == 1:
277 y = y.ravel() 284 y = y.ravel()
278 if input_type == 'refseq_and_interval': 285 if input_type == "refseq_and_interval":
279 estimator.set_params( 286 estimator.set_params(data_batch_generator__features=y.ravel().tolist())
280 data_batch_generator__features=y.ravel().tolist())
281 y = None 287 y = None
282 # end y 288 # end y
283 289
284 return estimator, X, y 290 return estimator, X, y
285 291
286 292
287 def _do_outer_cv(searcher, X, y, outer_cv, scoring, error_score='raise', 293 def _do_outer_cv(searcher, X, y, outer_cv, scoring, error_score="raise", outfile=None):
288 outfile=None):
289 """Do outer cross-validation for nested CV 294 """Do outer cross-validation for nested CV
290 295
291 Parameters 296 Parameters
292 ---------- 297 ----------
293 searcher : object 298 searcher : object
303 error_score: str, float or numpy float 308 error_score: str, float or numpy float
304 Whether to raise fit error or return an value 309 Whether to raise fit error or return an value
305 outfile : str 310 outfile : str
306 File path to store the restuls 311 File path to store the restuls
307 """ 312 """
308 if error_score == 'raise': 313 if error_score == "raise":
309 rval = cross_validate( 314 rval = cross_validate(
310 searcher, X, y, scoring=scoring, 315 searcher,
311 cv=outer_cv, n_jobs=N_JOBS, verbose=0, 316 X,
312 error_score=error_score) 317 y,
313 else: 318 scoring=scoring,
314 warnings.simplefilter('always', FitFailedWarning) 319 cv=outer_cv,
320 n_jobs=N_JOBS,
321 verbose=0,
322 error_score=error_score,
323 )
324 else:
325 warnings.simplefilter("always", FitFailedWarning)
315 with warnings.catch_warnings(record=True) as w: 326 with warnings.catch_warnings(record=True) as w:
316 try: 327 try:
317 rval = cross_validate( 328 rval = cross_validate(
318 searcher, X, y, 329 searcher,
330 X,
331 y,
319 scoring=scoring, 332 scoring=scoring,
320 cv=outer_cv, n_jobs=N_JOBS, 333 cv=outer_cv,
334 n_jobs=N_JOBS,
321 verbose=0, 335 verbose=0,
322 error_score=error_score) 336 error_score=error_score,
337 )
323 except ValueError: 338 except ValueError:
324 pass 339 pass
325 for warning in w: 340 for warning in w:
326 print(repr(warning.message)) 341 print(repr(warning.message))
327 342
328 keys = list(rval.keys()) 343 keys = list(rval.keys())
329 for k in keys: 344 for k in keys:
330 if k.startswith('test'): 345 if k.startswith("test"):
331 rval['mean_' + k] = np.mean(rval[k]) 346 rval["mean_" + k] = np.mean(rval[k])
332 rval['std_' + k] = np.std(rval[k]) 347 rval["std_" + k] = np.std(rval[k])
333 if k.endswith('time'): 348 if k.endswith("time"):
334 rval.pop(k) 349 rval.pop(k)
335 rval = pd.DataFrame(rval) 350 rval = pd.DataFrame(rval)
336 rval = rval[sorted(rval.columns)] 351 rval = rval[sorted(rval.columns)]
337 rval.to_csv(path_or_buf=outfile, sep='\t', header=True, index=False) 352 rval.to_csv(path_or_buf=outfile, sep="\t", header=True, index=False)
338 353
339 354
340 def _do_train_test_split_val(searcher, X, y, params, error_score='raise', 355 def _do_train_test_split_val(
341 primary_scoring=None, groups=None, 356 searcher,
342 outfile=None): 357 X,
343 """ do train test split, searchCV validates on the train and then use 358 y,
359 params,
360 error_score="raise",
361 primary_scoring=None,
362 groups=None,
363 outfile=None,
364 ):
365 """do train test split, searchCV validates on the train and then use
344 the best_estimator_ to evaluate on the test 366 the best_estimator_ to evaluate on the test
345 367
346 Returns 368 Returns
347 -------- 369 --------
348 Fitted SearchCV object 370 Fitted SearchCV object
349 """ 371 """
350 train_test_split = try_get_attr( 372 train_test_split = try_get_attr("galaxy_ml.model_validations", "train_test_split")
351 'galaxy_ml.model_validations', 'train_test_split') 373 split_options = params["outer_split"]
352 split_options = params['outer_split']
353 374
354 # splits 375 # splits
355 if split_options['shuffle'] == 'stratified': 376 if split_options["shuffle"] == "stratified":
356 split_options['labels'] = y 377 split_options["labels"] = y
357 X, X_test, y, y_test = train_test_split(X, y, **split_options) 378 X, X_test, y, y_test = train_test_split(X, y, **split_options)
358 elif split_options['shuffle'] == 'group': 379 elif split_options["shuffle"] == "group":
359 if groups is None: 380 if groups is None:
360 raise ValueError("No group based CV option was choosen for " 381 raise ValueError("No group based CV option was choosen for " "group shuffle!")
361 "group shuffle!") 382 split_options["labels"] = groups
362 split_options['labels'] = groups
363 if y is None: 383 if y is None:
364 X, X_test, groups, _ =\ 384 X, X_test, groups, _ = train_test_split(X, groups, **split_options)
365 train_test_split(X, groups, **split_options)
366 else: 385 else:
367 X, X_test, y, y_test, groups, _ =\ 386 X, X_test, y, y_test, groups, _ = train_test_split(X, y, groups, **split_options)
368 train_test_split(X, y, groups, **split_options) 387 else:
369 else: 388 if split_options["shuffle"] == "None":
370 if split_options['shuffle'] == 'None': 389 split_options["shuffle"] = None
371 split_options['shuffle'] = None 390 X, X_test, y, y_test = train_test_split(X, y, **split_options)
372 X, X_test, y, y_test =\ 391
373 train_test_split(X, y, **split_options) 392 if error_score == "raise":
374
375 if error_score == 'raise':
376 searcher.fit(X, y, groups=groups) 393 searcher.fit(X, y, groups=groups)
377 else: 394 else:
378 warnings.simplefilter('always', FitFailedWarning) 395 warnings.simplefilter("always", FitFailedWarning)
379 with warnings.catch_warnings(record=True) as w: 396 with warnings.catch_warnings(record=True) as w:
380 try: 397 try:
381 searcher.fit(X, y, groups=groups) 398 searcher.fit(X, y, groups=groups)
382 except ValueError: 399 except ValueError:
383 pass 400 pass
388 if isinstance(scorer_, collections.Mapping): 405 if isinstance(scorer_, collections.Mapping):
389 is_multimetric = True 406 is_multimetric = True
390 else: 407 else:
391 is_multimetric = False 408 is_multimetric = False
392 409
393 best_estimator_ = getattr(searcher, 'best_estimator_') 410 best_estimator_ = getattr(searcher, "best_estimator_")
394 411
395 # TODO Solve deep learning models in pipeline 412 # TODO Solve deep learning models in pipeline
396 if best_estimator_.__class__.__name__ == 'KerasGBatchClassifier': 413 if best_estimator_.__class__.__name__ == "KerasGBatchClassifier":
397 test_score = best_estimator_.evaluate( 414 test_score = best_estimator_.evaluate(X_test, scorer=scorer_, is_multimetric=is_multimetric)
398 X_test, scorer=scorer_, is_multimetric=is_multimetric) 415 else:
399 else: 416 test_score = _score(best_estimator_, X_test, y_test, scorer_, is_multimetric=is_multimetric)
400 test_score = _score(best_estimator_, X_test,
401 y_test, scorer_,
402 is_multimetric=is_multimetric)
403 417
404 if not is_multimetric: 418 if not is_multimetric:
405 test_score = {primary_scoring: test_score} 419 test_score = {primary_scoring: test_score}
406 for key, value in test_score.items(): 420 for key, value in test_score.items():
407 test_score[key] = [value] 421 test_score[key] = [value]
408 result_df = pd.DataFrame(test_score) 422 result_df = pd.DataFrame(test_score)
409 result_df.to_csv(path_or_buf=outfile, sep='\t', header=True, 423 result_df.to_csv(path_or_buf=outfile, sep="\t", header=True, index=False)
410 index=False)
411 424
412 return searcher 425 return searcher
413 426
414 427
415 def main(inputs, infile_estimator, infile1, infile2, 428 def main(
416 outfile_result, outfile_object=None, 429 inputs,
417 outfile_weights=None, groups=None, 430 infile_estimator,
418 ref_seq=None, intervals=None, targets=None, 431 infile1,
419 fasta_path=None): 432 infile2,
433 outfile_result,
434 outfile_object=None,
435 outfile_weights=None,
436 groups=None,
437 ref_seq=None,
438 intervals=None,
439 targets=None,
440 fasta_path=None,
441 ):
420 """ 442 """
421 Parameter 443 Parameter
422 --------- 444 ---------
423 inputs : str 445 inputs : str
424 File path to galaxy tool parameter 446 File path to galaxy tool parameter
454 File path to dataset compressed target bed file 476 File path to dataset compressed target bed file
455 477
456 fasta_path : str 478 fasta_path : str
457 File path to dataset containing fasta file 479 File path to dataset containing fasta file
458 """ 480 """
459 warnings.simplefilter('ignore') 481 warnings.simplefilter("ignore")
460 482
461 # store read dataframe object 483 # store read dataframe object
462 loaded_df = {} 484 loaded_df = {}
463 485
464 with open(inputs, 'r') as param_handler: 486 with open(inputs, "r") as param_handler:
465 params = json.load(param_handler) 487 params = json.load(param_handler)
466 488
467 # Override the refit parameter 489 # Override the refit parameter
468 params['search_schemes']['options']['refit'] = True \ 490 params["search_schemes"]["options"]["refit"] = True if params["save"] != "nope" else False
469 if params['save'] != 'nope' else False 491
470 492 with open(infile_estimator, "rb") as estimator_handler:
471 with open(infile_estimator, 'rb') as estimator_handler:
472 estimator = load_model(estimator_handler) 493 estimator = load_model(estimator_handler)
473 494
474 optimizer = params['search_schemes']['selected_search_scheme'] 495 optimizer = params["search_schemes"]["selected_search_scheme"]
475 optimizer = getattr(model_selection, optimizer) 496 optimizer = getattr(model_selection, optimizer)
476 497
477 # handle gridsearchcv options 498 # handle gridsearchcv options
478 options = params['search_schemes']['options'] 499 options = params["search_schemes"]["options"]
479 500
480 if groups: 501 if groups:
481 header = 'infer' if (options['cv_selector']['groups_selector'] 502 header = "infer" if (options["cv_selector"]["groups_selector"]["header_g"]) else None
482 ['header_g']) else None 503 column_option = options["cv_selector"]["groups_selector"]["column_selector_options_g"][
483 column_option = (options['cv_selector']['groups_selector'] 504 "selected_column_selector_option_g"
484 ['column_selector_options_g'] 505 ]
485 ['selected_column_selector_option_g']) 506 if column_option in [
486 if column_option in ['by_index_number', 'all_but_by_index_number', 507 "by_index_number",
487 'by_header_name', 'all_but_by_header_name']: 508 "all_but_by_index_number",
488 c = (options['cv_selector']['groups_selector'] 509 "by_header_name",
489 ['column_selector_options_g']['col_g']) 510 "all_but_by_header_name",
511 ]:
512 c = options["cv_selector"]["groups_selector"]["column_selector_options_g"]["col_g"]
490 else: 513 else:
491 c = None 514 c = None
492 515
493 df_key = groups + repr(header) 516 df_key = groups + repr(header)
494 517
495 groups = pd.read_csv(groups, sep='\t', header=header, 518 groups = pd.read_csv(groups, sep="\t", header=header, parse_dates=True)
496 parse_dates=True)
497 loaded_df[df_key] = groups 519 loaded_df[df_key] = groups
498 520
499 groups = read_columns( 521 groups = read_columns(
500 groups, 522 groups,
501 c=c, 523 c=c,
502 c_option=column_option, 524 c_option=column_option,
503 sep='\t', 525 sep="\t",
504 header=header, 526 header=header,
505 parse_dates=True) 527 parse_dates=True,
528 )
506 groups = groups.ravel() 529 groups = groups.ravel()
507 options['cv_selector']['groups_selector'] = groups 530 options["cv_selector"]["groups_selector"] = groups
508 531
509 splitter, groups = get_cv(options.pop('cv_selector')) 532 splitter, groups = get_cv(options.pop("cv_selector"))
510 options['cv'] = splitter 533 options["cv"] = splitter
511 primary_scoring = options['scoring']['primary_scoring'] 534 primary_scoring = options["scoring"]["primary_scoring"]
512 options['scoring'] = get_scoring(options['scoring']) 535 # get_scoring() expects secondary_scoring to be a comma separated string (not a list)
513 if options['error_score']: 536 # Check if secondary_scoring is specified
514 options['error_score'] = 'raise' 537 secondary_scoring = options["scoring"].get("secondary_scoring", None)
515 else: 538 if secondary_scoring is not None:
516 options['error_score'] = np.NaN 539 # If secondary_scoring is specified, convert the list into comman separated string
517 if options['refit'] and isinstance(options['scoring'], dict): 540 options["scoring"]["secondary_scoring"] = ",".join(options["scoring"]["secondary_scoring"])
518 options['refit'] = primary_scoring 541 options["scoring"] = get_scoring(options["scoring"])
519 if 'pre_dispatch' in options and options['pre_dispatch'] == '': 542 if options["error_score"]:
520 options['pre_dispatch'] = None 543 options["error_score"] = "raise"
521 544 else:
522 params_builder = params['search_schemes']['search_params_builder'] 545 options["error_score"] = np.NaN
546 if options["refit"] and isinstance(options["scoring"], dict):
547 options["refit"] = primary_scoring
548 if "pre_dispatch" in options and options["pre_dispatch"] == "":
549 options["pre_dispatch"] = None
550
551 params_builder = params["search_schemes"]["search_params_builder"]
523 param_grid = _eval_search_params(params_builder) 552 param_grid = _eval_search_params(params_builder)
524 553
525 estimator = clean_params(estimator) 554 estimator = clean_params(estimator)
526 555
527 # save the SearchCV object without fit 556 # save the SearchCV object without fit
528 if params['save'] == 'save_no_fit': 557 if params["save"] == "save_no_fit":
529 searcher = optimizer(estimator, param_grid, **options) 558 searcher = optimizer(estimator, param_grid, **options)
530 print(searcher) 559 print(searcher)
531 with open(outfile_object, 'wb') as output_handler: 560 with open(outfile_object, "wb") as output_handler:
532 pickle.dump(searcher, output_handler, 561 pickle.dump(searcher, output_handler, pickle.HIGHEST_PROTOCOL)
533 pickle.HIGHEST_PROTOCOL)
534 return 0 562 return 0
535 563
536 # read inputs and loads new attributes, like paths 564 # read inputs and loads new attributes, like paths
537 estimator, X, y = _handle_X_y(estimator, params, infile1, infile2, 565 estimator, X, y = _handle_X_y(
538 loaded_df=loaded_df, ref_seq=ref_seq, 566 estimator,
539 intervals=intervals, targets=targets, 567 params,
540 fasta_path=fasta_path) 568 infile1,
569 infile2,
570 loaded_df=loaded_df,
571 ref_seq=ref_seq,
572 intervals=intervals,
573 targets=targets,
574 fasta_path=fasta_path,
575 )
541 576
542 # cache iraps_core fits could increase search speed significantly 577 # cache iraps_core fits could increase search speed significantly
543 memory = joblib.Memory(location=CACHE_DIR, verbose=0) 578 memory = joblib.Memory(location=CACHE_DIR, verbose=0)
544 main_est = get_main_estimator(estimator) 579 main_est = get_main_estimator(estimator)
545 if main_est.__class__.__name__ == 'IRAPSClassifier': 580 if main_est.__class__.__name__ == "IRAPSClassifier":
546 main_est.set_params(memory=memory) 581 main_est.set_params(memory=memory)
547 582
548 searcher = optimizer(estimator, param_grid, **options) 583 searcher = optimizer(estimator, param_grid, **options)
549 584
550 split_mode = params['outer_split'].pop('split_mode') 585 split_mode = params["outer_split"].pop("split_mode")
551 586
552 if split_mode == 'nested_cv': 587 if split_mode == "nested_cv":
553 # make sure refit is choosen 588 # make sure refit is choosen
554 # this could be True for sklearn models, but not the case for 589 # this could be True for sklearn models, but not the case for
555 # deep learning models 590 # deep learning models
556 if not options['refit'] and \ 591 if not options["refit"] and not all(hasattr(estimator, attr) for attr in ("config", "model_type")):
557 not all(hasattr(estimator, attr)
558 for attr in ('config', 'model_type')):
559 warnings.warn("Refit is change to `True` for nested validation!") 592 warnings.warn("Refit is change to `True` for nested validation!")
560 setattr(searcher, 'refit', True) 593 setattr(searcher, "refit", True)
561 594
562 outer_cv, _ = get_cv(params['outer_split']['cv_selector']) 595 outer_cv, _ = get_cv(params["outer_split"]["cv_selector"])
563 # nested CV, outer cv using cross_validate 596 # nested CV, outer cv using cross_validate
564 if options['error_score'] == 'raise': 597 if options["error_score"] == "raise":
565 rval = cross_validate( 598 rval = cross_validate(
566 searcher, X, y, scoring=options['scoring'], 599 searcher,
567 cv=outer_cv, n_jobs=N_JOBS, 600 X,
568 verbose=options['verbose'], 601 y,
569 return_estimator=(params['save'] == 'save_estimator'), 602 scoring=options["scoring"],
570 error_score=options['error_score'], 603 cv=outer_cv,
571 return_train_score=True) 604 n_jobs=N_JOBS,
605 verbose=options["verbose"],
606 return_estimator=(params["save"] == "save_estimator"),
607 error_score=options["error_score"],
608 return_train_score=True,
609 )
572 else: 610 else:
573 warnings.simplefilter('always', FitFailedWarning) 611 warnings.simplefilter("always", FitFailedWarning)
574 with warnings.catch_warnings(record=True) as w: 612 with warnings.catch_warnings(record=True) as w:
575 try: 613 try:
576 rval = cross_validate( 614 rval = cross_validate(
577 searcher, X, y, 615 searcher,
578 scoring=options['scoring'], 616 X,
579 cv=outer_cv, n_jobs=N_JOBS, 617 y,
580 verbose=options['verbose'], 618 scoring=options["scoring"],
581 return_estimator=(params['save'] == 'save_estimator'), 619 cv=outer_cv,
582 error_score=options['error_score'], 620 n_jobs=N_JOBS,
583 return_train_score=True) 621 verbose=options["verbose"],
622 return_estimator=(params["save"] == "save_estimator"),
623 error_score=options["error_score"],
624 return_train_score=True,
625 )
584 except ValueError: 626 except ValueError:
585 pass 627 pass
586 for warning in w: 628 for warning in w:
587 print(repr(warning.message)) 629 print(repr(warning.message))
588 630
589 fitted_searchers = rval.pop('estimator', []) 631 fitted_searchers = rval.pop("estimator", [])
590 if fitted_searchers: 632 if fitted_searchers:
591 import os 633 import os
634
592 pwd = os.getcwd() 635 pwd = os.getcwd()
593 save_dir = os.path.join(pwd, 'cv_results_in_folds') 636 save_dir = os.path.join(pwd, "cv_results_in_folds")
594 try: 637 try:
595 os.mkdir(save_dir) 638 os.mkdir(save_dir)
596 for idx, obj in enumerate(fitted_searchers): 639 for idx, obj in enumerate(fitted_searchers):
597 target_name = 'cv_results_' + '_' + 'split%d' % idx 640 target_name = "cv_results_" + "_" + "split%d" % idx
598 target_path = os.path.join(pwd, save_dir, target_name) 641 target_path = os.path.join(pwd, save_dir, target_name)
599 cv_results_ = getattr(obj, 'cv_results_', None) 642 cv_results_ = getattr(obj, "cv_results_", None)
600 if not cv_results_: 643 if not cv_results_:
601 print("%s is not available" % target_name) 644 print("%s is not available" % target_name)
602 continue 645 continue
603 cv_results_ = pd.DataFrame(cv_results_) 646 cv_results_ = pd.DataFrame(cv_results_)
604 cv_results_ = cv_results_[sorted(cv_results_.columns)] 647 cv_results_ = cv_results_[sorted(cv_results_.columns)]
605 cv_results_.to_csv(target_path, sep='\t', header=True, 648 cv_results_.to_csv(target_path, sep="\t", header=True, index=False)
606 index=False)
607 except Exception as e: 649 except Exception as e:
608 print(e) 650 print(e)
609 finally: 651 finally:
610 del os 652 del os
611 653
612 keys = list(rval.keys()) 654 keys = list(rval.keys())
613 for k in keys: 655 for k in keys:
614 if k.startswith('test'): 656 if k.startswith("test"):
615 rval['mean_' + k] = np.mean(rval[k]) 657 rval["mean_" + k] = np.mean(rval[k])
616 rval['std_' + k] = np.std(rval[k]) 658 rval["std_" + k] = np.std(rval[k])
617 if k.endswith('time'): 659 if k.endswith("time"):
618 rval.pop(k) 660 rval.pop(k)
619 rval = pd.DataFrame(rval) 661 rval = pd.DataFrame(rval)
620 rval = rval[sorted(rval.columns)] 662 rval = rval[sorted(rval.columns)]
621 rval.to_csv(path_or_buf=outfile_result, sep='\t', header=True, 663 rval.to_csv(path_or_buf=outfile_result, sep="\t", header=True, index=False)
622 index=False)
623
624 return 0
625
626 # deprecate train test split mode 664 # deprecate train test split mode
627 """searcher = _do_train_test_split_val( 665 """searcher = _do_train_test_split_val(
628 searcher, X, y, params, 666 searcher, X, y, params,
629 primary_scoring=primary_scoring, 667 primary_scoring=primary_scoring,
630 error_score=options['error_score'], 668 error_score=options['error_score'],
631 groups=groups, 669 groups=groups,
632 outfile=outfile_result)""" 670 outfile=outfile_result)"""
671 return 0
633 672
634 # no outer split 673 # no outer split
635 else: 674 else:
636 searcher.set_params(n_jobs=N_JOBS) 675 searcher.set_params(n_jobs=N_JOBS)
637 if options['error_score'] == 'raise': 676 if options["error_score"] == "raise":
638 searcher.fit(X, y, groups=groups) 677 searcher.fit(X, y, groups=groups)
639 else: 678 else:
640 warnings.simplefilter('always', FitFailedWarning) 679 warnings.simplefilter("always", FitFailedWarning)
641 with warnings.catch_warnings(record=True) as w: 680 with warnings.catch_warnings(record=True) as w:
642 try: 681 try:
643 searcher.fit(X, y, groups=groups) 682 searcher.fit(X, y, groups=groups)
644 except ValueError: 683 except ValueError:
645 pass 684 pass
646 for warning in w: 685 for warning in w:
647 print(repr(warning.message)) 686 print(repr(warning.message))
648 687
649 cv_results = pd.DataFrame(searcher.cv_results_) 688 cv_results = pd.DataFrame(searcher.cv_results_)
650 cv_results = cv_results[sorted(cv_results.columns)] 689 cv_results = cv_results[sorted(cv_results.columns)]
651 cv_results.to_csv(path_or_buf=outfile_result, sep='\t', 690 cv_results.to_csv(path_or_buf=outfile_result, sep="\t", header=True, index=False)
652 header=True, index=False)
653 691
654 memory.clear(warn=False) 692 memory.clear(warn=False)
655 693
656 # output best estimator, and weights if applicable 694 # output best estimator, and weights if applicable
657 if outfile_object: 695 if outfile_object:
658 best_estimator_ = getattr(searcher, 'best_estimator_', None) 696 best_estimator_ = getattr(searcher, "best_estimator_", None)
659 if not best_estimator_: 697 if not best_estimator_:
660 warnings.warn("GridSearchCV object has no attribute " 698 warnings.warn(
661 "'best_estimator_', because either it's " 699 "GridSearchCV object has no attribute "
662 "nested gridsearch or `refit` is False!") 700 "'best_estimator_', because either it's "
701 "nested gridsearch or `refit` is False!"
702 )
663 return 703 return
664 704
665 # clean prams 705 # clean prams
666 best_estimator_ = clean_params(best_estimator_) 706 best_estimator_ = clean_params(best_estimator_)
667 707
668 main_est = get_main_estimator(best_estimator_) 708 main_est = get_main_estimator(best_estimator_)
669 709
670 if hasattr(main_est, 'model_') \ 710 if hasattr(main_est, "model_") and hasattr(main_est, "save_weights"):
671 and hasattr(main_est, 'save_weights'):
672 if outfile_weights: 711 if outfile_weights:
673 main_est.save_weights(outfile_weights) 712 main_est.save_weights(outfile_weights)
674 del main_est.model_ 713 del main_est.model_
675 del main_est.fit_params 714 del main_est.fit_params
676 del main_est.model_class_ 715 del main_est.model_class_
677 del main_est.validation_data 716 del main_est.validation_data
678 if getattr(main_est, 'data_generator_', None): 717 if getattr(main_est, "data_generator_", None):
679 del main_est.data_generator_ 718 del main_est.data_generator_
680 719
681 with open(outfile_object, 'wb') as output_handler: 720 with open(outfile_object, "wb") as output_handler:
682 print("Best estimator is saved: %s " % repr(best_estimator_)) 721 print("Best estimator is saved: %s " % repr(best_estimator_))
683 pickle.dump(best_estimator_, output_handler, 722 pickle.dump(best_estimator_, output_handler, pickle.HIGHEST_PROTOCOL)
684 pickle.HIGHEST_PROTOCOL) 723
685 724
686 725 if __name__ == "__main__":
687 if __name__ == '__main__':
688 aparser = argparse.ArgumentParser() 726 aparser = argparse.ArgumentParser()
689 aparser.add_argument("-i", "--inputs", dest="inputs", required=True) 727 aparser.add_argument("-i", "--inputs", dest="inputs", required=True)
690 aparser.add_argument("-e", "--estimator", dest="infile_estimator") 728 aparser.add_argument("-e", "--estimator", dest="infile_estimator")
691 aparser.add_argument("-X", "--infile1", dest="infile1") 729 aparser.add_argument("-X", "--infile1", dest="infile1")
692 aparser.add_argument("-y", "--infile2", dest="infile2") 730 aparser.add_argument("-y", "--infile2", dest="infile2")
698 aparser.add_argument("-b", "--intervals", dest="intervals") 736 aparser.add_argument("-b", "--intervals", dest="intervals")
699 aparser.add_argument("-t", "--targets", dest="targets") 737 aparser.add_argument("-t", "--targets", dest="targets")
700 aparser.add_argument("-f", "--fasta_path", dest="fasta_path") 738 aparser.add_argument("-f", "--fasta_path", dest="fasta_path")
701 args = aparser.parse_args() 739 args = aparser.parse_args()
702 740
703 main(args.inputs, args.infile_estimator, args.infile1, args.infile2, 741 main(
704 args.outfile_result, outfile_object=args.outfile_object, 742 args.inputs,
705 outfile_weights=args.outfile_weights, groups=args.groups, 743 args.infile_estimator,
706 ref_seq=args.ref_seq, intervals=args.intervals, 744 args.infile1,
707 targets=args.targets, fasta_path=args.fasta_path) 745 args.infile2,
746 args.outfile_result,
747 outfile_object=args.outfile_object,
748 outfile_weights=args.outfile_weights,
749 groups=args.groups,
750 ref_seq=args.ref_seq,
751 intervals=args.intervals,
752 targets=args.targets,
753 fasta_path=args.fasta_path,
754 )