diff plotly_plots.py @ 11:c5150cceab47 draft default tip

planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
author goeckslab
date Sat, 18 Oct 2025 03:17:09 +0000
parents 85e6f4b2ad18
children
line wrap: on
line diff
--- a/plotly_plots.py	Mon Sep 08 22:38:35 2025 +0000
+++ b/plotly_plots.py	Sat Oct 18 03:17:09 2025 +0000
@@ -1,9 +1,14 @@
 import json
+from pathlib import Path
 from typing import Dict, List, Optional
 
 import numpy as np
+import pandas as pd
 import plotly.graph_objects as go
 import plotly.io as pio
+from constants import LABEL_COLUMN_NAME, SPLIT_COLUMN_NAME
+from sklearn.metrics import auc, roc_curve
+from sklearn.preprocessing import label_binarize
 
 
 def build_classification_plots(
@@ -37,7 +42,12 @@
 
     # 0) Confusion Matrix
     cm = np.array(label_stats["confusion_matrix"], dtype=int)
-    labels = label_stats.get("labels", [str(i) for i in range(cm.shape[0])])
+    # Try to get actual class names from per_class_stats keys (which contain the real labels)
+    pcs = label_stats.get("per_class_stats", {})
+    if pcs:
+        labels = list(pcs.keys())
+    else:
+        labels = label_stats.get("labels", [str(i) for i in range(cm.shape[0])])
     total = cm.sum()
 
     fig_cm = go.Figure(
@@ -100,6 +110,11 @@
         )
     })
 
+    # 1) ROC-AUC Curves (Multi-class)
+    roc_plot = _build_roc_auc_plot(test_stats_path, labels, common_cfg)
+    if roc_plot:
+        plots.append(roc_plot)
+
     # 2) Classification Report Heatmap
     pcs = label_stats.get("per_class_stats", {})
     if pcs:
@@ -146,3 +161,243 @@
         })
 
     return plots
