diff base_model_trainer.py @ 9:c6c1f8777aae draft default tip

planemo upload for repository https://github.com/goeckslab/gleam commit 4a11e8a4c4e9daa884bddedfa47090476c517667
author goeckslab
date Thu, 31 Jul 2025 15:41:24 +0000
parents 1aed7d47c5ec
children
line wrap: on
line diff
--- a/base_model_trainer.py	Fri Jul 25 19:02:32 2025 +0000
+++ b/base_model_trainer.py	Thu Jul 31 15:41:24 2025 +0000
@@ -175,7 +175,13 @@
 
         if self.task_type == "classification":
             self.results.rename(columns={"AUC": "ROC-AUC"}, inplace=True)
-        _ = self.exp.predict_model(self.best_model)
+
+        prob_thresh = getattr(self, "probability_threshold", None)
+        if self.task_type == "classification" and prob_thresh is not None:
+            _ = self.exp.predict_model(self.best_model, probability_threshold=prob_thresh)
+        else:
+            _ = self.exp.predict_model(self.best_model)
+
         self.test_result_df = self.exp.pull()
         if self.task_type == "classification":
             self.test_result_df.rename(columns={"AUC": "ROC-AUC"}, inplace=True)
@@ -233,7 +239,7 @@
             best_model_name = type(self.best_model).__name__
         LOG.info(f"Best model determined as: {best_model_name}")
 
-        # 2) Compute training sample count
+    # 2) Compute training sample count
         try:
             n_train = self.exp.X_train.shape[0]
         except Exception:
@@ -241,7 +247,10 @@
         total_rows = self.data.shape[0]
 
         # 3) Build setup parameters table
-        all_params = self.setup_params
+        all_params = self.setup_params.copy()
+        if self.task_type == "classification" and hasattr(self, "probability_threshold"):
+            all_params["probability_threshold"] = self.probability_threshold
+
         display_keys = [
             "Target",
             "Session ID",
@@ -255,6 +264,7 @@
             "Polynomial Features",
             "Fix Imbalance",
             "Models",
+            "Probability Threshold",
         ]
         setup_rows = []
         for key in display_keys:
@@ -281,6 +291,8 @@
                 dv = v if v is not None else "None"
             elif key == "Models":
                 dv = ", ".join(map(str, v)) if isinstance(v, (list, tuple)) else "None"
+            elif key == "Probability Threshold":
+                dv = v if v is not None else "None"
             else:
                 dv = v if v is not None else "None"
             setup_rows.append([key, dv])