comparison train_test_eval.py @ 6:13b9ac5d277c draft

"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 208a8d348e7c7a182cfbe1b6f17868146428a7e2"
author bgruening
date Tue, 13 Apr 2021 22:24:07 +0000
parents 5a092779412e
children 3312fb686ffb
comparison
equal deleted inserted replaced
5:ce2fd1edbc6e 6:13b9ac5d277c
1 import argparse 1 import argparse
2 import joblib
3 import json 2 import json
4 import numpy as np
5 import os 3 import os
6 import pandas as pd
7 import pickle 4 import pickle
8 import warnings 5 import warnings
9 from itertools import chain 6 from itertools import chain
7
8 import joblib
9 import numpy as np
10 import pandas as pd
11 from galaxy_ml.model_validations import train_test_split
12 from galaxy_ml.utils import (
13 get_module,
14 get_scoring,
15 load_model,
16 read_columns,
17 SafeEval,
18 try_get_attr,
19 )
10 from scipy.io import mmread 20 from scipy.io import mmread
11 from sklearn.base import clone 21 from sklearn import pipeline
12 from sklearn import (cluster, compose, decomposition, ensemble,
13 feature_extraction, feature_selection,
14 gaussian_process, kernel_approximation, metrics,
15 model_selection, naive_bayes, neighbors,
16 pipeline, preprocessing, svm, linear_model,
17 tree, discriminant_analysis)
18 from sklearn.exceptions import FitFailedWarning
19 from sklearn.metrics.scorer import _check_multimetric_scoring 22 from sklearn.metrics.scorer import _check_multimetric_scoring
20 from sklearn.model_selection._validation import _score, cross_validate
21 from sklearn.model_selection import _search, _validation 23 from sklearn.model_selection import _search, _validation
24 from sklearn.model_selection._validation import _score
22 from sklearn.utils import indexable, safe_indexing 25 from sklearn.utils import indexable, safe_indexing
23 26
24 from galaxy_ml.model_validations import train_test_split 27
25 from galaxy_ml.utils import (SafeEval, get_scoring, load_model, 28 _fit_and_score = try_get_attr("galaxy_ml.model_validations", "_fit_and_score")
26 read_columns, try_get_attr, get_module) 29 setattr(_search, "_fit_and_score", _fit_and_score)
27 30 setattr(_validation, "_fit_and_score", _fit_and_score)
28 31
29 _fit_and_score = try_get_attr('galaxy_ml.model_validations', '_fit_and_score') 32 N_JOBS = int(os.environ.get("GALAXY_SLOTS", 1))
30 setattr(_search, '_fit_and_score', _fit_and_score) 33 CACHE_DIR = os.path.join(os.getcwd(), "cached")
31 setattr(_validation, '_fit_and_score', _fit_and_score)
32
33 N_JOBS = int(os.environ.get('GALAXY_SLOTS', 1))
34 CACHE_DIR = os.path.join(os.getcwd(), 'cached')
35 del os 34 del os
36 NON_SEARCHABLE = ('n_jobs', 'pre_dispatch', 'memory', '_path', 35 NON_SEARCHABLE = ("n_jobs", "pre_dispatch", "memory", "_path", "nthread", "callbacks")
37 'nthread', 'callbacks') 36 ALLOWED_CALLBACKS = (
38 ALLOWED_CALLBACKS = ('EarlyStopping', 'TerminateOnNaN', 'ReduceLROnPlateau', 37 "EarlyStopping",
39 'CSVLogger', 'None') 38 "TerminateOnNaN",
39 "ReduceLROnPlateau",
40 "CSVLogger",
41 "None",
42 )
40 43
41 44
42 def _eval_swap_params(params_builder): 45 def _eval_swap_params(params_builder):
43 swap_params = {} 46 swap_params = {}
44 47
45 for p in params_builder['param_set']: 48 for p in params_builder["param_set"]:
46 swap_value = p['sp_value'].strip() 49 swap_value = p["sp_value"].strip()
47 if swap_value == '': 50 if swap_value == "":
48 continue 51 continue
49 52
50 param_name = p['sp_name'] 53 param_name = p["sp_name"]
51 if param_name.lower().endswith(NON_SEARCHABLE): 54 if param_name.lower().endswith(NON_SEARCHABLE):
52 warnings.warn("Warning: `%s` is not eligible for search and was " 55 warnings.warn(
53 "omitted!" % param_name) 56 "Warning: `%s` is not eligible for search and was "
57 "omitted!" % param_name
58 )
54 continue 59 continue
55 60
56 if not swap_value.startswith(':'): 61 if not swap_value.startswith(":"):
57 safe_eval = SafeEval(load_scipy=True, load_numpy=True) 62 safe_eval = SafeEval(load_scipy=True, load_numpy=True)
58 ev = safe_eval(swap_value) 63 ev = safe_eval(swap_value)
59 else: 64 else:
60 # Have `:` before search list, asks for estimator evaluatio 65 # Have `:` before search list, asks for estimator evaluatio
61 safe_eval_es = SafeEval(load_estimators=True) 66 safe_eval_es = SafeEval(load_estimators=True)
78 if arr is None: 83 if arr is None:
79 nones.append(idx) 84 nones.append(idx)
80 else: 85 else:
81 new_arrays.append(arr) 86 new_arrays.append(arr)
82 87
83 if kwargs['shuffle'] == 'None': 88 if kwargs["shuffle"] == "None":
84 kwargs['shuffle'] = None 89 kwargs["shuffle"] = None
85 90
86 group_names = kwargs.pop('group_names', None) 91 group_names = kwargs.pop("group_names", None)
87 92
88 if group_names is not None and group_names.strip(): 93 if group_names is not None and group_names.strip():
89 group_names = [name.strip() for name in 94 group_names = [name.strip() for name in group_names.split(",")]
90 group_names.split(',')]
91 new_arrays = indexable(*new_arrays) 95 new_arrays = indexable(*new_arrays)
92 groups = kwargs['labels'] 96 groups = kwargs["labels"]
93 n_samples = new_arrays[0].shape[0] 97 n_samples = new_arrays[0].shape[0]
94 index_arr = np.arange(n_samples) 98 index_arr = np.arange(n_samples)
95 test = index_arr[np.isin(groups, group_names)] 99 test = index_arr[np.isin(groups, group_names)]
96 train = index_arr[~np.isin(groups, group_names)] 100 train = index_arr[~np.isin(groups, group_names)]
97 rval = list(chain.from_iterable( 101 rval = list(
98 (safe_indexing(a, train), 102 chain.from_iterable(
99 safe_indexing(a, test)) for a in new_arrays)) 103 (safe_indexing(a, train), safe_indexing(a, test)) for a in new_arrays
104 )
105 )
100 else: 106 else:
101 rval = train_test_split(*new_arrays, **kwargs) 107 rval = train_test_split(*new_arrays, **kwargs)
102 108
103 for pos in nones: 109 for pos in nones:
104 rval[pos * 2: 2] = [None, None] 110 rval[pos * 2: 2] = [None, None]
105 111
106 return rval 112 return rval
107 113
108 114
109 def main(inputs, infile_estimator, infile1, infile2, 115 def main(
110 outfile_result, outfile_object=None, 116 inputs,
111 outfile_weights=None, groups=None, 117 infile_estimator,
112 ref_seq=None, intervals=None, targets=None, 118 infile1,
113 fasta_path=None): 119 infile2,
120 outfile_result,
121 outfile_object=None,
122 outfile_weights=None,
123 groups=None,
124 ref_seq=None,
125 intervals=None,
126 targets=None,
127 fasta_path=None,
128 ):
114 """ 129 """
115 Parameter 130 Parameter
116 --------- 131 ---------
117 inputs : str 132 inputs : str
118 File path to galaxy tool parameter 133 File path to galaxy tool parameter
148 File path to dataset compressed target bed file 163 File path to dataset compressed target bed file
149 164
150 fasta_path : str 165 fasta_path : str
151 File path to dataset containing fasta file 166 File path to dataset containing fasta file
152 """ 167 """
153 warnings.simplefilter('ignore') 168 warnings.simplefilter("ignore")
154 169
155 with open(inputs, 'r') as param_handler: 170 with open(inputs, "r") as param_handler:
156 params = json.load(param_handler) 171 params = json.load(param_handler)
157 172
158 # load estimator 173 # load estimator
159 with open(infile_estimator, 'rb') as estimator_handler: 174 with open(infile_estimator, "rb") as estimator_handler:
160 estimator = load_model(estimator_handler) 175 estimator = load_model(estimator_handler)
161 176
162 # swap hyperparameter 177 # swap hyperparameter
163 swapping = params['experiment_schemes']['hyperparams_swapping'] 178 swapping = params["experiment_schemes"]["hyperparams_swapping"]
164 swap_params = _eval_swap_params(swapping) 179 swap_params = _eval_swap_params(swapping)
165 estimator.set_params(**swap_params) 180 estimator.set_params(**swap_params)
166 181
167 estimator_params = estimator.get_params() 182 estimator_params = estimator.get_params()
168 183
169 # store read dataframe object 184 # store read dataframe object
170 loaded_df = {} 185 loaded_df = {}
171 186
172 input_type = params['input_options']['selected_input'] 187 input_type = params["input_options"]["selected_input"]
173 # tabular input 188 # tabular input
174 if input_type == 'tabular': 189 if input_type == "tabular":
175 header = 'infer' if params['input_options']['header1'] else None 190 header = "infer" if params["input_options"]["header1"] else None
176 column_option = (params['input_options']['column_selector_options_1'] 191 column_option = params["input_options"]["column_selector_options_1"][
177 ['selected_column_selector_option']) 192 "selected_column_selector_option"
178 if column_option in ['by_index_number', 'all_but_by_index_number', 193 ]
179 'by_header_name', 'all_but_by_header_name']: 194 if column_option in [
180 c = params['input_options']['column_selector_options_1']['col1'] 195 "by_index_number",
196 "all_but_by_index_number",
197 "by_header_name",
198 "all_but_by_header_name",
199 ]:
200 c = params["input_options"]["column_selector_options_1"]["col1"]
181 else: 201 else:
182 c = None 202 c = None
183 203
184 df_key = infile1 + repr(header) 204 df_key = infile1 + repr(header)
185 df = pd.read_csv(infile1, sep='\t', header=header, 205 df = pd.read_csv(infile1, sep="\t", header=header, parse_dates=True)
186 parse_dates=True)
187 loaded_df[df_key] = df 206 loaded_df[df_key] = df
188 207
189 X = read_columns(df, c=c, c_option=column_option).astype(float) 208 X = read_columns(df, c=c, c_option=column_option).astype(float)
190 # sparse input 209 # sparse input
191 elif input_type == 'sparse': 210 elif input_type == "sparse":
192 X = mmread(open(infile1, 'r')) 211 X = mmread(open(infile1, "r"))
193 212
194 # fasta_file input 213 # fasta_file input
195 elif input_type == 'seq_fasta': 214 elif input_type == "seq_fasta":
196 pyfaidx = get_module('pyfaidx') 215 pyfaidx = get_module("pyfaidx")
197 sequences = pyfaidx.Fasta(fasta_path) 216 sequences = pyfaidx.Fasta(fasta_path)
198 n_seqs = len(sequences.keys()) 217 n_seqs = len(sequences.keys())
199 X = np.arange(n_seqs)[:, np.newaxis] 218 X = np.arange(n_seqs)[:, np.newaxis]
200 for param in estimator_params.keys(): 219 for param in estimator_params.keys():
201 if param.endswith('fasta_path'): 220 if param.endswith("fasta_path"):
202 estimator.set_params( 221 estimator.set_params(**{param: fasta_path})
203 **{param: fasta_path})
204 break 222 break
205 else: 223 else:
206 raise ValueError( 224 raise ValueError(
207 "The selected estimator doesn't support " 225 "The selected estimator doesn't support "
208 "fasta file input! Please consider using " 226 "fasta file input! Please consider using "
209 "KerasGBatchClassifier with " 227 "KerasGBatchClassifier with "
210 "FastaDNABatchGenerator/FastaProteinBatchGenerator " 228 "FastaDNABatchGenerator/FastaProteinBatchGenerator "
211 "or having GenomeOneHotEncoder/ProteinOneHotEncoder " 229 "or having GenomeOneHotEncoder/ProteinOneHotEncoder "
212 "in pipeline!") 230 "in pipeline!"
213 231 )
214 elif input_type == 'refseq_and_interval': 232
233 elif input_type == "refseq_and_interval":
215 path_params = { 234 path_params = {
216 'data_batch_generator__ref_genome_path': ref_seq, 235 "data_batch_generator__ref_genome_path": ref_seq,
217 'data_batch_generator__intervals_path': intervals, 236 "data_batch_generator__intervals_path": intervals,
218 'data_batch_generator__target_path': targets 237 "data_batch_generator__target_path": targets,
219 } 238 }
220 estimator.set_params(**path_params) 239 estimator.set_params(**path_params)
221 n_intervals = sum(1 for line in open(intervals)) 240 n_intervals = sum(1 for line in open(intervals))
222 X = np.arange(n_intervals)[:, np.newaxis] 241 X = np.arange(n_intervals)[:, np.newaxis]
223 242
224 # Get target y 243 # Get target y
225 header = 'infer' if params['input_options']['header2'] else None 244 header = "infer" if params["input_options"]["header2"] else None
226 column_option = (params['input_options']['column_selector_options_2'] 245 column_option = params["input_options"]["column_selector_options_2"][
227 ['selected_column_selector_option2']) 246 "selected_column_selector_option2"
228 if column_option in ['by_index_number', 'all_but_by_index_number', 247 ]
229 'by_header_name', 'all_but_by_header_name']: 248 if column_option in [
230 c = params['input_options']['column_selector_options_2']['col2'] 249 "by_index_number",
250 "all_but_by_index_number",
251 "by_header_name",
252 "all_but_by_header_name",
253 ]:
254 c = params["input_options"]["column_selector_options_2"]["col2"]
231 else: 255 else:
232 c = None 256 c = None
233 257
234 df_key = infile2 + repr(header) 258 df_key = infile2 + repr(header)
235 if df_key in loaded_df: 259 if df_key in loaded_df:
236 infile2 = loaded_df[df_key] 260 infile2 = loaded_df[df_key]
237 else: 261 else:
238 infile2 = pd.read_csv(infile2, sep='\t', 262 infile2 = pd.read_csv(infile2, sep="\t", header=header, parse_dates=True)
239 header=header, parse_dates=True)
240 loaded_df[df_key] = infile2 263 loaded_df[df_key] = infile2
241 264
242 y = read_columns( 265 y = read_columns(infile2,
243 infile2, 266 c=c,
244 c=c, 267 c_option=column_option,
245 c_option=column_option, 268 sep='\t',
246 sep='\t', 269 header=header,
247 header=header, 270 parse_dates=True)
248 parse_dates=True)
249 if len(y.shape) == 2 and y.shape[1] == 1: 271 if len(y.shape) == 2 and y.shape[1] == 1:
250 y = y.ravel() 272 y = y.ravel()
251 if input_type == 'refseq_and_interval': 273 if input_type == "refseq_and_interval":
252 estimator.set_params( 274 estimator.set_params(data_batch_generator__features=y.ravel().tolist())
253 data_batch_generator__features=y.ravel().tolist())
254 y = None 275 y = None
255 # end y 276 # end y
256 277
257 # load groups 278 # load groups
258 if groups: 279 if groups:
259 groups_selector = (params['experiment_schemes']['test_split'] 280 groups_selector = (
260 ['split_algos']).pop('groups_selector') 281 params["experiment_schemes"]["test_split"]["split_algos"]
261 282 ).pop("groups_selector")
262 header = 'infer' if groups_selector['header_g'] else None 283
263 column_option = \ 284 header = "infer" if groups_selector["header_g"] else None
264 (groups_selector['column_selector_options_g'] 285 column_option = groups_selector["column_selector_options_g"][
265 ['selected_column_selector_option_g']) 286 "selected_column_selector_option_g"
266 if column_option in ['by_index_number', 'all_but_by_index_number', 287 ]
267 'by_header_name', 'all_but_by_header_name']: 288 if column_option in [
268 c = groups_selector['column_selector_options_g']['col_g'] 289 "by_index_number",
290 "all_but_by_index_number",
291 "by_header_name",
292 "all_but_by_header_name",
293 ]:
294 c = groups_selector["column_selector_options_g"]["col_g"]
269 else: 295 else:
270 c = None 296 c = None
271 297
272 df_key = groups + repr(header) 298 df_key = groups + repr(header)
273 if df_key in loaded_df: 299 if df_key in loaded_df:
274 groups = loaded_df[df_key] 300 groups = loaded_df[df_key]
275 301
276 groups = read_columns( 302 groups = read_columns(groups,
277 groups, 303 c=c,
278 c=c, 304 c_option=column_option,
279 c_option=column_option, 305 sep='\t',
280 sep='\t', 306 header=header,
281 header=header, 307 parse_dates=True)
282 parse_dates=True)
283 groups = groups.ravel() 308 groups = groups.ravel()
284 309
285 # del loaded_df 310 # del loaded_df
286 del loaded_df 311 del loaded_df
287 312
288 # handle memory 313 # handle memory
289 memory = joblib.Memory(location=CACHE_DIR, verbose=0) 314 memory = joblib.Memory(location=CACHE_DIR, verbose=0)
290 # cache iraps_core fits could increase search speed significantly 315 # cache iraps_core fits could increase search speed significantly
291 if estimator.__class__.__name__ == 'IRAPSClassifier': 316 if estimator.__class__.__name__ == "IRAPSClassifier":
292 estimator.set_params(memory=memory) 317 estimator.set_params(memory=memory)
293 else: 318 else:
294 # For iraps buried in pipeline 319 # For iraps buried in pipeline
295 new_params = {} 320 new_params = {}
296 for p, v in estimator_params.items(): 321 for p, v in estimator_params.items():
297 if p.endswith('memory'): 322 if p.endswith("memory"):
298 # for case of `__irapsclassifier__memory` 323 # for case of `__irapsclassifier__memory`
299 if len(p) > 8 and p[:-8].endswith('irapsclassifier'): 324 if len(p) > 8 and p[:-8].endswith("irapsclassifier"):
300 # cache iraps_core fits could increase search 325 # cache iraps_core fits could increase search
301 # speed significantly 326 # speed significantly
302 new_params[p] = memory 327 new_params[p] = memory
303 # security reason, we don't want memory being 328 # security reason, we don't want memory being
304 # modified unexpectedly 329 # modified unexpectedly
305 elif v: 330 elif v:
306 new_params[p] = None 331 new_params[p] = None
307 # handle n_jobs 332 # handle n_jobs
308 elif p.endswith('n_jobs'): 333 elif p.endswith("n_jobs"):
309 # For now, 1 CPU is suggested for iprasclassifier 334 # For now, 1 CPU is suggested for iprasclassifier
310 if len(p) > 8 and p[:-8].endswith('irapsclassifier'): 335 if len(p) > 8 and p[:-8].endswith("irapsclassifier"):
311 new_params[p] = 1 336 new_params[p] = 1
312 else: 337 else:
313 new_params[p] = N_JOBS 338 new_params[p] = N_JOBS
314 # for security reason, types of callback are limited 339 # for security reason, types of callback are limited
315 elif p.endswith('callbacks'): 340 elif p.endswith("callbacks"):
316 for cb in v: 341 for cb in v:
317 cb_type = cb['callback_selection']['callback_type'] 342 cb_type = cb["callback_selection"]["callback_type"]
318 if cb_type not in ALLOWED_CALLBACKS: 343 if cb_type not in ALLOWED_CALLBACKS:
319 raise ValueError( 344 raise ValueError("Prohibited callback type: %s!" % cb_type)
320 "Prohibited callback type: %s!" % cb_type)
321 345
322 estimator.set_params(**new_params) 346 estimator.set_params(**new_params)
323 347
324 # handle scorer, convert to scorer dict 348 # handle scorer, convert to scorer dict
325 scoring = params['experiment_schemes']['metrics']['scoring'] 349 # Check if scoring is specified
350 scoring = params["experiment_schemes"]["metrics"].get("scoring", None)
351 if scoring is not None:
352 # get_scoring() expects secondary_scoring to be a comma separated string (not a list)
353 # Check if secondary_scoring is specified
354 secondary_scoring = scoring.get("secondary_scoring", None)
355 if secondary_scoring is not None:
356 # If secondary_scoring is specified, convert the list into comman separated string
357 scoring["secondary_scoring"] = ",".join(scoring["secondary_scoring"])
326 scorer = get_scoring(scoring) 358 scorer = get_scoring(scoring)
327 scorer, _ = _check_multimetric_scoring(estimator, scoring=scorer) 359 scorer, _ = _check_multimetric_scoring(estimator, scoring=scorer)
328 360
329 # handle test (first) split 361 # handle test (first) split
330 test_split_options = (params['experiment_schemes'] 362 test_split_options = params["experiment_schemes"]["test_split"]["split_algos"]
331 ['test_split']['split_algos']) 363
332 364 if test_split_options["shuffle"] == "group":
333 if test_split_options['shuffle'] == 'group': 365 test_split_options["labels"] = groups
334 test_split_options['labels'] = groups 366 if test_split_options["shuffle"] == "stratified":
335 if test_split_options['shuffle'] == 'stratified':
336 if y is not None: 367 if y is not None:
337 test_split_options['labels'] = y 368 test_split_options["labels"] = y
338 else: 369 else:
339 raise ValueError("Stratified shuffle split is not " 370 raise ValueError(
340 "applicable on empty target values!") 371 "Stratified shuffle split is not " "applicable on empty target values!"
341 372 )
342 X_train, X_test, y_train, y_test, groups_train, groups_test = \ 373
343 train_test_split_none(X, y, groups, **test_split_options) 374 X_train, X_test, y_train, y_test, groups_train, _groups_test = train_test_split_none(
344 375 X, y, groups, **test_split_options
345 exp_scheme = params['experiment_schemes']['selected_exp_scheme'] 376 )
377
378 exp_scheme = params["experiment_schemes"]["selected_exp_scheme"]
346 379
347 # handle validation (second) split 380 # handle validation (second) split
348 if exp_scheme == 'train_val_test': 381 if exp_scheme == "train_val_test":
349 val_split_options = (params['experiment_schemes'] 382 val_split_options = params["experiment_schemes"]["val_split"]["split_algos"]
350 ['val_split']['split_algos']) 383
351 384 if val_split_options["shuffle"] == "group":
352 if val_split_options['shuffle'] == 'group': 385 val_split_options["labels"] = groups_train
353 val_split_options['labels'] = groups_train 386 if val_split_options["shuffle"] == "stratified":
354 if val_split_options['shuffle'] == 'stratified':
355 if y_train is not None: 387 if y_train is not None:
356 val_split_options['labels'] = y_train 388 val_split_options["labels"] = y_train
357 else: 389 else:
358 raise ValueError("Stratified shuffle split is not " 390 raise ValueError(
359 "applicable on empty target values!") 391 "Stratified shuffle split is not "
360 392 "applicable on empty target values!"
361 X_train, X_val, y_train, y_val, groups_train, groups_val = \ 393 )
362 train_test_split_none(X_train, y_train, groups_train, 394
363 **val_split_options) 395 (
396 X_train,
397 X_val,
398 y_train,
399 y_val,
400 groups_train,
401 _groups_val,
402 ) = train_test_split_none(X_train, y_train, groups_train, **val_split_options)
364 403
365 # train and eval 404 # train and eval
366 if hasattr(estimator, 'validation_data'): 405 if hasattr(estimator, "validation_data"):
367 if exp_scheme == 'train_val_test': 406 if exp_scheme == "train_val_test":
368 estimator.fit(X_train, y_train, 407 estimator.fit(X_train, y_train, validation_data=(X_val, y_val))
369 validation_data=(X_val, y_val)) 408 else:
370 else: 409 estimator.fit(X_train, y_train, validation_data=(X_test, y_test))
371 estimator.fit(X_train, y_train,
372 validation_data=(X_test, y_test))
373 else: 410 else:
374 estimator.fit(X_train, y_train) 411 estimator.fit(X_train, y_train)
375 412
376 if hasattr(estimator, 'evaluate'): 413 if hasattr(estimator, "evaluate"):
377 scores = estimator.evaluate(X_test, y_test=y_test, 414 scores = estimator.evaluate(
378 scorer=scorer, 415 X_test, y_test=y_test, scorer=scorer, is_multimetric=True
379 is_multimetric=True) 416 )
380 else: 417 else:
381 scores = _score(estimator, X_test, y_test, scorer, 418 scores = _score(estimator, X_test, y_test, scorer, is_multimetric=True)
382 is_multimetric=True)
383 # handle output 419 # handle output
384 for name, score in scores.items(): 420 for name, score in scores.items():
385 scores[name] = [score] 421 scores[name] = [score]
386 df = pd.DataFrame(scores) 422 df = pd.DataFrame(scores)
387 df = df[sorted(df.columns)] 423 df = df[sorted(df.columns)]
388 df.to_csv(path_or_buf=outfile_result, sep='\t', 424 df.to_csv(path_or_buf=outfile_result, sep="\t", header=True, index=False)
389 header=True, index=False)
390 425
391 memory.clear(warn=False) 426 memory.clear(warn=False)
392 427
393 if outfile_object: 428 if outfile_object:
394 main_est = estimator 429 main_est = estimator
395 if isinstance(estimator, pipeline.Pipeline): 430 if isinstance(estimator, pipeline.Pipeline):
396 main_est = estimator.steps[-1][-1] 431 main_est = estimator.steps[-1][-1]
397 432
398 if hasattr(main_est, 'model_') \ 433 if hasattr(main_est, "model_") and hasattr(main_est, "save_weights"):
399 and hasattr(main_est, 'save_weights'):
400 if outfile_weights: 434 if outfile_weights:
401 main_est.save_weights(outfile_weights) 435 main_est.save_weights(outfile_weights)
402 del main_est.model_ 436 if getattr(main_est, "model_", None):
403 del main_est.fit_params 437 del main_est.model_
404 del main_est.model_class_ 438 if getattr(main_est, "fit_params", None):
405 del main_est.validation_data 439 del main_est.fit_params
406 if getattr(main_est, 'data_generator_', None): 440 if getattr(main_est, "model_class_", None):
441 del main_est.model_class_
442 if getattr(main_est, "validation_data", None):
443 del main_est.validation_data
444 if getattr(main_est, "data_generator_", None):
407 del main_est.data_generator_ 445 del main_est.data_generator_
408 446
409 with open(outfile_object, 'wb') as output_handler: 447 with open(outfile_object, "wb") as output_handler:
410 pickle.dump(estimator, output_handler, 448 pickle.dump(estimator, output_handler, pickle.HIGHEST_PROTOCOL)
411 pickle.HIGHEST_PROTOCOL) 449
412 450
413 451 if __name__ == "__main__":
414 if __name__ == '__main__':
415 aparser = argparse.ArgumentParser() 452 aparser = argparse.ArgumentParser()
416 aparser.add_argument("-i", "--inputs", dest="inputs", required=True) 453 aparser.add_argument("-i", "--inputs", dest="inputs", required=True)
417 aparser.add_argument("-e", "--estimator", dest="infile_estimator") 454 aparser.add_argument("-e", "--estimator", dest="infile_estimator")
418 aparser.add_argument("-X", "--infile1", dest="infile1") 455 aparser.add_argument("-X", "--infile1", dest="infile1")
419 aparser.add_argument("-y", "--infile2", dest="infile2") 456 aparser.add_argument("-y", "--infile2", dest="infile2")
425 aparser.add_argument("-b", "--intervals", dest="intervals") 462 aparser.add_argument("-b", "--intervals", dest="intervals")
426 aparser.add_argument("-t", "--targets", dest="targets") 463 aparser.add_argument("-t", "--targets", dest="targets")
427 aparser.add_argument("-f", "--fasta_path", dest="fasta_path") 464 aparser.add_argument("-f", "--fasta_path", dest="fasta_path")
428 args = aparser.parse_args() 465 args = aparser.parse_args()
429 466
430 main(args.inputs, args.infile_estimator, args.infile1, args.infile2, 467 main(
431 args.outfile_result, outfile_object=args.outfile_object, 468 args.inputs,
432 outfile_weights=args.outfile_weights, groups=args.groups, 469 args.infile_estimator,
433 ref_seq=args.ref_seq, intervals=args.intervals, 470 args.infile1,
434 targets=args.targets, fasta_path=args.fasta_path) 471 args.infile2,
472 args.outfile_result,
473 outfile_object=args.outfile_object,
474 outfile_weights=args.outfile_weights,
475 groups=args.groups,
476 ref_seq=args.ref_seq,
477 intervals=args.intervals,
478 targets=args.targets,
479 fasta_path=args.fasta_path,
480 )