comparison keras_train_and_eval.py @ 18:9991c4ddde14 draft

planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 80417bf0158a9b596e485dd66408f738f405145a
author bgruening
date Mon, 02 Oct 2023 10:26:06 +0000
parents 3bb1b688b0e4
children 61ed5b826c32
comparison
equal deleted inserted replaced
17:980bf31faa05 18:9991c4ddde14
186 inputs, 186 inputs,
187 infile_estimator, 187 infile_estimator,
188 infile1, 188 infile1,
189 infile2, 189 infile2,
190 outfile_result, 190 outfile_result,
191 outfile_history=None,
191 outfile_object=None, 192 outfile_object=None,
192 outfile_y_true=None, 193 outfile_y_true=None,
193 outfile_y_preds=None, 194 outfile_y_preds=None,
194 groups=None, 195 groups=None,
195 ref_seq=None, 196 ref_seq=None,
213 File path to dataset containing target values. 214 File path to dataset containing target values.
214 215
215 outfile_result : str 216 outfile_result : str
216 File path to save the results, either cv_results or test result. 217 File path to save the results, either cv_results or test result.
217 218
219 outfile_history : str, optional
220 File path to save the training history.
221
218 outfile_object : str, optional 222 outfile_object : str, optional
219 File path to save searchCV object. 223 File path to save searchCV object.
220 224
221 outfile_y_true : str, optional 225 outfile_y_true : str, optional
222 File path to target values for prediction. 226 File path to target values for prediction.
251 255
252 # swap hyperparameter 256 # swap hyperparameter
253 swapping = params["experiment_schemes"]["hyperparams_swapping"] 257 swapping = params["experiment_schemes"]["hyperparams_swapping"]
254 swap_params = _eval_swap_params(swapping) 258 swap_params = _eval_swap_params(swapping)
255 estimator.set_params(**swap_params) 259 estimator.set_params(**swap_params)
256
257 estimator_params = estimator.get_params() 260 estimator_params = estimator.get_params()
258
259 # store read dataframe object 261 # store read dataframe object
260 loaded_df = {} 262 loaded_df = {}
261 263
262 input_type = params["input_options"]["selected_input"] 264 input_type = params["input_options"]["selected_input"]
263 # tabular input 265 # tabular input
446 ) = train_test_split_none(X_train, y_train, groups_train, **val_split_options) 448 ) = train_test_split_none(X_train, y_train, groups_train, **val_split_options)
447 449
448 # train and eval 450 # train and eval
449 if hasattr(estimator, "config") and hasattr(estimator, "model_type"): 451 if hasattr(estimator, "config") and hasattr(estimator, "model_type"):
450 if exp_scheme == "train_val_test": 452 if exp_scheme == "train_val_test":
451 estimator.fit(X_train, y_train, validation_data=(X_val, y_val)) 453 history = estimator.fit(X_train, y_train, validation_data=(X_val, y_val))
452 else: 454 else:
453 estimator.fit(X_train, y_train, validation_data=(X_test, y_test)) 455 history = estimator.fit(X_train, y_train, validation_data=(X_test, y_test))
454 else: 456 else:
455 estimator.fit(X_train, y_train) 457 history = estimator.fit(X_train, y_train)
456 458 if "callbacks" in estimator_params:
459 for cb in estimator_params["callbacks"]:
460 if cb["callback_selection"]["callback_type"] == "CSVLogger":
461 hist_df = pd.DataFrame(history.history)
462 hist_df["epoch"] = np.arange(1, estimator_params["epochs"] + 1)
463 epo_col = hist_df.pop('epoch')
464 hist_df.insert(0, 'epoch', epo_col)
465 hist_df.to_csv(path_or_buf=outfile_history, sep="\t", header=True, index=False)
466 break
457 if isinstance(estimator, KerasGBatchClassifier): 467 if isinstance(estimator, KerasGBatchClassifier):
458 scores = {} 468 scores = {}
459 steps = estimator.prediction_steps 469 steps = estimator.prediction_steps
460 batch_size = estimator.batch_size 470 batch_size = estimator.batch_size
461 data_generator = estimator.data_generator_ 471 data_generator = estimator.data_generator_
524 aparser.add_argument("-i", "--inputs", dest="inputs", required=True) 534 aparser.add_argument("-i", "--inputs", dest="inputs", required=True)
525 aparser.add_argument("-e", "--estimator", dest="infile_estimator") 535 aparser.add_argument("-e", "--estimator", dest="infile_estimator")
526 aparser.add_argument("-X", "--infile1", dest="infile1") 536 aparser.add_argument("-X", "--infile1", dest="infile1")
527 aparser.add_argument("-y", "--infile2", dest="infile2") 537 aparser.add_argument("-y", "--infile2", dest="infile2")
528 aparser.add_argument("-O", "--outfile_result", dest="outfile_result") 538 aparser.add_argument("-O", "--outfile_result", dest="outfile_result")
539 aparser.add_argument("-hi", "--outfile_history", dest="outfile_history")
529 aparser.add_argument("-o", "--outfile_object", dest="outfile_object") 540 aparser.add_argument("-o", "--outfile_object", dest="outfile_object")
530 aparser.add_argument("-l", "--outfile_y_true", dest="outfile_y_true") 541 aparser.add_argument("-l", "--outfile_y_true", dest="outfile_y_true")
531 aparser.add_argument("-p", "--outfile_y_preds", dest="outfile_y_preds") 542 aparser.add_argument("-p", "--outfile_y_preds", dest="outfile_y_preds")
532 aparser.add_argument("-g", "--groups", dest="groups") 543 aparser.add_argument("-g", "--groups", dest="groups")
533 aparser.add_argument("-r", "--ref_seq", dest="ref_seq") 544 aparser.add_argument("-r", "--ref_seq", dest="ref_seq")
540 args.inputs, 551 args.inputs,
541 args.infile_estimator, 552 args.infile_estimator,
542 args.infile1, 553 args.infile1,
543 args.infile2, 554 args.infile2,
544 args.outfile_result, 555 args.outfile_result,
556 outfile_history=args.outfile_history,
545 outfile_object=args.outfile_object, 557 outfile_object=args.outfile_object,
546 outfile_y_true=args.outfile_y_true, 558 outfile_y_true=args.outfile_y_true,
547 outfile_y_preds=args.outfile_y_preds, 559 outfile_y_preds=args.outfile_y_preds,
548 groups=args.groups, 560 groups=args.groups,
549 ref_seq=args.ref_seq, 561 ref_seq=args.ref_seq,