diff plotly_plots.py @ 15:d17e3a1b8659 draft default tip

planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
author goeckslab
date Fri, 28 Nov 2025 15:45:49 +0000
parents c5150cceab47
children
line wrap: on
line diff
--- a/plotly_plots.py	Wed Nov 26 22:00:32 2025 +0000
+++ b/plotly_plots.py	Fri Nov 28 15:45:49 2025 +0000
@@ -7,13 +7,105 @@
 import plotly.graph_objects as go
 import plotly.io as pio
 from constants import LABEL_COLUMN_NAME, SPLIT_COLUMN_NAME
-from sklearn.metrics import auc, roc_curve
-from sklearn.preprocessing import label_binarize
+
+
+def _style_fig(fig: go.Figure, font_size: int = 12) -> go.Figure:
+    """Apply consistent styling across Plotly figures."""
+    fig.update_layout(
+        font=dict(size=font_size),
+        plot_bgcolor="#ffffff",
+        paper_bgcolor="#ffffff",
+    )
+    fig.update_xaxes(gridcolor="#e8e8e8")
+    fig.update_yaxes(gridcolor="#e8e8e8")
+    return fig
+
+
+def _labels_from_metadata_dict(meta_dict: dict) -> List[str]:
+    """Extract ordered label names from Ludwig train_set_metadata."""
+    if not isinstance(meta_dict, dict):
+        return []
+
+    for key in ("idx2str", "idx2label", "vocab"):
+        seq = meta_dict.get(key)
+        if isinstance(seq, list) and seq:
+            return [str(v) for v in seq]
+
+    str2idx = meta_dict.get("str2idx")
+    if isinstance(str2idx, dict) and str2idx:
+        int_indices = [v for v in str2idx.values() if isinstance(v, int)]
+        if int_indices:
+            max_idx = max(int_indices)
+            ordered = [None] * (max_idx + 1)
+            for name, idx in str2idx.items():
+                if isinstance(idx, int) and 0 <= idx < len(ordered):
+                    ordered[idx] = name
+            return [str(v) for v in ordered if v is not None]
+
+    return []
+
+
+def _resolve_confusion_labels(
+    label_stats: dict,
+    n_classes: int,
+    metadata_csv_path: Optional[str],
+    train_set_metadata_path: Optional[str],
+) -> List[str]:
+    """Prefer original labels from metadata; fall back to stats if unavailable."""
+    if train_set_metadata_path:
+        try:
+            meta_path = Path(train_set_metadata_path)
+            if meta_path.exists():
+                with open(meta_path, "r") as f:
+                    meta_json = json.load(f)
+                label_meta = meta_json.get(LABEL_COLUMN_NAME)
+                if not isinstance(label_meta, dict):
+                    label_meta = next(
+                        (
+                            v
+                            for v in meta_json.values()
+                            if isinstance(v, dict)
+                            and any(k in v for k in ("idx2str", "str2idx", "idx2label", "vocab"))
+                        ),
+                        None,
+                    )
+                labels_from_meta = _labels_from_metadata_dict(label_meta) if label_meta else []
+                if labels_from_meta and len(labels_from_meta) >= n_classes:
+                    return [str(label) for label in labels_from_meta[:n_classes]]
+        except Exception as exc:
+            print(f"Warning: Unable to read labels from train_set_metadata: {exc}")
+
+    if metadata_csv_path:
+        try:
+            csv_path = Path(metadata_csv_path)
+            if csv_path.exists():
+                df_meta = pd.read_csv(csv_path)
+                if LABEL_COLUMN_NAME in df_meta.columns:
+                    uniques = df_meta[LABEL_COLUMN_NAME].dropna().unique().tolist()
+                    if uniques and len(uniques) >= n_classes:
+                        return [str(u) for u in uniques[:n_classes]]
+        except Exception as exc:
+            print(f"Warning: Unable to read labels from metadata CSV: {exc}")
+
+    pcs = label_stats.get("per_class_stats", {})
+    if pcs:
+        pcs_labels = [str(k) for k in pcs.keys()]
+        if len(pcs_labels) >= n_classes:
+            return pcs_labels[:n_classes]
+
+    labels = label_stats.get("labels")
+    if not labels:
+        labels = [str(i) for i in range(n_classes)]
+    if len(labels) < n_classes:
+        labels = labels + [str(i) for i in range(len(labels), n_classes)]
+    return [str(label) for label in labels[:n_classes]]
 
 
 def build_classification_plots(
     test_stats_path: str,
     training_stats_path: Optional[str] = None,
+    metadata_csv_path: Optional[str] = None,
+    train_set_metadata_path: Optional[str] = None,
 ) -> List[Dict[str, str]]:
     """
     Read Ludwig’s test_statistics.json and build three interactive Plotly panels:
@@ -21,6 +113,9 @@
       - ROC-AUC
       - Classification Report Heatmap
 
+    If metadata paths are provided, the confusion matrix axes will use the original
+    label values from the training metadata rather than integer-encoded labels.
+
     Returns a list of dicts, each with:
       {
         "title": <plot title>,
@@ -42,12 +137,12 @@
 
     # 0) Confusion Matrix
     cm = np.array(label_stats["confusion_matrix"], dtype=int)
-    # Try to get actual class names from per_class_stats keys (which contain the real labels)
-    pcs = label_stats.get("per_class_stats", {})
-    if pcs:
-        labels = list(pcs.keys())
-    else:
-        labels = label_stats.get("labels", [str(i) for i in range(cm.shape[0])])
+    labels = _resolve_confusion_labels(
+        label_stats,
+        n_classes,
+        metadata_csv_path=metadata_csv_path,
+        train_set_metadata_path=train_set_metadata_path,
+    )
     total = cm.sum()
 
     fig_cm = go.Figure(
@@ -70,6 +165,7 @@
         height=side_px,
         margin=dict(t=100, l=80, r=80, b=80),
     )
+    _style_fig(fig_cm)
 
     # annotate counts and percentages
     mval = cm.max() if cm.size else 0
@@ -110,16 +206,28 @@
         )
     })
 
-    # 1) ROC-AUC Curves (Multi-class)
-    roc_plot = _build_roc_auc_plot(test_stats_path, labels, common_cfg)
+    # 1) ROC Curve (from test_statistics)
+    roc_plot = _build_static_roc_plot(label_stats, common_cfg, friendly_labels=labels)
     if roc_plot:
         plots.append(roc_plot)
 
+    # 2) Precision-Recall Curve (from test_statistics)
+    pr_plot = _build_precision_recall_plot(label_stats, common_cfg)
+    if pr_plot:
+        plots.append(pr_plot)
+
     # 2) Classification Report Heatmap
     pcs = label_stats.get("per_class_stats", {})
     if pcs:
         classes = list(pcs.keys())
-        metrics = ["precision", "recall", "f1_score"]
+        metrics = [
+            "precision",
+            "recall",
+            "f1_score",
+            "accuracy",
+            "matthews_correlation_coefficient",
+            "specificity",
+        ]
         z, txt = [], []
         for c in classes:
             row, trow = [], []
@@ -133,7 +241,7 @@
         fig_cr = go.Figure(
             go.Heatmap(
                 z=z,
-                x=metrics,
+                x=[m.replace("_", " ") for m in metrics],
                 y=[str(c) for c in classes],
                 text=txt,
                 texttemplate="%{text}",
@@ -143,15 +251,16 @@
             )
         )
         fig_cr.update_layout(
-            title="Classification Report",
+            title="Per-Class metrics",
             xaxis_title="",
             yaxis_title="Class",
             width=side_px,
             height=side_px,
             margin=dict(t=80, l=80, r=80, b=80),
         )
+        _style_fig(fig_cr)
         plots.append({
-            "title": "Classification Report",
+            "title": "Per-Class metrics",
             "html": pio.to_html(
                 fig_cr,
                 full_html=False,
@@ -160,68 +269,667 @@
             )
         })
 
+    # 3) Prediction Diagnostics (from predictions.csv)
+    # Note: appended separately in generate_html_report, not returned here.
+
+    return plots
+
+
+def build_train_validation_plots(train_stats_path: str) -> List[Dict[str, str]]:
+    """Generate Train/Validation learning curve plots from training_statistics.json."""
+    if not train_stats_path or not Path(train_stats_path).exists():
+        return []
+    try:
+        with open(train_stats_path, "r") as f:
+            train_stats = json.load(f)
+    except Exception as exc:
+        print(f"Warning: Unable to read training statistics: {exc}")
+        return []
+
+    label_train = (train_stats.get("training") or {}).get("label", {})
+    label_val = (train_stats.get("validation") or {}).get("label", {})
+    if not label_train and not label_val:
+        return []
+    plots: List[Dict[str, str]] = []
+    include_js = True  # Load Plotly.js once for this group
+
+    def _get_series(stats: dict, metric: str) -> List[float]:
+        if metric not in stats:
+            return []
+        vals = stats.get(metric, [])
+        if isinstance(vals, list):
+            return [float(v) for v in vals]
+        try:
+            return [float(vals)]
+        except Exception:
+            return []
+
+    def _line_plot(metric_key: str, title: str, yaxis_title: str) -> Optional[Dict[str, str]]:
+        train_series = _get_series(label_train, metric_key)
+        val_series = _get_series(label_val, metric_key)
+        if not train_series and not val_series:
+            return None
+        epochs_train = list(range(1, len(train_series) + 1))
+        epochs_val = list(range(1, len(val_series) + 1))
+        fig = go.Figure()
+        if train_series:
+            fig.add_trace(
+                go.Scatter(
+                    x=epochs_train,
+                    y=train_series,
+                    mode="lines+markers",
+                    name="Train",
+                    line=dict(width=4),
+                )
+            )
+        if val_series:
+            fig.add_trace(
+                go.Scatter(
+                    x=epochs_val,
+                    y=val_series,
+                    mode="lines+markers",
+                    name="Validation",
+                    line=dict(width=4),
+                )
+            )
+        fig.update_layout(
+            title=dict(text=title, x=0.5),
+            xaxis_title="Epoch",
+            yaxis_title=yaxis_title,
+            width=760,
+            height=520,
+            hovermode="x unified",
+        )
+        _style_fig(fig)
+        return {
+            "title": title,
+            "html": pio.to_html(
+                fig,
+                full_html=False,
+                include_plotlyjs="cdn" if include_js else False,
+            ),
+        }
+
+    # Core learning curves
+    for key, title in [
+        ("roc_auc", "ROC-AUC across epochs"),
+        ("precision", "Precision across epochs"),
+        ("recall", "Recall/Sensitivity across epochs"),
+        ("specificity", "Specificity across epochs"),
+    ]:
+        plot = _line_plot(key, title, title.replace("Learning Curve", "").strip())
+        if plot:
+            plots.append(plot)
+            include_js = False
+
+    # Precision vs Recall evolution (validation)
+    val_prec = _get_series(label_val, "precision")
+    val_rec = _get_series(label_val, "recall")
+    if val_prec and val_rec:
+        epochs = list(range(1, min(len(val_prec), len(val_rec)) + 1))
+        fig_pr = go.Figure()
+        fig_pr.add_trace(
+            go.Scatter(
+                x=epochs,
+                y=val_prec[: len(epochs)],
+                mode="lines+markers",
+                name="Precision",
+            )
+        )
+        fig_pr.add_trace(
+            go.Scatter(
+                x=epochs,
+                y=val_rec[: len(epochs)],
+                mode="lines+markers",
+                name="Recall",
+            )
+        )
+        fig_pr.update_layout(
+            title=dict(text="Validation Precision and Recall by Epoch", x=0.5),
+            xaxis_title="Epoch",
+            yaxis_title="Value",
+            width=760,
+            height=520,
+            hovermode="x unified",
+        )
+        _style_fig(fig_pr)
+        plots.append({
+            "title": "Precision vs Recall Evolution",
+            "html": pio.to_html(
+                fig_pr,
+                full_html=False,
+                include_plotlyjs="cdn" if include_js else False,
+            ),
+        })
+        include_js = False
+
+    # F1-score derived
+    def _compute_f1(p: List[float], r: List[float]) -> List[float]:
+        f1_vals = []
+        for prec, rec in zip(p, r):
+            if (prec + rec) == 0:
+                f1_vals.append(0.0)
+            else:
+                f1_vals.append(2 * prec * rec / (prec + rec))
+        return f1_vals
+
+    f1_train = _compute_f1(_get_series(label_train, "precision"), _get_series(label_train, "recall"))
+    f1_val = _compute_f1(val_prec, val_rec)
+    if f1_train or f1_val:
+        fig = go.Figure()
+        if f1_train:
+            fig.add_trace(go.Scatter(x=list(range(1, len(f1_train) + 1)), y=f1_train, mode="lines+markers", name="Train", line=dict(width=4)))
+        if f1_val:
+            fig.add_trace(go.Scatter(x=list(range(1, len(f1_val) + 1)), y=f1_val, mode="lines+markers", name="Validation", line=dict(width=4)))
+        fig.update_layout(
+            title=dict(text="F1-Score across epochs (derived)", x=0.5),
+            xaxis_title="Epoch",
+            yaxis_title="F1-Score",
+            width=760,
+            height=520,
+            hovermode="x unified",
+        )
+        _style_fig(fig)
+        plots.append({
+            "title": "F1-Score across epochs (derived)",
+            "html": pio.to_html(
+                fig,
+                full_html=False,
+                include_plotlyjs="cdn" if include_js else False,
+            ),
+        })
+        include_js = False
+
+    # Overfitting Gap: Train vs Val ROC-AUC (gap)
+    roc_train = _get_series(label_train, "roc_auc")
+    roc_val = _get_series(label_val, "roc_auc")
+    if roc_train and roc_val:
+        epochs_gap = list(range(1, min(len(roc_train), len(roc_val)) + 1))
+        gaps = [t - v for t, v in zip(roc_train[:len(epochs_gap)], roc_val[:len(epochs_gap)])]
+        fig_gap = go.Figure()
+        fig_gap.add_trace(go.Scatter(x=epochs_gap, y=gaps, mode="lines+markers", name="Train - Val ROC-AUC", line=dict(width=4)))
+        fig_gap.update_layout(
+            title=dict(text="Overfitting gap: ROC-AUC across epochs", x=0.5),
+            xaxis_title="Epoch",
+            yaxis_title="Gap",
+            width=760,
+            height=520,
+            hovermode="x unified",
+        )
+        _style_fig(fig_gap)
+        plots.append({
+            "title": "Overfitting gap: ROC-AUC across epochs",
+            "html": pio.to_html(
+                fig_gap,
+                full_html=False,
+                include_plotlyjs="cdn" if include_js else False,
+            ),
+        })
+        include_js = False
+
+    # Best Epoch Dashboard (based on max val ROC-AUC)
+    if roc_val:
+        best_idx = int(np.argmax(roc_val))
+        best_epoch = best_idx + 1
+        spec_val = _get_series(label_val, "specificity")
+        metrics_at_best = {
+            "ROC-AUC": roc_val[best_idx] if best_idx < len(roc_val) else None,
+            "Precision": val_prec[best_idx] if best_idx < len(val_prec) else None,
+            "Recall": val_rec[best_idx] if best_idx < len(val_rec) else None,
+            "Specificity": spec_val[best_idx] if best_idx < len(spec_val) else None,
+            "F1-Score": f1_val[best_idx] if best_idx < len(f1_val) else None,
+        }
+        fig_best = go.Figure()
+        for name, value in metrics_at_best.items():
+            if value is not None:
+                fig_best.add_trace(go.Bar(name=name, x=[name], y=[value]))
+        fig_best.update_layout(
+            title=dict(text=f"Best Epoch Dashboard (Val ROC-AUC @ epoch {best_epoch})", x=0.5),
+            xaxis_title="Metric",
+            yaxis_title="Value",
+            width=760,
+            height=520,
+            showlegend=False,
+        )
+        _style_fig(fig_best)
+        plots.append({
+            "title": "Best Validation Epoch Snapshot (Metrics)",
+            "html": pio.to_html(
+                fig_best,
+                full_html=False,
+                include_plotlyjs="cdn" if include_js else False,
+            ),
+        })
+        include_js = False
+
     return plots
 
 
-def _build_roc_auc_plot(test_stats_path: str, class_labels: List[str], config: dict) -> Optional[Dict[str, str]]:
-    """
-    Build an interactive ROC-AUC curve plot for multi-class classification.
-    Following sklearn's ROC example with micro-average and per-class curves.
+def _get_regression_series(split_stats: dict, metric: str) -> List[float]:
+    if metric not in split_stats:
+        return []
+    vals = split_stats.get(metric, [])
+    if isinstance(vals, list):
+        return [float(v) for v in vals]
+    try:
+        return [float(vals)]
+    except Exception:
+        return []
+
 
