diff plotly_plots.py @ 17:db9be962dc13 draft

planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
author goeckslab
date Wed, 10 Dec 2025 00:24:13 +0000
parents d17e3a1b8659
children
line wrap: on
line diff
--- a/plotly_plots.py	Wed Dec 03 01:28:52 2025 +0000
+++ b/plotly_plots.py	Wed Dec 10 00:24:13 2025 +0000
@@ -7,6 +7,17 @@
 import plotly.graph_objects as go
 import plotly.io as pio
 from constants import LABEL_COLUMN_NAME, SPLIT_COLUMN_NAME
+from sklearn.metrics import (
+    accuracy_score,
+    auc,
+    average_precision_score,
+    f1_score,
+    precision_recall_curve,
+    precision_score,
+    recall_score,
+    roc_curve,
+)
+from sklearn.preprocessing import label_binarize
 
 
 def _style_fig(fig: go.Figure, font_size: int = 12) -> go.Figure:
@@ -21,6 +32,64 @@
     return fig
 
 
+def _fig_to_html(
+    fig: go.Figure, *, include_js: bool = False, config: Optional[dict] = None
+) -> str:
+    """Render a Plotly figure to a lightweight HTML fragment."""
+    include_plotlyjs = "cdn" if include_js else False
+    return pio.to_html(
+        fig,
+        full_html=False,
+        include_plotlyjs=include_plotlyjs,
+        config=config,
+    )
+
+
+def _wrap_plot(
+    title: str,
+    fig: go.Figure,
+    *,
+    include_js: bool = False,
+    config: Optional[dict] = None,
+) -> Dict[str, str]:
+    """Package a figure with its title for downstream HTML rendering."""
+    return {"title": title, "html": _fig_to_html(fig, include_js=include_js, config=config)}
+
+
+def _line_chart(
+    traces: List[tuple],
+    *,
+    title: str,
+    yaxis_title: str,
+) -> go.Figure:
+    """Build a basic epoch-indexed line chart for train/val/test curves."""
+    fig = go.Figure()
+    for name, series in traces:
+        if not series:
+            continue
+        epochs = list(range(1, len(series) + 1))
+        fig.add_trace(
+            go.Scatter(
+                x=epochs,
+                y=series,
+                mode="lines+markers",
+                name=name,
+                line=dict(width=4),
+            )
+        )
+
+    fig.update_layout(
+        title=dict(text=title, x=0.5),
+        xaxis_title="Epoch",
+        yaxis_title=yaxis_title,
+        width=760,
+        height=520,
+        hovermode="x unified",
+    )
+    _style_fig(fig)
+    return fig
+
+
 def _labels_from_metadata_dict(meta_dict: dict) -> List[str]:
     """Extract ordered label names from Ludwig train_set_metadata."""
     if not isinstance(meta_dict, dict):
@@ -106,6 +175,7 @@
     training_stats_path: Optional[str] = None,
     metadata_csv_path: Optional[str] = None,
     train_set_metadata_path: Optional[str] = None,
