view plotly_plots.py @ 15:d17e3a1b8659 draft default tip

planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
author goeckslab
date Fri, 28 Nov 2025 15:45:49 +0000
parents c5150cceab47
children
line wrap: on
line source

import json
from pathlib import Path
from typing import Dict, List, Optional

import numpy as np
import pandas as pd
import plotly.graph_objects as go
import plotly.io as pio
from constants import LABEL_COLUMN_NAME, SPLIT_COLUMN_NAME


def _style_fig(fig: go.Figure, font_size: int = 12) -> go.Figure:
    """Apply consistent styling across Plotly figures."""
    fig.update_layout(
        font=dict(size=font_size),
        plot_bgcolor="#ffffff",
        paper_bgcolor="#ffffff",
    )
    fig.update_xaxes(gridcolor="#e8e8e8")
    fig.update_yaxes(gridcolor="#e8e8e8")
    return fig


def _labels_from_metadata_dict(meta_dict: dict) -> List[str]:
    """Extract ordered label names from Ludwig train_set_metadata."""
    if not isinstance(meta_dict, dict):
        return []

    for key in ("idx2str", "idx2label", "vocab"):
        seq = meta_dict.get(key)
        if isinstance(seq, list) and seq:
            return [str(v) for v in seq]

    str2idx = meta_dict.get("str2idx")
    if isinstance(str2idx, dict) and str2idx:
        int_indices = [v for v in str2idx.values() if isinstance(v, int)]
        if int_indices:
            max_idx = max(int_indices)
            ordered = [None] * (max_idx + 1)
            for name, idx in str2idx.items():
                if isinstance(idx, int) and 0 <= idx < len(ordered):
                    ordered[idx] = name
            return [str(v) for v in ordered if v is not None]

    return []


def _resolve_confusion_labels(
    label_stats: dict,
    n_classes: int,
    metadata_csv_path: Optional[str],
    train_set_metadata_path: Optional[str],
) -> List[str]:
    """Prefer original labels from metadata; fall back to stats if unavailable."""
    if train_set_metadata_path:
        try:
            meta_path = Path(train_set_metadata_path)
            if meta_path.exists():
                with open(meta_path, "r") as f:
                    meta_json = json.load(f)
                label_meta = meta_json.get(LABEL_COLUMN_NAME)
                if not isinstance(label_meta, dict):
                    label_meta = next(
                        (
                            v
                            for v in meta_json.values()
                            if isinstance(v, dict)
                            and any(k in v for k in ("idx2str", "str2idx", "idx2label", "vocab"))
                        ),
                        None,
                    )
                labels_from_meta = _labels_from_metadata_dict(label_meta) if label_meta else []
                if labels_from_meta and len(labels_from_meta) >= n_classes:
                    return [str(label) for label in labels_from_meta[:n_classes]]
        except Exception as exc:
            print(f"Warning: Unable to read labels from train_set_metadata: {exc}")

    if metadata_csv_path:
        try:
            csv_path = Path(metadata_csv_path)
            if csv_path.exists():
                df_meta = pd.read_csv(csv_path)
                if LABEL_COLUMN_NAME in df_meta.columns:
                    uniques = df_meta[LABEL_COLUMN_NAME].dropna().unique().tolist()
                    if uniques and len(uniques) >= n_classes:
                        return [str(u) for u in uniques[:n_classes]]
        except Exception as exc:
            print(f"Warning: Unable to read labels from metadata CSV: {exc}")

    pcs = label_stats.get("per_class_stats", {})
    if pcs:
        pcs_labels = [str(k) for k in pcs.keys()]
        if len(pcs_labels) >= n_classes:
            return pcs_labels[:n_classes]

    labels = label_stats.get("labels")
    if not labels:
        labels = [str(i) for i in range(n_classes)]
    if len(labels) < n_classes:
        labels = labels + [str(i) for i in range(len(labels), n_classes)]
    return [str(label) for label in labels[:n_classes]]