-    Args:
-        test_stats_path: Path to test_statistics.json
-        class_labels: List of class label names
-        config: Plotly config dict
+def _regression_line_plot(
+    train_split: dict,
+    val_split: dict,
+    metric_key: str,
+    title: str,
+    yaxis_title: str,
+    include_js: bool,
+) -> Optional[Dict[str, str]]:
+    train_series = _get_regression_series(train_split, metric_key)
+    val_series = _get_regression_series(val_split, metric_key)
+    if not train_series and not val_series:
+        return None
+    epochs_train = list(range(1, len(train_series) + 1))
+    epochs_val = list(range(1, len(val_series) + 1))
+    fig = go.Figure()
+    if train_series:
+        fig.add_trace(
+            go.Scatter(
+                x=epochs_train,
+                y=train_series,
+                mode="lines+markers",
+                name="Train",
+                line=dict(width=4),
+            )
+        )
+    if val_series:
+        fig.add_trace(
+            go.Scatter(
+                x=epochs_val,
+                y=val_series,
+                mode="lines+markers",
+                name="Validation",
+                line=dict(width=4),
+            )
+        )
+    fig.update_layout(
+        title=dict(text=title, x=0.5),
+        xaxis_title="Epoch",
+        yaxis_title=yaxis_title,
+        width=760,
+        height=520,
+        hovermode="x unified",
+    )
+    _style_fig(fig)
+    return {
+        "title": title,
+        "html": pio.to_html(
+            fig,
+            full_html=False,
+            include_plotlyjs="cdn" if include_js else False,
+        ),
+    }
+
+
+def build_regression_train_val_plots(train_stats_path: str) -> List[Dict[str, str]]:
+    """Generate regression Train/Validation learning curve plots from training_statistics.json."""
+    if not train_stats_path or not Path(train_stats_path).exists():
+        return []
+    try:
+        with open(train_stats_path, "r") as f:
+            train_stats = json.load(f)
+    except Exception as exc:
+        print(f"Warning: Unable to read training statistics: {exc}")
+        return []
+
+    label_train = (train_stats.get("training") or {}).get("label", {})
+    label_val = (train_stats.get("validation") or {}).get("label", {})
+    if not label_train and not label_val:
+        return []
+
+    plots: List[Dict[str, str]] = []
+    include_js = True
+    for metric_key, title, ytitle in [
+        ("mean_absolute_error", "Mean Absolute Error across epochs", "MAE"),
+        ("root_mean_squared_error", "Root Mean Squared Error across epochs", "RMSE"),
+        ("mean_absolute_percentage_error", "Mean Absolute Percentage Error across epochs", "MAPE"),
+        ("r2", "R² across epochs", "R²"),
+        ("loss", "Loss across epochs", "Loss"),
+    ]:
+        plot = _regression_line_plot(label_train, label_val, metric_key, title, ytitle, include_js)
+        if plot:
+            plots.append(plot)
+            include_js = False
+    return plots
+
 
