changeset 13:f07850192bc2 draft default tip

planemo upload for repository https://github.com/goeckslab/gleam commit 84d5cd0b1fa5c1ff0ad892bc39c95dad1ceb4920
author goeckslab
date Sat, 08 Nov 2025 14:20:33 +0000
parents e674b9e946fb
children
files base_model_trainer.py pycaret_macros.xml pycaret_train.py test-data/expected_best_model_classification_customized.csv test-data/expected_model_classification_customized.h5
diffstat 5 files changed, 43 insertions(+), 22 deletions(-) [+]
line wrap: on
line diff
--- a/base_model_trainer.py	Mon Sep 08 22:39:12 2025 +0000
+++ b/base_model_trainer.py	Sat Nov 08 14:20:33 2025 +0000
@@ -199,6 +199,28 @@
         self.exp.setup(self.data, **self.setup_params)
         self.setup_params.update(self.user_kwargs)
 
+    def _normalize_metric(self, m: str) -> str:
+        if not m:
+            return "R2" if self.task_type == "regression" else "Accuracy"
+        m_low = str(m).strip().lower()
+        alias = {
+            "auc": "AUC", "roc_auc": "AUC", "roc-auc": "AUC",
+            "accuracy": "Accuracy",
+            "precision": "Precision",
+            "recall": "Recall",
+            "f1": "F1",
+            "kappa": "Kappa",
+            "logloss": "Log Loss", "log_loss": "Log Loss",
+            "pr_auc": "PR-AUC-Weighted", "prauc": "PR-AUC-Weighted",
+            "r2": "R2",
+            "mae": "MAE",
+            "mse": "MSE",
+            "rmse": "RMSE",
+            "rmsle": "RMSLE",
+            "mape": "MAPE",
+        }
+        return alias.get(m_low, m)
+
     def train_model(self):
         LOG.info("Training and selecting the best model")
         if self.task_type == "classification":
@@ -222,6 +244,15 @@
         if getattr(self, "cross_validation_folds", None) is not None:
             compare_kwargs["fold"] = self.cross_validation_folds
 
+        chosen_metric = self._normalize_metric(getattr(self, "best_model_metric", None))
+        if chosen_metric:
+            compare_kwargs["sort"] = chosen_metric
+            self.chosen_metric_label = chosen_metric
+            try:
+                setattr(self.exp, "_fold_metric", chosen_metric)
+            except Exception as e:
+                LOG.warning(f"Failed to set '_fold_metric' to '{chosen_metric}': {e}", exc_info=True)
+
         LOG.info(f"compare_models kwargs: {compare_kwargs}")
         self.best_model = self.exp.compare_models(**compare_kwargs)
         self.results = self.exp.pull()
@@ -369,8 +400,8 @@
             else:
                 dv = v if v is not None else "None"
             setup_rows.append([key, dv])
-        if hasattr(self.exp, "_fold_metric"):
-            setup_rows.append(["best_model_metric", self.exp._fold_metric])
+        if getattr(self, "chosen_metric_label", None):
+            setup_rows.append(["Best Model Metric", self.chosen_metric_label])
 
         df_setup = pd.DataFrame(setup_rows, columns=["Parameter", "Value"])
         df_setup.to_csv(
--- a/pycaret_macros.xml	Mon Sep 08 22:39:12 2025 +0000
+++ b/pycaret_macros.xml	Sat Nov 08 14:20:33 2025 +0000
@@ -1,5 +1,5 @@
 <macros>
-    <token name="@TABULAR_LEARNER_VERSION@">0.1.0.1</token>
+    <token name="@TABULAR_LEARNER_VERSION@">0.1.1</token>
     <token name="@PYCARET_VERSION@">3.3.2</token>
     <token name="@SUFFIX@">1</token>
     <token name="@PYCARET_PREDICT_VERSION@">@PYCARET_VERSION@+@SUFFIX@</token>
--- a/pycaret_train.py	Mon Sep 08 22:39:12 2025 +0000
+++ b/pycaret_train.py	Sat Nov 08 14:20:33 2025 +0000
@@ -120,6 +120,12 @@
         default=None,
         help="Probability threshold for classification decision,",
     )
+    parser.add_argument(
+        "--best_model_metric",
+        type=str,
+        default=None,
+        help="Metric used to select the best model (e.g. AUC, Accuracy, R2, RMSE).",
+    )
 
     args = parser.parse_args()
 
@@ -144,6 +150,7 @@
         "fix_imbalance": args.fix_imbalance,
         "tune_model": args.tune_model,
         "probability_threshold": args.probability_threshold,
+        "best_model_metric": args.best_model_metric,
     }
     LOG.info(f"Model kwargs: {model_kwargs}")
 
--- a/test-data/expected_best_model_classification_customized.csv	Mon Sep 08 22:39:12 2025 +0000
+++ b/test-data/expected_best_model_classification_customized.csv	Sat Nov 08 14:20:33 2025 +0000
@@ -1,20 +1,3 @@
 Parameter,Value
-boosting_type,gbdt
-class_weight,
-colsample_bytree,1.0
-importance_type,split
-learning_rate,0.1
-max_depth,-1
-min_child_samples,20
-min_child_weight,0.001
-min_split_gain,0.0
-n_estimators,100
-n_jobs,-1
-num_leaves,31
-objective,
-random_state,42
-reg_alpha,0.0
-reg_lambda,0.0
-subsample,1.0
-subsample_for_bin,200000
-subsample_freq,0
+priors,
+var_smoothing,1e-09
Binary file test-data/expected_model_classification_customized.h5 has changed