def build_classification_plots(
    test_stats_path: str,
    training_stats_path: Optional[str] = None,
    metadata_csv_path: Optional[str] = None,
    train_set_metadata_path: Optional[str] = None,
) -> List[Dict[str, str]]:
    """
    Read Ludwig’s test_statistics.json and build three interactive Plotly panels:
      - Confusion Matrix
      - ROC-AUC
      - Classification Report Heatmap

    If metadata paths are provided, the confusion matrix axes will use the original
    label values from the training metadata rather than integer-encoded labels.

    Returns a list of dicts, each with:
      {
        "title": <plot title>,
        "html":  <HTML fragment for embedding>
      }
    """
    # --- Load test stats ---
    with open(test_stats_path, "r") as f:
        test_stats = json.load(f)
    label_stats = test_stats["label"]

    # common sizing
    cell = 40
    n_classes = len(label_stats["confusion_matrix"])
    side_px = max(cell * n_classes + 200, 600)
    common_cfg = {"displayModeBar": True, "scrollZoom": True}

    plots: List[Dict[str, str]] = []

    # 0) Confusion Matrix
    cm = np.array(label_stats["confusion_matrix"], dtype=int)
    labels = _resolve_confusion_labels(
        label_stats,
        n_classes,
        metadata_csv_path=metadata_csv_path,
        train_set_metadata_path=train_set_metadata_path,
    )
    total = cm.sum()

    fig_cm = go.Figure(
        go.Heatmap(
            z=cm,
            x=labels,
            y=labels,
            colorscale="Blues",
            showscale=True,
            colorbar=dict(title="Count"),
        )
    )
    fig_cm.update_traces(xgap=2, ygap=2)
    fig_cm.update_layout(
        title=dict(text="Confusion Matrix", x=0.5),
        xaxis_title="Predicted",
        yaxis_title="Observed",
        yaxis_autorange="reversed",
        width=side_px,
        height=side_px,
        margin=dict(t=100, l=80, r=80, b=80),
    )
    _style_fig(fig_cm)

    # annotate counts and percentages
    mval = cm.max() if cm.size else 0
    thresh = mval / 2
    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            v = cm[i, j]
            pct = (v / total * 100) if total > 0 else 0
            color = "white" if v > thresh else "black"
            fig_cm.add_annotation(
                x=labels[j],
                y=labels[i],
                text=f"<b>{v}</b>",
                showarrow=False,
                font=dict(color=color, size=14),
                xanchor="center",
                yanchor="bottom",
                yshift=2,
            )
            fig_cm.add_annotation(
                x=labels[j],
                y=labels[i],
                text=f"{pct:.1f}%",
                showarrow=False,
                font=dict(color=color, size=13),
                xanchor="center",
                yanchor="top",
                yshift=-2,
            )

    plots.append({
        "title": "Confusion Matrix",
        "html": pio.to_html(
            fig_cm,
            full_html=False,
            include_plotlyjs="cdn",
            config=common_cfg
        )
    })

    # 1) ROC Curve (from test_statistics)
    roc_plot = _build_static_roc_plot(label_stats, common_cfg, friendly_labels=labels)
    if roc_plot:
        plots.append(roc_plot)

    # 2) Precision-Recall Curve (from test_statistics)
    pr_plot = _build_precision_recall_plot(label_stats, common_cfg)
    if pr_plot:
        plots.append(pr_plot)

    # 2) Classification Report Heatmap
    pcs = label_stats.get("per_class_stats", {})
    if pcs:
        classes = list(pcs.keys())
        metrics = [
            "precision",
            "recall",
            "f1_score",
            "accuracy",
            "matthews_correlation_coefficient",
            "specificity",
        ]
        z, txt = [], []
        for c in classes:
            row, trow = [], []
            for m in metrics:
                val = pcs[c].get(m, 0)
                row.append(val)
                trow.append(f"{val:.2f}")
            z.append(row)
            txt.append(trow)

        fig_cr = go.Figure(
            go.Heatmap(
                z=z,
                x=[m.replace("_", " ") for m in metrics],
                y=[str(c) for c in classes],
                text=txt,
                texttemplate="%{text}",
                colorscale="Reds",
                showscale=True,
                colorbar=dict(title="Value"),
            )
        )
        fig_cr.update_layout(
            title="Per-Class metrics",
            xaxis_title="",
            yaxis_title="Class",
            width=side_px,
            height=side_px,
            margin=dict(t=80, l=80, r=80, b=80),
        )
        _style_fig(fig_cr)
        plots.append({
            "title": "Per-Class metrics",
            "html": pio.to_html(
                fig_cr,
                full_html=False,
                include_plotlyjs=False,
                config=common_cfg
            )
        })

    # 3) Prediction Diagnostics (from predictions.csv)
    # Note: appended separately in generate_html_report, not returned here.

    return plots


