Mercurial > repos > goeckslab > pycaret_predict
comparison base_model_trainer.py @ 16:4fee4504646e draft default tip
planemo upload for repository https://github.com/goeckslab/gleam commit 2b826699ef9518d4610f5cfb6468ce719ec8039d
| author | goeckslab |
|---|---|
| date | Fri, 28 Nov 2025 22:28:26 +0000 |
| parents | 7d78a6afc958 |
| children |
comparison
equal
deleted
inserted
replaced
| 15:a2aeeb754d76 | 16:4fee4504646e |
|---|---|
| 44 self.target = None | 44 self.target = None |
| 45 self.best_model = None | 45 self.best_model = None |
| 46 self.results = None | 46 self.results = None |
| 47 self.tuning_results = None | 47 self.tuning_results = None |
| 48 self.features_name = None | 48 self.features_name = None |
| 49 self.plot_feature_names = None | |
| 49 self.plots = {} | 50 self.plots = {} |
| 50 self.explainer_plots = {} | 51 self.explainer_plots = {} |
| 51 self.plots_explainer_html = None | 52 self.plots_explainer_html = None |
| 52 self.trees = [] | 53 self.trees = [] |
| 53 self.user_kwargs = kwargs.copy() | 54 self.user_kwargs = kwargs.copy() |
| 54 for key, value in self.user_kwargs.items(): | 55 for key, value in self.user_kwargs.items(): |
| 55 setattr(self, key, value) | 56 setattr(self, key, value) |
| 57 if not hasattr(self, "plot_feature_limit"): | |
| 58 self.plot_feature_limit = 30 | |
| 59 self._shap_row_cap = None | |
| 60 if getattr(self, "polynomial_features", False): | |
| 61 # Keep feature importance responsive by trimming plots/SHAP rows | |
| 62 try: | |
| 63 limit_val = int(self.plot_feature_limit) | |
| 64 except (TypeError, ValueError): | |
| 65 limit_val = 30 | |
| 66 self.plot_feature_limit = min(limit_val, 15) | |
| 67 self._shap_row_cap = 200 | |
| 68 LOG.info( | |
| 69 "Polynomial features enabled; limiting feature plots to %s and SHAP rows to %s", | |
| 70 self.plot_feature_limit, | |
| 71 self._shap_row_cap, | |
| 72 ) | |
| 73 self.imputed_training_data = None | |
| 74 self._best_model_metric_used = None | |
| 56 self.setup_params = {} | 75 self.setup_params = {} |
| 57 self.test_file = test_file | 76 self.test_file = test_file |
| 58 self.test_data = None | 77 self.test_data = None |
| 59 | 78 |
| 60 if not self.output_dir: | 79 if not self.output_dir: |
| 125 # Update names after possible drop | 144 # Update names after possible drop |
| 126 names = self.data.columns.to_list() | 145 names = self.data.columns.to_list() |
| 127 LOG.info(f"Dataset columns after processing: {names}") | 146 LOG.info(f"Dataset columns after processing: {names}") |
| 128 | 147 |
| 129 self.features_name = [n for n in names if n != self.target] | 148 self.features_name = [n for n in names if n != self.target] |
| 130 | 149 self.plot_feature_names = self._select_plot_features(self.features_name) |
| 131 if getattr(self, "missing_value_strategy", None): | |
| 132 strat = self.missing_value_strategy | |
| 133 if strat == "mean": | |
| 134 self.data = self.data.fillna( | |
| 135 self.data.mean(numeric_only=True) | |
| 136 ) | |
| 137 elif strat == "median": | |
| 138 self.data = self.data.fillna( | |
| 139 self.data.median(numeric_only=True) | |
| 140 ) | |
| 141 elif strat == "drop": | |
| 142 self.data = self.data.dropna() | |
| 143 else: | |
| 144 self.data = self.data.fillna( | |
| 145 self.data.median(numeric_only=True) | |
| 146 ) | |
| 147 | 150 |
| 148 if self.test_file: | 151 if self.test_file: |
| 149 LOG.info(f"Loading test data from {self.test_file}") | 152 LOG.info(f"Loading test data from {self.test_file}") |
| 150 df_test = pd.read_csv( | 153 df_test = pd.read_csv( |
| 151 self.test_file, sep=None, engine="python" | 154 self.test_file, sep=None, engine="python" |
| 152 ) | 155 ) |
| 153 df_test.columns = df_test.columns.str.replace(".", "_") | 156 df_test.columns = df_test.columns.str.replace(".", "_") |
| 154 self.test_data = df_test | 157 self.test_data = df_test |
| 158 | |
| 159 def _select_plot_features(self, all_features): | |
| 160 limit = getattr(self, "plot_feature_limit", 30) | |
| 161 if not isinstance(limit, int) or limit <= 0: | |
| 162 LOG.info( | |
| 163 "Feature plotting limit disabled (plot_feature_limit=%s).", limit | |
| 164 ) | |
| 165 return all_features | |
| 166 if len(all_features) <= limit: | |
| 167 LOG.info( | |
| 168 "Feature plotting limit not needed (%s features <= limit %s).", | |
| 169 len(all_features), | |
| 170 limit, | |
| 171 ) | |
| 172 return all_features | |
| 173 df = self.data[all_features].copy() | |
| 174 numeric_cols = df.select_dtypes(include=["number"]).columns | |
| 175 ranked = [] | |
| 176 if len(numeric_cols) > 0: | |
| 177 variances = ( | |
| 178 df[numeric_cols] | |
| 179 .var() | |
| 180 .fillna(0) | |
| 181 .abs() | |
| 182 .sort_values(ascending=False) | |
| 183 ) | |
| 184 ranked = variances.index.tolist() | |
| 185 selected = [] | |
| 186 for col in ranked: | |
| 187 if len(selected) >= limit: | |
| 188 break | |
| 189 selected.append(col) | |
| 190 if len(selected) < limit: | |
| 191 for col in all_features: | |
| 192 if col in selected: | |
| 193 continue | |
| 194 selected.append(col) | |
| 195 if len(selected) >= limit: | |
| 196 break | |
| 197 LOG.info( | |
| 198 "Limiting feature-level plots to %s of %s available features (limit=%s).", | |
| 199 len(selected), | |
| 200 len(all_features), | |
| 201 limit, | |
| 202 ) | |
| 203 return selected | |
| 155 | 204 |
| 156 def setup_pycaret(self): | 205 def setup_pycaret(self): |
| 157 LOG.info("Initializing PyCaret") | 206 LOG.info("Initializing PyCaret") |
| 158 self.setup_params = { | 207 self.setup_params = { |
| 159 "target": self.target, | 208 "target": self.target, |
| 196 raise ValueError( | 245 raise ValueError( |
| 197 "task_type must be 'classification' or 'regression'" | 246 "task_type must be 'classification' or 'regression'" |
| 198 ) | 247 ) |
| 199 | 248 |
| 200 self.exp.setup(self.data, **self.setup_params) | 249 self.exp.setup(self.data, **self.setup_params) |
| 250 self._capture_imputed_training_data() | |
| 201 self.setup_params.update(self.user_kwargs) | 251 self.setup_params.update(self.user_kwargs) |
| 202 | 252 |
| 203 def _normalize_metric(self, m: str) -> str: | 253 def _capture_imputed_training_data(self): |
| 204 if not m: | 254 """ |
| 205 return "R2" if self.task_type == "regression" else "Accuracy" | 255 Cache the dataset as transformed/imputed by PyCaret so downstream |
| 206 m_low = str(m).strip().lower() | 256 components (e.g., feature importance) can operate on the exact data |
| 207 alias = { | 257 used for training. |
| 208 "auc": "AUC", "roc_auc": "AUC", "roc-auc": "AUC", | 258 """ |
| 209 "accuracy": "Accuracy", | 259 if self.exp is None: |
| 210 "precision": "Precision", | 260 return |
| 211 "recall": "Recall", | 261 try: |
| 212 "f1": "F1", | 262 X_processed = self.exp.get_config("X_transformed").copy() |
| 213 "kappa": "Kappa", | 263 y_processed = self.exp.get_config("y") |
| 214 "logloss": "Log Loss", "log_loss": "Log Loss", | 264 if isinstance(y_processed, pd.Series): |
| 215 "pr_auc": "PR-AUC-Weighted", "prauc": "PR-AUC-Weighted", | 265 y_series = y_processed.reset_index(drop=True) |
| 216 "r2": "R2", | 266 else: |
| 217 "mae": "MAE", | 267 y_series = pd.Series(y_processed) |
| 218 "mse": "MSE", | 268 y_series.name = self.target |
| 219 "rmse": "RMSE", | 269 X_processed = X_processed.reset_index(drop=True) |
| 220 "rmsle": "RMSLE", | 270 self.imputed_training_data = pd.concat( |
| 221 "mape": "MAPE", | 271 [X_processed, y_series], axis=1 |
| 222 } | 272 ) |
| 223 return alias.get(m_low, m) | 273 LOG.info( |
| 274 "Captured imputed training dataset from PyCaret " | |
| 275 "(%s rows, %s features).", | |
| 276 self.imputed_training_data.shape[0], | |
| 277 self.imputed_training_data.shape[1] - 1, | |
| 278 ) | |
| 279 except Exception as exc: | |
| 280 LOG.warning( | |
| 281 "Unable to capture processed training data from PyCaret: %s", | |
| 282 exc, | |
| 283 ) | |
| 284 self.imputed_training_data = None | |
| 224 | 285 |
| 225 def train_model(self): | 286 def train_model(self): |
| 226 LOG.info("Training and selecting the best model") | 287 LOG.info("Training and selecting the best model") |
| 227 if self.task_type == "classification": | 288 if self.task_type == "classification": |
| 228 self.exp.add_metric( | 289 self.exp.add_metric( |
| 243 | 304 |
| 244 # Respect explicit fold count | 305 # Respect explicit fold count |
| 245 if getattr(self, "cross_validation_folds", None) is not None: | 306 if getattr(self, "cross_validation_folds", None) is not None: |
| 246 compare_kwargs["fold"] = self.cross_validation_folds | 307 compare_kwargs["fold"] = self.cross_validation_folds |
| 247 | 308 |
| 248 chosen_metric = self._normalize_metric(getattr(self, "best_model_metric", None)) | 309 best_metric = getattr(self, "best_model_metric", None) |
| 249 if chosen_metric: | 310 if best_metric: |
| 250 compare_kwargs["sort"] = chosen_metric | 311 compare_kwargs["sort"] = best_metric |
| 251 self.chosen_metric_label = chosen_metric | 312 self._best_model_metric_used = best_metric |
| 252 try: | 313 LOG.info(f"Ranking models using metric: {best_metric}") |
| 253 setattr(self.exp, "_fold_metric", chosen_metric) | |
| 254 except Exception as e: | |
| 255 LOG.warning(f"Failed to set '_fold_metric' to '{chosen_metric}': {e}", exc_info=True) | |
| 256 | 314 |
| 257 LOG.info(f"compare_models kwargs: {compare_kwargs}") | 315 LOG.info(f"compare_models kwargs: {compare_kwargs}") |
| 258 self.best_model = self.exp.compare_models(**compare_kwargs) | 316 self.best_model = self.exp.compare_models(**compare_kwargs) |
| 317 if self._best_model_metric_used is None: | |
| 318 self._best_model_metric_used = getattr(self.exp, "_fold_metric", None) | |
| 259 self.results = self.exp.pull() | 319 self.results = self.exp.pull() |
| 260 if getattr(self, "tune_model", False): | 320 if getattr(self, "tune_model", False): |
| 261 LOG.info("Tuning hyperparameters of the best model") | 321 LOG.info("Tuning hyperparameters of the best model") |
| 262 self.best_model = self.exp.tune_model(self.best_model) | 322 self.best_model = self.exp.tune_model(self.best_model) |
| 263 self.tuning_results = self.exp.pull() | 323 self.tuning_results = self.exp.pull() |
| 324 LOG.warning(f"Could not generate {name} plot: {e}") | 384 LOG.warning(f"Could not generate {name} plot: {e}") |
| 325 | 385 |
| 326 def encode_image_to_base64(self, img_path: str) -> str: | 386 def encode_image_to_base64(self, img_path: str) -> str: |
| 327 with open(img_path, "rb") as img_file: | 387 with open(img_path, "rb") as img_file: |
| 328 return base64.b64encode(img_file.read()).decode("utf-8") | 388 return base64.b64encode(img_file.read()).decode("utf-8") |
| 389 | |
| 390 def _resolve_plot_callable(self, key, fig_or_fn, section): | |
| 391 """ | |
| 392 Safely execute stored plot callables so a single failure does not | |
| 393 abort the entire HTML report generation. | |
| 394 """ | |
| 395 if fig_or_fn is None: | |
| 396 return None | |
| 397 try: | |
| 398 return fig_or_fn() if callable(fig_or_fn) else fig_or_fn | |
| 399 except Exception as exc: | |
| 400 extra = "" | |
| 401 if isinstance(exc, ValueError) and "Input contains NaN" in str(exc): | |
| 402 extra = ( | |
| 403 " (model returned NaN probabilities; " | |
| 404 "consider checking data preprocessing)" | |
| 405 ) | |
| 406 LOG.warning( | |
| 407 "Skipping %s plot '%s' due to error: %s%s", | |
| 408 section, | |
| 409 key, | |
| 410 exc, | |
| 411 extra, | |
| 412 ) | |
| 413 return None | |
| 329 | 414 |
| 330 def save_html_report(self): | 415 def save_html_report(self): |
| 331 LOG.info("Saving HTML report") | 416 LOG.info("Saving HTML report") |
| 332 | 417 |
| 333 # 1) Determine best model name | 418 # 1) Determine best model name |
| 399 elif key == "Probability Threshold": | 484 elif key == "Probability Threshold": |
| 400 dv = f"{v:.2f}" if v is not None else "0.5" | 485 dv = f"{v:.2f}" if v is not None else "0.5" |
| 401 else: | 486 else: |
| 402 dv = v if v is not None else "None" | 487 dv = v if v is not None else "None" |
| 403 setup_rows.append([key, dv]) | 488 setup_rows.append([key, dv]) |
| 404 if getattr(self, "chosen_metric_label", None): | 489 metric_label = self._best_model_metric_used or getattr( |
| 405 setup_rows.append(["Best Model Metric", self.chosen_metric_label]) | 490 self.exp, "_fold_metric", None |
| 491 ) | |
| 492 if metric_label: | |
| 493 setup_rows.append(["Best Model Metric", metric_label]) | |
| 406 | 494 |
| 407 df_setup = pd.DataFrame(setup_rows, columns=["Parameter", "Value"]) | 495 df_setup = pd.DataFrame(setup_rows, columns=["Parameter", "Value"]) |
| 408 df_setup.to_csv( | 496 df_setup.to_csv( |
| 409 Path(self.output_dir) / "setup_params.csv", index=False | 497 Path(self.output_dir) / "setup_params.csv", index=False |
| 410 ) | 498 ) |
| 562 test_order = [ | 650 test_order = [ |
| 563 "confusion_matrix", | 651 "confusion_matrix", |
| 564 "roc_auc", | 652 "roc_auc", |
| 565 "pr_auc", | 653 "pr_auc", |
| 566 "lift_curve", | 654 "lift_curve", |
| 567 "threshold", | |
| 568 "cumulative_precision", | 655 "cumulative_precision", |
| 569 ] | 656 ] |
| 570 for key in test_order: | 657 for key in test_order: |
| 571 fig_or_fn = self.explainer_plots.pop(key, None) | 658 fig_or_fn = self.explainer_plots.pop(key, None) |
| 572 if fig_or_fn is not None: | 659 if fig_or_fn is not None: |
| 573 fig = fig_or_fn() if callable(fig_or_fn) else fig_or_fn | 660 fig = self._resolve_plot_callable( |
| 661 key, fig_or_fn, section="test/explainer" | |
| 662 ) | |
| 663 if fig is None: | |
| 664 continue | |
| 574 title = plot_title_map.get( | 665 title = plot_title_map.get( |
| 575 key, key.replace("_", " ").title() | 666 key, key.replace("_", " ").title() |
| 576 ) | 667 ) |
| 577 test_html += ( | 668 test_html += ( |
| 578 f"<h2>{title}</h2>" + add_plot_to_html(fig) | 669 f"<h2>{title}</h2>" + add_plot_to_html(fig) |
| 582 for name, path in self.plots.items(): | 673 for name, path in self.plots.items(): |
| 583 # classification: include only the small extras, before | 674 # classification: include only the small extras, before |
| 584 # skipping anything | 675 # skipping anything |
| 585 if self.task_type == "classification" and ( | 676 if self.task_type == "classification" and ( |
| 586 name in { | 677 name in { |
| 587 "threshold", | |
| 588 "pr_auc", | 678 "pr_auc", |
| 589 "class_report", | 679 "class_report", |
| 590 } | 680 } |
| 591 ): | 681 ): |
| 592 title = plot_title_map.get( | 682 title = plot_title_map.get( |
| 628 | 718 |
| 629 # — Feature Importance — | 719 # — Feature Importance — |
| 630 feature_html = header | 720 feature_html = header |
| 631 | 721 |
| 632 # 6a) PyCaret’s default feature importances | 722 # 6a) PyCaret’s default feature importances |
| 633 feature_html += FeatureImportanceAnalyzer( | 723 imputed_data = ( |
| 634 data=self.data, | 724 self.imputed_training_data |
| 725 if self.imputed_training_data is not None | |
| 726 else self.data | |
| 727 ) | |
| 728 fi_analyzer = FeatureImportanceAnalyzer( | |
| 729 data=imputed_data, | |
| 635 target_col=self.target_col, | 730 target_col=self.target_col, |
| 636 task_type=self.task_type, | 731 task_type=self.task_type, |
| 637 output_dir=self.output_dir, | 732 output_dir=self.output_dir, |
| 638 exp=self.exp, | 733 exp=self.exp, |
| 639 best_model=self.best_model, | 734 best_model=self.best_model, |
| 640 ).run() | 735 max_plot_features=self.plot_feature_limit, |
| 736 processed_data=self.imputed_training_data, | |
| 737 max_shap_rows=self._shap_row_cap, | |
| 738 ) | |
| 739 fi_html = fi_analyzer.run() | |
| 740 # Add a small table to show SHAP feature caps near the Best Model header. | |
| 741 cap_rows = [] | |
| 742 if fi_analyzer.shap_total_features is not None: | |
| 743 cap_rows.append( | |
| 744 ("Total transformed features", fi_analyzer.shap_total_features) | |
| 745 ) | |
| 746 if fi_analyzer.shap_used_features is not None: | |
| 747 cap_rows.append( | |
| 748 ("Features used in SHAP", fi_analyzer.shap_used_features) | |
| 749 ) | |
| 750 if cap_rows: | |
| 751 cap_table = ( | |
| 752 "<div class='table-wrapper'>" | |
| 753 "<table class='table sortable'>" | |
| 754 "<thead><tr><th>Feature Importance Scope</th><th>Count</th></tr></thead>" | |
| 755 "<tbody>" | |
| 756 + "".join( | |
| 757 f"<tr><td>{label}</td><td>{value}</td></tr>" | |
| 758 for label, value in cap_rows | |
| 759 ) | |
| 760 + "</tbody></table></div>" | |
| 761 ) | |
| 762 feature_html += cap_table | |
| 763 feature_html += fi_html | |
| 641 | 764 |
| 642 # 6b) Explainer SHAP importances | 765 # 6b) Explainer SHAP importances |
| 643 for key in ["shap_mean", "shap_perm"]: | 766 for key in ["shap_mean", "shap_perm"]: |
| 644 fig_or_fn = self.explainer_plots.pop(key, None) | 767 fig_or_fn = self.explainer_plots.pop(key, None) |
| 645 if fig_or_fn is not None: | 768 if fig_or_fn is not None: |
| 646 fig = fig_or_fn() if callable(fig_or_fn) else fig_or_fn | 769 fig = self._resolve_plot_callable( |
| 770 key, fig_or_fn, section="feature importance" | |
| 771 ) | |
| 772 if fig is None: | |
| 773 continue | |
| 647 # give SHAP plots explicit titles | 774 # give SHAP plots explicit titles |
| 648 title = ( | 775 title = ( |
| 649 "Mean Absolute SHAP Value Impact" | 776 "Mean Absolute SHAP Value Impact" |
| 650 if key == "shap_mean" | 777 if key == "shap_mean" |
| 651 else "Permutation Feature Importance" | 778 else "Permutation Feature Importance" |
| 659 pdp_keys = sorted( | 786 pdp_keys = sorted( |
| 660 k for k in self.explainer_plots if k.startswith("pdp__") | 787 k for k in self.explainer_plots if k.startswith("pdp__") |
| 661 ) | 788 ) |
| 662 for k in pdp_keys: | 789 for k in pdp_keys: |
| 663 fig_or_fn = self.explainer_plots[k] | 790 fig_or_fn = self.explainer_plots[k] |
| 664 fig = fig_or_fn() if callable(fig_or_fn) else fig_or_fn | 791 fig = self._resolve_plot_callable( |
| 792 k, fig_or_fn, section="pdp" | |
| 793 ) | |
| 794 if fig is None: | |
| 795 continue | |
| 665 # extract feature name | 796 # extract feature name |
| 666 feature = k.split("__", 1)[1] | 797 feature = k.split("__", 1)[1] |
| 667 title = f"Partial Dependence for {feature}" | 798 title = f"Partial Dependence for {feature}" |
| 668 feature_html += ( | 799 feature_html += ( |
| 669 f"<h2>{title}</h2>" + add_plot_to_html(fig) | 800 f"<h2>{title}</h2>" + add_plot_to_html(fig) |
