Mercurial > repos > bgruening > model_prediction
comparison keras_train_and_eval.py @ 19:61ed5b826c32 draft
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 5eca9041ce0154eded5aec07195502d5eb3cdd4f
| author | bgruening |
|---|---|
| date | Fri, 03 Nov 2023 23:18:23 +0000 |
| parents | 9991c4ddde14 |
| children |
comparison
equal
deleted
inserted
replaced
| 18:9991c4ddde14 | 19:61ed5b826c32 |
|---|---|
| 8 import numpy as np | 8 import numpy as np |
| 9 import pandas as pd | 9 import pandas as pd |
| 10 from galaxy_ml.keras_galaxy_models import ( | 10 from galaxy_ml.keras_galaxy_models import ( |
| 11 _predict_generator, | 11 _predict_generator, |
| 12 KerasGBatchClassifier, | 12 KerasGBatchClassifier, |
| 13 KerasGClassifier, | |
| 14 KerasGRegressor | |
| 13 ) | 15 ) |
| 14 from galaxy_ml.model_persist import dump_model_to_h5, load_model_from_h5 | 16 from galaxy_ml.model_persist import dump_model_to_h5, load_model_from_h5 |
| 15 from galaxy_ml.model_validations import train_test_split | 17 from galaxy_ml.model_validations import train_test_split |
| 16 from galaxy_ml.utils import ( | 18 from galaxy_ml.utils import ( |
| 17 clean_params, | 19 clean_params, |
| 398 main_est.set_params(memory=memory) | 400 main_est.set_params(memory=memory) |
| 399 | 401 |
| 400 # handle scorer, convert to scorer dict | 402 # handle scorer, convert to scorer dict |
| 401 scoring = params["experiment_schemes"]["metrics"]["scoring"] | 403 scoring = params["experiment_schemes"]["metrics"]["scoring"] |
| 402 scorer = get_scoring(scoring) | 404 scorer = get_scoring(scoring) |
| 403 if not isinstance(scorer, (dict, list)): | 405 |
| 404 scorer = [scoring["primary_scoring"]] | 406 # We get 'None' back from the call to 'get_scoring()' if |
| 407 # the primary scoring is 'default'. Replace 'default' with | |
| 408 # the default scoring for classification/regression (accuracy/r2) | |
| 409 if scorer is None: | |
| 410 if isinstance(estimator, KerasGClassifier): | |
| 411 scorer = ['accuracy'] | |
| 412 if isinstance(estimator, KerasGRegressor): | |
| 413 scorer = ['r2'] | |
| 414 | |
| 405 scorer = _check_multimetric_scoring(estimator, scoring=scorer) | 415 scorer = _check_multimetric_scoring(estimator, scoring=scorer) |
| 406 | 416 |
| 407 # handle test (first) split | 417 # handle test (first) split |
| 408 test_split_options = params["experiment_schemes"]["test_split"]["split_algos"] | 418 test_split_options = params["experiment_schemes"]["test_split"]["split_algos"] |
| 409 | 419 |
| 497 if hasattr(estimator, "predict_proba"): | 507 if hasattr(estimator, "predict_proba"): |
| 498 predictions = estimator.predict_proba(X_test) | 508 predictions = estimator.predict_proba(X_test) |
| 499 else: | 509 else: |
| 500 predictions = estimator.predict(X_test) | 510 predictions = estimator.predict(X_test) |
| 501 | 511 |
| 502 y_true = y_test | 512 # Un-do OHE of the validation labels |
| 503 sk_scores = _score(estimator, X_test, y_test, scorer) | 513 if len(y_test.shape) == 2: |
| 514 rounded_test_labels = np.argmax(y_test, axis=1) | |
| 515 y_true = rounded_test_labels | |
| 516 sk_scores = _score(estimator, X_test, rounded_test_labels, scorer) | |
| 517 else: | |
| 518 y_true = y_test | |
| 519 sk_scores = _score(estimator, X_test, y_true, scorer) | |
| 520 | |
| 504 scores.update(sk_scores) | 521 scores.update(sk_scores) |
| 505 | 522 |
| 506 # handle output | 523 # handle output |
| 507 if outfile_y_true: | 524 if outfile_y_true: |
| 508 try: | 525 try: |