def build_train_validation_plots(train_stats_path: str) -> List[Dict[str, str]]:
    """Generate Train/Validation learning curve plots from training_statistics.json."""
    if not train_stats_path or not Path(train_stats_path).exists():
        return []
    try:
        with open(train_stats_path, "r") as f:
            train_stats = json.load(f)
    except Exception as exc:
        print(f"Warning: Unable to read training statistics: {exc}")
        return []

    label_train = (train_stats.get("training") or {}).get("label", {})
    label_val = (train_stats.get("validation") or {}).get("label", {})
    if not label_train and not label_val:
        return []
    plots: List[Dict[str, str]] = []
    include_js = True  # Load Plotly.js once for this group

    def _get_series(stats: dict, metric: str) -> List[float]:
        if metric not in stats:
            return []
        vals = stats.get(metric, [])
        if isinstance(vals, list):
            return [float(v) for v in vals]
        try:
            return [float(vals)]
        except Exception:
            return []

    def _line_plot(metric_key: str, title: str, yaxis_title: str) -> Optional[Dict[str, str]]:
        train_series = _get_series(label_train, metric_key)
        val_series = _get_series(label_val, metric_key)
        if not train_series and not val_series:
            return None
        epochs_train = list(range(1, len(train_series) + 1))
        epochs_val = list(range(1, len(val_series) + 1))
        fig = go.Figure()
        if train_series:
            fig.add_trace(
                go.Scatter(
                    x=epochs_train,
                    y=train_series,
                    mode="lines+markers",
                    name="Train",
                    line=dict(width=4),
                )
            )
        if val_series:
            fig.add_trace(
                go.Scatter(
                    x=epochs_val,
                    y=val_series,
                    mode="lines+markers",
                    name="Validation",
                    line=dict(width=4),
                )
            )
        fig.update_layout(
            title=dict(text=title, x=0.5),
            xaxis_title="Epoch",
            yaxis_title=yaxis_title,
            width=760,
            height=520,
            hovermode="x unified",
        )
        _style_fig(fig)
        return {
            "title": title,
            "html": pio.to_html(
                fig,
                full_html=False,
                include_plotlyjs="cdn" if include_js else False,
            ),
        }

    # Core learning curves
    for key, title in [
        ("roc_auc", "ROC-AUC across epochs"),
        ("precision", "Precision across epochs"),
        ("recall", "Recall/Sensitivity across epochs"),
        ("specificity", "Specificity across epochs"),
    ]:
        plot = _line_plot(key, title, title.replace("Learning Curve", "").strip())
        if plot:
            plots.append(plot)
            include_js = False

    # Precision vs Recall evolution (validation)
    val_prec = _get_series(label_val, "precision")
    val_rec = _get_series(label_val, "recall")
    if val_prec and val_rec:
        epochs = list(range(1, min(len(val_prec), len(val_rec)) + 1))
        fig_pr = go.Figure()
        fig_pr.add_trace(
            go.Scatter(
                x=epochs,
                y=val_prec[: len(epochs)],
                mode="lines+markers",
                name="Precision",
            )
        )
        fig_pr.add_trace(
            go.Scatter(
                x=epochs,
                y=val_rec[: len(epochs)],
                mode="lines+markers",
                name="Recall",
            )
        )
        fig_pr.update_layout(
            title=dict(text="Validation Precision and Recall by Epoch", x=0.5),
            xaxis_title="Epoch",
            yaxis_title="Value",
            width=760,
            height=520,
            hovermode="x unified",
        )
        _style_fig(fig_pr)
        plots.append({
            "title": "Precision vs Recall Evolution",
            "html": pio.to_html(
                fig_pr,
                full_html=False,
                include_plotlyjs="cdn" if include_js else False,
            ),
        })
        include_js = False

    # F1-score derived
    def _compute_f1(p: List[float], r: List[float]) -> List[float]:
        f1_vals = []
        for prec, rec in zip(p, r):
            if (prec + rec) == 0:
                f1_vals.append(0.0)
            else:
                f1_vals.append(2 * prec * rec / (prec + rec))
        return f1_vals

    f1_train = _compute_f1(_get_series(label_train, "precision"), _get_series(label_train, "recall"))
    f1_val = _compute_f1(val_prec, val_rec)
    if f1_train or f1_val:
        fig = go.Figure()
        if f1_train:
            fig.add_trace(go.Scatter(x=list(range(1, len(f1_train) + 1)), y=f1_train, mode="lines+markers", name="Train", line=dict(width=4)))
        if f1_val:
            fig.add_trace(go.Scatter(x=list(range(1, len(f1_val) + 1)), y=f1_val, mode="lines+markers", name="Validation", line=dict(width=4)))
        fig.update_layout(
            title=dict(text="F1-Score across epochs (derived)", x=0.5),
            xaxis_title="Epoch",
            yaxis_title="F1-Score",
            width=760,
            height=520,
            hovermode="x unified",
        )
        _style_fig(fig)
        plots.append({
            "title": "F1-Score across epochs (derived)",
            "html": pio.to_html(
                fig,
                full_html=False,
                include_plotlyjs="cdn" if include_js else False,
            ),
        })
        include_js = False

    # Overfitting Gap: Train vs Val ROC-AUC (gap)
    roc_train = _get_series(label_train, "roc_auc")
    roc_val = _get_series(label_val, "roc_auc")
    if roc_train and roc_val:
        epochs_gap = list(range(1, min(len(roc_train), len(roc_val)) + 1))
        gaps = [t - v for t, v in zip(roc_train[:len(epochs_gap)], roc_val[:len(epochs_gap)])]
        fig_gap = go.Figure()
        fig_gap.add_trace(go.Scatter(x=epochs_gap, y=gaps, mode="lines+markers", name="Train - Val ROC-AUC", line=dict(width=4)))
        fig_gap.update_layout(
            title=dict(text="Overfitting gap: ROC-AUC across epochs", x=0.5),
            xaxis_title="Epoch",
            yaxis_title="Gap",
            width=760,
            height=520,
            hovermode="x unified",
        )
        _style_fig(fig_gap)
        plots.append({
            "title": "Overfitting gap: ROC-AUC across epochs",
            "html": pio.to_html(
                fig_gap,
                full_html=False,
                include_plotlyjs="cdn" if include_js else False,
            ),
        })
        include_js = False

    # Best Epoch Dashboard (based on max val ROC-AUC)
    if roc_val:
        best_idx = int(np.argmax(roc_val))
        best_epoch = best_idx + 1
        spec_val = _get_series(label_val, "specificity")
        metrics_at_best = {
            "ROC-AUC": roc_val[best_idx] if best_idx < len(roc_val) else None,
            "Precision": val_prec[best_idx] if best_idx < len(val_prec) else None,
            "Recall": val_rec[best_idx] if best_idx < len(val_rec) else None,
            "Specificity": spec_val[best_idx] if best_idx < len(spec_val) else None,
            "F1-Score": f1_val[best_idx] if best_idx < len(f1_val) else None,
        }
        fig_best = go.Figure()
        for name, value in metrics_at_best.items():
            if value is not None:
                fig_best.add_trace(go.Bar(name=name, x=[name], y=[value]))
        fig_best.update_layout(
            title=dict(text=f"Best Epoch Dashboard (Val ROC-AUC @ epoch {best_epoch})", x=0.5),
            xaxis_title="Metric",
            yaxis_title="Value",
            width=760,
            height=520,
            showlegend=False,
        )
        _style_fig(fig_best)
        plots.append({
            "title": "Best Validation Epoch Snapshot (Metrics)",
            "html": pio.to_html(
                fig_best,
                full_html=False,
                include_plotlyjs="cdn" if include_js else False,
            ),
        })
        include_js = False

    return plots


