Mercurial > repos > bgruening > model_prediction
comparison keras_train_and_eval.py @ 18:9991c4ddde14 draft
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 80417bf0158a9b596e485dd66408f738f405145a
author | bgruening |
---|---|
date | Mon, 02 Oct 2023 10:26:06 +0000 |
parents | 3bb1b688b0e4 |
children | 61ed5b826c32 |
comparison
equal
deleted
inserted
replaced
17:980bf31faa05 | 18:9991c4ddde14 |
---|---|
186 inputs, | 186 inputs, |
187 infile_estimator, | 187 infile_estimator, |
188 infile1, | 188 infile1, |
189 infile2, | 189 infile2, |
190 outfile_result, | 190 outfile_result, |
191 outfile_history=None, | |
191 outfile_object=None, | 192 outfile_object=None, |
192 outfile_y_true=None, | 193 outfile_y_true=None, |
193 outfile_y_preds=None, | 194 outfile_y_preds=None, |
194 groups=None, | 195 groups=None, |
195 ref_seq=None, | 196 ref_seq=None, |
213 File path to dataset containing target values. | 214 File path to dataset containing target values. |
214 | 215 |
215 outfile_result : str | 216 outfile_result : str |
216 File path to save the results, either cv_results or test result. | 217 File path to save the results, either cv_results or test result. |
217 | 218 |
219 outfile_history : str, optional | |
220 File path to save the training history. | |
221 | |
218 outfile_object : str, optional | 222 outfile_object : str, optional |
219 File path to save searchCV object. | 223 File path to save searchCV object. |
220 | 224 |
221 outfile_y_true : str, optional | 225 outfile_y_true : str, optional |
222 File path to target values for prediction. | 226 File path to target values for prediction. |
251 | 255 |
252 # swap hyperparameter | 256 # swap hyperparameter |
253 swapping = params["experiment_schemes"]["hyperparams_swapping"] | 257 swapping = params["experiment_schemes"]["hyperparams_swapping"] |
254 swap_params = _eval_swap_params(swapping) | 258 swap_params = _eval_swap_params(swapping) |
255 estimator.set_params(**swap_params) | 259 estimator.set_params(**swap_params) |
256 | |
257 estimator_params = estimator.get_params() | 260 estimator_params = estimator.get_params() |
258 | |
259 # store read dataframe object | 261 # store read dataframe object |
260 loaded_df = {} | 262 loaded_df = {} |
261 | 263 |
262 input_type = params["input_options"]["selected_input"] | 264 input_type = params["input_options"]["selected_input"] |
263 # tabular input | 265 # tabular input |
446 ) = train_test_split_none(X_train, y_train, groups_train, **val_split_options) | 448 ) = train_test_split_none(X_train, y_train, groups_train, **val_split_options) |
447 | 449 |
448 # train and eval | 450 # train and eval |
449 if hasattr(estimator, "config") and hasattr(estimator, "model_type"): | 451 if hasattr(estimator, "config") and hasattr(estimator, "model_type"): |
450 if exp_scheme == "train_val_test": | 452 if exp_scheme == "train_val_test": |
451 estimator.fit(X_train, y_train, validation_data=(X_val, y_val)) | 453 history = estimator.fit(X_train, y_train, validation_data=(X_val, y_val)) |
452 else: | 454 else: |
453 estimator.fit(X_train, y_train, validation_data=(X_test, y_test)) | 455 history = estimator.fit(X_train, y_train, validation_data=(X_test, y_test)) |
454 else: | 456 else: |
455 estimator.fit(X_train, y_train) | 457 history = estimator.fit(X_train, y_train) |
456 | 458 if "callbacks" in estimator_params: |
459 for cb in estimator_params["callbacks"]: | |
460 if cb["callback_selection"]["callback_type"] == "CSVLogger": | |
461 hist_df = pd.DataFrame(history.history) | |
462 hist_df["epoch"] = np.arange(1, estimator_params["epochs"] + 1) | |
463 epo_col = hist_df.pop('epoch') | |
464 hist_df.insert(0, 'epoch', epo_col) | |
465 hist_df.to_csv(path_or_buf=outfile_history, sep="\t", header=True, index=False) | |
466 break | |
457 if isinstance(estimator, KerasGBatchClassifier): | 467 if isinstance(estimator, KerasGBatchClassifier): |
458 scores = {} | 468 scores = {} |
459 steps = estimator.prediction_steps | 469 steps = estimator.prediction_steps |
460 batch_size = estimator.batch_size | 470 batch_size = estimator.batch_size |
461 data_generator = estimator.data_generator_ | 471 data_generator = estimator.data_generator_ |
524 aparser.add_argument("-i", "--inputs", dest="inputs", required=True) | 534 aparser.add_argument("-i", "--inputs", dest="inputs", required=True) |
525 aparser.add_argument("-e", "--estimator", dest="infile_estimator") | 535 aparser.add_argument("-e", "--estimator", dest="infile_estimator") |
526 aparser.add_argument("-X", "--infile1", dest="infile1") | 536 aparser.add_argument("-X", "--infile1", dest="infile1") |
527 aparser.add_argument("-y", "--infile2", dest="infile2") | 537 aparser.add_argument("-y", "--infile2", dest="infile2") |
528 aparser.add_argument("-O", "--outfile_result", dest="outfile_result") | 538 aparser.add_argument("-O", "--outfile_result", dest="outfile_result") |
539 aparser.add_argument("-hi", "--outfile_history", dest="outfile_history") | |
529 aparser.add_argument("-o", "--outfile_object", dest="outfile_object") | 540 aparser.add_argument("-o", "--outfile_object", dest="outfile_object") |
530 aparser.add_argument("-l", "--outfile_y_true", dest="outfile_y_true") | 541 aparser.add_argument("-l", "--outfile_y_true", dest="outfile_y_true") |
531 aparser.add_argument("-p", "--outfile_y_preds", dest="outfile_y_preds") | 542 aparser.add_argument("-p", "--outfile_y_preds", dest="outfile_y_preds") |
532 aparser.add_argument("-g", "--groups", dest="groups") | 543 aparser.add_argument("-g", "--groups", dest="groups") |
533 aparser.add_argument("-r", "--ref_seq", dest="ref_seq") | 544 aparser.add_argument("-r", "--ref_seq", dest="ref_seq") |
540 args.inputs, | 551 args.inputs, |
541 args.infile_estimator, | 552 args.infile_estimator, |
542 args.infile1, | 553 args.infile1, |
543 args.infile2, | 554 args.infile2, |
544 args.outfile_result, | 555 args.outfile_result, |
556 outfile_history=args.outfile_history, | |
545 outfile_object=args.outfile_object, | 557 outfile_object=args.outfile_object, |
546 outfile_y_true=args.outfile_y_true, | 558 outfile_y_true=args.outfile_y_true, |
547 outfile_y_preds=args.outfile_y_preds, | 559 outfile_y_preds=args.outfile_y_preds, |
548 groups=args.groups, | 560 groups=args.groups, |
549 ref_seq=args.ref_seq, | 561 ref_seq=args.ref_seq, |