diff base_model_trainer.py @ 13:f07850192bc2 draft default tip

planemo upload for repository https://github.com/goeckslab/gleam commit 84d5cd0b1fa5c1ff0ad892bc39c95dad1ceb4920
author goeckslab
date Sat, 08 Nov 2025 14:20:33 +0000
parents e2a6fed32d54
children
line wrap: on
line diff
--- a/base_model_trainer.py	Mon Sep 08 22:39:12 2025 +0000
+++ b/base_model_trainer.py	Sat Nov 08 14:20:33 2025 +0000
@@ -199,6 +199,28 @@
         self.exp.setup(self.data, **self.setup_params)
         self.setup_params.update(self.user_kwargs)
 
+    def _normalize_metric(self, m: str) -> str:
+        if not m:
+            return "R2" if self.task_type == "regression" else "Accuracy"
+        m_low = str(m).strip().lower()
+        alias = {
+            "auc": "AUC", "roc_auc": "AUC", "roc-auc": "AUC",
+            "accuracy": "Accuracy",
+            "precision": "Precision",
+            "recall": "Recall",
+            "f1": "F1",
+            "kappa": "Kappa",
+            "logloss": "Log Loss", "log_loss": "Log Loss",
+            "pr_auc": "PR-AUC-Weighted", "prauc": "PR-AUC-Weighted",
+            "r2": "R2",
+            "mae": "MAE",
+            "mse": "MSE",
+            "rmse": "RMSE",
+            "rmsle": "RMSLE",
+            "mape": "MAPE",
+        }
+        return alias.get(m_low, m)
+
     def train_model(self):
         LOG.info("Training and selecting the best model")
         if self.task_type == "classification":
@@ -222,6 +244,15 @@
         if getattr(self, "cross_validation_folds", None) is not None:
             compare_kwargs["fold"] = self.cross_validation_folds
 
+        chosen_metric = self._normalize_metric(getattr(self, "best_model_metric", None))
+        if chosen_metric:
+            compare_kwargs["sort"] = chosen_metric
+            self.chosen_metric_label = chosen_metric
+            try:
+                setattr(self.exp, "_fold_metric", chosen_metric)
+            except Exception as e:
+                LOG.warning(f"Failed to set '_fold_metric' to '{chosen_metric}': {e}", exc_info=True)
+
         LOG.info(f"compare_models kwargs: {compare_kwargs}")
         self.best_model = self.exp.compare_models(**compare_kwargs)
         self.results = self.exp.pull()
@@ -369,8 +400,8 @@
             else:
                 dv = v if v is not None else "None"
             setup_rows.append([key, dv])
-        if hasattr(self.exp, "_fold_metric"):
-            setup_rows.append(["best_model_metric", self.exp._fold_metric])
+        if getattr(self, "chosen_metric_label", None):
+            setup_rows.append(["Best Model Metric", self.chosen_metric_label])
 
         df_setup = pd.DataFrame(setup_rows, columns=["Parameter", "Value"])
         df_setup.to_csv(