def _get_regression_series(split_stats: dict, metric: str) -> List[float]:
    if metric not in split_stats:
        return []
    vals = split_stats.get(metric, [])
    if isinstance(vals, list):
        return [float(v) for v in vals]
    try:
        return [float(vals)]
    except Exception:
        return []


def _regression_line_plot(
    train_split: dict,
    val_split: dict,
    metric_key: str,
    title: str,
    yaxis_title: str,
    include_js: bool,
) -> Optional[Dict[str, str]]:
    train_series = _get_regression_series(train_split, metric_key)
    val_series = _get_regression_series(val_split, metric_key)
    if not train_series and not val_series:
        return None
    epochs_train = list(range(1, len(train_series) + 1))
    epochs_val = list(range(1, len(val_series) + 1))
    fig = go.Figure()
    if train_series:
        fig.add_trace(
            go.Scatter(
                x=epochs_train,
                y=train_series,
                mode="lines+markers",
                name="Train",
                line=dict(width=4),
            )
        )
    if val_series:
        fig.add_trace(
            go.Scatter(
                x=epochs_val,
                y=val_series,
                mode="lines+markers",
                name="Validation",
                line=dict(width=4),
            )
        )
    fig.update_layout(
        title=dict(text=title, x=0.5),
        xaxis_title="Epoch",
        yaxis_title=yaxis_title,
        width=760,
        height=520,
        hovermode="x unified",
    )
    _style_fig(fig)
    return {
        "title": title,
        "html": pio.to_html(
            fig,
            full_html=False,
            include_plotlyjs="cdn" if include_js else False,
        ),
    }


