Mercurial > repos > goeckslab > image_learner
comparison plotly_plots.py @ 11:c5150cceab47 draft default tip
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
| author | goeckslab |
|---|---|
| date | Sat, 18 Oct 2025 03:17:09 +0000 |
| parents | 85e6f4b2ad18 |
| children |
comparison
equal
deleted
inserted
replaced
| 10:b0d893d04d4c | 11:c5150cceab47 |
|---|---|
| 1 import json | 1 import json |
| 2 from pathlib import Path | |
| 2 from typing import Dict, List, Optional | 3 from typing import Dict, List, Optional |
| 3 | 4 |
| 4 import numpy as np | 5 import numpy as np |
| 6 import pandas as pd | |
| 5 import plotly.graph_objects as go | 7 import plotly.graph_objects as go |
| 6 import plotly.io as pio | 8 import plotly.io as pio |
| 9 from constants import LABEL_COLUMN_NAME, SPLIT_COLUMN_NAME | |
| 10 from sklearn.metrics import auc, roc_curve | |
| 11 from sklearn.preprocessing import label_binarize | |
| 7 | 12 |
| 8 | 13 |
| 9 def build_classification_plots( | 14 def build_classification_plots( |
| 10 test_stats_path: str, | 15 test_stats_path: str, |
| 11 training_stats_path: Optional[str] = None, | 16 training_stats_path: Optional[str] = None, |
| 35 | 40 |
| 36 plots: List[Dict[str, str]] = [] | 41 plots: List[Dict[str, str]] = [] |
| 37 | 42 |
| 38 # 0) Confusion Matrix | 43 # 0) Confusion Matrix |
| 39 cm = np.array(label_stats["confusion_matrix"], dtype=int) | 44 cm = np.array(label_stats["confusion_matrix"], dtype=int) |
| 40 labels = label_stats.get("labels", [str(i) for i in range(cm.shape[0])]) | 45 # Try to get actual class names from per_class_stats keys (which contain the real labels) |
| 46 pcs = label_stats.get("per_class_stats", {}) | |
| 47 if pcs: | |
| 48 labels = list(pcs.keys()) | |
| 49 else: | |
| 50 labels = label_stats.get("labels", [str(i) for i in range(cm.shape[0])]) | |
| 41 total = cm.sum() | 51 total = cm.sum() |
| 42 | 52 |
| 43 fig_cm = go.Figure( | 53 fig_cm = go.Figure( |
| 44 go.Heatmap( | 54 go.Heatmap( |
| 45 z=cm, | 55 z=cm, |
| 98 include_plotlyjs="cdn", | 108 include_plotlyjs="cdn", |
| 99 config=common_cfg | 109 config=common_cfg |
| 100 ) | 110 ) |
| 101 }) | 111 }) |
| 102 | 112 |
| 113 # 1) ROC-AUC Curves (Multi-class) | |
| 114 roc_plot = _build_roc_auc_plot(test_stats_path, labels, common_cfg) | |
| 115 if roc_plot: | |
| 116 plots.append(roc_plot) | |
| 117 | |
| 103 # 2) Classification Report Heatmap | 118 # 2) Classification Report Heatmap |
| 104 pcs = label_stats.get("per_class_stats", {}) | 119 pcs = label_stats.get("per_class_stats", {}) |
| 105 if pcs: | 120 if pcs: |
| 106 classes = list(pcs.keys()) | 121 classes = list(pcs.keys()) |
| 107 metrics = ["precision", "recall", "f1_score"] | 122 metrics = ["precision", "recall", "f1_score"] |
| 144 config=common_cfg | 159 config=common_cfg |
| 145 ) | 160 ) |
| 146 }) | 161 }) |
| 147 | 162 |
| 148 return plots | 163 return plots |
| 164 | |
| 165 | |
| 166 def _build_roc_auc_plot(test_stats_path: str, class_labels: List[str], config: dict) -> Optional[Dict[str, str]]: | |
| 167 """ | |
| 168 Build an interactive ROC-AUC curve plot for multi-class classification. | |
| 169 Following sklearn's ROC example with micro-average and per-class curves. | |
| 170 | |
| 171 Args: | |
| 172 test_stats_path: Path to test_statistics.json | |
| 173 class_labels: List of class label names | |
| 174 config: Plotly config dict | |
| 175 | |
| 176 Returns: | |
| 177 Dict with title and HTML, or None if data unavailable | |
| 178 """ | |
| 179 try: | |
| 180 # Get the experiment directory from test_stats_path | |
| 181 exp_dir = Path(test_stats_path).parent | |
| 182 | |
| 183 # Load predictions with probabilities | |
| 184 predictions_path = exp_dir / "predictions.csv" | |
| 185 if not predictions_path.exists(): | |
| 186 return None | |
| 187 | |
| 188 df_pred = pd.read_csv(predictions_path) | |
| 189 | |
| 190 if SPLIT_COLUMN_NAME in df_pred.columns: | |
| 191 split_series = df_pred[SPLIT_COLUMN_NAME].astype(str).str.lower() | |
| 192 test_mask = split_series.isin({"2", "test", "testing"}) | |
| 193 if test_mask.any(): | |
| 194 df_pred = df_pred[test_mask].reset_index(drop=True) | |
| 195 | |
| 196 if df_pred.empty: | |
| 197 return None | |
| 198 | |
| 199 # Extract probability columns (label_probabilities_0, label_probabilities_1, etc.) | |
| 200 # or label_probabilities_<class_name> for string labels | |
| 201 prob_cols = [col for col in df_pred.columns if col.startswith('label_probabilities_') and col != 'label_probabilities'] | |
| 202 | |
| 203 # Sort by class number if numeric, otherwise keep alphabetical order | |
| 204 if prob_cols and prob_cols[0].split('_')[-1].isdigit(): | |
| 205 prob_cols.sort(key=lambda x: int(x.split('_')[-1])) | |
| 206 else: | |
| 207 prob_cols.sort() # Alphabetical sort for string class names | |
| 208 | |
| 209 if not prob_cols: | |
| 210 return None | |
| 211 | |
| 212 # Get probabilities matrix (n_samples x n_classes) | |
| 213 y_score = df_pred[prob_cols].values | |
| 214 n_classes = len(prob_cols) | |
| 215 | |
| 216 y_true = None | |
| 217 candidate_cols = [ | |
| 218 LABEL_COLUMN_NAME, | |
| 219 f"{LABEL_COLUMN_NAME}_ground_truth", | |
| 220 f"{LABEL_COLUMN_NAME}__ground_truth", | |
| 221 f"{LABEL_COLUMN_NAME}_target", | |
| 222 f"{LABEL_COLUMN_NAME}__target", | |
| 223 ] | |
| 224 candidate_cols.extend( | |
| 225 [ | |
| 226 col | |
| 227 for col in df_pred.columns | |
| 228 if (col.startswith(f"{LABEL_COLUMN_NAME}_") or col.startswith(f"{LABEL_COLUMN_NAME}__")) | |
| 229 and "probabilities" not in col | |
| 230 and "predictions" not in col | |
| 231 ] | |
| 232 ) | |
| 233 for col in candidate_cols: | |
| 234 if col in df_pred.columns and col not in prob_cols: | |
| 235 y_true = df_pred[col].values | |
| 236 break | |
| 237 | |
| 238 if y_true is None: | |
| 239 desc_path = exp_dir / "description.json" | |
| 240 if desc_path.exists(): | |
| 241 try: | |
| 242 with open(desc_path, 'r') as f: | |
| 243 desc = json.load(f) | |
| 244 dataset_path = desc.get('dataset', '') | |
| 245 if dataset_path and Path(dataset_path).exists(): | |
| 246 df_orig = pd.read_csv(dataset_path) | |
| 247 if SPLIT_COLUMN_NAME in df_orig.columns: | |
| 248 df_orig = df_orig[df_orig[SPLIT_COLUMN_NAME] == 2].reset_index(drop=True) | |
| 249 if LABEL_COLUMN_NAME in df_orig.columns: | |
| 250 y_true = df_orig[LABEL_COLUMN_NAME].values | |
| 251 if len(y_true) != len(df_pred): | |
| 252 print( | |
| 253 f"Warning: Test set size mismatch. Truncating to {len(df_pred)} samples for ROC plot." | |
| 254 ) | |
| 255 y_true = y_true[:len(df_pred)] | |
| 256 else: | |
| 257 print("Warning: Original dataset referenced in description.json is unavailable.") | |
| 258 except Exception as exc: # pragma: no cover - defensive | |
| 259 print(f"Warning: Failed to recover labels from dataset: {exc}") | |
| 260 | |
| 261 if y_true is None or len(y_true) == 0: | |
| 262 print("Warning: Unable to locate ground-truth labels for ROC plot.") | |
| 263 return None | |
| 264 | |
| 265 if len(y_true) != len(y_score): | |
| 266 limit = min(len(y_true), len(y_score)) | |
| 267 if limit == 0: | |
| 268 return None | |
| 269 print(f"Warning: Aligning prediction and label lengths to {limit} samples for ROC plot.") | |
| 270 y_true = y_true[:limit] | |
| 271 y_score = y_score[:limit] | |
| 272 | |
| 273 # Get actual class names from probability column names | |
| 274 actual_classes = [col.replace('label_probabilities_', '') for col in prob_cols] | |
| 275 display_classes = class_labels if len(class_labels) == n_classes else actual_classes | |
| 276 | |
| 277 # Binarize the output following sklearn example | |
| 278 # Use actual class names if they're strings, otherwise use range | |
| 279 if isinstance(y_true[0], str): | |
| 280 y_test = label_binarize(y_true, classes=actual_classes) | |
| 281 else: | |
| 282 y_test = label_binarize(y_true, classes=list(range(n_classes))) | |
| 283 | |
| 284 # Handle binary classification case | |
| 285 if y_test.ndim != 2: | |
| 286 y_test = np.atleast_2d(y_test) | |
| 287 | |
| 288 if n_classes == 2: | |
| 289 if y_test.shape[1] == 1: | |
| 290 y_test = np.hstack([1 - y_test, y_test]) | |
| 291 elif y_test.shape[1] != 2: | |
| 292 print("Warning: Unexpected label binarization shape for binary ROC plot.") | |
| 293 return None | |
| 294 elif y_test.shape[1] != n_classes: | |
| 295 print("Warning: Label binarization did not produce expected class dimension; skipping ROC plot.") | |
| 296 return None | |
| 297 | |
| 298 # Compute ROC curve and ROC area for each class (following sklearn example) | |
| 299 fpr = dict() | |
| 300 tpr = dict() | |
| 301 roc_auc = dict() | |
| 302 | |
| 303 for i in range(n_classes): | |
| 304 if np.sum(y_test[:, i]) > 0: # Check if class exists in test set | |
| 305 fpr[i], tpr[i], _ = roc_curve(y_test[:, i], y_score[:, i]) | |
| 306 roc_auc[i] = auc(fpr[i], tpr[i]) | |
| 307 | |
| 308 # Compute micro-average ROC curve and ROC area (sklearn example) | |
| 309 fpr["micro"], tpr["micro"], _ = roc_curve(y_test.ravel(), y_score.ravel()) | |
| 310 roc_auc["micro"] = auc(fpr["micro"], tpr["micro"]) | |
| 311 | |
| 312 # Create ROC curve plot | |
| 313 fig_roc = go.Figure() | |
| 314 | |
| 315 # Colors for different classes | |
| 316 colors = [ | |
| 317 '#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd', | |
| 318 '#8c564b', '#e377c2', '#7f7f7f', '#bcbd22', '#17becf' | |
| 319 ] | |
| 320 | |
| 321 # Plot micro-average ROC curve first (most important) | |
| 322 fig_roc.add_trace(go.Scatter( | |
| 323 x=fpr["micro"], | |
| 324 y=tpr["micro"], | |
| 325 mode='lines', | |
| 326 name=f'Micro-average ROC (AUC = {roc_auc["micro"]:.3f})', | |
| 327 line=dict(color='deeppink', width=3, dash='dot'), | |
| 328 hovertemplate=('<b>Micro-average ROC</b><br>' | |
| 329 'FPR: %{x:.3f}<br>' | |
| 330 'TPR: %{y:.3f}<br>' | |
| 331 f'AUC: {roc_auc["micro"]:.3f}<extra></extra>') | |
| 332 )) | |
| 333 | |
| 334 # Plot ROC curve for each class | |
| 335 for i in range(n_classes): | |
| 336 if i in roc_auc: # Only plot if class exists in test set | |
| 337 class_name = display_classes[i] if i < len(display_classes) else f"Class {i}" | |
| 338 color = colors[i % len(colors)] | |
| 339 | |
| 340 fig_roc.add_trace(go.Scatter( | |
| 341 x=fpr[i], | |
| 342 y=tpr[i], | |
| 343 mode='lines', | |
| 344 name=f'{class_name} (AUC = {roc_auc[i]:.3f})', | |
| 345 line=dict(color=color, width=2), | |
| 346 hovertemplate=(f'<b>{class_name}</b><br>' | |
| 347 'FPR: %{x:.3f}<br>' | |
| 348 'TPR: %{y:.3f}<br>' | |
| 349 f'AUC: {roc_auc[i]:.3f}<extra></extra>') | |
| 350 )) | |
| 351 | |
| 352 # Add diagonal line (random classifier) | |
| 353 fig_roc.add_trace(go.Scatter( | |
| 354 x=[0, 1], | |
| 355 y=[0, 1], | |
| 356 mode='lines', | |
| 357 name='Random Classifier', | |
| 358 line=dict(color='gray', width=1, dash='dash'), | |
| 359 hovertemplate='Random Classifier<br>AUC = 0.500<extra></extra>' | |
| 360 )) | |
| 361 | |
| 362 # Calculate macro-average AUC | |
| 363 class_aucs = [roc_auc[i] for i in range(n_classes) if i in roc_auc] | |
| 364 if class_aucs: | |
| 365 macro_auc = np.mean(class_aucs) | |
| 366 title_text = f"ROC Curves (Micro-avg = {roc_auc['micro']:.3f}, Macro-avg = {macro_auc:.3f})" | |
| 367 else: | |
| 368 title_text = f"ROC Curves (Micro-avg = {roc_auc['micro']:.3f})" | |
| 369 | |
| 370 fig_roc.update_layout( | |
| 371 title=dict(text=title_text, x=0.5), | |
| 372 xaxis_title="False Positive Rate", | |
| 373 yaxis_title="True Positive Rate", | |
| 374 width=700, | |
| 375 height=600, | |
| 376 margin=dict(t=80, l=80, r=80, b=80), | |
| 377 legend=dict( | |
| 378 x=0.6, | |
| 379 y=0.1, | |
| 380 bgcolor="rgba(255,255,255,0.9)", | |
| 381 bordercolor="rgba(0,0,0,0.2)", | |
| 382 borderwidth=1 | |
| 383 ), | |
| 384 hovermode='closest' | |
| 385 ) | |
| 386 | |
| 387 # Set equal aspect ratio and proper range | |
| 388 fig_roc.update_xaxes(range=[0, 1.0]) | |
| 389 fig_roc.update_yaxes(range=[0, 1.05]) | |
| 390 | |
| 391 return { | |
| 392 "title": "ROC-AUC Curves", | |
| 393 "html": pio.to_html( | |
| 394 fig_roc, | |
| 395 full_html=False, | |
| 396 include_plotlyjs=False, | |
| 397 config=config | |
| 398 ) | |
| 399 } | |
| 400 | |
| 401 except Exception as e: | |
| 402 print(f"Error building ROC-AUC plot: {e}") | |
| 403 return None |