+
+
+def _build_roc_auc_plot(test_stats_path: str, class_labels: List[str], config: dict) -> Optional[Dict[str, str]]:
+    """
+    Build an interactive ROC-AUC curve plot for multi-class classification.
+    Following sklearn's ROC example with micro-average and per-class curves.
+
+    Args:
+        test_stats_path: Path to test_statistics.json
+        class_labels: List of class label names
+        config: Plotly config dict
+
+    Returns:
+        Dict with title and HTML, or None if data unavailable
+    """
+    try:
+        # Get the experiment directory from test_stats_path
+        exp_dir = Path(test_stats_path).parent
+
+        # Load predictions with probabilities
+        predictions_path = exp_dir / "predictions.csv"
+        if not predictions_path.exists():
+            return None
+
+        df_pred = pd.read_csv(predictions_path)
+
+        if SPLIT_COLUMN_NAME in df_pred.columns:
+            split_series = df_pred[SPLIT_COLUMN_NAME].astype(str).str.lower()
+            test_mask = split_series.isin({"2", "test", "testing"})
+            if test_mask.any():
+                df_pred = df_pred[test_mask].reset_index(drop=True)
+
+        if df_pred.empty:
+            return None
+
+        # Extract probability columns (label_probabilities_0, label_probabilities_1, etc.)
+        # or label_probabilities_<class_name> for string labels
+        prob_cols = [col for col in df_pred.columns if col.startswith('label_probabilities_') and col != 'label_probabilities']
+
+        # Sort by class number if numeric, otherwise keep alphabetical order
+        if prob_cols and prob_cols[0].split('_')[-1].isdigit():
+            prob_cols.sort(key=lambda x: int(x.split('_')[-1]))
+        else:
+            prob_cols.sort()  # Alphabetical sort for string class names
+
+        if not prob_cols:
+            return None
+
+        # Get probabilities matrix (n_samples x n_classes)
+        y_score = df_pred[prob_cols].values
+        n_classes = len(prob_cols)
+
+        y_true = None
+        candidate_cols = [
+            LABEL_COLUMN_NAME,
+            f"{LABEL_COLUMN_NAME}_ground_truth",
+            f"{LABEL_COLUMN_NAME}__ground_truth",
+            f"{LABEL_COLUMN_NAME}_target",
+            f"{LABEL_COLUMN_NAME}__target",
+        ]
+        candidate_cols.extend(
+            [
+                col
+                for col in df_pred.columns
+                if (col.startswith(f"{LABEL_COLUMN_NAME}_") or col.startswith(f"{LABEL_COLUMN_NAME}__"))
+                and "probabilities" not in col
+                and "predictions" not in col
+            ]
+        )
+        for col in candidate_cols:
+            if col in df_pred.columns and col not in prob_cols:
+                y_true = df_pred[col].values
+                break
+
+        if y_true is None:
+            desc_path = exp_dir / "description.json"
+            if desc_path.exists():
+                try:
+                    with open(desc_path, 'r') as f:
+                        desc = json.load(f)
+                    dataset_path = desc.get('dataset', '')
+                    if dataset_path and Path(dataset_path).exists():
+                        df_orig = pd.read_csv(dataset_path)
+                        if SPLIT_COLUMN_NAME in df_orig.columns:
+                            df_orig = df_orig[df_orig[SPLIT_COLUMN_NAME] == 2].reset_index(drop=True)
+                        if LABEL_COLUMN_NAME in df_orig.columns:
+                            y_true = df_orig[LABEL_COLUMN_NAME].values
+                            if len(y_true) != len(df_pred):
+                                print(
+                                    f"Warning: Test set size mismatch. Truncating to {len(df_pred)} samples for ROC plot."
+                                )
+                                y_true = y_true[:len(df_pred)]
+                    else:
+                        print("Warning: Original dataset referenced in description.json is unavailable.")
+                except Exception as exc:  # pragma: no cover - defensive
+                    print(f"Warning: Failed to recover labels from dataset: {exc}")
+
+        if y_true is None or len(y_true) == 0:
+            print("Warning: Unable to locate ground-truth labels for ROC plot.")
+            return None
+
+        if len(y_true) != len(y_score):
+            limit = min(len(y_true), len(y_score))
+            if limit == 0:
+                return None
+            print(f"Warning: Aligning prediction and label lengths to {limit} samples for ROC plot.")
+            y_true = y_true[:limit]
+            y_score = y_score[:limit]
+
+        # Get actual class names from probability column names
+        actual_classes = [col.replace('label_probabilities_', '') for col in prob_cols]
+        display_classes = class_labels if len(class_labels) == n_classes else actual_classes
+
+        # Binarize the output following sklearn example
+        # Use actual class names if they're strings, otherwise use range
+        if isinstance(y_true[0], str):
+            y_test = label_binarize(y_true, classes=actual_classes)
+        else:
+            y_test = label_binarize(y_true, classes=list(range(n_classes)))
+
+        # Handle binary classification case
+        if y_test.ndim != 2:
+            y_test = np.atleast_2d(y_test)
+
+        if n_classes == 2:
+            if y_test.shape[1] == 1:
+                y_test = np.hstack([1 - y_test, y_test])
+            elif y_test.shape[1] != 2:
+                print("Warning: Unexpected label binarization shape for binary ROC plot.")
+                return None
+        elif y_test.shape[1] != n_classes:
+            print("Warning: Label binarization did not produce expected class dimension; skipping ROC plot.")
+            return None
+
+        # Compute ROC curve and ROC area for each class (following sklearn example)
+        fpr = dict()
+        tpr = dict()
+        roc_auc = dict()
+
+        for i in range(n_classes):
+            if np.sum(y_test[:, i]) > 0:  # Check if class exists in test set
+                fpr[i], tpr[i], _ = roc_curve(y_test[:, i], y_score[:, i])
+                roc_auc[i] = auc(fpr[i], tpr[i])
+
+        # Compute micro-average ROC curve and ROC area (sklearn example)
+        fpr["micro"], tpr["micro"], _ = roc_curve(y_test.ravel(), y_score.ravel())
+        roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])
+
+        # Create ROC curve plot
+        fig_roc = go.Figure()
+
+        # Colors for different classes
+        colors = [
+            '#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd',
+            '#8c564b', '#e377c2', '#7f7f7f', '#bcbd22', '#17becf'
+        ]
+
+        # Plot micro-average ROC curve first (most important)
+        fig_roc.add_trace(go.Scatter(
+            x=fpr["micro"],
+            y=tpr["micro"],
+            mode='lines',
+            name=f'Micro-average ROC (AUC = {roc_auc["micro"]:.3f})',
+            line=dict(color='deeppink', width=3, dash='dot'),
+            hovertemplate=('<b>Micro-average ROC</b><br>'
+                           'FPR: %{x:.3f}<br>'
+                           'TPR: %{y:.3f}<br>'
+                           f'AUC: {roc_auc["micro"]:.3f}<extra></extra>')
+        ))
+
+        # Plot ROC curve for each class
+        for i in range(n_classes):
+            if i in roc_auc:  # Only plot if class exists in test set
+                class_name = display_classes[i] if i < len(display_classes) else f"Class {i}"
+                color = colors[i % len(colors)]
+
+                fig_roc.add_trace(go.Scatter(
+                    x=fpr[i],
+                    y=tpr[i],
+                    mode='lines',
+                    name=f'{class_name} (AUC = {roc_auc[i]:.3f})',
+                    line=dict(color=color, width=2),
+                    hovertemplate=(f'<b>{class_name}</b><br>'
+                                   'FPR: %{x:.3f}<br>'
+                                   'TPR: %{y:.3f}<br>'
+                                   f'AUC: {roc_auc[i]:.3f}<extra></extra>')
+                ))
+
+        # Add diagonal line (random classifier)
+        fig_roc.add_trace(go.Scatter(
+            x=[0, 1],
+            y=[0, 1],
+            mode='lines',
+            name='Random Classifier',
+            line=dict(color='gray', width=1, dash='dash'),
+            hovertemplate='Random Classifier<br>AUC = 0.500<extra></extra>'
+        ))
+
+        # Calculate macro-average AUC
+        class_aucs = [roc_auc[i] for i in range(n_classes) if i in roc_auc]
+        if class_aucs:
+            macro_auc = np.mean(class_aucs)
+            title_text = f"ROC Curves (Micro-avg = {roc_auc['micro']:.3f}, Macro-avg = {macro_auc:.3f})"
+        else:
+            title_text = f"ROC Curves (Micro-avg = {roc_auc['micro']:.3f})"
+
+        fig_roc.update_layout(
+            title=dict(text=title_text, x=0.5),
+            xaxis_title="False Positive Rate",
+            yaxis_title="True Positive Rate",
+            width=700,
+            height=600,
+            margin=dict(t=80, l=80, r=80, b=80),
+            legend=dict(
+                x=0.6,
+                y=0.1,
+                bgcolor="rgba(255,255,255,0.9)",
+                bordercolor="rgba(0,0,0,0.2)",
+                borderwidth=1
+            ),
+            hovermode='closest'
+        )
+
+        # Set equal aspect ratio and proper range
+        fig_roc.update_xaxes(range=[0, 1.0])
+        fig_roc.update_yaxes(range=[0, 1.05])
+
+        return {
+            "title": "ROC-AUC Curves",
+            "html": pio.to_html(
+                fig_roc,
+                full_html=False,
+                include_plotlyjs=False,
+                config=config
+            )
+        }
+
+    except Exception as e:
+        print(f"Error building ROC-AUC plot: {e}")
+        return None