def build_regression_train_val_plots(train_stats_path: str) -> List[Dict[str, str]]:
    """Generate regression Train/Validation learning curve plots from training_statistics.json."""
    if not train_stats_path or not Path(train_stats_path).exists():
        return []
    try:
        with open(train_stats_path, "r") as f:
            train_stats = json.load(f)
    except Exception as exc:
        print(f"Warning: Unable to read training statistics: {exc}")
        return []

    label_train = (train_stats.get("training") or {}).get("label", {})
    label_val = (train_stats.get("validation") or {}).get("label", {})
    if not label_train and not label_val:
        return []

    plots: List[Dict[str, str]] = []
    include_js = True
    for metric_key, title, ytitle in [
        ("mean_absolute_error", "Mean Absolute Error across epochs", "MAE"),
        ("root_mean_squared_error", "Root Mean Squared Error across epochs", "RMSE"),
        ("mean_absolute_percentage_error", "Mean Absolute Percentage Error across epochs", "MAPE"),
        ("r2", "R² across epochs", "R²"),
        ("loss", "Loss across epochs", "Loss"),
    ]:
        plot = _regression_line_plot(label_train, label_val, metric_key, title, ytitle, include_js)
        if plot:
            plots.append(plot)
            include_js = False
    return plots


def build_regression_test_plots(train_stats_path: str) -> List[Dict[str, str]]:
    """Generate regression Test learning curves from training_statistics.json."""
    if not train_stats_path or not Path(train_stats_path).exists():
        return []
    try:
        with open(train_stats_path, "r") as f:
            train_stats = json.load(f)
    except Exception as exc:
        print(f"Warning: Unable to read training statistics: {exc}")
        return []

    label_test = (train_stats.get("test") or {}).get("label", {})
    if not label_test:
        return []

    plots: List[Dict[str, str]] = []
    include_js = True
    metrics = [
        ("mean_absolute_error", "Mean Absolute Error Across Epochs", "MAE"),
        ("root_mean_squared_error", "Root Mean Squared Error Across Epochs", "RMSE"),
        ("mean_absolute_percentage_error", "Mean Absolute Percentage Error Across Epochs", "MAPE"),
        ("r2", "R² Across Epochs", "R²"),
        ("loss", "Loss Across Epochs", "Loss"),
    ]
    epochs = None
    for metric_key, title, ytitle in metrics:
        series = _get_regression_series(label_test, metric_key)
        if not series:
            continue
        if epochs is None:
            epochs = list(range(1, len(series) + 1))
        fig = go.Figure()
        fig.add_trace(
            go.Scatter(
                x=epochs,
                y=series[: len(epochs)],
                mode="lines+markers",
                name="Test",
                line=dict(width=4),
            )
        )
        fig.update_layout(
            title=dict(text=title, x=0.5),
            xaxis_title="Epoch",
            yaxis_title=ytitle,
            width=760,
            height=520,
            hovermode="x unified",
        )
        _style_fig(fig)
        plots.append({
            "title": title,
            "html": pio.to_html(
                fig,
                full_html=False,
                include_plotlyjs="cdn" if include_js else False,
            ),
        })
        include_js = False
    return plots


