Mercurial > repos > goeckslab > image_learner
comparison plotly_plots.py @ 15:d17e3a1b8659 draft default tip
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
| author | goeckslab |
|---|---|
| date | Fri, 28 Nov 2025 15:45:49 +0000 |
| parents | c5150cceab47 |
| children |
comparison
equal
deleted
inserted
replaced
| 14:94cd9ac4a9b1 | 15:d17e3a1b8659 |
|---|---|
| 5 import numpy as np | 5 import numpy as np |
| 6 import pandas as pd | 6 import pandas as pd |
| 7 import plotly.graph_objects as go | 7 import plotly.graph_objects as go |
| 8 import plotly.io as pio | 8 import plotly.io as pio |
| 9 from constants import LABEL_COLUMN_NAME, SPLIT_COLUMN_NAME | 9 from constants import LABEL_COLUMN_NAME, SPLIT_COLUMN_NAME |
| 10 from sklearn.metrics import auc, roc_curve | 10 |
| 11 from sklearn.preprocessing import label_binarize | 11 |
| 12 def _style_fig(fig: go.Figure, font_size: int = 12) -> go.Figure: | |
| 13 """Apply consistent styling across Plotly figures.""" | |
| 14 fig.update_layout( | |
| 15 font=dict(size=font_size), | |
| 16 plot_bgcolor="#ffffff", | |
| 17 paper_bgcolor="#ffffff", | |
| 18 ) | |
| 19 fig.update_xaxes(gridcolor="#e8e8e8") | |
| 20 fig.update_yaxes(gridcolor="#e8e8e8") | |
| 21 return fig | |
| 22 | |
| 23 | |
| 24 def _labels_from_metadata_dict(meta_dict: dict) -> List[str]: | |
| 25 """Extract ordered label names from Ludwig train_set_metadata.""" | |
| 26 if not isinstance(meta_dict, dict): | |
| 27 return [] | |
| 28 | |
| 29 for key in ("idx2str", "idx2label", "vocab"): | |
| 30 seq = meta_dict.get(key) | |
| 31 if isinstance(seq, list) and seq: | |
| 32 return [str(v) for v in seq] | |
| 33 | |
| 34 str2idx = meta_dict.get("str2idx") | |
| 35 if isinstance(str2idx, dict) and str2idx: | |
| 36 int_indices = [v for v in str2idx.values() if isinstance(v, int)] | |
| 37 if int_indices: | |
| 38 max_idx = max(int_indices) | |
| 39 ordered = [None] * (max_idx + 1) | |
| 40 for name, idx in str2idx.items(): | |
| 41 if isinstance(idx, int) and 0 <= idx < len(ordered): | |
| 42 ordered[idx] = name | |
| 43 return [str(v) for v in ordered if v is not None] | |
| 44 | |
| 45 return [] | |
| 46 | |
| 47 | |
| 48 def _resolve_confusion_labels( | |
| 49 label_stats: dict, | |
| 50 n_classes: int, | |
| 51 metadata_csv_path: Optional[str], | |
| 52 train_set_metadata_path: Optional[str], | |
| 53 ) -> List[str]: | |
| 54 """Prefer original labels from metadata; fall back to stats if unavailable.""" | |
| 55 if train_set_metadata_path: | |
| 56 try: | |
| 57 meta_path = Path(train_set_metadata_path) | |
| 58 if meta_path.exists(): | |
| 59 with open(meta_path, "r") as f: | |
| 60 meta_json = json.load(f) | |
| 61 label_meta = meta_json.get(LABEL_COLUMN_NAME) | |
| 62 if not isinstance(label_meta, dict): | |
| 63 label_meta = next( | |
| 64 ( | |
| 65 v | |
| 66 for v in meta_json.values() | |
| 67 if isinstance(v, dict) | |
| 68 and any(k in v for k in ("idx2str", "str2idx", "idx2label", "vocab")) | |
| 69 ), | |
| 70 None, | |
| 71 ) | |
| 72 labels_from_meta = _labels_from_metadata_dict(label_meta) if label_meta else [] | |
| 73 if labels_from_meta and len(labels_from_meta) >= n_classes: | |
| 74 return [str(label) for label in labels_from_meta[:n_classes]] | |
| 75 except Exception as exc: | |
| 76 print(f"Warning: Unable to read labels from train_set_metadata: {exc}") | |
| 77 | |
| 78 if metadata_csv_path: | |
| 79 try: | |
| 80 csv_path = Path(metadata_csv_path) | |
| 81 if csv_path.exists(): | |
| 82 df_meta = pd.read_csv(csv_path) | |
| 83 if LABEL_COLUMN_NAME in df_meta.columns: | |
| 84 uniques = df_meta[LABEL_COLUMN_NAME].dropna().unique().tolist() | |
| 85 if uniques and len(uniques) >= n_classes: | |
| 86 return [str(u) for u in uniques[:n_classes]] | |
| 87 except Exception as exc: | |
| 88 print(f"Warning: Unable to read labels from metadata CSV: {exc}") | |
| 89 | |
| 90 pcs = label_stats.get("per_class_stats", {}) | |
| 91 if pcs: | |
| 92 pcs_labels = [str(k) for k in pcs.keys()] | |
| 93 if len(pcs_labels) >= n_classes: | |
| 94 return pcs_labels[:n_classes] | |
| 95 | |
| 96 labels = label_stats.get("labels") | |
| 97 if not labels: | |
| 98 labels = [str(i) for i in range(n_classes)] | |
| 99 if len(labels) < n_classes: | |
| 100 labels = labels + [str(i) for i in range(len(labels), n_classes)] | |
| 101 return [str(label) for label in labels[:n_classes]] | |
| 12 | 102 |
| 13 | 103 |
| 14 def build_classification_plots( | 104 def build_classification_plots( |
| 15 test_stats_path: str, | 105 test_stats_path: str, |
| 16 training_stats_path: Optional[str] = None, | 106 training_stats_path: Optional[str] = None, |
| 107 metadata_csv_path: Optional[str] = None, | |
| 108 train_set_metadata_path: Optional[str] = None, | |
| 17 ) -> List[Dict[str, str]]: | 109 ) -> List[Dict[str, str]]: |
| 18 """ | 110 """ |
| 19 Read Ludwig’s test_statistics.json and build three interactive Plotly panels: | 111 Read Ludwig’s test_statistics.json and build three interactive Plotly panels: |
| 20 - Confusion Matrix | 112 - Confusion Matrix |
| 21 - ROC-AUC | 113 - ROC-AUC |
| 22 - Classification Report Heatmap | 114 - Classification Report Heatmap |
| 115 | |
| 116 If metadata paths are provided, the confusion matrix axes will use the original | |
| 117 label values from the training metadata rather than integer-encoded labels. | |
| 23 | 118 |
| 24 Returns a list of dicts, each with: | 119 Returns a list of dicts, each with: |
| 25 { | 120 { |
| 26 "title": <plot title>, | 121 "title": <plot title>, |
| 27 "html": <HTML fragment for embedding> | 122 "html": <HTML fragment for embedding> |
| 40 | 135 |
| 41 plots: List[Dict[str, str]] = [] | 136 plots: List[Dict[str, str]] = [] |
| 42 | 137 |
| 43 # 0) Confusion Matrix | 138 # 0) Confusion Matrix |
| 44 cm = np.array(label_stats["confusion_matrix"], dtype=int) | 139 cm = np.array(label_stats["confusion_matrix"], dtype=int) |
| 45 # Try to get actual class names from per_class_stats keys (which contain the real labels) | 140 labels = _resolve_confusion_labels( |
| 46 pcs = label_stats.get("per_class_stats", {}) | 141 label_stats, |
| 47 if pcs: | 142 n_classes, |
| 48 labels = list(pcs.keys()) | 143 metadata_csv_path=metadata_csv_path, |
| 49 else: | 144 train_set_metadata_path=train_set_metadata_path, |
| 50 labels = label_stats.get("labels", [str(i) for i in range(cm.shape[0])]) | 145 ) |
| 51 total = cm.sum() | 146 total = cm.sum() |
| 52 | 147 |
| 53 fig_cm = go.Figure( | 148 fig_cm = go.Figure( |
| 54 go.Heatmap( | 149 go.Heatmap( |
| 55 z=cm, | 150 z=cm, |
| 68 yaxis_autorange="reversed", | 163 yaxis_autorange="reversed", |
| 69 width=side_px, | 164 width=side_px, |
| 70 height=side_px, | 165 height=side_px, |
| 71 margin=dict(t=100, l=80, r=80, b=80), | 166 margin=dict(t=100, l=80, r=80, b=80), |
| 72 ) | 167 ) |
| 168 _style_fig(fig_cm) | |
| 73 | 169 |
| 74 # annotate counts and percentages | 170 # annotate counts and percentages |
| 75 mval = cm.max() if cm.size else 0 | 171 mval = cm.max() if cm.size else 0 |
| 76 thresh = mval / 2 | 172 thresh = mval / 2 |
| 77 for i in range(cm.shape[0]): | 173 for i in range(cm.shape[0]): |
| 108 include_plotlyjs="cdn", | 204 include_plotlyjs="cdn", |
| 109 config=common_cfg | 205 config=common_cfg |
| 110 ) | 206 ) |
| 111 }) | 207 }) |
| 112 | 208 |
| 113 # 1) ROC-AUC Curves (Multi-class) | 209 # 1) ROC Curve (from test_statistics) |
| 114 roc_plot = _build_roc_auc_plot(test_stats_path, labels, common_cfg) | 210 roc_plot = _build_static_roc_plot(label_stats, common_cfg, friendly_labels=labels) |
| 115 if roc_plot: | 211 if roc_plot: |
| 116 plots.append(roc_plot) | 212 plots.append(roc_plot) |
| 213 | |
| 214 # 2) Precision-Recall Curve (from test_statistics) | |
| 215 pr_plot = _build_precision_recall_plot(label_stats, common_cfg) | |
| 216 if pr_plot: | |
| 217 plots.append(pr_plot) | |
| 117 | 218 |
| 118 # 2) Classification Report Heatmap | 219 # 2) Classification Report Heatmap |
| 119 pcs = label_stats.get("per_class_stats", {}) | 220 pcs = label_stats.get("per_class_stats", {}) |
| 120 if pcs: | 221 if pcs: |
| 121 classes = list(pcs.keys()) | 222 classes = list(pcs.keys()) |
| 122 metrics = ["precision", "recall", "f1_score"] | 223 metrics = [ |
| 224 "precision", | |
| 225 "recall", | |
| 226 "f1_score", | |
| 227 "accuracy", | |
| 228 "matthews_correlation_coefficient", | |
| 229 "specificity", | |
| 230 ] | |
| 123 z, txt = [], [] | 231 z, txt = [], [] |
| 124 for c in classes: | 232 for c in classes: |
| 125 row, trow = [], [] | 233 row, trow = [], [] |
| 126 for m in metrics: | 234 for m in metrics: |
| 127 val = pcs[c].get(m, 0) | 235 val = pcs[c].get(m, 0) |
| 131 txt.append(trow) | 239 txt.append(trow) |
| 132 | 240 |
| 133 fig_cr = go.Figure( | 241 fig_cr = go.Figure( |
| 134 go.Heatmap( | 242 go.Heatmap( |
| 135 z=z, | 243 z=z, |
| 136 x=metrics, | 244 x=[m.replace("_", " ") for m in metrics], |
| 137 y=[str(c) for c in classes], | 245 y=[str(c) for c in classes], |
| 138 text=txt, | 246 text=txt, |
| 139 texttemplate="%{text}", | 247 texttemplate="%{text}", |
| 140 colorscale="Reds", | 248 colorscale="Reds", |
| 141 showscale=True, | 249 showscale=True, |
| 142 colorbar=dict(title="Value"), | 250 colorbar=dict(title="Value"), |
| 143 ) | 251 ) |
| 144 ) | 252 ) |
| 145 fig_cr.update_layout( | 253 fig_cr.update_layout( |
| 146 title="Classification Report", | 254 title="Per-Class metrics", |
| 147 xaxis_title="", | 255 xaxis_title="", |
| 148 yaxis_title="Class", | 256 yaxis_title="Class", |
| 149 width=side_px, | 257 width=side_px, |
| 150 height=side_px, | 258 height=side_px, |
| 151 margin=dict(t=80, l=80, r=80, b=80), | 259 margin=dict(t=80, l=80, r=80, b=80), |
| 152 ) | 260 ) |
| 261 _style_fig(fig_cr) | |
| 153 plots.append({ | 262 plots.append({ |
| 154 "title": "Classification Report", | 263 "title": "Per-Class metrics", |
| 155 "html": pio.to_html( | 264 "html": pio.to_html( |
| 156 fig_cr, | 265 fig_cr, |
| 157 full_html=False, | 266 full_html=False, |
| 158 include_plotlyjs=False, | 267 include_plotlyjs=False, |
| 159 config=common_cfg | 268 config=common_cfg |
| 160 ) | 269 ) |
| 161 }) | 270 }) |
| 162 | 271 |
| 272 # 3) Prediction Diagnostics (from predictions.csv) | |
| 273 # Note: appended separately in generate_html_report, not returned here. | |
| 274 | |
| 163 return plots | 275 return plots |
| 164 | 276 |
| 165 | 277 |
| 166 def _build_roc_auc_plot(test_stats_path: str, class_labels: List[str], config: dict) -> Optional[Dict[str, str]]: | 278 def build_train_validation_plots(train_stats_path: str) -> List[Dict[str, str]]: |
| 167 """ | 279 """Generate Train/Validation learning curve plots from training_statistics.json.""" |
| 168 Build an interactive ROC-AUC curve plot for multi-class classification. | 280 if not train_stats_path or not Path(train_stats_path).exists(): |
| 169 Following sklearn's ROC example with micro-average and per-class curves. | 281 return [] |
| 170 | |
| 171 Args: | |
| 172 test_stats_path: Path to test_statistics.json | |
| 173 class_labels: List of class label names | |
| 174 config: Plotly config dict | |
| 175 | |
| 176 Returns: | |
| 177 Dict with title and HTML, or None if data unavailable | |
| 178 """ | |
| 179 try: | 282 try: |
| 180 # Get the experiment directory from test_stats_path | 283 with open(train_stats_path, "r") as f: |
| 181 exp_dir = Path(test_stats_path).parent | 284 train_stats = json.load(f) |
| 182 | 285 except Exception as exc: |
| 183 # Load predictions with probabilities | 286 print(f"Warning: Unable to read training statistics: {exc}") |
| 184 predictions_path = exp_dir / "predictions.csv" | 287 return [] |
| 185 if not predictions_path.exists(): | 288 |
| 289 label_train = (train_stats.get("training") or {}).get("label", {}) | |
| 290 label_val = (train_stats.get("validation") or {}).get("label", {}) | |
| 291 if not label_train and not label_val: | |
| 292 return [] | |
| 293 plots: List[Dict[str, str]] = [] | |
| 294 include_js = True # Load Plotly.js once for this group | |
| 295 | |
| 296 def _get_series(stats: dict, metric: str) -> List[float]: | |
| 297 if metric not in stats: | |
| 298 return [] | |
| 299 vals = stats.get(metric, []) | |
| 300 if isinstance(vals, list): | |
| 301 return [float(v) for v in vals] | |
| 302 try: | |
| 303 return [float(vals)] | |
| 304 except Exception: | |
| 305 return [] | |
| 306 | |
| 307 def _line_plot(metric_key: str, title: str, yaxis_title: str) -> Optional[Dict[str, str]]: | |
| 308 train_series = _get_series(label_train, metric_key) | |
| 309 val_series = _get_series(label_val, metric_key) | |
| 310 if not train_series and not val_series: | |
| 186 return None | 311 return None |
| 187 | 312 epochs_train = list(range(1, len(train_series) + 1)) |
| 313 epochs_val = list(range(1, len(val_series) + 1)) | |
| 314 fig = go.Figure() | |
| 315 if train_series: | |
| 316 fig.add_trace( | |
| 317 go.Scatter( | |
| 318 x=epochs_train, | |
| 319 y=train_series, | |
| 320 mode="lines+markers", | |
| 321 name="Train", | |
| 322 line=dict(width=4), | |
| 323 ) | |
| 324 ) | |
| 325 if val_series: | |
| 326 fig.add_trace( | |
| 327 go.Scatter( | |
| 328 x=epochs_val, | |
| 329 y=val_series, | |
| 330 mode="lines+markers", | |
| 331 name="Validation", | |
| 332 line=dict(width=4), | |
| 333 ) | |
| 334 ) | |
| 335 fig.update_layout( | |
| 336 title=dict(text=title, x=0.5), | |
| 337 xaxis_title="Epoch", | |
| 338 yaxis_title=yaxis_title, | |
| 339 width=760, | |
| 340 height=520, | |
| 341 hovermode="x unified", | |
| 342 ) | |
| 343 _style_fig(fig) | |
| 344 return { | |
| 345 "title": title, | |
| 346 "html": pio.to_html( | |
| 347 fig, | |
| 348 full_html=False, | |
| 349 include_plotlyjs="cdn" if include_js else False, | |
| 350 ), | |
| 351 } | |
| 352 | |
| 353 # Core learning curves | |
| 354 for key, title in [ | |
| 355 ("roc_auc", "ROC-AUC across epochs"), | |
| 356 ("precision", "Precision across epochs"), | |
| 357 ("recall", "Recall/Sensitivity across epochs"), | |
| 358 ("specificity", "Specificity across epochs"), | |
| 359 ]: | |
| 360 plot = _line_plot(key, title, title.replace("Learning Curve", "").strip()) | |
| 361 if plot: | |
| 362 plots.append(plot) | |
| 363 include_js = False | |
| 364 | |
| 365 # Precision vs Recall evolution (validation) | |
| 366 val_prec = _get_series(label_val, "precision") | |
| 367 val_rec = _get_series(label_val, "recall") | |
| 368 if val_prec and val_rec: | |
| 369 epochs = list(range(1, min(len(val_prec), len(val_rec)) + 1)) | |
| 370 fig_pr = go.Figure() | |
| 371 fig_pr.add_trace( | |
| 372 go.Scatter( | |
| 373 x=epochs, | |
| 374 y=val_prec[: len(epochs)], | |
| 375 mode="lines+markers", | |
| 376 name="Precision", | |
| 377 ) | |
| 378 ) | |
| 379 fig_pr.add_trace( | |
| 380 go.Scatter( | |
| 381 x=epochs, | |
| 382 y=val_rec[: len(epochs)], | |
| 383 mode="lines+markers", | |
| 384 name="Recall", | |
| 385 ) | |
| 386 ) | |
| 387 fig_pr.update_layout( | |
| 388 title=dict(text="Validation Precision and Recall by Epoch", x=0.5), | |
| 389 xaxis_title="Epoch", | |
| 390 yaxis_title="Value", | |
| 391 width=760, | |
| 392 height=520, | |
| 393 hovermode="x unified", | |
| 394 ) | |
| 395 _style_fig(fig_pr) | |
| 396 plots.append({ | |
| 397 "title": "Precision vs Recall Evolution", | |
| 398 "html": pio.to_html( | |
| 399 fig_pr, | |
| 400 full_html=False, | |
| 401 include_plotlyjs="cdn" if include_js else False, | |
| 402 ), | |
| 403 }) | |
| 404 include_js = False | |
| 405 | |
| 406 # F1-score derived | |
| 407 def _compute_f1(p: List[float], r: List[float]) -> List[float]: | |
| 408 f1_vals = [] | |
| 409 for prec, rec in zip(p, r): | |
| 410 if (prec + rec) == 0: | |
| 411 f1_vals.append(0.0) | |
| 412 else: | |
| 413 f1_vals.append(2 * prec * rec / (prec + rec)) | |
| 414 return f1_vals | |
| 415 | |
| 416 f1_train = _compute_f1(_get_series(label_train, "precision"), _get_series(label_train, "recall")) | |
| 417 f1_val = _compute_f1(val_prec, val_rec) | |
| 418 if f1_train or f1_val: | |
| 419 fig = go.Figure() | |
| 420 if f1_train: | |
| 421 fig.add_trace(go.Scatter(x=list(range(1, len(f1_train) + 1)), y=f1_train, mode="lines+markers", name="Train", line=dict(width=4))) | |
| 422 if f1_val: | |
| 423 fig.add_trace(go.Scatter(x=list(range(1, len(f1_val) + 1)), y=f1_val, mode="lines+markers", name="Validation", line=dict(width=4))) | |
| 424 fig.update_layout( | |
| 425 title=dict(text="F1-Score across epochs (derived)", x=0.5), | |
| 426 xaxis_title="Epoch", | |
| 427 yaxis_title="F1-Score", | |
| 428 width=760, | |
| 429 height=520, | |
| 430 hovermode="x unified", | |
| 431 ) | |
| 432 _style_fig(fig) | |
| 433 plots.append({ | |
| 434 "title": "F1-Score across epochs (derived)", | |
| 435 "html": pio.to_html( | |
| 436 fig, | |
| 437 full_html=False, | |
| 438 include_plotlyjs="cdn" if include_js else False, | |
| 439 ), | |
| 440 }) | |
| 441 include_js = False | |
| 442 | |
| 443 # Overfitting Gap: Train vs Val ROC-AUC (gap) | |
| 444 roc_train = _get_series(label_train, "roc_auc") | |
| 445 roc_val = _get_series(label_val, "roc_auc") | |
| 446 if roc_train and roc_val: | |
| 447 epochs_gap = list(range(1, min(len(roc_train), len(roc_val)) + 1)) | |
| 448 gaps = [t - v for t, v in zip(roc_train[:len(epochs_gap)], roc_val[:len(epochs_gap)])] | |
| 449 fig_gap = go.Figure() | |
| 450 fig_gap.add_trace(go.Scatter(x=epochs_gap, y=gaps, mode="lines+markers", name="Train - Val ROC-AUC", line=dict(width=4))) | |
| 451 fig_gap.update_layout( | |
| 452 title=dict(text="Overfitting gap: ROC-AUC across epochs", x=0.5), | |
| 453 xaxis_title="Epoch", | |
| 454 yaxis_title="Gap", | |
| 455 width=760, | |
| 456 height=520, | |
| 457 hovermode="x unified", | |
| 458 ) | |
| 459 _style_fig(fig_gap) | |
| 460 plots.append({ | |
| 461 "title": "Overfitting gap: ROC-AUC across epochs", | |
| 462 "html": pio.to_html( | |
| 463 fig_gap, | |
| 464 full_html=False, | |
| 465 include_plotlyjs="cdn" if include_js else False, | |
| 466 ), | |
| 467 }) | |
| 468 include_js = False | |
| 469 | |
| 470 # Best Epoch Dashboard (based on max val ROC-AUC) | |
| 471 if roc_val: | |
| 472 best_idx = int(np.argmax(roc_val)) | |
| 473 best_epoch = best_idx + 1 | |
| 474 spec_val = _get_series(label_val, "specificity") | |
| 475 metrics_at_best = { | |
| 476 "ROC-AUC": roc_val[best_idx] if best_idx < len(roc_val) else None, | |
| 477 "Precision": val_prec[best_idx] if best_idx < len(val_prec) else None, | |
| 478 "Recall": val_rec[best_idx] if best_idx < len(val_rec) else None, | |
| 479 "Specificity": spec_val[best_idx] if best_idx < len(spec_val) else None, | |
| 480 "F1-Score": f1_val[best_idx] if best_idx < len(f1_val) else None, | |
| 481 } | |
| 482 fig_best = go.Figure() | |
| 483 for name, value in metrics_at_best.items(): | |
| 484 if value is not None: | |
| 485 fig_best.add_trace(go.Bar(name=name, x=[name], y=[value])) | |
| 486 fig_best.update_layout( | |
| 487 title=dict(text=f"Best Epoch Dashboard (Val ROC-AUC @ epoch {best_epoch})", x=0.5), | |
| 488 xaxis_title="Metric", | |
| 489 yaxis_title="Value", | |
| 490 width=760, | |
| 491 height=520, | |
| 492 showlegend=False, | |
| 493 ) | |
| 494 _style_fig(fig_best) | |
| 495 plots.append({ | |
| 496 "title": "Best Validation Epoch Snapshot (Metrics)", | |
| 497 "html": pio.to_html( | |
| 498 fig_best, | |
| 499 full_html=False, | |
| 500 include_plotlyjs="cdn" if include_js else False, | |
| 501 ), | |
| 502 }) | |
| 503 include_js = False | |
| 504 | |
| 505 return plots | |
| 506 | |
| 507 | |
| 508 def _get_regression_series(split_stats: dict, metric: str) -> List[float]: | |
| 509 if metric not in split_stats: | |
| 510 return [] | |
| 511 vals = split_stats.get(metric, []) | |
| 512 if isinstance(vals, list): | |
| 513 return [float(v) for v in vals] | |
| 514 try: | |
| 515 return [float(vals)] | |
| 516 except Exception: | |
| 517 return [] | |
| 518 | |
| 519 | |
| 520 def _regression_line_plot( | |
| 521 train_split: dict, | |
| 522 val_split: dict, | |
| 523 metric_key: str, | |
| 524 title: str, | |
| 525 yaxis_title: str, | |
| 526 include_js: bool, | |
| 527 ) -> Optional[Dict[str, str]]: | |
| 528 train_series = _get_regression_series(train_split, metric_key) | |
| 529 val_series = _get_regression_series(val_split, metric_key) | |
| 530 if not train_series and not val_series: | |
| 531 return None | |
| 532 epochs_train = list(range(1, len(train_series) + 1)) | |
| 533 epochs_val = list(range(1, len(val_series) + 1)) | |
| 534 fig = go.Figure() | |
| 535 if train_series: | |
| 536 fig.add_trace( | |
| 537 go.Scatter( | |
| 538 x=epochs_train, | |
| 539 y=train_series, | |
| 540 mode="lines+markers", | |
| 541 name="Train", | |
| 542 line=dict(width=4), | |
| 543 ) | |
| 544 ) | |
| 545 if val_series: | |
| 546 fig.add_trace( | |
| 547 go.Scatter( | |
| 548 x=epochs_val, | |
| 549 y=val_series, | |
| 550 mode="lines+markers", | |
| 551 name="Validation", | |
| 552 line=dict(width=4), | |
| 553 ) | |
| 554 ) | |
| 555 fig.update_layout( | |
| 556 title=dict(text=title, x=0.5), | |
| 557 xaxis_title="Epoch", | |
| 558 yaxis_title=yaxis_title, | |
| 559 width=760, | |
| 560 height=520, | |
| 561 hovermode="x unified", | |
| 562 ) | |
| 563 _style_fig(fig) | |
| 564 return { | |
| 565 "title": title, | |
| 566 "html": pio.to_html( | |
| 567 fig, | |
| 568 full_html=False, | |
| 569 include_plotlyjs="cdn" if include_js else False, | |
| 570 ), | |
| 571 } | |
| 572 | |
| 573 | |
| 574 def build_regression_train_val_plots(train_stats_path: str) -> List[Dict[str, str]]: | |
| 575 """Generate regression Train/Validation learning curve plots from training_statistics.json.""" | |
| 576 if not train_stats_path or not Path(train_stats_path).exists(): | |
| 577 return [] | |
| 578 try: | |
| 579 with open(train_stats_path, "r") as f: | |
| 580 train_stats = json.load(f) | |
| 581 except Exception as exc: | |
| 582 print(f"Warning: Unable to read training statistics: {exc}") | |
| 583 return [] | |
| 584 | |
| 585 label_train = (train_stats.get("training") or {}).get("label", {}) | |
| 586 label_val = (train_stats.get("validation") or {}).get("label", {}) | |
| 587 if not label_train and not label_val: | |
| 588 return [] | |
| 589 | |
| 590 plots: List[Dict[str, str]] = [] | |
| 591 include_js = True | |
| 592 for metric_key, title, ytitle in [ | |
| 593 ("mean_absolute_error", "Mean Absolute Error across epochs", "MAE"), | |
| 594 ("root_mean_squared_error", "Root Mean Squared Error across epochs", "RMSE"), | |
| 595 ("mean_absolute_percentage_error", "Mean Absolute Percentage Error across epochs", "MAPE"), | |
| 596 ("r2", "R² across epochs", "R²"), | |
| 597 ("loss", "Loss across epochs", "Loss"), | |
| 598 ]: | |
| 599 plot = _regression_line_plot(label_train, label_val, metric_key, title, ytitle, include_js) | |
| 600 if plot: | |
| 601 plots.append(plot) | |
| 602 include_js = False | |
| 603 return plots | |
| 604 | |
| 605 | |
| 606 def build_regression_test_plots(train_stats_path: str) -> List[Dict[str, str]]: | |
| 607 """Generate regression Test learning curves from training_statistics.json.""" | |
| 608 if not train_stats_path or not Path(train_stats_path).exists(): | |
| 609 return [] | |
| 610 try: | |
| 611 with open(train_stats_path, "r") as f: | |
| 612 train_stats = json.load(f) | |
| 613 except Exception as exc: | |
| 614 print(f"Warning: Unable to read training statistics: {exc}") | |
| 615 return [] | |
| 616 | |
| 617 label_test = (train_stats.get("test") or {}).get("label", {}) | |
| 618 if not label_test: | |
| 619 return [] | |
| 620 | |
| 621 plots: List[Dict[str, str]] = [] | |
| 622 include_js = True | |
| 623 metrics = [ | |
| 624 ("mean_absolute_error", "Mean Absolute Error Across Epochs", "MAE"), | |
| 625 ("root_mean_squared_error", "Root Mean Squared Error Across Epochs", "RMSE"), | |
| 626 ("mean_absolute_percentage_error", "Mean Absolute Percentage Error Across Epochs", "MAPE"), | |
| 627 ("r2", "R² Across Epochs", "R²"), | |
| 628 ("loss", "Loss Across Epochs", "Loss"), | |
| 629 ] | |
| 630 epochs = None | |
| 631 for metric_key, title, ytitle in metrics: | |
| 632 series = _get_regression_series(label_test, metric_key) | |
| 633 if not series: | |
| 634 continue | |
| 635 if epochs is None: | |
| 636 epochs = list(range(1, len(series) + 1)) | |
| 637 fig = go.Figure() | |
| 638 fig.add_trace( | |
| 639 go.Scatter( | |
| 640 x=epochs, | |
| 641 y=series[: len(epochs)], | |
| 642 mode="lines+markers", | |
| 643 name="Test", | |
| 644 line=dict(width=4), | |
| 645 ) | |
| 646 ) | |
| 647 fig.update_layout( | |
| 648 title=dict(text=title, x=0.5), | |
| 649 xaxis_title="Epoch", | |
| 650 yaxis_title=ytitle, | |
| 651 width=760, | |
| 652 height=520, | |
| 653 hovermode="x unified", | |
| 654 ) | |
| 655 _style_fig(fig) | |
| 656 plots.append({ | |
| 657 "title": title, | |
| 658 "html": pio.to_html( | |
| 659 fig, | |
| 660 full_html=False, | |
| 661 include_plotlyjs="cdn" if include_js else False, | |
| 662 ), | |
| 663 }) | |
| 664 include_js = False | |
| 665 return plots | |
| 666 | |
| 667 | |
| 668 def _build_static_roc_plot( | |
| 669 label_stats: dict, config: dict, friendly_labels: Optional[List[str]] = None | |
| 670 ) -> Optional[Dict[str, str]]: | |
| 671 """Build ROC curve directly from test_statistics.json (single curve).""" | |
| 672 roc_data = label_stats.get("roc_curve") | |
| 673 if not isinstance(roc_data, dict): | |
| 674 return None | |
| 675 | |
| 676 fpr = roc_data.get("false_positive_rate") | |
| 677 tpr = roc_data.get("true_positive_rate") | |
| 678 if not fpr or not tpr or len(fpr) != len(tpr): | |
| 679 return None | |
| 680 | |
| 681 try: | |
| 682 fig = go.Figure() | |
| 683 fig.add_trace( | |
| 684 go.Scatter( | |
| 685 x=fpr, | |
| 686 y=tpr, | |
| 687 mode="lines+markers", | |
| 688 name="ROC Curve", | |
| 689 line=dict(color="#1f77b4", width=4), | |
| 690 hovertemplate="FPR: %{x:.3f}<br>TPR: %{y:.3f}<extra></extra>", | |
| 691 ) | |
| 692 ) | |
| 693 fig.add_trace( | |
| 694 go.Scatter( | |
| 695 x=[0, 1], | |
| 696 y=[0, 1], | |
| 697 mode="lines", | |
| 698 name="Random Classifier", | |
| 699 line=dict(color="gray", width=2, dash="dash"), | |
| 700 hovertemplate="Random Classifier<extra></extra>", | |
| 701 ) | |
| 702 ) | |
| 703 | |
| 704 auc_val = label_stats.get("roc_auc") or label_stats.get("roc_auc_macro") or label_stats.get("roc_auc_micro") | |
| 705 auc_txt = f" (AUC = {auc_val:.3f})" if isinstance(auc_val, (int, float)) else "" | |
| 706 | |
| 707 # Determine which label is treated as positive for the curve | |
| 708 label_list: List = [] | |
| 709 pcs = label_stats.get("per_class_stats", {}) | |
| 710 if pcs: | |
| 711 label_list = list(pcs.keys()) | |
| 712 if not label_list: | |
| 713 labels_from_stats = label_stats.get("labels") | |
| 714 if isinstance(labels_from_stats, list): | |
| 715 label_list = labels_from_stats | |
| 716 | |
| 717 # Try to resolve index of the positive label explicitly provided by Ludwig | |
| 718 pos_label_raw = ( | |
| 719 roc_data.get("positive_label") | |
| 720 or roc_data.get("positive_class") | |
| 721 or label_stats.get("positive_label") | |
| 722 ) | |
| 723 pos_label_idx = None | |
| 724 if pos_label_raw is not None and isinstance(label_list, list): | |
| 725 try: | |
| 726 pos_label_idx = label_list.index(pos_label_raw) | |
| 727 except ValueError: | |
| 728 pos_label_idx = None | |
| 729 | |
| 730 # Fallback: use the second label if available, otherwise the first | |
| 731 if pos_label_idx is None: | |
| 732 if isinstance(label_list, list) and len(label_list) >= 2: | |
| 733 pos_label_idx = 1 | |
| 734 elif isinstance(label_list, list) and label_list: | |
| 735 pos_label_idx = 0 | |
| 736 | |
| 737 if pos_label_raw is None and isinstance(label_list, list) and pos_label_idx is not None: | |
| 738 pos_label_raw = label_list[pos_label_idx] | |
| 739 | |
| 740 # Map to friendly label if we have one from metadata/CSV | |
| 741 pos_label_display = pos_label_raw | |
| 742 if ( | |
| 743 friendly_labels | |
| 744 and isinstance(pos_label_idx, int) | |
| 745 and 0 <= pos_label_idx < len(friendly_labels) | |
| 746 ): | |
| 747 pos_label_display = friendly_labels[pos_label_idx] | |
| 748 | |
| 749 pos_label_txt = ( | |
| 750 f"Positive class: {pos_label_display}" | |
| 751 if pos_label_display is not None | |
| 752 else "Positive class: (not available)" | |
| 753 ) | |
| 754 | |
| 755 title_label = f"ROC Curve{auc_txt}" | |
| 756 if pos_label_display is not None: | |
| 757 title_label = f"ROC Curve (Positive Class: {pos_label_display}){auc_txt}" | |
| 758 | |
| 759 fig.update_layout( | |
| 760 title=dict(text=title_label, x=0.5), | |
| 761 xaxis_title="False Positive Rate", | |
| 762 yaxis_title="True Positive Rate", | |
| 763 width=700, | |
| 764 height=600, | |
| 765 margin=dict(t=80, l=80, r=80, b=110), | |
| 766 hovermode="closest", | |
| 767 legend=dict( | |
| 768 x=0.6, | |
| 769 y=0.1, | |
| 770 bgcolor="rgba(255,255,255,0.9)", | |
| 771 bordercolor="rgba(0,0,0,0.2)", | |
| 772 borderwidth=1, | |
| 773 ), | |
| 774 ) | |
| 775 _style_fig(fig) | |
| 776 fig.update_xaxes(range=[0, 1.0]) | |
| 777 fig.update_yaxes(range=[0, 1.05]) | |
| 778 | |
| 779 fig.add_annotation( | |
| 780 x=0.5, | |
| 781 y=-0.15, | |
| 782 xref="paper", | |
| 783 yref="paper", | |
| 784 showarrow=False, | |
| 785 text=f"<span style='font-size:12px;color:#555;'>{pos_label_txt}</span>", | |
| 786 xanchor="center", | |
| 787 ) | |
| 788 | |
| 789 return { | |
| 790 "title": "ROC Curve", | |
| 791 "html": pio.to_html( | |
| 792 fig, | |
| 793 full_html=False, | |
| 794 include_plotlyjs=False, | |
| 795 config=config, | |
| 796 ), | |
| 797 } | |
| 798 except Exception as e: | |
| 799 print(f"Error building ROC plot: {e}") | |
| 800 return None | |
| 801 | |
| 802 | |
| 803 def _build_precision_recall_plot(label_stats: dict, config: dict) -> Optional[Dict[str, str]]: | |
| 804 """Build Precision-Recall curve directly from test_statistics.json.""" | |
| 805 pr_data = label_stats.get("precision_recall_curve") | |
| 806 if not isinstance(pr_data, dict): | |
| 807 return None | |
| 808 | |
| 809 precisions = pr_data.get("precisions") | |
| 810 recalls = pr_data.get("recalls") | |
| 811 if not precisions or not recalls or len(precisions) != len(recalls): | |
| 812 return None | |
| 813 | |
| 814 try: | |
| 815 fig = go.Figure() | |
| 816 fig.add_trace( | |
| 817 go.Scatter( | |
| 818 x=recalls, | |
| 819 y=precisions, | |
| 820 mode="lines+markers", | |
| 821 name="Precision-Recall", | |
| 822 line=dict(color="#d62728", width=4), | |
| 823 hovertemplate="Recall: %{x:.3f}<br>Precision: %{y:.3f}<extra></extra>", | |
| 824 ) | |
| 825 ) | |
| 826 | |
| 827 ap_val = ( | |
| 828 label_stats.get("average_precision_macro") | |
| 829 or label_stats.get("average_precision_micro") | |
| 830 or label_stats.get("average_precision_samples") | |
| 831 ) | |
| 832 ap_txt = f" (AP = {ap_val:.3f})" if isinstance(ap_val, (int, float)) else "" | |
| 833 | |
| 834 fig.update_layout( | |
| 835 title=dict(text=f"Precision-Recall Curve{ap_txt}", x=0.5), | |
| 836 xaxis_title="Recall", | |
| 837 yaxis_title="Precision", | |
| 838 width=700, | |
| 839 height=600, | |
| 840 margin=dict(t=80, l=80, r=80, b=80), | |
| 841 hovermode="closest", | |
| 842 legend=dict( | |
| 843 x=0.6, | |
| 844 y=0.1, | |
| 845 bgcolor="rgba(255,255,255,0.9)", | |
| 846 bordercolor="rgba(0,0,0,0.2)", | |
| 847 borderwidth=1, | |
| 848 ), | |
| 849 ) | |
| 850 _style_fig(fig) | |
| 851 fig.update_xaxes(range=[0, 1.0]) | |
| 852 fig.update_yaxes(range=[0, 1.05]) | |
| 853 | |
| 854 return { | |
| 855 "title": "Precision-Recall Curve", | |
| 856 "html": pio.to_html( | |
| 857 fig, | |
| 858 full_html=False, | |
| 859 include_plotlyjs=False, | |
| 860 config=config, | |
| 861 ), | |
| 862 } | |
| 863 except Exception as e: | |
| 864 print(f"Error building Precision-Recall plot: {e}") | |
| 865 return None | |
| 866 | |
| 867 | |
| 868 def build_prediction_diagnostics( | |
| 869 predictions_path: str, | |
| 870 label_data_path: Optional[str] = None, | |
| 871 split_value: int = 2, | |
| 872 threshold: Optional[float] = None, | |
| 873 ) -> List[Dict[str, str]]: | |
| 874 """Generate diagnostic plots from predictions.csv for classification tasks.""" | |
| 875 preds_file = Path(predictions_path) | |
| 876 if not preds_file.exists(): | |
| 877 return [] | |
| 878 | |
| 879 try: | |
| 188 df_pred = pd.read_csv(predictions_path) | 880 df_pred = pd.read_csv(predictions_path) |
| 189 | 881 except Exception as exc: |
| 190 if SPLIT_COLUMN_NAME in df_pred.columns: | 882 print(f"Warning: Unable to read predictions CSV: {exc}") |
| 191 split_series = df_pred[SPLIT_COLUMN_NAME].astype(str).str.lower() | 883 return [] |
| 192 test_mask = split_series.isin({"2", "test", "testing"}) | 884 |
| 193 if test_mask.any(): | 885 plots: List[Dict[str, str]] = [] |
| 194 df_pred = df_pred[test_mask].reset_index(drop=True) | 886 |
| 195 | 887 # Identify probability columns |
| 196 if df_pred.empty: | 888 prob_cols = [ |
| 197 return None | 889 c for c in df_pred.columns |
| 198 | 890 if c.startswith("label_probabilities_") and c != "label_probabilities" |
| 199 # Extract probability columns (label_probabilities_0, label_probabilities_1, etc.) | 891 ] |
| 200 # or label_probabilities_<class_name> for string labels | 892 prob_cols_sorted = sorted(prob_cols) |
| 201 prob_cols = [col for col in df_pred.columns if col.startswith('label_probabilities_') and col != 'label_probabilities'] | 893 |
| 202 | 894 def _select_positive_prob(): |
| 203 # Sort by class number if numeric, otherwise keep alphabetical order | 895 if not prob_cols_sorted: |
| 204 if prob_cols and prob_cols[0].split('_')[-1].isdigit(): | 896 return None, None |
| 205 prob_cols.sort(key=lambda x: int(x.split('_')[-1])) | 897 # Prefer a column indicating positive/event/true/1 |
| 206 else: | 898 preferred_keys = ("event", "true", "positive", "pos", "1") |
| 207 prob_cols.sort() # Alphabetical sort for string class names | 899 for col in prob_cols_sorted: |
| 208 | 900 suffix = col.replace("label_probabilities_", "").lower() |
| 209 if not prob_cols: | 901 if any(k in suffix for k in preferred_keys): |
| 210 return None | 902 return col, suffix |
| 211 | 903 if len(prob_cols_sorted) == 2: |
| 212 # Get probabilities matrix (n_samples x n_classes) | 904 col = prob_cols_sorted[1] |
| 213 y_score = df_pred[prob_cols].values | 905 return col, col.replace("label_probabilities_", "") |
| 214 n_classes = len(prob_cols) | 906 col = prob_cols_sorted[0] |
| 215 | 907 return col, col.replace("label_probabilities_", "") |
| 216 y_true = None | 908 |
| 217 candidate_cols = [ | 909 pos_prob_col, pos_label_hint = _select_positive_prob() |
| 910 pos_prob_series = df_pred[pos_prob_col] if pos_prob_col and pos_prob_col in df_pred else None | |
| 911 | |
| 912 # Confidence series: prefer label_probability, otherwise positive prob, otherwise max prob | |
| 913 confidence_series = None | |
| 914 if "label_probability" in df_pred.columns: | |
| 915 confidence_series = df_pred["label_probability"] | |
| 916 elif pos_prob_series is not None: | |
| 917 confidence_series = pos_prob_series | |
| 918 elif prob_cols_sorted: | |
| 919 confidence_series = df_pred[prob_cols_sorted].max(axis=1) | |
| 920 | |
| 921 # True labels | |
| 922 def _extract_labels(): | |
| 923 candidates = [ | |
| 218 LABEL_COLUMN_NAME, | 924 LABEL_COLUMN_NAME, |
| 219 f"{LABEL_COLUMN_NAME}_ground_truth", | 925 f"{LABEL_COLUMN_NAME}_ground_truth", |
| 220 f"{LABEL_COLUMN_NAME}__ground_truth", | 926 f"{LABEL_COLUMN_NAME}__ground_truth", |
| 221 f"{LABEL_COLUMN_NAME}_target", | 927 f"{LABEL_COLUMN_NAME}_target", |
| 222 f"{LABEL_COLUMN_NAME}__target", | 928 f"{LABEL_COLUMN_NAME}__target", |
| 929 "label", | |
| 930 "label_true", | |
| 223 ] | 931 ] |
| 224 candidate_cols.extend( | 932 candidates.extend( |
| 225 [ | 933 [ |
| 226 col | 934 col |
| 227 for col in df_pred.columns | 935 for col in df_pred.columns |
| 228 if (col.startswith(f"{LABEL_COLUMN_NAME}_") or col.startswith(f"{LABEL_COLUMN_NAME}__")) | 936 if (col.startswith(f"{LABEL_COLUMN_NAME}_") or col.startswith(f"{LABEL_COLUMN_NAME}__")) |
| 229 and "probabilities" not in col | 937 and "probabilities" not in col |
| 230 and "predictions" not in col | 938 and "predictions" not in col |
| 231 ] | 939 ] |
| 232 ) | 940 ) |
| 233 for col in candidate_cols: | 941 for col in candidates: |
| 234 if col in df_pred.columns and col not in prob_cols: | 942 if col in df_pred.columns and col not in prob_cols_sorted: |
| 235 y_true = df_pred[col].values | 943 return df_pred[col] |
| 236 break | 944 if label_data_path and Path(label_data_path).exists(): |
| 237 | 945 try: |
| 238 if y_true is None: | 946 df_all = pd.read_csv(label_data_path) |
| 239 desc_path = exp_dir / "description.json" | 947 if SPLIT_COLUMN_NAME in df_all.columns: |
| 240 if desc_path.exists(): | 948 df_all = df_all[df_all[SPLIT_COLUMN_NAME] == split_value].reset_index(drop=True) |
| 241 try: | 949 if LABEL_COLUMN_NAME in df_all.columns: |
| 242 with open(desc_path, 'r') as f: | 950 return df_all[LABEL_COLUMN_NAME].reset_index(drop=True) |
| 243 desc = json.load(f) | 951 except Exception as exc: |
| 244 dataset_path = desc.get('dataset', '') | 952 print(f"Warning: Unable to load labels from dataset: {exc}") |
| 245 if dataset_path and Path(dataset_path).exists(): | 953 return None |
| 246 df_orig = pd.read_csv(dataset_path) | 954 |
| 247 if SPLIT_COLUMN_NAME in df_orig.columns: | 955 labels_series = _extract_labels() |
| 248 df_orig = df_orig[df_orig[SPLIT_COLUMN_NAME] == 2].reset_index(drop=True) | 956 |
| 249 if LABEL_COLUMN_NAME in df_orig.columns: | 957 # Plot 1: Confidence Histogram |
| 250 y_true = df_orig[LABEL_COLUMN_NAME].values | 958 if confidence_series is not None: |
| 251 if len(y_true) != len(df_pred): | 959 fig_conf = go.Figure() |
| 252 print( | 960 fig_conf.add_trace( |
| 253 f"Warning: Test set size mismatch. Truncating to {len(df_pred)} samples for ROC plot." | 961 go.Histogram( |
| 254 ) | 962 x=confidence_series, |
| 255 y_true = y_true[:len(df_pred)] | 963 nbinsx=20, |
| 256 else: | 964 marker=dict(color="#1f77b4", line=dict(color="#ffffff", width=1)), |
| 257 print("Warning: Original dataset referenced in description.json is unavailable.") | 965 opacity=0.8, |
| 258 except Exception as exc: # pragma: no cover - defensive | 966 histnorm="percent", |
| 259 print(f"Warning: Failed to recover labels from dataset: {exc}") | 967 ) |
| 260 | 968 ) |
| 261 if y_true is None or len(y_true) == 0: | 969 fig_conf.update_layout( |
| 262 print("Warning: Unable to locate ground-truth labels for ROC plot.") | 970 title=dict(text="Prediction Confidence Distribution", x=0.5), |
| 263 return None | 971 xaxis_title="Predicted probability (confidence)", |
| 264 | 972 yaxis_title="Percentage (%)", |
| 265 if len(y_true) != len(y_score): | 973 bargap=0.05, |
| 266 limit = min(len(y_true), len(y_score)) | |
| 267 if limit == 0: | |
| 268 return None | |
| 269 print(f"Warning: Aligning prediction and label lengths to {limit} samples for ROC plot.") | |
| 270 y_true = y_true[:limit] | |
| 271 y_score = y_score[:limit] | |
| 272 | |
| 273 # Get actual class names from probability column names | |
| 274 actual_classes = [col.replace('label_probabilities_', '') for col in prob_cols] | |
| 275 display_classes = class_labels if len(class_labels) == n_classes else actual_classes | |
| 276 | |
| 277 # Binarize the output following sklearn example | |
| 278 # Use actual class names if they're strings, otherwise use range | |
| 279 if isinstance(y_true[0], str): | |
| 280 y_test = label_binarize(y_true, classes=actual_classes) | |
| 281 else: | |
| 282 y_test = label_binarize(y_true, classes=list(range(n_classes))) | |
| 283 | |
| 284 # Handle binary classification case | |
| 285 if y_test.ndim != 2: | |
| 286 y_test = np.atleast_2d(y_test) | |
| 287 | |
| 288 if n_classes == 2: | |
| 289 if y_test.shape[1] == 1: | |
| 290 y_test = np.hstack([1 - y_test, y_test]) | |
| 291 elif y_test.shape[1] != 2: | |
| 292 print("Warning: Unexpected label binarization shape for binary ROC plot.") | |
| 293 return None | |
| 294 elif y_test.shape[1] != n_classes: | |
| 295 print("Warning: Label binarization did not produce expected class dimension; skipping ROC plot.") | |
| 296 return None | |
| 297 | |
| 298 # Compute ROC curve and ROC area for each class (following sklearn example) | |
| 299 fpr = dict() | |
| 300 tpr = dict() | |
| 301 roc_auc = dict() | |
| 302 | |
| 303 for i in range(n_classes): | |
| 304 if np.sum(y_test[:, i]) > 0: # Check if class exists in test set | |
| 305 fpr[i], tpr[i], _ = roc_curve(y_test[:, i], y_score[:, i]) | |
| 306 roc_auc[i] = auc(fpr[i], tpr[i]) | |
| 307 | |
| 308 # Compute micro-average ROC curve and ROC area (sklearn example) | |
| 309 fpr["micro"], tpr["micro"], _ = roc_curve(y_test.ravel(), y_score.ravel()) | |
| 310 roc_auc["micro"] = auc(fpr["micro"], tpr["micro"]) | |
| 311 | |
| 312 # Create ROC curve plot | |
| 313 fig_roc = go.Figure() | |
| 314 | |
| 315 # Colors for different classes | |
| 316 colors = [ | |
| 317 '#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd', | |
| 318 '#8c564b', '#e377c2', '#7f7f7f', '#bcbd22', '#17becf' | |
| 319 ] | |
| 320 | |
| 321 # Plot micro-average ROC curve first (most important) | |
| 322 fig_roc.add_trace(go.Scatter( | |
| 323 x=fpr["micro"], | |
| 324 y=tpr["micro"], | |
| 325 mode='lines', | |
| 326 name=f'Micro-average ROC (AUC = {roc_auc["micro"]:.3f})', | |
| 327 line=dict(color='deeppink', width=3, dash='dot'), | |
| 328 hovertemplate=('<b>Micro-average ROC</b><br>' | |
| 329 'FPR: %{x:.3f}<br>' | |
| 330 'TPR: %{y:.3f}<br>' | |
| 331 f'AUC: {roc_auc["micro"]:.3f}<extra></extra>') | |
| 332 )) | |
| 333 | |
| 334 # Plot ROC curve for each class | |
| 335 for i in range(n_classes): | |
| 336 if i in roc_auc: # Only plot if class exists in test set | |
| 337 class_name = display_classes[i] if i < len(display_classes) else f"Class {i}" | |
| 338 color = colors[i % len(colors)] | |
| 339 | |
| 340 fig_roc.add_trace(go.Scatter( | |
| 341 x=fpr[i], | |
| 342 y=tpr[i], | |
| 343 mode='lines', | |
| 344 name=f'{class_name} (AUC = {roc_auc[i]:.3f})', | |
| 345 line=dict(color=color, width=2), | |
| 346 hovertemplate=(f'<b>{class_name}</b><br>' | |
| 347 'FPR: %{x:.3f}<br>' | |
| 348 'TPR: %{y:.3f}<br>' | |
| 349 f'AUC: {roc_auc[i]:.3f}<extra></extra>') | |
| 350 )) | |
| 351 | |
| 352 # Add diagonal line (random classifier) | |
| 353 fig_roc.add_trace(go.Scatter( | |
| 354 x=[0, 1], | |
| 355 y=[0, 1], | |
| 356 mode='lines', | |
| 357 name='Random Classifier', | |
| 358 line=dict(color='gray', width=1, dash='dash'), | |
| 359 hovertemplate='Random Classifier<br>AUC = 0.500<extra></extra>' | |
| 360 )) | |
| 361 | |
| 362 # Calculate macro-average AUC | |
| 363 class_aucs = [roc_auc[i] for i in range(n_classes) if i in roc_auc] | |
| 364 if class_aucs: | |
| 365 macro_auc = np.mean(class_aucs) | |
| 366 title_text = f"ROC Curves (Micro-avg = {roc_auc['micro']:.3f}, Macro-avg = {macro_auc:.3f})" | |
| 367 else: | |
| 368 title_text = f"ROC Curves (Micro-avg = {roc_auc['micro']:.3f})" | |
| 369 | |
| 370 fig_roc.update_layout( | |
| 371 title=dict(text=title_text, x=0.5), | |
| 372 xaxis_title="False Positive Rate", | |
| 373 yaxis_title="True Positive Rate", | |
| 374 width=700, | 974 width=700, |
| 375 height=600, | 975 height=500, |
| 376 margin=dict(t=80, l=80, r=80, b=80), | 976 ) |
| 377 legend=dict( | 977 _style_fig(fig_conf) |
| 378 x=0.6, | 978 plots.append({ |
| 379 y=0.1, | 979 "title": "Prediction Confidence Distribution", |
| 380 bgcolor="rgba(255,255,255,0.9)", | 980 "html": pio.to_html(fig_conf, full_html=False, include_plotlyjs=False), |
| 381 bordercolor="rgba(0,0,0,0.2)", | 981 }) |
| 382 borderwidth=1 | 982 |
| 383 ), | 983 # The remaining plots require true labels and a positive-class probability |
| 384 hovermode='closest' | 984 if labels_series is None or pos_prob_series is None: |
| 385 ) | 985 return plots |
| 386 | 986 |
| 387 # Set equal aspect ratio and proper range | 987 # Align lengths |
| 388 fig_roc.update_xaxes(range=[0, 1.0]) | 988 min_len = min(len(labels_series), len(pos_prob_series)) |
| 389 fig_roc.update_yaxes(range=[0, 1.05]) | 989 if min_len == 0: |
| 390 | 990 return plots |
| 391 return { | 991 y_true_raw = labels_series.iloc[:min_len] |
| 392 "title": "ROC-AUC Curves", | 992 y_score = np.array(pos_prob_series.iloc[:min_len], dtype=float) |
| 393 "html": pio.to_html( | 993 |
| 394 fig_roc, | 994 # Determine positive label |
| 395 full_html=False, | 995 unique_labels = pd.unique(y_true_raw) |
| 396 include_plotlyjs=False, | 996 unique_labels_list = list(unique_labels) |
| 397 config=config | 997 positive_label = None |
| 398 ) | 998 if pos_label_hint and str(pos_label_hint) in [str(u) for u in unique_labels_list]: |
| 399 } | 999 positive_label = pos_label_hint |
| 400 | 1000 elif len(unique_labels_list) == 2: |
| 401 except Exception as e: | 1001 positive_label = unique_labels_list[1] |
| 402 print(f"Error building ROC-AUC plot: {e}") | 1002 else: |
| 403 return None | 1003 positive_label = unique_labels_list[0] |
| 1004 | |
| 1005 y_true = (y_true_raw == positive_label).astype(int).values | |
| 1006 | |
| 1007 # Plot 2: Calibration Curve | |
| 1008 bins = np.linspace(0.0, 1.0, 11) | |
| 1009 bin_ids = np.digitize(y_score, bins, right=True) | |
| 1010 bin_centers = [] | |
| 1011 frac_positives = [] | |
| 1012 for b in range(1, len(bins)): | |
| 1013 mask = bin_ids == b | |
| 1014 if not np.any(mask): | |
| 1015 continue | |
| 1016 bin_centers.append(y_score[mask].mean()) | |
| 1017 frac_positives.append(y_true[mask].mean()) | |
| 1018 if bin_centers and frac_positives: | |
| 1019 fig_cal = go.Figure() | |
| 1020 fig_cal.add_trace( | |
| 1021 go.Scatter( | |
| 1022 x=bin_centers, | |
| 1023 y=frac_positives, | |
| 1024 mode="lines+markers", | |
| 1025 name="Calibration", | |
| 1026 line=dict(color="#2ca02c", width=4), | |
| 1027 ) | |
| 1028 ) | |
| 1029 fig_cal.add_trace( | |
| 1030 go.Scatter( | |
| 1031 x=[0, 1], | |
| 1032 y=[0, 1], | |
| 1033 mode="lines", | |
| 1034 name="Perfect Calibration", | |
| 1035 line=dict(color="gray", width=2, dash="dash"), | |
| 1036 ) | |
| 1037 ) | |
| 1038 fig_cal.update_layout( | |
| 1039 title=dict(text="Calibration Curve", x=0.5), | |
| 1040 xaxis_title="Predicted probability", | |
| 1041 yaxis_title="Observed frequency", | |
| 1042 width=700, | |
| 1043 height=500, | |
| 1044 ) | |
| 1045 _style_fig(fig_cal) | |
| 1046 plots.append({ | |
| 1047 "title": "Calibration Curve (Predicted Probability vs Observed Frequency)", | |
| 1048 "html": pio.to_html(fig_cal, full_html=False, include_plotlyjs=False), | |
| 1049 }) | |
| 1050 | |
| 1051 # Plot 3: Threshold vs Metrics | |
| 1052 thresholds = np.linspace(0.0, 1.0, 21) | |
| 1053 accs, f1s, sens, specs = [], [], [], [] | |
| 1054 for t in thresholds: | |
| 1055 y_pred = (y_score >= t).astype(int) | |
| 1056 tp = np.sum((y_true == 1) & (y_pred == 1)) | |
| 1057 tn = np.sum((y_true == 0) & (y_pred == 0)) | |
| 1058 fp = np.sum((y_true == 0) & (y_pred == 1)) | |
| 1059 fn = np.sum((y_true == 1) & (y_pred == 0)) | |
| 1060 acc = (tp + tn) / max(len(y_true), 1) | |
| 1061 prec = tp / max(tp + fp, 1e-9) | |
| 1062 rec = tp / max(tp + fn, 1e-9) | |
| 1063 f1 = 0 if (prec + rec) == 0 else 2 * prec * rec / (prec + rec) | |
| 1064 sensitivity = rec | |
| 1065 specificity = tn / max(tn + fp, 1e-9) | |
| 1066 accs.append(acc) | |
| 1067 f1s.append(f1) | |
| 1068 sens.append(sensitivity) | |
| 1069 specs.append(specificity) | |
| 1070 | |
| 1071 fig_thresh = go.Figure() | |
| 1072 fig_thresh.add_trace(go.Scatter(x=thresholds, y=accs, mode="lines", name="Accuracy", line=dict(width=4))) | |
| 1073 fig_thresh.add_trace(go.Scatter(x=thresholds, y=f1s, mode="lines", name="F1", line=dict(width=4))) | |
| 1074 fig_thresh.add_trace(go.Scatter(x=thresholds, y=sens, mode="lines", name="Sensitivity", line=dict(width=4))) | |
| 1075 fig_thresh.add_trace(go.Scatter(x=thresholds, y=specs, mode="lines", name="Specificity", line=dict(width=4))) | |
| 1076 fig_thresh.update_layout( | |
| 1077 title=dict(text="Threshold Sweep: Accuracy, F1, Sensitivity, Specificity", x=0.5), | |
| 1078 xaxis_title="Decision threshold", | |
| 1079 yaxis_title="Metric value", | |
| 1080 width=700, | |
| 1081 height=500, | |
| 1082 legend=dict( | |
| 1083 x=0.7, | |
| 1084 y=0.2, | |
| 1085 bgcolor="rgba(255,255,255,0.9)", | |
| 1086 bordercolor="rgba(0,0,0,0.2)", | |
| 1087 borderwidth=1, | |
| 1088 ), | |
| 1089 shapes=[ | |
| 1090 dict( | |
| 1091 type="line", | |
| 1092 x0=threshold, | |
| 1093 x1=threshold, | |
| 1094 y0=0, | |
| 1095 y1=1, | |
| 1096 xref="x", | |
| 1097 yref="paper", | |
| 1098 line=dict(color="#d62728", width=2, dash="dash"), | |
| 1099 ) | |
| 1100 ] if isinstance(threshold, (int, float)) else [], | |
| 1101 annotations=[ | |
| 1102 dict( | |
| 1103 x=threshold, | |
| 1104 y=1.02, | |
| 1105 xref="x", | |
| 1106 yref="paper", | |
| 1107 showarrow=False, | |
| 1108 text=f"Threshold = {threshold:.2f}", | |
| 1109 font=dict(size=11, color="#d62728"), | |
| 1110 ) | |
| 1111 ] if isinstance(threshold, (int, float)) else [], | |
| 1112 ) | |
| 1113 _style_fig(fig_thresh) | |
| 1114 plots.append({ | |
| 1115 "title": "Threshold Sweep: Accuracy, F1, Sensitivity, Specificity", | |
| 1116 "html": pio.to_html(fig_thresh, full_html=False, include_plotlyjs=False), | |
| 1117 }) | |
| 1118 | |
| 1119 return plots |
