comparison search_model_validation.py @ 9:4aa701f5a393 draft

"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit e2a5eade6d0e5ddf3a47630381a0ad90d80e8a04"
author bgruening
date Tue, 13 Apr 2021 18:00:54 +0000
parents 6efb9bc6bf32
children 22f9cbcf1582
comparison
equal deleted inserted replaced
8:83228baae3c5 9:4aa701f5a393
9 import pickle 9 import pickle
10 import skrebate 10 import skrebate
11 import sys 11 import sys
12 import warnings 12 import warnings
13 from scipy.io import mmread 13 from scipy.io import mmread
14 from sklearn import (cluster, decomposition, feature_selection, 14 from sklearn import (
15 kernel_approximation, model_selection, preprocessing) 15 cluster,
16 decomposition,
17 feature_selection,
18 kernel_approximation,
19 model_selection,
20 preprocessing,
21 )
16 from sklearn.exceptions import FitFailedWarning 22 from sklearn.exceptions import FitFailedWarning
17 from sklearn.model_selection._validation import _score, cross_validate 23 from sklearn.model_selection._validation import _score, cross_validate
18 from sklearn.model_selection import _search, _validation 24 from sklearn.model_selection import _search, _validation
19 from sklearn.pipeline import Pipeline 25 from sklearn.pipeline import Pipeline
20 26
21 from galaxy_ml.utils import (SafeEval, get_cv, get_scoring, load_model, 27 from galaxy_ml.utils import (
22 read_columns, try_get_attr, get_module, 28 SafeEval,
23 clean_params, get_main_estimator) 29 get_cv,
24 30 get_scoring,
25 31 load_model,
26 _fit_and_score = try_get_attr('galaxy_ml.model_validations', '_fit_and_score') 32 read_columns,
27 setattr(_search, '_fit_and_score', _fit_and_score) 33 try_get_attr,
28 setattr(_validation, '_fit_and_score', _fit_and_score) 34 get_module,
29 35 clean_params,
30 N_JOBS = int(os.environ.get('GALAXY_SLOTS', 1)) 36 get_main_estimator,
37 )
38
39
40 _fit_and_score = try_get_attr("galaxy_ml.model_validations", "_fit_and_score")
41 setattr(_search, "_fit_and_score", _fit_and_score)
42 setattr(_validation, "_fit_and_score", _fit_and_score)
43
44 N_JOBS = int(os.environ.get("GALAXY_SLOTS", 1))
31 # handle disk cache 45 # handle disk cache
32 CACHE_DIR = os.path.join(os.getcwd(), 'cached') 46 CACHE_DIR = os.path.join(os.getcwd(), "cached")
33 del os 47 del os
34 NON_SEARCHABLE = ('n_jobs', 'pre_dispatch', 'memory', '_path', 48 NON_SEARCHABLE = ("n_jobs", "pre_dispatch", "memory", "_path", "nthread", "callbacks")
35 'nthread', 'callbacks')
36 49
37 50
38 def _eval_search_params(params_builder): 51 def _eval_search_params(params_builder):
39 search_params = {} 52 search_params = {}
40 53
41 for p in params_builder['param_set']: 54 for p in params_builder["param_set"]:
42 search_list = p['sp_list'].strip() 55 search_list = p["sp_list"].strip()
43 if search_list == '': 56 if search_list == "":
44 continue 57 continue
45 58
46 param_name = p['sp_name'] 59 param_name = p["sp_name"]
47 if param_name.lower().endswith(NON_SEARCHABLE): 60 if param_name.lower().endswith(NON_SEARCHABLE):
48 print("Warning: `%s` is not eligible for search and was " 61 print("Warning: `%s` is not eligible for search and was " "omitted!" % param_name)
49 "omitted!" % param_name)
50 continue 62 continue
51 63
52 if not search_list.startswith(':'): 64 if not search_list.startswith(":"):
53 safe_eval = SafeEval(load_scipy=True, load_numpy=True) 65 safe_eval = SafeEval(load_scipy=True, load_numpy=True)
54 ev = safe_eval(search_list) 66 ev = safe_eval(search_list)
55 search_params[param_name] = ev 67 search_params[param_name] = ev
56 else: 68 else:
57 # Have `:` before search list, asks for estimator evaluatio 69 # Have `:` before search list, asks for estimator evaluatio
58 safe_eval_es = SafeEval(load_estimators=True) 70 safe_eval_es = SafeEval(load_estimators=True)
59 search_list = search_list[1:].strip() 71 search_list = search_list[1:].strip()
60 # TODO maybe add regular express check 72 # TODO maybe add regular express check
61 ev = safe_eval_es(search_list) 73 ev = safe_eval_es(search_list)
62 preprocessings = ( 74 preprocessings = (
63 preprocessing.StandardScaler(), preprocessing.Binarizer(), 75 preprocessing.StandardScaler(),
76 preprocessing.Binarizer(),
64 preprocessing.MaxAbsScaler(), 77 preprocessing.MaxAbsScaler(),
65 preprocessing.Normalizer(), preprocessing.MinMaxScaler(), 78 preprocessing.Normalizer(),
79 preprocessing.MinMaxScaler(),
66 preprocessing.PolynomialFeatures(), 80 preprocessing.PolynomialFeatures(),
67 preprocessing.RobustScaler(), feature_selection.SelectKBest(), 81 preprocessing.RobustScaler(),
82 feature_selection.SelectKBest(),
68 feature_selection.GenericUnivariateSelect(), 83 feature_selection.GenericUnivariateSelect(),
69 feature_selection.SelectPercentile(), 84 feature_selection.SelectPercentile(),
70 feature_selection.SelectFpr(), feature_selection.SelectFdr(), 85 feature_selection.SelectFpr(),
86 feature_selection.SelectFdr(),
71 feature_selection.SelectFwe(), 87 feature_selection.SelectFwe(),
72 feature_selection.VarianceThreshold(), 88 feature_selection.VarianceThreshold(),
73 decomposition.FactorAnalysis(random_state=0), 89 decomposition.FactorAnalysis(random_state=0),
74 decomposition.FastICA(random_state=0), 90 decomposition.FastICA(random_state=0),
75 decomposition.IncrementalPCA(), 91 decomposition.IncrementalPCA(),
76 decomposition.KernelPCA(random_state=0, n_jobs=N_JOBS), 92 decomposition.KernelPCA(random_state=0, n_jobs=N_JOBS),
77 decomposition.LatentDirichletAllocation( 93 decomposition.LatentDirichletAllocation(random_state=0, n_jobs=N_JOBS),
78 random_state=0, n_jobs=N_JOBS), 94 decomposition.MiniBatchDictionaryLearning(random_state=0, n_jobs=N_JOBS),
79 decomposition.MiniBatchDictionaryLearning( 95 decomposition.MiniBatchSparsePCA(random_state=0, n_jobs=N_JOBS),
80 random_state=0, n_jobs=N_JOBS),
81 decomposition.MiniBatchSparsePCA(
82 random_state=0, n_jobs=N_JOBS),
83 decomposition.NMF(random_state=0), 96 decomposition.NMF(random_state=0),
84 decomposition.PCA(random_state=0), 97 decomposition.PCA(random_state=0),
85 decomposition.SparsePCA(random_state=0, n_jobs=N_JOBS), 98 decomposition.SparsePCA(random_state=0, n_jobs=N_JOBS),
86 decomposition.TruncatedSVD(random_state=0), 99 decomposition.TruncatedSVD(random_state=0),
87 kernel_approximation.Nystroem(random_state=0), 100 kernel_approximation.Nystroem(random_state=0),
92 skrebate.ReliefF(n_jobs=N_JOBS), 105 skrebate.ReliefF(n_jobs=N_JOBS),
93 skrebate.SURF(n_jobs=N_JOBS), 106 skrebate.SURF(n_jobs=N_JOBS),
94 skrebate.SURFstar(n_jobs=N_JOBS), 107 skrebate.SURFstar(n_jobs=N_JOBS),
95 skrebate.MultiSURF(n_jobs=N_JOBS), 108 skrebate.MultiSURF(n_jobs=N_JOBS),
96 skrebate.MultiSURFstar(n_jobs=N_JOBS), 109 skrebate.MultiSURFstar(n_jobs=N_JOBS),
97 imblearn.under_sampling.ClusterCentroids( 110 imblearn.under_sampling.ClusterCentroids(random_state=0, n_jobs=N_JOBS),
98 random_state=0, n_jobs=N_JOBS), 111 imblearn.under_sampling.CondensedNearestNeighbour(random_state=0, n_jobs=N_JOBS),
99 imblearn.under_sampling.CondensedNearestNeighbour( 112 imblearn.under_sampling.EditedNearestNeighbours(random_state=0, n_jobs=N_JOBS),
100 random_state=0, n_jobs=N_JOBS), 113 imblearn.under_sampling.RepeatedEditedNearestNeighbours(random_state=0, n_jobs=N_JOBS),
101 imblearn.under_sampling.EditedNearestNeighbours(
102 random_state=0, n_jobs=N_JOBS),
103 imblearn.under_sampling.RepeatedEditedNearestNeighbours(
104 random_state=0, n_jobs=N_JOBS),
105 imblearn.under_sampling.AllKNN(random_state=0, n_jobs=N_JOBS), 114 imblearn.under_sampling.AllKNN(random_state=0, n_jobs=N_JOBS),
106 imblearn.under_sampling.InstanceHardnessThreshold( 115 imblearn.under_sampling.InstanceHardnessThreshold(random_state=0, n_jobs=N_JOBS),
107 random_state=0, n_jobs=N_JOBS), 116 imblearn.under_sampling.NearMiss(random_state=0, n_jobs=N_JOBS),
108 imblearn.under_sampling.NearMiss( 117 imblearn.under_sampling.NeighbourhoodCleaningRule(random_state=0, n_jobs=N_JOBS),
109 random_state=0, n_jobs=N_JOBS), 118 imblearn.under_sampling.OneSidedSelection(random_state=0, n_jobs=N_JOBS),
110 imblearn.under_sampling.NeighbourhoodCleaningRule( 119 imblearn.under_sampling.RandomUnderSampler(random_state=0),
111 random_state=0, n_jobs=N_JOBS), 120 imblearn.under_sampling.TomekLinks(random_state=0, n_jobs=N_JOBS),
112 imblearn.under_sampling.OneSidedSelection(
113 random_state=0, n_jobs=N_JOBS),
114 imblearn.under_sampling.RandomUnderSampler(
115 random_state=0),
116 imblearn.under_sampling.TomekLinks(
117 random_state=0, n_jobs=N_JOBS),
118 imblearn.over_sampling.ADASYN(random_state=0, n_jobs=N_JOBS), 121 imblearn.over_sampling.ADASYN(random_state=0, n_jobs=N_JOBS),
119 imblearn.over_sampling.RandomOverSampler(random_state=0), 122 imblearn.over_sampling.RandomOverSampler(random_state=0),
120 imblearn.over_sampling.SMOTE(random_state=0, n_jobs=N_JOBS), 123 imblearn.over_sampling.SMOTE(random_state=0, n_jobs=N_JOBS),
121 imblearn.over_sampling.SVMSMOTE(random_state=0, n_jobs=N_JOBS), 124 imblearn.over_sampling.SVMSMOTE(random_state=0, n_jobs=N_JOBS),
122 imblearn.over_sampling.BorderlineSMOTE( 125 imblearn.over_sampling.BorderlineSMOTE(random_state=0, n_jobs=N_JOBS),
123 random_state=0, n_jobs=N_JOBS), 126 imblearn.over_sampling.SMOTENC(categorical_features=[], random_state=0, n_jobs=N_JOBS),
124 imblearn.over_sampling.SMOTENC(
125 categorical_features=[], random_state=0, n_jobs=N_JOBS),
126 imblearn.combine.SMOTEENN(random_state=0), 127 imblearn.combine.SMOTEENN(random_state=0),
127 imblearn.combine.SMOTETomek(random_state=0)) 128 imblearn.combine.SMOTETomek(random_state=0),
129 )
128 newlist = [] 130 newlist = []
129 for obj in ev: 131 for obj in ev:
130 if obj is None: 132 if obj is None:
131 newlist.append(None) 133 newlist.append(None)
132 elif obj == 'all_0': 134 elif obj == "all_0":
133 newlist.extend(preprocessings[0:35]) 135 newlist.extend(preprocessings[0:35])
134 elif obj == 'sk_prep_all': # no KernalCenter() 136 elif obj == "sk_prep_all": # no KernalCenter()
135 newlist.extend(preprocessings[0:7]) 137 newlist.extend(preprocessings[0:7])
136 elif obj == 'fs_all': 138 elif obj == "fs_all":
137 newlist.extend(preprocessings[7:14]) 139 newlist.extend(preprocessings[7:14])
138 elif obj == 'decomp_all': 140 elif obj == "decomp_all":
139 newlist.extend(preprocessings[14:25]) 141 newlist.extend(preprocessings[14:25])
140 elif obj == 'k_appr_all': 142 elif obj == "k_appr_all":
141 newlist.extend(preprocessings[25:29]) 143 newlist.extend(preprocessings[25:29])
142 elif obj == 'reb_all': 144 elif obj == "reb_all":
143 newlist.extend(preprocessings[30:35]) 145 newlist.extend(preprocessings[30:35])
144 elif obj == 'imb_all': 146 elif obj == "imb_all":
145 newlist.extend(preprocessings[35:54]) 147 newlist.extend(preprocessings[35:54])
146 elif type(obj) is int and -1 < obj < len(preprocessings): 148 elif type(obj) is int and -1 < obj < len(preprocessings):
147 newlist.append(preprocessings[obj]) 149 newlist.append(preprocessings[obj])
148 elif hasattr(obj, 'get_params'): # user uploaded object 150 elif hasattr(obj, "get_params"): # user uploaded object
149 if 'n_jobs' in obj.get_params(): 151 if "n_jobs" in obj.get_params():
150 newlist.append(obj.set_params(n_jobs=N_JOBS)) 152 newlist.append(obj.set_params(n_jobs=N_JOBS))
151 else: 153 else:
152 newlist.append(obj) 154 newlist.append(obj)
153 else: 155 else:
154 sys.exit("Unsupported estimator type: %r" % (obj)) 156 sys.exit("Unsupported estimator type: %r" % (obj))
156 search_params[param_name] = newlist 158 search_params[param_name] = newlist
157 159
158 return search_params 160 return search_params
159 161
160 162
161 def _handle_X_y(estimator, params, infile1, infile2, loaded_df={}, 163 def _handle_X_y(
162 ref_seq=None, intervals=None, targets=None, 164 estimator,
163 fasta_path=None): 165 params,
166 infile1,
167 infile2,
168 loaded_df={},
169 ref_seq=None,
170 intervals=None,
171 targets=None,
172 fasta_path=None,
173 ):
164 """read inputs 174 """read inputs
165 175
166 Params 176 Params
167 ------- 177 -------
168 estimator : estimator object 178 estimator : estimator object
190 X : numpy array 200 X : numpy array
191 y : numpy array 201 y : numpy array
192 """ 202 """
193 estimator_params = estimator.get_params() 203 estimator_params = estimator.get_params()
194 204
195 input_type = params['input_options']['selected_input'] 205 input_type = params["input_options"]["selected_input"]
196 # tabular input 206 # tabular input
197 if input_type == 'tabular': 207 if input_type == "tabular":
198 header = 'infer' if params['input_options']['header1'] else None 208 header = "infer" if params["input_options"]["header1"] else None
199 column_option = (params['input_options']['column_selector_options_1'] 209 column_option = params["input_options"]["column_selector_options_1"]["selected_column_selector_option"]
200 ['selected_column_selector_option']) 210 if column_option in [
201 if column_option in ['by_index_number', 'all_but_by_index_number', 211 "by_index_number",
202 'by_header_name', 'all_but_by_header_name']: 212 "all_but_by_index_number",
203 c = params['input_options']['column_selector_options_1']['col1'] 213 "by_header_name",
214 "all_but_by_header_name",
215 ]:
216 c = params["input_options"]["column_selector_options_1"]["col1"]
204 else: 217 else:
205 c = None 218 c = None
206 219
207 df_key = infile1 + repr(header) 220 df_key = infile1 + repr(header)
208 221
209 if df_key in loaded_df: 222 if df_key in loaded_df:
210 infile1 = loaded_df[df_key] 223 infile1 = loaded_df[df_key]
211 224
212 df = pd.read_csv(infile1, sep='\t', header=header, 225 df = pd.read_csv(infile1, sep="\t", header=header, parse_dates=True)
213 parse_dates=True)
214 loaded_df[df_key] = df 226 loaded_df[df_key] = df
215 227
216 X = read_columns(df, c=c, c_option=column_option).astype(float) 228 X = read_columns(df, c=c, c_option=column_option).astype(float)
217 # sparse input 229 # sparse input
218 elif input_type == 'sparse': 230 elif input_type == "sparse":
219 X = mmread(open(infile1, 'r')) 231 X = mmread(open(infile1, "r"))
220 232
221 # fasta_file input 233 # fasta_file input
222 elif input_type == 'seq_fasta': 234 elif input_type == "seq_fasta":
223 pyfaidx = get_module('pyfaidx') 235 pyfaidx = get_module("pyfaidx")
224 sequences = pyfaidx.Fasta(fasta_path) 236 sequences = pyfaidx.Fasta(fasta_path)
225 n_seqs = len(sequences.keys()) 237 n_seqs = len(sequences.keys())
226 X = np.arange(n_seqs)[:, np.newaxis] 238 X = np.arange(n_seqs)[:, np.newaxis]
227 for param in estimator_params.keys(): 239 for param in estimator_params.keys():
228 if param.endswith('fasta_path'): 240 if param.endswith("fasta_path"):
229 estimator.set_params( 241 estimator.set_params(**{param: fasta_path})
230 **{param: fasta_path})
231 break 242 break
232 else: 243 else:
233 raise ValueError( 244 raise ValueError(
234 "The selected estimator doesn't support " 245 "The selected estimator doesn't support "
235 "fasta file input! Please consider using " 246 "fasta file input! Please consider using "
236 "KerasGBatchClassifier with " 247 "KerasGBatchClassifier with "
237 "FastaDNABatchGenerator/FastaProteinBatchGenerator " 248 "FastaDNABatchGenerator/FastaProteinBatchGenerator "
238 "or having GenomeOneHotEncoder/ProteinOneHotEncoder " 249 "or having GenomeOneHotEncoder/ProteinOneHotEncoder "
239 "in pipeline!") 250 "in pipeline!"
240 251 )
241 elif input_type == 'refseq_and_interval': 252
253 elif input_type == "refseq_and_interval":
242 path_params = { 254 path_params = {
243 'data_batch_generator__ref_genome_path': ref_seq, 255 "data_batch_generator__ref_genome_path": ref_seq,
244 'data_batch_generator__intervals_path': intervals, 256 "data_batch_generator__intervals_path": intervals,
245 'data_batch_generator__target_path': targets 257 "data_batch_generator__target_path": targets,
246 } 258 }
247 estimator.set_params(**path_params) 259 estimator.set_params(**path_params)
248 n_intervals = sum(1 for line in open(intervals)) 260 n_intervals = sum(1 for line in open(intervals))
249 X = np.arange(n_intervals)[:, np.newaxis] 261 X = np.arange(n_intervals)[:, np.newaxis]
250 262
251 # Get target y 263 # Get target y
252 header = 'infer' if params['input_options']['header2'] else None 264 header = "infer" if params["input_options"]["header2"] else None
253 column_option = (params['input_options']['column_selector_options_2'] 265 column_option = params["input_options"]["column_selector_options_2"]["selected_column_selector_option2"]
254 ['selected_column_selector_option2']) 266 if column_option in [
255 if column_option in ['by_index_number', 'all_but_by_index_number', 267 "by_index_number",
256 'by_header_name', 'all_but_by_header_name']: 268 "all_but_by_index_number",
257 c = params['input_options']['column_selector_options_2']['col2'] 269 "by_header_name",
270 "all_but_by_header_name",
271 ]:
272 c = params["input_options"]["column_selector_options_2"]["col2"]
258 else: 273 else:
259 c = None 274 c = None
260 275
261 df_key = infile2 + repr(header) 276 df_key = infile2 + repr(header)
262 if df_key in loaded_df: 277 if df_key in loaded_df:
263 infile2 = loaded_df[df_key] 278 infile2 = loaded_df[df_key]
264 else: 279 else:
265 infile2 = pd.read_csv(infile2, sep='\t', 280 infile2 = pd.read_csv(infile2, sep="\t", header=header, parse_dates=True)
266 header=header, parse_dates=True)
267 loaded_df[df_key] = infile2 281 loaded_df[df_key] = infile2
268 282
269 y = read_columns( 283 y = read_columns(infile2, c=c, c_option=column_option, sep="\t", header=header, parse_dates=True)
270 infile2,
271 c=c,
272 c_option=column_option,
273 sep='\t',
274 header=header,
275 parse_dates=True)
276 if len(y.shape) == 2 and y.shape[1] == 1: 284 if len(y.shape) == 2 and y.shape[1] == 1:
277 y = y.ravel() 285 y = y.ravel()
278 if input_type == 'refseq_and_interval': 286 if input_type == "refseq_and_interval":
279 estimator.set_params( 287 estimator.set_params(data_batch_generator__features=y.ravel().tolist())
280 data_batch_generator__features=y.ravel().tolist())
281 y = None 288 y = None
282 # end y 289 # end y
283 290
284 return estimator, X, y 291 return estimator, X, y
285 292
286 293
287 def _do_outer_cv(searcher, X, y, outer_cv, scoring, error_score='raise', 294 def _do_outer_cv(searcher, X, y, outer_cv, scoring, error_score="raise", outfile=None):
288 outfile=None):
289 """Do outer cross-validation for nested CV 295 """Do outer cross-validation for nested CV
290 296
291 Parameters 297 Parameters
292 ---------- 298 ----------
293 searcher : object 299 searcher : object
303 error_score: str, float or numpy float 309 error_score: str, float or numpy float
304 Whether to raise fit error or return an value 310 Whether to raise fit error or return an value
305 outfile : str 311 outfile : str
306 File path to store the restuls 312 File path to store the restuls
307 """ 313 """
308 if error_score == 'raise': 314 if error_score == "raise":
309 rval = cross_validate( 315 rval = cross_validate(
310 searcher, X, y, scoring=scoring, 316 searcher,
311 cv=outer_cv, n_jobs=N_JOBS, verbose=0, 317 X,
312 error_score=error_score) 318 y,
313 else: 319 scoring=scoring,
314 warnings.simplefilter('always', FitFailedWarning) 320 cv=outer_cv,
321 n_jobs=N_JOBS,
322 verbose=0,
323 error_score=error_score,
324 )
325 else:
326 warnings.simplefilter("always", FitFailedWarning)
315 with warnings.catch_warnings(record=True) as w: 327 with warnings.catch_warnings(record=True) as w:
316 try: 328 try:
317 rval = cross_validate( 329 rval = cross_validate(
318 searcher, X, y, 330 searcher,
331 X,
332 y,
319 scoring=scoring, 333 scoring=scoring,
320 cv=outer_cv, n_jobs=N_JOBS, 334 cv=outer_cv,
335 n_jobs=N_JOBS,
321 verbose=0, 336 verbose=0,
322 error_score=error_score) 337 error_score=error_score,
338 )
323 except ValueError: 339 except ValueError:
324 pass 340 pass
325 for warning in w: 341 for warning in w:
326 print(repr(warning.message)) 342 print(repr(warning.message))
327 343
328 keys = list(rval.keys()) 344 keys = list(rval.keys())
329 for k in keys: 345 for k in keys:
330 if k.startswith('test'): 346 if k.startswith("test"):
331 rval['mean_' + k] = np.mean(rval[k]) 347 rval["mean_" + k] = np.mean(rval[k])
332 rval['std_' + k] = np.std(rval[k]) 348 rval["std_" + k] = np.std(rval[k])
333 if k.endswith('time'): 349 if k.endswith("time"):
334 rval.pop(k) 350 rval.pop(k)
335 rval = pd.DataFrame(rval) 351 rval = pd.DataFrame(rval)
336 rval = rval[sorted(rval.columns)] 352 rval = rval[sorted(rval.columns)]
337 rval.to_csv(path_or_buf=outfile, sep='\t', header=True, index=False) 353 rval.to_csv(path_or_buf=outfile, sep="\t", header=True, index=False)
338 354
339 355
340 def _do_train_test_split_val(searcher, X, y, params, error_score='raise', 356 def _do_train_test_split_val(
341 primary_scoring=None, groups=None, 357 searcher,
342 outfile=None): 358 X,
343 """ do train test split, searchCV validates on the train and then use 359 y,
360 params,
361 error_score="raise",
362 primary_scoring=None,
363 groups=None,
364 outfile=None,
365 ):
366 """do train test split, searchCV validates on the train and then use
344 the best_estimator_ to evaluate on the test 367 the best_estimator_ to evaluate on the test
345 368
346 Returns 369 Returns
347 -------- 370 --------
348 Fitted SearchCV object 371 Fitted SearchCV object
349 """ 372 """
350 train_test_split = try_get_attr( 373 train_test_split = try_get_attr("galaxy_ml.model_validations", "train_test_split")
351 'galaxy_ml.model_validations', 'train_test_split') 374 split_options = params["outer_split"]
352 split_options = params['outer_split']
353 375
354 # splits 376 # splits
355 if split_options['shuffle'] == 'stratified': 377 if split_options["shuffle"] == "stratified":
356 split_options['labels'] = y 378 split_options["labels"] = y
357 X, X_test, y, y_test = train_test_split(X, y, **split_options) 379 X, X_test, y, y_test = train_test_split(X, y, **split_options)
358 elif split_options['shuffle'] == 'group': 380 elif split_options["shuffle"] == "group":
359 if groups is None: 381 if groups is None:
360 raise ValueError("No group based CV option was choosen for " 382 raise ValueError("No group based CV option was choosen for " "group shuffle!")
361 "group shuffle!") 383 split_options["labels"] = groups
362 split_options['labels'] = groups
363 if y is None: 384 if y is None:
364 X, X_test, groups, _ =\ 385 X, X_test, groups, _ = train_test_split(X, groups, **split_options)
365 train_test_split(X, groups, **split_options)
366 else: 386 else:
367 X, X_test, y, y_test, groups, _ =\ 387 X, X_test, y, y_test, groups, _ = train_test_split(X, y, groups, **split_options)
368 train_test_split(X, y, groups, **split_options) 388 else:
369 else: 389 if split_options["shuffle"] == "None":
370 if split_options['shuffle'] == 'None': 390 split_options["shuffle"] = None
371 split_options['shuffle'] = None 391 X, X_test, y, y_test = train_test_split(X, y, **split_options)
372 X, X_test, y, y_test =\ 392
373 train_test_split(X, y, **split_options) 393 if error_score == "raise":
374
375 if error_score == 'raise':
376 searcher.fit(X, y, groups=groups) 394 searcher.fit(X, y, groups=groups)
377 else: 395 else:
378 warnings.simplefilter('always', FitFailedWarning) 396 warnings.simplefilter("always", FitFailedWarning)
379 with warnings.catch_warnings(record=True) as w: 397 with warnings.catch_warnings(record=True) as w:
380 try: 398 try:
381 searcher.fit(X, y, groups=groups) 399 searcher.fit(X, y, groups=groups)
382 except ValueError: 400 except ValueError:
383 pass 401 pass
388 if isinstance(scorer_, collections.Mapping): 406 if isinstance(scorer_, collections.Mapping):
389 is_multimetric = True 407 is_multimetric = True
390 else: 408 else:
391 is_multimetric = False 409 is_multimetric = False
392 410
393 best_estimator_ = getattr(searcher, 'best_estimator_') 411 best_estimator_ = getattr(searcher, "best_estimator_")
394 412
395 # TODO Solve deep learning models in pipeline 413 # TODO Solve deep learning models in pipeline
396 if best_estimator_.__class__.__name__ == 'KerasGBatchClassifier': 414 if best_estimator_.__class__.__name__ == "KerasGBatchClassifier":
397 test_score = best_estimator_.evaluate( 415 test_score = best_estimator_.evaluate(X_test, scorer=scorer_, is_multimetric=is_multimetric)
398 X_test, scorer=scorer_, is_multimetric=is_multimetric) 416 else:
399 else: 417 test_score = _score(best_estimator_, X_test, y_test, scorer_, is_multimetric=is_multimetric)
400 test_score = _score(best_estimator_, X_test,
401 y_test, scorer_,
402 is_multimetric=is_multimetric)
403 418
404 if not is_multimetric: 419 if not is_multimetric:
405 test_score = {primary_scoring: test_score} 420 test_score = {primary_scoring: test_score}
406 for key, value in test_score.items(): 421 for key, value in test_score.items():
407 test_score[key] = [value] 422 test_score[key] = [value]
408 result_df = pd.DataFrame(test_score) 423 result_df = pd.DataFrame(test_score)
409 result_df.to_csv(path_or_buf=outfile, sep='\t', header=True, 424 result_df.to_csv(path_or_buf=outfile, sep="\t", header=True, index=False)
410 index=False)
411 425
412 return searcher 426 return searcher
413 427
414 428
415 def main(inputs, infile_estimator, infile1, infile2, 429 def main(
416 outfile_result, outfile_object=None, 430 inputs,
417 outfile_weights=None, groups=None, 431 infile_estimator,
418 ref_seq=None, intervals=None, targets=None, 432 infile1,
419 fasta_path=None): 433 infile2,
434 outfile_result,
435 outfile_object=None,
436 outfile_weights=None,
437 groups=None,
438 ref_seq=None,
439 intervals=None,
440 targets=None,
441 fasta_path=None,
442 ):
420 """ 443 """
421 Parameter 444 Parameter
422 --------- 445 ---------
423 inputs : str 446 inputs : str
424 File path to galaxy tool parameter 447 File path to galaxy tool parameter
454 File path to dataset compressed target bed file 477 File path to dataset compressed target bed file
455 478
456 fasta_path : str 479 fasta_path : str
457 File path to dataset containing fasta file 480 File path to dataset containing fasta file
458 """ 481 """
459 warnings.simplefilter('ignore') 482 warnings.simplefilter("ignore")
460 483
461 # store read dataframe object 484 # store read dataframe object
462 loaded_df = {} 485 loaded_df = {}
463 486
464 with open(inputs, 'r') as param_handler: 487 with open(inputs, "r") as param_handler:
465 params = json.load(param_handler) 488 params = json.load(param_handler)
466 489
467 # Override the refit parameter 490 # Override the refit parameter
468 params['search_schemes']['options']['refit'] = True \ 491 params["search_schemes"]["options"]["refit"] = True if params["save"] != "nope" else False
469 if params['save'] != 'nope' else False 492
470 493 with open(infile_estimator, "rb") as estimator_handler:
471 with open(infile_estimator, 'rb') as estimator_handler:
472 estimator = load_model(estimator_handler) 494 estimator = load_model(estimator_handler)
473 495
474 optimizer = params['search_schemes']['selected_search_scheme'] 496 optimizer = params["search_schemes"]["selected_search_scheme"]
475 optimizer = getattr(model_selection, optimizer) 497 optimizer = getattr(model_selection, optimizer)
476 498
477 # handle gridsearchcv options 499 # handle gridsearchcv options
478 options = params['search_schemes']['options'] 500 options = params["search_schemes"]["options"]
479 501
480 if groups: 502 if groups:
481 header = 'infer' if (options['cv_selector']['groups_selector'] 503 header = "infer" if (options["cv_selector"]["groups_selector"]["header_g"]) else None
482 ['header_g']) else None 504 column_option = options["cv_selector"]["groups_selector"]["column_selector_options_g"][
483 column_option = (options['cv_selector']['groups_selector'] 505 "selected_column_selector_option_g"
484 ['column_selector_options_g'] 506 ]
485 ['selected_column_selector_option_g']) 507 if column_option in [
486 if column_option in ['by_index_number', 'all_but_by_index_number', 508 "by_index_number",
487 'by_header_name', 'all_but_by_header_name']: 509 "all_but_by_index_number",
488 c = (options['cv_selector']['groups_selector'] 510 "by_header_name",
489 ['column_selector_options_g']['col_g']) 511 "all_but_by_header_name",
512 ]:
513 c = options["cv_selector"]["groups_selector"]["column_selector_options_g"]["col_g"]
490 else: 514 else:
491 c = None 515 c = None
492 516
493 df_key = groups + repr(header) 517 df_key = groups + repr(header)
494 518
495 groups = pd.read_csv(groups, sep='\t', header=header, 519 groups = pd.read_csv(groups, sep="\t", header=header, parse_dates=True)
496 parse_dates=True)
497 loaded_df[df_key] = groups 520 loaded_df[df_key] = groups
498 521
499 groups = read_columns( 522 groups = read_columns(
500 groups, 523 groups,
501 c=c, 524 c=c,
502 c_option=column_option, 525 c_option=column_option,
503 sep='\t', 526 sep="\t",
504 header=header, 527 header=header,
505 parse_dates=True) 528 parse_dates=True,
529 )
506 groups = groups.ravel() 530 groups = groups.ravel()
507 options['cv_selector']['groups_selector'] = groups 531 options["cv_selector"]["groups_selector"] = groups
508 532
509 splitter, groups = get_cv(options.pop('cv_selector')) 533 splitter, groups = get_cv(options.pop("cv_selector"))
510 options['cv'] = splitter 534 options["cv"] = splitter
511 primary_scoring = options['scoring']['primary_scoring'] 535 primary_scoring = options["scoring"]["primary_scoring"]
512 options['scoring'] = get_scoring(options['scoring']) 536 # get_scoring() expects secondary_scoring to be a comma separated string (not a list)
513 if options['error_score']: 537 # Check if secondary_scoring is specified
514 options['error_score'] = 'raise' 538 secondary_scoring = options["scoring"].get("secondary_scoring", None)
515 else: 539 if secondary_scoring is not None:
516 options['error_score'] = np.NaN 540 # If secondary_scoring is specified, convert the list into comman separated string
517 if options['refit'] and isinstance(options['scoring'], dict): 541 options["scoring"]["secondary_scoring"] = ",".join(options["scoring"]["secondary_scoring"])
518 options['refit'] = primary_scoring 542 options["scoring"] = get_scoring(options["scoring"])
519 if 'pre_dispatch' in options and options['pre_dispatch'] == '': 543 if options["error_score"]:
520 options['pre_dispatch'] = None 544 options["error_score"] = "raise"
521 545 else:
522 params_builder = params['search_schemes']['search_params_builder'] 546 options["error_score"] = np.NaN
547 if options["refit"] and isinstance(options["scoring"], dict):
548 options["refit"] = primary_scoring
549 if "pre_dispatch" in options and options["pre_dispatch"] == "":
550 options["pre_dispatch"] = None
551
552 params_builder = params["search_schemes"]["search_params_builder"]
523 param_grid = _eval_search_params(params_builder) 553 param_grid = _eval_search_params(params_builder)
524 554
525 estimator = clean_params(estimator) 555 estimator = clean_params(estimator)
526 556
527 # save the SearchCV object without fit 557 # save the SearchCV object without fit
528 if params['save'] == 'save_no_fit': 558 if params["save"] == "save_no_fit":
529 searcher = optimizer(estimator, param_grid, **options) 559 searcher = optimizer(estimator, param_grid, **options)
530 print(searcher) 560 print(searcher)
531 with open(outfile_object, 'wb') as output_handler: 561 with open(outfile_object, "wb") as output_handler:
532 pickle.dump(searcher, output_handler, 562 pickle.dump(searcher, output_handler, pickle.HIGHEST_PROTOCOL)
533 pickle.HIGHEST_PROTOCOL)
534 return 0 563 return 0
535 564
536 # read inputs and loads new attributes, like paths 565 # read inputs and loads new attributes, like paths
537 estimator, X, y = _handle_X_y(estimator, params, infile1, infile2, 566 estimator, X, y = _handle_X_y(
538 loaded_df=loaded_df, ref_seq=ref_seq, 567 estimator,
539 intervals=intervals, targets=targets, 568 params,
540 fasta_path=fasta_path) 569 infile1,
570 infile2,
571 loaded_df=loaded_df,
572 ref_seq=ref_seq,
573 intervals=intervals,
574 targets=targets,
575 fasta_path=fasta_path,
576 )
541 577
542 # cache iraps_core fits could increase search speed significantly 578 # cache iraps_core fits could increase search speed significantly
543 memory = joblib.Memory(location=CACHE_DIR, verbose=0) 579 memory = joblib.Memory(location=CACHE_DIR, verbose=0)
544 main_est = get_main_estimator(estimator) 580 main_est = get_main_estimator(estimator)
545 if main_est.__class__.__name__ == 'IRAPSClassifier': 581 if main_est.__class__.__name__ == "IRAPSClassifier":
546 main_est.set_params(memory=memory) 582 main_est.set_params(memory=memory)
547 583
548 searcher = optimizer(estimator, param_grid, **options) 584 searcher = optimizer(estimator, param_grid, **options)
549 585
550 split_mode = params['outer_split'].pop('split_mode') 586 split_mode = params["outer_split"].pop("split_mode")
551 587
552 if split_mode == 'nested_cv': 588 if split_mode == "nested_cv":
553 # make sure refit is choosen 589 # make sure refit is choosen
554 # this could be True for sklearn models, but not the case for 590 # this could be True for sklearn models, but not the case for
555 # deep learning models 591 # deep learning models
556 if not options['refit'] and \ 592 if not options["refit"] and not all(hasattr(estimator, attr) for attr in ("config", "model_type")):
557 not all(hasattr(estimator, attr)
558 for attr in ('config', 'model_type')):
559 warnings.warn("Refit is change to `True` for nested validation!") 593 warnings.warn("Refit is change to `True` for nested validation!")
560 setattr(searcher, 'refit', True) 594 setattr(searcher, "refit", True)
561 595
562 outer_cv, _ = get_cv(params['outer_split']['cv_selector']) 596 outer_cv, _ = get_cv(params["outer_split"]["cv_selector"])
563 # nested CV, outer cv using cross_validate 597 # nested CV, outer cv using cross_validate
564 if options['error_score'] == 'raise': 598 if options["error_score"] == "raise":
565 rval = cross_validate( 599 rval = cross_validate(
566 searcher, X, y, scoring=options['scoring'], 600 searcher,
567 cv=outer_cv, n_jobs=N_JOBS, 601 X,
568 verbose=options['verbose'], 602 y,
569 return_estimator=(params['save'] == 'save_estimator'), 603 scoring=options["scoring"],
570 error_score=options['error_score'], 604 cv=outer_cv,
571 return_train_score=True) 605 n_jobs=N_JOBS,
606 verbose=options["verbose"],
607 return_estimator=(params["save"] == "save_estimator"),
608 error_score=options["error_score"],
609 return_train_score=True,
610 )
572 else: 611 else:
573 warnings.simplefilter('always', FitFailedWarning) 612 warnings.simplefilter("always", FitFailedWarning)
574 with warnings.catch_warnings(record=True) as w: 613 with warnings.catch_warnings(record=True) as w:
575 try: 614 try:
576 rval = cross_validate( 615 rval = cross_validate(
577 searcher, X, y, 616 searcher,
578 scoring=options['scoring'], 617 X,
579 cv=outer_cv, n_jobs=N_JOBS, 618 y,
580 verbose=options['verbose'], 619 scoring=options["scoring"],
581 return_estimator=(params['save'] == 'save_estimator'), 620 cv=outer_cv,
582 error_score=options['error_score'], 621 n_jobs=N_JOBS,
583 return_train_score=True) 622 verbose=options["verbose"],
623 return_estimator=(params["save"] == "save_estimator"),
624 error_score=options["error_score"],
625 return_train_score=True,
626 )
584 except ValueError: 627 except ValueError:
585 pass 628 pass
586 for warning in w: 629 for warning in w:
587 print(repr(warning.message)) 630 print(repr(warning.message))
588 631
589 fitted_searchers = rval.pop('estimator', []) 632 fitted_searchers = rval.pop("estimator", [])
590 if fitted_searchers: 633 if fitted_searchers:
591 import os 634 import os
635
592 pwd = os.getcwd() 636 pwd = os.getcwd()
593 save_dir = os.path.join(pwd, 'cv_results_in_folds') 637 save_dir = os.path.join(pwd, "cv_results_in_folds")
594 try: 638 try:
595 os.mkdir(save_dir) 639 os.mkdir(save_dir)
596 for idx, obj in enumerate(fitted_searchers): 640 for idx, obj in enumerate(fitted_searchers):
597 target_name = 'cv_results_' + '_' + 'split%d' % idx 641 target_name = "cv_results_" + "_" + "split%d" % idx
598 target_path = os.path.join(pwd, save_dir, target_name) 642 target_path = os.path.join(pwd, save_dir, target_name)
599 cv_results_ = getattr(obj, 'cv_results_', None) 643 cv_results_ = getattr(obj, "cv_results_", None)
600 if not cv_results_: 644 if not cv_results_:
601 print("%s is not available" % target_name) 645 print("%s is not available" % target_name)
602 continue 646 continue
603 cv_results_ = pd.DataFrame(cv_results_) 647 cv_results_ = pd.DataFrame(cv_results_)
604 cv_results_ = cv_results_[sorted(cv_results_.columns)] 648 cv_results_ = cv_results_[sorted(cv_results_.columns)]
605 cv_results_.to_csv(target_path, sep='\t', header=True, 649 cv_results_.to_csv(target_path, sep="\t", header=True, index=False)
606 index=False)
607 except Exception as e: 650 except Exception as e:
608 print(e) 651 print(e)
609 finally: 652 finally:
610 del os 653 del os
611 654
612 keys = list(rval.keys()) 655 keys = list(rval.keys())
613 for k in keys: 656 for k in keys:
614 if k.startswith('test'): 657 if k.startswith("test"):
615 rval['mean_' + k] = np.mean(rval[k]) 658 rval["mean_" + k] = np.mean(rval[k])
616 rval['std_' + k] = np.std(rval[k]) 659 rval["std_" + k] = np.std(rval[k])
617 if k.endswith('time'): 660 if k.endswith("time"):
618 rval.pop(k) 661 rval.pop(k)
619 rval = pd.DataFrame(rval) 662 rval = pd.DataFrame(rval)
620 rval = rval[sorted(rval.columns)] 663 rval = rval[sorted(rval.columns)]
621 rval.to_csv(path_or_buf=outfile_result, sep='\t', header=True, 664 rval.to_csv(path_or_buf=outfile_result, sep="\t", header=True, index=False)
622 index=False)
623
624 return 0
625
626 # deprecate train test split mode 665 # deprecate train test split mode
627 """searcher = _do_train_test_split_val( 666 """searcher = _do_train_test_split_val(
628 searcher, X, y, params, 667 searcher, X, y, params,
629 primary_scoring=primary_scoring, 668 primary_scoring=primary_scoring,
630 error_score=options['error_score'], 669 error_score=options['error_score'],
631 groups=groups, 670 groups=groups,
632 outfile=outfile_result)""" 671 outfile=outfile_result)"""
672 return 0
633 673
634 # no outer split 674 # no outer split
635 else: 675 else:
636 searcher.set_params(n_jobs=N_JOBS) 676 searcher.set_params(n_jobs=N_JOBS)
637 if options['error_score'] == 'raise': 677 if options["error_score"] == "raise":
638 searcher.fit(X, y, groups=groups) 678 searcher.fit(X, y, groups=groups)
639 else: 679 else:
640 warnings.simplefilter('always', FitFailedWarning) 680 warnings.simplefilter("always", FitFailedWarning)
641 with warnings.catch_warnings(record=True) as w: 681 with warnings.catch_warnings(record=True) as w:
642 try: 682 try:
643 searcher.fit(X, y, groups=groups) 683 searcher.fit(X, y, groups=groups)
644 except ValueError: 684 except ValueError:
645 pass 685 pass
646 for warning in w: 686 for warning in w:
647 print(repr(warning.message)) 687 print(repr(warning.message))
648 688
649 cv_results = pd.DataFrame(searcher.cv_results_) 689 cv_results = pd.DataFrame(searcher.cv_results_)
650 cv_results = cv_results[sorted(cv_results.columns)] 690 cv_results = cv_results[sorted(cv_results.columns)]
651 cv_results.to_csv(path_or_buf=outfile_result, sep='\t', 691 cv_results.to_csv(path_or_buf=outfile_result, sep="\t", header=True, index=False)
652 header=True, index=False)
653 692
654 memory.clear(warn=False) 693 memory.clear(warn=False)
655 694
656 # output best estimator, and weights if applicable 695 # output best estimator, and weights if applicable
657 if outfile_object: 696 if outfile_object:
658 best_estimator_ = getattr(searcher, 'best_estimator_', None) 697 best_estimator_ = getattr(searcher, "best_estimator_", None)
659 if not best_estimator_: 698 if not best_estimator_:
660 warnings.warn("GridSearchCV object has no attribute " 699 warnings.warn(
661 "'best_estimator_', because either it's " 700 "GridSearchCV object has no attribute "
662 "nested gridsearch or `refit` is False!") 701 "'best_estimator_', because either it's "
702 "nested gridsearch or `refit` is False!"
703 )
663 return 704 return
664 705
665 # clean prams 706 # clean prams
666 best_estimator_ = clean_params(best_estimator_) 707 best_estimator_ = clean_params(best_estimator_)
667 708
668 main_est = get_main_estimator(best_estimator_) 709 main_est = get_main_estimator(best_estimator_)
669 710
670 if hasattr(main_est, 'model_') \ 711 if hasattr(main_est, "model_") and hasattr(main_est, "save_weights"):
671 and hasattr(main_est, 'save_weights'):
672 if outfile_weights: 712 if outfile_weights:
673 main_est.save_weights(outfile_weights) 713 main_est.save_weights(outfile_weights)
674 del main_est.model_ 714 del main_est.model_
675 del main_est.fit_params 715 del main_est.fit_params
676 del main_est.model_class_ 716 del main_est.model_class_
677 del main_est.validation_data 717 del main_est.validation_data
678 if getattr(main_est, 'data_generator_', None): 718 if getattr(main_est, "data_generator_", None):
679 del main_est.data_generator_ 719 del main_est.data_generator_
680 720
681 with open(outfile_object, 'wb') as output_handler: 721 with open(outfile_object, "wb") as output_handler:
682 print("Best estimator is saved: %s " % repr(best_estimator_)) 722 print("Best estimator is saved: %s " % repr(best_estimator_))
683 pickle.dump(best_estimator_, output_handler, 723 pickle.dump(best_estimator_, output_handler, pickle.HIGHEST_PROTOCOL)
684 pickle.HIGHEST_PROTOCOL) 724
685 725
686 726 if __name__ == "__main__":
687 if __name__ == '__main__':
688 aparser = argparse.ArgumentParser() 727 aparser = argparse.ArgumentParser()
689 aparser.add_argument("-i", "--inputs", dest="inputs", required=True) 728 aparser.add_argument("-i", "--inputs", dest="inputs", required=True)
690 aparser.add_argument("-e", "--estimator", dest="infile_estimator") 729 aparser.add_argument("-e", "--estimator", dest="infile_estimator")
691 aparser.add_argument("-X", "--infile1", dest="infile1") 730 aparser.add_argument("-X", "--infile1", dest="infile1")
692 aparser.add_argument("-y", "--infile2", dest="infile2") 731 aparser.add_argument("-y", "--infile2", dest="infile2")
698 aparser.add_argument("-b", "--intervals", dest="intervals") 737 aparser.add_argument("-b", "--intervals", dest="intervals")
699 aparser.add_argument("-t", "--targets", dest="targets") 738 aparser.add_argument("-t", "--targets", dest="targets")
700 aparser.add_argument("-f", "--fasta_path", dest="fasta_path") 739 aparser.add_argument("-f", "--fasta_path", dest="fasta_path")
701 args = aparser.parse_args() 740 args = aparser.parse_args()
702 741
703 main(args.inputs, args.infile_estimator, args.infile1, args.infile2, 742 main(
704 args.outfile_result, outfile_object=args.outfile_object, 743 args.inputs,
705 outfile_weights=args.outfile_weights, groups=args.groups, 744 args.infile_estimator,
706 ref_seq=args.ref_seq, intervals=args.intervals, 745 args.infile1,
707 targets=args.targets, fasta_path=args.fasta_path) 746 args.infile2,
747 args.outfile_result,
748 outfile_object=args.outfile_object,
749 outfile_weights=args.outfile_weights,
750 groups=args.groups,
751 ref_seq=args.ref_seq,
752 intervals=args.intervals,
753 targets=args.targets,
754 fasta_path=args.fasta_path,
755 )