Mercurial > repos > goeckslab > image_learner
diff plotly_plots.py @ 11:c5150cceab47 draft default tip
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
author | goeckslab |
---|---|
date | Sat, 18 Oct 2025 03:17:09 +0000 |
parents | 85e6f4b2ad18 |
children |
line wrap: on
line diff
--- a/plotly_plots.py Mon Sep 08 22:38:35 2025 +0000 +++ b/plotly_plots.py Sat Oct 18 03:17:09 2025 +0000 @@ -1,9 +1,14 @@ import json +from pathlib import Path from typing import Dict, List, Optional import numpy as np +import pandas as pd import plotly.graph_objects as go import plotly.io as pio +from constants import LABEL_COLUMN_NAME, SPLIT_COLUMN_NAME +from sklearn.metrics import auc, roc_curve +from sklearn.preprocessing import label_binarize def build_classification_plots( @@ -37,7 +42,12 @@ # 0) Confusion Matrix cm = np.array(label_stats["confusion_matrix"], dtype=int) - labels = label_stats.get("labels", [str(i) for i in range(cm.shape[0])]) + # Try to get actual class names from per_class_stats keys (which contain the real labels) + pcs = label_stats.get("per_class_stats", {}) + if pcs: + labels = list(pcs.keys()) + else: + labels = label_stats.get("labels", [str(i) for i in range(cm.shape[0])]) total = cm.sum() fig_cm = go.Figure( @@ -100,6 +110,11 @@ ) }) + # 1) ROC-AUC Curves (Multi-class) + roc_plot = _build_roc_auc_plot(test_stats_path, labels, common_cfg) + if roc_plot: + plots.append(roc_plot) + # 2) Classification Report Heatmap pcs = label_stats.get("per_class_stats", {}) if pcs: @@ -146,3 +161,243 @@ }) return plots + + +def _build_roc_auc_plot(test_stats_path: str, class_labels: List[str], config: dict) -> Optional[Dict[str, str]]: + """ + Build an interactive ROC-AUC curve plot for multi-class classification. + Following sklearn's ROC example with micro-average and per-class curves. + + Args: + test_stats_path: Path to test_statistics.json + class_labels: List of class label names + config: Plotly config dict + + Returns: + Dict with title and HTML, or None if data unavailable + """ + try: + # Get the experiment directory from test_stats_path + exp_dir = Path(test_stats_path).parent + + # Load predictions with probabilities + predictions_path = exp_dir / "predictions.csv" + if not predictions_path.exists(): + return None + + df_pred = pd.read_csv(predictions_path) + + if SPLIT_COLUMN_NAME in df_pred.columns: + split_series = df_pred[SPLIT_COLUMN_NAME].astype(str).str.lower() + test_mask = split_series.isin({"2", "test", "testing"}) + if test_mask.any(): + df_pred = df_pred[test_mask].reset_index(drop=True) + + if df_pred.empty: + return None + + # Extract probability columns (label_probabilities_0, label_probabilities_1, etc.) + # or label_probabilities_<class_name> for string labels + prob_cols = [col for col in df_pred.columns if col.startswith('label_probabilities_') and col != 'label_probabilities'] + + # Sort by class number if numeric, otherwise keep alphabetical order + if prob_cols and prob_cols[0].split('_')[-1].isdigit(): + prob_cols.sort(key=lambda x: int(x.split('_')[-1])) + else: + prob_cols.sort() # Alphabetical sort for string class names + + if not prob_cols: + return None + + # Get probabilities matrix (n_samples x n_classes) + y_score = df_pred[prob_cols].values + n_classes = len(prob_cols) + + y_true = None + candidate_cols = [ + LABEL_COLUMN_NAME, + f"{LABEL_COLUMN_NAME}_ground_truth", + f"{LABEL_COLUMN_NAME}__ground_truth", + f"{LABEL_COLUMN_NAME}_target", + f"{LABEL_COLUMN_NAME}__target", + ] + candidate_cols.extend( + [ + col + for col in df_pred.columns + if (col.startswith(f"{LABEL_COLUMN_NAME}_") or col.startswith(f"{LABEL_COLUMN_NAME}__")) + and "probabilities" not in col + and "predictions" not in col + ] + ) + for col in candidate_cols: + if col in df_pred.columns and col not in prob_cols: + y_true = df_pred[col].values + break + + if y_true is None: + desc_path = exp_dir / "description.json" + if desc_path.exists(): + try: + with open(desc_path, 'r') as f: + desc = json.load(f) + dataset_path = desc.get('dataset', '') + if dataset_path and Path(dataset_path).exists(): + df_orig = pd.read_csv(dataset_path) + if SPLIT_COLUMN_NAME in df_orig.columns: + df_orig = df_orig[df_orig[SPLIT_COLUMN_NAME] == 2].reset_index(drop=True) + if LABEL_COLUMN_NAME in df_orig.columns: + y_true = df_orig[LABEL_COLUMN_NAME].values + if len(y_true) != len(df_pred): + print( + f"Warning: Test set size mismatch. Truncating to {len(df_pred)} samples for ROC plot." + ) + y_true = y_true[:len(df_pred)] + else: + print("Warning: Original dataset referenced in description.json is unavailable.") + except Exception as exc: # pragma: no cover - defensive + print(f"Warning: Failed to recover labels from dataset: {exc}") + + if y_true is None or len(y_true) == 0: + print("Warning: Unable to locate ground-truth labels for ROC plot.") + return None + + if len(y_true) != len(y_score): + limit = min(len(y_true), len(y_score)) + if limit == 0: + return None + print(f"Warning: Aligning prediction and label lengths to {limit} samples for ROC plot.") + y_true = y_true[:limit] + y_score = y_score[:limit] + + # Get actual class names from probability column names + actual_classes = [col.replace('label_probabilities_', '') for col in prob_cols] + display_classes = class_labels if len(class_labels) == n_classes else actual_classes + + # Binarize the output following sklearn example + # Use actual class names if they're strings, otherwise use range + if isinstance(y_true[0], str): + y_test = label_binarize(y_true, classes=actual_classes) + else: + y_test = label_binarize(y_true, classes=list(range(n_classes))) + + # Handle binary classification case + if y_test.ndim != 2: + y_test = np.atleast_2d(y_test) + + if n_classes == 2: + if y_test.shape[1] == 1: + y_test = np.hstack([1 - y_test, y_test]) + elif y_test.shape[1] != 2: + print("Warning: Unexpected label binarization shape for binary ROC plot.") + return None + elif y_test.shape[1] != n_classes: + print("Warning: Label binarization did not produce expected class dimension; skipping ROC plot.") + return None + + # Compute ROC curve and ROC area for each class (following sklearn example) + fpr = dict() + tpr = dict() + roc_auc = dict() + + for i in range(n_classes): + if np.sum(y_test[:, i]) > 0: # Check if class exists in test set + fpr[i], tpr[i], _ = roc_curve(y_test[:, i], y_score[:, i]) + roc_auc[i] = auc(fpr[i], tpr[i]) + + # Compute micro-average ROC curve and ROC area (sklearn example) + fpr["micro"], tpr["micro"], _ = roc_curve(y_test.ravel(), y_score.ravel()) + roc_auc["micro"] = auc(fpr["micro"], tpr["micro"]) + + # Create ROC curve plot + fig_roc = go.Figure() + + # Colors for different classes + colors = [ + '#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd', + '#8c564b', '#e377c2', '#7f7f7f', '#bcbd22', '#17becf' + ] + + # Plot micro-average ROC curve first (most important) + fig_roc.add_trace(go.Scatter( + x=fpr["micro"], + y=tpr["micro"], + mode='lines', + name=f'Micro-average ROC (AUC = {roc_auc["micro"]:.3f})', + line=dict(color='deeppink', width=3, dash='dot'), + hovertemplate=('<b>Micro-average ROC</b><br>' + 'FPR: %{x:.3f}<br>' + 'TPR: %{y:.3f}<br>' + f'AUC: {roc_auc["micro"]:.3f}<extra></extra>') + )) + + # Plot ROC curve for each class + for i in range(n_classes): + if i in roc_auc: # Only plot if class exists in test set + class_name = display_classes[i] if i < len(display_classes) else f"Class {i}" + color = colors[i % len(colors)] + + fig_roc.add_trace(go.Scatter( + x=fpr[i], + y=tpr[i], + mode='lines', + name=f'{class_name} (AUC = {roc_auc[i]:.3f})', + line=dict(color=color, width=2), + hovertemplate=(f'<b>{class_name}</b><br>' + 'FPR: %{x:.3f}<br>' + 'TPR: %{y:.3f}<br>' + f'AUC: {roc_auc[i]:.3f}<extra></extra>') + )) + + # Add diagonal line (random classifier) + fig_roc.add_trace(go.Scatter( + x=[0, 1], + y=[0, 1], + mode='lines', + name='Random Classifier', + line=dict(color='gray', width=1, dash='dash'), + hovertemplate='Random Classifier<br>AUC = 0.500<extra></extra>' + )) + + # Calculate macro-average AUC + class_aucs = [roc_auc[i] for i in range(n_classes) if i in roc_auc] + if class_aucs: + macro_auc = np.mean(class_aucs) + title_text = f"ROC Curves (Micro-avg = {roc_auc['micro']:.3f}, Macro-avg = {macro_auc:.3f})" + else: + title_text = f"ROC Curves (Micro-avg = {roc_auc['micro']:.3f})" + + fig_roc.update_layout( + title=dict(text=title_text, x=0.5), + xaxis_title="False Positive Rate", + yaxis_title="True Positive Rate", + width=700, + height=600, + margin=dict(t=80, l=80, r=80, b=80), + legend=dict( + x=0.6, + y=0.1, + bgcolor="rgba(255,255,255,0.9)", + bordercolor="rgba(0,0,0,0.2)", + borderwidth=1 + ), + hovermode='closest' + ) + + # Set equal aspect ratio and proper range + fig_roc.update_xaxes(range=[0, 1.0]) + fig_roc.update_yaxes(range=[0, 1.05]) + + return { + "title": "ROC-AUC Curves", + "html": pio.to_html( + fig_roc, + full_html=False, + include_plotlyjs=False, + config=config + ) + } + + except Exception as e: + print(f"Error building ROC-AUC plot: {e}") + return None