Mercurial > repos > goeckslab > image_learner
comparison plotly_plots.py @ 11:c5150cceab47 draft default tip
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
author | goeckslab |
---|---|
date | Sat, 18 Oct 2025 03:17:09 +0000 |
parents | 85e6f4b2ad18 |
children |
comparison
equal
deleted
inserted
replaced
10:b0d893d04d4c | 11:c5150cceab47 |
---|---|
1 import json | 1 import json |
2 from pathlib import Path | |
2 from typing import Dict, List, Optional | 3 from typing import Dict, List, Optional |
3 | 4 |
4 import numpy as np | 5 import numpy as np |
6 import pandas as pd | |
5 import plotly.graph_objects as go | 7 import plotly.graph_objects as go |
6 import plotly.io as pio | 8 import plotly.io as pio |
9 from constants import LABEL_COLUMN_NAME, SPLIT_COLUMN_NAME | |
10 from sklearn.metrics import auc, roc_curve | |
11 from sklearn.preprocessing import label_binarize | |
7 | 12 |
8 | 13 |
9 def build_classification_plots( | 14 def build_classification_plots( |
10 test_stats_path: str, | 15 test_stats_path: str, |
11 training_stats_path: Optional[str] = None, | 16 training_stats_path: Optional[str] = None, |
35 | 40 |
36 plots: List[Dict[str, str]] = [] | 41 plots: List[Dict[str, str]] = [] |
37 | 42 |
38 # 0) Confusion Matrix | 43 # 0) Confusion Matrix |
39 cm = np.array(label_stats["confusion_matrix"], dtype=int) | 44 cm = np.array(label_stats["confusion_matrix"], dtype=int) |
40 labels = label_stats.get("labels", [str(i) for i in range(cm.shape[0])]) | 45 # Try to get actual class names from per_class_stats keys (which contain the real labels) |
46 pcs = label_stats.get("per_class_stats", {}) | |
47 if pcs: | |
48 labels = list(pcs.keys()) | |
49 else: | |
50 labels = label_stats.get("labels", [str(i) for i in range(cm.shape[0])]) | |
41 total = cm.sum() | 51 total = cm.sum() |
42 | 52 |
43 fig_cm = go.Figure( | 53 fig_cm = go.Figure( |
44 go.Heatmap( | 54 go.Heatmap( |
45 z=cm, | 55 z=cm, |
98 include_plotlyjs="cdn", | 108 include_plotlyjs="cdn", |
99 config=common_cfg | 109 config=common_cfg |
100 ) | 110 ) |
101 }) | 111 }) |
102 | 112 |
113 # 1) ROC-AUC Curves (Multi-class) | |
114 roc_plot = _build_roc_auc_plot(test_stats_path, labels, common_cfg) | |
115 if roc_plot: | |
116 plots.append(roc_plot) | |
117 | |
103 # 2) Classification Report Heatmap | 118 # 2) Classification Report Heatmap |
104 pcs = label_stats.get("per_class_stats", {}) | 119 pcs = label_stats.get("per_class_stats", {}) |
105 if pcs: | 120 if pcs: |
106 classes = list(pcs.keys()) | 121 classes = list(pcs.keys()) |
107 metrics = ["precision", "recall", "f1_score"] | 122 metrics = ["precision", "recall", "f1_score"] |
144 config=common_cfg | 159 config=common_cfg |
145 ) | 160 ) |
146 }) | 161 }) |
147 | 162 |
148 return plots | 163 return plots |
164 | |
165 | |
166 def _build_roc_auc_plot(test_stats_path: str, class_labels: List[str], config: dict) -> Optional[Dict[str, str]]: | |
167 """ | |
168 Build an interactive ROC-AUC curve plot for multi-class classification. | |
169 Following sklearn's ROC example with micro-average and per-class curves. | |
170 | |
171 Args: | |
172 test_stats_path: Path to test_statistics.json | |
173 class_labels: List of class label names | |
174 config: Plotly config dict | |
175 | |
176 Returns: | |
177 Dict with title and HTML, or None if data unavailable | |
178 """ | |
179 try: | |
180 # Get the experiment directory from test_stats_path | |
181 exp_dir = Path(test_stats_path).parent | |
182 | |
183 # Load predictions with probabilities | |
184 predictions_path = exp_dir / "predictions.csv" | |
185 if not predictions_path.exists(): | |
186 return None | |
187 | |
188 df_pred = pd.read_csv(predictions_path) | |
189 | |
190 if SPLIT_COLUMN_NAME in df_pred.columns: | |
191 split_series = df_pred[SPLIT_COLUMN_NAME].astype(str).str.lower() | |
192 test_mask = split_series.isin({"2", "test", "testing"}) | |
193 if test_mask.any(): | |
194 df_pred = df_pred[test_mask].reset_index(drop=True) | |
195 | |
196 if df_pred.empty: | |
197 return None | |
198 | |
199 # Extract probability columns (label_probabilities_0, label_probabilities_1, etc.) | |
200 # or label_probabilities_<class_name> for string labels | |
201 prob_cols = [col for col in df_pred.columns if col.startswith('label_probabilities_') and col != 'label_probabilities'] | |
202 | |
203 # Sort by class number if numeric, otherwise keep alphabetical order | |
204 if prob_cols and prob_cols[0].split('_')[-1].isdigit(): | |
205 prob_cols.sort(key=lambda x: int(x.split('_')[-1])) | |
206 else: | |
207 prob_cols.sort() # Alphabetical sort for string class names | |
208 | |
209 if not prob_cols: | |
210 return None | |
211 | |
212 # Get probabilities matrix (n_samples x n_classes) | |
213 y_score = df_pred[prob_cols].values | |
214 n_classes = len(prob_cols) | |
215 | |
216 y_true = None | |
217 candidate_cols = [ | |
218 LABEL_COLUMN_NAME, | |
219 f"{LABEL_COLUMN_NAME}_ground_truth", | |
220 f"{LABEL_COLUMN_NAME}__ground_truth", | |
221 f"{LABEL_COLUMN_NAME}_target", | |
222 f"{LABEL_COLUMN_NAME}__target", | |
223 ] | |
224 candidate_cols.extend( | |
225 [ | |
226 col | |
227 for col in df_pred.columns | |
228 if (col.startswith(f"{LABEL_COLUMN_NAME}_") or col.startswith(f"{LABEL_COLUMN_NAME}__")) | |
229 and "probabilities" not in col | |
230 and "predictions" not in col | |
231 ] | |
232 ) | |
233 for col in candidate_cols: | |
234 if col in df_pred.columns and col not in prob_cols: | |
235 y_true = df_pred[col].values | |
236 break | |
237 | |
238 if y_true is None: | |
239 desc_path = exp_dir / "description.json" | |
240 if desc_path.exists(): | |
241 try: | |
242 with open(desc_path, 'r') as f: | |
243 desc = json.load(f) | |
244 dataset_path = desc.get('dataset', '') | |
245 if dataset_path and Path(dataset_path).exists(): | |
246 df_orig = pd.read_csv(dataset_path) | |
247 if SPLIT_COLUMN_NAME in df_orig.columns: | |
248 df_orig = df_orig[df_orig[SPLIT_COLUMN_NAME] == 2].reset_index(drop=True) | |
249 if LABEL_COLUMN_NAME in df_orig.columns: | |
250 y_true = df_orig[LABEL_COLUMN_NAME].values | |
251 if len(y_true) != len(df_pred): | |
252 print( | |
253 f"Warning: Test set size mismatch. Truncating to {len(df_pred)} samples for ROC plot." | |
254 ) | |
255 y_true = y_true[:len(df_pred)] | |
256 else: | |
257 print("Warning: Original dataset referenced in description.json is unavailable.") | |
258 except Exception as exc: # pragma: no cover - defensive | |
259 print(f"Warning: Failed to recover labels from dataset: {exc}") | |
260 | |
261 if y_true is None or len(y_true) == 0: | |
262 print("Warning: Unable to locate ground-truth labels for ROC plot.") | |
263 return None | |
264 | |
265 if len(y_true) != len(y_score): | |
266 limit = min(len(y_true), len(y_score)) | |
267 if limit == 0: | |
268 return None | |
269 print(f"Warning: Aligning prediction and label lengths to {limit} samples for ROC plot.") | |
270 y_true = y_true[:limit] | |
271 y_score = y_score[:limit] | |
272 | |
273 # Get actual class names from probability column names | |
274 actual_classes = [col.replace('label_probabilities_', '') for col in prob_cols] | |
275 display_classes = class_labels if len(class_labels) == n_classes else actual_classes | |
276 | |
277 # Binarize the output following sklearn example | |
278 # Use actual class names if they're strings, otherwise use range | |
279 if isinstance(y_true[0], str): | |
280 y_test = label_binarize(y_true, classes=actual_classes) | |
281 else: | |
282 y_test = label_binarize(y_true, classes=list(range(n_classes))) | |
283 | |
284 # Handle binary classification case | |
285 if y_test.ndim != 2: | |
286 y_test = np.atleast_2d(y_test) | |
287 | |
288 if n_classes == 2: | |
289 if y_test.shape[1] == 1: | |
290 y_test = np.hstack([1 - y_test, y_test]) | |
291 elif y_test.shape[1] != 2: | |
292 print("Warning: Unexpected label binarization shape for binary ROC plot.") | |
293 return None | |
294 elif y_test.shape[1] != n_classes: | |
295 print("Warning: Label binarization did not produce expected class dimension; skipping ROC plot.") | |
296 return None | |
297 | |
298 # Compute ROC curve and ROC area for each class (following sklearn example) | |
299 fpr = dict() | |
300 tpr = dict() | |
301 roc_auc = dict() | |
302 | |
303 for i in range(n_classes): | |
304 if np.sum(y_test[:, i]) > 0: # Check if class exists in test set | |
305 fpr[i], tpr[i], _ = roc_curve(y_test[:, i], y_score[:, i]) | |
306 roc_auc[i] = auc(fpr[i], tpr[i]) | |
307 | |
308 # Compute micro-average ROC curve and ROC area (sklearn example) | |
309 fpr["micro"], tpr["micro"], _ = roc_curve(y_test.ravel(), y_score.ravel()) | |
310 roc_auc["micro"] = auc(fpr["micro"], tpr["micro"]) | |
311 | |
312 # Create ROC curve plot | |
313 fig_roc = go.Figure() | |
314 | |
315 # Colors for different classes | |
316 colors = [ | |
317 '#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd', | |
318 '#8c564b', '#e377c2', '#7f7f7f', '#bcbd22', '#17becf' | |
319 ] | |
320 | |
321 # Plot micro-average ROC curve first (most important) | |
322 fig_roc.add_trace(go.Scatter( | |
323 x=fpr["micro"], | |
324 y=tpr["micro"], | |
325 mode='lines', | |
326 name=f'Micro-average ROC (AUC = {roc_auc["micro"]:.3f})', | |
327 line=dict(color='deeppink', width=3, dash='dot'), | |
328 hovertemplate=('<b>Micro-average ROC</b><br>' | |
329 'FPR: %{x:.3f}<br>' | |
330 'TPR: %{y:.3f}<br>' | |
331 f'AUC: {roc_auc["micro"]:.3f}<extra></extra>') | |
332 )) | |
333 | |
334 # Plot ROC curve for each class | |
335 for i in range(n_classes): | |
336 if i in roc_auc: # Only plot if class exists in test set | |
337 class_name = display_classes[i] if i < len(display_classes) else f"Class {i}" | |
338 color = colors[i % len(colors)] | |
339 | |
340 fig_roc.add_trace(go.Scatter( | |
341 x=fpr[i], | |
342 y=tpr[i], | |
343 mode='lines', | |
344 name=f'{class_name} (AUC = {roc_auc[i]:.3f})', | |
345 line=dict(color=color, width=2), | |
346 hovertemplate=(f'<b>{class_name}</b><br>' | |
347 'FPR: %{x:.3f}<br>' | |
348 'TPR: %{y:.3f}<br>' | |
349 f'AUC: {roc_auc[i]:.3f}<extra></extra>') | |
350 )) | |
351 | |
352 # Add diagonal line (random classifier) | |
353 fig_roc.add_trace(go.Scatter( | |
354 x=[0, 1], | |
355 y=[0, 1], | |
356 mode='lines', | |
357 name='Random Classifier', | |
358 line=dict(color='gray', width=1, dash='dash'), | |
359 hovertemplate='Random Classifier<br>AUC = 0.500<extra></extra>' | |
360 )) | |
361 | |
362 # Calculate macro-average AUC | |
363 class_aucs = [roc_auc[i] for i in range(n_classes) if i in roc_auc] | |
364 if class_aucs: | |
365 macro_auc = np.mean(class_aucs) | |
366 title_text = f"ROC Curves (Micro-avg = {roc_auc['micro']:.3f}, Macro-avg = {macro_auc:.3f})" | |
367 else: | |
368 title_text = f"ROC Curves (Micro-avg = {roc_auc['micro']:.3f})" | |
369 | |
370 fig_roc.update_layout( | |
371 title=dict(text=title_text, x=0.5), | |
372 xaxis_title="False Positive Rate", | |
373 yaxis_title="True Positive Rate", | |
374 width=700, | |
375 height=600, | |
376 margin=dict(t=80, l=80, r=80, b=80), | |
377 legend=dict( | |
378 x=0.6, | |
379 y=0.1, | |
380 bgcolor="rgba(255,255,255,0.9)", | |
381 bordercolor="rgba(0,0,0,0.2)", | |
382 borderwidth=1 | |
383 ), | |
384 hovermode='closest' | |
385 ) | |
386 | |
387 # Set equal aspect ratio and proper range | |
388 fig_roc.update_xaxes(range=[0, 1.0]) | |
389 fig_roc.update_yaxes(range=[0, 1.05]) | |
390 | |
391 return { | |
392 "title": "ROC-AUC Curves", | |
393 "html": pio.to_html( | |
394 fig_roc, | |
395 full_html=False, | |
396 include_plotlyjs=False, | |
397 config=config | |
398 ) | |
399 } | |
400 | |
401 except Exception as e: | |
402 print(f"Error building ROC-AUC plot: {e}") | |
403 return None |