+    threshold: Optional[float] = None,
 ) -> List[Dict[str, str]]:
     """
     Read Ludwig’s test_statistics.json and build three interactive Plotly panels:
@@ -156,8 +226,11 @@
         )
     )
     fig_cm.update_traces(xgap=2, ygap=2)
+    cm_title = "Confusion Matrix"
+    if threshold is not None:
+        cm_title = f"Confusion Matrix (Threshold: {threshold})"
     fig_cm.update_layout(
-        title=dict(text="Confusion Matrix", x=0.5),
+        title=dict(text=cm_title, x=0.5),
         xaxis_title="Predicted",
         yaxis_title="Observed",
         yaxis_autorange="reversed",
@@ -196,25 +269,19 @@
                 yshift=-2,
             )
 
-    plots.append({
-        "title": "Confusion Matrix",
-        "html": pio.to_html(
-            fig_cm,
-            full_html=False,
-            include_plotlyjs="cdn",
-            config=common_cfg
-        )
-    })
+    plots.append(
+        _wrap_plot("Confusion Matrix", fig_cm, include_js=True, config=common_cfg)
+    )
 
-    # 1) ROC Curve (from test_statistics)
-    roc_plot = _build_static_roc_plot(label_stats, common_cfg, friendly_labels=labels)
-    if roc_plot:
-        plots.append(roc_plot)
+    # 1) ROC / PR curves only for binary tasks
+    if n_classes == 2:
+        roc_plot = _build_static_roc_plot(label_stats, common_cfg, friendly_labels=labels)
+        if roc_plot:
+            plots.append(roc_plot)
 
-    # 2) Precision-Recall Curve (from test_statistics)
-    pr_plot = _build_precision_recall_plot(label_stats, common_cfg)
-    if pr_plot:
-        plots.append(pr_plot)
+        pr_plot = _build_precision_recall_plot(label_stats, common_cfg)
+        if pr_plot:
+            plots.append(pr_plot)
 
     # 2) Classification Report Heatmap
     pcs = label_stats.get("per_class_stats", {})
@@ -259,15 +326,9 @@
             margin=dict(t=80, l=80, r=80, b=80),
         )
         _style_fig(fig_cr)
-        plots.append({
-            "title": "Per-Class metrics",
-            "html": pio.to_html(
-                fig_cr,
-                full_html=False,
-                include_plotlyjs=False,
-                config=common_cfg
-            )
-        })
+        plots.append(
+            _wrap_plot("Per-Class metrics", fig_cr, config=common_cfg)
+        )
 
     # 3) Prediction Diagnostics (from predictions.csv)
     # Note: appended separately in generate_html_report, not returned here.
@@ -294,8 +355,6 @@
     include_js = True  # Load Plotly.js once for this group
 
     def _get_series(stats: dict, metric: str) -> List[float]:
-        if metric not in stats:
-            return []
         vals = stats.get(metric, [])
         if isinstance(vals, list):
             return [float(v) for v in vals]
@@ -304,181 +363,98 @@
         except Exception:
             return []
 
-    def _line_plot(metric_key: str, title: str, yaxis_title: str) -> Optional[Dict[str, str]]:
-        train_series = _get_series(label_train, metric_key)
-        val_series = _get_series(label_val, metric_key)
+    metric_specs = [
+        ("loss", "Loss across epochs", "Loss"),
+        ("accuracy", "Accuracy across epochs", "Accuracy"),
+        ("roc_auc", "ROC-AUC across epochs", "ROC-AUC"),
+        ("precision", "Precision across epochs", "Precision"),
+        ("recall", "Recall/Sensitivity across epochs", "Recall"),
+        ("specificity", "Specificity across epochs", "Specificity"),
+    ]
+
+    for key, title, yaxis in metric_specs:
+        train_series = _get_series(label_train, key)
+        val_series = _get_series(label_val, key)
         if not train_series and not val_series:
-            return None
-        epochs_train = list(range(1, len(train_series) + 1))
-        epochs_val = list(range(1, len(val_series) + 1))
-        fig = go.Figure()
-        if train_series:
-            fig.add_trace(
-                go.Scatter(
-                    x=epochs_train,
-                    y=train_series,
-                    mode="lines+markers",
-                    name="Train",
-                    line=dict(width=4),
-                )
-            )
-        if val_series:
-            fig.add_trace(
-                go.Scatter(
-                    x=epochs_val,
-                    y=val_series,
-                    mode="lines+markers",
-                    name="Validation",
-                    line=dict(width=4),
-                )
-            )
-        fig.update_layout(
-            title=dict(text=title, x=0.5),
-            xaxis_title="Epoch",
-            yaxis_title=yaxis_title,
-            width=760,
-            height=520,
-            hovermode="x unified",
+            continue
+        fig = _line_chart(
+            [("Train", train_series), ("Validation", val_series)],
+            title=title,
+            yaxis_title=yaxis,
         )
-        _style_fig(fig)
-        return {
-            "title": title,
-            "html": pio.to_html(
-                fig,
-                full_html=False,
-                include_plotlyjs="cdn" if include_js else False,
-            ),
-        }
-
-    # Core learning curves
-    for key, title in [
-        ("roc_auc", "ROC-AUC across epochs"),
-        ("precision", "Precision across epochs"),
-        ("recall", "Recall/Sensitivity across epochs"),
-        ("specificity", "Specificity across epochs"),
-    ]:
-        plot = _line_plot(key, title, title.replace("Learning Curve", "").strip())
-        if plot:
-            plots.append(plot)
-            include_js = False
+        plots.append(_wrap_plot(title, fig, include_js=include_js))
+        include_js = False
 
     # Precision vs Recall evolution (validation)
     val_prec = _get_series(label_val, "precision")
     val_rec = _get_series(label_val, "recall")
     if val_prec and val_rec:
-        epochs = list(range(1, min(len(val_prec), len(val_rec)) + 1))
-        fig_pr = go.Figure()
-        fig_pr.add_trace(
-            go.Scatter(
-                x=epochs,
-                y=val_prec[: len(epochs)],
-                mode="lines+markers",
-                name="Precision",
-            )
-        )
-        fig_pr.add_trace(
-            go.Scatter(
-                x=epochs,
-                y=val_rec[: len(epochs)],
-                mode="lines+markers",
-                name="Recall",
-            )
+        max_len = min(len(val_prec), len(val_rec))
+        fig_pr = _line_chart(
+            [
+                ("Precision", val_prec[:max_len]),
+                ("Recall", val_rec[:max_len]),
+            ],
+            title="Validation Precision and Recall by Epoch",
+            yaxis_title="Value",
         )
-        fig_pr.update_layout(
-            title=dict(text="Validation Precision and Recall by Epoch", x=0.5),
-            xaxis_title="Epoch",
-            yaxis_title="Value",
-            width=760,
-            height=520,
-            hovermode="x unified",
-        )
-        _style_fig(fig_pr)
-        plots.append({
-            "title": "Precision vs Recall Evolution",
-            "html": pio.to_html(
-                fig_pr,
-                full_html=False,
-                include_plotlyjs="cdn" if include_js else False,
-            ),
-        })
+        plots.append(_wrap_plot("Precision vs Recall Evolution", fig_pr, include_js=include_js))
         include_js = False
 
-    # F1-score derived
     def _compute_f1(p: List[float], r: List[float]) -> List[float]:
-        f1_vals = []
-        for prec, rec in zip(p, r):
-            if (prec + rec) == 0:
-                f1_vals.append(0.0)
-            else:
-                f1_vals.append(2 * prec * rec / (prec + rec))
-        return f1_vals
+        return [
+            0.0 if (prec + rec) == 0 else 2 * prec * rec / (prec + rec)
+            for prec, rec in zip(p, r)
+        ]
 
     f1_train = _compute_f1(_get_series(label_train, "precision"), _get_series(label_train, "recall"))
     f1_val = _compute_f1(val_prec, val_rec)
     if f1_train or f1_val:
-        fig = go.Figure()
-        if f1_train:
-            fig.add_trace(go.Scatter(x=list(range(1, len(f1_train) + 1)), y=f1_train, mode="lines+markers", name="Train", line=dict(width=4)))
-        if f1_val:
-            fig.add_trace(go.Scatter(x=list(range(1, len(f1_val) + 1)), y=f1_val, mode="lines+markers", name="Validation", line=dict(width=4)))
-        fig.update_layout(
-            title=dict(text="F1-Score across epochs (derived)", x=0.5),
-            xaxis_title="Epoch",
+        fig_f1 = _line_chart(
+            [("Train", f1_train), ("Validation", f1_val)],
+            title="F1-Score across epochs (derived)",
             yaxis_title="F1-Score",
-            width=760,
-            height=520,
-            hovermode="x unified",
         )
-        _style_fig(fig)
-        plots.append({
-            "title": "F1-Score across epochs (derived)",
-            "html": pio.to_html(
-                fig,
-                full_html=False,
-                include_plotlyjs="cdn" if include_js else False,
-            ),
-        })
+        plots.append(_wrap_plot("F1-Score across epochs (derived)", fig_f1, include_js=include_js))
         include_js = False
 
     # Overfitting Gap: Train vs Val ROC-AUC (gap)
     roc_train = _get_series(label_train, "roc_auc")
     roc_val = _get_series(label_val, "roc_auc")
     if roc_train and roc_val:
-        epochs_gap = list(range(1, min(len(roc_train), len(roc_val)) + 1))
-        gaps = [t - v for t, v in zip(roc_train[:len(epochs_gap)], roc_val[:len(epochs_gap)])]
-        fig_gap = go.Figure()
-        fig_gap.add_trace(go.Scatter(x=epochs_gap, y=gaps, mode="lines+markers", name="Train - Val ROC-AUC", line=dict(width=4)))
-        fig_gap.update_layout(
-            title=dict(text="Overfitting gap: ROC-AUC across epochs", x=0.5),
-            xaxis_title="Epoch",
+        max_len = min(len(roc_train), len(roc_val))
+        gaps = [t - v for t, v in zip(roc_train[:max_len], roc_val[:max_len])]
+        fig_gap = _line_chart(
+            [("Train - Val ROC-AUC", gaps)],
+            title="Overfitting gap: ROC-AUC across epochs",
             yaxis_title="Gap",
-            width=760,
-            height=520,
-            hovermode="x unified",
         )
-        _style_fig(fig_gap)
-        plots.append({
-            "title": "Overfitting gap: ROC-AUC across epochs",
-            "html": pio.to_html(
-                fig_gap,
-                full_html=False,
-                include_plotlyjs="cdn" if include_js else False,
-            ),
-        })
+        plots.append(_wrap_plot("Overfitting gap: ROC-AUC across epochs", fig_gap, include_js=include_js))
         include_js = False
 
     # Best Epoch Dashboard (based on max val ROC-AUC)
     if roc_val:
         best_idx = int(np.argmax(roc_val))
         best_epoch = best_idx + 1
-        spec_val = _get_series(label_val, "specificity")
-        metrics_at_best = {
-            "ROC-AUC": roc_val[best_idx] if best_idx < len(roc_val) else None,
-            "Precision": val_prec[best_idx] if best_idx < len(val_prec) else None,
-            "Recall": val_rec[best_idx] if best_idx < len(val_rec) else None,
-            "Specificity": spec_val[best_idx] if best_idx < len(spec_val) else None,
-            "F1-Score": f1_val[best_idx] if best_idx < len(f1_val) else None,
+        metrics_at_best: Dict[str, Optional[float]] = {
+            "ROC-AUC": roc_val[best_idx] if best_idx < len(roc_val) else None
         }
+
+        for metric_key, label in [
+            ("accuracy", "Accuracy"),
+            ("balanced_accuracy", "Balanced Accuracy"),
+            ("precision", "Precision"),
+            ("recall", "Recall"),
+            ("specificity", "Specificity"),
+            ("loss", "Loss"),
+        ]:
+            series = _get_series(label_val, metric_key)
+            if series and best_idx < len(series):
+                metrics_at_best[label] = series[best_idx]
+
+        if f1_val and best_idx < len(f1_val):
+            metrics_at_best["F1-Score (derived)"] = f1_val[best_idx]
+
         fig_best = go.Figure()
         for name, value in metrics_at_best.items():
             if value is not None:
@@ -492,15 +468,7 @@
             showlegend=False,
         )
         _style_fig(fig_best)
-        plots.append({
-            "title": "Best Validation Epoch Snapshot (Metrics)",
-            "html": pio.to_html(
-                fig_best,
-                full_html=False,
-                include_plotlyjs="cdn" if include_js else False,
-            ),
-        })
-        include_js = False
+        plots.append(_wrap_plot("Best Validation Epoch Snapshot (Metrics)", fig_best, include_js=include_js))
 
     return plots
 
@@ -529,46 +497,13 @@
     val_series = _get_regression_series(val_split, metric_key)
     if not train_series and not val_series:
         return None
-    epochs_train = list(range(1, len(train_series) + 1))
-    epochs_val = list(range(1, len(val_series) + 1))
-    fig = go.Figure()
-    if train_series:
-        fig.add_trace(
-            go.Scatter(
-                x=epochs_train,
-                y=train_series,
-                mode="lines+markers",
-                name="Train",
-                line=dict(width=4),
-            )
-        )
-    if val_series:
-        fig.add_trace(
-            go.Scatter(
-                x=epochs_val,
-                y=val_series,
-                mode="lines+markers",
-                name="Validation",
-                line=dict(width=4),
-            )
-        )
-    fig.update_layout(
-        title=dict(text=title, x=0.5),
-        xaxis_title="Epoch",
+
+    fig = _line_chart(
+        [("Train", train_series), ("Validation", val_series)],
+        title=title,
         yaxis_title=yaxis_title,
-        width=760,
-        height=520,
-        hovermode="x unified",
     )
-    _style_fig(fig)
-    return {
-        "title": title,
-        "html": pio.to_html(
-            fig,
-            full_html=False,
-            include_plotlyjs="cdn" if include_js else False,
-        ),
-    }
+    return _wrap_plot(title, fig, include_js=include_js)
 
 
 def build_regression_train_val_plots(train_stats_path: str) -> List[Dict[str, str]]:
@@ -627,46 +562,25 @@
         ("r2", "R² Across Epochs", "R²"),
         ("loss", "Loss Across Epochs", "Loss"),
     ]
-    epochs = None
     for metric_key, title, ytitle in metrics:
         series = _get_regression_series(label_test, metric_key)
         if not series:
             continue
-        if epochs is None:
-            epochs = list(range(1, len(series) + 1))
-        fig = go.Figure()
-        fig.add_trace(
-            go.Scatter(
-                x=epochs,
-                y=series[: len(epochs)],
-                mode="lines+markers",
-                name="Test",
-                line=dict(width=4),
-            )
+        fig = _line_chart(
+            [("Test", series)],
+            title=title,
+            yaxis_title=ytitle,
         )
-        fig.update_layout(
-            title=dict(text=title, x=0.5),
-            xaxis_title="Epoch",
-            yaxis_title=ytitle,
-            width=760,
-            height=520,
-            hovermode="x unified",
-        )
-        _style_fig(fig)
-        plots.append({
-            "title": title,
-            "html": pio.to_html(
-                fig,
-                full_html=False,
-                include_plotlyjs="cdn" if include_js else False,
-            ),
-        })
+        plots.append(_wrap_plot(title, fig, include_js=include_js))
         include_js = False
     return plots
 
 
 def _build_static_roc_plot(
-    label_stats: dict, config: dict, friendly_labels: Optional[List[str]] = None
+    label_stats: dict,
+    config: dict,
+    friendly_labels: Optional[List[str]] = None,
+    threshold: Optional[float] = None,
 ) -> Optional[Dict[str, str]]:
     """Build ROC curve directly from test_statistics.json (single curve)."""
     roc_data = label_stats.get("roc_curve")
@@ -776,6 +690,42 @@
         fig.update_xaxes(range=[0, 1.0])
         fig.update_yaxes(range=[0, 1.05])
 
+        roc_thresholds = roc_data.get("thresholds")
+        if threshold is not None and isinstance(roc_thresholds, list) and len(roc_thresholds) == len(fpr):
+            try:
+                diffs = [abs(th - threshold) for th in roc_thresholds]
+                best_idx = int(np.argmin(diffs))
+                # dashed guides through the chosen point
+                fig.add_shape(
+                    type="line",
+                    x0=fpr[best_idx],
+                    x1=fpr[best_idx],
+                    y0=0,
+                    y1=tpr[best_idx],
+                    line=dict(color="gray", width=2, dash="dash"),
+                )
+                fig.add_shape(
+                    type="line",
+                    x0=0,
+                    x1=fpr[best_idx],
+                    y0=tpr[best_idx],
+                    y1=tpr[best_idx],
+                    line=dict(color="gray", width=2, dash="dash"),
+                )
+                fig.add_trace(
+                    go.Scatter(
+                        x=[fpr[best_idx]],
+                        y=[tpr[best_idx]],
+                        mode="markers",
+                        marker=dict(color="black", size=10, symbol="x"),
+                        name=f"Threshold={threshold}",
+                        hovertemplate="FPR: %{x:.3f}<br>TPR: %{y:.3f}<br>Threshold: %{text}<extra></extra>",
+                        text=[f"{threshold}"],
+                    )
+                )
+            except Exception as exc:
+                print(f"Warning: could not add threshold marker to ROC: {exc}")
+
         fig.add_annotation(
             x=0.5,
             y=-0.15,
@@ -786,21 +736,17 @@
             xanchor="center",
         )
 
-        return {
-            "title": "ROC Curve",
-            "html": pio.to_html(
-                fig,
-                full_html=False,
-                include_plotlyjs=False,
-                config=config,
-            ),
-        }
+        return _wrap_plot("ROC Curve", fig, config=config)
     except Exception as e:
         print(f"Error building ROC plot: {e}")
         return None
 
 
-def _build_precision_recall_plot(label_stats: dict, config: dict) -> Optional[Dict[str, str]]:
+def _build_precision_recall_plot(
+    label_stats: dict,
+    config: dict,
+    threshold: Optional[float] = None,
+) -> Optional[Dict[str, str]]:
     """Build Precision-Recall curve directly from test_statistics.json."""
     pr_data = label_stats.get("precision_recall_curve")
     if not isinstance(pr_data, dict):
@@ -811,6 +757,8 @@
     if not precisions or not recalls or len(precisions) != len(recalls):
         return None
 
+    thresholds = pr_data.get("thresholds")
+
     try:
         fig = go.Figure()
         fig.add_trace(
@@ -851,15 +799,41 @@
         fig.update_xaxes(range=[0, 1.0])
         fig.update_yaxes(range=[0, 1.05])
 
-        return {
-            "title": "Precision-Recall Curve",
-            "html": pio.to_html(
-                fig,
-                full_html=False,
-                include_plotlyjs=False,
-                config=config,
-            ),
-        }
+        if threshold is not None and isinstance(thresholds, list) and len(thresholds) == len(recalls):
+            try:
+                diffs = [abs(th - threshold) for th in thresholds]
+                best_idx = int(np.argmin(diffs))
+                fig.add_shape(
+                    type="line",
+                    x0=recalls[best_idx],
+                    x1=recalls[best_idx],
+                    y0=0,
+                    y1=precisions[best_idx],
+                    line=dict(color="gray", width=2, dash="dash"),
+                )
+                fig.add_shape(
+                    type="line",
+                    x0=0,
+                    x1=recalls[best_idx],
+                    y0=precisions[best_idx],
+                    y1=precisions[best_idx],
+                    line=dict(color="gray", width=2, dash="dash"),
+                )
+                fig.add_trace(
+                    go.Scatter(
+                        x=[recalls[best_idx]],
+                        y=[precisions[best_idx]],
+                        mode="markers",
+                        marker=dict(color="black", size=10, symbol="x"),
+                        name=f"Threshold={threshold}",
+                        hovertemplate="Recall: %{x:.3f}<br>Precision: %{y:.3f}<br>Threshold: %{text}<extra></extra>",
+                        text=[f"{threshold}"],
+                    )
+                )
+            except Exception as exc:
+                print(f"Warning: could not add threshold marker to PR: {exc}")
+
+        return _wrap_plot("Precision-Recall Curve", fig, config=config)
     except Exception as e:
         print(f"Error building Precision-Recall plot: {e}")
         return None
@@ -869,7 +843,6 @@
     predictions_path: str,
     label_data_path: Optional[str] = None,
     split_value: int = 2,
-    threshold: Optional[float] = None,
 ) -> List[Dict[str, str]]:
     """Generate diagnostic plots from predictions.csv for classification tasks."""
     preds_file = Path(predictions_path)
@@ -883,12 +856,89 @@
         return []
 
     plots: List[Dict[str, str]] = []
+    labels_from_dataset: Optional[pd.Series] = None
+
+    filtered_by_split = False
+
+    # If a split column exists, focus on the requested split (e.g., validation=1, test=2).
+    # If not, but label_data_path is available and matches row count, use it to filter predictions.
+    if SPLIT_COLUMN_NAME in df_pred.columns:
+        df_pred = df_pred[df_pred[SPLIT_COLUMN_NAME] == split_value].reset_index(drop=True)
+        if df_pred.empty:
+            return []
+        filtered_by_split = True
+    elif label_data_path and Path(label_data_path).exists():
+        try:
+            df_labels_all = pd.read_csv(label_data_path)
+            if SPLIT_COLUMN_NAME in df_labels_all.columns and len(df_labels_all) == len(df_pred):
+                split_mask = pd.to_numeric(df_labels_all[SPLIT_COLUMN_NAME], errors="coerce") == split_value
+                labels_from_dataset = df_labels_all.loc[split_mask, LABEL_COLUMN_NAME].reset_index(drop=True)
+                df_pred = df_pred.loc[split_mask].reset_index(drop=True)
+                if df_pred.empty:
+                    return []
+                filtered_by_split = True
+        except Exception as exc:
+            print(f"Warning: Unable to filter predictions by split from label data: {exc}")
+
+    # Fallback: no split info available. Assume the predictions file is already filtered
+    # (common for test-only exports) and avoid heuristic slicing that could discard rows.
+    if not filtered_by_split:
+        if split_value != 2:
+            return []
+
+    def _strip_prob_prefix(col: str) -> str:
+        if col.startswith("label_probabilities_"):
+            return col.replace("label_probabilities_", "")
+        if col.startswith("probabilities_"):
+            return col.replace("probabilities_", "")
+        return col
+
+    def _maybe_expand_probabilities_column(df: pd.DataFrame, labels_guess: List[str]) -> List[str]:
+        """If only a single 'probabilities' column exists (list-like), expand it into per-class columns."""
+        if "probabilities" not in df.columns:
+            return []
+        try:
+            # Parse first non-null entry to infer length
+            first_val = df["probabilities"].dropna().iloc[0]
+            parsed = first_val
+            if isinstance(first_val, str):
+                parsed = json.loads(first_val)
+            probs = list(parsed)
+            n = len(probs)
+            if n == 0:
+                return []
+            # Build labels: prefer provided guess; otherwise numeric
+            if labels_guess and len(labels_guess) == n:
+                labels_use = labels_guess
+            else:
+                labels_use = [str(i) for i in range(n)]
+            # Expand column
+            for idx, lbl in enumerate(labels_use):
+                df[f"probabilities_{lbl}"] = df["probabilities"].apply(
+                    lambda v: (json.loads(v)[idx] if isinstance(v, str) else list(v)[idx]) if pd.notnull(v) else np.nan
+                )
+            return [f"probabilities_{lbl}" for lbl in labels_use]
+        except Exception:
+            return []
 
     # Identify probability columns
     prob_cols = [
-        c for c in df_pred.columns
-        if c.startswith("label_probabilities_") and c != "label_probabilities"
+        c
+        for c in df_pred.columns
+        if (
+            (c.startswith("label_probabilities_") or c.startswith("probabilities_"))
+            and c != "label_probabilities"
+        )
     ]
+    if not prob_cols and "label_probability" in df_pred.columns:
+        prob_cols = ["label_probability"]
+    if not prob_cols and "probability" in df_pred.columns:
+        prob_cols = ["probability"]
+    if not prob_cols and "prediction_probability" in df_pred.columns:
+        prob_cols = ["prediction_probability"]
+    if not prob_cols and "probabilities" in df_pred.columns:
+        labels_guess = sorted([str(u) for u in pd.unique(df_pred[LABEL_COLUMN_NAME])])
+        prob_cols = _maybe_expand_probabilities_column(df_pred, labels_guess)
     prob_cols_sorted = sorted(prob_cols)
 
     def _select_positive_prob():
@@ -897,14 +947,14 @@
         # Prefer a column indicating positive/event/true/1
         preferred_keys = ("event", "true", "positive", "pos", "1")
         for col in prob_cols_sorted:
-            suffix = col.replace("label_probabilities_", "").lower()
+            suffix = _strip_prob_prefix(col).lower()
             if any(k in suffix for k in preferred_keys):
                 return col, suffix
         if len(prob_cols_sorted) == 2:
             col = prob_cols_sorted[1]
-            return col, col.replace("label_probabilities_", "")
+            return col, _strip_prob_prefix(col)
         col = prob_cols_sorted[0]
-        return col, col.replace("label_probabilities_", "")
+        return col, _strip_prob_prefix(col)
 
     pos_prob_col, pos_label_hint = _select_positive_prob()
     pos_prob_series = df_pred[pos_prob_col] if pos_prob_col and pos_prob_col in df_pred else None
@@ -920,6 +970,8 @@
 
     # True labels
     def _extract_labels():
+        if labels_from_dataset is not None:
+            return labels_from_dataset
         candidates = [
             LABEL_COLUMN_NAME,
             f"{LABEL_COLUMN_NAME}_ground_truth",
@@ -975,10 +1027,7 @@
             height=500,
         )
         _style_fig(fig_conf)
-        plots.append({
-            "title": "Prediction Confidence Distribution",
-            "html": pio.to_html(fig_conf, full_html=False, include_plotlyjs=False),
-        })
+        plots.append(_wrap_plot("Prediction Confidence Distribution", fig_conf))
 
     # The remaining plots require true labels and a positive-class probability
     if labels_series is None or pos_prob_series is None:
@@ -1004,116 +1053,470 @@
 
     y_true = (y_true_raw == positive_label).astype(int).values
 
-    # Plot 2: Calibration Curve
-    bins = np.linspace(0.0, 1.0, 11)
-    bin_ids = np.digitize(y_score, bins, right=True)
-    bin_centers = []
-    frac_positives = []
-    for b in range(1, len(bins)):
-        mask = bin_ids == b
-        if not np.any(mask):
-            continue
-        bin_centers.append(y_score[mask].mean())
-        frac_positives.append(y_true[mask].mean())
-    if bin_centers and frac_positives:
-        fig_cal = go.Figure()
-        fig_cal.add_trace(
+    # Utility: compute calibration points
+    def _calibration_points(y_true_bin: np.ndarray, scores: np.ndarray):
+        bins = np.linspace(0.0, 1.0, 11)
+        bin_ids = np.digitize(scores, bins, right=True)
+        bin_centers, frac_positives = [], []
+        for b in range(1, len(bins)):
+            mask = bin_ids == b
+            if not np.any(mask):
+                continue
+            bin_centers.append(scores[mask].mean())
+            frac_positives.append(y_true_bin[mask].mean())
+        return bin_centers, frac_positives
+
+    # Plot 2: Calibration Curve (multi-class aware; one-vs-rest per label)
+    label_prob_map = {}
+    for col in prob_cols_sorted:
+        if col.startswith("label_probabilities_"):
+            cls = col.replace("label_probabilities_", "")
+            label_prob_map[cls] = col
+
+    unique_label_strs = [str(u) for u in unique_labels_list]
+    if len(label_prob_map) > 1 and len(unique_label_strs) > 2:
+        # Skip multi-class calibration curve for now (not informative in current report)
+        pass
+    else:
+        # Binary/unknown fallback (previous behavior)
+        bin_centers, frac_positives = _calibration_points(y_true, y_score)
+        if bin_centers and frac_positives:
+            fig_cal = go.Figure()
+            fig_cal.add_trace(
+                go.Scatter(
+                    x=bin_centers,
+                    y=frac_positives,
+                    mode="lines+markers",
+                    name="Calibration",
+                    line=dict(color="#2ca02c", width=4),
+                )
+            )
+            fig_cal.add_trace(
+                go.Scatter(
+                    x=[0, 1],
+                    y=[0, 1],
+                    mode="lines",
+                    name="Perfect Calibration",
+                    line=dict(color="gray", width=2, dash="dash"),
+                )
+            )
+            fig_cal.update_layout(
+                title=dict(text="Calibration Curve", x=0.5),
+                xaxis_title="Predicted probability",
+                yaxis_title="Observed frequency",
+                width=700,
+                height=500,
+            )
+            _style_fig(fig_cal)
+            plots.append(
+                _wrap_plot(
+                    "Calibration Curve (Predicted Probability vs Observed Frequency)",
+                    fig_cal,
+                )
+            )
+
+    return plots
+
+
+def build_binary_threshold_plot(
+    predictions_path: str,
+    label_data_path: Optional[str] = None,
+    split_value: int = 1,
+) -> Optional[Dict[str, str]]:
+    """Build a binary threshold sweep plot (accuracy, precision, recall, F1) for a given split."""
+    preds_file = Path(predictions_path)
+    if not preds_file.exists():
+        return None
+
+    try:
+        df_pred = pd.read_csv(predictions_path)
+    except Exception as exc:
+        print(f"Warning: Unable to read predictions CSV for threshold plot: {exc}")
+        return None
+
+    labels_from_dataset: Optional[pd.Series] = None
+    df_full = df_pred.copy()
+
+    def _filter_by_split(df: pd.DataFrame, split_val: int) -> pd.DataFrame:
+        if SPLIT_COLUMN_NAME in df.columns:
+            return df[df[SPLIT_COLUMN_NAME] == split_val].reset_index(drop=True)
+        return df
+
+    # Try preferred split, then fallback to others with data (val -> test -> train)
+    candidate_splits = [split_value, 2, 0, 1] if split_value == 1 else [split_value, 1, 0, 2]
+    df_candidate = pd.DataFrame()
+    used_split: Optional[int] = None
+    for sv in candidate_splits:
+        df_candidate = _filter_by_split(df_full, sv)
+        if not df_candidate.empty:
+            used_split = sv
+            break
+    if used_split is None:
+        df_candidate = df_full
+    df_pred = df_candidate.reset_index(drop=True)
+
+    # If still empty (e.g., split column exists but no rows for candidates), fall back to all rows
+    if df_pred.empty:
+        df_pred = df_full.reset_index(drop=True)
+        labels_from_dataset = None
+
+    if label_data_path and Path(label_data_path).exists():
+        try:
+            df_labels_all = pd.read_csv(label_data_path)
+            if SPLIT_COLUMN_NAME in df_labels_all.columns and len(df_labels_all) == len(df_full):
+                mask = (
+                    pd.to_numeric(df_labels_all[SPLIT_COLUMN_NAME], errors="coerce") == used_split
+                    if used_split is not None and SPLIT_COLUMN_NAME in df_labels_all.columns
+                    else pd.Series([True] * len(df_full))
+                )
+                labels_from_dataset = df_labels_all.loc[mask, LABEL_COLUMN_NAME].reset_index(drop=True)
+                if len(labels_from_dataset) == len(df_pred):
+                    labels_from_dataset = labels_from_dataset.reset_index(drop=True)
+        except Exception as exc:
+            print(f"Warning: Unable to align labels for threshold plot: {exc}")
+
+    # Identify probability columns
+    prob_cols = [
+        c
+        for c in df_pred.columns
+        if (
+            (c.startswith("label_probabilities_") or c.startswith("probabilities_"))
+            and c != "label_probabilities"
+        )
+    ]
+    if not prob_cols and "probabilities" in df_pred.columns:
+        labels_guess = sorted([str(u) for u in pd.unique(df_pred.get(LABEL_COLUMN_NAME, []))])
+        # reuse expansion logic from diagnostics
+        try:
+            first_val = df_pred["probabilities"].dropna().iloc[0]
+            parsed = json.loads(first_val) if isinstance(first_val, str) else list(first_val)
+            n = len(parsed)
+            if n > 0:
+                if labels_guess and len(labels_guess) == n:
+                    labels_use = labels_guess
+                else:
+                    labels_use = [str(i) for i in range(n)]
+                for idx, lbl in enumerate(labels_use):
+                    df_pred[f"probabilities_{lbl}"] = df_pred["probabilities"].apply(
+                        lambda v: (json.loads(v)[idx] if isinstance(v, str) else list(v)[idx]) if pd.notnull(v) else np.nan
+                    )
+                prob_cols = [f"probabilities_{lbl}" for lbl in labels_use]
+        except Exception:
+            prob_cols = []
+    prob_cols_sorted = sorted(prob_cols)
+
+    def _strip_prob_prefix(col: str) -> str:
+        if col.startswith("label_probabilities_"):
+            return col.replace("label_probabilities_", "")
+        if col.startswith("probabilities_"):
+            return col.replace("probabilities_", "")
+        return col
+
+    # True labels
+    def _extract_labels():
+        if labels_from_dataset is not None:
+            return labels_from_dataset
+        for col in [
+            LABEL_COLUMN_NAME,
+            f"{LABEL_COLUMN_NAME}_ground_truth",
+            f"{LABEL_COLUMN_NAME}__ground_truth",
+            f"{LABEL_COLUMN_NAME}_target",
+            f"{LABEL_COLUMN_NAME}__target",
+            "label",
+            "label_true",
+            "label_predictions",
+            "prediction",
+        ]:
+            if col in df_pred.columns and col not in prob_cols_sorted:
+                return df_pred[col]
+        return None
+
+    labels_series = _extract_labels()
+    if labels_series is None or not prob_cols_sorted:
+        return None
+
+    # Positive prob column selection
+    preferred_keys = ("event", "true", "positive", "pos", "1")
+    pos_prob_col = None
+    for col in prob_cols_sorted:
+        suffix = _strip_prob_prefix(col).lower()
+        if any(k in suffix for k in preferred_keys):
+            pos_prob_col = col
+            break
+    if pos_prob_col is None:
+        pos_prob_col = prob_cols_sorted[-1]
+
+    min_len = min(len(labels_series), len(df_pred[pos_prob_col]))
+    if min_len == 0:
+        return None
+
+    y_true = np.array(labels_series.iloc[:min_len])
+    # map to binary 0/1
+    unique_labels = pd.unique(y_true)
+    if len(unique_labels) < 2:
+        return None
+    positive_label = unique_labels[1] if len(unique_labels) >= 2 else unique_labels[0]
+    y_true_bin = (y_true == positive_label).astype(int)
+    y_score = np.array(df_pred[pos_prob_col].iloc[:min_len], dtype=float)
+
+    thresholds = np.linspace(0.0, 1.0, 101)
+    accs: List[float] = []
+    precs: List[float] = []
+    recs: List[float] = []
+    f1s: List[float] = []
+    for t in thresholds:
+        preds = (y_score >= t).astype(int)
+        accs.append(accuracy_score(y_true_bin, preds))
+        precs.append(precision_score(y_true_bin, preds, zero_division=0))
+        recs.append(recall_score(y_true_bin, preds, zero_division=0))
+        f1s.append(f1_score(y_true_bin, preds, zero_division=0))
+
+    best_idx = int(np.argmax(f1s))
+    best_thr = thresholds[best_idx]
+
+    fig = go.Figure()
+    fig.add_trace(go.Scatter(x=thresholds, y=accs, mode="lines", name="Accuracy", line=dict(width=4)))
+    fig.add_trace(go.Scatter(x=thresholds, y=precs, mode="lines", name="Precision", line=dict(width=4)))
+    fig.add_trace(go.Scatter(x=thresholds, y=recs, mode="lines", name="Recall", line=dict(width=4)))
+    fig.add_trace(go.Scatter(x=thresholds, y=f1s, mode="lines", name="F1-Score", line=dict(width=4)))
+    fig.add_shape(
+        type="line",
+        x0=best_thr,
+        x1=best_thr,
+        y0=0,
+        y1=1,
+        line=dict(color="gray", width=2, dash="dash"),
+    )
+    fig.update_layout(
+        title=dict(text="Threshold plot", x=0.5),
+        xaxis_title="Threshold",
+        yaxis_title="Metric value",
+        yaxis=dict(range=[0, 1]),
+        width=760,
+        height=520,
+        legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1),
+    )
+    _style_fig(fig)
+    return _wrap_plot("Threshold plot", fig, include_js=True)
+
+
+def build_multiclass_roc_pr_plots(
+    predictions_path: str,
+    split_value: int = 2,
+) -> List[Dict[str, str]]:
+    """Build one-vs-rest ROC and PR curves for multi-class classification from predictions."""
+    preds_file = Path(predictions_path)
+    if not preds_file.exists():
+        return []
+    try:
+        df_pred = pd.read_csv(predictions_path)
+    except Exception as exc:
+        print(f"Warning: Unable to read predictions CSV: {exc}")
+        return []
+
+    if SPLIT_COLUMN_NAME in df_pred.columns:
+        df_pred = df_pred[df_pred[SPLIT_COLUMN_NAME] == split_value].reset_index(drop=True)
+    if df_pred.empty:
+        return []
+
+    if LABEL_COLUMN_NAME not in df_pred.columns:
+        return []
+
+    # Identify per-class probability columns
+    prob_cols = [
+        c
+        for c in df_pred.columns
+        if (
+            (c.startswith("label_probabilities_") or c.startswith("probabilities_"))
+            and c != "label_probabilities"
+        )
+    ]
+    if not prob_cols:
+        return []
+    labels = [c.replace("label_probabilities_", "").replace("probabilities_", "") for c in prob_cols]
+    labels_sorted = sorted(labels)
+
+    # Ensure all labels are present as probability columns
+    prob_map = {
+        c.replace("label_probabilities_", "").replace("probabilities_", ""): c
+        for c in prob_cols
+    }
+    if len(labels_sorted) < 3:
+        return []
+
+    y_true_raw = df_pred[LABEL_COLUMN_NAME].astype(str)
+    # Drop rows with NaN probabilities across any class to avoid metric errors
+    prob_matrix = df_pred[[prob_map[lbl] for lbl in labels_sorted]].astype(float)
+    mask_valid = ~prob_matrix.isnull().any(axis=1)
+    prob_matrix = prob_matrix[mask_valid]
+    y_true_raw = y_true_raw[mask_valid]
+    if prob_matrix.empty:
+        return []
+
+    y_true_bin = label_binarize(y_true_raw, classes=labels_sorted)
+    y_score = prob_matrix.to_numpy()
+
+    plots: List[Dict[str, str]] = []
+
+    # ROC: one-vs-rest + micro
+    fig_roc = go.Figure()
+    added_any = False
+    for idx, lbl in enumerate(labels_sorted):
+        if y_true_bin[:, idx].sum() == 0 or y_true_bin[:, idx].sum() == len(y_true_bin):
+            continue  # skip classes without both positives and negatives
+        fpr, tpr, _ = roc_curve(y_true_bin[:, idx], y_score[:, idx])
+        fig_roc.add_trace(
             go.Scatter(
-                x=bin_centers,
-                y=frac_positives,
-                mode="lines+markers",
-                name="Calibration",
-                line=dict(color="#2ca02c", width=4),
-            )
-        )
-        fig_cal.add_trace(
-            go.Scatter(
-                x=[0, 1],
-                y=[0, 1],
+                x=fpr,
+                y=tpr,
                 mode="lines",
-                name="Perfect Calibration",
-                line=dict(color="gray", width=2, dash="dash"),
+                name=f"{lbl} (AUC={auc(fpr, tpr):.3f})",
+                line=dict(width=3),
             )
         )
-        fig_cal.update_layout(
-            title=dict(text="Calibration Curve", x=0.5),
-            xaxis_title="Predicted probability",
-            yaxis_title="Observed frequency",
-            width=700,
-            height=500,
+        added_any = True
+    # Micro-average only if we have mixed labels
+    if y_true_bin.sum() > 0 and y_true_bin.sum() < y_true_bin.size:
+        fpr_micro, tpr_micro, _ = roc_curve(y_true_bin.ravel(), y_score.ravel())
+        fig_roc.add_trace(
+            go.Scatter(
+                x=fpr_micro,
+                y=tpr_micro,
+                mode="lines",
+                name=f"Micro-average (AUC={auc(fpr_micro, tpr_micro):.3f})",
+                line=dict(width=3, dash="dash"),
+            )
         )
-        _style_fig(fig_cal)
-        plots.append({
-            "title": "Calibration Curve (Predicted Probability vs Observed Frequency)",
-            "html": pio.to_html(fig_cal, full_html=False, include_plotlyjs=False),
-        })
-
-    # Plot 3: Threshold vs Metrics
-    thresholds = np.linspace(0.0, 1.0, 21)
-    accs, f1s, sens, specs = [], [], [], []
-    for t in thresholds:
-        y_pred = (y_score >= t).astype(int)
-        tp = np.sum((y_true == 1) & (y_pred == 1))
-        tn = np.sum((y_true == 0) & (y_pred == 0))
-        fp = np.sum((y_true == 0) & (y_pred == 1))
-        fn = np.sum((y_true == 1) & (y_pred == 0))
-        acc = (tp + tn) / max(len(y_true), 1)
-        prec = tp / max(tp + fp, 1e-9)
-        rec = tp / max(tp + fn, 1e-9)
-        f1 = 0 if (prec + rec) == 0 else 2 * prec * rec / (prec + rec)
-        sensitivity = rec
-        specificity = tn / max(tn + fp, 1e-9)
-        accs.append(acc)
-        f1s.append(f1)
-        sens.append(sensitivity)
-        specs.append(specificity)
-
-    fig_thresh = go.Figure()
-    fig_thresh.add_trace(go.Scatter(x=thresholds, y=accs, mode="lines", name="Accuracy", line=dict(width=4)))
-    fig_thresh.add_trace(go.Scatter(x=thresholds, y=f1s, mode="lines", name="F1", line=dict(width=4)))
-    fig_thresh.add_trace(go.Scatter(x=thresholds, y=sens, mode="lines", name="Sensitivity", line=dict(width=4)))
-    fig_thresh.add_trace(go.Scatter(x=thresholds, y=specs, mode="lines", name="Specificity", line=dict(width=4)))
-    fig_thresh.update_layout(
-        title=dict(text="Threshold Sweep: Accuracy, F1, Sensitivity, Specificity", x=0.5),
-        xaxis_title="Decision threshold",
-        yaxis_title="Metric value",
-        width=700,
-        height=500,
+        added_any = True
+    if not added_any:
+        return []
+    fig_roc.add_trace(
+        go.Scatter(
+            x=[0, 1],
+            y=[0, 1],
+            mode="lines",
+            name="Random",
+            line=dict(color="gray", width=2, dash="dot"),
+        )
+    )
+    fig_roc.update_layout(
+        title=dict(text="Multi-class ROC-AUC (one-vs-rest)", x=0.5),
+        xaxis_title="False Positive Rate",
+        yaxis_title="True Positive Rate",
+        width=820,
+        height=620,
         legend=dict(
-            x=0.7,
-            y=0.2,
+            x=0.62,
+            y=0.05,
             bgcolor="rgba(255,255,255,0.9)",
             bordercolor="rgba(0,0,0,0.2)",
             borderwidth=1,
         ),
-        shapes=[
-            dict(
-                type="line",
-                x0=threshold,
-                x1=threshold,
-                y0=0,
-                y1=1,
-                xref="x",
-                yref="paper",
-                line=dict(color="#d62728", width=2, dash="dash"),
+    )
+    _style_fig(fig_roc)
+    plots.append(_wrap_plot("Multi-class ROC-AUC (one-vs-rest)", fig_roc))
+
+    # PR: one-vs-rest + micro AP
+    fig_pr = go.Figure()
+    added_pr = False
+    for idx, lbl in enumerate(labels_sorted):
+        if y_true_bin[:, idx].sum() == 0:
+            continue
+        prec, rec, _ = precision_recall_curve(y_true_bin[:, idx], y_score[:, idx])
+        ap = average_precision_score(y_true_bin[:, idx], y_score[:, idx])
+        fig_pr.add_trace(
+            go.Scatter(
+                x=rec,
+                y=prec,
+                mode="lines",
+                name=f"{lbl} (AP={ap:.3f})",
+                line=dict(width=3),
             )
-        ] if isinstance(threshold, (int, float)) else [],
-        annotations=[
-            dict(
-                x=threshold,
-                y=1.02,
-                xref="x",
-                yref="paper",
-                showarrow=False,
-                text=f"Threshold = {threshold:.2f}",
-                font=dict(size=11, color="#d62728"),
+        )
+        added_pr = True
+    if y_true_bin.sum() > 0:
+        prec_micro, rec_micro, _ = precision_recall_curve(y_true_bin.ravel(), y_score.ravel())
+        ap_micro = average_precision_score(y_true_bin, y_score, average="micro")
+        fig_pr.add_trace(
+            go.Scatter(
+                x=rec_micro,
+                y=prec_micro,
+                mode="lines",
+                name=f"Micro-average (AP={ap_micro:.3f})",
+                line=dict(width=3, dash="dash"),
             )
-        ] if isinstance(threshold, (int, float)) else [],
+        )
+        added_pr = True
+    if not added_pr:
+        return plots
+    fig_pr.update_layout(
+        title=dict(text="Multi-class Precision-Recall (one-vs-rest)", x=0.5),
+        xaxis_title="Recall",
+        yaxis_title="Precision",
+        width=820,
+        height=620,
+        legend=dict(
+            x=0.62,
+            y=0.05,
+            bgcolor="rgba(255,255,255,0.9)",
+            bordercolor="rgba(0,0,0,0.2)",
+            borderwidth=1,
+        ),
     )
-    _style_fig(fig_thresh)
-    plots.append({
-        "title": "Threshold Sweep: Accuracy, F1, Sensitivity, Specificity",
-        "html": pio.to_html(fig_thresh, full_html=False, include_plotlyjs=False),
-    })
+    _style_fig(fig_pr)
+    plots.append(_wrap_plot("Multi-class Precision-Recall (one-vs-rest)", fig_pr))
 
     return plots
+
+
+def build_multiclass_metric_plots(test_stats_path: str) -> List[Dict[str, str]]:
+    """Alternative multi-class transparency plots using test_statistics.json per-class stats."""
+    ts_path = Path(test_stats_path)
+    if not ts_path.exists():
+        return []
+    try:
+        with open(ts_path, "r") as f:
+            test_stats = json.load(f)
+    except Exception:
+        return []
+
+    label_stats = test_stats.get("label", {})
+    pcs = label_stats.get("per_class_stats", {})
+    if not pcs:
+        return []
+    classes = list(pcs.keys())
+    if not classes:
+        return []
+
+    metrics = ["precision", "recall", "f1_score", "specificity", "accuracy"]
+    fig_bar = go.Figure()
+    for metric in metrics:
+        values = []
+        for cls in classes:
+            v = pcs.get(cls, {}).get(metric)
+            values.append(v if isinstance(v, (int, float)) else 0)
+        fig_bar.add_trace(
+            go.Bar(
+                x=classes,
+                y=values,
+                name=metric.replace("_", " ").title(),
+            )
+        )
+    fig_bar.update_layout(
+        title=dict(text="Per-Class Metrics (Test)", x=0.5),
+        xaxis_title="Class",
+        yaxis_title="Metric value",
+        barmode="group",
+        width=900,
+        height=600,
+        legend=dict(
+            x=1.02,
+            y=1.0,
+            bgcolor="rgba(255,255,255,0.9)",
+            bordercolor="rgba(0,0,0,0.2)",
+            borderwidth=1,
+        ),
+    )
+    _style_fig(fig_bar)
+
+    return [_wrap_plot("Per-Class Metrics (Test)", fig_bar)]