comparison plotly_plots.py @ 15:d17e3a1b8659 draft default tip

planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
author goeckslab
date Fri, 28 Nov 2025 15:45:49 +0000
parents c5150cceab47
children
comparison
equal deleted inserted replaced
14:94cd9ac4a9b1 15:d17e3a1b8659
5 import numpy as np 5 import numpy as np
6 import pandas as pd 6 import pandas as pd
7 import plotly.graph_objects as go 7 import plotly.graph_objects as go
8 import plotly.io as pio 8 import plotly.io as pio
9 from constants import LABEL_COLUMN_NAME, SPLIT_COLUMN_NAME 9 from constants import LABEL_COLUMN_NAME, SPLIT_COLUMN_NAME
10 from sklearn.metrics import auc, roc_curve 10
11 from sklearn.preprocessing import label_binarize 11
12 def _style_fig(fig: go.Figure, font_size: int = 12) -> go.Figure:
13 """Apply consistent styling across Plotly figures."""
14 fig.update_layout(
15 font=dict(size=font_size),
16 plot_bgcolor="#ffffff",
17 paper_bgcolor="#ffffff",
18 )
19 fig.update_xaxes(gridcolor="#e8e8e8")
20 fig.update_yaxes(gridcolor="#e8e8e8")
21 return fig
22
23
24 def _labels_from_metadata_dict(meta_dict: dict) -> List[str]:
25 """Extract ordered label names from Ludwig train_set_metadata."""
26 if not isinstance(meta_dict, dict):
27 return []
28
29 for key in ("idx2str", "idx2label", "vocab"):
30 seq = meta_dict.get(key)
31 if isinstance(seq, list) and seq:
32 return [str(v) for v in seq]
33
34 str2idx = meta_dict.get("str2idx")
35 if isinstance(str2idx, dict) and str2idx:
36 int_indices = [v for v in str2idx.values() if isinstance(v, int)]
37 if int_indices:
38 max_idx = max(int_indices)
39 ordered = [None] * (max_idx + 1)
40 for name, idx in str2idx.items():
41 if isinstance(idx, int) and 0 <= idx < len(ordered):
42 ordered[idx] = name
43 return [str(v) for v in ordered if v is not None]
44
45 return []
46
47
48 def _resolve_confusion_labels(
49 label_stats: dict,
50 n_classes: int,
51 metadata_csv_path: Optional[str],
52 train_set_metadata_path: Optional[str],
53 ) -> List[str]:
54 """Prefer original labels from metadata; fall back to stats if unavailable."""
55 if train_set_metadata_path:
56 try:
57 meta_path = Path(train_set_metadata_path)
58 if meta_path.exists():
59 with open(meta_path, "r") as f:
60 meta_json = json.load(f)
61 label_meta = meta_json.get(LABEL_COLUMN_NAME)
62 if not isinstance(label_meta, dict):
63 label_meta = next(
64 (
65 v
66 for v in meta_json.values()
67 if isinstance(v, dict)
68 and any(k in v for k in ("idx2str", "str2idx", "idx2label", "vocab"))
69 ),
70 None,
71 )
72 labels_from_meta = _labels_from_metadata_dict(label_meta) if label_meta else []
73 if labels_from_meta and len(labels_from_meta) >= n_classes:
74 return [str(label) for label in labels_from_meta[:n_classes]]
75 except Exception as exc:
76 print(f"Warning: Unable to read labels from train_set_metadata: {exc}")
77
78 if metadata_csv_path:
79 try:
80 csv_path = Path(metadata_csv_path)
81 if csv_path.exists():
82 df_meta = pd.read_csv(csv_path)
83 if LABEL_COLUMN_NAME in df_meta.columns:
84 uniques = df_meta[LABEL_COLUMN_NAME].dropna().unique().tolist()
85 if uniques and len(uniques) >= n_classes:
86 return [str(u) for u in uniques[:n_classes]]
87 except Exception as exc:
88 print(f"Warning: Unable to read labels from metadata CSV: {exc}")
89
90 pcs = label_stats.get("per_class_stats", {})
91 if pcs:
92 pcs_labels = [str(k) for k in pcs.keys()]
93 if len(pcs_labels) >= n_classes:
94 return pcs_labels[:n_classes]
95
96 labels = label_stats.get("labels")
97 if not labels:
98 labels = [str(i) for i in range(n_classes)]
99 if len(labels) < n_classes:
100 labels = labels + [str(i) for i in range(len(labels), n_classes)]
101 return [str(label) for label in labels[:n_classes]]
12 102
13 103
14 def build_classification_plots( 104 def build_classification_plots(
15 test_stats_path: str, 105 test_stats_path: str,
16 training_stats_path: Optional[str] = None, 106 training_stats_path: Optional[str] = None,
107 metadata_csv_path: Optional[str] = None,
108 train_set_metadata_path: Optional[str] = None,
17 ) -> List[Dict[str, str]]: 109 ) -> List[Dict[str, str]]:
18 """ 110 """
19 Read Ludwig’s test_statistics.json and build three interactive Plotly panels: 111 Read Ludwig’s test_statistics.json and build three interactive Plotly panels:
20 - Confusion Matrix 112 - Confusion Matrix
21 - ROC-AUC 113 - ROC-AUC
22 - Classification Report Heatmap 114 - Classification Report Heatmap
115
116 If metadata paths are provided, the confusion matrix axes will use the original
117 label values from the training metadata rather than integer-encoded labels.
23 118
24 Returns a list of dicts, each with: 119 Returns a list of dicts, each with:
25 { 120 {
26 "title": <plot title>, 121 "title": <plot title>,
27 "html": <HTML fragment for embedding> 122 "html": <HTML fragment for embedding>
40 135
41 plots: List[Dict[str, str]] = [] 136 plots: List[Dict[str, str]] = []
42 137
43 # 0) Confusion Matrix 138 # 0) Confusion Matrix
44 cm = np.array(label_stats["confusion_matrix"], dtype=int) 139 cm = np.array(label_stats["confusion_matrix"], dtype=int)
45 # Try to get actual class names from per_class_stats keys (which contain the real labels) 140 labels = _resolve_confusion_labels(
46 pcs = label_stats.get("per_class_stats", {}) 141 label_stats,
47 if pcs: 142 n_classes,
48 labels = list(pcs.keys()) 143 metadata_csv_path=metadata_csv_path,
49 else: 144 train_set_metadata_path=train_set_metadata_path,
50 labels = label_stats.get("labels", [str(i) for i in range(cm.shape[0])]) 145 )
51 total = cm.sum() 146 total = cm.sum()
52 147
53 fig_cm = go.Figure( 148 fig_cm = go.Figure(
54 go.Heatmap( 149 go.Heatmap(
55 z=cm, 150 z=cm,
68 yaxis_autorange="reversed", 163 yaxis_autorange="reversed",
69 width=side_px, 164 width=side_px,
70 height=side_px, 165 height=side_px,
71 margin=dict(t=100, l=80, r=80, b=80), 166 margin=dict(t=100, l=80, r=80, b=80),
72 ) 167 )
168 _style_fig(fig_cm)
73 169
74 # annotate counts and percentages 170 # annotate counts and percentages
75 mval = cm.max() if cm.size else 0 171 mval = cm.max() if cm.size else 0
76 thresh = mval / 2 172 thresh = mval / 2
77 for i in range(cm.shape[0]): 173 for i in range(cm.shape[0]):
108 include_plotlyjs="cdn", 204 include_plotlyjs="cdn",
109 config=common_cfg 205 config=common_cfg
110 ) 206 )
111 }) 207 })
112 208
113 # 1) ROC-AUC Curves (Multi-class) 209 # 1) ROC Curve (from test_statistics)
114 roc_plot = _build_roc_auc_plot(test_stats_path, labels, common_cfg) 210 roc_plot = _build_static_roc_plot(label_stats, common_cfg, friendly_labels=labels)
115 if roc_plot: 211 if roc_plot:
116 plots.append(roc_plot) 212 plots.append(roc_plot)
213
214 # 2) Precision-Recall Curve (from test_statistics)
215 pr_plot = _build_precision_recall_plot(label_stats, common_cfg)
216 if pr_plot:
217 plots.append(pr_plot)
117 218
118 # 2) Classification Report Heatmap 219 # 2) Classification Report Heatmap
119 pcs = label_stats.get("per_class_stats", {}) 220 pcs = label_stats.get("per_class_stats", {})
120 if pcs: 221 if pcs:
121 classes = list(pcs.keys()) 222 classes = list(pcs.keys())
122 metrics = ["precision", "recall", "f1_score"] 223 metrics = [
224 "precision",
225 "recall",
226 "f1_score",
227 "accuracy",
228 "matthews_correlation_coefficient",
229 "specificity",
230 ]
123 z, txt = [], [] 231 z, txt = [], []
124 for c in classes: 232 for c in classes:
125 row, trow = [], [] 233 row, trow = [], []
126 for m in metrics: 234 for m in metrics:
127 val = pcs[c].get(m, 0) 235 val = pcs[c].get(m, 0)
131 txt.append(trow) 239 txt.append(trow)
132 240
133 fig_cr = go.Figure( 241 fig_cr = go.Figure(
134 go.Heatmap( 242 go.Heatmap(
135 z=z, 243 z=z,
136 x=metrics, 244 x=[m.replace("_", " ") for m in metrics],
137 y=[str(c) for c in classes], 245 y=[str(c) for c in classes],
138 text=txt, 246 text=txt,
139 texttemplate="%{text}", 247 texttemplate="%{text}",
140 colorscale="Reds", 248 colorscale="Reds",
141 showscale=True, 249 showscale=True,
142 colorbar=dict(title="Value"), 250 colorbar=dict(title="Value"),
143 ) 251 )
144 ) 252 )
145 fig_cr.update_layout( 253 fig_cr.update_layout(
146 title="Classification Report", 254 title="Per-Class metrics",
147 xaxis_title="", 255 xaxis_title="",
148 yaxis_title="Class", 256 yaxis_title="Class",
149 width=side_px, 257 width=side_px,
150 height=side_px, 258 height=side_px,
151 margin=dict(t=80, l=80, r=80, b=80), 259 margin=dict(t=80, l=80, r=80, b=80),
152 ) 260 )
261 _style_fig(fig_cr)
153 plots.append({ 262 plots.append({
154 "title": "Classification Report", 263 "title": "Per-Class metrics",
155 "html": pio.to_html( 264 "html": pio.to_html(
156 fig_cr, 265 fig_cr,
157 full_html=False, 266 full_html=False,
158 include_plotlyjs=False, 267 include_plotlyjs=False,
159 config=common_cfg 268 config=common_cfg
160 ) 269 )
161 }) 270 })
162 271
272 # 3) Prediction Diagnostics (from predictions.csv)
273 # Note: appended separately in generate_html_report, not returned here.
274
163 return plots 275 return plots
164 276
165 277
166 def _build_roc_auc_plot(test_stats_path: str, class_labels: List[str], config: dict) -> Optional[Dict[str, str]]: 278 def build_train_validation_plots(train_stats_path: str) -> List[Dict[str, str]]:
167 """ 279 """Generate Train/Validation learning curve plots from training_statistics.json."""
168 Build an interactive ROC-AUC curve plot for multi-class classification. 280 if not train_stats_path or not Path(train_stats_path).exists():
169 Following sklearn's ROC example with micro-average and per-class curves. 281 return []
170
171 Args:
172 test_stats_path: Path to test_statistics.json
173 class_labels: List of class label names
174 config: Plotly config dict
175
176 Returns:
177 Dict with title and HTML, or None if data unavailable
178 """
179 try: 282 try:
180 # Get the experiment directory from test_stats_path 283 with open(train_stats_path, "r") as f:
181 exp_dir = Path(test_stats_path).parent 284 train_stats = json.load(f)
182 285 except Exception as exc:
183 # Load predictions with probabilities 286 print(f"Warning: Unable to read training statistics: {exc}")
184 predictions_path = exp_dir / "predictions.csv" 287 return []
185 if not predictions_path.exists(): 288
289 label_train = (train_stats.get("training") or {}).get("label", {})
290 label_val = (train_stats.get("validation") or {}).get("label", {})
291 if not label_train and not label_val:
292 return []
293 plots: List[Dict[str, str]] = []
294 include_js = True # Load Plotly.js once for this group
295
296 def _get_series(stats: dict, metric: str) -> List[float]:
297 if metric not in stats:
298 return []
299 vals = stats.get(metric, [])
300 if isinstance(vals, list):
301 return [float(v) for v in vals]
302 try:
303 return [float(vals)]
304 except Exception:
305 return []
306
307 def _line_plot(metric_key: str, title: str, yaxis_title: str) -> Optional[Dict[str, str]]:
308 train_series = _get_series(label_train, metric_key)
309 val_series = _get_series(label_val, metric_key)
310 if not train_series and not val_series:
186 return None 311 return None
187 312 epochs_train = list(range(1, len(train_series) + 1))
313 epochs_val = list(range(1, len(val_series) + 1))
314 fig = go.Figure()
315 if train_series:
316 fig.add_trace(
317 go.Scatter(
318 x=epochs_train,
319 y=train_series,
320 mode="lines+markers",
321 name="Train",
322 line=dict(width=4),
323 )
324 )
325 if val_series:
326 fig.add_trace(
327 go.Scatter(
328 x=epochs_val,
329 y=val_series,
330 mode="lines+markers",
331 name="Validation",
332 line=dict(width=4),
333 )
334 )
335 fig.update_layout(
336 title=dict(text=title, x=0.5),
337 xaxis_title="Epoch",
338 yaxis_title=yaxis_title,
339 width=760,
340 height=520,
341 hovermode="x unified",
342 )
343 _style_fig(fig)
344 return {
345 "title": title,
346 "html": pio.to_html(
347 fig,
348 full_html=False,
349 include_plotlyjs="cdn" if include_js else False,
350 ),
351 }
352
353 # Core learning curves
354 for key, title in [
355 ("roc_auc", "ROC-AUC across epochs"),
356 ("precision", "Precision across epochs"),
357 ("recall", "Recall/Sensitivity across epochs"),
358 ("specificity", "Specificity across epochs"),
359 ]:
360 plot = _line_plot(key, title, title.replace("Learning Curve", "").strip())
361 if plot:
362 plots.append(plot)
363 include_js = False
364
365 # Precision vs Recall evolution (validation)
366 val_prec = _get_series(label_val, "precision")
367 val_rec = _get_series(label_val, "recall")
368 if val_prec and val_rec:
369 epochs = list(range(1, min(len(val_prec), len(val_rec)) + 1))
370 fig_pr = go.Figure()
371 fig_pr.add_trace(
372 go.Scatter(
373 x=epochs,
374 y=val_prec[: len(epochs)],
375 mode="lines+markers",
376 name="Precision",
377 )
378 )
379 fig_pr.add_trace(
380 go.Scatter(
381 x=epochs,
382 y=val_rec[: len(epochs)],
383 mode="lines+markers",
384 name="Recall",
385 )
386 )
387 fig_pr.update_layout(
388 title=dict(text="Validation Precision and Recall by Epoch", x=0.5),
389 xaxis_title="Epoch",
390 yaxis_title="Value",
391 width=760,
392 height=520,
393 hovermode="x unified",
394 )
395 _style_fig(fig_pr)
396 plots.append({
397 "title": "Precision vs Recall Evolution",
398 "html": pio.to_html(
399 fig_pr,
400 full_html=False,
401 include_plotlyjs="cdn" if include_js else False,
402 ),
403 })
404 include_js = False
405
406 # F1-score derived
407 def _compute_f1(p: List[float], r: List[float]) -> List[float]:
408 f1_vals = []
409 for prec, rec in zip(p, r):
410 if (prec + rec) == 0:
411 f1_vals.append(0.0)
412 else:
413 f1_vals.append(2 * prec * rec / (prec + rec))
414 return f1_vals
415
416 f1_train = _compute_f1(_get_series(label_train, "precision"), _get_series(label_train, "recall"))
417 f1_val = _compute_f1(val_prec, val_rec)
418 if f1_train or f1_val:
419 fig = go.Figure()
420 if f1_train:
421 fig.add_trace(go.Scatter(x=list(range(1, len(f1_train) + 1)), y=f1_train, mode="lines+markers", name="Train", line=dict(width=4)))
422 if f1_val:
423 fig.add_trace(go.Scatter(x=list(range(1, len(f1_val) + 1)), y=f1_val, mode="lines+markers", name="Validation", line=dict(width=4)))
424 fig.update_layout(
425 title=dict(text="F1-Score across epochs (derived)", x=0.5),
426 xaxis_title="Epoch",
427 yaxis_title="F1-Score",
428 width=760,
429 height=520,
430 hovermode="x unified",
431 )
432 _style_fig(fig)
433 plots.append({
434 "title": "F1-Score across epochs (derived)",
435 "html": pio.to_html(
436 fig,
437 full_html=False,
438 include_plotlyjs="cdn" if include_js else False,
439 ),
440 })
441 include_js = False
442
443 # Overfitting Gap: Train vs Val ROC-AUC (gap)
444 roc_train = _get_series(label_train, "roc_auc")
445 roc_val = _get_series(label_val, "roc_auc")
446 if roc_train and roc_val:
447 epochs_gap = list(range(1, min(len(roc_train), len(roc_val)) + 1))
448 gaps = [t - v for t, v in zip(roc_train[:len(epochs_gap)], roc_val[:len(epochs_gap)])]
449 fig_gap = go.Figure()
450 fig_gap.add_trace(go.Scatter(x=epochs_gap, y=gaps, mode="lines+markers", name="Train - Val ROC-AUC", line=dict(width=4)))
451 fig_gap.update_layout(
452 title=dict(text="Overfitting gap: ROC-AUC across epochs", x=0.5),
453 xaxis_title="Epoch",
454 yaxis_title="Gap",
455 width=760,
456 height=520,
457 hovermode="x unified",
458 )
459 _style_fig(fig_gap)
460 plots.append({
461 "title": "Overfitting gap: ROC-AUC across epochs",
462 "html": pio.to_html(
463 fig_gap,
464 full_html=False,
465 include_plotlyjs="cdn" if include_js else False,
466 ),
467 })
468 include_js = False
469
470 # Best Epoch Dashboard (based on max val ROC-AUC)
471 if roc_val:
472 best_idx = int(np.argmax(roc_val))
473 best_epoch = best_idx + 1
474 spec_val = _get_series(label_val, "specificity")
475 metrics_at_best = {
476 "ROC-AUC": roc_val[best_idx] if best_idx < len(roc_val) else None,
477 "Precision": val_prec[best_idx] if best_idx < len(val_prec) else None,
478 "Recall": val_rec[best_idx] if best_idx < len(val_rec) else None,
479 "Specificity": spec_val[best_idx] if best_idx < len(spec_val) else None,
480 "F1-Score": f1_val[best_idx] if best_idx < len(f1_val) else None,
481 }
482 fig_best = go.Figure()
483 for name, value in metrics_at_best.items():
484 if value is not None:
485 fig_best.add_trace(go.Bar(name=name, x=[name], y=[value]))
486 fig_best.update_layout(
487 title=dict(text=f"Best Epoch Dashboard (Val ROC-AUC @ epoch {best_epoch})", x=0.5),
488 xaxis_title="Metric",
489 yaxis_title="Value",
490 width=760,
491 height=520,
492 showlegend=False,
493 )
494 _style_fig(fig_best)
495 plots.append({
496 "title": "Best Validation Epoch Snapshot (Metrics)",
497 "html": pio.to_html(
498 fig_best,
499 full_html=False,
500 include_plotlyjs="cdn" if include_js else False,
501 ),
502 })
503 include_js = False
504
505 return plots
506
507
508 def _get_regression_series(split_stats: dict, metric: str) -> List[float]:
509 if metric not in split_stats:
510 return []
511 vals = split_stats.get(metric, [])
512 if isinstance(vals, list):
513 return [float(v) for v in vals]
514 try:
515 return [float(vals)]
516 except Exception:
517 return []
518
519
520 def _regression_line_plot(
521 train_split: dict,
522 val_split: dict,
523 metric_key: str,
524 title: str,
525 yaxis_title: str,
526 include_js: bool,
527 ) -> Optional[Dict[str, str]]:
528 train_series = _get_regression_series(train_split, metric_key)
529 val_series = _get_regression_series(val_split, metric_key)
530 if not train_series and not val_series:
531 return None
532 epochs_train = list(range(1, len(train_series) + 1))
533 epochs_val = list(range(1, len(val_series) + 1))
534 fig = go.Figure()
535 if train_series:
536 fig.add_trace(
537 go.Scatter(
538 x=epochs_train,
539 y=train_series,
540 mode="lines+markers",
541 name="Train",
542 line=dict(width=4),
543 )
544 )
545 if val_series:
546 fig.add_trace(
547 go.Scatter(
548 x=epochs_val,
549 y=val_series,
550 mode="lines+markers",
551 name="Validation",
552 line=dict(width=4),
553 )
554 )
555 fig.update_layout(
556 title=dict(text=title, x=0.5),
557 xaxis_title="Epoch",
558 yaxis_title=yaxis_title,
559 width=760,
560 height=520,
561 hovermode="x unified",
562 )
563 _style_fig(fig)
564 return {
565 "title": title,
566 "html": pio.to_html(
567 fig,
568 full_html=False,
569 include_plotlyjs="cdn" if include_js else False,
570 ),
571 }
572
573
574 def build_regression_train_val_plots(train_stats_path: str) -> List[Dict[str, str]]:
575 """Generate regression Train/Validation learning curve plots from training_statistics.json."""
576 if not train_stats_path or not Path(train_stats_path).exists():
577 return []
578 try:
579 with open(train_stats_path, "r") as f:
580 train_stats = json.load(f)
581 except Exception as exc:
582 print(f"Warning: Unable to read training statistics: {exc}")
583 return []
584
585 label_train = (train_stats.get("training") or {}).get("label", {})
586 label_val = (train_stats.get("validation") or {}).get("label", {})
587 if not label_train and not label_val:
588 return []
589
590 plots: List[Dict[str, str]] = []
591 include_js = True
592 for metric_key, title, ytitle in [
593 ("mean_absolute_error", "Mean Absolute Error across epochs", "MAE"),
594 ("root_mean_squared_error", "Root Mean Squared Error across epochs", "RMSE"),
595 ("mean_absolute_percentage_error", "Mean Absolute Percentage Error across epochs", "MAPE"),
596 ("r2", "R² across epochs", "R²"),
597 ("loss", "Loss across epochs", "Loss"),
598 ]:
599 plot = _regression_line_plot(label_train, label_val, metric_key, title, ytitle, include_js)
600 if plot:
601 plots.append(plot)
602 include_js = False
603 return plots
604
605
606 def build_regression_test_plots(train_stats_path: str) -> List[Dict[str, str]]:
607 """Generate regression Test learning curves from training_statistics.json."""
608 if not train_stats_path or not Path(train_stats_path).exists():
609 return []
610 try:
611 with open(train_stats_path, "r") as f:
612 train_stats = json.load(f)
613 except Exception as exc:
614 print(f"Warning: Unable to read training statistics: {exc}")
615 return []
616
617 label_test = (train_stats.get("test") or {}).get("label", {})
618 if not label_test:
619 return []
620
621 plots: List[Dict[str, str]] = []
622 include_js = True
623 metrics = [
624 ("mean_absolute_error", "Mean Absolute Error Across Epochs", "MAE"),
625 ("root_mean_squared_error", "Root Mean Squared Error Across Epochs", "RMSE"),
626 ("mean_absolute_percentage_error", "Mean Absolute Percentage Error Across Epochs", "MAPE"),
627 ("r2", "R² Across Epochs", "R²"),
628 ("loss", "Loss Across Epochs", "Loss"),
629 ]
630 epochs = None
631 for metric_key, title, ytitle in metrics:
632 series = _get_regression_series(label_test, metric_key)
633 if not series:
634 continue
635 if epochs is None:
636 epochs = list(range(1, len(series) + 1))
637 fig = go.Figure()
638 fig.add_trace(
639 go.Scatter(
640 x=epochs,
641 y=series[: len(epochs)],
642 mode="lines+markers",
643 name="Test",
644 line=dict(width=4),
645 )
646 )
647 fig.update_layout(
648 title=dict(text=title, x=0.5),
649 xaxis_title="Epoch",
650 yaxis_title=ytitle,
651 width=760,
652 height=520,
653 hovermode="x unified",
654 )
655 _style_fig(fig)
656 plots.append({
657 "title": title,
658 "html": pio.to_html(
659 fig,
660 full_html=False,
661 include_plotlyjs="cdn" if include_js else False,
662 ),
663 })
664 include_js = False
665 return plots
666
667
668 def _build_static_roc_plot(
669 label_stats: dict, config: dict, friendly_labels: Optional[List[str]] = None
670 ) -> Optional[Dict[str, str]]:
671 """Build ROC curve directly from test_statistics.json (single curve)."""
672 roc_data = label_stats.get("roc_curve")
673 if not isinstance(roc_data, dict):
674 return None
675
676 fpr = roc_data.get("false_positive_rate")
677 tpr = roc_data.get("true_positive_rate")
678 if not fpr or not tpr or len(fpr) != len(tpr):
679 return None
680
681 try:
682 fig = go.Figure()
683 fig.add_trace(
684 go.Scatter(
685 x=fpr,
686 y=tpr,
687 mode="lines+markers",
688 name="ROC Curve",
689 line=dict(color="#1f77b4", width=4),
690 hovertemplate="FPR: %{x:.3f}<br>TPR: %{y:.3f}<extra></extra>",
691 )
692 )
693 fig.add_trace(
694 go.Scatter(
695 x=[0, 1],
696 y=[0, 1],
697 mode="lines",
698 name="Random Classifier",
699 line=dict(color="gray", width=2, dash="dash"),
700 hovertemplate="Random Classifier<extra></extra>",
701 )
702 )
703
704 auc_val = label_stats.get("roc_auc") or label_stats.get("roc_auc_macro") or label_stats.get("roc_auc_micro")
705 auc_txt = f" (AUC = {auc_val:.3f})" if isinstance(auc_val, (int, float)) else ""
706
707 # Determine which label is treated as positive for the curve
708 label_list: List = []
709 pcs = label_stats.get("per_class_stats", {})
710 if pcs:
711 label_list = list(pcs.keys())
712 if not label_list:
713 labels_from_stats = label_stats.get("labels")
714 if isinstance(labels_from_stats, list):
715 label_list = labels_from_stats
716
717 # Try to resolve index of the positive label explicitly provided by Ludwig
718 pos_label_raw = (
719 roc_data.get("positive_label")
720 or roc_data.get("positive_class")
721 or label_stats.get("positive_label")
722 )
723 pos_label_idx = None
724 if pos_label_raw is not None and isinstance(label_list, list):
725 try:
726 pos_label_idx = label_list.index(pos_label_raw)
727 except ValueError:
728 pos_label_idx = None
729
730 # Fallback: use the second label if available, otherwise the first
731 if pos_label_idx is None:
732 if isinstance(label_list, list) and len(label_list) >= 2:
733 pos_label_idx = 1
734 elif isinstance(label_list, list) and label_list:
735 pos_label_idx = 0
736
737 if pos_label_raw is None and isinstance(label_list, list) and pos_label_idx is not None:
738 pos_label_raw = label_list[pos_label_idx]
739
740 # Map to friendly label if we have one from metadata/CSV
741 pos_label_display = pos_label_raw
742 if (
743 friendly_labels
744 and isinstance(pos_label_idx, int)
745 and 0 <= pos_label_idx < len(friendly_labels)
746 ):
747 pos_label_display = friendly_labels[pos_label_idx]
748
749 pos_label_txt = (
750 f"Positive class: {pos_label_display}"
751 if pos_label_display is not None
752 else "Positive class: (not available)"
753 )
754
755 title_label = f"ROC Curve{auc_txt}"
756 if pos_label_display is not None:
757 title_label = f"ROC Curve (Positive Class: {pos_label_display}){auc_txt}"
758
759 fig.update_layout(
760 title=dict(text=title_label, x=0.5),
761 xaxis_title="False Positive Rate",
762 yaxis_title="True Positive Rate",
763 width=700,
764 height=600,
765 margin=dict(t=80, l=80, r=80, b=110),
766 hovermode="closest",
767 legend=dict(
768 x=0.6,
769 y=0.1,
770 bgcolor="rgba(255,255,255,0.9)",
771 bordercolor="rgba(0,0,0,0.2)",
772 borderwidth=1,
773 ),
774 )
775 _style_fig(fig)
776 fig.update_xaxes(range=[0, 1.0])
777 fig.update_yaxes(range=[0, 1.05])
778
779 fig.add_annotation(
780 x=0.5,
781 y=-0.15,
782 xref="paper",
783 yref="paper",
784 showarrow=False,
785 text=f"<span style='font-size:12px;color:#555;'>{pos_label_txt}</span>",
786 xanchor="center",
787 )
788
789 return {
790 "title": "ROC Curve",
791 "html": pio.to_html(
792 fig,
793 full_html=False,
794 include_plotlyjs=False,
795 config=config,
796 ),
797 }
798 except Exception as e:
799 print(f"Error building ROC plot: {e}")
800 return None
801
802
803 def _build_precision_recall_plot(label_stats: dict, config: dict) -> Optional[Dict[str, str]]:
804 """Build Precision-Recall curve directly from test_statistics.json."""
805 pr_data = label_stats.get("precision_recall_curve")
806 if not isinstance(pr_data, dict):
807 return None
808
809 precisions = pr_data.get("precisions")
810 recalls = pr_data.get("recalls")
811 if not precisions or not recalls or len(precisions) != len(recalls):
812 return None
813
814 try:
815 fig = go.Figure()
816 fig.add_trace(
817 go.Scatter(
818 x=recalls,
819 y=precisions,
820 mode="lines+markers",
821 name="Precision-Recall",
822 line=dict(color="#d62728", width=4),
823 hovertemplate="Recall: %{x:.3f}<br>Precision: %{y:.3f}<extra></extra>",
824 )
825 )
826
827 ap_val = (
828 label_stats.get("average_precision_macro")
829 or label_stats.get("average_precision_micro")
830 or label_stats.get("average_precision_samples")
831 )
832 ap_txt = f" (AP = {ap_val:.3f})" if isinstance(ap_val, (int, float)) else ""
833
834 fig.update_layout(
835 title=dict(text=f"Precision-Recall Curve{ap_txt}", x=0.5),
836 xaxis_title="Recall",
837 yaxis_title="Precision",
838 width=700,
839 height=600,
840 margin=dict(t=80, l=80, r=80, b=80),
841 hovermode="closest",
842 legend=dict(
843 x=0.6,
844 y=0.1,
845 bgcolor="rgba(255,255,255,0.9)",
846 bordercolor="rgba(0,0,0,0.2)",
847 borderwidth=1,
848 ),
849 )
850 _style_fig(fig)
851 fig.update_xaxes(range=[0, 1.0])
852 fig.update_yaxes(range=[0, 1.05])
853
854 return {
855 "title": "Precision-Recall Curve",
856 "html": pio.to_html(
857 fig,
858 full_html=False,
859 include_plotlyjs=False,
860 config=config,
861 ),
862 }
863 except Exception as e:
864 print(f"Error building Precision-Recall plot: {e}")
865 return None
866
867
868 def build_prediction_diagnostics(
869 predictions_path: str,
870 label_data_path: Optional[str] = None,
871 split_value: int = 2,
872 threshold: Optional[float] = None,
873 ) -> List[Dict[str, str]]:
874 """Generate diagnostic plots from predictions.csv for classification tasks."""
875 preds_file = Path(predictions_path)
876 if not preds_file.exists():
877 return []
878
879 try:
188 df_pred = pd.read_csv(predictions_path) 880 df_pred = pd.read_csv(predictions_path)
189 881 except Exception as exc:
190 if SPLIT_COLUMN_NAME in df_pred.columns: 882 print(f"Warning: Unable to read predictions CSV: {exc}")
191 split_series = df_pred[SPLIT_COLUMN_NAME].astype(str).str.lower() 883 return []
192 test_mask = split_series.isin({"2", "test", "testing"}) 884
193 if test_mask.any(): 885 plots: List[Dict[str, str]] = []
194 df_pred = df_pred[test_mask].reset_index(drop=True) 886
195 887 # Identify probability columns
196 if df_pred.empty: 888 prob_cols = [
197 return None 889 c for c in df_pred.columns
198 890 if c.startswith("label_probabilities_") and c != "label_probabilities"
199 # Extract probability columns (label_probabilities_0, label_probabilities_1, etc.) 891 ]
200 # or label_probabilities_<class_name> for string labels 892 prob_cols_sorted = sorted(prob_cols)
201 prob_cols = [col for col in df_pred.columns if col.startswith('label_probabilities_') and col != 'label_probabilities'] 893
202 894 def _select_positive_prob():
203 # Sort by class number if numeric, otherwise keep alphabetical order 895 if not prob_cols_sorted:
204 if prob_cols and prob_cols[0].split('_')[-1].isdigit(): 896 return None, None
205 prob_cols.sort(key=lambda x: int(x.split('_')[-1])) 897 # Prefer a column indicating positive/event/true/1
206 else: 898 preferred_keys = ("event", "true", "positive", "pos", "1")
207 prob_cols.sort() # Alphabetical sort for string class names 899 for col in prob_cols_sorted:
208 900 suffix = col.replace("label_probabilities_", "").lower()
209 if not prob_cols: 901 if any(k in suffix for k in preferred_keys):
210 return None 902 return col, suffix
211 903 if len(prob_cols_sorted) == 2:
212 # Get probabilities matrix (n_samples x n_classes) 904 col = prob_cols_sorted[1]
213 y_score = df_pred[prob_cols].values 905 return col, col.replace("label_probabilities_", "")
214 n_classes = len(prob_cols) 906 col = prob_cols_sorted[0]
215 907 return col, col.replace("label_probabilities_", "")
216 y_true = None 908
217 candidate_cols = [ 909 pos_prob_col, pos_label_hint = _select_positive_prob()
910 pos_prob_series = df_pred[pos_prob_col] if pos_prob_col and pos_prob_col in df_pred else None
911
912 # Confidence series: prefer label_probability, otherwise positive prob, otherwise max prob
913 confidence_series = None
914 if "label_probability" in df_pred.columns:
915 confidence_series = df_pred["label_probability"]
916 elif pos_prob_series is not None:
917 confidence_series = pos_prob_series
918 elif prob_cols_sorted:
919 confidence_series = df_pred[prob_cols_sorted].max(axis=1)
920
921 # True labels
922 def _extract_labels():
923 candidates = [
218 LABEL_COLUMN_NAME, 924 LABEL_COLUMN_NAME,
219 f"{LABEL_COLUMN_NAME}_ground_truth", 925 f"{LABEL_COLUMN_NAME}_ground_truth",
220 f"{LABEL_COLUMN_NAME}__ground_truth", 926 f"{LABEL_COLUMN_NAME}__ground_truth",
221 f"{LABEL_COLUMN_NAME}_target", 927 f"{LABEL_COLUMN_NAME}_target",
222 f"{LABEL_COLUMN_NAME}__target", 928 f"{LABEL_COLUMN_NAME}__target",
929 "label",
930 "label_true",
223 ] 931 ]
224 candidate_cols.extend( 932 candidates.extend(
225 [ 933 [
226 col 934 col
227 for col in df_pred.columns 935 for col in df_pred.columns
228 if (col.startswith(f"{LABEL_COLUMN_NAME}_") or col.startswith(f"{LABEL_COLUMN_NAME}__")) 936 if (col.startswith(f"{LABEL_COLUMN_NAME}_") or col.startswith(f"{LABEL_COLUMN_NAME}__"))
229 and "probabilities" not in col 937 and "probabilities" not in col
230 and "predictions" not in col 938 and "predictions" not in col
231 ] 939 ]
232 ) 940 )
233 for col in candidate_cols: 941 for col in candidates:
234 if col in df_pred.columns and col not in prob_cols: 942 if col in df_pred.columns and col not in prob_cols_sorted:
235 y_true = df_pred[col].values 943 return df_pred[col]
236 break 944 if label_data_path and Path(label_data_path).exists():
237 945 try:
238 if y_true is None: 946 df_all = pd.read_csv(label_data_path)
239 desc_path = exp_dir / "description.json" 947 if SPLIT_COLUMN_NAME in df_all.columns:
240 if desc_path.exists(): 948 df_all = df_all[df_all[SPLIT_COLUMN_NAME] == split_value].reset_index(drop=True)
241 try: 949 if LABEL_COLUMN_NAME in df_all.columns:
242 with open(desc_path, 'r') as f: 950 return df_all[LABEL_COLUMN_NAME].reset_index(drop=True)
243 desc = json.load(f) 951 except Exception as exc:
244 dataset_path = desc.get('dataset', '') 952 print(f"Warning: Unable to load labels from dataset: {exc}")
245 if dataset_path and Path(dataset_path).exists(): 953 return None
246 df_orig = pd.read_csv(dataset_path) 954
247 if SPLIT_COLUMN_NAME in df_orig.columns: 955 labels_series = _extract_labels()
248 df_orig = df_orig[df_orig[SPLIT_COLUMN_NAME] == 2].reset_index(drop=True) 956
249 if LABEL_COLUMN_NAME in df_orig.columns: 957 # Plot 1: Confidence Histogram
250 y_true = df_orig[LABEL_COLUMN_NAME].values 958 if confidence_series is not None:
251 if len(y_true) != len(df_pred): 959 fig_conf = go.Figure()
252 print( 960 fig_conf.add_trace(
253 f"Warning: Test set size mismatch. Truncating to {len(df_pred)} samples for ROC plot." 961 go.Histogram(
254 ) 962 x=confidence_series,
255 y_true = y_true[:len(df_pred)] 963 nbinsx=20,
256 else: 964 marker=dict(color="#1f77b4", line=dict(color="#ffffff", width=1)),
257 print("Warning: Original dataset referenced in description.json is unavailable.") 965 opacity=0.8,
258 except Exception as exc: # pragma: no cover - defensive 966 histnorm="percent",
259 print(f"Warning: Failed to recover labels from dataset: {exc}") 967 )
260 968 )
261 if y_true is None or len(y_true) == 0: 969 fig_conf.update_layout(
262 print("Warning: Unable to locate ground-truth labels for ROC plot.") 970 title=dict(text="Prediction Confidence Distribution", x=0.5),
263 return None 971 xaxis_title="Predicted probability (confidence)",
264 972 yaxis_title="Percentage (%)",
265 if len(y_true) != len(y_score): 973 bargap=0.05,
266 limit = min(len(y_true), len(y_score))
267 if limit == 0:
268 return None
269 print(f"Warning: Aligning prediction and label lengths to {limit} samples for ROC plot.")
270 y_true = y_true[:limit]
271 y_score = y_score[:limit]
272
273 # Get actual class names from probability column names
274 actual_classes = [col.replace('label_probabilities_', '') for col in prob_cols]
275 display_classes = class_labels if len(class_labels) == n_classes else actual_classes
276
277 # Binarize the output following sklearn example
278 # Use actual class names if they're strings, otherwise use range
279 if isinstance(y_true[0], str):
280 y_test = label_binarize(y_true, classes=actual_classes)
281 else:
282 y_test = label_binarize(y_true, classes=list(range(n_classes)))
283
284 # Handle binary classification case
285 if y_test.ndim != 2:
286 y_test = np.atleast_2d(y_test)
287
288 if n_classes == 2:
289 if y_test.shape[1] == 1:
290 y_test = np.hstack([1 - y_test, y_test])
291 elif y_test.shape[1] != 2:
292 print("Warning: Unexpected label binarization shape for binary ROC plot.")
293 return None
294 elif y_test.shape[1] != n_classes:
295 print("Warning: Label binarization did not produce expected class dimension; skipping ROC plot.")
296 return None
297
298 # Compute ROC curve and ROC area for each class (following sklearn example)
299 fpr = dict()
300 tpr = dict()
301 roc_auc = dict()
302
303 for i in range(n_classes):
304 if np.sum(y_test[:, i]) > 0: # Check if class exists in test set
305 fpr[i], tpr[i], _ = roc_curve(y_test[:, i], y_score[:, i])
306 roc_auc[i] = auc(fpr[i], tpr[i])
307
308 # Compute micro-average ROC curve and ROC area (sklearn example)
309 fpr["micro"], tpr["micro"], _ = roc_curve(y_test.ravel(), y_score.ravel())
310 roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])
311
312 # Create ROC curve plot
313 fig_roc = go.Figure()
314
315 # Colors for different classes
316 colors = [
317 '#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd',
318 '#8c564b', '#e377c2', '#7f7f7f', '#bcbd22', '#17becf'
319 ]
320
321 # Plot micro-average ROC curve first (most important)
322 fig_roc.add_trace(go.Scatter(
323 x=fpr["micro"],
324 y=tpr["micro"],
325 mode='lines',
326 name=f'Micro-average ROC (AUC = {roc_auc["micro"]:.3f})',
327 line=dict(color='deeppink', width=3, dash='dot'),
328 hovertemplate=('<b>Micro-average ROC</b><br>'
329 'FPR: %{x:.3f}<br>'
330 'TPR: %{y:.3f}<br>'
331 f'AUC: {roc_auc["micro"]:.3f}<extra></extra>')
332 ))
333
334 # Plot ROC curve for each class
335 for i in range(n_classes):
336 if i in roc_auc: # Only plot if class exists in test set
337 class_name = display_classes[i] if i < len(display_classes) else f"Class {i}"
338 color = colors[i % len(colors)]
339
340 fig_roc.add_trace(go.Scatter(
341 x=fpr[i],
342 y=tpr[i],
343 mode='lines',
344 name=f'{class_name} (AUC = {roc_auc[i]:.3f})',
345 line=dict(color=color, width=2),
346 hovertemplate=(f'<b>{class_name}</b><br>'
347 'FPR: %{x:.3f}<br>'
348 'TPR: %{y:.3f}<br>'
349 f'AUC: {roc_auc[i]:.3f}<extra></extra>')
350 ))
351
352 # Add diagonal line (random classifier)
353 fig_roc.add_trace(go.Scatter(
354 x=[0, 1],
355 y=[0, 1],
356 mode='lines',
357 name='Random Classifier',
358 line=dict(color='gray', width=1, dash='dash'),
359 hovertemplate='Random Classifier<br>AUC = 0.500<extra></extra>'
360 ))
361
362 # Calculate macro-average AUC
363 class_aucs = [roc_auc[i] for i in range(n_classes) if i in roc_auc]
364 if class_aucs:
365 macro_auc = np.mean(class_aucs)
366 title_text = f"ROC Curves (Micro-avg = {roc_auc['micro']:.3f}, Macro-avg = {macro_auc:.3f})"
367 else:
368 title_text = f"ROC Curves (Micro-avg = {roc_auc['micro']:.3f})"
369
370 fig_roc.update_layout(
371 title=dict(text=title_text, x=0.5),
372 xaxis_title="False Positive Rate",
373 yaxis_title="True Positive Rate",
374 width=700, 974 width=700,
375 height=600, 975 height=500,
376 margin=dict(t=80, l=80, r=80, b=80), 976 )
377 legend=dict( 977 _style_fig(fig_conf)
378 x=0.6, 978 plots.append({
379 y=0.1, 979 "title": "Prediction Confidence Distribution",
380 bgcolor="rgba(255,255,255,0.9)", 980 "html": pio.to_html(fig_conf, full_html=False, include_plotlyjs=False),
381 bordercolor="rgba(0,0,0,0.2)", 981 })
382 borderwidth=1 982
383 ), 983 # The remaining plots require true labels and a positive-class probability
384 hovermode='closest' 984 if labels_series is None or pos_prob_series is None:
385 ) 985 return plots
386 986
387 # Set equal aspect ratio and proper range 987 # Align lengths
388 fig_roc.update_xaxes(range=[0, 1.0]) 988 min_len = min(len(labels_series), len(pos_prob_series))
389 fig_roc.update_yaxes(range=[0, 1.05]) 989 if min_len == 0:
390 990 return plots
391 return { 991 y_true_raw = labels_series.iloc[:min_len]
392 "title": "ROC-AUC Curves", 992 y_score = np.array(pos_prob_series.iloc[:min_len], dtype=float)
393 "html": pio.to_html( 993
394 fig_roc, 994 # Determine positive label
395 full_html=False, 995 unique_labels = pd.unique(y_true_raw)
396 include_plotlyjs=False, 996 unique_labels_list = list(unique_labels)
397 config=config 997 positive_label = None
398 ) 998 if pos_label_hint and str(pos_label_hint) in [str(u) for u in unique_labels_list]:
399 } 999 positive_label = pos_label_hint
400 1000 elif len(unique_labels_list) == 2:
401 except Exception as e: 1001 positive_label = unique_labels_list[1]
402 print(f"Error building ROC-AUC plot: {e}") 1002 else:
403 return None 1003 positive_label = unique_labels_list[0]
1004
1005 y_true = (y_true_raw == positive_label).astype(int).values
1006
1007 # Plot 2: Calibration Curve
1008 bins = np.linspace(0.0, 1.0, 11)
1009 bin_ids = np.digitize(y_score, bins, right=True)
1010 bin_centers = []
1011 frac_positives = []
1012 for b in range(1, len(bins)):
1013 mask = bin_ids == b
1014 if not np.any(mask):
1015 continue
1016 bin_centers.append(y_score[mask].mean())
1017 frac_positives.append(y_true[mask].mean())
1018 if bin_centers and frac_positives:
1019 fig_cal = go.Figure()
1020 fig_cal.add_trace(
1021 go.Scatter(
1022 x=bin_centers,
1023 y=frac_positives,
1024 mode="lines+markers",
1025 name="Calibration",
1026 line=dict(color="#2ca02c", width=4),
1027 )
1028 )
1029 fig_cal.add_trace(
1030 go.Scatter(
1031 x=[0, 1],
1032 y=[0, 1],
1033 mode="lines",
1034 name="Perfect Calibration",
1035 line=dict(color="gray", width=2, dash="dash"),
1036 )
1037 )
1038 fig_cal.update_layout(
1039 title=dict(text="Calibration Curve", x=0.5),
1040 xaxis_title="Predicted probability",
1041 yaxis_title="Observed frequency",
1042 width=700,
1043 height=500,
1044 )
1045 _style_fig(fig_cal)
1046 plots.append({
1047 "title": "Calibration Curve (Predicted Probability vs Observed Frequency)",
1048 "html": pio.to_html(fig_cal, full_html=False, include_plotlyjs=False),
1049 })
1050
1051 # Plot 3: Threshold vs Metrics
1052 thresholds = np.linspace(0.0, 1.0, 21)
1053 accs, f1s, sens, specs = [], [], [], []
1054 for t in thresholds:
1055 y_pred = (y_score >= t).astype(int)
1056 tp = np.sum((y_true == 1) & (y_pred == 1))
1057 tn = np.sum((y_true == 0) & (y_pred == 0))
1058 fp = np.sum((y_true == 0) & (y_pred == 1))
1059 fn = np.sum((y_true == 1) & (y_pred == 0))
1060 acc = (tp + tn) / max(len(y_true), 1)
1061 prec = tp / max(tp + fp, 1e-9)
1062 rec = tp / max(tp + fn, 1e-9)
1063 f1 = 0 if (prec + rec) == 0 else 2 * prec * rec / (prec + rec)
1064 sensitivity = rec
1065 specificity = tn / max(tn + fp, 1e-9)
1066 accs.append(acc)
1067 f1s.append(f1)
1068 sens.append(sensitivity)
1069 specs.append(specificity)
1070
1071 fig_thresh = go.Figure()
1072 fig_thresh.add_trace(go.Scatter(x=thresholds, y=accs, mode="lines", name="Accuracy", line=dict(width=4)))
1073 fig_thresh.add_trace(go.Scatter(x=thresholds, y=f1s, mode="lines", name="F1", line=dict(width=4)))
1074 fig_thresh.add_trace(go.Scatter(x=thresholds, y=sens, mode="lines", name="Sensitivity", line=dict(width=4)))
1075 fig_thresh.add_trace(go.Scatter(x=thresholds, y=specs, mode="lines", name="Specificity", line=dict(width=4)))
1076 fig_thresh.update_layout(
1077 title=dict(text="Threshold Sweep: Accuracy, F1, Sensitivity, Specificity", x=0.5),
1078 xaxis_title="Decision threshold",
1079 yaxis_title="Metric value",
1080 width=700,
1081 height=500,
1082 legend=dict(
1083 x=0.7,
1084 y=0.2,
1085 bgcolor="rgba(255,255,255,0.9)",
1086 bordercolor="rgba(0,0,0,0.2)",
1087 borderwidth=1,
1088 ),
1089 shapes=[
1090 dict(
1091 type="line",
1092 x0=threshold,
1093 x1=threshold,
1094 y0=0,
1095 y1=1,
1096 xref="x",
1097 yref="paper",
1098 line=dict(color="#d62728", width=2, dash="dash"),
1099 )
1100 ] if isinstance(threshold, (int, float)) else [],
1101 annotations=[
1102 dict(
1103 x=threshold,
1104 y=1.02,
1105 xref="x",
1106 yref="paper",
1107 showarrow=False,
1108 text=f"Threshold = {threshold:.2f}",
1109 font=dict(size=11, color="#d62728"),
1110 )
1111 ] if isinstance(threshold, (int, float)) else [],
1112 )
1113 _style_fig(fig_thresh)
1114 plots.append({
1115 "title": "Threshold Sweep: Accuracy, F1, Sensitivity, Specificity",
1116 "html": pio.to_html(fig_thresh, full_html=False, include_plotlyjs=False),
1117 })
1118
1119 return plots