def _build_static_roc_plot(
    label_stats: dict, config: dict, friendly_labels: Optional[List[str]] = None
) -> Optional[Dict[str, str]]:
    """Build ROC curve directly from test_statistics.json (single curve)."""
    roc_data = label_stats.get("roc_curve")
    if not isinstance(roc_data, dict):
        return None

    fpr = roc_data.get("false_positive_rate")
    tpr = roc_data.get("true_positive_rate")
    if not fpr or not tpr or len(fpr) != len(tpr):
        return None

    try:
        fig = go.Figure()
        fig.add_trace(
            go.Scatter(
                x=fpr,
                y=tpr,
                mode="lines+markers",
                name="ROC Curve",
                line=dict(color="#1f77b4", width=4),
                hovertemplate="FPR: %{x:.3f}<br>TPR: %{y:.3f}<extra></extra>",
            )
        )
        fig.add_trace(
            go.Scatter(
                x=[0, 1],
                y=[0, 1],
                mode="lines",
                name="Random Classifier",
                line=dict(color="gray", width=2, dash="dash"),
                hovertemplate="Random Classifier<extra></extra>",
            )
        )

        auc_val = label_stats.get("roc_auc") or label_stats.get("roc_auc_macro") or label_stats.get("roc_auc_micro")
        auc_txt = f" (AUC = {auc_val:.3f})" if isinstance(auc_val, (int, float)) else ""

        # Determine which label is treated as positive for the curve
        label_list: List = []
        pcs = label_stats.get("per_class_stats", {})
        if pcs:
            label_list = list(pcs.keys())
        if not label_list:
            labels_from_stats = label_stats.get("labels")
            if isinstance(labels_from_stats, list):
                label_list = labels_from_stats

        # Try to resolve index of the positive label explicitly provided by Ludwig
        pos_label_raw = (
            roc_data.get("positive_label")
            or roc_data.get("positive_class")
            or label_stats.get("positive_label")
        )
        pos_label_idx = None
        if pos_label_raw is not None and isinstance(label_list, list):
            try:
                pos_label_idx = label_list.index(pos_label_raw)
            except ValueError:
                pos_label_idx = None

        # Fallback: use the second label if available, otherwise the first
        if pos_label_idx is None:
            if isinstance(label_list, list) and len(label_list) >= 2:
                pos_label_idx = 1
            elif isinstance(label_list, list) and label_list:
                pos_label_idx = 0

        if pos_label_raw is None and isinstance(label_list, list) and pos_label_idx is not None:
            pos_label_raw = label_list[pos_label_idx]

        # Map to friendly label if we have one from metadata/CSV
        pos_label_display = pos_label_raw
        if (
            friendly_labels
            and isinstance(pos_label_idx, int)
            and 0 <= pos_label_idx < len(friendly_labels)
        ):
            pos_label_display = friendly_labels[pos_label_idx]

        pos_label_txt = (
            f"Positive class: {pos_label_display}"
            if pos_label_display is not None
            else "Positive class: (not available)"
        )

        title_label = f"ROC Curve{auc_txt}"
        if pos_label_display is not None:
            title_label = f"ROC Curve (Positive Class: {pos_label_display}){auc_txt}"

        fig.update_layout(
            title=dict(text=title_label, x=0.5),
            xaxis_title="False Positive Rate",
            yaxis_title="True Positive Rate",
            width=700,
            height=600,
            margin=dict(t=80, l=80, r=80, b=110),
            hovermode="closest",
            legend=dict(
                x=0.6,
                y=0.1,
                bgcolor="rgba(255,255,255,0.9)",
                bordercolor="rgba(0,0,0,0.2)",
                borderwidth=1,
            ),
        )
        _style_fig(fig)
        fig.update_xaxes(range=[0, 1.0])
        fig.update_yaxes(range=[0, 1.05])

        fig.add_annotation(
            x=0.5,
            y=-0.15,
            xref="paper",
            yref="paper",
            showarrow=False,
            text=f"<span style='font-size:12px;color:#555;'>{pos_label_txt}</span>",
            xanchor="center",
        )

        return {
            "title": "ROC Curve",
            "html": pio.to_html(
                fig,
                full_html=False,
                include_plotlyjs=False,
                config=config,
            ),
        }
    except Exception as e:
        print(f"Error building ROC plot: {e}")
        return None


def _build_precision_recall_plot(label_stats: dict, config: dict) -> Optional[Dict[str, str]]:
    """Build Precision-Recall curve directly from test_statistics.json."""
    pr_data = label_stats.get("precision_recall_curve")
    if not isinstance(pr_data, dict):
        return None

    precisions = pr_data.get("precisions")
    recalls = pr_data.get("recalls")
    if not precisions or not recalls or len(precisions) != len(recalls):
        return None

    try:
        fig = go.Figure()
        fig.add_trace(
            go.Scatter(
                x=recalls,
                y=precisions,
                mode="lines+markers",
                name="Precision-Recall",
                line=dict(color="#d62728", width=4),
                hovertemplate="Recall: %{x:.3f}<br>Precision: %{y:.3f}<extra></extra>",
            )
        )

        ap_val = (
            label_stats.get("average_precision_macro")
            or label_stats.get("average_precision_micro")
            or label_stats.get("average_precision_samples")
        )
        ap_txt = f" (AP = {ap_val:.3f})" if isinstance(ap_val, (int, float)) else ""

        fig.update_layout(
            title=dict(text=f"Precision-Recall Curve{ap_txt}", x=0.5),
            xaxis_title="Recall",
            yaxis_title="Precision",
            width=700,
            height=600,
            margin=dict(t=80, l=80, r=80, b=80),
            hovermode="closest",
            legend=dict(
                x=0.6,
                y=0.1,
                bgcolor="rgba(255,255,255,0.9)",
                bordercolor="rgba(0,0,0,0.2)",
                borderwidth=1,
            ),
        )
        _style_fig(fig)
        fig.update_xaxes(range=[0, 1.0])
        fig.update_yaxes(range=[0, 1.05])

        return {
            "title": "Precision-Recall Curve",
            "html": pio.to_html(
                fig,
                full_html=False,
                include_plotlyjs=False,
                config=config,
            ),
        }
    except Exception as e:
        print(f"Error building Precision-Recall plot: {e}")
        return None


