Mercurial > repos > goeckslab > pycaret_predict
comparison base_model_trainer.py @ 9:c6c1f8777aae draft
planemo upload for repository https://github.com/goeckslab/gleam commit 4a11e8a4c4e9daa884bddedfa47090476c517667
| author | goeckslab |
|---|---|
| date | Thu, 31 Jul 2025 15:41:24 +0000 |
| parents | 1aed7d47c5ec |
| children | e2a6fed32d54 |
comparison
equal
deleted
inserted
replaced
| 8:1aed7d47c5ec | 9:c6c1f8777aae |
|---|---|
| 173 self.best_model = self.exp.tune_model(self.best_model) | 173 self.best_model = self.exp.tune_model(self.best_model) |
| 174 self.results = self.exp.pull() | 174 self.results = self.exp.pull() |
| 175 | 175 |
| 176 if self.task_type == "classification": | 176 if self.task_type == "classification": |
| 177 self.results.rename(columns={"AUC": "ROC-AUC"}, inplace=True) | 177 self.results.rename(columns={"AUC": "ROC-AUC"}, inplace=True) |
| 178 _ = self.exp.predict_model(self.best_model) | 178 |
| 179 prob_thresh = getattr(self, "probability_threshold", None) | |
| 180 if self.task_type == "classification" and prob_thresh is not None: | |
| 181 _ = self.exp.predict_model(self.best_model, probability_threshold=prob_thresh) | |
| 182 else: | |
| 183 _ = self.exp.predict_model(self.best_model) | |
| 184 | |
| 179 self.test_result_df = self.exp.pull() | 185 self.test_result_df = self.exp.pull() |
| 180 if self.task_type == "classification": | 186 if self.task_type == "classification": |
| 181 self.test_result_df.rename(columns={"AUC": "ROC-AUC"}, inplace=True) | 187 self.test_result_df.rename(columns={"AUC": "ROC-AUC"}, inplace=True) |
| 182 | 188 |
| 183 def save_model(self): | 189 def save_model(self): |
| 231 best_model_name = str(self.results.iloc[0]["Model"]) | 237 best_model_name = str(self.results.iloc[0]["Model"]) |
| 232 except Exception: | 238 except Exception: |
| 233 best_model_name = type(self.best_model).__name__ | 239 best_model_name = type(self.best_model).__name__ |
| 234 LOG.info(f"Best model determined as: {best_model_name}") | 240 LOG.info(f"Best model determined as: {best_model_name}") |
| 235 | 241 |
| 236 # 2) Compute training sample count | 242 # 2) Compute training sample count |
| 237 try: | 243 try: |
| 238 n_train = self.exp.X_train.shape[0] | 244 n_train = self.exp.X_train.shape[0] |
| 239 except Exception: | 245 except Exception: |
| 240 n_train = getattr(self.exp, "X_train_transformed", pd.DataFrame()).shape[0] | 246 n_train = getattr(self.exp, "X_train_transformed", pd.DataFrame()).shape[0] |
| 241 total_rows = self.data.shape[0] | 247 total_rows = self.data.shape[0] |
| 242 | 248 |
| 243 # 3) Build setup parameters table | 249 # 3) Build setup parameters table |
| 244 all_params = self.setup_params | 250 all_params = self.setup_params.copy() |
| 251 if self.task_type == "classification" and hasattr(self, "probability_threshold"): | |
| 252 all_params["probability_threshold"] = self.probability_threshold | |
| 253 | |
| 245 display_keys = [ | 254 display_keys = [ |
| 246 "Target", | 255 "Target", |
| 247 "Session ID", | 256 "Session ID", |
| 248 "Train Size", | 257 "Train Size", |
| 249 "Normalize", | 258 "Normalize", |
| 253 "Remove Outliers", | 262 "Remove Outliers", |
| 254 "Remove Multicollinearity", | 263 "Remove Multicollinearity", |
| 255 "Polynomial Features", | 264 "Polynomial Features", |
| 256 "Fix Imbalance", | 265 "Fix Imbalance", |
| 257 "Models", | 266 "Models", |
| 267 "Probability Threshold", | |
| 258 ] | 268 ] |
| 259 setup_rows = [] | 269 setup_rows = [] |
| 260 for key in display_keys: | 270 for key in display_keys: |
| 261 pk = key.lower().replace(" ", "_") | 271 pk = key.lower().replace(" ", "_") |
| 262 v = all_params.get(pk) | 272 v = all_params.get(pk) |
| 279 dv = bool(v) | 289 dv = bool(v) |
| 280 elif key == "Cross Validation Folds": | 290 elif key == "Cross Validation Folds": |
| 281 dv = v if v is not None else "None" | 291 dv = v if v is not None else "None" |
| 282 elif key == "Models": | 292 elif key == "Models": |
| 283 dv = ", ".join(map(str, v)) if isinstance(v, (list, tuple)) else "None" | 293 dv = ", ".join(map(str, v)) if isinstance(v, (list, tuple)) else "None" |
| 294 elif key == "Probability Threshold": | |
| 295 dv = v if v is not None else "None" | |
| 284 else: | 296 else: |
| 285 dv = v if v is not None else "None" | 297 dv = v if v is not None else "None" |
| 286 setup_rows.append([key, dv]) | 298 setup_rows.append([key, dv]) |
| 287 if hasattr(self.exp, "_fold_metric"): | 299 if hasattr(self.exp, "_fold_metric"): |
| 288 setup_rows.append(["best_model_metric", self.exp._fold_metric]) | 300 setup_rows.append(["best_model_metric", self.exp._fold_metric]) |
