Mercurial > repos > goeckslab > pycaret_predict
comparison pycaret_classification.py @ 12:e674b9e946fb draft default tip
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
author | goeckslab |
---|---|
date | Mon, 08 Sep 2025 22:39:12 +0000 |
parents | 1aed7d47c5ec |
children |
comparison
equal
deleted
inserted
replaced
11:4eca9d109de1 | 12:e674b9e946fb |
---|---|
1 import logging | 1 import logging |
2 import types | 2 import types |
3 from typing import Dict | 3 from typing import Dict |
4 | 4 |
5 import numpy as np | |
6 import pandas as pd | |
7 import plotly.graph_objects as go | |
5 from base_model_trainer import BaseModelTrainer | 8 from base_model_trainer import BaseModelTrainer |
6 from dashboard import generate_classifier_explainer_dashboard | 9 from dashboard import generate_classifier_explainer_dashboard |
7 from plotly.graph_objects import Figure | |
8 from pycaret.classification import ClassificationExperiment | 10 from pycaret.classification import ClassificationExperiment |
11 from sklearn.metrics import auc, confusion_matrix, precision_recall_curve, roc_curve | |
9 from utils import predict_proba | 12 from utils import predict_proba |
10 | 13 |
11 LOG = logging.getLogger(__name__) | 14 LOG = logging.getLogger(__name__) |
15 | |
16 | |
17 def _apply_report_layout(fig: go.Figure) -> go.Figure: | |
18 # Give the left side more space for y-axis title/ticks and let axes auto-reserve room | |
19 fig.update_xaxes(automargin=True, title_standoff=12) | |
20 fig.update_yaxes(automargin=True, title_standoff=12) | |
21 fig.update_layout( | |
22 autosize=True, | |
23 margin=dict(l=120, r=40, t=60, b=60), # bump 'l' if you still see clipping | |
24 ) | |
25 return fig | |
12 | 26 |
13 | 27 |
14 class ClassificationModelTrainer(BaseModelTrainer): | 28 class ClassificationModelTrainer(BaseModelTrainer): |
15 def __init__( | 29 def __init__( |
16 self, | 30 self, |
48 LOG.warning( | 62 LOG.warning( |
49 f"The model {type(self.best_model).__name__} does not support `predict_proba`. Applying monkey patch." | 63 f"The model {type(self.best_model).__name__} does not support `predict_proba`. Applying monkey patch." |
50 ) | 64 ) |
51 | 65 |
52 plots = [ | 66 plots = [ |
53 'confusion_matrix', | 67 "auc", |
54 'auc', | 68 "threshold", |
55 'threshold', | 69 "pr", |
56 'pr', | 70 "error", |
57 'error', | 71 "class_report", |
58 'class_report', | 72 "learning", |
59 'learning', | 73 "calibration", |
60 'calibration', | 74 "vc", |
61 'vc', | 75 "dimension", |
62 'dimension', | 76 "manifold", |
63 'manifold', | 77 "rfe", |
64 'rfe', | 78 "feature", |
65 'feature', | 79 "feature_all", |
66 'feature_all', | |
67 ] | 80 ] |
68 for plot_name in plots: | 81 for plot_name in plots: |
69 try: | 82 try: |
70 if plot_name == "threshold": | 83 if plot_name == "threshold": |
71 plot_path = self.exp.plot_model( | 84 plot_path = self.exp.plot_model( |
100 def generate_plots_explainer(self): | 113 def generate_plots_explainer(self): |
101 from explainerdashboard import ClassifierExplainer | 114 from explainerdashboard import ClassifierExplainer |
102 | 115 |
103 LOG.info("Generating explainer plots") | 116 LOG.info("Generating explainer plots") |
104 | 117 |
118 # Ensure predict_proba is available here too | |
119 if not hasattr(self.best_model, "predict_proba"): | |
120 self.best_model.predict_proba = types.MethodType( | |
121 predict_proba, self.best_model | |
122 ) | |
123 LOG.warning( | |
124 f"The model {type(self.best_model).__name__} does not support `predict_proba`. Applying monkey patch." | |
125 ) | |
126 | |
105 X_test = self.exp.X_test_transformed.copy() | 127 X_test = self.exp.X_test_transformed.copy() |
106 y_test = self.exp.y_test_transformed | 128 y_test = self.exp.y_test_transformed |
107 explainer = ClassifierExplainer(self.best_model, X_test, y_test) | 129 explainer = ClassifierExplainer(self.best_model, X_test, y_test) |
108 | 130 |
109 # a dict to hold the raw Figure objects or callables | 131 # a dict to hold the raw Figure objects or callables |
110 self.explainer_plots: Dict[str, Figure] = {} | 132 self.explainer_plots: Dict[str, go.Figure] = {} |
111 | 133 |
112 # these go into the Test tab | 134 # --- Threshold-aware overrides for CM / ROC / PR --- |
135 prob_thresh = getattr(self, "probability_threshold", None) | |
136 | |
137 # Only for binary classification and when threshold is provided | |
138 if (prob_thresh is not None) and (not self.exp.is_multiclass): | |
139 X = self.exp.X_test_transformed | |
140 y = pd.Series(self.exp.y_test_transformed).reset_index(drop=True) | |
141 | |
142 # Get positive-class scores (robust defaults) | |
143 classes = list(getattr(self.best_model, "classes_", [0, 1])) | |
144 try: | |
145 pos_idx = classes.index(1) if 1 in classes else 1 | |
146 except Exception: | |
147 pos_idx = 1 | |
148 | |
149 proba = self.best_model.predict_proba(X) | |
150 y_scores = proba[:, pos_idx] | |
151 | |
152 # Derive label names consistently | |
153 pos_label = classes[pos_idx] if len(classes) > pos_idx else 1 | |
154 neg_label = classes[1 - pos_idx] if len(classes) > 1 else 0 | |
155 | |
156 # ---- Confusion Matrix @ threshold ---- | |
157 try: | |
158 y_pred = np.where(y_scores >= prob_thresh, pos_label, neg_label) | |
159 cm = confusion_matrix(y, y_pred, labels=[neg_label, pos_label]) | |
160 fig_cm = go.Figure( | |
161 data=go.Heatmap( | |
162 z=cm, | |
163 x=[f"Pred {neg_label}", f"Pred {pos_label}"], | |
164 y=[f"True {neg_label}", f"True {pos_label}"], | |
165 text=cm, | |
166 texttemplate="%{text}", | |
167 colorscale="Blues", | |
168 showscale=False, | |
169 ) | |
170 ) | |
171 fig_cm.update_layout( | |
172 title=f"Confusion Matrix @ threshold={prob_thresh:.2f}", | |
173 xaxis_title="Predicted label", | |
174 yaxis_title="True label", | |
175 ) | |
176 _apply_report_layout(fig_cm) | |
177 self.explainer_plots["confusion_matrix"] = fig_cm | |
178 except Exception as e: | |
179 LOG.warning( | |
180 f"Threshold-aware confusion matrix failed; falling back: {e}" | |
181 ) | |
182 | |
183 # ---- ROC with threshold marker ---- | |
184 try: | |
185 fpr, tpr, thr = roc_curve(y, y_scores) | |
186 roc_auc = auc(fpr, tpr) | |
187 fig_roc = go.Figure() | |
188 fig_roc.add_scatter( | |
189 x=fpr, y=tpr, mode="lines", name=f"ROC (AUC={roc_auc:.3f})" | |
190 ) | |
191 if len(thr): | |
192 mask = np.isfinite(thr) | |
193 if mask.any(): | |
194 idx_local = int(np.argmin(np.abs(thr[mask] - prob_thresh))) | |
195 idx = np.where(mask)[0][idx_local] | |
196 if 0 <= idx < len(fpr): | |
197 fig_roc.add_scatter( | |
198 x=[fpr[idx]], | |
199 y=[tpr[idx]], | |
200 mode="markers", | |
201 name=f"@ {prob_thresh:.2f}", | |
202 marker=dict(size=10), | |
203 ) | |
204 fig_roc.update_layout( | |
205 title=f"ROC Curve (marker at threshold={prob_thresh:.2f})", | |
206 xaxis_title="False Positive Rate", | |
207 yaxis_title="True Positive Rate", | |
208 ) | |
209 _apply_report_layout(fig_roc) | |
210 self.explainer_plots["roc_auc"] = fig_roc | |
211 except Exception as e: | |
212 LOG.warning(f"Threshold marker on ROC failed; falling back: {e}") | |
213 | |
214 # ---- PR with threshold marker ---- | |
215 try: | |
216 precision, recall, thr_pr = precision_recall_curve(y, y_scores) | |
217 pr_auc = auc(recall, precision) | |
218 fig_pr = go.Figure() | |
219 fig_pr.add_scatter( | |
220 x=recall, y=precision, mode="lines", name=f"PR (AUC={pr_auc:.3f})" | |
221 ) | |
222 if len(thr_pr): | |
223 idx_pr = int(np.argmin(np.abs(thr_pr - prob_thresh))) | |
224 # note: thr_pr has length = len(precision) - 1 | |
225 idx_pr = max(0, min(idx_pr, len(recall) - 1)) | |
226 fig_pr.add_scatter( | |
227 x=[recall[idx_pr]], | |
228 y=[precision[idx_pr]], | |
229 mode="markers", | |
230 name=f"@ {prob_thresh:.2f}", | |
231 marker=dict(size=10), | |
232 ) | |
233 fig_pr.update_layout( | |
234 title=f"Precision–Recall (marker at threshold={prob_thresh:.2f})", | |
235 xaxis_title="Recall", | |
236 yaxis_title="Precision", | |
237 ) | |
238 _apply_report_layout(fig_pr) | |
239 self.explainer_plots["pr_auc"] = fig_pr | |
240 except Exception as e: | |
241 LOG.warning(f"Threshold marker on PR failed; falling back: {e}") | |
242 | |
243 # these go into the Test tab (don't overwrite overrides) | |
113 for key, fn in [ | 244 for key, fn in [ |
114 ("roc_auc", explainer.plot_roc_auc), | 245 ("roc_auc", explainer.plot_roc_auc), |
115 ("pr_auc", explainer.plot_pr_auc), | 246 ("pr_auc", explainer.plot_pr_auc), |
116 ("lift_curve", explainer.plot_lift_curve), | 247 ("lift_curve", explainer.plot_lift_curve), |
117 ("confusion_matrix", explainer.plot_confusion_matrix), | 248 ("confusion_matrix", explainer.plot_confusion_matrix), |
118 ("threshold", explainer.plot_precision), # Percentage vs probability | 249 ("threshold", explainer.plot_precision), # percentage vs probability |
119 ("cumulative_precision", explainer.plot_cumulative_precision), | 250 ("cumulative_precision", explainer.plot_cumulative_precision), |
120 ]: | 251 ]: |
121 try: | 252 if key in self.explainer_plots: |
122 self.explainer_plots[key] = fn() | 253 continue |
254 try: | |
255 fig = fn() | |
256 if fig is not None: | |
257 self.explainer_plots[key] = fig | |
123 except Exception as e: | 258 except Exception as e: |
124 LOG.error(f"Error generating explainer plot {key}: {e}") | 259 LOG.error(f"Error generating explainer plot {key}: {e}") |
125 | 260 |
126 # mean SHAP importances | 261 # mean SHAP importances |
127 try: | 262 try: |
141 valid_feats = [] | 276 valid_feats = [] |
142 for feat in self.features_name: | 277 for feat in self.features_name: |
143 if feat in explainer.X.columns or feat in explainer.onehot_cols: | 278 if feat in explainer.X.columns or feat in explainer.onehot_cols: |
144 valid_feats.append(feat) | 279 valid_feats.append(feat) |
145 else: | 280 else: |
146 LOG.warning(f"Skipping PDP for feature {feat!r}: not found in explainer data") | 281 LOG.warning( |
282 f"Skipping PDP for feature {feat!r}: not found in explainer data" | |
283 ) | |
147 | 284 |
148 for feat in valid_feats: | 285 for feat in valid_feats: |
149 # wrap each PDP call to catch any unexpected AssertionErrors | 286 # wrap each PDP call to catch any unexpected AssertionErrors |
150 def make_pdp_plotter(f): | 287 def make_pdp_plotter(f): |
151 def _plot(): | 288 def _plot(): |
155 LOG.warning(f"PDP AssertionError for {f!r}: {ae}") | 292 LOG.warning(f"PDP AssertionError for {f!r}: {ae}") |
156 return None | 293 return None |
157 except Exception as e: | 294 except Exception as e: |
158 LOG.error(f"Unexpected error plotting PDP for {f!r}: {e}") | 295 LOG.error(f"Unexpected error plotting PDP for {f!r}: {e}") |
159 return None | 296 return None |
297 | |
160 return _plot | 298 return _plot |
161 | 299 |
162 self.explainer_plots[f"pdp__{feat}"] = make_pdp_plotter(feat) | 300 self.explainer_plots[f"pdp__{feat}"] = make_pdp_plotter(feat) |