changeset 18:bbf30253c99f draft default tip

planemo upload for repository https://github.com/goeckslab/gleam.git commit c155acd3616dfac920b17653179d7bc38ba48e39
author goeckslab
date Sun, 14 Dec 2025 03:27:12 +0000
parents db9be962dc13
children
files constants.py image_learner.xml image_learner_cli.py ludwig_backend.py
diffstat 4 files changed, 76 insertions(+), 55 deletions(-) [+]
line wrap: on
line diff
--- a/constants.py	Wed Dec 10 00:24:13 2025 +0000
+++ b/constants.py	Sun Dec 14 03:27:12 2025 +0000
@@ -174,6 +174,7 @@
 }
 METRIC_DISPLAY_NAMES = {
     "accuracy": "Accuracy",
+    "balanced_accuracy": "Balanced Accuracy",
     "accuracy_micro": "Micro Accuracy",
     "loss": "Loss",
     "roc_auc": "ROC-AUC",
--- a/image_learner.xml	Wed Dec 10 00:24:13 2025 +0000
+++ b/image_learner.xml	Sun Dec 14 03:27:12 2025 +0000
@@ -106,33 +106,26 @@
                 <param name="validation_metric_binary" type="select" optional="true" label="Validation metric (binary)" help="Metrics accepted by Ludwig for binary outputs.">
                     <option value="roc_auc" selected="true">ROC-AUC</option>
                     <option value="accuracy">Accuracy</option>
-                    <option value="balanced_accuracy">Balanced Accuracy</option>
                     <option value="precision">Precision</option>
                     <option value="recall">Recall</option>
-                    <option value="f1">F1</option>
                     <option value="specificity">Specificity</option>
-                    <option value="log_loss">Log Loss</option>
                     <option value="loss">Loss</option>
                 </param>
             </when>
             <when value="classification">
                 <param name="validation_metric_multiclass" type="select" optional="true" label="Validation metric (multi-class)" help="Metrics accepted by Ludwig for multi-class outputs.">
                     <option value="accuracy" selected="true">Accuracy</option>
-                    <option value="roc_auc">ROC-AUC</option>
+                    <option value="balanced_accuracy">Balanced Accuracy</option>
+                    <option value="hits_at_k">Hits at K (top-k)</option>
                     <option value="loss">Loss</option>
-                    <option value="balanced_accuracy">Balanced Accuracy</option>
-                    <option value="precision">Precision</option>
-                    <option value="recall">Recall</option>
-                    <option value="f1">F1</option>
-                    <option value="specificity">Specificity</option>
-                    <option value="log_loss">Log Loss</option>
                 </param>
             </when>
             <when value="regression">
                 <param name="validation_metric_regression" type="select" optional="true" label="Validation metric (regression)" help="Metrics accepted by Ludwig for regression outputs.">
-                    <option value="mae" selected="true">MAE</option>
-                    <option value="mse">MSE</option>
-                    <option value="rmse">RMSE</option>
+                    <option value="mean_squared_error" selected="true">Mean Squared Error</option>
+                    <option value="mean_absolute_error">Mean Absolute Error</option>
+                    <option value="root_mean_squared_error">Root Mean Squared Error</option>
+                    <option value="root_mean_squared_percentage_error">Root Mean Squared Percentage Error</option>
                     <option value="loss">Loss</option>
                 </param>
             </when>
--- a/image_learner_cli.py	Wed Dec 10 00:24:13 2025 +0000
+++ b/image_learner_cli.py	Sun Dec 14 03:27:12 2025 +0000
@@ -145,26 +145,11 @@
     parser.add_argument(
         "--validation-metric",
         type=str,
-        default="roc_auc",
-        choices=[
-            "accuracy",
-            "loss",
-            "roc_auc",
-            "balanced_accuracy",
-            "precision",
-            "recall",
-            "f1",
-            "specificity",
-            "log_loss",
-            "pearson_r",
-            "mae",
-            "mse",
-            "rmse",
-            "mape",
-            "r2",
-            "explained_variance",
-        ],
-        help="Metric Ludwig uses to select the best model during training/validation.",
+        default=None,
+        help=(
+            "Metric Ludwig uses to select the best model during training/validation. "
+            "Leave unset to let the tool pick a default for the inferred task."
+        ),
     )
     parser.add_argument(
         "--target-column",
--- a/ludwig_backend.py	Wed Dec 10 00:24:13 2025 +0000
+++ b/ludwig_backend.py	Sun Dec 14 03:27:12 2025 +0000
@@ -403,42 +403,38 @@
             # No explicit resize provided; keep for reporting purposes
             config_params.setdefault("image_size", "original")
 
-        def _resolve_validation_metric(task: str, requested: Optional[str]) -> Optional[str]:
-            """Pick a validation metric that Ludwig will accept for the resolved task."""
+        def _resolve_validation_metric(
+            task: str, requested: Optional[str], output_feature: Dict[str, Any]
+        ) -> Optional[str]:
+            """
+            Pick a validation metric that Ludwig will accept for the resolved task/output.
+            If the requested metric is invalid, fall back to a safe option or omit it entirely.
+            """
             default_map = {
-                "regression": "pearson_r",
+                "regression": "mean_squared_error",
                 "binary": "roc_auc",
                 "category": "accuracy",
             }
             allowed_map = {
                 "regression": {
-                    "pearson_r",
                     "mean_absolute_error",
                     "mean_squared_error",
                     "root_mean_squared_error",
-                    "mean_absolute_percentage_error",
-                    "r2",
-                    "explained_variance",
+                    "root_mean_squared_percentage_error",
                     "loss",
                 },
-                # Ludwig rejects f1 and balanced_accuracy for binary outputs; keep to known-safe set.
                 "binary": {
                     "roc_auc",
                     "accuracy",
                     "precision",
                     "recall",
                     "specificity",
-                    "log_loss",
                     "loss",
                 },
                 "category": {
                     "accuracy",
                     "balanced_accuracy",
-                    "precision",
-                    "recall",
-                    "f1",
-                    "specificity",
-                    "log_loss",
+                    "hits_at_k",
                     "loss",
                 },
             }
@@ -447,25 +443,64 @@
                     "mae": "mean_absolute_error",
                     "mse": "mean_squared_error",
                     "rmse": "root_mean_squared_error",
-                    "mape": "mean_absolute_percentage_error",
+                    "rmspe": "root_mean_squared_percentage_error",
+                },
+                "category": {},
+                "binary": {
+                    "roc_auc": "roc_auc",
                 },
             }
 
             default_metric = default_map.get(task)
