comparison plotly_plots.py @ 11:c5150cceab47 draft default tip

planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
author goeckslab
date Sat, 18 Oct 2025 03:17:09 +0000
parents 85e6f4b2ad18
children
comparison
equal deleted inserted replaced
10:b0d893d04d4c 11:c5150cceab47
1 import json 1 import json
2 from pathlib import Path
2 from typing import Dict, List, Optional 3 from typing import Dict, List, Optional
3 4
4 import numpy as np 5 import numpy as np
6 import pandas as pd
5 import plotly.graph_objects as go 7 import plotly.graph_objects as go
6 import plotly.io as pio 8 import plotly.io as pio
9 from constants import LABEL_COLUMN_NAME, SPLIT_COLUMN_NAME
10 from sklearn.metrics import auc, roc_curve
11 from sklearn.preprocessing import label_binarize
7 12
8 13
9 def build_classification_plots( 14 def build_classification_plots(
10 test_stats_path: str, 15 test_stats_path: str,
11 training_stats_path: Optional[str] = None, 16 training_stats_path: Optional[str] = None,
35 40
36 plots: List[Dict[str, str]] = [] 41 plots: List[Dict[str, str]] = []
37 42
38 # 0) Confusion Matrix 43 # 0) Confusion Matrix
39 cm = np.array(label_stats["confusion_matrix"], dtype=int) 44 cm = np.array(label_stats["confusion_matrix"], dtype=int)
40 labels = label_stats.get("labels", [str(i) for i in range(cm.shape[0])]) 45 # Try to get actual class names from per_class_stats keys (which contain the real labels)
46 pcs = label_stats.get("per_class_stats", {})
47 if pcs:
48 labels = list(pcs.keys())
49 else:
50 labels = label_stats.get("labels", [str(i) for i in range(cm.shape[0])])
41 total = cm.sum() 51 total = cm.sum()
42 52
43 fig_cm = go.Figure( 53 fig_cm = go.Figure(
44 go.Heatmap( 54 go.Heatmap(
45 z=cm, 55 z=cm,
98 include_plotlyjs="cdn", 108 include_plotlyjs="cdn",
99 config=common_cfg 109 config=common_cfg
100 ) 110 )
101 }) 111 })
102 112
113 # 1) ROC-AUC Curves (Multi-class)
114 roc_plot = _build_roc_auc_plot(test_stats_path, labels, common_cfg)
115 if roc_plot:
116 plots.append(roc_plot)
117
103 # 2) Classification Report Heatmap 118 # 2) Classification Report Heatmap
104 pcs = label_stats.get("per_class_stats", {}) 119 pcs = label_stats.get("per_class_stats", {})
105 if pcs: 120 if pcs:
106 classes = list(pcs.keys()) 121 classes = list(pcs.keys())
107 metrics = ["precision", "recall", "f1_score"] 122 metrics = ["precision", "recall", "f1_score"]
144 config=common_cfg 159 config=common_cfg
145 ) 160 )
146 }) 161 })
147 162
148 return plots 163 return plots
164
165
166 def _build_roc_auc_plot(test_stats_path: str, class_labels: List[str], config: dict) -> Optional[Dict[str, str]]:
167 """
168 Build an interactive ROC-AUC curve plot for multi-class classification.
169 Following sklearn's ROC example with micro-average and per-class curves.
170
171 Args:
172 test_stats_path: Path to test_statistics.json
173 class_labels: List of class label names
174 config: Plotly config dict
175
176 Returns:
177 Dict with title and HTML, or None if data unavailable
178 """
179 try:
180 # Get the experiment directory from test_stats_path
181 exp_dir = Path(test_stats_path).parent
182
183 # Load predictions with probabilities
184 predictions_path = exp_dir / "predictions.csv"
185 if not predictions_path.exists():
186 return None
187
188 df_pred = pd.read_csv(predictions_path)
189
190 if SPLIT_COLUMN_NAME in df_pred.columns:
191 split_series = df_pred[SPLIT_COLUMN_NAME].astype(str).str.lower()
192 test_mask = split_series.isin({"2", "test", "testing"})
193 if test_mask.any():
194 df_pred = df_pred[test_mask].reset_index(drop=True)
195
196 if df_pred.empty:
197 return None
198
199 # Extract probability columns (label_probabilities_0, label_probabilities_1, etc.)
200 # or label_probabilities_<class_name> for string labels
201 prob_cols = [col for col in df_pred.columns if col.startswith('label_probabilities_') and col != 'label_probabilities']
202
203 # Sort by class number if numeric, otherwise keep alphabetical order
204 if prob_cols and prob_cols[0].split('_')[-1].isdigit():
205 prob_cols.sort(key=lambda x: int(x.split('_')[-1]))
206 else:
207 prob_cols.sort() # Alphabetical sort for string class names
208
209 if not prob_cols:
210 return None
211
212 # Get probabilities matrix (n_samples x n_classes)
213 y_score = df_pred[prob_cols].values
214 n_classes = len(prob_cols)
215
216 y_true = None
217 candidate_cols = [
218 LABEL_COLUMN_NAME,
219 f"{LABEL_COLUMN_NAME}_ground_truth",
220 f"{LABEL_COLUMN_NAME}__ground_truth",
221 f"{LABEL_COLUMN_NAME}_target",
222 f"{LABEL_COLUMN_NAME}__target",
223 ]
224 candidate_cols.extend(
225 [
226 col
227 for col in df_pred.columns
228 if (col.startswith(f"{LABEL_COLUMN_NAME}_") or col.startswith(f"{LABEL_COLUMN_NAME}__"))
229 and "probabilities" not in col
230 and "predictions" not in col
231 ]
232 )
233 for col in candidate_cols:
234 if col in df_pred.columns and col not in prob_cols:
235 y_true = df_pred[col].values
236 break
237
238 if y_true is None:
239 desc_path = exp_dir / "description.json"
240 if desc_path.exists():
241 try:
242 with open(desc_path, 'r') as f:
243 desc = json.load(f)
244 dataset_path = desc.get('dataset', '')
245 if dataset_path and Path(dataset_path).exists():
246 df_orig = pd.read_csv(dataset_path)
247 if SPLIT_COLUMN_NAME in df_orig.columns:
248 df_orig = df_orig[df_orig[SPLIT_COLUMN_NAME] == 2].reset_index(drop=True)
249 if LABEL_COLUMN_NAME in df_orig.columns:
250 y_true = df_orig[LABEL_COLUMN_NAME].values
251 if len(y_true) != len(df_pred):
252 print(
253 f"Warning: Test set size mismatch. Truncating to {len(df_pred)} samples for ROC plot."
254 )
255 y_true = y_true[:len(df_pred)]
256 else:
257 print("Warning: Original dataset referenced in description.json is unavailable.")
258 except Exception as exc: # pragma: no cover - defensive
259 print(f"Warning: Failed to recover labels from dataset: {exc}")
260
261 if y_true is None or len(y_true) == 0:
262 print("Warning: Unable to locate ground-truth labels for ROC plot.")
263 return None
264
265 if len(y_true) != len(y_score):
266 limit = min(len(y_true), len(y_score))
267 if limit == 0:
268 return None
269 print(f"Warning: Aligning prediction and label lengths to {limit} samples for ROC plot.")
270 y_true = y_true[:limit]
271 y_score = y_score[:limit]
272
273 # Get actual class names from probability column names
274 actual_classes = [col.replace('label_probabilities_', '') for col in prob_cols]
275 display_classes = class_labels if len(class_labels) == n_classes else actual_classes
276
277 # Binarize the output following sklearn example
278 # Use actual class names if they're strings, otherwise use range
279 if isinstance(y_true[0], str):
280 y_test = label_binarize(y_true, classes=actual_classes)
281 else:
282 y_test = label_binarize(y_true, classes=list(range(n_classes)))
283
284 # Handle binary classification case
285 if y_test.ndim != 2:
286 y_test = np.atleast_2d(y_test)
287
288 if n_classes == 2:
289 if y_test.shape[1] == 1:
290 y_test = np.hstack([1 - y_test, y_test])
291 elif y_test.shape[1] != 2:
292 print("Warning: Unexpected label binarization shape for binary ROC plot.")
293 return None
294 elif y_test.shape[1] != n_classes:
295 print("Warning: Label binarization did not produce expected class dimension; skipping ROC plot.")
296 return None
297
298 # Compute ROC curve and ROC area for each class (following sklearn example)
299 fpr = dict()
300 tpr = dict()
301 roc_auc = dict()
302
303 for i in range(n_classes):
304 if np.sum(y_test[:, i]) > 0: # Check if class exists in test set
305 fpr[i], tpr[i], _ = roc_curve(y_test[:, i], y_score[:, i])
306 roc_auc[i] = auc(fpr[i], tpr[i])
307
308 # Compute micro-average ROC curve and ROC area (sklearn example)
309 fpr["micro"], tpr["micro"], _ = roc_curve(y_test.ravel(), y_score.ravel())
310 roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])
311
312 # Create ROC curve plot
313 fig_roc = go.Figure()
314
315 # Colors for different classes
316 colors = [
317 '#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd',
318 '#8c564b', '#e377c2', '#7f7f7f', '#bcbd22', '#17becf'
319 ]
320
321 # Plot micro-average ROC curve first (most important)
322 fig_roc.add_trace(go.Scatter(
323 x=fpr["micro"],
324 y=tpr["micro"],
325 mode='lines',
326 name=f'Micro-average ROC (AUC = {roc_auc["micro"]:.3f})',
327 line=dict(color='deeppink', width=3, dash='dot'),
328 hovertemplate=('<b>Micro-average ROC</b><br>'
329 'FPR: %{x:.3f}<br>'
330 'TPR: %{y:.3f}<br>'
331 f'AUC: {roc_auc["micro"]:.3f}<extra></extra>')
332 ))
333
334 # Plot ROC curve for each class
335 for i in range(n_classes):
336 if i in roc_auc: # Only plot if class exists in test set
337 class_name = display_classes[i] if i < len(display_classes) else f"Class {i}"
338 color = colors[i % len(colors)]
339
340 fig_roc.add_trace(go.Scatter(
341 x=fpr[i],
342 y=tpr[i],
343 mode='lines',
344 name=f'{class_name} (AUC = {roc_auc[i]:.3f})',
345 line=dict(color=color, width=2),
346 hovertemplate=(f'<b>{class_name}</b><br>'
347 'FPR: %{x:.3f}<br>'
348 'TPR: %{y:.3f}<br>'
349 f'AUC: {roc_auc[i]:.3f}<extra></extra>')
350 ))
351
352 # Add diagonal line (random classifier)
353 fig_roc.add_trace(go.Scatter(
354 x=[0, 1],
355 y=[0, 1],
356 mode='lines',
357 name='Random Classifier',
358 line=dict(color='gray', width=1, dash='dash'),
359 hovertemplate='Random Classifier<br>AUC = 0.500<extra></extra>'
360 ))
361
362 # Calculate macro-average AUC
363 class_aucs = [roc_auc[i] for i in range(n_classes) if i in roc_auc]
364 if class_aucs:
365 macro_auc = np.mean(class_aucs)
366 title_text = f"ROC Curves (Micro-avg = {roc_auc['micro']:.3f}, Macro-avg = {macro_auc:.3f})"
367 else:
368 title_text = f"ROC Curves (Micro-avg = {roc_auc['micro']:.3f})"
369
370 fig_roc.update_layout(
371 title=dict(text=title_text, x=0.5),
372 xaxis_title="False Positive Rate",
373 yaxis_title="True Positive Rate",
374 width=700,
375 height=600,
376 margin=dict(t=80, l=80, r=80, b=80),
377 legend=dict(
378 x=0.6,
379 y=0.1,
380 bgcolor="rgba(255,255,255,0.9)",
381 bordercolor="rgba(0,0,0,0.2)",
382 borderwidth=1
383 ),
384 hovermode='closest'
385 )
386
387 # Set equal aspect ratio and proper range
388 fig_roc.update_xaxes(range=[0, 1.0])
389 fig_roc.update_yaxes(range=[0, 1.05])
390
391 return {
392 "title": "ROC-AUC Curves",
393 "html": pio.to_html(
394 fig_roc,
395 full_html=False,
396 include_plotlyjs=False,
397 config=config
398 )
399 }
400
401 except Exception as e:
402 print(f"Error building ROC-AUC plot: {e}")
403 return None