diff plotly_plots.py @ 8:85e6f4b2ad18 draft default tip

planemo upload for repository https://github.com/goeckslab/gleam.git commit 8a42eb9b33df7e1df5ad5153b380e20b910a05b6
author goeckslab
date Thu, 14 Aug 2025 14:53:10 +0000
parents
children
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/plotly_plots.py	Thu Aug 14 14:53:10 2025 +0000
@@ -0,0 +1,148 @@
+import json
+from typing import Dict, List, Optional
+
+import numpy as np
+import plotly.graph_objects as go
+import plotly.io as pio
+
+
+def build_classification_plots(
+    test_stats_path: str,
+    training_stats_path: Optional[str] = None,
+) -> List[Dict[str, str]]:
+    """
+    Read Ludwig’s test_statistics.json and build three interactive Plotly panels:
+      - Confusion Matrix
+      - ROC-AUC
+      - Classification Report Heatmap
+
+    Returns a list of dicts, each with:
+      {
+        "title": <plot title>,
+        "html":  <HTML fragment for embedding>
+      }
+    """
+    # --- Load test stats ---
+    with open(test_stats_path, "r") as f:
+        test_stats = json.load(f)
+    label_stats = test_stats["label"]
+
+    # common sizing
+    cell = 40
+    n_classes = len(label_stats["confusion_matrix"])
+    side_px = max(cell * n_classes + 200, 600)
+    common_cfg = {"displayModeBar": True, "scrollZoom": True}
+
+    plots: List[Dict[str, str]] = []
+
+    # 0) Confusion Matrix
+    cm = np.array(label_stats["confusion_matrix"], dtype=int)
+    labels = label_stats.get("labels", [str(i) for i in range(cm.shape[0])])
+    total = cm.sum()
+
+    fig_cm = go.Figure(
+        go.Heatmap(
+            z=cm,
+            x=labels,
+            y=labels,
+            colorscale="Blues",
+            showscale=True,
+            colorbar=dict(title="Count"),
+        )
+    )
+    fig_cm.update_traces(xgap=2, ygap=2)
+    fig_cm.update_layout(
+        title=dict(text="Confusion Matrix", x=0.5),
+        xaxis_title="Predicted",
+        yaxis_title="Observed",
+        yaxis_autorange="reversed",
+        width=side_px,
+        height=side_px,
+        margin=dict(t=100, l=80, r=80, b=80),
+    )
+
+    # annotate counts and percentages
+    mval = cm.max() if cm.size else 0
+    thresh = mval / 2
+    for i in range(cm.shape[0]):
+        for j in range(cm.shape[1]):
+            v = cm[i, j]
+            pct = (v / total * 100) if total > 0 else 0
+            color = "white" if v > thresh else "black"
+            fig_cm.add_annotation(
+                x=labels[j],
+                y=labels[i],
+                text=f"<b>{v}</b>",
+                showarrow=False,
+                font=dict(color=color, size=14),
+                xanchor="center",
+                yanchor="bottom",
+                yshift=2,
+            )
+            fig_cm.add_annotation(
+                x=labels[j],
+                y=labels[i],
+                text=f"{pct:.1f}%",
+                showarrow=False,
+                font=dict(color=color, size=13),
+                xanchor="center",
+                yanchor="top",
+                yshift=-2,
+            )
+
+    plots.append({
+        "title": "Confusion Matrix",
+        "html": pio.to_html(
+            fig_cm,
+            full_html=False,
+            include_plotlyjs="cdn",
+            config=common_cfg
+        )
+    })
+
+    # 2) Classification Report Heatmap
+    pcs = label_stats.get("per_class_stats", {})
+    if pcs:
+        classes = list(pcs.keys())
+        metrics = ["precision", "recall", "f1_score"]
+        z, txt = [], []
+        for c in classes:
+            row, trow = [], []
+            for m in metrics:
+                val = pcs[c].get(m, 0)
+                row.append(val)
+                trow.append(f"{val:.2f}")
+            z.append(row)
+            txt.append(trow)
+
+        fig_cr = go.Figure(
+            go.Heatmap(
+                z=z,
+                x=metrics,
+                y=[str(c) for c in classes],
+                text=txt,
+                texttemplate="%{text}",
+                colorscale="Reds",
+                showscale=True,
+                colorbar=dict(title="Value"),
+            )
+        )
+        fig_cr.update_layout(
+            title="Classification Report",
+            xaxis_title="",
+            yaxis_title="Class",
+            width=side_px,
+            height=side_px,
+            margin=dict(t=80, l=80, r=80, b=80),
+        )
+        plots.append({
+            "title": "Classification Report",
+            "html": pio.to_html(
+                fig_cr,
+                full_html=False,
+                include_plotlyjs=False,
+                config=common_cfg
+            )
+        })
+
+    return plots