comparison plotly_plots.py @ 17:db9be962dc13 draft

planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
author goeckslab
date Wed, 10 Dec 2025 00:24:13 +0000
parents d17e3a1b8659
children
comparison
equal deleted inserted replaced
16:8729f69e9207 17:db9be962dc13
5 import numpy as np 5 import numpy as np
6 import pandas as pd 6 import pandas as pd
7 import plotly.graph_objects as go 7 import plotly.graph_objects as go
8 import plotly.io as pio 8 import plotly.io as pio
9 from constants import LABEL_COLUMN_NAME, SPLIT_COLUMN_NAME 9 from constants import LABEL_COLUMN_NAME, SPLIT_COLUMN_NAME
10 from sklearn.metrics import (
11 accuracy_score,
12 auc,
13 average_precision_score,
14 f1_score,
15 precision_recall_curve,
16 precision_score,
17 recall_score,
18 roc_curve,
19 )
20 from sklearn.preprocessing import label_binarize
10 21
11 22
12 def _style_fig(fig: go.Figure, font_size: int = 12) -> go.Figure: 23 def _style_fig(fig: go.Figure, font_size: int = 12) -> go.Figure:
13 """Apply consistent styling across Plotly figures.""" 24 """Apply consistent styling across Plotly figures."""
14 fig.update_layout( 25 fig.update_layout(
16 plot_bgcolor="#ffffff", 27 plot_bgcolor="#ffffff",
17 paper_bgcolor="#ffffff", 28 paper_bgcolor="#ffffff",
18 ) 29 )
19 fig.update_xaxes(gridcolor="#e8e8e8") 30 fig.update_xaxes(gridcolor="#e8e8e8")
20 fig.update_yaxes(gridcolor="#e8e8e8") 31 fig.update_yaxes(gridcolor="#e8e8e8")
32 return fig
33
34
35 def _fig_to_html(
36 fig: go.Figure, *, include_js: bool = False, config: Optional[dict] = None
37 ) -> str:
38 """Render a Plotly figure to a lightweight HTML fragment."""
39 include_plotlyjs = "cdn" if include_js else False
40 return pio.to_html(
41 fig,
42 full_html=False,
43 include_plotlyjs=include_plotlyjs,
44 config=config,
45 )
46
47
48 def _wrap_plot(
49 title: str,
50 fig: go.Figure,
51 *,
52 include_js: bool = False,
53 config: Optional[dict] = None,
54 ) -> Dict[str, str]:
55 """Package a figure with its title for downstream HTML rendering."""
56 return {"title": title, "html": _fig_to_html(fig, include_js=include_js, config=config)}
57
58
59 def _line_chart(
60 traces: List[tuple],
61 *,
62 title: str,
63 yaxis_title: str,
64 ) -> go.Figure:
65 """Build a basic epoch-indexed line chart for train/val/test curves."""
66 fig = go.Figure()
67 for name, series in traces:
68 if not series:
69 continue
70 epochs = list(range(1, len(series) + 1))
71 fig.add_trace(
72 go.Scatter(
73 x=epochs,
74 y=series,
75 mode="lines+markers",
76 name=name,
77 line=dict(width=4),
78 )
79 )
80
81 fig.update_layout(
82 title=dict(text=title, x=0.5),
83 xaxis_title="Epoch",
84 yaxis_title=yaxis_title,
85 width=760,
86 height=520,
87 hovermode="x unified",
88 )
89 _style_fig(fig)
21 return fig 90 return fig
22 91
23 92
24 def _labels_from_metadata_dict(meta_dict: dict) -> List[str]: 93 def _labels_from_metadata_dict(meta_dict: dict) -> List[str]:
25 """Extract ordered label names from Ludwig train_set_metadata.""" 94 """Extract ordered label names from Ludwig train_set_metadata."""
104 def build_classification_plots( 173 def build_classification_plots(
105 test_stats_path: str, 174 test_stats_path: str,
106 training_stats_path: Optional[str] = None, 175 training_stats_path: Optional[str] = None,
107 metadata_csv_path: Optional[str] = None, 176 metadata_csv_path: Optional[str] = None,
108 train_set_metadata_path: Optional[str] = None, 177 train_set_metadata_path: Optional[str] = None,
178 threshold: Optional[float] = None,
109 ) -> List[Dict[str, str]]: 179 ) -> List[Dict[str, str]]:
110 """ 180 """
111 Read Ludwig’s test_statistics.json and build three interactive Plotly panels: 181 Read Ludwig’s test_statistics.json and build three interactive Plotly panels:
112 - Confusion Matrix 182 - Confusion Matrix
113 - ROC-AUC 183 - ROC-AUC
154 showscale=True, 224 showscale=True,
155 colorbar=dict(title="Count"), 225 colorbar=dict(title="Count"),
156 ) 226 )
157 ) 227 )
158 fig_cm.update_traces(xgap=2, ygap=2) 228 fig_cm.update_traces(xgap=2, ygap=2)
229 cm_title = "Confusion Matrix"
230 if threshold is not None:
231 cm_title = f"Confusion Matrix (Threshold: {threshold})"
159 fig_cm.update_layout( 232 fig_cm.update_layout(
160 title=dict(text="Confusion Matrix", x=0.5), 233 title=dict(text=cm_title, x=0.5),
161 xaxis_title="Predicted", 234 xaxis_title="Predicted",
162 yaxis_title="Observed", 235 yaxis_title="Observed",
163 yaxis_autorange="reversed", 236 yaxis_autorange="reversed",
164 width=side_px, 237 width=side_px,
165 height=side_px, 238 height=side_px,
194 xanchor="center", 267 xanchor="center",
195 yanchor="top", 268 yanchor="top",
196 yshift=-2, 269 yshift=-2,
197 ) 270 )
198 271
199 plots.append({ 272 plots.append(
200 "title": "Confusion Matrix", 273 _wrap_plot("Confusion Matrix", fig_cm, include_js=True, config=common_cfg)
201 "html": pio.to_html( 274 )
202 fig_cm, 275
203 full_html=False, 276 # 1) ROC / PR curves only for binary tasks
204 include_plotlyjs="cdn", 277 if n_classes == 2:
205 config=common_cfg 278 roc_plot = _build_static_roc_plot(label_stats, common_cfg, friendly_labels=labels)
206 ) 279 if roc_plot:
207 }) 280 plots.append(roc_plot)
208 281
209 # 1) ROC Curve (from test_statistics) 282 pr_plot = _build_precision_recall_plot(label_stats, common_cfg)
210 roc_plot = _build_static_roc_plot(label_stats, common_cfg, friendly_labels=labels) 283 if pr_plot:
211 if roc_plot: 284 plots.append(pr_plot)
212 plots.append(roc_plot)
213
214 # 2) Precision-Recall Curve (from test_statistics)
215 pr_plot = _build_precision_recall_plot(label_stats, common_cfg)
216 if pr_plot:
217 plots.append(pr_plot)
218 285
219 # 2) Classification Report Heatmap 286 # 2) Classification Report Heatmap
220 pcs = label_stats.get("per_class_stats", {}) 287 pcs = label_stats.get("per_class_stats", {})
221 if pcs: 288 if pcs:
222 classes = list(pcs.keys()) 289 classes = list(pcs.keys())
257 width=side_px, 324 width=side_px,
258 height=side_px, 325 height=side_px,
259 margin=dict(t=80, l=80, r=80, b=80), 326 margin=dict(t=80, l=80, r=80, b=80),
260 ) 327 )
261 _style_fig(fig_cr) 328 _style_fig(fig_cr)
262 plots.append({ 329 plots.append(
263 "title": "Per-Class metrics", 330 _wrap_plot("Per-Class metrics", fig_cr, config=common_cfg)
264 "html": pio.to_html( 331 )
265 fig_cr,
266 full_html=False,
267 include_plotlyjs=False,
268 config=common_cfg
269 )
270 })
271 332
272 # 3) Prediction Diagnostics (from predictions.csv) 333 # 3) Prediction Diagnostics (from predictions.csv)
273 # Note: appended separately in generate_html_report, not returned here. 334 # Note: appended separately in generate_html_report, not returned here.
274 335
275 return plots 336 return plots
292 return [] 353 return []
293 plots: List[Dict[str, str]] = [] 354 plots: List[Dict[str, str]] = []
294 include_js = True # Load Plotly.js once for this group 355 include_js = True # Load Plotly.js once for this group
295 356
296 def _get_series(stats: dict, metric: str) -> List[float]: 357 def _get_series(stats: dict, metric: str) -> List[float]:
297 if metric not in stats:
298 return []
299 vals = stats.get(metric, []) 358 vals = stats.get(metric, [])
300 if isinstance(vals, list): 359 if isinstance(vals, list):
301 return [float(v) for v in vals] 360 return [float(v) for v in vals]
302 try: 361 try:
303 return [float(vals)] 362 return [float(vals)]
304 except Exception: 363 except Exception:
305 return [] 364 return []
306 365
307 def _line_plot(metric_key: str, title: str, yaxis_title: str) -> Optional[Dict[str, str]]: 366 metric_specs = [
308 train_series = _get_series(label_train, metric_key) 367 ("loss", "Loss across epochs", "Loss"),
309 val_series = _get_series(label_val, metric_key) 368 ("accuracy", "Accuracy across epochs", "Accuracy"),
369 ("roc_auc", "ROC-AUC across epochs", "ROC-AUC"),
370 ("precision", "Precision across epochs", "Precision"),
371 ("recall", "Recall/Sensitivity across epochs", "Recall"),
372 ("specificity", "Specificity across epochs", "Specificity"),
373 ]
374
375 for key, title, yaxis in metric_specs:
376 train_series = _get_series(label_train, key)
377 val_series = _get_series(label_val, key)
310 if not train_series and not val_series: 378 if not train_series and not val_series:
311 return None 379 continue
312 epochs_train = list(range(1, len(train_series) + 1)) 380 fig = _line_chart(
313 epochs_val = list(range(1, len(val_series) + 1)) 381 [("Train", train_series), ("Validation", val_series)],
314 fig = go.Figure() 382 title=title,
315 if train_series: 383 yaxis_title=yaxis,
316 fig.add_trace( 384 )
317 go.Scatter( 385 plots.append(_wrap_plot(title, fig, include_js=include_js))
318 x=epochs_train, 386 include_js = False
319 y=train_series,
320 mode="lines+markers",
321 name="Train",
322 line=dict(width=4),
323 )
324 )
325 if val_series:
326 fig.add_trace(
327 go.Scatter(
328 x=epochs_val,
329 y=val_series,
330 mode="lines+markers",
331 name="Validation",
332 line=dict(width=4),
333 )
334 )
335 fig.update_layout(
336 title=dict(text=title, x=0.5),
337 xaxis_title="Epoch",
338 yaxis_title=yaxis_title,
339 width=760,
340 height=520,
341 hovermode="x unified",
342 )
343 _style_fig(fig)
344 return {
345 "title": title,
346 "html": pio.to_html(
347 fig,
348 full_html=False,
349 include_plotlyjs="cdn" if include_js else False,
350 ),
351 }
352
353 # Core learning curves
354 for key, title in [
355 ("roc_auc", "ROC-AUC across epochs"),
356 ("precision", "Precision across epochs"),
357 ("recall", "Recall/Sensitivity across epochs"),
358 ("specificity", "Specificity across epochs"),
359 ]:
360 plot = _line_plot(key, title, title.replace("Learning Curve", "").strip())
361 if plot:
362 plots.append(plot)
363 include_js = False
364 387
365 # Precision vs Recall evolution (validation) 388 # Precision vs Recall evolution (validation)
366 val_prec = _get_series(label_val, "precision") 389 val_prec = _get_series(label_val, "precision")
367 val_rec = _get_series(label_val, "recall") 390 val_rec = _get_series(label_val, "recall")
368 if val_prec and val_rec: 391 if val_prec and val_rec:
369 epochs = list(range(1, min(len(val_prec), len(val_rec)) + 1)) 392 max_len = min(len(val_prec), len(val_rec))
370 fig_pr = go.Figure() 393 fig_pr = _line_chart(
371 fig_pr.add_trace( 394 [
372 go.Scatter( 395 ("Precision", val_prec[:max_len]),
373 x=epochs, 396 ("Recall", val_rec[:max_len]),
374 y=val_prec[: len(epochs)], 397 ],
375 mode="lines+markers", 398 title="Validation Precision and Recall by Epoch",
376 name="Precision",
377 )
378 )
379 fig_pr.add_trace(
380 go.Scatter(
381 x=epochs,
382 y=val_rec[: len(epochs)],
383 mode="lines+markers",
384 name="Recall",
385 )
386 )
387 fig_pr.update_layout(
388 title=dict(text="Validation Precision and Recall by Epoch", x=0.5),
389 xaxis_title="Epoch",
390 yaxis_title="Value", 399 yaxis_title="Value",
391 width=760, 400 )
392 height=520, 401 plots.append(_wrap_plot("Precision vs Recall Evolution", fig_pr, include_js=include_js))
393 hovermode="x unified",
394 )
395 _style_fig(fig_pr)
396 plots.append({
397 "title": "Precision vs Recall Evolution",
398 "html": pio.to_html(
399 fig_pr,
400 full_html=False,
401 include_plotlyjs="cdn" if include_js else False,
402 ),
403 })
404 include_js = False 402 include_js = False
405 403
406 # F1-score derived
407 def _compute_f1(p: List[float], r: List[float]) -> List[float]: 404 def _compute_f1(p: List[float], r: List[float]) -> List[float]:
408 f1_vals = [] 405 return [
409 for prec, rec in zip(p, r): 406 0.0 if (prec + rec) == 0 else 2 * prec * rec / (prec + rec)
410 if (prec + rec) == 0: 407 for prec, rec in zip(p, r)
411 f1_vals.append(0.0) 408 ]
412 else:
413 f1_vals.append(2 * prec * rec / (prec + rec))
414 return f1_vals
415 409
416 f1_train = _compute_f1(_get_series(label_train, "precision"), _get_series(label_train, "recall")) 410 f1_train = _compute_f1(_get_series(label_train, "precision"), _get_series(label_train, "recall"))
417 f1_val = _compute_f1(val_prec, val_rec) 411 f1_val = _compute_f1(val_prec, val_rec)
418 if f1_train or f1_val: 412 if f1_train or f1_val:
419 fig = go.Figure() 413 fig_f1 = _line_chart(
420 if f1_train: 414 [("Train", f1_train), ("Validation", f1_val)],
421 fig.add_trace(go.Scatter(x=list(range(1, len(f1_train) + 1)), y=f1_train, mode="lines+markers", name="Train", line=dict(width=4))) 415 title="F1-Score across epochs (derived)",
422 if f1_val:
423 fig.add_trace(go.Scatter(x=list(range(1, len(f1_val) + 1)), y=f1_val, mode="lines+markers", name="Validation", line=dict(width=4)))
424 fig.update_layout(
425 title=dict(text="F1-Score across epochs (derived)", x=0.5),
426 xaxis_title="Epoch",
427 yaxis_title="F1-Score", 416 yaxis_title="F1-Score",
428 width=760, 417 )
429 height=520, 418 plots.append(_wrap_plot("F1-Score across epochs (derived)", fig_f1, include_js=include_js))
430 hovermode="x unified",
431 )
432 _style_fig(fig)
433 plots.append({
434 "title": "F1-Score across epochs (derived)",
435 "html": pio.to_html(
436 fig,
437 full_html=False,
438 include_plotlyjs="cdn" if include_js else False,
439 ),
440 })
441 include_js = False 419 include_js = False
442 420
443 # Overfitting Gap: Train vs Val ROC-AUC (gap) 421 # Overfitting Gap: Train vs Val ROC-AUC (gap)
444 roc_train = _get_series(label_train, "roc_auc") 422 roc_train = _get_series(label_train, "roc_auc")
445 roc_val = _get_series(label_val, "roc_auc") 423 roc_val = _get_series(label_val, "roc_auc")
446 if roc_train and roc_val: 424 if roc_train and roc_val:
447 epochs_gap = list(range(1, min(len(roc_train), len(roc_val)) + 1)) 425 max_len = min(len(roc_train), len(roc_val))
448 gaps = [t - v for t, v in zip(roc_train[:len(epochs_gap)], roc_val[:len(epochs_gap)])] 426 gaps = [t - v for t, v in zip(roc_train[:max_len], roc_val[:max_len])]
449 fig_gap = go.Figure() 427 fig_gap = _line_chart(
450 fig_gap.add_trace(go.Scatter(x=epochs_gap, y=gaps, mode="lines+markers", name="Train - Val ROC-AUC", line=dict(width=4))) 428 [("Train - Val ROC-AUC", gaps)],
451 fig_gap.update_layout( 429 title="Overfitting gap: ROC-AUC across epochs",
452 title=dict(text="Overfitting gap: ROC-AUC across epochs", x=0.5),
453 xaxis_title="Epoch",
454 yaxis_title="Gap", 430 yaxis_title="Gap",
455 width=760, 431 )
456 height=520, 432 plots.append(_wrap_plot("Overfitting gap: ROC-AUC across epochs", fig_gap, include_js=include_js))
457 hovermode="x unified",
458 )
459 _style_fig(fig_gap)
460 plots.append({
461 "title": "Overfitting gap: ROC-AUC across epochs",
462 "html": pio.to_html(
463 fig_gap,
464 full_html=False,
465 include_plotlyjs="cdn" if include_js else False,
466 ),
467 })
468 include_js = False 433 include_js = False
469 434
470 # Best Epoch Dashboard (based on max val ROC-AUC) 435 # Best Epoch Dashboard (based on max val ROC-AUC)
471 if roc_val: 436 if roc_val:
472 best_idx = int(np.argmax(roc_val)) 437 best_idx = int(np.argmax(roc_val))
473 best_epoch = best_idx + 1 438 best_epoch = best_idx + 1
474 spec_val = _get_series(label_val, "specificity") 439 metrics_at_best: Dict[str, Optional[float]] = {
475 metrics_at_best = { 440 "ROC-AUC": roc_val[best_idx] if best_idx < len(roc_val) else None
476 "ROC-AUC": roc_val[best_idx] if best_idx < len(roc_val) else None,
477 "Precision": val_prec[best_idx] if best_idx < len(val_prec) else None,
478 "Recall": val_rec[best_idx] if best_idx < len(val_rec) else None,
479 "Specificity": spec_val[best_idx] if best_idx < len(spec_val) else None,
480 "F1-Score": f1_val[best_idx] if best_idx < len(f1_val) else None,
481 } 441 }
442
443 for metric_key, label in [
444 ("accuracy", "Accuracy"),
445 ("balanced_accuracy", "Balanced Accuracy"),
446 ("precision", "Precision"),
447 ("recall", "Recall"),
448 ("specificity", "Specificity"),
449 ("loss", "Loss"),
450 ]:
451 series = _get_series(label_val, metric_key)
452 if series and best_idx < len(series):
453 metrics_at_best[label] = series[best_idx]
454
455 if f1_val and best_idx < len(f1_val):
456 metrics_at_best["F1-Score (derived)"] = f1_val[best_idx]
457
482 fig_best = go.Figure() 458 fig_best = go.Figure()
483 for name, value in metrics_at_best.items(): 459 for name, value in metrics_at_best.items():
484 if value is not None: 460 if value is not None:
485 fig_best.add_trace(go.Bar(name=name, x=[name], y=[value])) 461 fig_best.add_trace(go.Bar(name=name, x=[name], y=[value]))
486 fig_best.update_layout( 462 fig_best.update_layout(
490 width=760, 466 width=760,
491 height=520, 467 height=520,
492 showlegend=False, 468 showlegend=False,
493 ) 469 )
494 _style_fig(fig_best) 470 _style_fig(fig_best)
495 plots.append({ 471 plots.append(_wrap_plot("Best Validation Epoch Snapshot (Metrics)", fig_best, include_js=include_js))
496 "title": "Best Validation Epoch Snapshot (Metrics)",
497 "html": pio.to_html(
498 fig_best,
499 full_html=False,
500 include_plotlyjs="cdn" if include_js else False,
501 ),
502 })
503 include_js = False
504 472
505 return plots 473 return plots
506 474
507 475
508 def _get_regression_series(split_stats: dict, metric: str) -> List[float]: 476 def _get_regression_series(split_stats: dict, metric: str) -> List[float]:
527 ) -> Optional[Dict[str, str]]: 495 ) -> Optional[Dict[str, str]]:
528 train_series = _get_regression_series(train_split, metric_key) 496 train_series = _get_regression_series(train_split, metric_key)
529 val_series = _get_regression_series(val_split, metric_key) 497 val_series = _get_regression_series(val_split, metric_key)
530 if not train_series and not val_series: 498 if not train_series and not val_series:
531 return None 499 return None
532 epochs_train = list(range(1, len(train_series) + 1)) 500
533 epochs_val = list(range(1, len(val_series) + 1)) 501 fig = _line_chart(
534 fig = go.Figure() 502 [("Train", train_series), ("Validation", val_series)],
535 if train_series: 503 title=title,
536 fig.add_trace(
537 go.Scatter(
538 x=epochs_train,
539 y=train_series,
540 mode="lines+markers",
541 name="Train",
542 line=dict(width=4),
543 )
544 )
545 if val_series:
546 fig.add_trace(
547 go.Scatter(
548 x=epochs_val,
549 y=val_series,
550 mode="lines+markers",
551 name="Validation",
552 line=dict(width=4),
553 )
554 )
555 fig.update_layout(
556 title=dict(text=title, x=0.5),
557 xaxis_title="Epoch",
558 yaxis_title=yaxis_title, 504 yaxis_title=yaxis_title,
559 width=760,
560 height=520,
561 hovermode="x unified",
562 ) 505 )
563 _style_fig(fig) 506 return _wrap_plot(title, fig, include_js=include_js)
564 return {
565 "title": title,
566 "html": pio.to_html(
567 fig,
568 full_html=False,
569 include_plotlyjs="cdn" if include_js else False,
570 ),
571 }
572 507
573 508
574 def build_regression_train_val_plots(train_stats_path: str) -> List[Dict[str, str]]: 509 def build_regression_train_val_plots(train_stats_path: str) -> List[Dict[str, str]]:
575 """Generate regression Train/Validation learning curve plots from training_statistics.json.""" 510 """Generate regression Train/Validation learning curve plots from training_statistics.json."""
576 if not train_stats_path or not Path(train_stats_path).exists(): 511 if not train_stats_path or not Path(train_stats_path).exists():
625 ("root_mean_squared_error", "Root Mean Squared Error Across Epochs", "RMSE"), 560 ("root_mean_squared_error", "Root Mean Squared Error Across Epochs", "RMSE"),
626 ("mean_absolute_percentage_error", "Mean Absolute Percentage Error Across Epochs", "MAPE"), 561 ("mean_absolute_percentage_error", "Mean Absolute Percentage Error Across Epochs", "MAPE"),
627 ("r2", "R² Across Epochs", "R²"), 562 ("r2", "R² Across Epochs", "R²"),
628 ("loss", "Loss Across Epochs", "Loss"), 563 ("loss", "Loss Across Epochs", "Loss"),
629 ] 564 ]
630 epochs = None
631 for metric_key, title, ytitle in metrics: 565 for metric_key, title, ytitle in metrics:
632 series = _get_regression_series(label_test, metric_key) 566 series = _get_regression_series(label_test, metric_key)
633 if not series: 567 if not series:
634 continue 568 continue
635 if epochs is None: 569 fig = _line_chart(
636 epochs = list(range(1, len(series) + 1)) 570 [("Test", series)],
637 fig = go.Figure() 571 title=title,
638 fig.add_trace(
639 go.Scatter(
640 x=epochs,
641 y=series[: len(epochs)],
642 mode="lines+markers",
643 name="Test",
644 line=dict(width=4),
645 )
646 )
647 fig.update_layout(
648 title=dict(text=title, x=0.5),
649 xaxis_title="Epoch",
650 yaxis_title=ytitle, 572 yaxis_title=ytitle,
651 width=760, 573 )
652 height=520, 574 plots.append(_wrap_plot(title, fig, include_js=include_js))
653 hovermode="x unified",
654 )
655 _style_fig(fig)
656 plots.append({
657 "title": title,
658 "html": pio.to_html(
659 fig,
660 full_html=False,
661 include_plotlyjs="cdn" if include_js else False,
662 ),
663 })
664 include_js = False 575 include_js = False
665 return plots 576 return plots
666 577
667 578
668 def _build_static_roc_plot( 579 def _build_static_roc_plot(
669 label_stats: dict, config: dict, friendly_labels: Optional[List[str]] = None 580 label_stats: dict,
581 config: dict,
582 friendly_labels: Optional[List[str]] = None,
583 threshold: Optional[float] = None,
670 ) -> Optional[Dict[str, str]]: 584 ) -> Optional[Dict[str, str]]:
671 """Build ROC curve directly from test_statistics.json (single curve).""" 585 """Build ROC curve directly from test_statistics.json (single curve)."""
672 roc_data = label_stats.get("roc_curve") 586 roc_data = label_stats.get("roc_curve")
673 if not isinstance(roc_data, dict): 587 if not isinstance(roc_data, dict):
674 return None 588 return None
774 ) 688 )
775 _style_fig(fig) 689 _style_fig(fig)
776 fig.update_xaxes(range=[0, 1.0]) 690 fig.update_xaxes(range=[0, 1.0])
777 fig.update_yaxes(range=[0, 1.05]) 691 fig.update_yaxes(range=[0, 1.05])
778 692
693 roc_thresholds = roc_data.get("thresholds")
694 if threshold is not None and isinstance(roc_thresholds, list) and len(roc_thresholds) == len(fpr):
695 try:
696 diffs = [abs(th - threshold) for th in roc_thresholds]
697 best_idx = int(np.argmin(diffs))
698 # dashed guides through the chosen point
699 fig.add_shape(
700 type="line",
701 x0=fpr[best_idx],
702 x1=fpr[best_idx],
703 y0=0,
704 y1=tpr[best_idx],
705 line=dict(color="gray", width=2, dash="dash"),
706 )
707 fig.add_shape(
708 type="line",
709 x0=0,
710 x1=fpr[best_idx],
711 y0=tpr[best_idx],
712 y1=tpr[best_idx],
713 line=dict(color="gray", width=2, dash="dash"),
714 )
715 fig.add_trace(
716 go.Scatter(
717 x=[fpr[best_idx]],
718 y=[tpr[best_idx]],
719 mode="markers",
720 marker=dict(color="black", size=10, symbol="x"),
721 name=f"Threshold={threshold}",
722 hovertemplate="FPR: %{x:.3f}<br>TPR: %{y:.3f}<br>Threshold: %{text}<extra></extra>",
723 text=[f"{threshold}"],
724 )
725 )
726 except Exception as exc:
727 print(f"Warning: could not add threshold marker to ROC: {exc}")
728
779 fig.add_annotation( 729 fig.add_annotation(
780 x=0.5, 730 x=0.5,
781 y=-0.15, 731 y=-0.15,
782 xref="paper", 732 xref="paper",
783 yref="paper", 733 yref="paper",
784 showarrow=False, 734 showarrow=False,
785 text=f"<span style='font-size:12px;color:#555;'>{pos_label_txt}</span>", 735 text=f"<span style='font-size:12px;color:#555;'>{pos_label_txt}</span>",
786 xanchor="center", 736 xanchor="center",
787 ) 737 )
788 738
789 return { 739 return _wrap_plot("ROC Curve", fig, config=config)
790 "title": "ROC Curve",
791 "html": pio.to_html(
792 fig,
793 full_html=False,
794 include_plotlyjs=False,
795 config=config,
796 ),
797 }
798 except Exception as e: 740 except Exception as e:
799 print(f"Error building ROC plot: {e}") 741 print(f"Error building ROC plot: {e}")
800 return None 742 return None
801 743
802 744
803 def _build_precision_recall_plot(label_stats: dict, config: dict) -> Optional[Dict[str, str]]: 745 def _build_precision_recall_plot(
746 label_stats: dict,
747 config: dict,
748 threshold: Optional[float] = None,
749 ) -> Optional[Dict[str, str]]:
804 """Build Precision-Recall curve directly from test_statistics.json.""" 750 """Build Precision-Recall curve directly from test_statistics.json."""
805 pr_data = label_stats.get("precision_recall_curve") 751 pr_data = label_stats.get("precision_recall_curve")
806 if not isinstance(pr_data, dict): 752 if not isinstance(pr_data, dict):
807 return None 753 return None
808 754
809 precisions = pr_data.get("precisions") 755 precisions = pr_data.get("precisions")
810 recalls = pr_data.get("recalls") 756 recalls = pr_data.get("recalls")
811 if not precisions or not recalls or len(precisions) != len(recalls): 757 if not precisions or not recalls or len(precisions) != len(recalls):
812 return None 758 return None
759
760 thresholds = pr_data.get("thresholds")
813 761
814 try: 762 try:
815 fig = go.Figure() 763 fig = go.Figure()
816 fig.add_trace( 764 fig.add_trace(
817 go.Scatter( 765 go.Scatter(
849 ) 797 )
850 _style_fig(fig) 798 _style_fig(fig)
851 fig.update_xaxes(range=[0, 1.0]) 799 fig.update_xaxes(range=[0, 1.0])
852 fig.update_yaxes(range=[0, 1.05]) 800 fig.update_yaxes(range=[0, 1.05])
853 801
854 return { 802 if threshold is not None and isinstance(thresholds, list) and len(thresholds) == len(recalls):
855 "title": "Precision-Recall Curve", 803 try:
856 "html": pio.to_html( 804 diffs = [abs(th - threshold) for th in thresholds]
857 fig, 805 best_idx = int(np.argmin(diffs))
858 full_html=False, 806 fig.add_shape(
859 include_plotlyjs=False, 807 type="line",
860 config=config, 808 x0=recalls[best_idx],
861 ), 809 x1=recalls[best_idx],
862 } 810 y0=0,
811 y1=precisions[best_idx],
812 line=dict(color="gray", width=2, dash="dash"),
813 )
814 fig.add_shape(
815 type="line",
816 x0=0,
817 x1=recalls[best_idx],
818 y0=precisions[best_idx],
819 y1=precisions[best_idx],
820 line=dict(color="gray", width=2, dash="dash"),
821 )
822 fig.add_trace(
823 go.Scatter(
824 x=[recalls[best_idx]],
825 y=[precisions[best_idx]],
826 mode="markers",
827 marker=dict(color="black", size=10, symbol="x"),
828 name=f"Threshold={threshold}",
829 hovertemplate="Recall: %{x:.3f}<br>Precision: %{y:.3f}<br>Threshold: %{text}<extra></extra>",
830 text=[f"{threshold}"],
831 )
832 )
833 except Exception as exc:
834 print(f"Warning: could not add threshold marker to PR: {exc}")
835
836 return _wrap_plot("Precision-Recall Curve", fig, config=config)
863 except Exception as e: 837 except Exception as e:
864 print(f"Error building Precision-Recall plot: {e}") 838 print(f"Error building Precision-Recall plot: {e}")
865 return None 839 return None
866 840
867 841
868 def build_prediction_diagnostics( 842 def build_prediction_diagnostics(
869 predictions_path: str, 843 predictions_path: str,
870 label_data_path: Optional[str] = None, 844 label_data_path: Optional[str] = None,
871 split_value: int = 2, 845 split_value: int = 2,
872 threshold: Optional[float] = None,
873 ) -> List[Dict[str, str]]: 846 ) -> List[Dict[str, str]]:
874 """Generate diagnostic plots from predictions.csv for classification tasks.""" 847 """Generate diagnostic plots from predictions.csv for classification tasks."""
875 preds_file = Path(predictions_path) 848 preds_file = Path(predictions_path)
876 if not preds_file.exists(): 849 if not preds_file.exists():
877 return [] 850 return []
881 except Exception as exc: 854 except Exception as exc:
882 print(f"Warning: Unable to read predictions CSV: {exc}") 855 print(f"Warning: Unable to read predictions CSV: {exc}")
883 return [] 856 return []
884 857
885 plots: List[Dict[str, str]] = [] 858 plots: List[Dict[str, str]] = []
859 labels_from_dataset: Optional[pd.Series] = None
860
861 filtered_by_split = False
862
863 # If a split column exists, focus on the requested split (e.g., validation=1, test=2).
864 # If not, but label_data_path is available and matches row count, use it to filter predictions.
865 if SPLIT_COLUMN_NAME in df_pred.columns:
866 df_pred = df_pred[df_pred[SPLIT_COLUMN_NAME] == split_value].reset_index(drop=True)
867 if df_pred.empty:
868 return []
869 filtered_by_split = True
870 elif label_data_path and Path(label_data_path).exists():
871 try:
872 df_labels_all = pd.read_csv(label_data_path)
873 if SPLIT_COLUMN_NAME in df_labels_all.columns and len(df_labels_all) == len(df_pred):
874 split_mask = pd.to_numeric(df_labels_all[SPLIT_COLUMN_NAME], errors="coerce") == split_value
875 labels_from_dataset = df_labels_all.loc[split_mask, LABEL_COLUMN_NAME].reset_index(drop=True)
876 df_pred = df_pred.loc[split_mask].reset_index(drop=True)
877 if df_pred.empty:
878 return []
879 filtered_by_split = True
880 except Exception as exc:
881 print(f"Warning: Unable to filter predictions by split from label data: {exc}")
882
883 # Fallback: no split info available. Assume the predictions file is already filtered
884 # (common for test-only exports) and avoid heuristic slicing that could discard rows.
885 if not filtered_by_split:
886 if split_value != 2:
887 return []
888
889 def _strip_prob_prefix(col: str) -> str:
890 if col.startswith("label_probabilities_"):
891 return col.replace("label_probabilities_", "")
892 if col.startswith("probabilities_"):
893 return col.replace("probabilities_", "")
894 return col
895
896 def _maybe_expand_probabilities_column(df: pd.DataFrame, labels_guess: List[str]) -> List[str]:
897 """If only a single 'probabilities' column exists (list-like), expand it into per-class columns."""
898 if "probabilities" not in df.columns:
899 return []
900 try:
901 # Parse first non-null entry to infer length
902 first_val = df["probabilities"].dropna().iloc[0]
903 parsed = first_val
904 if isinstance(first_val, str):
905 parsed = json.loads(first_val)
906 probs = list(parsed)
907 n = len(probs)
908 if n == 0:
909 return []
910 # Build labels: prefer provided guess; otherwise numeric
911 if labels_guess and len(labels_guess) == n:
912 labels_use = labels_guess
913 else:
914 labels_use = [str(i) for i in range(n)]
915 # Expand column
916 for idx, lbl in enumerate(labels_use):
917 df[f"probabilities_{lbl}"] = df["probabilities"].apply(
918 lambda v: (json.loads(v)[idx] if isinstance(v, str) else list(v)[idx]) if pd.notnull(v) else np.nan
919 )
920 return [f"probabilities_{lbl}" for lbl in labels_use]
921 except Exception:
922 return []
886 923
887 # Identify probability columns 924 # Identify probability columns
888 prob_cols = [ 925 prob_cols = [
889 c for c in df_pred.columns 926 c
890 if c.startswith("label_probabilities_") and c != "label_probabilities" 927 for c in df_pred.columns
928 if (
929 (c.startswith("label_probabilities_") or c.startswith("probabilities_"))
930 and c != "label_probabilities"
931 )
891 ] 932 ]
933 if not prob_cols and "label_probability" in df_pred.columns:
934 prob_cols = ["label_probability"]
935 if not prob_cols and "probability" in df_pred.columns:
936 prob_cols = ["probability"]
937 if not prob_cols and "prediction_probability" in df_pred.columns:
938 prob_cols = ["prediction_probability"]
939 if not prob_cols and "probabilities" in df_pred.columns:
940 labels_guess = sorted([str(u) for u in pd.unique(df_pred[LABEL_COLUMN_NAME])])
941 prob_cols = _maybe_expand_probabilities_column(df_pred, labels_guess)
892 prob_cols_sorted = sorted(prob_cols) 942 prob_cols_sorted = sorted(prob_cols)
893 943
894 def _select_positive_prob(): 944 def _select_positive_prob():
895 if not prob_cols_sorted: 945 if not prob_cols_sorted:
896 return None, None 946 return None, None
897 # Prefer a column indicating positive/event/true/1 947 # Prefer a column indicating positive/event/true/1
898 preferred_keys = ("event", "true", "positive", "pos", "1") 948 preferred_keys = ("event", "true", "positive", "pos", "1")
899 for col in prob_cols_sorted: 949 for col in prob_cols_sorted:
900 suffix = col.replace("label_probabilities_", "").lower() 950 suffix = _strip_prob_prefix(col).lower()
901 if any(k in suffix for k in preferred_keys): 951 if any(k in suffix for k in preferred_keys):
902 return col, suffix 952 return col, suffix
903 if len(prob_cols_sorted) == 2: 953 if len(prob_cols_sorted) == 2:
904 col = prob_cols_sorted[1] 954 col = prob_cols_sorted[1]
905 return col, col.replace("label_probabilities_", "") 955 return col, _strip_prob_prefix(col)
906 col = prob_cols_sorted[0] 956 col = prob_cols_sorted[0]
907 return col, col.replace("label_probabilities_", "") 957 return col, _strip_prob_prefix(col)
908 958
909 pos_prob_col, pos_label_hint = _select_positive_prob() 959 pos_prob_col, pos_label_hint = _select_positive_prob()
910 pos_prob_series = df_pred[pos_prob_col] if pos_prob_col and pos_prob_col in df_pred else None 960 pos_prob_series = df_pred[pos_prob_col] if pos_prob_col and pos_prob_col in df_pred else None
911 961
912 # Confidence series: prefer label_probability, otherwise positive prob, otherwise max prob 962 # Confidence series: prefer label_probability, otherwise positive prob, otherwise max prob
918 elif prob_cols_sorted: 968 elif prob_cols_sorted:
919 confidence_series = df_pred[prob_cols_sorted].max(axis=1) 969 confidence_series = df_pred[prob_cols_sorted].max(axis=1)
920 970
921 # True labels 971 # True labels
922 def _extract_labels(): 972 def _extract_labels():
973 if labels_from_dataset is not None:
974 return labels_from_dataset
923 candidates = [ 975 candidates = [
924 LABEL_COLUMN_NAME, 976 LABEL_COLUMN_NAME,
925 f"{LABEL_COLUMN_NAME}_ground_truth", 977 f"{LABEL_COLUMN_NAME}_ground_truth",
926 f"{LABEL_COLUMN_NAME}__ground_truth", 978 f"{LABEL_COLUMN_NAME}__ground_truth",
927 f"{LABEL_COLUMN_NAME}_target", 979 f"{LABEL_COLUMN_NAME}_target",
973 bargap=0.05, 1025 bargap=0.05,
974 width=700, 1026 width=700,
975 height=500, 1027 height=500,
976 ) 1028 )
977 _style_fig(fig_conf) 1029 _style_fig(fig_conf)
978 plots.append({ 1030 plots.append(_wrap_plot("Prediction Confidence Distribution", fig_conf))
979 "title": "Prediction Confidence Distribution",
980 "html": pio.to_html(fig_conf, full_html=False, include_plotlyjs=False),
981 })
982 1031
983 # The remaining plots require true labels and a positive-class probability 1032 # The remaining plots require true labels and a positive-class probability
984 if labels_series is None or pos_prob_series is None: 1033 if labels_series is None or pos_prob_series is None:
985 return plots 1034 return plots
986 1035
1002 else: 1051 else:
1003 positive_label = unique_labels_list[0] 1052 positive_label = unique_labels_list[0]
1004 1053
1005 y_true = (y_true_raw == positive_label).astype(int).values 1054 y_true = (y_true_raw == positive_label).astype(int).values
1006 1055
1007 # Plot 2: Calibration Curve 1056 # Utility: compute calibration points
1008 bins = np.linspace(0.0, 1.0, 11) 1057 def _calibration_points(y_true_bin: np.ndarray, scores: np.ndarray):
1009 bin_ids = np.digitize(y_score, bins, right=True) 1058 bins = np.linspace(0.0, 1.0, 11)
1010 bin_centers = [] 1059 bin_ids = np.digitize(scores, bins, right=True)
1011 frac_positives = [] 1060 bin_centers, frac_positives = [], []
1012 for b in range(1, len(bins)): 1061 for b in range(1, len(bins)):
1013 mask = bin_ids == b 1062 mask = bin_ids == b
1014 if not np.any(mask): 1063 if not np.any(mask):
1015 continue 1064 continue
1016 bin_centers.append(y_score[mask].mean()) 1065 bin_centers.append(scores[mask].mean())
1017 frac_positives.append(y_true[mask].mean()) 1066 frac_positives.append(y_true_bin[mask].mean())
1018 if bin_centers and frac_positives: 1067 return bin_centers, frac_positives
1019 fig_cal = go.Figure() 1068
1020 fig_cal.add_trace( 1069 # Plot 2: Calibration Curve (multi-class aware; one-vs-rest per label)
1070 label_prob_map = {}
1071 for col in prob_cols_sorted:
1072 if col.startswith("label_probabilities_"):
1073 cls = col.replace("label_probabilities_", "")
1074 label_prob_map[cls] = col
1075
1076 unique_label_strs = [str(u) for u in unique_labels_list]
1077 if len(label_prob_map) > 1 and len(unique_label_strs) > 2:
1078 # Skip multi-class calibration curve for now (not informative in current report)
1079 pass
1080 else:
1081 # Binary/unknown fallback (previous behavior)
1082 bin_centers, frac_positives = _calibration_points(y_true, y_score)
1083 if bin_centers and frac_positives:
1084 fig_cal = go.Figure()
1085 fig_cal.add_trace(
1086 go.Scatter(
1087 x=bin_centers,
1088 y=frac_positives,
1089 mode="lines+markers",
1090 name="Calibration",
1091 line=dict(color="#2ca02c", width=4),
1092 )
1093 )
1094 fig_cal.add_trace(
1095 go.Scatter(
1096 x=[0, 1],
1097 y=[0, 1],
1098 mode="lines",
1099 name="Perfect Calibration",
1100 line=dict(color="gray", width=2, dash="dash"),
1101 )
1102 )
1103 fig_cal.update_layout(
1104 title=dict(text="Calibration Curve", x=0.5),
1105 xaxis_title="Predicted probability",
1106 yaxis_title="Observed frequency",
1107 width=700,
1108 height=500,
1109 )
1110 _style_fig(fig_cal)
1111 plots.append(
1112 _wrap_plot(
1113 "Calibration Curve (Predicted Probability vs Observed Frequency)",
1114 fig_cal,
1115 )
1116 )
1117
1118 return plots
1119
1120
1121 def build_binary_threshold_plot(
1122 predictions_path: str,
1123 label_data_path: Optional[str] = None,
1124 split_value: int = 1,
1125 ) -> Optional[Dict[str, str]]:
1126 """Build a binary threshold sweep plot (accuracy, precision, recall, F1) for a given split."""
1127 preds_file = Path(predictions_path)
1128 if not preds_file.exists():
1129 return None
1130
1131 try:
1132 df_pred = pd.read_csv(predictions_path)
1133 except Exception as exc:
1134 print(f"Warning: Unable to read predictions CSV for threshold plot: {exc}")
1135 return None
1136
1137 labels_from_dataset: Optional[pd.Series] = None
1138 df_full = df_pred.copy()
1139
1140 def _filter_by_split(df: pd.DataFrame, split_val: int) -> pd.DataFrame:
1141 if SPLIT_COLUMN_NAME in df.columns:
1142 return df[df[SPLIT_COLUMN_NAME] == split_val].reset_index(drop=True)
1143 return df
1144
1145 # Try preferred split, then fallback to others with data (val -> test -> train)
1146 candidate_splits = [split_value, 2, 0, 1] if split_value == 1 else [split_value, 1, 0, 2]
1147 df_candidate = pd.DataFrame()
1148 used_split: Optional[int] = None
1149 for sv in candidate_splits:
1150 df_candidate = _filter_by_split(df_full, sv)
1151 if not df_candidate.empty:
1152 used_split = sv
1153 break
1154 if used_split is None:
1155 df_candidate = df_full
1156 df_pred = df_candidate.reset_index(drop=True)
1157
1158 # If still empty (e.g., split column exists but no rows for candidates), fall back to all rows
1159 if df_pred.empty:
1160 df_pred = df_full.reset_index(drop=True)
1161 labels_from_dataset = None
1162
1163 if label_data_path and Path(label_data_path).exists():
1164 try:
1165 df_labels_all = pd.read_csv(label_data_path)
1166 if SPLIT_COLUMN_NAME in df_labels_all.columns and len(df_labels_all) == len(df_full):
1167 mask = (
1168 pd.to_numeric(df_labels_all[SPLIT_COLUMN_NAME], errors="coerce") == used_split
1169 if used_split is not None and SPLIT_COLUMN_NAME in df_labels_all.columns
1170 else pd.Series([True] * len(df_full))
1171 )
1172 labels_from_dataset = df_labels_all.loc[mask, LABEL_COLUMN_NAME].reset_index(drop=True)
1173 if len(labels_from_dataset) == len(df_pred):
1174 labels_from_dataset = labels_from_dataset.reset_index(drop=True)
1175 except Exception as exc:
1176 print(f"Warning: Unable to align labels for threshold plot: {exc}")
1177
1178 # Identify probability columns
1179 prob_cols = [
1180 c
1181 for c in df_pred.columns
1182 if (
1183 (c.startswith("label_probabilities_") or c.startswith("probabilities_"))
1184 and c != "label_probabilities"
1185 )
1186 ]
1187 if not prob_cols and "probabilities" in df_pred.columns:
1188 labels_guess = sorted([str(u) for u in pd.unique(df_pred.get(LABEL_COLUMN_NAME, []))])
1189 # reuse expansion logic from diagnostics
1190 try:
1191 first_val = df_pred["probabilities"].dropna().iloc[0]
1192 parsed = json.loads(first_val) if isinstance(first_val, str) else list(first_val)
1193 n = len(parsed)
1194 if n > 0:
1195 if labels_guess and len(labels_guess) == n:
1196 labels_use = labels_guess
1197 else:
1198 labels_use = [str(i) for i in range(n)]
1199 for idx, lbl in enumerate(labels_use):
1200 df_pred[f"probabilities_{lbl}"] = df_pred["probabilities"].apply(
1201 lambda v: (json.loads(v)[idx] if isinstance(v, str) else list(v)[idx]) if pd.notnull(v) else np.nan
1202 )
1203 prob_cols = [f"probabilities_{lbl}" for lbl in labels_use]
1204 except Exception:
1205 prob_cols = []
1206 prob_cols_sorted = sorted(prob_cols)
1207
1208 def _strip_prob_prefix(col: str) -> str:
1209 if col.startswith("label_probabilities_"):
1210 return col.replace("label_probabilities_", "")
1211 if col.startswith("probabilities_"):
1212 return col.replace("probabilities_", "")
1213 return col
1214
1215 # True labels
1216 def _extract_labels():
1217 if labels_from_dataset is not None:
1218 return labels_from_dataset
1219 for col in [
1220 LABEL_COLUMN_NAME,
1221 f"{LABEL_COLUMN_NAME}_ground_truth",
1222 f"{LABEL_COLUMN_NAME}__ground_truth",
1223 f"{LABEL_COLUMN_NAME}_target",
1224 f"{LABEL_COLUMN_NAME}__target",
1225 "label",
1226 "label_true",
1227 "label_predictions",
1228 "prediction",
1229 ]:
1230 if col in df_pred.columns and col not in prob_cols_sorted:
1231 return df_pred[col]
1232 return None
1233
1234 labels_series = _extract_labels()
1235 if labels_series is None or not prob_cols_sorted:
1236 return None
1237
1238 # Positive prob column selection
1239 preferred_keys = ("event", "true", "positive", "pos", "1")
1240 pos_prob_col = None
1241 for col in prob_cols_sorted:
1242 suffix = _strip_prob_prefix(col).lower()
1243 if any(k in suffix for k in preferred_keys):
1244 pos_prob_col = col
1245 break
1246 if pos_prob_col is None:
1247 pos_prob_col = prob_cols_sorted[-1]
1248
1249 min_len = min(len(labels_series), len(df_pred[pos_prob_col]))
1250 if min_len == 0:
1251 return None
1252
1253 y_true = np.array(labels_series.iloc[:min_len])
1254 # map to binary 0/1
1255 unique_labels = pd.unique(y_true)
1256 if len(unique_labels) < 2:
1257 return None
1258 positive_label = unique_labels[1] if len(unique_labels) >= 2 else unique_labels[0]
1259 y_true_bin = (y_true == positive_label).astype(int)
1260 y_score = np.array(df_pred[pos_prob_col].iloc[:min_len], dtype=float)
1261
1262 thresholds = np.linspace(0.0, 1.0, 101)
1263 accs: List[float] = []
1264 precs: List[float] = []
1265 recs: List[float] = []
1266 f1s: List[float] = []
1267 for t in thresholds:
1268 preds = (y_score >= t).astype(int)
1269 accs.append(accuracy_score(y_true_bin, preds))
1270 precs.append(precision_score(y_true_bin, preds, zero_division=0))
1271 recs.append(recall_score(y_true_bin, preds, zero_division=0))
1272 f1s.append(f1_score(y_true_bin, preds, zero_division=0))
1273
1274 best_idx = int(np.argmax(f1s))
1275 best_thr = thresholds[best_idx]
1276
1277 fig = go.Figure()
1278 fig.add_trace(go.Scatter(x=thresholds, y=accs, mode="lines", name="Accuracy", line=dict(width=4)))
1279 fig.add_trace(go.Scatter(x=thresholds, y=precs, mode="lines", name="Precision", line=dict(width=4)))
1280 fig.add_trace(go.Scatter(x=thresholds, y=recs, mode="lines", name="Recall", line=dict(width=4)))
1281 fig.add_trace(go.Scatter(x=thresholds, y=f1s, mode="lines", name="F1-Score", line=dict(width=4)))
1282 fig.add_shape(
1283 type="line",
1284 x0=best_thr,
1285 x1=best_thr,
1286 y0=0,
1287 y1=1,
1288 line=dict(color="gray", width=2, dash="dash"),
1289 )
1290 fig.update_layout(
1291 title=dict(text="Threshold plot", x=0.5),
1292 xaxis_title="Threshold",
1293 yaxis_title="Metric value",
1294 yaxis=dict(range=[0, 1]),
1295 width=760,
1296 height=520,
1297 legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1),
1298 )
1299 _style_fig(fig)
1300 return _wrap_plot("Threshold plot", fig, include_js=True)
1301
1302
1303 def build_multiclass_roc_pr_plots(
1304 predictions_path: str,
1305 split_value: int = 2,
1306 ) -> List[Dict[str, str]]:
1307 """Build one-vs-rest ROC and PR curves for multi-class classification from predictions."""
1308 preds_file = Path(predictions_path)
1309 if not preds_file.exists():
1310 return []
1311 try:
1312 df_pred = pd.read_csv(predictions_path)
1313 except Exception as exc:
1314 print(f"Warning: Unable to read predictions CSV: {exc}")
1315 return []
1316
1317 if SPLIT_COLUMN_NAME in df_pred.columns:
1318 df_pred = df_pred[df_pred[SPLIT_COLUMN_NAME] == split_value].reset_index(drop=True)
1319 if df_pred.empty:
1320 return []
1321
1322 if LABEL_COLUMN_NAME not in df_pred.columns:
1323 return []
1324
1325 # Identify per-class probability columns
1326 prob_cols = [
1327 c
1328 for c in df_pred.columns
1329 if (
1330 (c.startswith("label_probabilities_") or c.startswith("probabilities_"))
1331 and c != "label_probabilities"
1332 )
1333 ]
1334 if not prob_cols:
1335 return []
1336 labels = [c.replace("label_probabilities_", "").replace("probabilities_", "") for c in prob_cols]
1337 labels_sorted = sorted(labels)
1338
1339 # Ensure all labels are present as probability columns
1340 prob_map = {
1341 c.replace("label_probabilities_", "").replace("probabilities_", ""): c
1342 for c in prob_cols
1343 }
1344 if len(labels_sorted) < 3:
1345 return []
1346
1347 y_true_raw = df_pred[LABEL_COLUMN_NAME].astype(str)
1348 # Drop rows with NaN probabilities across any class to avoid metric errors
1349 prob_matrix = df_pred[[prob_map[lbl] for lbl in labels_sorted]].astype(float)
1350 mask_valid = ~prob_matrix.isnull().any(axis=1)
1351 prob_matrix = prob_matrix[mask_valid]
1352 y_true_raw = y_true_raw[mask_valid]
1353 if prob_matrix.empty:
1354 return []
1355
1356 y_true_bin = label_binarize(y_true_raw, classes=labels_sorted)
1357 y_score = prob_matrix.to_numpy()
1358
1359 plots: List[Dict[str, str]] = []
1360
1361 # ROC: one-vs-rest + micro
1362 fig_roc = go.Figure()
1363 added_any = False
1364 for idx, lbl in enumerate(labels_sorted):
1365 if y_true_bin[:, idx].sum() == 0 or y_true_bin[:, idx].sum() == len(y_true_bin):
1366 continue # skip classes without both positives and negatives
1367 fpr, tpr, _ = roc_curve(y_true_bin[:, idx], y_score[:, idx])
1368 fig_roc.add_trace(
1021 go.Scatter( 1369 go.Scatter(
1022 x=bin_centers, 1370 x=fpr,
1023 y=frac_positives, 1371 y=tpr,
1024 mode="lines+markers", 1372 mode="lines",
1025 name="Calibration", 1373 name=f"{lbl} (AUC={auc(fpr, tpr):.3f})",
1026 line=dict(color="#2ca02c", width=4), 1374 line=dict(width=3),
1027 ) 1375 )
1028 ) 1376 )
1029 fig_cal.add_trace( 1377 added_any = True
1378 # Micro-average only if we have mixed labels
1379 if y_true_bin.sum() > 0 and y_true_bin.sum() < y_true_bin.size:
1380 fpr_micro, tpr_micro, _ = roc_curve(y_true_bin.ravel(), y_score.ravel())
1381 fig_roc.add_trace(
1030 go.Scatter( 1382 go.Scatter(
1031 x=[0, 1], 1383 x=fpr_micro,
1032 y=[0, 1], 1384 y=tpr_micro,
1033 mode="lines", 1385 mode="lines",
1034 name="Perfect Calibration", 1386 name=f"Micro-average (AUC={auc(fpr_micro, tpr_micro):.3f})",
1035 line=dict(color="gray", width=2, dash="dash"), 1387 line=dict(width=3, dash="dash"),
1036 ) 1388 )
1037 ) 1389 )
1038 fig_cal.update_layout( 1390 added_any = True
1039 title=dict(text="Calibration Curve", x=0.5), 1391 if not added_any:
1040 xaxis_title="Predicted probability", 1392 return []
1041 yaxis_title="Observed frequency", 1393 fig_roc.add_trace(
1042 width=700, 1394 go.Scatter(
1043 height=500, 1395 x=[0, 1],
1044 ) 1396 y=[0, 1],
1045 _style_fig(fig_cal) 1397 mode="lines",
1046 plots.append({ 1398 name="Random",
1047 "title": "Calibration Curve (Predicted Probability vs Observed Frequency)", 1399 line=dict(color="gray", width=2, dash="dot"),
1048 "html": pio.to_html(fig_cal, full_html=False, include_plotlyjs=False), 1400 )
1049 }) 1401 )
1050 1402 fig_roc.update_layout(
1051 # Plot 3: Threshold vs Metrics 1403 title=dict(text="Multi-class ROC-AUC (one-vs-rest)", x=0.5),
1052 thresholds = np.linspace(0.0, 1.0, 21) 1404 xaxis_title="False Positive Rate",
1053 accs, f1s, sens, specs = [], [], [], [] 1405 yaxis_title="True Positive Rate",
1054 for t in thresholds: 1406 width=820,
1055 y_pred = (y_score >= t).astype(int) 1407 height=620,
1056 tp = np.sum((y_true == 1) & (y_pred == 1))
1057 tn = np.sum((y_true == 0) & (y_pred == 0))
1058 fp = np.sum((y_true == 0) & (y_pred == 1))
1059 fn = np.sum((y_true == 1) & (y_pred == 0))
1060 acc = (tp + tn) / max(len(y_true), 1)
1061 prec = tp / max(tp + fp, 1e-9)
1062 rec = tp / max(tp + fn, 1e-9)
1063 f1 = 0 if (prec + rec) == 0 else 2 * prec * rec / (prec + rec)
1064 sensitivity = rec
1065 specificity = tn / max(tn + fp, 1e-9)
1066 accs.append(acc)
1067 f1s.append(f1)
1068 sens.append(sensitivity)
1069 specs.append(specificity)
1070
1071 fig_thresh = go.Figure()
1072 fig_thresh.add_trace(go.Scatter(x=thresholds, y=accs, mode="lines", name="Accuracy", line=dict(width=4)))
1073 fig_thresh.add_trace(go.Scatter(x=thresholds, y=f1s, mode="lines", name="F1", line=dict(width=4)))
1074 fig_thresh.add_trace(go.Scatter(x=thresholds, y=sens, mode="lines", name="Sensitivity", line=dict(width=4)))
1075 fig_thresh.add_trace(go.Scatter(x=thresholds, y=specs, mode="lines", name="Specificity", line=dict(width=4)))
1076 fig_thresh.update_layout(
1077 title=dict(text="Threshold Sweep: Accuracy, F1, Sensitivity, Specificity", x=0.5),
1078 xaxis_title="Decision threshold",
1079 yaxis_title="Metric value",
1080 width=700,
1081 height=500,
1082 legend=dict( 1408 legend=dict(
1083 x=0.7, 1409 x=0.62,
1084 y=0.2, 1410 y=0.05,
1085 bgcolor="rgba(255,255,255,0.9)", 1411 bgcolor="rgba(255,255,255,0.9)",
1086 bordercolor="rgba(0,0,0,0.2)", 1412 bordercolor="rgba(0,0,0,0.2)",
1087 borderwidth=1, 1413 borderwidth=1,
1088 ), 1414 ),
1089 shapes=[
1090 dict(
1091 type="line",
1092 x0=threshold,
1093 x1=threshold,
1094 y0=0,
1095 y1=1,
1096 xref="x",
1097 yref="paper",
1098 line=dict(color="#d62728", width=2, dash="dash"),
1099 )
1100 ] if isinstance(threshold, (int, float)) else [],
1101 annotations=[
1102 dict(
1103 x=threshold,
1104 y=1.02,
1105 xref="x",
1106 yref="paper",
1107 showarrow=False,
1108 text=f"Threshold = {threshold:.2f}",
1109 font=dict(size=11, color="#d62728"),
1110 )
1111 ] if isinstance(threshold, (int, float)) else [],
1112 ) 1415 )
1113 _style_fig(fig_thresh) 1416 _style_fig(fig_roc)
1114 plots.append({ 1417 plots.append(_wrap_plot("Multi-class ROC-AUC (one-vs-rest)", fig_roc))
1115 "title": "Threshold Sweep: Accuracy, F1, Sensitivity, Specificity", 1418
1116 "html": pio.to_html(fig_thresh, full_html=False, include_plotlyjs=False), 1419 # PR: one-vs-rest + micro AP
1117 }) 1420 fig_pr = go.Figure()
1421 added_pr = False
1422 for idx, lbl in enumerate(labels_sorted):
1423 if y_true_bin[:, idx].sum() == 0:
1424 continue
1425 prec, rec, _ = precision_recall_curve(y_true_bin[:, idx], y_score[:, idx])
1426 ap = average_precision_score(y_true_bin[:, idx], y_score[:, idx])
1427 fig_pr.add_trace(
1428 go.Scatter(
1429 x=rec,
1430 y=prec,
1431 mode="lines",
1432 name=f"{lbl} (AP={ap:.3f})",
1433 line=dict(width=3),
1434 )
1435 )
1436 added_pr = True
1437 if y_true_bin.sum() > 0:
1438 prec_micro, rec_micro, _ = precision_recall_curve(y_true_bin.ravel(), y_score.ravel())
1439 ap_micro = average_precision_score(y_true_bin, y_score, average="micro")
1440 fig_pr.add_trace(
1441 go.Scatter(
1442 x=rec_micro,
1443 y=prec_micro,
1444 mode="lines",
1445 name=f"Micro-average (AP={ap_micro:.3f})",
1446 line=dict(width=3, dash="dash"),
1447 )
1448 )
1449 added_pr = True
1450 if not added_pr:
1451 return plots
1452 fig_pr.update_layout(
1453 title=dict(text="Multi-class Precision-Recall (one-vs-rest)", x=0.5),
1454 xaxis_title="Recall",
1455 yaxis_title="Precision",
1456 width=820,
1457 height=620,
1458 legend=dict(
1459 x=0.62,
1460 y=0.05,
1461 bgcolor="rgba(255,255,255,0.9)",
1462 bordercolor="rgba(0,0,0,0.2)",
1463 borderwidth=1,
1464 ),
1465 )
1466 _style_fig(fig_pr)
1467 plots.append(_wrap_plot("Multi-class Precision-Recall (one-vs-rest)", fig_pr))
1118 1468
1119 return plots 1469 return plots
1470
1471
1472 def build_multiclass_metric_plots(test_stats_path: str) -> List[Dict[str, str]]:
1473 """Alternative multi-class transparency plots using test_statistics.json per-class stats."""
1474 ts_path = Path(test_stats_path)
1475 if not ts_path.exists():
1476 return []
1477 try:
1478 with open(ts_path, "r") as f:
1479 test_stats = json.load(f)
1480 except Exception:
1481 return []
1482
1483 label_stats = test_stats.get("label", {})
1484 pcs = label_stats.get("per_class_stats", {})
1485 if not pcs:
1486 return []
1487 classes = list(pcs.keys())
1488 if not classes:
1489 return []
1490
1491 metrics = ["precision", "recall", "f1_score", "specificity", "accuracy"]
1492 fig_bar = go.Figure()
1493 for metric in metrics:
1494 values = []
1495 for cls in classes:
1496 v = pcs.get(cls, {}).get(metric)
1497 values.append(v if isinstance(v, (int, float)) else 0)
1498 fig_bar.add_trace(
1499 go.Bar(
1500 x=classes,
1501 y=values,
1502 name=metric.replace("_", " ").title(),
1503 )
1504 )
1505 fig_bar.update_layout(
1506 title=dict(text="Per-Class Metrics (Test)", x=0.5),
1507 xaxis_title="Class",
1508 yaxis_title="Metric value",
1509 barmode="group",
1510 width=900,
1511 height=600,
1512 legend=dict(
1513 x=1.02,
1514 y=1.0,
1515 bgcolor="rgba(255,255,255,0.9)",
1516 bordercolor="rgba(0,0,0,0.2)",
1517 borderwidth=1,
1518 ),
1519 )
1520 _style_fig(fig_bar)
1521
1522 return [_wrap_plot("Per-Class Metrics (Test)", fig_bar)]