Mercurial > repos > goeckslab > pycaret_predict
comparison base_model_trainer.py @ 9:c6c1f8777aae draft default tip
planemo upload for repository https://github.com/goeckslab/gleam commit 4a11e8a4c4e9daa884bddedfa47090476c517667
author | goeckslab |
---|---|
date | Thu, 31 Jul 2025 15:41:24 +0000 |
parents | 1aed7d47c5ec |
children |
comparison
equal
deleted
inserted
replaced
8:1aed7d47c5ec | 9:c6c1f8777aae |
---|---|
173 self.best_model = self.exp.tune_model(self.best_model) | 173 self.best_model = self.exp.tune_model(self.best_model) |
174 self.results = self.exp.pull() | 174 self.results = self.exp.pull() |
175 | 175 |
176 if self.task_type == "classification": | 176 if self.task_type == "classification": |
177 self.results.rename(columns={"AUC": "ROC-AUC"}, inplace=True) | 177 self.results.rename(columns={"AUC": "ROC-AUC"}, inplace=True) |
178 _ = self.exp.predict_model(self.best_model) | 178 |
179 prob_thresh = getattr(self, "probability_threshold", None) | |
180 if self.task_type == "classification" and prob_thresh is not None: | |
181 _ = self.exp.predict_model(self.best_model, probability_threshold=prob_thresh) | |
182 else: | |
183 _ = self.exp.predict_model(self.best_model) | |
184 | |
179 self.test_result_df = self.exp.pull() | 185 self.test_result_df = self.exp.pull() |
180 if self.task_type == "classification": | 186 if self.task_type == "classification": |
181 self.test_result_df.rename(columns={"AUC": "ROC-AUC"}, inplace=True) | 187 self.test_result_df.rename(columns={"AUC": "ROC-AUC"}, inplace=True) |
182 | 188 |
183 def save_model(self): | 189 def save_model(self): |
231 best_model_name = str(self.results.iloc[0]["Model"]) | 237 best_model_name = str(self.results.iloc[0]["Model"]) |
232 except Exception: | 238 except Exception: |
233 best_model_name = type(self.best_model).__name__ | 239 best_model_name = type(self.best_model).__name__ |
234 LOG.info(f"Best model determined as: {best_model_name}") | 240 LOG.info(f"Best model determined as: {best_model_name}") |
235 | 241 |
236 # 2) Compute training sample count | 242 # 2) Compute training sample count |
237 try: | 243 try: |
238 n_train = self.exp.X_train.shape[0] | 244 n_train = self.exp.X_train.shape[0] |
239 except Exception: | 245 except Exception: |
240 n_train = getattr(self.exp, "X_train_transformed", pd.DataFrame()).shape[0] | 246 n_train = getattr(self.exp, "X_train_transformed", pd.DataFrame()).shape[0] |
241 total_rows = self.data.shape[0] | 247 total_rows = self.data.shape[0] |
242 | 248 |
243 # 3) Build setup parameters table | 249 # 3) Build setup parameters table |
244 all_params = self.setup_params | 250 all_params = self.setup_params.copy() |
251 if self.task_type == "classification" and hasattr(self, "probability_threshold"): | |
252 all_params["probability_threshold"] = self.probability_threshold | |
253 | |
245 display_keys = [ | 254 display_keys = [ |
246 "Target", | 255 "Target", |
247 "Session ID", | 256 "Session ID", |
248 "Train Size", | 257 "Train Size", |
249 "Normalize", | 258 "Normalize", |
253 "Remove Outliers", | 262 "Remove Outliers", |
254 "Remove Multicollinearity", | 263 "Remove Multicollinearity", |
255 "Polynomial Features", | 264 "Polynomial Features", |
256 "Fix Imbalance", | 265 "Fix Imbalance", |
257 "Models", | 266 "Models", |
267 "Probability Threshold", | |
258 ] | 268 ] |
259 setup_rows = [] | 269 setup_rows = [] |
260 for key in display_keys: | 270 for key in display_keys: |
261 pk = key.lower().replace(" ", "_") | 271 pk = key.lower().replace(" ", "_") |
262 v = all_params.get(pk) | 272 v = all_params.get(pk) |
279 dv = bool(v) | 289 dv = bool(v) |
280 elif key == "Cross Validation Folds": | 290 elif key == "Cross Validation Folds": |
281 dv = v if v is not None else "None" | 291 dv = v if v is not None else "None" |
282 elif key == "Models": | 292 elif key == "Models": |
283 dv = ", ".join(map(str, v)) if isinstance(v, (list, tuple)) else "None" | 293 dv = ", ".join(map(str, v)) if isinstance(v, (list, tuple)) else "None" |
294 elif key == "Probability Threshold": | |
295 dv = v if v is not None else "None" | |
284 else: | 296 else: |
285 dv = v if v is not None else "None" | 297 dv = v if v is not None else "None" |
286 setup_rows.append([key, dv]) | 298 setup_rows.append([key, dv]) |
287 if hasattr(self.exp, "_fold_metric"): | 299 if hasattr(self.exp, "_fold_metric"): |
288 setup_rows.append(["best_model_metric", self.exp._fold_metric]) | 300 setup_rows.append(["best_model_metric", self.exp._fold_metric]) |