diff ludwig_backend.py @ 20:64872c48a21f draft

planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
author goeckslab
date Tue, 06 Jan 2026 15:35:11 +0000
parents c460abae83eb
children
line wrap: on
line diff
--- a/ludwig_backend.py	Thu Dec 18 16:59:58 2025 +0000
+++ b/ludwig_backend.py	Tue Jan 06 15:35:11 2026 +0000
@@ -416,7 +416,8 @@
                 "binary": "roc_auc",
                 "category": "accuracy",
             }
-            allowed_map = {
+            # Safe defaults when Ludwig's registry isn't available.
+            safe_allowed_map = {
                 "regression": {
                     "mean_absolute_error",
                     "mean_squared_error",
@@ -429,13 +430,10 @@
                     "accuracy",
                     "precision",
                     "recall",
-                    "specificity",
                     "loss",
                 },
                 "category": {
                     "accuracy",
-                    "balanced_accuracy",
-                    "hits_at_k",
                     "loss",
                 },
             }
@@ -472,6 +470,16 @@
                     )
                     if isinstance(metrics_attr, dict):
                         registry_metrics = set(metrics_attr.keys())
+                    elif isinstance(metrics_attr, (list, tuple, set)):
+                        extracted = set()
+                        for item in metrics_attr:
+                            if isinstance(item, str):
+                                extracted.add(item)
+                            elif hasattr(item, "name"):
+                                extracted.add(str(item.name))
+                            elif hasattr(item, "__name__"):
+                                extracted.add(str(item.__name__))
+                        registry_metrics = extracted or None
             except Exception as exc:
                 logger.debug(
                     "Could not inspect Ludwig metrics for output type %s: %s",
@@ -479,12 +487,10 @@
                     exc,
                 )
 
-            allowed = set(allowed_map.get(task, set()))
-            if registry_metrics:
-                # Only keep metrics that Ludwig actually exposes for this output type;
-                # if the intersection is empty, fall back to the registry set.
-                intersected = allowed.intersection(registry_metrics)
-                allowed = intersected or registry_metrics
+            allowed = set(safe_allowed_map.get(task, set()))
+            if registry_metrics is not None:
+                # Use Ludwig's registry when available; fall back to safe defaults if it's empty.
+                allowed = registry_metrics or allowed
 
             if allowed and metric not in allowed:
                 fallback_candidates = [