-    Returns:
-        Dict with title and HTML, or None if data unavailable
-    """
+def build_regression_test_plots(train_stats_path: str) -> List[Dict[str, str]]:
+    """Generate regression Test learning curves from training_statistics.json."""
+    if not train_stats_path or not Path(train_stats_path).exists():
+        return []
     try:
-        # Get the experiment directory from test_stats_path
-        exp_dir = Path(test_stats_path).parent
+        with open(train_stats_path, "r") as f:
+            train_stats = json.load(f)
+    except Exception as exc:
+        print(f"Warning: Unable to read training statistics: {exc}")
+        return []
+
+    label_test = (train_stats.get("test") or {}).get("label", {})
+    if not label_test:
+        return []
 
-        # Load predictions with probabilities
-        predictions_path = exp_dir / "predictions.csv"
-        if not predictions_path.exists():
-            return None
+    plots: List[Dict[str, str]] = []
+    include_js = True
+    metrics = [
+        ("mean_absolute_error", "Mean Absolute Error Across Epochs", "MAE"),
+        ("root_mean_squared_error", "Root Mean Squared Error Across Epochs", "RMSE"),
+        ("mean_absolute_percentage_error", "Mean Absolute Percentage Error Across Epochs", "MAPE"),
+        ("r2", "R² Across Epochs", "R²"),
+        ("loss", "Loss Across Epochs", "Loss"),
+    ]
+    epochs = None
+    for metric_key, title, ytitle in metrics:
+        series = _get_regression_series(label_test, metric_key)
+        if not series:
+            continue
+        if epochs is None:
+            epochs = list(range(1, len(series) + 1))
+        fig = go.Figure()
+        fig.add_trace(
+            go.Scatter(
+                x=epochs,
+                y=series[: len(epochs)],
+                mode="lines+markers",
+                name="Test",
+                line=dict(width=4),
+            )
+        )
+        fig.update_layout(
+            title=dict(text=title, x=0.5),
+            xaxis_title="Epoch",
+            yaxis_title=ytitle,
+            width=760,
+            height=520,
+            hovermode="x unified",
+        )
+        _style_fig(fig)
+        plots.append({
+            "title": title,
+            "html": pio.to_html(
+                fig,
+                full_html=False,
+                include_plotlyjs="cdn" if include_js else False,
+            ),
+        })
+        include_js = False
+    return plots
 
-        df_pred = pd.read_csv(predictions_path)
+
+def _build_static_roc_plot(
+    label_stats: dict, config: dict, friendly_labels: Optional[List[str]] = None
+) -> Optional[Dict[str, str]]:
+    """Build ROC curve directly from test_statistics.json (single curve)."""
+    roc_data = label_stats.get("roc_curve")
+    if not isinstance(roc_data, dict):
+        return None
+
+    fpr = roc_data.get("false_positive_rate")
+    tpr = roc_data.get("true_positive_rate")
+    if not fpr or not tpr or len(fpr) != len(tpr):
+        return None
+
+    try:
+        fig = go.Figure()
+        fig.add_trace(
+            go.Scatter(
+                x=fpr,
+                y=tpr,
+                mode="lines+markers",
+                name="ROC Curve",
+                line=dict(color="#1f77b4", width=4),
+                hovertemplate="FPR: %{x:.3f}<br>TPR: %{y:.3f}<extra></extra>",
+            )
+        )
+        fig.add_trace(
+            go.Scatter(
+                x=[0, 1],
+                y=[0, 1],
+                mode="lines",
+                name="Random Classifier",
+                line=dict(color="gray", width=2, dash="dash"),
+                hovertemplate="Random Classifier<extra></extra>",
+            )
+        )
+
+        auc_val = label_stats.get("roc_auc") or label_stats.get("roc_auc_macro") or label_stats.get("roc_auc_micro")
+        auc_txt = f" (AUC = {auc_val:.3f})" if isinstance(auc_val, (int, float)) else ""
 
-        if SPLIT_COLUMN_NAME in df_pred.columns:
-            split_series = df_pred[SPLIT_COLUMN_NAME].astype(str).str.lower()
-            test_mask = split_series.isin({"2", "test", "testing"})
-            if test_mask.any():
-                df_pred = df_pred[test_mask].reset_index(drop=True)
+        # Determine which label is treated as positive for the curve
+        label_list: List = []
+        pcs = label_stats.get("per_class_stats", {})
+        if pcs:
+            label_list = list(pcs.keys())
+        if not label_list:
+            labels_from_stats = label_stats.get("labels")
+            if isinstance(labels_from_stats, list):
+                label_list = labels_from_stats
+
+        # Try to resolve index of the positive label explicitly provided by Ludwig
+        pos_label_raw = (
+            roc_data.get("positive_label")
+            or roc_data.get("positive_class")
+            or label_stats.get("positive_label")
+        )
+        pos_label_idx = None
+        if pos_label_raw is not None and isinstance(label_list, list):
+            try:
+                pos_label_idx = label_list.index(pos_label_raw)
+            except ValueError:
+                pos_label_idx = None
+
+        # Fallback: use the second label if available, otherwise the first
+        if pos_label_idx is None:
+            if isinstance(label_list, list) and len(label_list) >= 2:
+                pos_label_idx = 1
+            elif isinstance(label_list, list) and label_list:
+                pos_label_idx = 0
+
+        if pos_label_raw is None and isinstance(label_list, list) and pos_label_idx is not None:
+            pos_label_raw = label_list[pos_label_idx]
+
+        # Map to friendly label if we have one from metadata/CSV
+        pos_label_display = pos_label_raw
+        if (
+            friendly_labels
+            and isinstance(pos_label_idx, int)
+            and 0 <= pos_label_idx < len(friendly_labels)
+        ):
+            pos_label_display = friendly_labels[pos_label_idx]
+
+        pos_label_txt = (
+            f"Positive class: {pos_label_display}"
+            if pos_label_display is not None
+            else "Positive class: (not available)"
+        )
+
+        title_label = f"ROC Curve{auc_txt}"
+        if pos_label_display is not None:
+            title_label = f"ROC Curve (Positive Class: {pos_label_display}){auc_txt}"
 
-        if df_pred.empty:
-            return None
+        fig.update_layout(
+            title=dict(text=title_label, x=0.5),
+            xaxis_title="False Positive Rate",
+            yaxis_title="True Positive Rate",
+            width=700,
+            height=600,
+            margin=dict(t=80, l=80, r=80, b=110),
+            hovermode="closest",
+            legend=dict(
+                x=0.6,
+                y=0.1,
+                bgcolor="rgba(255,255,255,0.9)",
+                bordercolor="rgba(0,0,0,0.2)",
+                borderwidth=1,
+            ),
+        )
+        _style_fig(fig)
+        fig.update_xaxes(range=[0, 1.0])
+        fig.update_yaxes(range=[0, 1.05])
 
-        # Extract probability columns (label_probabilities_0, label_probabilities_1, etc.)
-        # or label_probabilities_<class_name> for string labels
-        prob_cols = [col for col in df_pred.columns if col.startswith('label_probabilities_') and col != 'label_probabilities']
+        fig.add_annotation(
+            x=0.5,
+            y=-0.15,
+            xref="paper",
+            yref="paper",
+            showarrow=False,
+            text=f"<span style='font-size:12px;color:#555;'>{pos_label_txt}</span>",
+            xanchor="center",
+        )
+
+        return {
+            "title": "ROC Curve",
+            "html": pio.to_html(
+                fig,
+                full_html=False,
+                include_plotlyjs=False,
+                config=config,
+            ),
+        }
+    except Exception as e:
+        print(f"Error building ROC plot: {e}")
+        return None
+
+
+def _build_precision_recall_plot(label_stats: dict, config: dict) -> Optional[Dict[str, str]]:
+    """Build Precision-Recall curve directly from test_statistics.json."""
+    pr_data = label_stats.get("precision_recall_curve")
+    if not isinstance(pr_data, dict):
+        return None
+
+    precisions = pr_data.get("precisions")
+    recalls = pr_data.get("recalls")
+    if not precisions or not recalls or len(precisions) != len(recalls):
+        return None
 
-        # Sort by class number if numeric, otherwise keep alphabetical order
-        if prob_cols and prob_cols[0].split('_')[-1].isdigit():
-            prob_cols.sort(key=lambda x: int(x.split('_')[-1]))
-        else:
-            prob_cols.sort()  # Alphabetical sort for string class names
+    try:
+        fig = go.Figure()
+        fig.add_trace(
+            go.Scatter(
+                x=recalls,
+                y=precisions,
+                mode="lines+markers",
+                name="Precision-Recall",
+                line=dict(color="#d62728", width=4),
+                hovertemplate="Recall: %{x:.3f}<br>Precision: %{y:.3f}<extra></extra>",
+            )
+        )
+
+        ap_val = (
+            label_stats.get("average_precision_macro")
+            or label_stats.get("average_precision_micro")
+            or label_stats.get("average_precision_samples")
+        )
+        ap_txt = f" (AP = {ap_val:.3f})" if isinstance(ap_val, (int, float)) else ""
+
+        fig.update_layout(
+            title=dict(text=f"Precision-Recall Curve{ap_txt}", x=0.5),
+            xaxis_title="Recall",
+            yaxis_title="Precision",
+            width=700,
+            height=600,
+            margin=dict(t=80, l=80, r=80, b=80),
+            hovermode="closest",
+            legend=dict(
+                x=0.6,
+                y=0.1,
+                bgcolor="rgba(255,255,255,0.9)",
+                bordercolor="rgba(0,0,0,0.2)",
+                borderwidth=1,
+            ),
+        )
+        _style_fig(fig)
+        fig.update_xaxes(range=[0, 1.0])
+        fig.update_yaxes(range=[0, 1.05])
+
+        return {
+            "title": "Precision-Recall Curve",
+            "html": pio.to_html(
+                fig,
+                full_html=False,
+                include_plotlyjs=False,
+                config=config,
+            ),
+        }
+    except Exception as e:
+        print(f"Error building Precision-Recall plot: {e}")
+        return None
+
 
-        if not prob_cols:
-            return None
+def build_prediction_diagnostics(
+    predictions_path: str,
+    label_data_path: Optional[str] = None,
+    split_value: int = 2,
+    threshold: Optional[float] = None,
+) -> List[Dict[str, str]]:
+    """Generate diagnostic plots from predictions.csv for classification tasks."""
+    preds_file = Path(predictions_path)
+    if not preds_file.exists():
+        return []
+
+    try:
+        df_pred = pd.read_csv(predictions_path)
+    except Exception as exc:
+        print(f"Warning: Unable to read predictions CSV: {exc}")
+        return []
+
+    plots: List[Dict[str, str]] = []
+
+    # Identify probability columns
+    prob_cols = [
+        c for c in df_pred.columns
+        if c.startswith("label_probabilities_") and c != "label_probabilities"
+    ]
+    prob_cols_sorted = sorted(prob_cols)
 
-        # Get probabilities matrix (n_samples x n_classes)
-        y_score = df_pred[prob_cols].values
-        n_classes = len(prob_cols)
+    def _select_positive_prob():
+        if not prob_cols_sorted:
+            return None, None
+        # Prefer a column indicating positive/event/true/1
+        preferred_keys = ("event", "true", "positive", "pos", "1")
+        for col in prob_cols_sorted:
+            suffix = col.replace("label_probabilities_", "").lower()
+            if any(k in suffix for k in preferred_keys):
+                return col, suffix
+        if len(prob_cols_sorted) == 2:
+            col = prob_cols_sorted[1]
+            return col, col.replace("label_probabilities_", "")
+        col = prob_cols_sorted[0]
+        return col, col.replace("label_probabilities_", "")
 
-        y_true = None
-        candidate_cols = [
+    pos_prob_col, pos_label_hint = _select_positive_prob()
+    pos_prob_series = df_pred[pos_prob_col] if pos_prob_col and pos_prob_col in df_pred else None
+
+    # Confidence series: prefer label_probability, otherwise positive prob, otherwise max prob
+    confidence_series = None
+    if "label_probability" in df_pred.columns:
+        confidence_series = df_pred["label_probability"]
+    elif pos_prob_series is not None:
+        confidence_series = pos_prob_series
+    elif prob_cols_sorted:
+        confidence_series = df_pred[prob_cols_sorted].max(axis=1)
+
+    # True labels
+    def _extract_labels():
+        candidates = [
             LABEL_COLUMN_NAME,
             f"{LABEL_COLUMN_NAME}_ground_truth",
             f"{LABEL_COLUMN_NAME}__ground_truth",
             f"{LABEL_COLUMN_NAME}_target",
             f"{LABEL_COLUMN_NAME}__target",
+            "label",
+            "label_true",
         ]
-        candidate_cols.extend(
+        candidates.extend(
             [
                 col
                 for col in df_pred.columns
@@ -230,174 +938,182 @@
                 and "predictions" not in col
             ]
         )
-        for col in candidate_cols:
-            if col in df_pred.columns and col not in prob_cols:
-                y_true = df_pred[col].values
-                break
+        for col in candidates:
+            if col in df_pred.columns and col not in prob_cols_sorted:
+                return df_pred[col]
+        if label_data_path and Path(label_data_path).exists():
+            try:
+                df_all = pd.read_csv(label_data_path)
+                if SPLIT_COLUMN_NAME in df_all.columns:
+                    df_all = df_all[df_all[SPLIT_COLUMN_NAME] == split_value].reset_index(drop=True)
+                if LABEL_COLUMN_NAME in df_all.columns:
+                    return df_all[LABEL_COLUMN_NAME].reset_index(drop=True)
+            except Exception as exc:
+                print(f"Warning: Unable to load labels from dataset: {exc}")
+        return None
 
-        if y_true is None:
-            desc_path = exp_dir / "description.json"
-            if desc_path.exists():
-                try:
-                    with open(desc_path, 'r') as f:
-                        desc = json.load(f)
-                    dataset_path = desc.get('dataset', '')
-                    if dataset_path and Path(dataset_path).exists():
-                        df_orig = pd.read_csv(dataset_path)
-                        if SPLIT_COLUMN_NAME in df_orig.columns:
-                            df_orig = df_orig[df_orig[SPLIT_COLUMN_NAME] == 2].reset_index(drop=True)
-                        if LABEL_COLUMN_NAME in df_orig.columns:
-                            y_true = df_orig[LABEL_COLUMN_NAME].values
-                            if len(y_true) != len(df_pred):
-                                print(
-                                    f"Warning: Test set size mismatch. Truncating to {len(df_pred)} samples for ROC plot."
-                                )
-                                y_true = y_true[:len(df_pred)]
-                    else:
-                        print("Warning: Original dataset referenced in description.json is unavailable.")
-                except Exception as exc:  # pragma: no cover - defensive
-                    print(f"Warning: Failed to recover labels from dataset: {exc}")
-
-        if y_true is None or len(y_true) == 0:
-            print("Warning: Unable to locate ground-truth labels for ROC plot.")
-            return None
-
-        if len(y_true) != len(y_score):
-            limit = min(len(y_true), len(y_score))
-            if limit == 0:
-                return None
-            print(f"Warning: Aligning prediction and label lengths to {limit} samples for ROC plot.")
-            y_true = y_true[:limit]
-            y_score = y_score[:limit]
+    labels_series = _extract_labels()
 
-        # Get actual class names from probability column names
-        actual_classes = [col.replace('label_probabilities_', '') for col in prob_cols]
-        display_classes = class_labels if len(class_labels) == n_classes else actual_classes
-
-        # Binarize the output following sklearn example
-        # Use actual class names if they're strings, otherwise use range
-        if isinstance(y_true[0], str):
-            y_test = label_binarize(y_true, classes=actual_classes)
-        else:
-            y_test = label_binarize(y_true, classes=list(range(n_classes)))
-
-        # Handle binary classification case
-        if y_test.ndim != 2:
-            y_test = np.atleast_2d(y_test)
+    # Plot 1: Confidence Histogram
+    if confidence_series is not None:
+        fig_conf = go.Figure()
+        fig_conf.add_trace(
+            go.Histogram(
+                x=confidence_series,
+                nbinsx=20,
+                marker=dict(color="#1f77b4", line=dict(color="#ffffff", width=1)),
+                opacity=0.8,
+                histnorm="percent",
+            )
+        )
+        fig_conf.update_layout(
+            title=dict(text="Prediction Confidence Distribution", x=0.5),
+            xaxis_title="Predicted probability (confidence)",
+            yaxis_title="Percentage (%)",
+            bargap=0.05,
+            width=700,
+            height=500,
+        )
+        _style_fig(fig_conf)
+        plots.append({
+            "title": "Prediction Confidence Distribution",
+            "html": pio.to_html(fig_conf, full_html=False, include_plotlyjs=False),
+        })
 
-        if n_classes == 2:
-            if y_test.shape[1] == 1:
-                y_test = np.hstack([1 - y_test, y_test])
-            elif y_test.shape[1] != 2:
-                print("Warning: Unexpected label binarization shape for binary ROC plot.")
-                return None
-        elif y_test.shape[1] != n_classes:
-            print("Warning: Label binarization did not produce expected class dimension; skipping ROC plot.")
-            return None
+    # The remaining plots require true labels and a positive-class probability
+    if labels_series is None or pos_prob_series is None:
+        return plots
+
+    # Align lengths
+    min_len = min(len(labels_series), len(pos_prob_series))
+    if min_len == 0:
+        return plots
+    y_true_raw = labels_series.iloc[:min_len]
+    y_score = np.array(pos_prob_series.iloc[:min_len], dtype=float)
 
-        # Compute ROC curve and ROC area for each class (following sklearn example)
-        fpr = dict()
-        tpr = dict()
-        roc_auc = dict()
+    # Determine positive label
+    unique_labels = pd.unique(y_true_raw)
+    unique_labels_list = list(unique_labels)
+    positive_label = None
+    if pos_label_hint and str(pos_label_hint) in [str(u) for u in unique_labels_list]:
+        positive_label = pos_label_hint
+    elif len(unique_labels_list) == 2:
+        positive_label = unique_labels_list[1]
+    else:
+        positive_label = unique_labels_list[0]
 
-        for i in range(n_classes):
-            if np.sum(y_test[:, i]) > 0:  # Check if class exists in test set
-                fpr[i], tpr[i], _ = roc_curve(y_test[:, i], y_score[:, i])
-                roc_auc[i] = auc(fpr[i], tpr[i])
-
-        # Compute micro-average ROC curve and ROC area (sklearn example)
-        fpr["micro"], tpr["micro"], _ = roc_curve(y_test.ravel(), y_score.ravel())
-        roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])
-
-        # Create ROC curve plot
-        fig_roc = go.Figure()
+    y_true = (y_true_raw == positive_label).astype(int).values
 
-        # Colors for different classes
-        colors = [
-            '#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd',
-            '#8c564b', '#e377c2', '#7f7f7f', '#bcbd22', '#17becf'
-        ]
-
-        # Plot micro-average ROC curve first (most important)
-        fig_roc.add_trace(go.Scatter(
-            x=fpr["micro"],
-            y=tpr["micro"],
-            mode='lines',
-            name=f'Micro-average ROC (AUC = {roc_auc["micro"]:.3f})',
-            line=dict(color='deeppink', width=3, dash='dot'),
-            hovertemplate=('<b>Micro-average ROC</b><br>'
-                           'FPR: %{x:.3f}<br>'
-                           'TPR: %{y:.3f}<br>'
-                           f'AUC: {roc_auc["micro"]:.3f}<extra></extra>')
-        ))
-
-        # Plot ROC curve for each class
-        for i in range(n_classes):
-            if i in roc_auc:  # Only plot if class exists in test set
-                class_name = display_classes[i] if i < len(display_classes) else f"Class {i}"
-                color = colors[i % len(colors)]
-
-                fig_roc.add_trace(go.Scatter(
-                    x=fpr[i],
-                    y=tpr[i],
-                    mode='lines',
-                    name=f'{class_name} (AUC = {roc_auc[i]:.3f})',
-                    line=dict(color=color, width=2),
-                    hovertemplate=(f'<b>{class_name}</b><br>'
-                                   'FPR: %{x:.3f}<br>'
-                                   'TPR: %{y:.3f}<br>'
-                                   f'AUC: {roc_auc[i]:.3f}<extra></extra>')
-                ))
+    # Plot 2: Calibration Curve
+    bins = np.linspace(0.0, 1.0, 11)
+    bin_ids = np.digitize(y_score, bins, right=True)
+    bin_centers = []
+    frac_positives = []
+    for b in range(1, len(bins)):
+        mask = bin_ids == b
+        if not np.any(mask):
+            continue
+        bin_centers.append(y_score[mask].mean())
+        frac_positives.append(y_true[mask].mean())
+    if bin_centers and frac_positives:
+        fig_cal = go.Figure()
+        fig_cal.add_trace(
+            go.Scatter(
+                x=bin_centers,
+                y=frac_positives,
+                mode="lines+markers",
+                name="Calibration",
+                line=dict(color="#2ca02c", width=4),
+            )
+        )
+        fig_cal.add_trace(
+            go.Scatter(
+                x=[0, 1],
+                y=[0, 1],
+                mode="lines",
+                name="Perfect Calibration",
+                line=dict(color="gray", width=2, dash="dash"),
+            )
+        )
+        fig_cal.update_layout(
+            title=dict(text="Calibration Curve", x=0.5),
+            xaxis_title="Predicted probability",
+            yaxis_title="Observed frequency",
+            width=700,
+            height=500,
+        )
+        _style_fig(fig_cal)
+        plots.append({
+            "title": "Calibration Curve (Predicted Probability vs Observed Frequency)",
+            "html": pio.to_html(fig_cal, full_html=False, include_plotlyjs=False),
+        })
 
-        # Add diagonal line (random classifier)
-        fig_roc.add_trace(go.Scatter(
-            x=[0, 1],
-            y=[0, 1],
-            mode='lines',
-            name='Random Classifier',
-            line=dict(color='gray', width=1, dash='dash'),
-            hovertemplate='Random Classifier<br>AUC = 0.500<extra></extra>'
-        ))
-
-        # Calculate macro-average AUC
-        class_aucs = [roc_auc[i] for i in range(n_classes) if i in roc_auc]
-        if class_aucs:
-            macro_auc = np.mean(class_aucs)
-            title_text = f"ROC Curves (Micro-avg = {roc_auc['micro']:.3f}, Macro-avg = {macro_auc:.3f})"
-        else:
-            title_text = f"ROC Curves (Micro-avg = {roc_auc['micro']:.3f})"
+    # Plot 3: Threshold vs Metrics
+    thresholds = np.linspace(0.0, 1.0, 21)
+    accs, f1s, sens, specs = [], [], [], []
+    for t in thresholds:
+        y_pred = (y_score >= t).astype(int)
+        tp = np.sum((y_true == 1) & (y_pred == 1))
+        tn = np.sum((y_true == 0) & (y_pred == 0))
+        fp = np.sum((y_true == 0) & (y_pred == 1))
+        fn = np.sum((y_true == 1) & (y_pred == 0))
+        acc = (tp + tn) / max(len(y_true), 1)
+        prec = tp / max(tp + fp, 1e-9)
+        rec = tp / max(tp + fn, 1e-9)
+        f1 = 0 if (prec + rec) == 0 else 2 * prec * rec / (prec + rec)
+        sensitivity = rec
+        specificity = tn / max(tn + fp, 1e-9)
+        accs.append(acc)
+        f1s.append(f1)
+        sens.append(sensitivity)
+        specs.append(specificity)
 
-        fig_roc.update_layout(
-            title=dict(text=title_text, x=0.5),
-            xaxis_title="False Positive Rate",
-            yaxis_title="True Positive Rate",
-            width=700,
-            height=600,
-            margin=dict(t=80, l=80, r=80, b=80),
-            legend=dict(
-                x=0.6,
-                y=0.1,
-                bgcolor="rgba(255,255,255,0.9)",
-                bordercolor="rgba(0,0,0,0.2)",
-                borderwidth=1
-            ),
-            hovermode='closest'
-        )
+    fig_thresh = go.Figure()
+    fig_thresh.add_trace(go.Scatter(x=thresholds, y=accs, mode="lines", name="Accuracy", line=dict(width=4)))
+    fig_thresh.add_trace(go.Scatter(x=thresholds, y=f1s, mode="lines", name="F1", line=dict(width=4)))
+    fig_thresh.add_trace(go.Scatter(x=thresholds, y=sens, mode="lines", name="Sensitivity", line=dict(width=4)))
+    fig_thresh.add_trace(go.Scatter(x=thresholds, y=specs, mode="lines", name="Specificity", line=dict(width=4)))
+    fig_thresh.update_layout(
+        title=dict(text="Threshold Sweep: Accuracy, F1, Sensitivity, Specificity", x=0.5),
+        xaxis_title="Decision threshold",
+        yaxis_title="Metric value",
+        width=700,
+        height=500,
+        legend=dict(
+            x=0.7,
+            y=0.2,
+            bgcolor="rgba(255,255,255,0.9)",
+            bordercolor="rgba(0,0,0,0.2)",
+            borderwidth=1,
+        ),
+        shapes=[
+            dict(
+                type="line",
+                x0=threshold,
+                x1=threshold,
+                y0=0,
+                y1=1,
+                xref="x",
+                yref="paper",
+                line=dict(color="#d62728", width=2, dash="dash"),
+            )
+        ] if isinstance(threshold, (int, float)) else [],
+        annotations=[
+            dict(
+                x=threshold,
+                y=1.02,
+                xref="x",
+                yref="paper",
+                showarrow=False,
+                text=f"Threshold = {threshold:.2f}",
+                font=dict(size=11, color="#d62728"),
+            )
+        ] if isinstance(threshold, (int, float)) else [],
+    )
+    _style_fig(fig_thresh)
+    plots.append({
+        "title": "Threshold Sweep: Accuracy, F1, Sensitivity, Specificity",
+        "html": pio.to_html(fig_thresh, full_html=False, include_plotlyjs=False),
+    })
 
-        # Set equal aspect ratio and proper range
-        fig_roc.update_xaxes(range=[0, 1.0])
-        fig_roc.update_yaxes(range=[0, 1.05])
-
-        return {
-            "title": "ROC-AUC Curves",
-            "html": pio.to_html(
-                fig_roc,
-                full_html=False,
-                include_plotlyjs=False,
-                config=config
-            )
-        }
-
-    except Exception as e:
-        print(f"Error building ROC-AUC plot: {e}")
-        return None
+    return plots