Mercurial > repos > goeckslab > pycaret_predict
comparison pycaret_classification.py @ 12:e674b9e946fb draft default tip
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
| author | goeckslab | 
|---|---|
| date | Mon, 08 Sep 2025 22:39:12 +0000 | 
| parents | 1aed7d47c5ec | 
| children | 
   comparison
  equal
  deleted
  inserted
  replaced
| 11:4eca9d109de1 | 12:e674b9e946fb | 
|---|---|
| 1 import logging | 1 import logging | 
| 2 import types | 2 import types | 
| 3 from typing import Dict | 3 from typing import Dict | 
| 4 | 4 | 
| 5 import numpy as np | |
| 6 import pandas as pd | |
| 7 import plotly.graph_objects as go | |
| 5 from base_model_trainer import BaseModelTrainer | 8 from base_model_trainer import BaseModelTrainer | 
| 6 from dashboard import generate_classifier_explainer_dashboard | 9 from dashboard import generate_classifier_explainer_dashboard | 
| 7 from plotly.graph_objects import Figure | |
| 8 from pycaret.classification import ClassificationExperiment | 10 from pycaret.classification import ClassificationExperiment | 
| 11 from sklearn.metrics import auc, confusion_matrix, precision_recall_curve, roc_curve | |
| 9 from utils import predict_proba | 12 from utils import predict_proba | 
| 10 | 13 | 
| 11 LOG = logging.getLogger(__name__) | 14 LOG = logging.getLogger(__name__) | 
| 15 | |
| 16 | |
| 17 def _apply_report_layout(fig: go.Figure) -> go.Figure: | |
| 18 # Give the left side more space for y-axis title/ticks and let axes auto-reserve room | |
| 19 fig.update_xaxes(automargin=True, title_standoff=12) | |
| 20 fig.update_yaxes(automargin=True, title_standoff=12) | |
| 21 fig.update_layout( | |
| 22 autosize=True, | |
| 23 margin=dict(l=120, r=40, t=60, b=60), # bump 'l' if you still see clipping | |
| 24 ) | |
| 25 return fig | |
| 12 | 26 | 
| 13 | 27 | 
| 14 class ClassificationModelTrainer(BaseModelTrainer): | 28 class ClassificationModelTrainer(BaseModelTrainer): | 
| 15 def __init__( | 29 def __init__( | 
| 16 self, | 30 self, | 
| 48 LOG.warning( | 62 LOG.warning( | 
| 49 f"The model {type(self.best_model).__name__} does not support `predict_proba`. Applying monkey patch." | 63 f"The model {type(self.best_model).__name__} does not support `predict_proba`. Applying monkey patch." | 
| 50 ) | 64 ) | 
| 51 | 65 | 
| 52 plots = [ | 66 plots = [ | 
| 53 'confusion_matrix', | 67 "auc", | 
| 54 'auc', | 68 "threshold", | 
| 55 'threshold', | 69 "pr", | 
| 56 'pr', | 70 "error", | 
| 57 'error', | 71 "class_report", | 
| 58 'class_report', | 72 "learning", | 
| 59 'learning', | 73 "calibration", | 
| 60 'calibration', | 74 "vc", | 
| 61 'vc', | 75 "dimension", | 
| 62 'dimension', | 76 "manifold", | 
| 63 'manifold', | 77 "rfe", | 
| 64 'rfe', | 78 "feature", | 
| 65 'feature', | 79 "feature_all", | 
| 66 'feature_all', | |
| 67 ] | 80 ] | 
| 68 for plot_name in plots: | 81 for plot_name in plots: | 
| 69 try: | 82 try: | 
| 70 if plot_name == "threshold": | 83 if plot_name == "threshold": | 
| 71 plot_path = self.exp.plot_model( | 84 plot_path = self.exp.plot_model( | 
| 100 def generate_plots_explainer(self): | 113 def generate_plots_explainer(self): | 
| 101 from explainerdashboard import ClassifierExplainer | 114 from explainerdashboard import ClassifierExplainer | 
| 102 | 115 | 
| 103 LOG.info("Generating explainer plots") | 116 LOG.info("Generating explainer plots") | 
| 104 | 117 | 
| 118 # Ensure predict_proba is available here too | |
| 119 if not hasattr(self.best_model, "predict_proba"): | |
| 120 self.best_model.predict_proba = types.MethodType( | |
| 121 predict_proba, self.best_model | |
| 122 ) | |
| 123 LOG.warning( | |
| 124 f"The model {type(self.best_model).__name__} does not support `predict_proba`. Applying monkey patch." | |
| 125 ) | |
| 126 | |
| 105 X_test = self.exp.X_test_transformed.copy() | 127 X_test = self.exp.X_test_transformed.copy() | 
| 106 y_test = self.exp.y_test_transformed | 128 y_test = self.exp.y_test_transformed | 
| 107 explainer = ClassifierExplainer(self.best_model, X_test, y_test) | 129 explainer = ClassifierExplainer(self.best_model, X_test, y_test) | 
| 108 | 130 | 
| 109 # a dict to hold the raw Figure objects or callables | 131 # a dict to hold the raw Figure objects or callables | 
| 110 self.explainer_plots: Dict[str, Figure] = {} | 132 self.explainer_plots: Dict[str, go.Figure] = {} | 
| 111 | 133 | 
| 112 # these go into the Test tab | 134 # --- Threshold-aware overrides for CM / ROC / PR --- | 
| 135 prob_thresh = getattr(self, "probability_threshold", None) | |
| 136 | |
| 137 # Only for binary classification and when threshold is provided | |
| 138 if (prob_thresh is not None) and (not self.exp.is_multiclass): | |
| 139 X = self.exp.X_test_transformed | |
| 140 y = pd.Series(self.exp.y_test_transformed).reset_index(drop=True) | |
| 141 | |
| 142 # Get positive-class scores (robust defaults) | |
| 143 classes = list(getattr(self.best_model, "classes_", [0, 1])) | |
| 144 try: | |
| 145 pos_idx = classes.index(1) if 1 in classes else 1 | |
| 146 except Exception: | |
| 147 pos_idx = 1 | |
| 148 | |
| 149 proba = self.best_model.predict_proba(X) | |
| 150 y_scores = proba[:, pos_idx] | |
| 151 | |
| 152 # Derive label names consistently | |
| 153 pos_label = classes[pos_idx] if len(classes) > pos_idx else 1 | |
| 154 neg_label = classes[1 - pos_idx] if len(classes) > 1 else 0 | |
| 155 | |
| 156 # ---- Confusion Matrix @ threshold ---- | |
| 157 try: | |
| 158 y_pred = np.where(y_scores >= prob_thresh, pos_label, neg_label) | |
| 159 cm = confusion_matrix(y, y_pred, labels=[neg_label, pos_label]) | |
| 160 fig_cm = go.Figure( | |
| 161 data=go.Heatmap( | |
| 162 z=cm, | |
| 163 x=[f"Pred {neg_label}", f"Pred {pos_label}"], | |
| 164 y=[f"True {neg_label}", f"True {pos_label}"], | |
| 165 text=cm, | |
| 166 texttemplate="%{text}", | |
| 167 colorscale="Blues", | |
| 168 showscale=False, | |
| 169 ) | |
| 170 ) | |
| 171 fig_cm.update_layout( | |
| 172 title=f"Confusion Matrix @ threshold={prob_thresh:.2f}", | |
| 173 xaxis_title="Predicted label", | |
| 174 yaxis_title="True label", | |
| 175 ) | |
| 176 _apply_report_layout(fig_cm) | |
| 177 self.explainer_plots["confusion_matrix"] = fig_cm | |
| 178 except Exception as e: | |
| 179 LOG.warning( | |
| 180 f"Threshold-aware confusion matrix failed; falling back: {e}" | |
| 181 ) | |
| 182 | |
| 183 # ---- ROC with threshold marker ---- | |
| 184 try: | |
| 185 fpr, tpr, thr = roc_curve(y, y_scores) | |
| 186 roc_auc = auc(fpr, tpr) | |
| 187 fig_roc = go.Figure() | |
| 188 fig_roc.add_scatter( | |
| 189 x=fpr, y=tpr, mode="lines", name=f"ROC (AUC={roc_auc:.3f})" | |
| 190 ) | |
| 191 if len(thr): | |
| 192 mask = np.isfinite(thr) | |
| 193 if mask.any(): | |
| 194 idx_local = int(np.argmin(np.abs(thr[mask] - prob_thresh))) | |
| 195 idx = np.where(mask)[0][idx_local] | |
| 196 if 0 <= idx < len(fpr): | |
| 197 fig_roc.add_scatter( | |
| 198 x=[fpr[idx]], | |
| 199 y=[tpr[idx]], | |
| 200 mode="markers", | |
| 201 name=f"@ {prob_thresh:.2f}", | |
| 202 marker=dict(size=10), | |
| 203 ) | |
| 204 fig_roc.update_layout( | |
| 205 title=f"ROC Curve (marker at threshold={prob_thresh:.2f})", | |
| 206 xaxis_title="False Positive Rate", | |
| 207 yaxis_title="True Positive Rate", | |
| 208 ) | |
| 209 _apply_report_layout(fig_roc) | |
| 210 self.explainer_plots["roc_auc"] = fig_roc | |
| 211 except Exception as e: | |
| 212 LOG.warning(f"Threshold marker on ROC failed; falling back: {e}") | |
| 213 | |
| 214 # ---- PR with threshold marker ---- | |
| 215 try: | |
| 216 precision, recall, thr_pr = precision_recall_curve(y, y_scores) | |
| 217 pr_auc = auc(recall, precision) | |
| 218 fig_pr = go.Figure() | |
| 219 fig_pr.add_scatter( | |
| 220 x=recall, y=precision, mode="lines", name=f"PR (AUC={pr_auc:.3f})" | |
| 221 ) | |
| 222 if len(thr_pr): | |
| 223 idx_pr = int(np.argmin(np.abs(thr_pr - prob_thresh))) | |
| 224 # note: thr_pr has length = len(precision) - 1 | |
| 225 idx_pr = max(0, min(idx_pr, len(recall) - 1)) | |
| 226 fig_pr.add_scatter( | |
| 227 x=[recall[idx_pr]], | |
| 228 y=[precision[idx_pr]], | |
| 229 mode="markers", | |
| 230 name=f"@ {prob_thresh:.2f}", | |
| 231 marker=dict(size=10), | |
| 232 ) | |
| 233 fig_pr.update_layout( | |
| 234 title=f"Precision–Recall (marker at threshold={prob_thresh:.2f})", | |
| 235 xaxis_title="Recall", | |
| 236 yaxis_title="Precision", | |
| 237 ) | |
| 238 _apply_report_layout(fig_pr) | |
| 239 self.explainer_plots["pr_auc"] = fig_pr | |
| 240 except Exception as e: | |
| 241 LOG.warning(f"Threshold marker on PR failed; falling back: {e}") | |
| 242 | |
| 243 # these go into the Test tab (don't overwrite overrides) | |
| 113 for key, fn in [ | 244 for key, fn in [ | 
| 114 ("roc_auc", explainer.plot_roc_auc), | 245 ("roc_auc", explainer.plot_roc_auc), | 
| 115 ("pr_auc", explainer.plot_pr_auc), | 246 ("pr_auc", explainer.plot_pr_auc), | 
| 116 ("lift_curve", explainer.plot_lift_curve), | 247 ("lift_curve", explainer.plot_lift_curve), | 
| 117 ("confusion_matrix", explainer.plot_confusion_matrix), | 248 ("confusion_matrix", explainer.plot_confusion_matrix), | 
| 118 ("threshold", explainer.plot_precision), # Percentage vs probability | 249 ("threshold", explainer.plot_precision), # percentage vs probability | 
| 119 ("cumulative_precision", explainer.plot_cumulative_precision), | 250 ("cumulative_precision", explainer.plot_cumulative_precision), | 
| 120 ]: | 251 ]: | 
| 121 try: | 252 if key in self.explainer_plots: | 
| 122 self.explainer_plots[key] = fn() | 253 continue | 
| 254 try: | |
| 255 fig = fn() | |
| 256 if fig is not None: | |
| 257 self.explainer_plots[key] = fig | |
| 123 except Exception as e: | 258 except Exception as e: | 
| 124 LOG.error(f"Error generating explainer plot {key}: {e}") | 259 LOG.error(f"Error generating explainer plot {key}: {e}") | 
| 125 | 260 | 
| 126 # mean SHAP importances | 261 # mean SHAP importances | 
| 127 try: | 262 try: | 
| 141 valid_feats = [] | 276 valid_feats = [] | 
| 142 for feat in self.features_name: | 277 for feat in self.features_name: | 
| 143 if feat in explainer.X.columns or feat in explainer.onehot_cols: | 278 if feat in explainer.X.columns or feat in explainer.onehot_cols: | 
| 144 valid_feats.append(feat) | 279 valid_feats.append(feat) | 
| 145 else: | 280 else: | 
| 146 LOG.warning(f"Skipping PDP for feature {feat!r}: not found in explainer data") | 281 LOG.warning( | 
| 282 f"Skipping PDP for feature {feat!r}: not found in explainer data" | |
| 283 ) | |
| 147 | 284 | 
| 148 for feat in valid_feats: | 285 for feat in valid_feats: | 
| 149 # wrap each PDP call to catch any unexpected AssertionErrors | 286 # wrap each PDP call to catch any unexpected AssertionErrors | 
| 150 def make_pdp_plotter(f): | 287 def make_pdp_plotter(f): | 
| 151 def _plot(): | 288 def _plot(): | 
| 155 LOG.warning(f"PDP AssertionError for {f!r}: {ae}") | 292 LOG.warning(f"PDP AssertionError for {f!r}: {ae}") | 
| 156 return None | 293 return None | 
| 157 except Exception as e: | 294 except Exception as e: | 
| 158 LOG.error(f"Unexpected error plotting PDP for {f!r}: {e}") | 295 LOG.error(f"Unexpected error plotting PDP for {f!r}: {e}") | 
| 159 return None | 296 return None | 
| 297 | |
| 160 return _plot | 298 return _plot | 
| 161 | 299 | 
| 162 self.explainer_plots[f"pdp__{feat}"] = make_pdp_plotter(feat) | 300 self.explainer_plots[f"pdp__{feat}"] = make_pdp_plotter(feat) | 
