Mercurial > repos > goeckslab > image_learner
comparison plotly_plots.py @ 17:db9be962dc13 draft
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
| author | goeckslab |
|---|---|
| date | Wed, 10 Dec 2025 00:24:13 +0000 |
| parents | d17e3a1b8659 |
| children |
comparison
equal
deleted
inserted
replaced
| 16:8729f69e9207 | 17:db9be962dc13 |
|---|---|
| 5 import numpy as np | 5 import numpy as np |
| 6 import pandas as pd | 6 import pandas as pd |
| 7 import plotly.graph_objects as go | 7 import plotly.graph_objects as go |
| 8 import plotly.io as pio | 8 import plotly.io as pio |
| 9 from constants import LABEL_COLUMN_NAME, SPLIT_COLUMN_NAME | 9 from constants import LABEL_COLUMN_NAME, SPLIT_COLUMN_NAME |
| 10 from sklearn.metrics import ( | |
| 11 accuracy_score, | |
| 12 auc, | |
| 13 average_precision_score, | |
| 14 f1_score, | |
| 15 precision_recall_curve, | |
| 16 precision_score, | |
| 17 recall_score, | |
| 18 roc_curve, | |
| 19 ) | |
| 20 from sklearn.preprocessing import label_binarize | |
| 10 | 21 |
| 11 | 22 |
| 12 def _style_fig(fig: go.Figure, font_size: int = 12) -> go.Figure: | 23 def _style_fig(fig: go.Figure, font_size: int = 12) -> go.Figure: |
| 13 """Apply consistent styling across Plotly figures.""" | 24 """Apply consistent styling across Plotly figures.""" |
| 14 fig.update_layout( | 25 fig.update_layout( |
| 16 plot_bgcolor="#ffffff", | 27 plot_bgcolor="#ffffff", |
| 17 paper_bgcolor="#ffffff", | 28 paper_bgcolor="#ffffff", |
| 18 ) | 29 ) |
| 19 fig.update_xaxes(gridcolor="#e8e8e8") | 30 fig.update_xaxes(gridcolor="#e8e8e8") |
| 20 fig.update_yaxes(gridcolor="#e8e8e8") | 31 fig.update_yaxes(gridcolor="#e8e8e8") |
| 32 return fig | |
| 33 | |
| 34 | |
| 35 def _fig_to_html( | |
| 36 fig: go.Figure, *, include_js: bool = False, config: Optional[dict] = None | |
| 37 ) -> str: | |
| 38 """Render a Plotly figure to a lightweight HTML fragment.""" | |
| 39 include_plotlyjs = "cdn" if include_js else False | |
| 40 return pio.to_html( | |
| 41 fig, | |
| 42 full_html=False, | |
| 43 include_plotlyjs=include_plotlyjs, | |
| 44 config=config, | |
| 45 ) | |
| 46 | |
| 47 | |
| 48 def _wrap_plot( | |
| 49 title: str, | |
| 50 fig: go.Figure, | |
| 51 *, | |
| 52 include_js: bool = False, | |
| 53 config: Optional[dict] = None, | |
| 54 ) -> Dict[str, str]: | |
| 55 """Package a figure with its title for downstream HTML rendering.""" | |
| 56 return {"title": title, "html": _fig_to_html(fig, include_js=include_js, config=config)} | |
| 57 | |
| 58 | |
| 59 def _line_chart( | |
| 60 traces: List[tuple], | |
| 61 *, | |
| 62 title: str, | |
| 63 yaxis_title: str, | |
| 64 ) -> go.Figure: | |
| 65 """Build a basic epoch-indexed line chart for train/val/test curves.""" | |
| 66 fig = go.Figure() | |
| 67 for name, series in traces: | |
| 68 if not series: | |
| 69 continue | |
| 70 epochs = list(range(1, len(series) + 1)) | |
| 71 fig.add_trace( | |
| 72 go.Scatter( | |
| 73 x=epochs, | |
| 74 y=series, | |
| 75 mode="lines+markers", | |
| 76 name=name, | |
| 77 line=dict(width=4), | |
| 78 ) | |
| 79 ) | |
| 80 | |
| 81 fig.update_layout( | |
| 82 title=dict(text=title, x=0.5), | |
| 83 xaxis_title="Epoch", | |
| 84 yaxis_title=yaxis_title, | |
| 85 width=760, | |
| 86 height=520, | |
| 87 hovermode="x unified", | |
| 88 ) | |
| 89 _style_fig(fig) | |
| 21 return fig | 90 return fig |
| 22 | 91 |
| 23 | 92 |
| 24 def _labels_from_metadata_dict(meta_dict: dict) -> List[str]: | 93 def _labels_from_metadata_dict(meta_dict: dict) -> List[str]: |
| 25 """Extract ordered label names from Ludwig train_set_metadata.""" | 94 """Extract ordered label names from Ludwig train_set_metadata.""" |
| 104 def build_classification_plots( | 173 def build_classification_plots( |
| 105 test_stats_path: str, | 174 test_stats_path: str, |
| 106 training_stats_path: Optional[str] = None, | 175 training_stats_path: Optional[str] = None, |
| 107 metadata_csv_path: Optional[str] = None, | 176 metadata_csv_path: Optional[str] = None, |
| 108 train_set_metadata_path: Optional[str] = None, | 177 train_set_metadata_path: Optional[str] = None, |
| 178 threshold: Optional[float] = None, | |
| 109 ) -> List[Dict[str, str]]: | 179 ) -> List[Dict[str, str]]: |
| 110 """ | 180 """ |
| 111 Read Ludwig’s test_statistics.json and build three interactive Plotly panels: | 181 Read Ludwig’s test_statistics.json and build three interactive Plotly panels: |
| 112 - Confusion Matrix | 182 - Confusion Matrix |
| 113 - ROC-AUC | 183 - ROC-AUC |
| 154 showscale=True, | 224 showscale=True, |
| 155 colorbar=dict(title="Count"), | 225 colorbar=dict(title="Count"), |
| 156 ) | 226 ) |
| 157 ) | 227 ) |
| 158 fig_cm.update_traces(xgap=2, ygap=2) | 228 fig_cm.update_traces(xgap=2, ygap=2) |
| 229 cm_title = "Confusion Matrix" | |
| 230 if threshold is not None: | |
| 231 cm_title = f"Confusion Matrix (Threshold: {threshold})" | |
| 159 fig_cm.update_layout( | 232 fig_cm.update_layout( |
| 160 title=dict(text="Confusion Matrix", x=0.5), | 233 title=dict(text=cm_title, x=0.5), |
| 161 xaxis_title="Predicted", | 234 xaxis_title="Predicted", |
| 162 yaxis_title="Observed", | 235 yaxis_title="Observed", |
| 163 yaxis_autorange="reversed", | 236 yaxis_autorange="reversed", |
| 164 width=side_px, | 237 width=side_px, |
| 165 height=side_px, | 238 height=side_px, |
| 194 xanchor="center", | 267 xanchor="center", |
| 195 yanchor="top", | 268 yanchor="top", |
| 196 yshift=-2, | 269 yshift=-2, |
| 197 ) | 270 ) |
| 198 | 271 |
| 199 plots.append({ | 272 plots.append( |
| 200 "title": "Confusion Matrix", | 273 _wrap_plot("Confusion Matrix", fig_cm, include_js=True, config=common_cfg) |
| 201 "html": pio.to_html( | 274 ) |
| 202 fig_cm, | 275 |
| 203 full_html=False, | 276 # 1) ROC / PR curves only for binary tasks |
| 204 include_plotlyjs="cdn", | 277 if n_classes == 2: |
| 205 config=common_cfg | 278 roc_plot = _build_static_roc_plot(label_stats, common_cfg, friendly_labels=labels) |
| 206 ) | 279 if roc_plot: |
| 207 }) | 280 plots.append(roc_plot) |
| 208 | 281 |
| 209 # 1) ROC Curve (from test_statistics) | 282 pr_plot = _build_precision_recall_plot(label_stats, common_cfg) |
| 210 roc_plot = _build_static_roc_plot(label_stats, common_cfg, friendly_labels=labels) | 283 if pr_plot: |
| 211 if roc_plot: | 284 plots.append(pr_plot) |
| 212 plots.append(roc_plot) | |
| 213 | |
| 214 # 2) Precision-Recall Curve (from test_statistics) | |
| 215 pr_plot = _build_precision_recall_plot(label_stats, common_cfg) | |
| 216 if pr_plot: | |
| 217 plots.append(pr_plot) | |
| 218 | 285 |
| 219 # 2) Classification Report Heatmap | 286 # 2) Classification Report Heatmap |
| 220 pcs = label_stats.get("per_class_stats", {}) | 287 pcs = label_stats.get("per_class_stats", {}) |
| 221 if pcs: | 288 if pcs: |
| 222 classes = list(pcs.keys()) | 289 classes = list(pcs.keys()) |
| 257 width=side_px, | 324 width=side_px, |
| 258 height=side_px, | 325 height=side_px, |
| 259 margin=dict(t=80, l=80, r=80, b=80), | 326 margin=dict(t=80, l=80, r=80, b=80), |
| 260 ) | 327 ) |
| 261 _style_fig(fig_cr) | 328 _style_fig(fig_cr) |
| 262 plots.append({ | 329 plots.append( |
| 263 "title": "Per-Class metrics", | 330 _wrap_plot("Per-Class metrics", fig_cr, config=common_cfg) |
| 264 "html": pio.to_html( | 331 ) |
| 265 fig_cr, | |
| 266 full_html=False, | |
| 267 include_plotlyjs=False, | |
| 268 config=common_cfg | |
| 269 ) | |
| 270 }) | |
| 271 | 332 |
| 272 # 3) Prediction Diagnostics (from predictions.csv) | 333 # 3) Prediction Diagnostics (from predictions.csv) |
| 273 # Note: appended separately in generate_html_report, not returned here. | 334 # Note: appended separately in generate_html_report, not returned here. |
| 274 | 335 |
| 275 return plots | 336 return plots |
| 292 return [] | 353 return [] |
| 293 plots: List[Dict[str, str]] = [] | 354 plots: List[Dict[str, str]] = [] |
| 294 include_js = True # Load Plotly.js once for this group | 355 include_js = True # Load Plotly.js once for this group |
| 295 | 356 |
| 296 def _get_series(stats: dict, metric: str) -> List[float]: | 357 def _get_series(stats: dict, metric: str) -> List[float]: |
| 297 if metric not in stats: | |
| 298 return [] | |
| 299 vals = stats.get(metric, []) | 358 vals = stats.get(metric, []) |
| 300 if isinstance(vals, list): | 359 if isinstance(vals, list): |
| 301 return [float(v) for v in vals] | 360 return [float(v) for v in vals] |
| 302 try: | 361 try: |
| 303 return [float(vals)] | 362 return [float(vals)] |
| 304 except Exception: | 363 except Exception: |
| 305 return [] | 364 return [] |
| 306 | 365 |
| 307 def _line_plot(metric_key: str, title: str, yaxis_title: str) -> Optional[Dict[str, str]]: | 366 metric_specs = [ |
| 308 train_series = _get_series(label_train, metric_key) | 367 ("loss", "Loss across epochs", "Loss"), |
| 309 val_series = _get_series(label_val, metric_key) | 368 ("accuracy", "Accuracy across epochs", "Accuracy"), |
| 369 ("roc_auc", "ROC-AUC across epochs", "ROC-AUC"), | |
| 370 ("precision", "Precision across epochs", "Precision"), | |
| 371 ("recall", "Recall/Sensitivity across epochs", "Recall"), | |
| 372 ("specificity", "Specificity across epochs", "Specificity"), | |
| 373 ] | |
| 374 | |
| 375 for key, title, yaxis in metric_specs: | |
| 376 train_series = _get_series(label_train, key) | |
| 377 val_series = _get_series(label_val, key) | |
| 310 if not train_series and not val_series: | 378 if not train_series and not val_series: |
| 311 return None | 379 continue |
| 312 epochs_train = list(range(1, len(train_series) + 1)) | 380 fig = _line_chart( |
| 313 epochs_val = list(range(1, len(val_series) + 1)) | 381 [("Train", train_series), ("Validation", val_series)], |
| 314 fig = go.Figure() | 382 title=title, |
| 315 if train_series: | 383 yaxis_title=yaxis, |
| 316 fig.add_trace( | 384 ) |
| 317 go.Scatter( | 385 plots.append(_wrap_plot(title, fig, include_js=include_js)) |
| 318 x=epochs_train, | 386 include_js = False |
| 319 y=train_series, | |
| 320 mode="lines+markers", | |
| 321 name="Train", | |
| 322 line=dict(width=4), | |
| 323 ) | |
| 324 ) | |
| 325 if val_series: | |
| 326 fig.add_trace( | |
| 327 go.Scatter( | |
| 328 x=epochs_val, | |
| 329 y=val_series, | |
| 330 mode="lines+markers", | |
| 331 name="Validation", | |
| 332 line=dict(width=4), | |
| 333 ) | |
| 334 ) | |
| 335 fig.update_layout( | |
| 336 title=dict(text=title, x=0.5), | |
| 337 xaxis_title="Epoch", | |
| 338 yaxis_title=yaxis_title, | |
| 339 width=760, | |
| 340 height=520, | |
| 341 hovermode="x unified", | |
| 342 ) | |
| 343 _style_fig(fig) | |
| 344 return { | |
| 345 "title": title, | |
| 346 "html": pio.to_html( | |
| 347 fig, | |
| 348 full_html=False, | |
| 349 include_plotlyjs="cdn" if include_js else False, | |
| 350 ), | |
| 351 } | |
| 352 | |
| 353 # Core learning curves | |
| 354 for key, title in [ | |
| 355 ("roc_auc", "ROC-AUC across epochs"), | |
| 356 ("precision", "Precision across epochs"), | |
| 357 ("recall", "Recall/Sensitivity across epochs"), | |
| 358 ("specificity", "Specificity across epochs"), | |
| 359 ]: | |
| 360 plot = _line_plot(key, title, title.replace("Learning Curve", "").strip()) | |
| 361 if plot: | |
| 362 plots.append(plot) | |
| 363 include_js = False | |
| 364 | 387 |
| 365 # Precision vs Recall evolution (validation) | 388 # Precision vs Recall evolution (validation) |
| 366 val_prec = _get_series(label_val, "precision") | 389 val_prec = _get_series(label_val, "precision") |
| 367 val_rec = _get_series(label_val, "recall") | 390 val_rec = _get_series(label_val, "recall") |
| 368 if val_prec and val_rec: | 391 if val_prec and val_rec: |
| 369 epochs = list(range(1, min(len(val_prec), len(val_rec)) + 1)) | 392 max_len = min(len(val_prec), len(val_rec)) |
| 370 fig_pr = go.Figure() | 393 fig_pr = _line_chart( |
| 371 fig_pr.add_trace( | 394 [ |
| 372 go.Scatter( | 395 ("Precision", val_prec[:max_len]), |
| 373 x=epochs, | 396 ("Recall", val_rec[:max_len]), |
| 374 y=val_prec[: len(epochs)], | 397 ], |
| 375 mode="lines+markers", | 398 title="Validation Precision and Recall by Epoch", |
| 376 name="Precision", | |
| 377 ) | |
| 378 ) | |
| 379 fig_pr.add_trace( | |
| 380 go.Scatter( | |
| 381 x=epochs, | |
| 382 y=val_rec[: len(epochs)], | |
| 383 mode="lines+markers", | |
| 384 name="Recall", | |
| 385 ) | |
| 386 ) | |
| 387 fig_pr.update_layout( | |
| 388 title=dict(text="Validation Precision and Recall by Epoch", x=0.5), | |
| 389 xaxis_title="Epoch", | |
| 390 yaxis_title="Value", | 399 yaxis_title="Value", |
| 391 width=760, | 400 ) |
| 392 height=520, | 401 plots.append(_wrap_plot("Precision vs Recall Evolution", fig_pr, include_js=include_js)) |
| 393 hovermode="x unified", | |
| 394 ) | |
| 395 _style_fig(fig_pr) | |
| 396 plots.append({ | |
| 397 "title": "Precision vs Recall Evolution", | |
| 398 "html": pio.to_html( | |
| 399 fig_pr, | |
| 400 full_html=False, | |
| 401 include_plotlyjs="cdn" if include_js else False, | |
| 402 ), | |
| 403 }) | |
| 404 include_js = False | 402 include_js = False |
| 405 | 403 |
| 406 # F1-score derived | |
| 407 def _compute_f1(p: List[float], r: List[float]) -> List[float]: | 404 def _compute_f1(p: List[float], r: List[float]) -> List[float]: |
| 408 f1_vals = [] | 405 return [ |
| 409 for prec, rec in zip(p, r): | 406 0.0 if (prec + rec) == 0 else 2 * prec * rec / (prec + rec) |
| 410 if (prec + rec) == 0: | 407 for prec, rec in zip(p, r) |
| 411 f1_vals.append(0.0) | 408 ] |
| 412 else: | |
| 413 f1_vals.append(2 * prec * rec / (prec + rec)) | |
| 414 return f1_vals | |
| 415 | 409 |
| 416 f1_train = _compute_f1(_get_series(label_train, "precision"), _get_series(label_train, "recall")) | 410 f1_train = _compute_f1(_get_series(label_train, "precision"), _get_series(label_train, "recall")) |
| 417 f1_val = _compute_f1(val_prec, val_rec) | 411 f1_val = _compute_f1(val_prec, val_rec) |
| 418 if f1_train or f1_val: | 412 if f1_train or f1_val: |
| 419 fig = go.Figure() | 413 fig_f1 = _line_chart( |
| 420 if f1_train: | 414 [("Train", f1_train), ("Validation", f1_val)], |
| 421 fig.add_trace(go.Scatter(x=list(range(1, len(f1_train) + 1)), y=f1_train, mode="lines+markers", name="Train", line=dict(width=4))) | 415 title="F1-Score across epochs (derived)", |
| 422 if f1_val: | |
| 423 fig.add_trace(go.Scatter(x=list(range(1, len(f1_val) + 1)), y=f1_val, mode="lines+markers", name="Validation", line=dict(width=4))) | |
| 424 fig.update_layout( | |
| 425 title=dict(text="F1-Score across epochs (derived)", x=0.5), | |
| 426 xaxis_title="Epoch", | |
| 427 yaxis_title="F1-Score", | 416 yaxis_title="F1-Score", |
| 428 width=760, | 417 ) |
| 429 height=520, | 418 plots.append(_wrap_plot("F1-Score across epochs (derived)", fig_f1, include_js=include_js)) |
| 430 hovermode="x unified", | |
| 431 ) | |
| 432 _style_fig(fig) | |
| 433 plots.append({ | |
| 434 "title": "F1-Score across epochs (derived)", | |
| 435 "html": pio.to_html( | |
| 436 fig, | |
| 437 full_html=False, | |
| 438 include_plotlyjs="cdn" if include_js else False, | |
| 439 ), | |
| 440 }) | |
| 441 include_js = False | 419 include_js = False |
| 442 | 420 |
| 443 # Overfitting Gap: Train vs Val ROC-AUC (gap) | 421 # Overfitting Gap: Train vs Val ROC-AUC (gap) |
| 444 roc_train = _get_series(label_train, "roc_auc") | 422 roc_train = _get_series(label_train, "roc_auc") |
| 445 roc_val = _get_series(label_val, "roc_auc") | 423 roc_val = _get_series(label_val, "roc_auc") |
| 446 if roc_train and roc_val: | 424 if roc_train and roc_val: |
| 447 epochs_gap = list(range(1, min(len(roc_train), len(roc_val)) + 1)) | 425 max_len = min(len(roc_train), len(roc_val)) |
| 448 gaps = [t - v for t, v in zip(roc_train[:len(epochs_gap)], roc_val[:len(epochs_gap)])] | 426 gaps = [t - v for t, v in zip(roc_train[:max_len], roc_val[:max_len])] |
| 449 fig_gap = go.Figure() | 427 fig_gap = _line_chart( |
| 450 fig_gap.add_trace(go.Scatter(x=epochs_gap, y=gaps, mode="lines+markers", name="Train - Val ROC-AUC", line=dict(width=4))) | 428 [("Train - Val ROC-AUC", gaps)], |
| 451 fig_gap.update_layout( | 429 title="Overfitting gap: ROC-AUC across epochs", |
| 452 title=dict(text="Overfitting gap: ROC-AUC across epochs", x=0.5), | |
| 453 xaxis_title="Epoch", | |
| 454 yaxis_title="Gap", | 430 yaxis_title="Gap", |
| 455 width=760, | 431 ) |
| 456 height=520, | 432 plots.append(_wrap_plot("Overfitting gap: ROC-AUC across epochs", fig_gap, include_js=include_js)) |
| 457 hovermode="x unified", | |
| 458 ) | |
| 459 _style_fig(fig_gap) | |
| 460 plots.append({ | |
| 461 "title": "Overfitting gap: ROC-AUC across epochs", | |
| 462 "html": pio.to_html( | |
| 463 fig_gap, | |
| 464 full_html=False, | |
| 465 include_plotlyjs="cdn" if include_js else False, | |
| 466 ), | |
| 467 }) | |
| 468 include_js = False | 433 include_js = False |
| 469 | 434 |
| 470 # Best Epoch Dashboard (based on max val ROC-AUC) | 435 # Best Epoch Dashboard (based on max val ROC-AUC) |
| 471 if roc_val: | 436 if roc_val: |
| 472 best_idx = int(np.argmax(roc_val)) | 437 best_idx = int(np.argmax(roc_val)) |
| 473 best_epoch = best_idx + 1 | 438 best_epoch = best_idx + 1 |
| 474 spec_val = _get_series(label_val, "specificity") | 439 metrics_at_best: Dict[str, Optional[float]] = { |
| 475 metrics_at_best = { | 440 "ROC-AUC": roc_val[best_idx] if best_idx < len(roc_val) else None |
| 476 "ROC-AUC": roc_val[best_idx] if best_idx < len(roc_val) else None, | |
| 477 "Precision": val_prec[best_idx] if best_idx < len(val_prec) else None, | |
| 478 "Recall": val_rec[best_idx] if best_idx < len(val_rec) else None, | |
| 479 "Specificity": spec_val[best_idx] if best_idx < len(spec_val) else None, | |
| 480 "F1-Score": f1_val[best_idx] if best_idx < len(f1_val) else None, | |
| 481 } | 441 } |
| 442 | |
| 443 for metric_key, label in [ | |
| 444 ("accuracy", "Accuracy"), | |
| 445 ("balanced_accuracy", "Balanced Accuracy"), | |
| 446 ("precision", "Precision"), | |
| 447 ("recall", "Recall"), | |
| 448 ("specificity", "Specificity"), | |
| 449 ("loss", "Loss"), | |
| 450 ]: | |
| 451 series = _get_series(label_val, metric_key) | |
| 452 if series and best_idx < len(series): | |
| 453 metrics_at_best[label] = series[best_idx] | |
| 454 | |
| 455 if f1_val and best_idx < len(f1_val): | |
| 456 metrics_at_best["F1-Score (derived)"] = f1_val[best_idx] | |
| 457 | |
| 482 fig_best = go.Figure() | 458 fig_best = go.Figure() |
| 483 for name, value in metrics_at_best.items(): | 459 for name, value in metrics_at_best.items(): |
| 484 if value is not None: | 460 if value is not None: |
| 485 fig_best.add_trace(go.Bar(name=name, x=[name], y=[value])) | 461 fig_best.add_trace(go.Bar(name=name, x=[name], y=[value])) |
| 486 fig_best.update_layout( | 462 fig_best.update_layout( |
| 490 width=760, | 466 width=760, |
| 491 height=520, | 467 height=520, |
| 492 showlegend=False, | 468 showlegend=False, |
| 493 ) | 469 ) |
| 494 _style_fig(fig_best) | 470 _style_fig(fig_best) |
| 495 plots.append({ | 471 plots.append(_wrap_plot("Best Validation Epoch Snapshot (Metrics)", fig_best, include_js=include_js)) |
| 496 "title": "Best Validation Epoch Snapshot (Metrics)", | |
| 497 "html": pio.to_html( | |
| 498 fig_best, | |
| 499 full_html=False, | |
| 500 include_plotlyjs="cdn" if include_js else False, | |
| 501 ), | |
| 502 }) | |
| 503 include_js = False | |
| 504 | 472 |
| 505 return plots | 473 return plots |
| 506 | 474 |
| 507 | 475 |
| 508 def _get_regression_series(split_stats: dict, metric: str) -> List[float]: | 476 def _get_regression_series(split_stats: dict, metric: str) -> List[float]: |
| 527 ) -> Optional[Dict[str, str]]: | 495 ) -> Optional[Dict[str, str]]: |
| 528 train_series = _get_regression_series(train_split, metric_key) | 496 train_series = _get_regression_series(train_split, metric_key) |
| 529 val_series = _get_regression_series(val_split, metric_key) | 497 val_series = _get_regression_series(val_split, metric_key) |
| 530 if not train_series and not val_series: | 498 if not train_series and not val_series: |
| 531 return None | 499 return None |
| 532 epochs_train = list(range(1, len(train_series) + 1)) | 500 |
| 533 epochs_val = list(range(1, len(val_series) + 1)) | 501 fig = _line_chart( |
| 534 fig = go.Figure() | 502 [("Train", train_series), ("Validation", val_series)], |
| 535 if train_series: | 503 title=title, |
| 536 fig.add_trace( | |
| 537 go.Scatter( | |
| 538 x=epochs_train, | |
| 539 y=train_series, | |
| 540 mode="lines+markers", | |
| 541 name="Train", | |
| 542 line=dict(width=4), | |
| 543 ) | |
| 544 ) | |
| 545 if val_series: | |
| 546 fig.add_trace( | |
| 547 go.Scatter( | |
| 548 x=epochs_val, | |
| 549 y=val_series, | |
| 550 mode="lines+markers", | |
| 551 name="Validation", | |
| 552 line=dict(width=4), | |
| 553 ) | |
| 554 ) | |
| 555 fig.update_layout( | |
| 556 title=dict(text=title, x=0.5), | |
| 557 xaxis_title="Epoch", | |
| 558 yaxis_title=yaxis_title, | 504 yaxis_title=yaxis_title, |
| 559 width=760, | |
| 560 height=520, | |
| 561 hovermode="x unified", | |
| 562 ) | 505 ) |
| 563 _style_fig(fig) | 506 return _wrap_plot(title, fig, include_js=include_js) |
| 564 return { | |
| 565 "title": title, | |
| 566 "html": pio.to_html( | |
| 567 fig, | |
| 568 full_html=False, | |
| 569 include_plotlyjs="cdn" if include_js else False, | |
| 570 ), | |
| 571 } | |
| 572 | 507 |
| 573 | 508 |
| 574 def build_regression_train_val_plots(train_stats_path: str) -> List[Dict[str, str]]: | 509 def build_regression_train_val_plots(train_stats_path: str) -> List[Dict[str, str]]: |
| 575 """Generate regression Train/Validation learning curve plots from training_statistics.json.""" | 510 """Generate regression Train/Validation learning curve plots from training_statistics.json.""" |
| 576 if not train_stats_path or not Path(train_stats_path).exists(): | 511 if not train_stats_path or not Path(train_stats_path).exists(): |
| 625 ("root_mean_squared_error", "Root Mean Squared Error Across Epochs", "RMSE"), | 560 ("root_mean_squared_error", "Root Mean Squared Error Across Epochs", "RMSE"), |
| 626 ("mean_absolute_percentage_error", "Mean Absolute Percentage Error Across Epochs", "MAPE"), | 561 ("mean_absolute_percentage_error", "Mean Absolute Percentage Error Across Epochs", "MAPE"), |
| 627 ("r2", "R² Across Epochs", "R²"), | 562 ("r2", "R² Across Epochs", "R²"), |
| 628 ("loss", "Loss Across Epochs", "Loss"), | 563 ("loss", "Loss Across Epochs", "Loss"), |
| 629 ] | 564 ] |
| 630 epochs = None | |
| 631 for metric_key, title, ytitle in metrics: | 565 for metric_key, title, ytitle in metrics: |
| 632 series = _get_regression_series(label_test, metric_key) | 566 series = _get_regression_series(label_test, metric_key) |
| 633 if not series: | 567 if not series: |
| 634 continue | 568 continue |
| 635 if epochs is None: | 569 fig = _line_chart( |
| 636 epochs = list(range(1, len(series) + 1)) | 570 [("Test", series)], |
| 637 fig = go.Figure() | 571 title=title, |
| 638 fig.add_trace( | |
| 639 go.Scatter( | |
| 640 x=epochs, | |
| 641 y=series[: len(epochs)], | |
| 642 mode="lines+markers", | |
| 643 name="Test", | |
| 644 line=dict(width=4), | |
| 645 ) | |
| 646 ) | |
| 647 fig.update_layout( | |
| 648 title=dict(text=title, x=0.5), | |
| 649 xaxis_title="Epoch", | |
| 650 yaxis_title=ytitle, | 572 yaxis_title=ytitle, |
| 651 width=760, | 573 ) |
| 652 height=520, | 574 plots.append(_wrap_plot(title, fig, include_js=include_js)) |
| 653 hovermode="x unified", | |
| 654 ) | |
| 655 _style_fig(fig) | |
| 656 plots.append({ | |
| 657 "title": title, | |
| 658 "html": pio.to_html( | |
| 659 fig, | |
| 660 full_html=False, | |
| 661 include_plotlyjs="cdn" if include_js else False, | |
| 662 ), | |
| 663 }) | |
| 664 include_js = False | 575 include_js = False |
| 665 return plots | 576 return plots |
| 666 | 577 |
| 667 | 578 |
| 668 def _build_static_roc_plot( | 579 def _build_static_roc_plot( |
| 669 label_stats: dict, config: dict, friendly_labels: Optional[List[str]] = None | 580 label_stats: dict, |
| 581 config: dict, | |
| 582 friendly_labels: Optional[List[str]] = None, | |
| 583 threshold: Optional[float] = None, | |
| 670 ) -> Optional[Dict[str, str]]: | 584 ) -> Optional[Dict[str, str]]: |
| 671 """Build ROC curve directly from test_statistics.json (single curve).""" | 585 """Build ROC curve directly from test_statistics.json (single curve).""" |
| 672 roc_data = label_stats.get("roc_curve") | 586 roc_data = label_stats.get("roc_curve") |
| 673 if not isinstance(roc_data, dict): | 587 if not isinstance(roc_data, dict): |
| 674 return None | 588 return None |
| 774 ) | 688 ) |
| 775 _style_fig(fig) | 689 _style_fig(fig) |
| 776 fig.update_xaxes(range=[0, 1.0]) | 690 fig.update_xaxes(range=[0, 1.0]) |
| 777 fig.update_yaxes(range=[0, 1.05]) | 691 fig.update_yaxes(range=[0, 1.05]) |
| 778 | 692 |
| 693 roc_thresholds = roc_data.get("thresholds") | |
| 694 if threshold is not None and isinstance(roc_thresholds, list) and len(roc_thresholds) == len(fpr): | |
| 695 try: | |
| 696 diffs = [abs(th - threshold) for th in roc_thresholds] | |
| 697 best_idx = int(np.argmin(diffs)) | |
| 698 # dashed guides through the chosen point | |
| 699 fig.add_shape( | |
| 700 type="line", | |
| 701 x0=fpr[best_idx], | |
| 702 x1=fpr[best_idx], | |
| 703 y0=0, | |
| 704 y1=tpr[best_idx], | |
| 705 line=dict(color="gray", width=2, dash="dash"), | |
| 706 ) | |
| 707 fig.add_shape( | |
| 708 type="line", | |
| 709 x0=0, | |
| 710 x1=fpr[best_idx], | |
| 711 y0=tpr[best_idx], | |
| 712 y1=tpr[best_idx], | |
| 713 line=dict(color="gray", width=2, dash="dash"), | |
| 714 ) | |
| 715 fig.add_trace( | |
| 716 go.Scatter( | |
| 717 x=[fpr[best_idx]], | |
| 718 y=[tpr[best_idx]], | |
| 719 mode="markers", | |
| 720 marker=dict(color="black", size=10, symbol="x"), | |
| 721 name=f"Threshold={threshold}", | |
| 722 hovertemplate="FPR: %{x:.3f}<br>TPR: %{y:.3f}<br>Threshold: %{text}<extra></extra>", | |
| 723 text=[f"{threshold}"], | |
| 724 ) | |
| 725 ) | |
| 726 except Exception as exc: | |
| 727 print(f"Warning: could not add threshold marker to ROC: {exc}") | |
| 728 | |
| 779 fig.add_annotation( | 729 fig.add_annotation( |
| 780 x=0.5, | 730 x=0.5, |
| 781 y=-0.15, | 731 y=-0.15, |
| 782 xref="paper", | 732 xref="paper", |
| 783 yref="paper", | 733 yref="paper", |
| 784 showarrow=False, | 734 showarrow=False, |
| 785 text=f"<span style='font-size:12px;color:#555;'>{pos_label_txt}</span>", | 735 text=f"<span style='font-size:12px;color:#555;'>{pos_label_txt}</span>", |
| 786 xanchor="center", | 736 xanchor="center", |
| 787 ) | 737 ) |
| 788 | 738 |
| 789 return { | 739 return _wrap_plot("ROC Curve", fig, config=config) |
| 790 "title": "ROC Curve", | |
| 791 "html": pio.to_html( | |
| 792 fig, | |
| 793 full_html=False, | |
| 794 include_plotlyjs=False, | |
| 795 config=config, | |
| 796 ), | |
| 797 } | |
| 798 except Exception as e: | 740 except Exception as e: |
| 799 print(f"Error building ROC plot: {e}") | 741 print(f"Error building ROC plot: {e}") |
| 800 return None | 742 return None |
| 801 | 743 |
| 802 | 744 |
| 803 def _build_precision_recall_plot(label_stats: dict, config: dict) -> Optional[Dict[str, str]]: | 745 def _build_precision_recall_plot( |
| 746 label_stats: dict, | |
| 747 config: dict, | |
| 748 threshold: Optional[float] = None, | |
| 749 ) -> Optional[Dict[str, str]]: | |
| 804 """Build Precision-Recall curve directly from test_statistics.json.""" | 750 """Build Precision-Recall curve directly from test_statistics.json.""" |
| 805 pr_data = label_stats.get("precision_recall_curve") | 751 pr_data = label_stats.get("precision_recall_curve") |
| 806 if not isinstance(pr_data, dict): | 752 if not isinstance(pr_data, dict): |
| 807 return None | 753 return None |
| 808 | 754 |
| 809 precisions = pr_data.get("precisions") | 755 precisions = pr_data.get("precisions") |
| 810 recalls = pr_data.get("recalls") | 756 recalls = pr_data.get("recalls") |
| 811 if not precisions or not recalls or len(precisions) != len(recalls): | 757 if not precisions or not recalls or len(precisions) != len(recalls): |
| 812 return None | 758 return None |
| 759 | |
| 760 thresholds = pr_data.get("thresholds") | |
| 813 | 761 |
| 814 try: | 762 try: |
| 815 fig = go.Figure() | 763 fig = go.Figure() |
| 816 fig.add_trace( | 764 fig.add_trace( |
| 817 go.Scatter( | 765 go.Scatter( |
| 849 ) | 797 ) |
| 850 _style_fig(fig) | 798 _style_fig(fig) |
| 851 fig.update_xaxes(range=[0, 1.0]) | 799 fig.update_xaxes(range=[0, 1.0]) |
| 852 fig.update_yaxes(range=[0, 1.05]) | 800 fig.update_yaxes(range=[0, 1.05]) |
| 853 | 801 |
| 854 return { | 802 if threshold is not None and isinstance(thresholds, list) and len(thresholds) == len(recalls): |
| 855 "title": "Precision-Recall Curve", | 803 try: |
| 856 "html": pio.to_html( | 804 diffs = [abs(th - threshold) for th in thresholds] |
| 857 fig, | 805 best_idx = int(np.argmin(diffs)) |
| 858 full_html=False, | 806 fig.add_shape( |
| 859 include_plotlyjs=False, | 807 type="line", |
| 860 config=config, | 808 x0=recalls[best_idx], |
| 861 ), | 809 x1=recalls[best_idx], |
| 862 } | 810 y0=0, |
| 811 y1=precisions[best_idx], | |
| 812 line=dict(color="gray", width=2, dash="dash"), | |
| 813 ) | |
| 814 fig.add_shape( | |
| 815 type="line", | |
| 816 x0=0, | |
| 817 x1=recalls[best_idx], | |
| 818 y0=precisions[best_idx], | |
| 819 y1=precisions[best_idx], | |
| 820 line=dict(color="gray", width=2, dash="dash"), | |
| 821 ) | |
| 822 fig.add_trace( | |
| 823 go.Scatter( | |
| 824 x=[recalls[best_idx]], | |
| 825 y=[precisions[best_idx]], | |
| 826 mode="markers", | |
| 827 marker=dict(color="black", size=10, symbol="x"), | |
| 828 name=f"Threshold={threshold}", | |
| 829 hovertemplate="Recall: %{x:.3f}<br>Precision: %{y:.3f}<br>Threshold: %{text}<extra></extra>", | |
| 830 text=[f"{threshold}"], | |
| 831 ) | |
| 832 ) | |
| 833 except Exception as exc: | |
| 834 print(f"Warning: could not add threshold marker to PR: {exc}") | |
| 835 | |
| 836 return _wrap_plot("Precision-Recall Curve", fig, config=config) | |
| 863 except Exception as e: | 837 except Exception as e: |
| 864 print(f"Error building Precision-Recall plot: {e}") | 838 print(f"Error building Precision-Recall plot: {e}") |
| 865 return None | 839 return None |
| 866 | 840 |
| 867 | 841 |
| 868 def build_prediction_diagnostics( | 842 def build_prediction_diagnostics( |
| 869 predictions_path: str, | 843 predictions_path: str, |
| 870 label_data_path: Optional[str] = None, | 844 label_data_path: Optional[str] = None, |
| 871 split_value: int = 2, | 845 split_value: int = 2, |
| 872 threshold: Optional[float] = None, | |
| 873 ) -> List[Dict[str, str]]: | 846 ) -> List[Dict[str, str]]: |
| 874 """Generate diagnostic plots from predictions.csv for classification tasks.""" | 847 """Generate diagnostic plots from predictions.csv for classification tasks.""" |
| 875 preds_file = Path(predictions_path) | 848 preds_file = Path(predictions_path) |
| 876 if not preds_file.exists(): | 849 if not preds_file.exists(): |
| 877 return [] | 850 return [] |
| 881 except Exception as exc: | 854 except Exception as exc: |
| 882 print(f"Warning: Unable to read predictions CSV: {exc}") | 855 print(f"Warning: Unable to read predictions CSV: {exc}") |
| 883 return [] | 856 return [] |
| 884 | 857 |
| 885 plots: List[Dict[str, str]] = [] | 858 plots: List[Dict[str, str]] = [] |
| 859 labels_from_dataset: Optional[pd.Series] = None | |
| 860 | |
| 861 filtered_by_split = False | |
| 862 | |
| 863 # If a split column exists, focus on the requested split (e.g., validation=1, test=2). | |
| 864 # If not, but label_data_path is available and matches row count, use it to filter predictions. | |
| 865 if SPLIT_COLUMN_NAME in df_pred.columns: | |
| 866 df_pred = df_pred[df_pred[SPLIT_COLUMN_NAME] == split_value].reset_index(drop=True) | |
| 867 if df_pred.empty: | |
| 868 return [] | |
| 869 filtered_by_split = True | |
| 870 elif label_data_path and Path(label_data_path).exists(): | |
| 871 try: | |
| 872 df_labels_all = pd.read_csv(label_data_path) | |
| 873 if SPLIT_COLUMN_NAME in df_labels_all.columns and len(df_labels_all) == len(df_pred): | |
| 874 split_mask = pd.to_numeric(df_labels_all[SPLIT_COLUMN_NAME], errors="coerce") == split_value | |
| 875 labels_from_dataset = df_labels_all.loc[split_mask, LABEL_COLUMN_NAME].reset_index(drop=True) | |
| 876 df_pred = df_pred.loc[split_mask].reset_index(drop=True) | |
| 877 if df_pred.empty: | |
| 878 return [] | |
| 879 filtered_by_split = True | |
| 880 except Exception as exc: | |
| 881 print(f"Warning: Unable to filter predictions by split from label data: {exc}") | |
| 882 | |
| 883 # Fallback: no split info available. Assume the predictions file is already filtered | |
| 884 # (common for test-only exports) and avoid heuristic slicing that could discard rows. | |
| 885 if not filtered_by_split: | |
| 886 if split_value != 2: | |
| 887 return [] | |
| 888 | |
| 889 def _strip_prob_prefix(col: str) -> str: | |
| 890 if col.startswith("label_probabilities_"): | |
| 891 return col.replace("label_probabilities_", "") | |
| 892 if col.startswith("probabilities_"): | |
| 893 return col.replace("probabilities_", "") | |
| 894 return col | |
| 895 | |
| 896 def _maybe_expand_probabilities_column(df: pd.DataFrame, labels_guess: List[str]) -> List[str]: | |
| 897 """If only a single 'probabilities' column exists (list-like), expand it into per-class columns.""" | |
| 898 if "probabilities" not in df.columns: | |
| 899 return [] | |
| 900 try: | |
| 901 # Parse first non-null entry to infer length | |
| 902 first_val = df["probabilities"].dropna().iloc[0] | |
| 903 parsed = first_val | |
| 904 if isinstance(first_val, str): | |
| 905 parsed = json.loads(first_val) | |
| 906 probs = list(parsed) | |
| 907 n = len(probs) | |
| 908 if n == 0: | |
| 909 return [] | |
| 910 # Build labels: prefer provided guess; otherwise numeric | |
| 911 if labels_guess and len(labels_guess) == n: | |
| 912 labels_use = labels_guess | |
| 913 else: | |
| 914 labels_use = [str(i) for i in range(n)] | |
| 915 # Expand column | |
| 916 for idx, lbl in enumerate(labels_use): | |
| 917 df[f"probabilities_{lbl}"] = df["probabilities"].apply( | |
| 918 lambda v: (json.loads(v)[idx] if isinstance(v, str) else list(v)[idx]) if pd.notnull(v) else np.nan | |
| 919 ) | |
| 920 return [f"probabilities_{lbl}" for lbl in labels_use] | |
| 921 except Exception: | |
| 922 return [] | |
| 886 | 923 |
| 887 # Identify probability columns | 924 # Identify probability columns |
| 888 prob_cols = [ | 925 prob_cols = [ |
| 889 c for c in df_pred.columns | 926 c |
| 890 if c.startswith("label_probabilities_") and c != "label_probabilities" | 927 for c in df_pred.columns |
| 928 if ( | |
| 929 (c.startswith("label_probabilities_") or c.startswith("probabilities_")) | |
| 930 and c != "label_probabilities" | |
| 931 ) | |
| 891 ] | 932 ] |
| 933 if not prob_cols and "label_probability" in df_pred.columns: | |
| 934 prob_cols = ["label_probability"] | |
| 935 if not prob_cols and "probability" in df_pred.columns: | |
| 936 prob_cols = ["probability"] | |
| 937 if not prob_cols and "prediction_probability" in df_pred.columns: | |
| 938 prob_cols = ["prediction_probability"] | |
| 939 if not prob_cols and "probabilities" in df_pred.columns: | |
| 940 labels_guess = sorted([str(u) for u in pd.unique(df_pred[LABEL_COLUMN_NAME])]) | |
| 941 prob_cols = _maybe_expand_probabilities_column(df_pred, labels_guess) | |
| 892 prob_cols_sorted = sorted(prob_cols) | 942 prob_cols_sorted = sorted(prob_cols) |
| 893 | 943 |
| 894 def _select_positive_prob(): | 944 def _select_positive_prob(): |
| 895 if not prob_cols_sorted: | 945 if not prob_cols_sorted: |
| 896 return None, None | 946 return None, None |
| 897 # Prefer a column indicating positive/event/true/1 | 947 # Prefer a column indicating positive/event/true/1 |
| 898 preferred_keys = ("event", "true", "positive", "pos", "1") | 948 preferred_keys = ("event", "true", "positive", "pos", "1") |
| 899 for col in prob_cols_sorted: | 949 for col in prob_cols_sorted: |
| 900 suffix = col.replace("label_probabilities_", "").lower() | 950 suffix = _strip_prob_prefix(col).lower() |
| 901 if any(k in suffix for k in preferred_keys): | 951 if any(k in suffix for k in preferred_keys): |
| 902 return col, suffix | 952 return col, suffix |
| 903 if len(prob_cols_sorted) == 2: | 953 if len(prob_cols_sorted) == 2: |
| 904 col = prob_cols_sorted[1] | 954 col = prob_cols_sorted[1] |
| 905 return col, col.replace("label_probabilities_", "") | 955 return col, _strip_prob_prefix(col) |
| 906 col = prob_cols_sorted[0] | 956 col = prob_cols_sorted[0] |
| 907 return col, col.replace("label_probabilities_", "") | 957 return col, _strip_prob_prefix(col) |
| 908 | 958 |
| 909 pos_prob_col, pos_label_hint = _select_positive_prob() | 959 pos_prob_col, pos_label_hint = _select_positive_prob() |
| 910 pos_prob_series = df_pred[pos_prob_col] if pos_prob_col and pos_prob_col in df_pred else None | 960 pos_prob_series = df_pred[pos_prob_col] if pos_prob_col and pos_prob_col in df_pred else None |
| 911 | 961 |
| 912 # Confidence series: prefer label_probability, otherwise positive prob, otherwise max prob | 962 # Confidence series: prefer label_probability, otherwise positive prob, otherwise max prob |
| 918 elif prob_cols_sorted: | 968 elif prob_cols_sorted: |
| 919 confidence_series = df_pred[prob_cols_sorted].max(axis=1) | 969 confidence_series = df_pred[prob_cols_sorted].max(axis=1) |
| 920 | 970 |
| 921 # True labels | 971 # True labels |
| 922 def _extract_labels(): | 972 def _extract_labels(): |
| 973 if labels_from_dataset is not None: | |
| 974 return labels_from_dataset | |
| 923 candidates = [ | 975 candidates = [ |
| 924 LABEL_COLUMN_NAME, | 976 LABEL_COLUMN_NAME, |
| 925 f"{LABEL_COLUMN_NAME}_ground_truth", | 977 f"{LABEL_COLUMN_NAME}_ground_truth", |
| 926 f"{LABEL_COLUMN_NAME}__ground_truth", | 978 f"{LABEL_COLUMN_NAME}__ground_truth", |
| 927 f"{LABEL_COLUMN_NAME}_target", | 979 f"{LABEL_COLUMN_NAME}_target", |
| 973 bargap=0.05, | 1025 bargap=0.05, |
| 974 width=700, | 1026 width=700, |
| 975 height=500, | 1027 height=500, |
| 976 ) | 1028 ) |
| 977 _style_fig(fig_conf) | 1029 _style_fig(fig_conf) |
| 978 plots.append({ | 1030 plots.append(_wrap_plot("Prediction Confidence Distribution", fig_conf)) |
| 979 "title": "Prediction Confidence Distribution", | |
| 980 "html": pio.to_html(fig_conf, full_html=False, include_plotlyjs=False), | |
| 981 }) | |
| 982 | 1031 |
| 983 # The remaining plots require true labels and a positive-class probability | 1032 # The remaining plots require true labels and a positive-class probability |
| 984 if labels_series is None or pos_prob_series is None: | 1033 if labels_series is None or pos_prob_series is None: |
| 985 return plots | 1034 return plots |
| 986 | 1035 |
| 1002 else: | 1051 else: |
| 1003 positive_label = unique_labels_list[0] | 1052 positive_label = unique_labels_list[0] |
| 1004 | 1053 |
| 1005 y_true = (y_true_raw == positive_label).astype(int).values | 1054 y_true = (y_true_raw == positive_label).astype(int).values |
| 1006 | 1055 |
| 1007 # Plot 2: Calibration Curve | 1056 # Utility: compute calibration points |
| 1008 bins = np.linspace(0.0, 1.0, 11) | 1057 def _calibration_points(y_true_bin: np.ndarray, scores: np.ndarray): |
| 1009 bin_ids = np.digitize(y_score, bins, right=True) | 1058 bins = np.linspace(0.0, 1.0, 11) |
| 1010 bin_centers = [] | 1059 bin_ids = np.digitize(scores, bins, right=True) |
| 1011 frac_positives = [] | 1060 bin_centers, frac_positives = [], [] |
| 1012 for b in range(1, len(bins)): | 1061 for b in range(1, len(bins)): |
| 1013 mask = bin_ids == b | 1062 mask = bin_ids == b |
| 1014 if not np.any(mask): | 1063 if not np.any(mask): |
| 1015 continue | 1064 continue |
| 1016 bin_centers.append(y_score[mask].mean()) | 1065 bin_centers.append(scores[mask].mean()) |
| 1017 frac_positives.append(y_true[mask].mean()) | 1066 frac_positives.append(y_true_bin[mask].mean()) |
| 1018 if bin_centers and frac_positives: | 1067 return bin_centers, frac_positives |
| 1019 fig_cal = go.Figure() | 1068 |
| 1020 fig_cal.add_trace( | 1069 # Plot 2: Calibration Curve (multi-class aware; one-vs-rest per label) |
| 1070 label_prob_map = {} | |
| 1071 for col in prob_cols_sorted: | |
| 1072 if col.startswith("label_probabilities_"): | |
| 1073 cls = col.replace("label_probabilities_", "") | |
| 1074 label_prob_map[cls] = col | |
| 1075 | |
| 1076 unique_label_strs = [str(u) for u in unique_labels_list] | |
| 1077 if len(label_prob_map) > 1 and len(unique_label_strs) > 2: | |
| 1078 # Skip multi-class calibration curve for now (not informative in current report) | |
| 1079 pass | |
| 1080 else: | |
| 1081 # Binary/unknown fallback (previous behavior) | |
| 1082 bin_centers, frac_positives = _calibration_points(y_true, y_score) | |
| 1083 if bin_centers and frac_positives: | |
| 1084 fig_cal = go.Figure() | |
| 1085 fig_cal.add_trace( | |
| 1086 go.Scatter( | |
| 1087 x=bin_centers, | |
| 1088 y=frac_positives, | |
| 1089 mode="lines+markers", | |
| 1090 name="Calibration", | |
| 1091 line=dict(color="#2ca02c", width=4), | |
| 1092 ) | |
| 1093 ) | |
| 1094 fig_cal.add_trace( | |
| 1095 go.Scatter( | |
| 1096 x=[0, 1], | |
| 1097 y=[0, 1], | |
| 1098 mode="lines", | |
| 1099 name="Perfect Calibration", | |
| 1100 line=dict(color="gray", width=2, dash="dash"), | |
| 1101 ) | |
| 1102 ) | |
| 1103 fig_cal.update_layout( | |
| 1104 title=dict(text="Calibration Curve", x=0.5), | |
| 1105 xaxis_title="Predicted probability", | |
| 1106 yaxis_title="Observed frequency", | |
| 1107 width=700, | |
| 1108 height=500, | |
| 1109 ) | |
| 1110 _style_fig(fig_cal) | |
| 1111 plots.append( | |
| 1112 _wrap_plot( | |
| 1113 "Calibration Curve (Predicted Probability vs Observed Frequency)", | |
| 1114 fig_cal, | |
| 1115 ) | |
| 1116 ) | |
| 1117 | |
| 1118 return plots | |
| 1119 | |
| 1120 | |
| 1121 def build_binary_threshold_plot( | |
| 1122 predictions_path: str, | |
| 1123 label_data_path: Optional[str] = None, | |
| 1124 split_value: int = 1, | |
| 1125 ) -> Optional[Dict[str, str]]: | |
| 1126 """Build a binary threshold sweep plot (accuracy, precision, recall, F1) for a given split.""" | |
| 1127 preds_file = Path(predictions_path) | |
| 1128 if not preds_file.exists(): | |
| 1129 return None | |
| 1130 | |
| 1131 try: | |
| 1132 df_pred = pd.read_csv(predictions_path) | |
| 1133 except Exception as exc: | |
| 1134 print(f"Warning: Unable to read predictions CSV for threshold plot: {exc}") | |
| 1135 return None | |
| 1136 | |
| 1137 labels_from_dataset: Optional[pd.Series] = None | |
| 1138 df_full = df_pred.copy() | |
| 1139 | |
| 1140 def _filter_by_split(df: pd.DataFrame, split_val: int) -> pd.DataFrame: | |
| 1141 if SPLIT_COLUMN_NAME in df.columns: | |
| 1142 return df[df[SPLIT_COLUMN_NAME] == split_val].reset_index(drop=True) | |
| 1143 return df | |
| 1144 | |
| 1145 # Try preferred split, then fallback to others with data (val -> test -> train) | |
| 1146 candidate_splits = [split_value, 2, 0, 1] if split_value == 1 else [split_value, 1, 0, 2] | |
| 1147 df_candidate = pd.DataFrame() | |
| 1148 used_split: Optional[int] = None | |
| 1149 for sv in candidate_splits: | |
| 1150 df_candidate = _filter_by_split(df_full, sv) | |
| 1151 if not df_candidate.empty: | |
| 1152 used_split = sv | |
| 1153 break | |
| 1154 if used_split is None: | |
| 1155 df_candidate = df_full | |
| 1156 df_pred = df_candidate.reset_index(drop=True) | |
| 1157 | |
| 1158 # If still empty (e.g., split column exists but no rows for candidates), fall back to all rows | |
| 1159 if df_pred.empty: | |
| 1160 df_pred = df_full.reset_index(drop=True) | |
| 1161 labels_from_dataset = None | |
| 1162 | |
| 1163 if label_data_path and Path(label_data_path).exists(): | |
| 1164 try: | |
| 1165 df_labels_all = pd.read_csv(label_data_path) | |
| 1166 if SPLIT_COLUMN_NAME in df_labels_all.columns and len(df_labels_all) == len(df_full): | |
| 1167 mask = ( | |
| 1168 pd.to_numeric(df_labels_all[SPLIT_COLUMN_NAME], errors="coerce") == used_split | |
| 1169 if used_split is not None and SPLIT_COLUMN_NAME in df_labels_all.columns | |
| 1170 else pd.Series([True] * len(df_full)) | |
| 1171 ) | |
| 1172 labels_from_dataset = df_labels_all.loc[mask, LABEL_COLUMN_NAME].reset_index(drop=True) | |
| 1173 if len(labels_from_dataset) == len(df_pred): | |
| 1174 labels_from_dataset = labels_from_dataset.reset_index(drop=True) | |
| 1175 except Exception as exc: | |
| 1176 print(f"Warning: Unable to align labels for threshold plot: {exc}") | |
| 1177 | |
| 1178 # Identify probability columns | |
| 1179 prob_cols = [ | |
| 1180 c | |
| 1181 for c in df_pred.columns | |
| 1182 if ( | |
| 1183 (c.startswith("label_probabilities_") or c.startswith("probabilities_")) | |
| 1184 and c != "label_probabilities" | |
| 1185 ) | |
| 1186 ] | |
| 1187 if not prob_cols and "probabilities" in df_pred.columns: | |
| 1188 labels_guess = sorted([str(u) for u in pd.unique(df_pred.get(LABEL_COLUMN_NAME, []))]) | |
| 1189 # reuse expansion logic from diagnostics | |
| 1190 try: | |
| 1191 first_val = df_pred["probabilities"].dropna().iloc[0] | |
| 1192 parsed = json.loads(first_val) if isinstance(first_val, str) else list(first_val) | |
| 1193 n = len(parsed) | |
| 1194 if n > 0: | |
| 1195 if labels_guess and len(labels_guess) == n: | |
| 1196 labels_use = labels_guess | |
| 1197 else: | |
| 1198 labels_use = [str(i) for i in range(n)] | |
| 1199 for idx, lbl in enumerate(labels_use): | |
| 1200 df_pred[f"probabilities_{lbl}"] = df_pred["probabilities"].apply( | |
| 1201 lambda v: (json.loads(v)[idx] if isinstance(v, str) else list(v)[idx]) if pd.notnull(v) else np.nan | |
| 1202 ) | |
| 1203 prob_cols = [f"probabilities_{lbl}" for lbl in labels_use] | |
| 1204 except Exception: | |
| 1205 prob_cols = [] | |
| 1206 prob_cols_sorted = sorted(prob_cols) | |
| 1207 | |
| 1208 def _strip_prob_prefix(col: str) -> str: | |
| 1209 if col.startswith("label_probabilities_"): | |
| 1210 return col.replace("label_probabilities_", "") | |
| 1211 if col.startswith("probabilities_"): | |
| 1212 return col.replace("probabilities_", "") | |
| 1213 return col | |
| 1214 | |
| 1215 # True labels | |
| 1216 def _extract_labels(): | |
| 1217 if labels_from_dataset is not None: | |
| 1218 return labels_from_dataset | |
| 1219 for col in [ | |
| 1220 LABEL_COLUMN_NAME, | |
| 1221 f"{LABEL_COLUMN_NAME}_ground_truth", | |
| 1222 f"{LABEL_COLUMN_NAME}__ground_truth", | |
| 1223 f"{LABEL_COLUMN_NAME}_target", | |
| 1224 f"{LABEL_COLUMN_NAME}__target", | |
| 1225 "label", | |
| 1226 "label_true", | |
| 1227 "label_predictions", | |
| 1228 "prediction", | |
| 1229 ]: | |
| 1230 if col in df_pred.columns and col not in prob_cols_sorted: | |
| 1231 return df_pred[col] | |
| 1232 return None | |
| 1233 | |
| 1234 labels_series = _extract_labels() | |
| 1235 if labels_series is None or not prob_cols_sorted: | |
| 1236 return None | |
| 1237 | |
| 1238 # Positive prob column selection | |
| 1239 preferred_keys = ("event", "true", "positive", "pos", "1") | |
| 1240 pos_prob_col = None | |
| 1241 for col in prob_cols_sorted: | |
| 1242 suffix = _strip_prob_prefix(col).lower() | |
| 1243 if any(k in suffix for k in preferred_keys): | |
| 1244 pos_prob_col = col | |
| 1245 break | |
| 1246 if pos_prob_col is None: | |
| 1247 pos_prob_col = prob_cols_sorted[-1] | |
| 1248 | |
| 1249 min_len = min(len(labels_series), len(df_pred[pos_prob_col])) | |
| 1250 if min_len == 0: | |
| 1251 return None | |
| 1252 | |
| 1253 y_true = np.array(labels_series.iloc[:min_len]) | |
| 1254 # map to binary 0/1 | |
| 1255 unique_labels = pd.unique(y_true) | |
| 1256 if len(unique_labels) < 2: | |
| 1257 return None | |
| 1258 positive_label = unique_labels[1] if len(unique_labels) >= 2 else unique_labels[0] | |
| 1259 y_true_bin = (y_true == positive_label).astype(int) | |
| 1260 y_score = np.array(df_pred[pos_prob_col].iloc[:min_len], dtype=float) | |
| 1261 | |
| 1262 thresholds = np.linspace(0.0, 1.0, 101) | |
| 1263 accs: List[float] = [] | |
| 1264 precs: List[float] = [] | |
| 1265 recs: List[float] = [] | |
| 1266 f1s: List[float] = [] | |
| 1267 for t in thresholds: | |
| 1268 preds = (y_score >= t).astype(int) | |
| 1269 accs.append(accuracy_score(y_true_bin, preds)) | |
| 1270 precs.append(precision_score(y_true_bin, preds, zero_division=0)) | |
| 1271 recs.append(recall_score(y_true_bin, preds, zero_division=0)) | |
| 1272 f1s.append(f1_score(y_true_bin, preds, zero_division=0)) | |
| 1273 | |
| 1274 best_idx = int(np.argmax(f1s)) | |
| 1275 best_thr = thresholds[best_idx] | |
| 1276 | |
| 1277 fig = go.Figure() | |
| 1278 fig.add_trace(go.Scatter(x=thresholds, y=accs, mode="lines", name="Accuracy", line=dict(width=4))) | |
| 1279 fig.add_trace(go.Scatter(x=thresholds, y=precs, mode="lines", name="Precision", line=dict(width=4))) | |
| 1280 fig.add_trace(go.Scatter(x=thresholds, y=recs, mode="lines", name="Recall", line=dict(width=4))) | |
| 1281 fig.add_trace(go.Scatter(x=thresholds, y=f1s, mode="lines", name="F1-Score", line=dict(width=4))) | |
| 1282 fig.add_shape( | |
| 1283 type="line", | |
| 1284 x0=best_thr, | |
| 1285 x1=best_thr, | |
| 1286 y0=0, | |
| 1287 y1=1, | |
| 1288 line=dict(color="gray", width=2, dash="dash"), | |
| 1289 ) | |
| 1290 fig.update_layout( | |
| 1291 title=dict(text="Threshold plot", x=0.5), | |
| 1292 xaxis_title="Threshold", | |
| 1293 yaxis_title="Metric value", | |
| 1294 yaxis=dict(range=[0, 1]), | |
| 1295 width=760, | |
| 1296 height=520, | |
| 1297 legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1), | |
| 1298 ) | |
| 1299 _style_fig(fig) | |
| 1300 return _wrap_plot("Threshold plot", fig, include_js=True) | |
| 1301 | |
| 1302 | |
| 1303 def build_multiclass_roc_pr_plots( | |
| 1304 predictions_path: str, | |
| 1305 split_value: int = 2, | |
| 1306 ) -> List[Dict[str, str]]: | |
| 1307 """Build one-vs-rest ROC and PR curves for multi-class classification from predictions.""" | |
| 1308 preds_file = Path(predictions_path) | |
| 1309 if not preds_file.exists(): | |
| 1310 return [] | |
| 1311 try: | |
| 1312 df_pred = pd.read_csv(predictions_path) | |
| 1313 except Exception as exc: | |
| 1314 print(f"Warning: Unable to read predictions CSV: {exc}") | |
| 1315 return [] | |
| 1316 | |
| 1317 if SPLIT_COLUMN_NAME in df_pred.columns: | |
| 1318 df_pred = df_pred[df_pred[SPLIT_COLUMN_NAME] == split_value].reset_index(drop=True) | |
| 1319 if df_pred.empty: | |
| 1320 return [] | |
| 1321 | |
| 1322 if LABEL_COLUMN_NAME not in df_pred.columns: | |
| 1323 return [] | |
| 1324 | |
| 1325 # Identify per-class probability columns | |
| 1326 prob_cols = [ | |
| 1327 c | |
| 1328 for c in df_pred.columns | |
| 1329 if ( | |
| 1330 (c.startswith("label_probabilities_") or c.startswith("probabilities_")) | |
| 1331 and c != "label_probabilities" | |
| 1332 ) | |
| 1333 ] | |
| 1334 if not prob_cols: | |
| 1335 return [] | |
| 1336 labels = [c.replace("label_probabilities_", "").replace("probabilities_", "") for c in prob_cols] | |
| 1337 labels_sorted = sorted(labels) | |
| 1338 | |
| 1339 # Ensure all labels are present as probability columns | |
| 1340 prob_map = { | |
| 1341 c.replace("label_probabilities_", "").replace("probabilities_", ""): c | |
| 1342 for c in prob_cols | |
| 1343 } | |
| 1344 if len(labels_sorted) < 3: | |
| 1345 return [] | |
| 1346 | |
| 1347 y_true_raw = df_pred[LABEL_COLUMN_NAME].astype(str) | |
| 1348 # Drop rows with NaN probabilities across any class to avoid metric errors | |
| 1349 prob_matrix = df_pred[[prob_map[lbl] for lbl in labels_sorted]].astype(float) | |
| 1350 mask_valid = ~prob_matrix.isnull().any(axis=1) | |
| 1351 prob_matrix = prob_matrix[mask_valid] | |
| 1352 y_true_raw = y_true_raw[mask_valid] | |
| 1353 if prob_matrix.empty: | |
| 1354 return [] | |
| 1355 | |
| 1356 y_true_bin = label_binarize(y_true_raw, classes=labels_sorted) | |
| 1357 y_score = prob_matrix.to_numpy() | |
| 1358 | |
| 1359 plots: List[Dict[str, str]] = [] | |
| 1360 | |
| 1361 # ROC: one-vs-rest + micro | |
| 1362 fig_roc = go.Figure() | |
| 1363 added_any = False | |
| 1364 for idx, lbl in enumerate(labels_sorted): | |
| 1365 if y_true_bin[:, idx].sum() == 0 or y_true_bin[:, idx].sum() == len(y_true_bin): | |
| 1366 continue # skip classes without both positives and negatives | |
| 1367 fpr, tpr, _ = roc_curve(y_true_bin[:, idx], y_score[:, idx]) | |
| 1368 fig_roc.add_trace( | |
| 1021 go.Scatter( | 1369 go.Scatter( |
| 1022 x=bin_centers, | 1370 x=fpr, |
| 1023 y=frac_positives, | 1371 y=tpr, |
| 1024 mode="lines+markers", | 1372 mode="lines", |
| 1025 name="Calibration", | 1373 name=f"{lbl} (AUC={auc(fpr, tpr):.3f})", |
| 1026 line=dict(color="#2ca02c", width=4), | 1374 line=dict(width=3), |
| 1027 ) | 1375 ) |
| 1028 ) | 1376 ) |
| 1029 fig_cal.add_trace( | 1377 added_any = True |
| 1378 # Micro-average only if we have mixed labels | |
| 1379 if y_true_bin.sum() > 0 and y_true_bin.sum() < y_true_bin.size: | |
| 1380 fpr_micro, tpr_micro, _ = roc_curve(y_true_bin.ravel(), y_score.ravel()) | |
| 1381 fig_roc.add_trace( | |
| 1030 go.Scatter( | 1382 go.Scatter( |
| 1031 x=[0, 1], | 1383 x=fpr_micro, |
| 1032 y=[0, 1], | 1384 y=tpr_micro, |
| 1033 mode="lines", | 1385 mode="lines", |
| 1034 name="Perfect Calibration", | 1386 name=f"Micro-average (AUC={auc(fpr_micro, tpr_micro):.3f})", |
| 1035 line=dict(color="gray", width=2, dash="dash"), | 1387 line=dict(width=3, dash="dash"), |
| 1036 ) | 1388 ) |
| 1037 ) | 1389 ) |
| 1038 fig_cal.update_layout( | 1390 added_any = True |
| 1039 title=dict(text="Calibration Curve", x=0.5), | 1391 if not added_any: |
| 1040 xaxis_title="Predicted probability", | 1392 return [] |
| 1041 yaxis_title="Observed frequency", | 1393 fig_roc.add_trace( |
| 1042 width=700, | 1394 go.Scatter( |
| 1043 height=500, | 1395 x=[0, 1], |
| 1044 ) | 1396 y=[0, 1], |
| 1045 _style_fig(fig_cal) | 1397 mode="lines", |
| 1046 plots.append({ | 1398 name="Random", |
| 1047 "title": "Calibration Curve (Predicted Probability vs Observed Frequency)", | 1399 line=dict(color="gray", width=2, dash="dot"), |
| 1048 "html": pio.to_html(fig_cal, full_html=False, include_plotlyjs=False), | 1400 ) |
| 1049 }) | 1401 ) |
| 1050 | 1402 fig_roc.update_layout( |
| 1051 # Plot 3: Threshold vs Metrics | 1403 title=dict(text="Multi-class ROC-AUC (one-vs-rest)", x=0.5), |
| 1052 thresholds = np.linspace(0.0, 1.0, 21) | 1404 xaxis_title="False Positive Rate", |
| 1053 accs, f1s, sens, specs = [], [], [], [] | 1405 yaxis_title="True Positive Rate", |
| 1054 for t in thresholds: | 1406 width=820, |
| 1055 y_pred = (y_score >= t).astype(int) | 1407 height=620, |
| 1056 tp = np.sum((y_true == 1) & (y_pred == 1)) | |
| 1057 tn = np.sum((y_true == 0) & (y_pred == 0)) | |
| 1058 fp = np.sum((y_true == 0) & (y_pred == 1)) | |
| 1059 fn = np.sum((y_true == 1) & (y_pred == 0)) | |
| 1060 acc = (tp + tn) / max(len(y_true), 1) | |
| 1061 prec = tp / max(tp + fp, 1e-9) | |
| 1062 rec = tp / max(tp + fn, 1e-9) | |
| 1063 f1 = 0 if (prec + rec) == 0 else 2 * prec * rec / (prec + rec) | |
| 1064 sensitivity = rec | |
| 1065 specificity = tn / max(tn + fp, 1e-9) | |
| 1066 accs.append(acc) | |
| 1067 f1s.append(f1) | |
| 1068 sens.append(sensitivity) | |
| 1069 specs.append(specificity) | |
| 1070 | |
| 1071 fig_thresh = go.Figure() | |
| 1072 fig_thresh.add_trace(go.Scatter(x=thresholds, y=accs, mode="lines", name="Accuracy", line=dict(width=4))) | |
| 1073 fig_thresh.add_trace(go.Scatter(x=thresholds, y=f1s, mode="lines", name="F1", line=dict(width=4))) | |
| 1074 fig_thresh.add_trace(go.Scatter(x=thresholds, y=sens, mode="lines", name="Sensitivity", line=dict(width=4))) | |
| 1075 fig_thresh.add_trace(go.Scatter(x=thresholds, y=specs, mode="lines", name="Specificity", line=dict(width=4))) | |
| 1076 fig_thresh.update_layout( | |
| 1077 title=dict(text="Threshold Sweep: Accuracy, F1, Sensitivity, Specificity", x=0.5), | |
| 1078 xaxis_title="Decision threshold", | |
| 1079 yaxis_title="Metric value", | |
| 1080 width=700, | |
| 1081 height=500, | |
| 1082 legend=dict( | 1408 legend=dict( |
| 1083 x=0.7, | 1409 x=0.62, |
| 1084 y=0.2, | 1410 y=0.05, |
| 1085 bgcolor="rgba(255,255,255,0.9)", | 1411 bgcolor="rgba(255,255,255,0.9)", |
| 1086 bordercolor="rgba(0,0,0,0.2)", | 1412 bordercolor="rgba(0,0,0,0.2)", |
| 1087 borderwidth=1, | 1413 borderwidth=1, |
| 1088 ), | 1414 ), |
| 1089 shapes=[ | |
| 1090 dict( | |
| 1091 type="line", | |
| 1092 x0=threshold, | |
| 1093 x1=threshold, | |
| 1094 y0=0, | |
| 1095 y1=1, | |
| 1096 xref="x", | |
| 1097 yref="paper", | |
| 1098 line=dict(color="#d62728", width=2, dash="dash"), | |
| 1099 ) | |
| 1100 ] if isinstance(threshold, (int, float)) else [], | |
| 1101 annotations=[ | |
| 1102 dict( | |
| 1103 x=threshold, | |
| 1104 y=1.02, | |
| 1105 xref="x", | |
| 1106 yref="paper", | |
| 1107 showarrow=False, | |
| 1108 text=f"Threshold = {threshold:.2f}", | |
| 1109 font=dict(size=11, color="#d62728"), | |
| 1110 ) | |
| 1111 ] if isinstance(threshold, (int, float)) else [], | |
| 1112 ) | 1415 ) |
| 1113 _style_fig(fig_thresh) | 1416 _style_fig(fig_roc) |
| 1114 plots.append({ | 1417 plots.append(_wrap_plot("Multi-class ROC-AUC (one-vs-rest)", fig_roc)) |
| 1115 "title": "Threshold Sweep: Accuracy, F1, Sensitivity, Specificity", | 1418 |
| 1116 "html": pio.to_html(fig_thresh, full_html=False, include_plotlyjs=False), | 1419 # PR: one-vs-rest + micro AP |
| 1117 }) | 1420 fig_pr = go.Figure() |
| 1421 added_pr = False | |
| 1422 for idx, lbl in enumerate(labels_sorted): | |
| 1423 if y_true_bin[:, idx].sum() == 0: | |
| 1424 continue | |
| 1425 prec, rec, _ = precision_recall_curve(y_true_bin[:, idx], y_score[:, idx]) | |
| 1426 ap = average_precision_score(y_true_bin[:, idx], y_score[:, idx]) | |
| 1427 fig_pr.add_trace( | |
| 1428 go.Scatter( | |
| 1429 x=rec, | |
| 1430 y=prec, | |
| 1431 mode="lines", | |
| 1432 name=f"{lbl} (AP={ap:.3f})", | |
| 1433 line=dict(width=3), | |
| 1434 ) | |
| 1435 ) | |
| 1436 added_pr = True | |
| 1437 if y_true_bin.sum() > 0: | |
| 1438 prec_micro, rec_micro, _ = precision_recall_curve(y_true_bin.ravel(), y_score.ravel()) | |
| 1439 ap_micro = average_precision_score(y_true_bin, y_score, average="micro") | |
| 1440 fig_pr.add_trace( | |
| 1441 go.Scatter( | |
| 1442 x=rec_micro, | |
| 1443 y=prec_micro, | |
| 1444 mode="lines", | |
| 1445 name=f"Micro-average (AP={ap_micro:.3f})", | |
| 1446 line=dict(width=3, dash="dash"), | |
| 1447 ) | |
| 1448 ) | |
| 1449 added_pr = True | |
| 1450 if not added_pr: | |
| 1451 return plots | |
| 1452 fig_pr.update_layout( | |
| 1453 title=dict(text="Multi-class Precision-Recall (one-vs-rest)", x=0.5), | |
| 1454 xaxis_title="Recall", | |
| 1455 yaxis_title="Precision", | |
| 1456 width=820, | |
| 1457 height=620, | |
| 1458 legend=dict( | |
| 1459 x=0.62, | |
| 1460 y=0.05, | |
| 1461 bgcolor="rgba(255,255,255,0.9)", | |
| 1462 bordercolor="rgba(0,0,0,0.2)", | |
| 1463 borderwidth=1, | |
| 1464 ), | |
| 1465 ) | |
| 1466 _style_fig(fig_pr) | |
| 1467 plots.append(_wrap_plot("Multi-class Precision-Recall (one-vs-rest)", fig_pr)) | |
| 1118 | 1468 |
| 1119 return plots | 1469 return plots |
| 1470 | |
| 1471 | |
| 1472 def build_multiclass_metric_plots(test_stats_path: str) -> List[Dict[str, str]]: | |
| 1473 """Alternative multi-class transparency plots using test_statistics.json per-class stats.""" | |
| 1474 ts_path = Path(test_stats_path) | |
| 1475 if not ts_path.exists(): | |
| 1476 return [] | |
| 1477 try: | |
| 1478 with open(ts_path, "r") as f: | |
| 1479 test_stats = json.load(f) | |
| 1480 except Exception: | |
| 1481 return [] | |
| 1482 | |
| 1483 label_stats = test_stats.get("label", {}) | |
| 1484 pcs = label_stats.get("per_class_stats", {}) | |
| 1485 if not pcs: | |
| 1486 return [] | |
| 1487 classes = list(pcs.keys()) | |
| 1488 if not classes: | |
| 1489 return [] | |
| 1490 | |
| 1491 metrics = ["precision", "recall", "f1_score", "specificity", "accuracy"] | |
| 1492 fig_bar = go.Figure() | |
| 1493 for metric in metrics: | |
| 1494 values = [] | |
| 1495 for cls in classes: | |
| 1496 v = pcs.get(cls, {}).get(metric) | |
| 1497 values.append(v if isinstance(v, (int, float)) else 0) | |
| 1498 fig_bar.add_trace( | |
| 1499 go.Bar( | |
| 1500 x=classes, | |
| 1501 y=values, | |
| 1502 name=metric.replace("_", " ").title(), | |
| 1503 ) | |
| 1504 ) | |
| 1505 fig_bar.update_layout( | |
| 1506 title=dict(text="Per-Class Metrics (Test)", x=0.5), | |
| 1507 xaxis_title="Class", | |
| 1508 yaxis_title="Metric value", | |
| 1509 barmode="group", | |
| 1510 width=900, | |
| 1511 height=600, | |
| 1512 legend=dict( | |
| 1513 x=1.02, | |
| 1514 y=1.0, | |
| 1515 bgcolor="rgba(255,255,255,0.9)", | |
| 1516 bordercolor="rgba(0,0,0,0.2)", | |
| 1517 borderwidth=1, | |
| 1518 ), | |
| 1519 ) | |
| 1520 _style_fig(fig_bar) | |
| 1521 | |
| 1522 return [_wrap_plot("Per-Class Metrics (Test)", fig_bar)] |
