comparison keras_train_and_eval.py @ 40:06d772036a62 draft

planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 9981e25b00de29ed881b2229a173a8c812ded9bb
author bgruening
date Wed, 09 Aug 2023 13:11:48 +0000
parents 73e7f1c76ece
children bb9fc9d46ea4
comparison
equal deleted inserted replaced
39:7dd3fb35904f 40:06d772036a62
1 import argparse 1 import argparse
2 import json 2 import json
3 import os 3 import os
4 import pickle
5 import warnings 4 import warnings
6 from itertools import chain 5 from itertools import chain
7 6
8 import joblib 7 import joblib
9 import numpy as np 8 import numpy as np
10 import pandas as pd 9 import pandas as pd
11 from galaxy_ml.externals.selene_sdk.utils import compute_score 10 from galaxy_ml.keras_galaxy_models import (
12 from galaxy_ml.keras_galaxy_models import _predict_generator 11 _predict_generator,
12 KerasGBatchClassifier,
13 )
14 from galaxy_ml.model_persist import dump_model_to_h5, load_model_from_h5
13 from galaxy_ml.model_validations import train_test_split 15 from galaxy_ml.model_validations import train_test_split
14 from galaxy_ml.utils import (clean_params, get_main_estimator, 16 from galaxy_ml.utils import (
15 get_module, get_scoring, load_model, read_columns, 17 clean_params,
16 SafeEval, try_get_attr) 18 gen_compute_scores,
19 get_main_estimator,
20 get_module,
21 get_scoring,
22 read_columns,
23 SafeEval
24 )
17 from scipy.io import mmread 25 from scipy.io import mmread
18 from sklearn.metrics.scorer import _check_multimetric_scoring 26 from sklearn.metrics._scorer import _check_multimetric_scoring
19 from sklearn.model_selection import _search, _validation
20 from sklearn.model_selection._validation import _score 27 from sklearn.model_selection._validation import _score
21 from sklearn.pipeline import Pipeline 28 from sklearn.utils import _safe_indexing, indexable
22 from sklearn.utils import indexable, safe_indexing
23
24 _fit_and_score = try_get_attr("galaxy_ml.model_validations", "_fit_and_score")
25 setattr(_search, "_fit_and_score", _fit_and_score)
26 setattr(_validation, "_fit_and_score", _fit_and_score)
27 29
28 N_JOBS = int(os.environ.get("GALAXY_SLOTS", 1)) 30 N_JOBS = int(os.environ.get("GALAXY_SLOTS", 1))
29 CACHE_DIR = os.path.join(os.getcwd(), "cached") 31 CACHE_DIR = os.path.join(os.getcwd(), "cached")
30 del os 32 NON_SEARCHABLE = (
31 NON_SEARCHABLE = ("n_jobs", "pre_dispatch", "memory", "_path", "nthread", "callbacks") 33 "n_jobs",
34 "pre_dispatch",
35 "memory",
36 "_path",
37 "_dir",
38 "nthread",
39 "callbacks",
40 )
32 ALLOWED_CALLBACKS = ( 41 ALLOWED_CALLBACKS = (
33 "EarlyStopping", 42 "EarlyStopping",
34 "TerminateOnNaN", 43 "TerminateOnNaN",
35 "ReduceLROnPlateau", 44 "ReduceLROnPlateau",
36 "CSVLogger", 45 "CSVLogger",
94 index_arr = np.arange(n_samples) 103 index_arr = np.arange(n_samples)
95 test = index_arr[np.isin(groups, group_names)] 104 test = index_arr[np.isin(groups, group_names)]
96 train = index_arr[~np.isin(groups, group_names)] 105 train = index_arr[~np.isin(groups, group_names)]
97 rval = list( 106 rval = list(
98 chain.from_iterable( 107 chain.from_iterable(
99 (safe_indexing(a, train), safe_indexing(a, test)) for a in new_arrays 108 (_safe_indexing(a, train), _safe_indexing(a, test)) for a in new_arrays
100 ) 109 )
101 ) 110 )
102 else: 111 else:
103 rval = train_test_split(*new_arrays, **kwargs) 112 rval = train_test_split(*new_arrays, **kwargs)
104 113
106 rval[pos * 2: 2] = [None, None] 115 rval[pos * 2: 2] = [None, None]
107 116
108 return rval 117 return rval
109 118
110 119
111 def _evaluate(y_true, pred_probas, scorer, is_multimetric=True): 120 def _evaluate_keras_and_sklearn_scores(
112 """output scores based on input scorer 121 estimator,
122 data_generator,
123 X,
124 y=None,
125 sk_scoring=None,
126 steps=None,
127 batch_size=32,
128 return_predictions=False,
129 ):
130 """output scores for bother keras and sklearn metrics
113 131
114 Parameters 132 Parameters
115 ---------- 133 -----------
116 y_true : array 134 estimator : object
117 True label or target values 135 Fitted `galaxy_ml.keras_galaxy_models.KerasGBatchClassifier`.
118 pred_probas : array 136 data_generator : object
119 Prediction values, probability for classification problem 137 From `galaxy_ml.preprocessors.ImageDataFrameBatchGenerator`.
120 scorer : dict 138 X : 2-D array
121 dict of `sklearn.metrics.scorer.SCORER` 139 Contains indecies of images that need to be evaluated.
122 is_multimetric : bool, default is True 140 y : None
141 Target value.
142 sk_scoring : dict
143 Galaxy tool input parameters.
144 steps : integer or None
145 Evaluation/prediction steps before stop.
146 batch_size : integer
147 Number of samples in a batch
148 return_predictions : bool, default is False
149 Whether to return predictions and true labels.
123 """ 150 """
124 if y_true.ndim == 1 or y_true.shape[-1] == 1: 151 scores = {}
125 pred_probas = pred_probas.ravel() 152
126 pred_labels = (pred_probas > 0.5).astype("int32") 153 generator = data_generator.flow(X, y=y, batch_size=batch_size)
127 targets = y_true.ravel().astype("int32") 154 # keras metrics evaluation
128 if not is_multimetric: 155 # handle scorer, convert to scorer dict
129 preds = ( 156 generator.reset()
130 pred_labels 157 score_results = estimator.model_.evaluate_generator(generator, steps=steps)
131 if scorer.__class__.__name__ == "_PredictScorer" 158 metrics_names = estimator.model_.metrics_names
132 else pred_probas 159 if not isinstance(metrics_names, list):
133 ) 160 scores[metrics_names] = score_results
134 score = scorer._score_func(targets, preds, **scorer._kwargs) 161 else:
135 162 scores = dict(zip(metrics_names, score_results))
136 return score 163
137 else: 164 if sk_scoring["primary_scoring"] == "default" and not return_predictions:
138 scores = {} 165 return scores
139 for name, one_scorer in scorer.items(): 166
140 preds = ( 167 generator.reset()
141 pred_labels 168 predictions, y_true = _predict_generator(estimator.model_, generator, steps=steps)
142 if one_scorer.__class__.__name__ == "_PredictScorer" 169
143 else pred_probas 170 # for sklearn metrics
144 ) 171 if sk_scoring["primary_scoring"] != "default":
145 score = one_scorer._score_func(targets, preds, **one_scorer._kwargs) 172 scorer = get_scoring(sk_scoring)
146 scores[name] = score 173 if not isinstance(scorer, (dict, list)):
147 174 scorer = [sk_scoring["primary_scoring"]]
148 # TODO: multi-class metrics 175 scorer = _check_multimetric_scoring(estimator, scoring=scorer)
149 # multi-label 176 sk_scores = gen_compute_scores(y_true, predictions, scorer)
150 else: 177 scores.update(sk_scores)
151 pred_labels = (pred_probas > 0.5).astype("int32") 178
152 targets = y_true.astype("int32") 179 if return_predictions:
153 if not is_multimetric: 180 return scores, predictions, y_true
154 preds = ( 181 else:
155 pred_labels 182 return scores, None, None
156 if scorer.__class__.__name__ == "_PredictScorer"
157 else pred_probas
158 )
159 score, _ = compute_score(preds, targets, scorer._score_func)
160 return score
161 else:
162 scores = {}
163 for name, one_scorer in scorer.items():
164 preds = (
165 pred_labels
166 if one_scorer.__class__.__name__ == "_PredictScorer"
167 else pred_probas
168 )
169 score, _ = compute_score(preds, targets, one_scorer._score_func)
170 scores[name] = score
171
172 return scores
173 183
174 184
175 def main( 185 def main(
176 inputs, 186 inputs,
177 infile_estimator, 187 infile_estimator,
178 infile1, 188 infile1,
179 infile2, 189 infile2,
180 outfile_result, 190 outfile_result,
181 outfile_object=None, 191 outfile_object=None,
182 outfile_weights=None,
183 outfile_y_true=None, 192 outfile_y_true=None,
184 outfile_y_preds=None, 193 outfile_y_preds=None,
185 groups=None, 194 groups=None,
186 ref_seq=None, 195 ref_seq=None,
187 intervals=None, 196 intervals=None,
190 ): 199 ):
191 """ 200 """
192 Parameter 201 Parameter
193 --------- 202 ---------
194 inputs : str 203 inputs : str
195 File path to galaxy tool parameter 204 File path to galaxy tool parameter.
196 205
197 infile_estimator : str 206 infile_estimator : str
198 File path to estimator 207 File path to estimator.
199 208
200 infile1 : str 209 infile1 : str
201 File path to dataset containing features 210 File path to dataset containing features.
202 211
203 infile2 : str 212 infile2 : str
204 File path to dataset containing target values 213 File path to dataset containing target values.
205 214
206 outfile_result : str 215 outfile_result : str
207 File path to save the results, either cv_results or test result 216 File path to save the results, either cv_results or test result.
208 217
209 outfile_object : str, optional 218 outfile_object : str, optional
210 File path to save searchCV object 219 File path to save searchCV object.
211
212 outfile_weights : str, optional
213 File path to save deep learning model weights
214 220
215 outfile_y_true : str, optional 221 outfile_y_true : str, optional
216 File path to target values for prediction 222 File path to target values for prediction.
217 223
218 outfile_y_preds : str, optional 224 outfile_y_preds : str, optional
219 File path to save deep learning model weights 225 File path to save predictions.
220 226
221 groups : str 227 groups : str
222 File path to dataset containing groups labels 228 File path to dataset containing groups labels.
223 229
224 ref_seq : str 230 ref_seq : str
225 File path to dataset containing genome sequence file 231 File path to dataset containing genome sequence file.
226 232
227 intervals : str 233 intervals : str
228 File path to dataset containing interval file 234 File path to dataset containing interval file.
229 235
230 targets : str 236 targets : str
231 File path to dataset compressed target bed file 237 File path to dataset compressed target bed file.
232 238
233 fasta_path : str 239 fasta_path : str
234 File path to dataset containing fasta file 240 File path to dataset containing fasta file.
235 """ 241 """
236 warnings.simplefilter("ignore") 242 warnings.simplefilter("ignore")
237 243
238 with open(inputs, "r") as param_handler: 244 with open(inputs, "r") as param_handler:
239 params = json.load(param_handler) 245 params = json.load(param_handler)
240 246
241 # load estimator 247 # load estimator
242 with open(infile_estimator, "rb") as estimator_handler: 248 estimator = load_model_from_h5(infile_estimator)
243 estimator = load_model(estimator_handler)
244 249
245 estimator = clean_params(estimator) 250 estimator = clean_params(estimator)
246 251
247 # swap hyperparameter 252 # swap hyperparameter
248 swapping = params["experiment_schemes"]["hyperparams_swapping"] 253 swapping = params["experiment_schemes"]["hyperparams_swapping"]
331 else: 336 else:
332 infile2 = pd.read_csv(infile2, sep="\t", header=header, parse_dates=True) 337 infile2 = pd.read_csv(infile2, sep="\t", header=header, parse_dates=True)
333 loaded_df[df_key] = infile2 338 loaded_df[df_key] = infile2
334 339
335 y = read_columns( 340 y = read_columns(
336 infile2, c=c, c_option=column_option, sep="\t", header=header, parse_dates=True 341 infile2,
342 c=c,
343 c_option=column_option,
344 sep="\t",
345 header=header,
346 parse_dates=True,
337 ) 347 )
338 if len(y.shape) == 2 and y.shape[1] == 1: 348 if len(y.shape) == 2 and y.shape[1] == 1:
339 y = y.ravel() 349 y = y.ravel()
340 if input_type == "refseq_and_interval": 350 if input_type == "refseq_and_interval":
341 estimator.set_params(data_batch_generator__features=y.ravel().tolist()) 351 estimator.set_params(data_batch_generator__features=y.ravel().tolist())
385 if main_est.__class__.__name__ == "IRAPSClassifier": 395 if main_est.__class__.__name__ == "IRAPSClassifier":
386 main_est.set_params(memory=memory) 396 main_est.set_params(memory=memory)
387 397
388 # handle scorer, convert to scorer dict 398 # handle scorer, convert to scorer dict
389 scoring = params["experiment_schemes"]["metrics"]["scoring"] 399 scoring = params["experiment_schemes"]["metrics"]["scoring"]
390 if scoring is not None:
391 # get_scoring() expects secondary_scoring to be a comma separated string (not a list)
392 # Check if secondary_scoring is specified
393 secondary_scoring = scoring.get("secondary_scoring", None)
394 if secondary_scoring is not None:
395 # If secondary_scoring is specified, convert the list into comman separated string
396 scoring["secondary_scoring"] = ",".join(scoring["secondary_scoring"])
397
398 scorer = get_scoring(scoring) 400 scorer = get_scoring(scoring)
399 scorer, _ = _check_multimetric_scoring(estimator, scoring=scorer) 401 if not isinstance(scorer, (dict, list)):
402 scorer = [scoring["primary_scoring"]]
403 scorer = _check_multimetric_scoring(estimator, scoring=scorer)
400 404
401 # handle test (first) split 405 # handle test (first) split
402 test_split_options = params["experiment_schemes"]["test_split"]["split_algos"] 406 test_split_options = params["experiment_schemes"]["test_split"]["split_algos"]
403 407
404 if test_split_options["shuffle"] == "group": 408 if test_split_options["shuffle"] == "group":
409 else: 413 else:
410 raise ValueError( 414 raise ValueError(
411 "Stratified shuffle split is not " "applicable on empty target values!" 415 "Stratified shuffle split is not " "applicable on empty target values!"
412 ) 416 )
413 417
414 ( 418 X_train, X_test, y_train, y_test, groups_train, groups_test = train_test_split_none(
415 X_train, 419 X, y, groups, **test_split_options
416 X_test, 420 )
417 y_train,
418 y_test,
419 groups_train,
420 _groups_test,
421 ) = train_test_split_none(X, y, groups, **test_split_options)
422 421
423 exp_scheme = params["experiment_schemes"]["selected_exp_scheme"] 422 exp_scheme = params["experiment_schemes"]["selected_exp_scheme"]
424 423
425 # handle validation (second) split 424 # handle validation (second) split
426 if exp_scheme == "train_val_test": 425 if exp_scheme == "train_val_test":
441 X_train, 440 X_train,
442 X_val, 441 X_val,
443 y_train, 442 y_train,
444 y_val, 443 y_val,
445 groups_train, 444 groups_train,
446 _groups_val, 445 groups_val,
447 ) = train_test_split_none(X_train, y_train, groups_train, **val_split_options) 446 ) = train_test_split_none(X_train, y_train, groups_train, **val_split_options)
448 447
449 # train and eval 448 # train and eval
450 if hasattr(estimator, "validation_data"): 449 if hasattr(estimator, "config") and hasattr(estimator, "model_type"):
451 if exp_scheme == "train_val_test": 450 if exp_scheme == "train_val_test":
452 estimator.fit(X_train, y_train, validation_data=(X_val, y_val)) 451 estimator.fit(X_train, y_train, validation_data=(X_val, y_val))
453 else: 452 else:
454 estimator.fit(X_train, y_train, validation_data=(X_test, y_test)) 453 estimator.fit(X_train, y_train, validation_data=(X_test, y_test))
455 else: 454 else:
456 estimator.fit(X_train, y_train) 455 estimator.fit(X_train, y_train)
457 456
458 if hasattr(estimator, "evaluate"): 457 if isinstance(estimator, KerasGBatchClassifier):
458 scores = {}
459 steps = estimator.prediction_steps 459 steps = estimator.prediction_steps
460 batch_size = estimator.batch_size 460 batch_size = estimator.batch_size
461 generator = estimator.data_generator_.flow( 461 data_generator = estimator.data_generator_
462 X_test, y=y_test, batch_size=batch_size 462
463 scores, predictions, y_true = _evaluate_keras_and_sklearn_scores(
464 estimator,
465 data_generator,
466 X_test,
467 y=y_test,
468 sk_scoring=scoring,
469 steps=steps,
470 batch_size=batch_size,
471 return_predictions=bool(outfile_y_true),
463 ) 472 )
464 predictions, y_true = _predict_generator( 473
465 estimator.model_, generator, steps=steps 474 else:
466 ) 475 scores = {}
467 scores = _evaluate(y_true, predictions, scorer, is_multimetric=True) 476 if hasattr(estimator, "model_") and hasattr(estimator.model_, "metrics_names"):
468 477 batch_size = estimator.batch_size
469 else: 478 score_results = estimator.model_.evaluate(
479 X_test, y=y_test, batch_size=batch_size, verbose=0
480 )
481 metrics_names = estimator.model_.metrics_names
482 if not isinstance(metrics_names, list):
483 scores[metrics_names] = score_results
484 else:
485 scores = dict(zip(metrics_names, score_results))
486
470 if hasattr(estimator, "predict_proba"): 487 if hasattr(estimator, "predict_proba"):
471 predictions = estimator.predict_proba(X_test) 488 predictions = estimator.predict_proba(X_test)
472 else: 489 else:
473 predictions = estimator.predict(X_test) 490 predictions = estimator.predict(X_test)
474 491
475 y_true = y_test 492 y_true = y_test
476 scores = _score(estimator, X_test, y_test, scorer, is_multimetric=True) 493 sk_scores = _score(estimator, X_test, y_test, scorer)
494 scores.update(sk_scores)
495
496 # handle output
477 if outfile_y_true: 497 if outfile_y_true:
478 try: 498 try:
479 pd.DataFrame(y_true).to_csv(outfile_y_true, sep="\t", index=False) 499 pd.DataFrame(y_true).to_csv(outfile_y_true, sep="\t", index=False)
480 pd.DataFrame(predictions).astype(np.float32).to_csv( 500 pd.DataFrame(predictions).astype(np.float32).to_csv(
481 outfile_y_preds, 501 outfile_y_preds,
484 float_format="%g", 504 float_format="%g",
485 chunksize=10000, 505 chunksize=10000,
486 ) 506 )
487 except Exception as e: 507 except Exception as e:
488 print("Error in saving predictions: %s" % e) 508 print("Error in saving predictions: %s" % e)
489
490 # handle output 509 # handle output
491 for name, score in scores.items(): 510 for name, score in scores.items():
492 scores[name] = [score] 511 scores[name] = [score]
493 df = pd.DataFrame(scores) 512 df = pd.DataFrame(scores)
494 df = df[sorted(df.columns)] 513 df = df[sorted(df.columns)]
495 df.to_csv(path_or_buf=outfile_result, sep="\t", header=True, index=False) 514 df.to_csv(path_or_buf=outfile_result, sep="\t", header=True, index=False)
496 515
497 memory.clear(warn=False) 516 memory.clear(warn=False)
498 517
499 if outfile_object: 518 if outfile_object:
500 main_est = estimator 519 dump_model_to_h5(estimator, outfile_object)
501 if isinstance(estimator, Pipeline):
502 main_est = estimator.steps[-1][-1]
503
504 if hasattr(main_est, "model_") and hasattr(main_est, "save_weights"):
505 if outfile_weights:
506 main_est.save_weights(outfile_weights)
507 del main_est.model_
508 del main_est.fit_params
509 del main_est.model_class_
510 if getattr(main_est, "validation_data", None):
511 del main_est.validation_data
512 if getattr(main_est, "data_generator_", None):
513 del main_est.data_generator_
514
515 with open(outfile_object, "wb") as output_handler:
516 pickle.dump(estimator, output_handler, pickle.HIGHEST_PROTOCOL)
517 520
518 521
519 if __name__ == "__main__": 522 if __name__ == "__main__":
520 aparser = argparse.ArgumentParser() 523 aparser = argparse.ArgumentParser()
521 aparser.add_argument("-i", "--inputs", dest="inputs", required=True) 524 aparser.add_argument("-i", "--inputs", dest="inputs", required=True)
522 aparser.add_argument("-e", "--estimator", dest="infile_estimator") 525 aparser.add_argument("-e", "--estimator", dest="infile_estimator")
523 aparser.add_argument("-X", "--infile1", dest="infile1") 526 aparser.add_argument("-X", "--infile1", dest="infile1")
524 aparser.add_argument("-y", "--infile2", dest="infile2") 527 aparser.add_argument("-y", "--infile2", dest="infile2")
525 aparser.add_argument("-O", "--outfile_result", dest="outfile_result") 528 aparser.add_argument("-O", "--outfile_result", dest="outfile_result")
526 aparser.add_argument("-o", "--outfile_object", dest="outfile_object") 529 aparser.add_argument("-o", "--outfile_object", dest="outfile_object")
527 aparser.add_argument("-w", "--outfile_weights", dest="outfile_weights")
528 aparser.add_argument("-l", "--outfile_y_true", dest="outfile_y_true") 530 aparser.add_argument("-l", "--outfile_y_true", dest="outfile_y_true")
529 aparser.add_argument("-p", "--outfile_y_preds", dest="outfile_y_preds") 531 aparser.add_argument("-p", "--outfile_y_preds", dest="outfile_y_preds")
530 aparser.add_argument("-g", "--groups", dest="groups") 532 aparser.add_argument("-g", "--groups", dest="groups")
531 aparser.add_argument("-r", "--ref_seq", dest="ref_seq") 533 aparser.add_argument("-r", "--ref_seq", dest="ref_seq")
532 aparser.add_argument("-b", "--intervals", dest="intervals") 534 aparser.add_argument("-b", "--intervals", dest="intervals")
539 args.infile_estimator, 541 args.infile_estimator,
540 args.infile1, 542 args.infile1,
541 args.infile2, 543 args.infile2,
542 args.outfile_result, 544 args.outfile_result,
543 outfile_object=args.outfile_object, 545 outfile_object=args.outfile_object,
544 outfile_weights=args.outfile_weights,
545 outfile_y_true=args.outfile_y_true, 546 outfile_y_true=args.outfile_y_true,
546 outfile_y_preds=args.outfile_y_preds, 547 outfile_y_preds=args.outfile_y_preds,
547 groups=args.groups, 548 groups=args.groups,
548 ref_seq=args.ref_seq, 549 ref_seq=args.ref_seq,
549 intervals=args.intervals, 550 intervals=args.intervals,