-            allowed = allowed_map.get(task, set())
             metric = requested or default_metric
-
             if metric is None:
                 return None
 
             metric = alias_map.get(task, {}).get(metric, metric)
 
-            if metric not in allowed:
+            # Prefer Ludwig's own metric registry when available; intersect with known-safe sets.
+            registry_metrics = None
+            try:
+                from ludwig.features.feature_registries import output_type_registry
+
+                feature_cls = output_type_registry.get(output_feature.get("type"))
+                if feature_cls:
+                    feature_obj = feature_cls(feature=output_feature)
+                    metrics_attr = getattr(feature_obj, "metric_functions", None) or getattr(
+                        feature_obj, "metrics", None
+                    )
+                    if isinstance(metrics_attr, dict):
+                        registry_metrics = set(metrics_attr.keys())
+            except Exception as exc:
+                logger.debug(
+                    "Could not inspect Ludwig metrics for output type %s: %s",
+                    output_feature.get("type"),
+                    exc,
+                )
+
+            allowed = set(allowed_map.get(task, set()))
+            if registry_metrics:
+                # Only keep metrics that Ludwig actually exposes for this output type;
+                # if the intersection is empty, fall back to the registry set.
+                intersected = allowed.intersection(registry_metrics)
+                allowed = intersected or registry_metrics
+
+            if allowed and metric not in allowed:
+                fallback_candidates = [
+                    default_metric if default_metric in allowed else None,
+                    "loss" if "loss" in allowed else None,
+                    next(iter(allowed), None),
+                ]
+                fallback = next((m for m in fallback_candidates if m in allowed), None)
                 if requested:
                     logger.warning(
-                        f"Validation metric '{requested}' is not supported for {task} outputs; using '{default_metric}' instead."
+                        "Validation metric '%s' is not supported for %s outputs; %s",
+                        requested,
+                        task,
+                        (f"using '{fallback}' instead." if fallback else "omitting validation_metric."),
                     )
-                metric = default_metric
+                metric = fallback
+
             return metric
 
         if task_type == "regression":
@@ -475,7 +510,11 @@
                 "decoder": {"type": "regressor"},
                 "loss": {"type": "mean_squared_error"},
             }
-            val_metric = _resolve_validation_metric("regression", config_params.get("validation_metric"))
+            val_metric = _resolve_validation_metric(
+                "regression",
+                config_params.get("validation_metric"),
+                output_feat,
+            )
 
         else:
             if num_unique_labels == 2:
@@ -495,6 +534,7 @@
             val_metric = _resolve_validation_metric(
                 "binary" if num_unique_labels == 2 else "category",
                 config_params.get("validation_metric"),
+                output_feat,
             )
 
         # Propagate the resolved validation metric (including any task-based fallback or alias normalization)
@@ -610,7 +650,9 @@
             raise RuntimeError("Ludwig argument error.") from e
         except Exception:
             logger.error(
-                "LudwigDirectBackend: Experiment execution error.",
+                "LudwigDirectBackend: Experiment execution error. "
+                "If this relates to validation_metric, confirm the XML task selection "
+                "passes a metric that matches the inferred task type.",
                 exc_info=True,
             )
             raise