comparison base_model_trainer.py @ 9:c6c1f8777aae draft default tip

planemo upload for repository https://github.com/goeckslab/gleam commit 4a11e8a4c4e9daa884bddedfa47090476c517667
author goeckslab
date Thu, 31 Jul 2025 15:41:24 +0000
parents 1aed7d47c5ec
children
comparison
equal deleted inserted replaced
8:1aed7d47c5ec 9:c6c1f8777aae
173 self.best_model = self.exp.tune_model(self.best_model) 173 self.best_model = self.exp.tune_model(self.best_model)
174 self.results = self.exp.pull() 174 self.results = self.exp.pull()
175 175
176 if self.task_type == "classification": 176 if self.task_type == "classification":
177 self.results.rename(columns={"AUC": "ROC-AUC"}, inplace=True) 177 self.results.rename(columns={"AUC": "ROC-AUC"}, inplace=True)
178 _ = self.exp.predict_model(self.best_model) 178
179 prob_thresh = getattr(self, "probability_threshold", None)
180 if self.task_type == "classification" and prob_thresh is not None:
181 _ = self.exp.predict_model(self.best_model, probability_threshold=prob_thresh)
182 else:
183 _ = self.exp.predict_model(self.best_model)
184
179 self.test_result_df = self.exp.pull() 185 self.test_result_df = self.exp.pull()
180 if self.task_type == "classification": 186 if self.task_type == "classification":
181 self.test_result_df.rename(columns={"AUC": "ROC-AUC"}, inplace=True) 187 self.test_result_df.rename(columns={"AUC": "ROC-AUC"}, inplace=True)
182 188
183 def save_model(self): 189 def save_model(self):
231 best_model_name = str(self.results.iloc[0]["Model"]) 237 best_model_name = str(self.results.iloc[0]["Model"])
232 except Exception: 238 except Exception:
233 best_model_name = type(self.best_model).__name__ 239 best_model_name = type(self.best_model).__name__
234 LOG.info(f"Best model determined as: {best_model_name}") 240 LOG.info(f"Best model determined as: {best_model_name}")
235 241
236 # 2) Compute training sample count 242 # 2) Compute training sample count
237 try: 243 try:
238 n_train = self.exp.X_train.shape[0] 244 n_train = self.exp.X_train.shape[0]
239 except Exception: 245 except Exception:
240 n_train = getattr(self.exp, "X_train_transformed", pd.DataFrame()).shape[0] 246 n_train = getattr(self.exp, "X_train_transformed", pd.DataFrame()).shape[0]
241 total_rows = self.data.shape[0] 247 total_rows = self.data.shape[0]
242 248
243 # 3) Build setup parameters table 249 # 3) Build setup parameters table
244 all_params = self.setup_params 250 all_params = self.setup_params.copy()
251 if self.task_type == "classification" and hasattr(self, "probability_threshold"):
252 all_params["probability_threshold"] = self.probability_threshold
253
245 display_keys = [ 254 display_keys = [
246 "Target", 255 "Target",
247 "Session ID", 256 "Session ID",
248 "Train Size", 257 "Train Size",
249 "Normalize", 258 "Normalize",
253 "Remove Outliers", 262 "Remove Outliers",
254 "Remove Multicollinearity", 263 "Remove Multicollinearity",
255 "Polynomial Features", 264 "Polynomial Features",
256 "Fix Imbalance", 265 "Fix Imbalance",
257 "Models", 266 "Models",
267 "Probability Threshold",
258 ] 268 ]
259 setup_rows = [] 269 setup_rows = []
260 for key in display_keys: 270 for key in display_keys:
261 pk = key.lower().replace(" ", "_") 271 pk = key.lower().replace(" ", "_")
262 v = all_params.get(pk) 272 v = all_params.get(pk)
279 dv = bool(v) 289 dv = bool(v)
280 elif key == "Cross Validation Folds": 290 elif key == "Cross Validation Folds":
281 dv = v if v is not None else "None" 291 dv = v if v is not None else "None"
282 elif key == "Models": 292 elif key == "Models":
283 dv = ", ".join(map(str, v)) if isinstance(v, (list, tuple)) else "None" 293 dv = ", ".join(map(str, v)) if isinstance(v, (list, tuple)) else "None"
294 elif key == "Probability Threshold":
295 dv = v if v is not None else "None"
284 else: 296 else:
285 dv = v if v is not None else "None" 297 dv = v if v is not None else "None"
286 setup_rows.append([key, dv]) 298 setup_rows.append([key, dv])
287 if hasattr(self.exp, "_fold_metric"): 299 if hasattr(self.exp, "_fold_metric"):
288 setup_rows.append(["best_model_metric", self.exp._fold_metric]) 300 setup_rows.append(["best_model_metric", self.exp._fold_metric])