Mercurial > repos > bgruening > keras_model_config
comparison keras_train_and_eval.py @ 20:463a197abbd1 draft
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 5eca9041ce0154eded5aec07195502d5eb3cdd4f
author | bgruening |
---|---|
date | Fri, 03 Nov 2023 23:11:53 +0000 |
parents | f22a9297440f |
children |
comparison
equal
deleted
inserted
replaced
19:f22a9297440f | 20:463a197abbd1 |
---|---|
8 import numpy as np | 8 import numpy as np |
9 import pandas as pd | 9 import pandas as pd |
10 from galaxy_ml.keras_galaxy_models import ( | 10 from galaxy_ml.keras_galaxy_models import ( |
11 _predict_generator, | 11 _predict_generator, |
12 KerasGBatchClassifier, | 12 KerasGBatchClassifier, |
13 KerasGClassifier, | |
14 KerasGRegressor | |
13 ) | 15 ) |
14 from galaxy_ml.model_persist import dump_model_to_h5, load_model_from_h5 | 16 from galaxy_ml.model_persist import dump_model_to_h5, load_model_from_h5 |
15 from galaxy_ml.model_validations import train_test_split | 17 from galaxy_ml.model_validations import train_test_split |
16 from galaxy_ml.utils import ( | 18 from galaxy_ml.utils import ( |
17 clean_params, | 19 clean_params, |
398 main_est.set_params(memory=memory) | 400 main_est.set_params(memory=memory) |
399 | 401 |
400 # handle scorer, convert to scorer dict | 402 # handle scorer, convert to scorer dict |
401 scoring = params["experiment_schemes"]["metrics"]["scoring"] | 403 scoring = params["experiment_schemes"]["metrics"]["scoring"] |
402 scorer = get_scoring(scoring) | 404 scorer = get_scoring(scoring) |
403 if not isinstance(scorer, (dict, list)): | 405 |
404 scorer = [scoring["primary_scoring"]] | 406 # We get 'None' back from the call to 'get_scoring()' if |
407 # the primary scoring is 'default'. Replace 'default' with | |
408 # the default scoring for classification/regression (accuracy/r2) | |
409 if scorer is None: | |
410 if isinstance(estimator, KerasGClassifier): | |
411 scorer = ['accuracy'] | |
412 if isinstance(estimator, KerasGRegressor): | |
413 scorer = ['r2'] | |
414 | |
405 scorer = _check_multimetric_scoring(estimator, scoring=scorer) | 415 scorer = _check_multimetric_scoring(estimator, scoring=scorer) |
406 | 416 |
407 # handle test (first) split | 417 # handle test (first) split |
408 test_split_options = params["experiment_schemes"]["test_split"]["split_algos"] | 418 test_split_options = params["experiment_schemes"]["test_split"]["split_algos"] |
409 | 419 |
497 if hasattr(estimator, "predict_proba"): | 507 if hasattr(estimator, "predict_proba"): |
498 predictions = estimator.predict_proba(X_test) | 508 predictions = estimator.predict_proba(X_test) |
499 else: | 509 else: |
500 predictions = estimator.predict(X_test) | 510 predictions = estimator.predict(X_test) |
501 | 511 |
502 y_true = y_test | 512 # Un-do OHE of the validation labels |
503 sk_scores = _score(estimator, X_test, y_test, scorer) | 513 if len(y_test.shape) == 2: |
514 rounded_test_labels = np.argmax(y_test, axis=1) | |
515 y_true = rounded_test_labels | |
516 sk_scores = _score(estimator, X_test, rounded_test_labels, scorer) | |
517 else: | |
518 y_true = y_test | |
519 sk_scores = _score(estimator, X_test, y_true, scorer) | |
520 | |
504 scores.update(sk_scores) | 521 scores.update(sk_scores) |
505 | 522 |
506 # handle output | 523 # handle output |
507 if outfile_y_true: | 524 if outfile_y_true: |
508 try: | 525 try: |