# HG changeset patch # User bgruening # Date 1699050802 0 # Node ID 46b43ee6d367acf1d2ee0220168e987cf8806118 # Parent 74adae8d7b0f549541d417ee3bfd7c29a1927d2d planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 5eca9041ce0154eded5aec07195502d5eb3cdd4f diff -r 74adae8d7b0f -r 46b43ee6d367 keras_train_and_eval.py --- a/keras_train_and_eval.py Mon Oct 02 10:30:40 2023 +0000 +++ b/keras_train_and_eval.py Fri Nov 03 22:33:22 2023 +0000 @@ -10,6 +10,8 @@ from galaxy_ml.keras_galaxy_models import ( _predict_generator, KerasGBatchClassifier, + KerasGClassifier, + KerasGRegressor ) from galaxy_ml.model_persist import dump_model_to_h5, load_model_from_h5 from galaxy_ml.model_validations import train_test_split @@ -400,8 +402,16 @@ # handle scorer, convert to scorer dict scoring = params["experiment_schemes"]["metrics"]["scoring"] scorer = get_scoring(scoring) - if not isinstance(scorer, (dict, list)): - scorer = [scoring["primary_scoring"]] + + # We get 'None' back from the call to 'get_scoring()' if + # the primary scoring is 'default'. Replace 'default' with + # the default scoring for classification/regression (accuracy/r2) + if scorer is None: + if isinstance(estimator, KerasGClassifier): + scorer = ['accuracy'] + if isinstance(estimator, KerasGRegressor): + scorer = ['r2'] + scorer = _check_multimetric_scoring(estimator, scoring=scorer) # handle test (first) split @@ -499,8 +509,15 @@ else: predictions = estimator.predict(X_test) - y_true = y_test - sk_scores = _score(estimator, X_test, y_test, scorer) + # Un-do OHE of the validation labels + if len(y_test.shape) == 2: + rounded_test_labels = np.argmax(y_test, axis=1) + y_true = rounded_test_labels + sk_scores = _score(estimator, X_test, rounded_test_labels, scorer) + else: + y_true = y_test + sk_scores = _score(estimator, X_test, y_true, scorer) + scores.update(sk_scores) # handle output diff -r 74adae8d7b0f -r 46b43ee6d367 main_macros.xml --- a/main_macros.xml Mon Oct 02 10:30:40 2023 +0000 +++ b/main_macros.xml Fri Nov 03 22:33:22 2023 +0000 @@ -1,5 +1,5 @@ - 1.0.10.0 + 1.0.11.0 21.05