diff keras_train_and_eval.py @ 46:0087d6db4290 draft

planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 5eca9041ce0154eded5aec07195502d5eb3cdd4f
author bgruening
date Fri, 03 Nov 2023 23:03:23 +0000
parents f3dfa4bdf87e
children
line wrap: on
line diff
--- a/keras_train_and_eval.py	Mon Oct 02 09:37:47 2023 +0000
+++ b/keras_train_and_eval.py	Fri Nov 03 23:03:23 2023 +0000
@@ -10,6 +10,8 @@
 from galaxy_ml.keras_galaxy_models import (
     _predict_generator,
     KerasGBatchClassifier,
+    KerasGClassifier,
+    KerasGRegressor
 )
 from galaxy_ml.model_persist import dump_model_to_h5, load_model_from_h5
 from galaxy_ml.model_validations import train_test_split
@@ -400,8 +402,16 @@
     # handle scorer, convert to scorer dict
     scoring = params["experiment_schemes"]["metrics"]["scoring"]
     scorer = get_scoring(scoring)
-    if not isinstance(scorer, (dict, list)):
-        scorer = [scoring["primary_scoring"]]
+
+    # We get 'None' back from the call to 'get_scoring()' if
+    # the primary scoring is 'default'. Replace 'default' with
+    # the default scoring for classification/regression (accuracy/r2)
+    if scorer is None:
+        if isinstance(estimator, KerasGClassifier):
+            scorer = ['accuracy']
+        if isinstance(estimator, KerasGRegressor):
+            scorer = ['r2']
+
     scorer = _check_multimetric_scoring(estimator, scoring=scorer)
 
     # handle test (first) split
@@ -499,8 +509,15 @@
         else:
             predictions = estimator.predict(X_test)
 
-        y_true = y_test
-        sk_scores = _score(estimator, X_test, y_test, scorer)
+        # Un-do OHE of the validation labels
+        if len(y_test.shape) == 2:
+            rounded_test_labels = np.argmax(y_test, axis=1)
+            y_true = rounded_test_labels
+            sk_scores = _score(estimator, X_test, rounded_test_labels, scorer)
+        else:
+            y_true = y_test
+            sk_scores = _score(estimator, X_test, y_true, scorer)
+
         scores.update(sk_scores)
 
     # handle output