Mercurial > repos > goeckslab > pycaret_predict
comparison pycaret_classification.py @ 12:e674b9e946fb draft default tip
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
| author | goeckslab |
|---|---|
| date | Mon, 08 Sep 2025 22:39:12 +0000 |
| parents | 1aed7d47c5ec |
| children |
comparison
equal
deleted
inserted
replaced
| 11:4eca9d109de1 | 12:e674b9e946fb |
|---|---|
| 1 import logging | 1 import logging |
| 2 import types | 2 import types |
| 3 from typing import Dict | 3 from typing import Dict |
| 4 | 4 |
| 5 import numpy as np | |
| 6 import pandas as pd | |
| 7 import plotly.graph_objects as go | |
| 5 from base_model_trainer import BaseModelTrainer | 8 from base_model_trainer import BaseModelTrainer |
| 6 from dashboard import generate_classifier_explainer_dashboard | 9 from dashboard import generate_classifier_explainer_dashboard |
| 7 from plotly.graph_objects import Figure | |
| 8 from pycaret.classification import ClassificationExperiment | 10 from pycaret.classification import ClassificationExperiment |
| 11 from sklearn.metrics import auc, confusion_matrix, precision_recall_curve, roc_curve | |
| 9 from utils import predict_proba | 12 from utils import predict_proba |
| 10 | 13 |
| 11 LOG = logging.getLogger(__name__) | 14 LOG = logging.getLogger(__name__) |
| 15 | |
| 16 | |
| 17 def _apply_report_layout(fig: go.Figure) -> go.Figure: | |
| 18 # Give the left side more space for y-axis title/ticks and let axes auto-reserve room | |
| 19 fig.update_xaxes(automargin=True, title_standoff=12) | |
| 20 fig.update_yaxes(automargin=True, title_standoff=12) | |
| 21 fig.update_layout( | |
| 22 autosize=True, | |
| 23 margin=dict(l=120, r=40, t=60, b=60), # bump 'l' if you still see clipping | |
| 24 ) | |
| 25 return fig | |
| 12 | 26 |
| 13 | 27 |
| 14 class ClassificationModelTrainer(BaseModelTrainer): | 28 class ClassificationModelTrainer(BaseModelTrainer): |
| 15 def __init__( | 29 def __init__( |
| 16 self, | 30 self, |
| 48 LOG.warning( | 62 LOG.warning( |
| 49 f"The model {type(self.best_model).__name__} does not support `predict_proba`. Applying monkey patch." | 63 f"The model {type(self.best_model).__name__} does not support `predict_proba`. Applying monkey patch." |
| 50 ) | 64 ) |
| 51 | 65 |
| 52 plots = [ | 66 plots = [ |
| 53 'confusion_matrix', | 67 "auc", |
| 54 'auc', | 68 "threshold", |
| 55 'threshold', | 69 "pr", |
| 56 'pr', | 70 "error", |
| 57 'error', | 71 "class_report", |
| 58 'class_report', | 72 "learning", |
| 59 'learning', | 73 "calibration", |
| 60 'calibration', | 74 "vc", |
| 61 'vc', | 75 "dimension", |
| 62 'dimension', | 76 "manifold", |
| 63 'manifold', | 77 "rfe", |
| 64 'rfe', | 78 "feature", |
| 65 'feature', | 79 "feature_all", |
| 66 'feature_all', | |
| 67 ] | 80 ] |
| 68 for plot_name in plots: | 81 for plot_name in plots: |
| 69 try: | 82 try: |
| 70 if plot_name == "threshold": | 83 if plot_name == "threshold": |
| 71 plot_path = self.exp.plot_model( | 84 plot_path = self.exp.plot_model( |
| 100 def generate_plots_explainer(self): | 113 def generate_plots_explainer(self): |
| 101 from explainerdashboard import ClassifierExplainer | 114 from explainerdashboard import ClassifierExplainer |
| 102 | 115 |
| 103 LOG.info("Generating explainer plots") | 116 LOG.info("Generating explainer plots") |
| 104 | 117 |
| 118 # Ensure predict_proba is available here too | |
| 119 if not hasattr(self.best_model, "predict_proba"): | |
| 120 self.best_model.predict_proba = types.MethodType( | |
| 121 predict_proba, self.best_model | |
| 122 ) | |
| 123 LOG.warning( | |
| 124 f"The model {type(self.best_model).__name__} does not support `predict_proba`. Applying monkey patch." | |
| 125 ) | |
| 126 | |
| 105 X_test = self.exp.X_test_transformed.copy() | 127 X_test = self.exp.X_test_transformed.copy() |
| 106 y_test = self.exp.y_test_transformed | 128 y_test = self.exp.y_test_transformed |
| 107 explainer = ClassifierExplainer(self.best_model, X_test, y_test) | 129 explainer = ClassifierExplainer(self.best_model, X_test, y_test) |
| 108 | 130 |
| 109 # a dict to hold the raw Figure objects or callables | 131 # a dict to hold the raw Figure objects or callables |
| 110 self.explainer_plots: Dict[str, Figure] = {} | 132 self.explainer_plots: Dict[str, go.Figure] = {} |
| 111 | 133 |
| 112 # these go into the Test tab | 134 # --- Threshold-aware overrides for CM / ROC / PR --- |
| 135 prob_thresh = getattr(self, "probability_threshold", None) | |
| 136 | |
| 137 # Only for binary classification and when threshold is provided | |
| 138 if (prob_thresh is not None) and (not self.exp.is_multiclass): | |
| 139 X = self.exp.X_test_transformed | |
| 140 y = pd.Series(self.exp.y_test_transformed).reset_index(drop=True) | |
| 141 | |
| 142 # Get positive-class scores (robust defaults) | |
| 143 classes = list(getattr(self.best_model, "classes_", [0, 1])) | |
| 144 try: | |
| 145 pos_idx = classes.index(1) if 1 in classes else 1 | |
| 146 except Exception: | |
| 147 pos_idx = 1 | |
| 148 | |
| 149 proba = self.best_model.predict_proba(X) | |
| 150 y_scores = proba[:, pos_idx] | |
| 151 | |
| 152 # Derive label names consistently | |
| 153 pos_label = classes[pos_idx] if len(classes) > pos_idx else 1 | |
| 154 neg_label = classes[1 - pos_idx] if len(classes) > 1 else 0 | |
| 155 | |
| 156 # ---- Confusion Matrix @ threshold ---- | |
| 157 try: | |
| 158 y_pred = np.where(y_scores >= prob_thresh, pos_label, neg_label) | |
| 159 cm = confusion_matrix(y, y_pred, labels=[neg_label, pos_label]) | |
| 160 fig_cm = go.Figure( | |
| 161 data=go.Heatmap( | |
| 162 z=cm, | |
| 163 x=[f"Pred {neg_label}", f"Pred {pos_label}"], | |
| 164 y=[f"True {neg_label}", f"True {pos_label}"], | |
| 165 text=cm, | |
| 166 texttemplate="%{text}", | |
| 167 colorscale="Blues", | |
| 168 showscale=False, | |
| 169 ) | |
| 170 ) | |
| 171 fig_cm.update_layout( | |
| 172 title=f"Confusion Matrix @ threshold={prob_thresh:.2f}", | |
| 173 xaxis_title="Predicted label", | |
| 174 yaxis_title="True label", | |
| 175 ) | |
| 176 _apply_report_layout(fig_cm) | |
| 177 self.explainer_plots["confusion_matrix"] = fig_cm | |
| 178 except Exception as e: | |
| 179 LOG.warning( | |
| 180 f"Threshold-aware confusion matrix failed; falling back: {e}" | |
| 181 ) | |
| 182 | |
| 183 # ---- ROC with threshold marker ---- | |
| 184 try: | |
| 185 fpr, tpr, thr = roc_curve(y, y_scores) | |
| 186 roc_auc = auc(fpr, tpr) | |
| 187 fig_roc = go.Figure() | |
| 188 fig_roc.add_scatter( | |
| 189 x=fpr, y=tpr, mode="lines", name=f"ROC (AUC={roc_auc:.3f})" | |
| 190 ) | |
| 191 if len(thr): | |
| 192 mask = np.isfinite(thr) | |
| 193 if mask.any(): | |
| 194 idx_local = int(np.argmin(np.abs(thr[mask] - prob_thresh))) | |
| 195 idx = np.where(mask)[0][idx_local] | |
| 196 if 0 <= idx < len(fpr): | |
| 197 fig_roc.add_scatter( | |
| 198 x=[fpr[idx]], | |
| 199 y=[tpr[idx]], | |
| 200 mode="markers", | |
| 201 name=f"@ {prob_thresh:.2f}", | |
| 202 marker=dict(size=10), | |
| 203 ) | |
| 204 fig_roc.update_layout( | |
| 205 title=f"ROC Curve (marker at threshold={prob_thresh:.2f})", | |
| 206 xaxis_title="False Positive Rate", | |
| 207 yaxis_title="True Positive Rate", | |
| 208 ) | |
| 209 _apply_report_layout(fig_roc) | |
| 210 self.explainer_plots["roc_auc"] = fig_roc | |
| 211 except Exception as e: | |
| 212 LOG.warning(f"Threshold marker on ROC failed; falling back: {e}") | |
| 213 | |
| 214 # ---- PR with threshold marker ---- | |
| 215 try: | |
| 216 precision, recall, thr_pr = precision_recall_curve(y, y_scores) | |
| 217 pr_auc = auc(recall, precision) | |
| 218 fig_pr = go.Figure() | |
| 219 fig_pr.add_scatter( | |
| 220 x=recall, y=precision, mode="lines", name=f"PR (AUC={pr_auc:.3f})" | |
| 221 ) | |
| 222 if len(thr_pr): | |
| 223 idx_pr = int(np.argmin(np.abs(thr_pr - prob_thresh))) | |
| 224 # note: thr_pr has length = len(precision) - 1 | |
| 225 idx_pr = max(0, min(idx_pr, len(recall) - 1)) | |
| 226 fig_pr.add_scatter( | |
| 227 x=[recall[idx_pr]], | |
| 228 y=[precision[idx_pr]], | |
| 229 mode="markers", | |
| 230 name=f"@ {prob_thresh:.2f}", | |
| 231 marker=dict(size=10), | |
| 232 ) | |
| 233 fig_pr.update_layout( | |
| 234 title=f"Precision–Recall (marker at threshold={prob_thresh:.2f})", | |
| 235 xaxis_title="Recall", | |
| 236 yaxis_title="Precision", | |
| 237 ) | |
| 238 _apply_report_layout(fig_pr) | |
| 239 self.explainer_plots["pr_auc"] = fig_pr | |
| 240 except Exception as e: | |
| 241 LOG.warning(f"Threshold marker on PR failed; falling back: {e}") | |
| 242 | |
| 243 # these go into the Test tab (don't overwrite overrides) | |
| 113 for key, fn in [ | 244 for key, fn in [ |
| 114 ("roc_auc", explainer.plot_roc_auc), | 245 ("roc_auc", explainer.plot_roc_auc), |
| 115 ("pr_auc", explainer.plot_pr_auc), | 246 ("pr_auc", explainer.plot_pr_auc), |
| 116 ("lift_curve", explainer.plot_lift_curve), | 247 ("lift_curve", explainer.plot_lift_curve), |
| 117 ("confusion_matrix", explainer.plot_confusion_matrix), | 248 ("confusion_matrix", explainer.plot_confusion_matrix), |
| 118 ("threshold", explainer.plot_precision), # Percentage vs probability | 249 ("threshold", explainer.plot_precision), # percentage vs probability |
| 119 ("cumulative_precision", explainer.plot_cumulative_precision), | 250 ("cumulative_precision", explainer.plot_cumulative_precision), |
| 120 ]: | 251 ]: |
| 121 try: | 252 if key in self.explainer_plots: |
| 122 self.explainer_plots[key] = fn() | 253 continue |
| 254 try: | |
| 255 fig = fn() | |
| 256 if fig is not None: | |
| 257 self.explainer_plots[key] = fig | |
| 123 except Exception as e: | 258 except Exception as e: |
| 124 LOG.error(f"Error generating explainer plot {key}: {e}") | 259 LOG.error(f"Error generating explainer plot {key}: {e}") |
| 125 | 260 |
| 126 # mean SHAP importances | 261 # mean SHAP importances |
| 127 try: | 262 try: |
| 141 valid_feats = [] | 276 valid_feats = [] |
| 142 for feat in self.features_name: | 277 for feat in self.features_name: |
| 143 if feat in explainer.X.columns or feat in explainer.onehot_cols: | 278 if feat in explainer.X.columns or feat in explainer.onehot_cols: |
| 144 valid_feats.append(feat) | 279 valid_feats.append(feat) |
| 145 else: | 280 else: |
| 146 LOG.warning(f"Skipping PDP for feature {feat!r}: not found in explainer data") | 281 LOG.warning( |
| 282 f"Skipping PDP for feature {feat!r}: not found in explainer data" | |
| 283 ) | |
| 147 | 284 |
| 148 for feat in valid_feats: | 285 for feat in valid_feats: |
| 149 # wrap each PDP call to catch any unexpected AssertionErrors | 286 # wrap each PDP call to catch any unexpected AssertionErrors |
| 150 def make_pdp_plotter(f): | 287 def make_pdp_plotter(f): |
| 151 def _plot(): | 288 def _plot(): |
| 155 LOG.warning(f"PDP AssertionError for {f!r}: {ae}") | 292 LOG.warning(f"PDP AssertionError for {f!r}: {ae}") |
| 156 return None | 293 return None |
| 157 except Exception as e: | 294 except Exception as e: |
| 158 LOG.error(f"Unexpected error plotting PDP for {f!r}: {e}") | 295 LOG.error(f"Unexpected error plotting PDP for {f!r}: {e}") |
| 159 return None | 296 return None |
| 297 | |
| 160 return _plot | 298 return _plot |
| 161 | 299 |
| 162 self.explainer_plots[f"pdp__{feat}"] = make_pdp_plotter(feat) | 300 self.explainer_plots[f"pdp__{feat}"] = make_pdp_plotter(feat) |
