Mercurial > repos > goeckslab > image_learner
view plotly_plots.py @ 11:c5150cceab47 draft default tip
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
author | goeckslab |
---|---|
date | Sat, 18 Oct 2025 03:17:09 +0000 |
parents | 85e6f4b2ad18 |
children |
line wrap: on
line source
import json from pathlib import Path from typing import Dict, List, Optional import numpy as np import pandas as pd import plotly.graph_objects as go import plotly.io as pio from constants import LABEL_COLUMN_NAME, SPLIT_COLUMN_NAME from sklearn.metrics import auc, roc_curve from sklearn.preprocessing import label_binarize def build_classification_plots( test_stats_path: str, training_stats_path: Optional[str] = None, ) -> List[Dict[str, str]]: """ Read Ludwig’s test_statistics.json and build three interactive Plotly panels: - Confusion Matrix - ROC-AUC - Classification Report Heatmap Returns a list of dicts, each with: { "title": <plot title>, "html": <HTML fragment for embedding> } """ # --- Load test stats --- with open(test_stats_path, "r") as f: test_stats = json.load(f) label_stats = test_stats["label"] # common sizing cell = 40 n_classes = len(label_stats["confusion_matrix"]) side_px = max(cell * n_classes + 200, 600) common_cfg = {"displayModeBar": True, "scrollZoom": True} plots: List[Dict[str, str]] = [] # 0) Confusion Matrix cm = np.array(label_stats["confusion_matrix"], dtype=int) # Try to get actual class names from per_class_stats keys (which contain the real labels) pcs = label_stats.get("per_class_stats", {}) if pcs: labels = list(pcs.keys()) else: labels = label_stats.get("labels", [str(i) for i in range(cm.shape[0])]) total = cm.sum() fig_cm = go.Figure( go.Heatmap( z=cm, x=labels, y=labels, colorscale="Blues", showscale=True, colorbar=dict(title="Count"), ) ) fig_cm.update_traces(xgap=2, ygap=2) fig_cm.update_layout( title=dict(text="Confusion Matrix", x=0.5), xaxis_title="Predicted", yaxis_title="Observed", yaxis_autorange="reversed", width=side_px, height=side_px, margin=dict(t=100, l=80, r=80, b=80), ) # annotate counts and percentages mval = cm.max() if cm.size else 0 thresh = mval / 2 for i in range(cm.shape[0]): for j in range(cm.shape[1]): v = cm[i, j] pct = (v / total * 100) if total > 0 else 0 color = "white" if v > thresh else "black" fig_cm.add_annotation( x=labels[j], y=labels[i], text=f"<b>{v}</b>", showarrow=False, font=dict(color=color, size=14), xanchor="center", yanchor="bottom", yshift=2, ) fig_cm.add_annotation( x=labels[j], y=labels[i], text=f"{pct:.1f}%", showarrow=False, font=dict(color=color, size=13), xanchor="center", yanchor="top", yshift=-2, ) plots.append({ "title": "Confusion Matrix", "html": pio.to_html( fig_cm, full_html=False, include_plotlyjs="cdn", config=common_cfg ) }) # 1) ROC-AUC Curves (Multi-class) roc_plot = _build_roc_auc_plot(test_stats_path, labels, common_cfg) if roc_plot: plots.append(roc_plot) # 2) Classification Report Heatmap pcs = label_stats.get("per_class_stats", {}) if pcs: classes = list(pcs.keys()) metrics = ["precision", "recall", "f1_score"] z, txt = [], [] for c in classes: row, trow = [], [] for m in metrics: val = pcs[c].get(m, 0) row.append(val) trow.append(f"{val:.2f}") z.append(row) txt.append(trow) fig_cr = go.Figure( go.Heatmap( z=z, x=metrics, y=[str(c) for c in classes], text=txt, texttemplate="%{text}", colorscale="Reds", showscale=True, colorbar=dict(title="Value"), ) ) fig_cr.update_layout( title="Classification Report", xaxis_title="", yaxis_title="Class", width=side_px, height=side_px, margin=dict(t=80, l=80, r=80, b=80), ) plots.append({ "title": "Classification Report", "html": pio.to_html( fig_cr, full_html=False, include_plotlyjs=False, config=common_cfg ) }) return plots def _build_roc_auc_plot(test_stats_path: str, class_labels: List[str], config: dict) -> Optional[Dict[str, str]]: """ Build an interactive ROC-AUC curve plot for multi-class classification. Following sklearn's ROC example with micro-average and per-class curves. Args: test_stats_path: Path to test_statistics.json class_labels: List of class label names config: Plotly config dict Returns: Dict with title and HTML, or None if data unavailable """ try: # Get the experiment directory from test_stats_path exp_dir = Path(test_stats_path).parent # Load predictions with probabilities predictions_path = exp_dir / "predictions.csv" if not predictions_path.exists(): return None df_pred = pd.read_csv(predictions_path) if SPLIT_COLUMN_NAME in df_pred.columns: split_series = df_pred[SPLIT_COLUMN_NAME].astype(str).str.lower() test_mask = split_series.isin({"2", "test", "testing"}) if test_mask.any(): df_pred = df_pred[test_mask].reset_index(drop=True) if df_pred.empty: return None # Extract probability columns (label_probabilities_0, label_probabilities_1, etc.) # or label_probabilities_<class_name> for string labels prob_cols = [col for col in df_pred.columns if col.startswith('label_probabilities_') and col != 'label_probabilities'] # Sort by class number if numeric, otherwise keep alphabetical order if prob_cols and prob_cols[0].split('_')[-1].isdigit(): prob_cols.sort(key=lambda x: int(x.split('_')[-1])) else: prob_cols.sort() # Alphabetical sort for string class names if not prob_cols: return None # Get probabilities matrix (n_samples x n_classes) y_score = df_pred[prob_cols].values n_classes = len(prob_cols) y_true = None candidate_cols = [ LABEL_COLUMN_NAME, f"{LABEL_COLUMN_NAME}_ground_truth", f"{LABEL_COLUMN_NAME}__ground_truth", f"{LABEL_COLUMN_NAME}_target", f"{LABEL_COLUMN_NAME}__target", ] candidate_cols.extend( [ col for col in df_pred.columns if (col.startswith(f"{LABEL_COLUMN_NAME}_") or col.startswith(f"{LABEL_COLUMN_NAME}__")) and "probabilities" not in col and "predictions" not in col ] ) for col in candidate_cols: if col in df_pred.columns and col not in prob_cols: y_true = df_pred[col].values break if y_true is None: desc_path = exp_dir / "description.json" if desc_path.exists(): try: with open(desc_path, 'r') as f: desc = json.load(f) dataset_path = desc.get('dataset', '') if dataset_path and Path(dataset_path).exists(): df_orig = pd.read_csv(dataset_path) if SPLIT_COLUMN_NAME in df_orig.columns: df_orig = df_orig[df_orig[SPLIT_COLUMN_NAME] == 2].reset_index(drop=True) if LABEL_COLUMN_NAME in df_orig.columns: y_true = df_orig[LABEL_COLUMN_NAME].values if len(y_true) != len(df_pred): print( f"Warning: Test set size mismatch. Truncating to {len(df_pred)} samples for ROC plot." ) y_true = y_true[:len(df_pred)] else: print("Warning: Original dataset referenced in description.json is unavailable.") except Exception as exc: # pragma: no cover - defensive print(f"Warning: Failed to recover labels from dataset: {exc}") if y_true is None or len(y_true) == 0: print("Warning: Unable to locate ground-truth labels for ROC plot.") return None if len(y_true) != len(y_score): limit = min(len(y_true), len(y_score)) if limit == 0: return None print(f"Warning: Aligning prediction and label lengths to {limit} samples for ROC plot.") y_true = y_true[:limit] y_score = y_score[:limit] # Get actual class names from probability column names actual_classes = [col.replace('label_probabilities_', '') for col in prob_cols] display_classes = class_labels if len(class_labels) == n_classes else actual_classes # Binarize the output following sklearn example # Use actual class names if they're strings, otherwise use range if isinstance(y_true[0], str): y_test = label_binarize(y_true, classes=actual_classes) else: y_test = label_binarize(y_true, classes=list(range(n_classes))) # Handle binary classification case if y_test.ndim != 2: y_test = np.atleast_2d(y_test) if n_classes == 2: if y_test.shape[1] == 1: y_test = np.hstack([1 - y_test, y_test]) elif y_test.shape[1] != 2: print("Warning: Unexpected label binarization shape for binary ROC plot.") return None elif y_test.shape[1] != n_classes: print("Warning: Label binarization did not produce expected class dimension; skipping ROC plot.") return None # Compute ROC curve and ROC area for each class (following sklearn example) fpr = dict() tpr = dict() roc_auc = dict() for i in range(n_classes): if np.sum(y_test[:, i]) > 0: # Check if class exists in test set fpr[i], tpr[i], _ = roc_curve(y_test[:, i], y_score[:, i]) roc_auc[i] = auc(fpr[i], tpr[i]) # Compute micro-average ROC curve and ROC area (sklearn example) fpr["micro"], tpr["micro"], _ = roc_curve(y_test.ravel(), y_score.ravel()) roc_auc["micro"] = auc(fpr["micro"], tpr["micro"]) # Create ROC curve plot fig_roc = go.Figure() # Colors for different classes colors = [ '#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd', '#8c564b', '#e377c2', '#7f7f7f', '#bcbd22', '#17becf' ] # Plot micro-average ROC curve first (most important) fig_roc.add_trace(go.Scatter( x=fpr["micro"], y=tpr["micro"], mode='lines', name=f'Micro-average ROC (AUC = {roc_auc["micro"]:.3f})', line=dict(color='deeppink', width=3, dash='dot'), hovertemplate=('<b>Micro-average ROC</b><br>' 'FPR: %{x:.3f}<br>' 'TPR: %{y:.3f}<br>' f'AUC: {roc_auc["micro"]:.3f}<extra></extra>') )) # Plot ROC curve for each class for i in range(n_classes): if i in roc_auc: # Only plot if class exists in test set class_name = display_classes[i] if i < len(display_classes) else f"Class {i}" color = colors[i % len(colors)] fig_roc.add_trace(go.Scatter( x=fpr[i], y=tpr[i], mode='lines', name=f'{class_name} (AUC = {roc_auc[i]:.3f})', line=dict(color=color, width=2), hovertemplate=(f'<b>{class_name}</b><br>' 'FPR: %{x:.3f}<br>' 'TPR: %{y:.3f}<br>' f'AUC: {roc_auc[i]:.3f}<extra></extra>') )) # Add diagonal line (random classifier) fig_roc.add_trace(go.Scatter( x=[0, 1], y=[0, 1], mode='lines', name='Random Classifier', line=dict(color='gray', width=1, dash='dash'), hovertemplate='Random Classifier<br>AUC = 0.500<extra></extra>' )) # Calculate macro-average AUC class_aucs = [roc_auc[i] for i in range(n_classes) if i in roc_auc] if class_aucs: macro_auc = np.mean(class_aucs) title_text = f"ROC Curves (Micro-avg = {roc_auc['micro']:.3f}, Macro-avg = {macro_auc:.3f})" else: title_text = f"ROC Curves (Micro-avg = {roc_auc['micro']:.3f})" fig_roc.update_layout( title=dict(text=title_text, x=0.5), xaxis_title="False Positive Rate", yaxis_title="True Positive Rate", width=700, height=600, margin=dict(t=80, l=80, r=80, b=80), legend=dict( x=0.6, y=0.1, bgcolor="rgba(255,255,255,0.9)", bordercolor="rgba(0,0,0,0.2)", borderwidth=1 ), hovermode='closest' ) # Set equal aspect ratio and proper range fig_roc.update_xaxes(range=[0, 1.0]) fig_roc.update_yaxes(range=[0, 1.05]) return { "title": "ROC-AUC Curves", "html": pio.to_html( fig_roc, full_html=False, include_plotlyjs=False, config=config ) } except Exception as e: print(f"Error building ROC-AUC plot: {e}") return None