Mercurial > repos > goeckslab > pycaret_predict
comparison pycaret_classification.py @ 17:c5c324ac29fc draft default tip
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
| author | goeckslab |
|---|---|
| date | Sat, 06 Dec 2025 14:20:36 +0000 |
| parents | a2aeeb754d76 |
| children |
comparison
equal
deleted
inserted
replaced
| 16:4fee4504646e | 17:c5c324ac29fc |
|---|---|
| 6 import pandas as pd | 6 import pandas as pd |
| 7 import plotly.graph_objects as go | 7 import plotly.graph_objects as go |
| 8 from base_model_trainer import BaseModelTrainer | 8 from base_model_trainer import BaseModelTrainer |
| 9 from dashboard import generate_classifier_explainer_dashboard | 9 from dashboard import generate_classifier_explainer_dashboard |
| 10 from pycaret.classification import ClassificationExperiment | 10 from pycaret.classification import ClassificationExperiment |
| 11 from sklearn.metrics import auc, confusion_matrix, precision_recall_curve, roc_curve | 11 from sklearn.metrics import ( |
| 12 auc, | |
| 13 confusion_matrix, | |
| 14 matthews_corrcoef, | |
| 15 precision_recall_curve, | |
| 16 precision_recall_fscore_support, | |
| 17 roc_curve, | |
| 18 ) | |
| 12 from utils import predict_proba | 19 from utils import predict_proba |
| 13 | 20 |
| 14 LOG = logging.getLogger(__name__) | 21 LOG = logging.getLogger(__name__) |
| 15 | 22 |
| 16 | 23 |
| 135 explainer = ClassifierExplainer(self.best_model, X_test, y_test) | 142 explainer = ClassifierExplainer(self.best_model, X_test, y_test) |
| 136 | 143 |
| 137 # a dict to hold the raw Figure objects or callables | 144 # a dict to hold the raw Figure objects or callables |
| 138 self.explainer_plots: Dict[str, go.Figure] = {} | 145 self.explainer_plots: Dict[str, go.Figure] = {} |
| 139 | 146 |
| 147 y_true, y_pred, label_values, y_scores = self._get_test_predictions() | |
| 148 | |
| 149 # — Classification report (Plotly table) — | |
| 150 try: | |
| 151 fig_report = self._build_classification_report_fig( | |
| 152 y_true, y_pred, label_values | |
| 153 ) | |
| 154 if fig_report is not None: | |
| 155 self.explainer_plots["class_report"] = fig_report | |
| 156 except Exception as e: | |
| 157 LOG.warning(f"Could not generate Plotly classification report: {e}") | |
| 158 | |
| 159 # — Confusion matrix with actual labels — | |
| 160 try: | |
| 161 fig_cm = self._build_confusion_matrix_fig(y_true, y_pred, label_values) | |
| 162 if fig_cm is not None: | |
| 163 self.explainer_plots["confusion_matrix"] = fig_cm | |
| 164 except Exception as e: | |
| 165 LOG.warning(f"Could not generate Plotly confusion matrix: {e}") | |
| 166 | |
| 140 # --- Threshold-aware overrides for CM / ROC / PR --- | 167 # --- Threshold-aware overrides for CM / ROC / PR --- |
| 141 prob_thresh = getattr(self, "probability_threshold", None) | 168 prob_thresh = getattr(self, "probability_threshold", None) |
| 142 | 169 |
| 143 # Only for binary classification and when threshold is provided | 170 # Only for binary classification and when threshold is provided |
| 144 if (prob_thresh is not None) and (not self.exp.is_multiclass): | 171 if (prob_thresh is not None) and (not self.exp.is_multiclass): |
| 145 X = self.exp.X_test_transformed | |
| 146 y = pd.Series(self.exp.y_test_transformed).reset_index(drop=True) | |
| 147 | |
| 148 # Get positive-class scores (robust defaults) | |
| 149 classes = list(getattr(self.best_model, "classes_", [0, 1])) | |
| 150 try: | |
| 151 pos_idx = classes.index(1) if 1 in classes else 1 | |
| 152 except Exception: | |
| 153 pos_idx = 1 | |
| 154 | |
| 155 proba = self.best_model.predict_proba(X) | |
| 156 y_scores = proba[:, pos_idx] | |
| 157 | |
| 158 # Derive label names consistently | |
| 159 pos_label = classes[pos_idx] if len(classes) > pos_idx else 1 | |
| 160 neg_label = classes[1 - pos_idx] if len(classes) > 1 else 0 | |
| 161 | |
| 162 # ---- Confusion Matrix @ threshold ---- | |
| 163 try: | |
| 164 y_pred = np.where(y_scores >= prob_thresh, pos_label, neg_label) | |
| 165 cm = confusion_matrix(y, y_pred, labels=[neg_label, pos_label]) | |
| 166 fig_cm = go.Figure( | |
| 167 data=go.Heatmap( | |
| 168 z=cm, | |
| 169 x=[f"Pred {neg_label}", f"Pred {pos_label}"], | |
| 170 y=[f"True {neg_label}", f"True {pos_label}"], | |
| 171 text=cm, | |
| 172 texttemplate="%{text}", | |
| 173 colorscale="Blues", | |
| 174 showscale=False, | |
| 175 ) | |
| 176 ) | |
| 177 fig_cm.update_layout( | |
| 178 title=f"Confusion Matrix @ threshold={prob_thresh:.2f}", | |
| 179 xaxis_title="Predicted label", | |
| 180 yaxis_title="True label", | |
| 181 ) | |
| 182 _apply_report_layout(fig_cm) | |
| 183 self.explainer_plots["confusion_matrix"] = fig_cm | |
| 184 except Exception as e: | |
| 185 LOG.warning( | |
| 186 f"Threshold-aware confusion matrix failed; falling back: {e}" | |
| 187 ) | |
| 188 | |
| 189 # ---- ROC with threshold marker ---- | 172 # ---- ROC with threshold marker ---- |
| 190 try: | 173 try: |
| 191 fpr, tpr, thr = roc_curve(y, y_scores) | 174 if y_scores is None: |
| 175 raise ValueError("Predicted probabilities unavailable") | |
| 176 fpr, tpr, thr = roc_curve(y_true, y_scores) | |
| 192 roc_auc = auc(fpr, tpr) | 177 roc_auc = auc(fpr, tpr) |
| 193 fig_roc = go.Figure() | 178 fig_roc = go.Figure() |
| 194 fig_roc.add_scatter( | 179 fig_roc.add_scatter( |
| 195 x=fpr, y=tpr, mode="lines", name=f"ROC (AUC={roc_auc:.3f})" | 180 x=fpr, y=tpr, mode="lines", name=f"ROC (AUC={roc_auc:.3f})" |
| 196 ) | 181 ) |
| 217 except Exception as e: | 202 except Exception as e: |
| 218 LOG.warning(f"Threshold marker on ROC failed; falling back: {e}") | 203 LOG.warning(f"Threshold marker on ROC failed; falling back: {e}") |
| 219 | 204 |
| 220 # ---- PR with threshold marker ---- | 205 # ---- PR with threshold marker ---- |
| 221 try: | 206 try: |
| 222 precision, recall, thr_pr = precision_recall_curve(y, y_scores) | 207 if y_scores is None: |
| 208 raise ValueError("Predicted probabilities unavailable") | |
| 209 precision, recall, thr_pr = precision_recall_curve(y_true, y_scores) | |
| 223 pr_auc = auc(recall, precision) | 210 pr_auc = auc(recall, precision) |
| 224 fig_pr = go.Figure() | 211 fig_pr = go.Figure() |
| 225 fig_pr.add_scatter( | 212 fig_pr.add_scatter( |
| 226 x=recall, y=precision, mode="lines", name=f"PR (AUC={pr_auc:.3f})" | 213 x=recall, y=precision, mode="lines", name=f"PR (AUC={pr_auc:.3f})" |
| 227 ) | 214 ) |
| 302 return None | 289 return None |
| 303 | 290 |
| 304 return _plot | 291 return _plot |
| 305 | 292 |
| 306 self.explainer_plots[f"pdp__{feat}"] = make_pdp_plotter(feat) | 293 self.explainer_plots[f"pdp__{feat}"] = make_pdp_plotter(feat) |
| 294 | |
| 295 def _get_test_predictions(self): | |
| 296 """ | |
| 297 Return y_true, y_pred, label list, and (optionally) positive-class | |
| 298 probabilities when available. Ensures predictions respect the optional | |
| 299 probability threshold for binary tasks. | |
| 300 """ | |
| 301 y_true = pd.Series(self.exp.y_test_transformed).reset_index(drop=True) | |
| 302 X_test = self.exp.X_test_transformed | |
| 303 prob_thresh = getattr(self, "probability_threshold", None) | |
| 304 | |
| 305 y_scores = None | |
| 306 try: | |
| 307 proba = self.best_model.predict_proba(X_test) | |
| 308 y_scores = proba | |
| 309 except Exception: | |
| 310 LOG.debug("predict_proba unavailable for test predictions.") | |
| 311 | |
| 312 try: | |
| 313 if ( | |
| 314 prob_thresh is not None | |
| 315 and not self.exp.is_multiclass | |
| 316 and y_scores is not None | |
| 317 and y_scores.ndim == 2 | |
| 318 and y_scores.shape[1] > 1 | |
| 319 ): | |
| 320 classes = list(getattr(self.best_model, "classes_", [])) | |
| 321 try: | |
| 322 pos_idx = classes.index(1) if 1 in classes else 1 | |
| 323 except Exception: | |
| 324 pos_idx = 1 | |
| 325 neg_idx = 1 - pos_idx if y_scores.shape[1] > 1 else 0 | |
| 326 pos_label = classes[pos_idx] if len(classes) > pos_idx else 1 | |
| 327 neg_label = classes[neg_idx] if len(classes) > neg_idx else 0 | |
| 328 y_pred = np.where(y_scores[:, pos_idx] >= prob_thresh, pos_label, neg_label) | |
| 329 y_scores = y_scores[:, pos_idx] | |
| 330 else: | |
| 331 y_pred = self.best_model.predict(X_test) | |
| 332 except Exception as exc: | |
| 333 LOG.warning("Falling back to raw predict for test predictions: %s", exc) | |
| 334 y_pred = self.best_model.predict(X_test) | |
| 335 | |
| 336 y_pred = pd.Series(y_pred).reset_index(drop=True) | |
| 337 if y_scores is not None: | |
| 338 y_scores = np.asarray(y_scores) | |
| 339 if y_scores.ndim > 1 and y_scores.shape[1] == 1: | |
| 340 y_scores = y_scores.ravel() | |
| 341 if self.exp.is_multiclass and y_scores.ndim > 1: | |
| 342 # Avoid passing multiclass score matrices to ROC/PR utilities | |
| 343 y_scores = None | |
| 344 label_values = pd.unique(pd.concat([y_true, y_pred], ignore_index=True)) | |
| 345 return y_true, y_pred, label_values.tolist(), y_scores | |
| 346 | |
| 347 def _threshold_suffix(self) -> str: | |
| 348 """ | |
| 349 Build a suffix like ' (threshold=0.50)' for binary tasks; omit for | |
| 350 multiclass where thresholds are not applied. | |
| 351 """ | |
| 352 if getattr(self, "task_type", None) != "classification": | |
| 353 return "" | |
| 354 if getattr(self.exp, "is_multiclass", False): | |
| 355 return "" | |
| 356 prob_thresh = getattr(self, "probability_threshold", None) | |
| 357 if prob_thresh is None: | |
| 358 return " (threshold=0.50)" | |
| 359 try: | |
| 360 return f" (threshold={float(prob_thresh):.2f})" | |
| 361 except Exception: | |
| 362 return f" (threshold={prob_thresh})" | |
| 363 | |
| 364 def _build_confusion_matrix_fig(self, y_true, y_pred, labels): | |
| 365 def _label_sort_key(lbl): | |
| 366 try: | |
| 367 return (0, float(lbl)) | |
| 368 except Exception: | |
| 369 return (1, str(lbl)) | |
| 370 | |
| 371 ordered_labels = sorted(labels, key=_label_sort_key) | |
| 372 cm = confusion_matrix(y_true, y_pred, labels=ordered_labels) | |
| 373 label_names = [str(lbl) for lbl in ordered_labels] | |
| 374 fig_cm = go.Figure( | |
| 375 data=go.Heatmap( | |
| 376 z=cm, | |
| 377 x=[f"Pred {lbl}" for lbl in label_names], | |
| 378 y=[f"True {lbl}" for lbl in label_names], | |
| 379 text=cm, | |
| 380 texttemplate="%{text}", | |
| 381 colorscale="Blues", | |
| 382 showscale=False, | |
| 383 ) | |
| 384 ) | |
| 385 fig_cm.update_layout( | |
| 386 title=f"Confusion Matrix{self._threshold_suffix()}", | |
| 387 xaxis_title=f"Predicted label ({self.target})", | |
| 388 yaxis_title=f"True label ({self.target})", | |
| 389 ) | |
| 390 fig_cm.update_xaxes( | |
| 391 type="category", | |
| 392 categoryorder="array", | |
| 393 categoryarray=[f"Pred {lbl}" for lbl in label_names], | |
| 394 ) | |
| 395 fig_cm.update_yaxes( | |
| 396 type="category", | |
| 397 categoryorder="array", | |
| 398 categoryarray=[f"True {lbl}" for lbl in label_names], | |
| 399 autorange="reversed", | |
| 400 ) | |
| 401 _apply_report_layout(fig_cm) | |
| 402 return fig_cm | |
| 403 | |
| 404 def _build_classification_report_fig(self, y_true, y_pred, labels): | |
| 405 precision, recall, f1, support = precision_recall_fscore_support( | |
| 406 y_true, y_pred, labels=labels, zero_division=0 | |
| 407 ) | |
| 408 mcc_scores = [] | |
| 409 for lbl in labels: | |
| 410 y_true_bin = (y_true == lbl).astype(int) | |
| 411 y_pred_bin = (y_pred == lbl).astype(int) | |
| 412 try: | |
| 413 mcc_val = matthews_corrcoef(y_true_bin, y_pred_bin) | |
| 414 except Exception: | |
| 415 mcc_val = 0.0 | |
| 416 mcc_scores.append(mcc_val) | |
| 417 | |
| 418 label_names = [str(lbl) for lbl in labels] | |
| 419 metrics = ["precision", "recall", "f1", "support"] | |
| 420 | |
| 421 max_support = float(max(support) if len(support) else 0) | |
| 422 z_rows = [] | |
| 423 text_rows = [] | |
| 424 for i, lbl in enumerate(label_names): | |
| 425 norm_support = (support[i] / max_support) if max_support else 0.0 | |
| 426 z_rows.append( | |
| 427 [ | |
| 428 precision[i], | |
| 429 recall[i], | |
| 430 f1[i], | |
| 431 norm_support, | |
| 432 ] | |
| 433 ) | |
| 434 text_rows.append( | |
| 435 [ | |
| 436 f"{precision[i]:.3f}", | |
| 437 f"{recall[i]:.3f}", | |
| 438 f"{f1[i]:.3f}", | |
| 439 f"{int(support[i])}", | |
| 440 ] | |
| 441 ) | |
| 442 | |
| 443 fig = go.Figure( | |
| 444 data=go.Heatmap( | |
| 445 z=z_rows, | |
| 446 x=metrics, | |
| 447 y=label_names, | |
| 448 colorscale="YlOrRd", | |
| 449 zmin=0, | |
| 450 zmax=1, | |
| 451 colorbar=dict(title="Scale"), | |
| 452 text=text_rows, | |
| 453 texttemplate="%{text}", | |
| 454 hovertemplate="Label=%{y}<br>Metric=%{x}<br>Value=%{text}<extra></extra>", | |
| 455 ) | |
| 456 ) | |
| 457 fig.update_yaxes( | |
| 458 title_text=f"Label ({self.target})", | |
| 459 autorange="reversed", | |
| 460 type="category", | |
| 461 tickmode="array", | |
| 462 tickvals=label_names, | |
| 463 ticktext=label_names, | |
| 464 showgrid=False, | |
| 465 ) | |
| 466 fig.update_xaxes(title_text="", tickangle=45) | |
| 467 fig.update_layout( | |
| 468 title=f"Per-Class Metrics{self._threshold_suffix()}", | |
| 469 margin=dict(l=70, r=60, t=70, b=80), | |
| 470 ) | |
| 471 _apply_report_layout(fig) | |
| 472 return fig |