def build_prediction_diagnostics(
    predictions_path: str,
    label_data_path: Optional[str] = None,
    split_value: int = 2,
    threshold: Optional[float] = None,
) -> List[Dict[str, str]]:
    """Generate diagnostic plots from predictions.csv for classification tasks."""
    preds_file = Path(predictions_path)
    if not preds_file.exists():
        return []

    try:
        df_pred = pd.read_csv(predictions_path)
    except Exception as exc:
        print(f"Warning: Unable to read predictions CSV: {exc}")
        return []

    plots: List[Dict[str, str]] = []

    # Identify probability columns
    prob_cols = [
        c for c in df_pred.columns
        if c.startswith("label_probabilities_") and c != "label_probabilities"
    ]
    prob_cols_sorted = sorted(prob_cols)

    def _select_positive_prob():
        if not prob_cols_sorted:
            return None, None
        # Prefer a column indicating positive/event/true/1
        preferred_keys = ("event", "true", "positive", "pos", "1")
        for col in prob_cols_sorted:
            suffix = col.replace("label_probabilities_", "").lower()
            if any(k in suffix for k in preferred_keys):
                return col, suffix
        if len(prob_cols_sorted) == 2:
            col = prob_cols_sorted[1]
            return col, col.replace("label_probabilities_", "")
        col = prob_cols_sorted[0]
        return col, col.replace("label_probabilities_", "")

    pos_prob_col, pos_label_hint = _select_positive_prob()
    pos_prob_series = df_pred[pos_prob_col] if pos_prob_col and pos_prob_col in df_pred else None

    # Confidence series: prefer label_probability, otherwise positive prob, otherwise max prob
    confidence_series = None
    if "label_probability" in df_pred.columns:
        confidence_series = df_pred["label_probability"]
    elif pos_prob_series is not None:
        confidence_series = pos_prob_series
    elif prob_cols_sorted:
        confidence_series = df_pred[prob_cols_sorted].max(axis=1)

    # True labels
    def _extract_labels():
        candidates = [
            LABEL_COLUMN_NAME,
            f"{LABEL_COLUMN_NAME}_ground_truth",
            f"{LABEL_COLUMN_NAME}__ground_truth",
            f"{LABEL_COLUMN_NAME}_target",
            f"{LABEL_COLUMN_NAME}__target",
            "label",
            "label_true",
        ]
        candidates.extend(
            [
                col
                for col in df_pred.columns
                if (col.startswith(f"{LABEL_COLUMN_NAME}_") or col.startswith(f"{LABEL_COLUMN_NAME}__"))
                and "probabilities" not in col
                and "predictions" not in col
            ]
        )
        for col in candidates:
            if col in df_pred.columns and col not in prob_cols_sorted:
                return df_pred[col]
        if label_data_path and Path(label_data_path).exists():
            try:
                df_all = pd.read_csv(label_data_path)
                if SPLIT_COLUMN_NAME in df_all.columns:
                    df_all = df_all[df_all[SPLIT_COLUMN_NAME] == split_value].reset_index(drop=True)
                if LABEL_COLUMN_NAME in df_all.columns:
                    return df_all[LABEL_COLUMN_NAME].reset_index(drop=True)
            except Exception as exc:
                print(f"Warning: Unable to load labels from dataset: {exc}")
        return None

    labels_series = _extract_labels()

    # Plot 1: Confidence Histogram
    if confidence_series is not None:
        fig_conf = go.Figure()
        fig_conf.add_trace(
            go.Histogram(
                x=confidence_series,
                nbinsx=20,
                marker=dict(color="#1f77b4", line=dict(color="#ffffff", width=1)),
                opacity=0.8,
                histnorm="percent",
            )
        )
        fig_conf.update_layout(
            title=dict(text="Prediction Confidence Distribution", x=0.5),
            xaxis_title="Predicted probability (confidence)",
            yaxis_title="Percentage (%)",
            bargap=0.05,
            width=700,
            height=500,
        )
        _style_fig(fig_conf)
        plots.append({
            "title": "Prediction Confidence Distribution",
            "html": pio.to_html(fig_conf, full_html=False, include_plotlyjs=False),
        })

    # The remaining plots require true labels and a positive-class probability
    if labels_series is None or pos_prob_series is None:
        return plots

    # Align lengths
    min_len = min(len(labels_series), len(pos_prob_series))
    if min_len == 0:
        return plots
    y_true_raw = labels_series.iloc[:min_len]
    y_score = np.array(pos_prob_series.iloc[:min_len], dtype=float)

    # Determine positive label
    unique_labels = pd.unique(y_true_raw)
    unique_labels_list = list(unique_labels)
    positive_label = None
    if pos_label_hint and str(pos_label_hint) in [str(u) for u in unique_labels_list]:
        positive_label = pos_label_hint
    elif len(unique_labels_list) == 2:
        positive_label = unique_labels_list[1]
    else:
        positive_label = unique_labels_list[0]

    y_true = (y_true_raw == positive_label).astype(int).values

    # Plot 2: Calibration Curve
    bins = np.linspace(0.0, 1.0, 11)
    bin_ids = np.digitize(y_score, bins, right=True)
    bin_centers = []
    frac_positives = []
    for b in range(1, len(bins)):
        mask = bin_ids == b
        if not np.any(mask):
            continue
        bin_centers.append(y_score[mask].mean())
        frac_positives.append(y_true[mask].mean())
    if bin_centers and frac_positives:
        fig_cal = go.Figure()
        fig_cal.add_trace(
            go.Scatter(
                x=bin_centers,
                y=frac_positives,
                mode="lines+markers",
                name="Calibration",
                line=dict(color="#2ca02c", width=4),
            )
        )
        fig_cal.add_trace(
            go.Scatter(
                x=[0, 1],
                y=[0, 1],
                mode="lines",
                name="Perfect Calibration",
                line=dict(color="gray", width=2, dash="dash"),
            )
        )
        fig_cal.update_layout(
            title=dict(text="Calibration Curve", x=0.5),
            xaxis_title="Predicted probability",
            yaxis_title="Observed frequency",
            width=700,
            height=500,
        )
        _style_fig(fig_cal)
        plots.append({
            "title": "Calibration Curve (Predicted Probability vs Observed Frequency)",
            "html": pio.to_html(fig_cal, full_html=False, include_plotlyjs=False),
        })

    # Plot 3: Threshold vs Metrics
    thresholds = np.linspace(0.0, 1.0, 21)
    accs, f1s, sens, specs = [], [], [], []
    for t in thresholds:
        y_pred = (y_score >= t).astype(int)
        tp = np.sum((y_true == 1) & (y_pred == 1))
        tn = np.sum((y_true == 0) & (y_pred == 0))
        fp = np.sum((y_true == 0) & (y_pred == 1))
        fn = np.sum((y_true == 1) & (y_pred == 0))
        acc = (tp + tn) / max(len(y_true), 1)
        prec = tp / max(tp + fp, 1e-9)
        rec = tp / max(tp + fn, 1e-9)
        f1 = 0 if (prec + rec) == 0 else 2 * prec * rec / (prec + rec)
        sensitivity = rec
        specificity = tn / max(tn + fp, 1e-9)
        accs.append(acc)
        f1s.append(f1)
        sens.append(sensitivity)
        specs.append(specificity)

    fig_thresh = go.Figure()
    fig_thresh.add_trace(go.Scatter(x=thresholds, y=accs, mode="lines", name="Accuracy", line=dict(width=4)))
    fig_thresh.add_trace(go.Scatter(x=thresholds, y=f1s, mode="lines", name="F1", line=dict(width=4)))
    fig_thresh.add_trace(go.Scatter(x=thresholds, y=sens, mode="lines", name="Sensitivity", line=dict(width=4)))
    fig_thresh.add_trace(go.Scatter(x=thresholds, y=specs, mode="lines", name="Specificity", line=dict(width=4)))
    fig_thresh.update_layout(
        title=dict(text="Threshold Sweep: Accuracy, F1, Sensitivity, Specificity", x=0.5),
        xaxis_title="Decision threshold",
        yaxis_title="Metric value",
        width=700,
        height=500,
        legend=dict(
            x=0.7,
            y=0.2,
            bgcolor="rgba(255,255,255,0.9)",
            bordercolor="rgba(0,0,0,0.2)",
            borderwidth=1,
        ),
        shapes=[
            dict(
                type="line",
                x0=threshold,
                x1=threshold,
                y0=0,
                y1=1,
                xref="x",
                yref="paper",
                line=dict(color="#d62728", width=2, dash="dash"),
            )
        ] if isinstance(threshold, (int, float)) else [],
        annotations=[
            dict(
                x=threshold,
                y=1.02,
                xref="x",
                yref="paper",
                showarrow=False,
                text=f"Threshold = {threshold:.2f}",
                font=dict(size=11, color="#d62728"),
            )
        ] if isinstance(threshold, (int, float)) else [],
    )
    _style_fig(fig_thresh)
    plots.append({
        "title": "Threshold Sweep: Accuracy, F1, Sensitivity, Specificity",
        "html": pio.to_html(fig_thresh, full_html=False, include_plotlyjs=False),
    })

    return plots