Mercurial > repos > goeckslab > image_learner
view plotly_plots.py @ 18:bbf30253c99f draft default tip
planemo upload for repository https://github.com/goeckslab/gleam.git commit c155acd3616dfac920b17653179d7bc38ba48e39
| author | goeckslab |
|---|---|
| date | Sun, 14 Dec 2025 03:27:12 +0000 |
| parents | db9be962dc13 |
| children |
line wrap: on
line source
import json from pathlib import Path from typing import Dict, List, Optional import numpy as np import pandas as pd import plotly.graph_objects as go import plotly.io as pio from constants import LABEL_COLUMN_NAME, SPLIT_COLUMN_NAME from sklearn.metrics import ( accuracy_score, auc, average_precision_score, f1_score, precision_recall_curve, precision_score, recall_score, roc_curve, ) from sklearn.preprocessing import label_binarize def _style_fig(fig: go.Figure, font_size: int = 12) -> go.Figure: """Apply consistent styling across Plotly figures.""" fig.update_layout( font=dict(size=font_size), plot_bgcolor="#ffffff", paper_bgcolor="#ffffff", ) fig.update_xaxes(gridcolor="#e8e8e8") fig.update_yaxes(gridcolor="#e8e8e8") return fig def _fig_to_html( fig: go.Figure, *, include_js: bool = False, config: Optional[dict] = None ) -> str: """Render a Plotly figure to a lightweight HTML fragment.""" include_plotlyjs = "cdn" if include_js else False return pio.to_html( fig, full_html=False, include_plotlyjs=include_plotlyjs, config=config, ) def _wrap_plot( title: str, fig: go.Figure, *, include_js: bool = False, config: Optional[dict] = None, ) -> Dict[str, str]: """Package a figure with its title for downstream HTML rendering.""" return {"title": title, "html": _fig_to_html(fig, include_js=include_js, config=config)} def _line_chart( traces: List[tuple], *, title: str, yaxis_title: str, ) -> go.Figure: """Build a basic epoch-indexed line chart for train/val/test curves.""" fig = go.Figure() for name, series in traces: if not series: continue epochs = list(range(1, len(series) + 1)) fig.add_trace( go.Scatter( x=epochs, y=series, mode="lines+markers", name=name, line=dict(width=4), ) ) fig.update_layout( title=dict(text=title, x=0.5), xaxis_title="Epoch", yaxis_title=yaxis_title, width=760, height=520, hovermode="x unified", ) _style_fig(fig) return fig def _labels_from_metadata_dict(meta_dict: dict) -> List[str]: """Extract ordered label names from Ludwig train_set_metadata.""" if not isinstance(meta_dict, dict): return [] for key in ("idx2str", "idx2label", "vocab"): seq = meta_dict.get(key) if isinstance(seq, list) and seq: return [str(v) for v in seq] str2idx = meta_dict.get("str2idx") if isinstance(str2idx, dict) and str2idx: int_indices = [v for v in str2idx.values() if isinstance(v, int)] if int_indices: max_idx = max(int_indices) ordered = [None] * (max_idx + 1) for name, idx in str2idx.items(): if isinstance(idx, int) and 0 <= idx < len(ordered): ordered[idx] = name return [str(v) for v in ordered if v is not None] return [] def _resolve_confusion_labels( label_stats: dict, n_classes: int, metadata_csv_path: Optional[str], train_set_metadata_path: Optional[str], ) -> List[str]: """Prefer original labels from metadata; fall back to stats if unavailable.""" if train_set_metadata_path: try: meta_path = Path(train_set_metadata_path) if meta_path.exists(): with open(meta_path, "r") as f: meta_json = json.load(f) label_meta = meta_json.get(LABEL_COLUMN_NAME) if not isinstance(label_meta, dict): label_meta = next( ( v for v in meta_json.values() if isinstance(v, dict) and any(k in v for k in ("idx2str", "str2idx", "idx2label", "vocab")) ), None, ) labels_from_meta = _labels_from_metadata_dict(label_meta) if label_meta else [] if labels_from_meta and len(labels_from_meta) >= n_classes: return [str(label) for label in labels_from_meta[:n_classes]] except Exception as exc: print(f"Warning: Unable to read labels from train_set_metadata: {exc}") if metadata_csv_path: try: csv_path = Path(metadata_csv_path) if csv_path.exists(): df_meta = pd.read_csv(csv_path) if LABEL_COLUMN_NAME in df_meta.columns: uniques = df_meta[LABEL_COLUMN_NAME].dropna().unique().tolist() if uniques and len(uniques) >= n_classes: return [str(u) for u in uniques[:n_classes]] except Exception as exc: print(f"Warning: Unable to read labels from metadata CSV: {exc}") pcs = label_stats.get("per_class_stats", {}) if pcs: pcs_labels = [str(k) for k in pcs.keys()] if len(pcs_labels) >= n_classes: return pcs_labels[:n_classes] labels = label_stats.get("labels") if not labels: labels = [str(i) for i in range(n_classes)] if len(labels) < n_classes: labels = labels + [str(i) for i in range(len(labels), n_classes)] return [str(label) for label in labels[:n_classes]] def build_classification_plots( test_stats_path: str, training_stats_path: Optional[str] = None, metadata_csv_path: Optional[str] = None, train_set_metadata_path: Optional[str] = None, threshold: Optional[float] = None, ) -> List[Dict[str, str]]: """ Read Ludwig’s test_statistics.json and build three interactive Plotly panels: - Confusion Matrix - ROC-AUC - Classification Report Heatmap If metadata paths are provided, the confusion matrix axes will use the original label values from the training metadata rather than integer-encoded labels. Returns a list of dicts, each with: { "title": <plot title>, "html": <HTML fragment for embedding> } """ # --- Load test stats --- with open(test_stats_path, "r") as f: test_stats = json.load(f) label_stats = test_stats["label"] # common sizing cell = 40 n_classes = len(label_stats["confusion_matrix"]) side_px = max(cell * n_classes + 200, 600) common_cfg = {"displayModeBar": True, "scrollZoom": True} plots: List[Dict[str, str]] = [] # 0) Confusion Matrix cm = np.array(label_stats["confusion_matrix"], dtype=int) labels = _resolve_confusion_labels( label_stats, n_classes, metadata_csv_path=metadata_csv_path, train_set_metadata_path=train_set_metadata_path, ) total = cm.sum() fig_cm = go.Figure( go.Heatmap( z=cm, x=labels, y=labels, colorscale="Blues", showscale=True, colorbar=dict(title="Count"), ) ) fig_cm.update_traces(xgap=2, ygap=2) cm_title = "Confusion Matrix" if threshold is not None: cm_title = f"Confusion Matrix (Threshold: {threshold})" fig_cm.update_layout( title=dict(text=cm_title, x=0.5), xaxis_title="Predicted", yaxis_title="Observed", yaxis_autorange="reversed", width=side_px, height=side_px, margin=dict(t=100, l=80, r=80, b=80), ) _style_fig(fig_cm) # annotate counts and percentages mval = cm.max() if cm.size else 0 thresh = mval / 2 for i in range(cm.shape[0]): for j in range(cm.shape[1]): v = cm[i, j] pct = (v / total * 100) if total > 0 else 0 color = "white" if v > thresh else "black" fig_cm.add_annotation( x=labels[j], y=labels[i], text=f"<b>{v}</b>", showarrow=False, font=dict(color=color, size=14), xanchor="center", yanchor="bottom", yshift=2, ) fig_cm.add_annotation( x=labels[j], y=labels[i], text=f"{pct:.1f}%", showarrow=False, font=dict(color=color, size=13), xanchor="center", yanchor="top", yshift=-2, ) plots.append( _wrap_plot("Confusion Matrix", fig_cm, include_js=True, config=common_cfg) ) # 1) ROC / PR curves only for binary tasks if n_classes == 2: roc_plot = _build_static_roc_plot(label_stats, common_cfg, friendly_labels=labels) if roc_plot: plots.append(roc_plot) pr_plot = _build_precision_recall_plot(label_stats, common_cfg) if pr_plot: plots.append(pr_plot) # 2) Classification Report Heatmap pcs = label_stats.get("per_class_stats", {}) if pcs: classes = list(pcs.keys()) metrics = [ "precision", "recall", "f1_score", "accuracy", "matthews_correlation_coefficient", "specificity", ] z, txt = [], [] for c in classes: row, trow = [], [] for m in metrics: val = pcs[c].get(m, 0) row.append(val) trow.append(f"{val:.2f}") z.append(row) txt.append(trow) fig_cr = go.Figure( go.Heatmap( z=z, x=[m.replace("_", " ") for m in metrics], y=[str(c) for c in classes], text=txt, texttemplate="%{text}", colorscale="Reds", showscale=True, colorbar=dict(title="Value"), ) ) fig_cr.update_layout( title="Per-Class metrics", xaxis_title="", yaxis_title="Class", width=side_px, height=side_px, margin=dict(t=80, l=80, r=80, b=80), ) _style_fig(fig_cr) plots.append( _wrap_plot("Per-Class metrics", fig_cr, config=common_cfg) ) # 3) Prediction Diagnostics (from predictions.csv) # Note: appended separately in generate_html_report, not returned here. return plots def build_train_validation_plots(train_stats_path: str) -> List[Dict[str, str]]: """Generate Train/Validation learning curve plots from training_statistics.json.""" if not train_stats_path or not Path(train_stats_path).exists(): return [] try: with open(train_stats_path, "r") as f: train_stats = json.load(f) except Exception as exc: print(f"Warning: Unable to read training statistics: {exc}") return [] label_train = (train_stats.get("training") or {}).get("label", {}) label_val = (train_stats.get("validation") or {}).get("label", {}) if not label_train and not label_val: return [] plots: List[Dict[str, str]] = [] include_js = True # Load Plotly.js once for this group def _get_series(stats: dict, metric: str) -> List[float]: vals = stats.get(metric, []) if isinstance(vals, list): return [float(v) for v in vals] try: return [float(vals)] except Exception: return [] metric_specs = [ ("loss", "Loss across epochs", "Loss"), ("accuracy", "Accuracy across epochs", "Accuracy"), ("roc_auc", "ROC-AUC across epochs", "ROC-AUC"), ("precision", "Precision across epochs", "Precision"), ("recall", "Recall/Sensitivity across epochs", "Recall"), ("specificity", "Specificity across epochs", "Specificity"), ] for key, title, yaxis in metric_specs: train_series = _get_series(label_train, key) val_series = _get_series(label_val, key) if not train_series and not val_series: continue fig = _line_chart( [("Train", train_series), ("Validation", val_series)], title=title, yaxis_title=yaxis, ) plots.append(_wrap_plot(title, fig, include_js=include_js)) include_js = False # Precision vs Recall evolution (validation) val_prec = _get_series(label_val, "precision") val_rec = _get_series(label_val, "recall") if val_prec and val_rec: max_len = min(len(val_prec), len(val_rec)) fig_pr = _line_chart( [ ("Precision", val_prec[:max_len]), ("Recall", val_rec[:max_len]), ], title="Validation Precision and Recall by Epoch", yaxis_title="Value", ) plots.append(_wrap_plot("Precision vs Recall Evolution", fig_pr, include_js=include_js)) include_js = False def _compute_f1(p: List[float], r: List[float]) -> List[float]: return [ 0.0 if (prec + rec) == 0 else 2 * prec * rec / (prec + rec) for prec, rec in zip(p, r) ] f1_train = _compute_f1(_get_series(label_train, "precision"), _get_series(label_train, "recall")) f1_val = _compute_f1(val_prec, val_rec) if f1_train or f1_val: fig_f1 = _line_chart( [("Train", f1_train), ("Validation", f1_val)], title="F1-Score across epochs (derived)", yaxis_title="F1-Score", ) plots.append(_wrap_plot("F1-Score across epochs (derived)", fig_f1, include_js=include_js)) include_js = False # Overfitting Gap: Train vs Val ROC-AUC (gap) roc_train = _get_series(label_train, "roc_auc") roc_val = _get_series(label_val, "roc_auc") if roc_train and roc_val: max_len = min(len(roc_train), len(roc_val)) gaps = [t - v for t, v in zip(roc_train[:max_len], roc_val[:max_len])] fig_gap = _line_chart( [("Train - Val ROC-AUC", gaps)], title="Overfitting gap: ROC-AUC across epochs", yaxis_title="Gap", ) plots.append(_wrap_plot("Overfitting gap: ROC-AUC across epochs", fig_gap, include_js=include_js)) include_js = False # Best Epoch Dashboard (based on max val ROC-AUC) if roc_val: best_idx = int(np.argmax(roc_val)) best_epoch = best_idx + 1 metrics_at_best: Dict[str, Optional[float]] = { "ROC-AUC": roc_val[best_idx] if best_idx < len(roc_val) else None } for metric_key, label in [ ("accuracy", "Accuracy"), ("balanced_accuracy", "Balanced Accuracy"), ("precision", "Precision"), ("recall", "Recall"), ("specificity", "Specificity"), ("loss", "Loss"), ]: series = _get_series(label_val, metric_key) if series and best_idx < len(series): metrics_at_best[label] = series[best_idx] if f1_val and best_idx < len(f1_val): metrics_at_best["F1-Score (derived)"] = f1_val[best_idx] fig_best = go.Figure() for name, value in metrics_at_best.items(): if value is not None: fig_best.add_trace(go.Bar(name=name, x=[name], y=[value])) fig_best.update_layout( title=dict(text=f"Best Epoch Dashboard (Val ROC-AUC @ epoch {best_epoch})", x=0.5), xaxis_title="Metric", yaxis_title="Value", width=760, height=520, showlegend=False, ) _style_fig(fig_best) plots.append(_wrap_plot("Best Validation Epoch Snapshot (Metrics)", fig_best, include_js=include_js)) return plots def _get_regression_series(split_stats: dict, metric: str) -> List[float]: if metric not in split_stats: return [] vals = split_stats.get(metric, []) if isinstance(vals, list): return [float(v) for v in vals] try: return [float(vals)] except Exception: return [] def _regression_line_plot( train_split: dict, val_split: dict, metric_key: str, title: str, yaxis_title: str, include_js: bool, ) -> Optional[Dict[str, str]]: train_series = _get_regression_series(train_split, metric_key) val_series = _get_regression_series(val_split, metric_key) if not train_series and not val_series: return None fig = _line_chart( [("Train", train_series), ("Validation", val_series)], title=title, yaxis_title=yaxis_title, ) return _wrap_plot(title, fig, include_js=include_js) def build_regression_train_val_plots(train_stats_path: str) -> List[Dict[str, str]]: """Generate regression Train/Validation learning curve plots from training_statistics.json.""" if not train_stats_path or not Path(train_stats_path).exists(): return [] try: with open(train_stats_path, "r") as f: train_stats = json.load(f) except Exception as exc: print(f"Warning: Unable to read training statistics: {exc}") return [] label_train = (train_stats.get("training") or {}).get("label", {}) label_val = (train_stats.get("validation") or {}).get("label", {}) if not label_train and not label_val: return [] plots: List[Dict[str, str]] = [] include_js = True for metric_key, title, ytitle in [ ("mean_absolute_error", "Mean Absolute Error across epochs", "MAE"), ("root_mean_squared_error", "Root Mean Squared Error across epochs", "RMSE"), ("mean_absolute_percentage_error", "Mean Absolute Percentage Error across epochs", "MAPE"), ("r2", "R² across epochs", "R²"), ("loss", "Loss across epochs", "Loss"), ]: plot = _regression_line_plot(label_train, label_val, metric_key, title, ytitle, include_js) if plot: plots.append(plot) include_js = False return plots def build_regression_test_plots(train_stats_path: str) -> List[Dict[str, str]]: """Generate regression Test learning curves from training_statistics.json.""" if not train_stats_path or not Path(train_stats_path).exists(): return [] try: with open(train_stats_path, "r") as f: train_stats = json.load(f) except Exception as exc: print(f"Warning: Unable to read training statistics: {exc}") return [] label_test = (train_stats.get("test") or {}).get("label", {}) if not label_test: return [] plots: List[Dict[str, str]] = [] include_js = True metrics = [ ("mean_absolute_error", "Mean Absolute Error Across Epochs", "MAE"), ("root_mean_squared_error", "Root Mean Squared Error Across Epochs", "RMSE"), ("mean_absolute_percentage_error", "Mean Absolute Percentage Error Across Epochs", "MAPE"), ("r2", "R² Across Epochs", "R²"), ("loss", "Loss Across Epochs", "Loss"), ] for metric_key, title, ytitle in metrics: series = _get_regression_series(label_test, metric_key) if not series: continue fig = _line_chart( [("Test", series)], title=title, yaxis_title=ytitle, ) plots.append(_wrap_plot(title, fig, include_js=include_js)) include_js = False return plots def _build_static_roc_plot( label_stats: dict, config: dict, friendly_labels: Optional[List[str]] = None, threshold: Optional[float] = None, ) -> Optional[Dict[str, str]]: """Build ROC curve directly from test_statistics.json (single curve).""" roc_data = label_stats.get("roc_curve") if not isinstance(roc_data, dict): return None fpr = roc_data.get("false_positive_rate") tpr = roc_data.get("true_positive_rate") if not fpr or not tpr or len(fpr) != len(tpr): return None try: fig = go.Figure() fig.add_trace( go.Scatter( x=fpr, y=tpr, mode="lines+markers", name="ROC Curve", line=dict(color="#1f77b4", width=4), hovertemplate="FPR: %{x:.3f}<br>TPR: %{y:.3f}<extra></extra>", ) ) fig.add_trace( go.Scatter( x=[0, 1], y=[0, 1], mode="lines", name="Random Classifier", line=dict(color="gray", width=2, dash="dash"), hovertemplate="Random Classifier<extra></extra>", ) ) auc_val = label_stats.get("roc_auc") or label_stats.get("roc_auc_macro") or label_stats.get("roc_auc_micro") auc_txt = f" (AUC = {auc_val:.3f})" if isinstance(auc_val, (int, float)) else "" # Determine which label is treated as positive for the curve label_list: List = [] pcs = label_stats.get("per_class_stats", {}) if pcs: label_list = list(pcs.keys()) if not label_list: labels_from_stats = label_stats.get("labels") if isinstance(labels_from_stats, list): label_list = labels_from_stats # Try to resolve index of the positive label explicitly provided by Ludwig pos_label_raw = ( roc_data.get("positive_label") or roc_data.get("positive_class") or label_stats.get("positive_label") ) pos_label_idx = None if pos_label_raw is not None and isinstance(label_list, list): try: pos_label_idx = label_list.index(pos_label_raw) except ValueError: pos_label_idx = None # Fallback: use the second label if available, otherwise the first if pos_label_idx is None: if isinstance(label_list, list) and len(label_list) >= 2: pos_label_idx = 1 elif isinstance(label_list, list) and label_list: pos_label_idx = 0 if pos_label_raw is None and isinstance(label_list, list) and pos_label_idx is not None: pos_label_raw = label_list[pos_label_idx] # Map to friendly label if we have one from metadata/CSV pos_label_display = pos_label_raw if ( friendly_labels and isinstance(pos_label_idx, int) and 0 <= pos_label_idx < len(friendly_labels) ): pos_label_display = friendly_labels[pos_label_idx] pos_label_txt = ( f"Positive class: {pos_label_display}" if pos_label_display is not None else "Positive class: (not available)" ) title_label = f"ROC Curve{auc_txt}" if pos_label_display is not None: title_label = f"ROC Curve (Positive Class: {pos_label_display}){auc_txt}" fig.update_layout( title=dict(text=title_label, x=0.5), xaxis_title="False Positive Rate", yaxis_title="True Positive Rate", width=700, height=600, margin=dict(t=80, l=80, r=80, b=110), hovermode="closest", legend=dict( x=0.6, y=0.1, bgcolor="rgba(255,255,255,0.9)", bordercolor="rgba(0,0,0,0.2)", borderwidth=1, ), ) _style_fig(fig) fig.update_xaxes(range=[0, 1.0]) fig.update_yaxes(range=[0, 1.05]) roc_thresholds = roc_data.get("thresholds") if threshold is not None and isinstance(roc_thresholds, list) and len(roc_thresholds) == len(fpr): try: diffs = [abs(th - threshold) for th in roc_thresholds] best_idx = int(np.argmin(diffs)) # dashed guides through the chosen point fig.add_shape( type="line", x0=fpr[best_idx], x1=fpr[best_idx], y0=0, y1=tpr[best_idx], line=dict(color="gray", width=2, dash="dash"), ) fig.add_shape( type="line", x0=0, x1=fpr[best_idx], y0=tpr[best_idx], y1=tpr[best_idx], line=dict(color="gray", width=2, dash="dash"), ) fig.add_trace( go.Scatter( x=[fpr[best_idx]], y=[tpr[best_idx]], mode="markers", marker=dict(color="black", size=10, symbol="x"), name=f"Threshold={threshold}", hovertemplate="FPR: %{x:.3f}<br>TPR: %{y:.3f}<br>Threshold: %{text}<extra></extra>", text=[f"{threshold}"], ) ) except Exception as exc: print(f"Warning: could not add threshold marker to ROC: {exc}") fig.add_annotation( x=0.5, y=-0.15, xref="paper", yref="paper", showarrow=False, text=f"<span style='font-size:12px;color:#555;'>{pos_label_txt}</span>", xanchor="center", ) return _wrap_plot("ROC Curve", fig, config=config) except Exception as e: print(f"Error building ROC plot: {e}") return None def _build_precision_recall_plot( label_stats: dict, config: dict, threshold: Optional[float] = None, ) -> Optional[Dict[str, str]]: """Build Precision-Recall curve directly from test_statistics.json.""" pr_data = label_stats.get("precision_recall_curve") if not isinstance(pr_data, dict): return None precisions = pr_data.get("precisions") recalls = pr_data.get("recalls") if not precisions or not recalls or len(precisions) != len(recalls): return None thresholds = pr_data.get("thresholds") try: fig = go.Figure() fig.add_trace( go.Scatter( x=recalls, y=precisions, mode="lines+markers", name="Precision-Recall", line=dict(color="#d62728", width=4), hovertemplate="Recall: %{x:.3f}<br>Precision: %{y:.3f}<extra></extra>", ) ) ap_val = ( label_stats.get("average_precision_macro") or label_stats.get("average_precision_micro") or label_stats.get("average_precision_samples") ) ap_txt = f" (AP = {ap_val:.3f})" if isinstance(ap_val, (int, float)) else "" fig.update_layout( title=dict(text=f"Precision-Recall Curve{ap_txt}", x=0.5), xaxis_title="Recall", yaxis_title="Precision", width=700, height=600, margin=dict(t=80, l=80, r=80, b=80), hovermode="closest", legend=dict( x=0.6, y=0.1, bgcolor="rgba(255,255,255,0.9)", bordercolor="rgba(0,0,0,0.2)", borderwidth=1, ), ) _style_fig(fig) fig.update_xaxes(range=[0, 1.0]) fig.update_yaxes(range=[0, 1.05]) if threshold is not None and isinstance(thresholds, list) and len(thresholds) == len(recalls): try: diffs = [abs(th - threshold) for th in thresholds] best_idx = int(np.argmin(diffs)) fig.add_shape( type="line", x0=recalls[best_idx], x1=recalls[best_idx], y0=0, y1=precisions[best_idx], line=dict(color="gray", width=2, dash="dash"), ) fig.add_shape( type="line", x0=0, x1=recalls[best_idx], y0=precisions[best_idx], y1=precisions[best_idx], line=dict(color="gray", width=2, dash="dash"), ) fig.add_trace( go.Scatter( x=[recalls[best_idx]], y=[precisions[best_idx]], mode="markers", marker=dict(color="black", size=10, symbol="x"), name=f"Threshold={threshold}", hovertemplate="Recall: %{x:.3f}<br>Precision: %{y:.3f}<br>Threshold: %{text}<extra></extra>", text=[f"{threshold}"], ) ) except Exception as exc: print(f"Warning: could not add threshold marker to PR: {exc}") return _wrap_plot("Precision-Recall Curve", fig, config=config) except Exception as e: print(f"Error building Precision-Recall plot: {e}") return None def build_prediction_diagnostics( predictions_path: str, label_data_path: Optional[str] = None, split_value: int = 2, ) -> List[Dict[str, str]]: """Generate diagnostic plots from predictions.csv for classification tasks.""" preds_file = Path(predictions_path) if not preds_file.exists(): return [] try: df_pred = pd.read_csv(predictions_path) except Exception as exc: print(f"Warning: Unable to read predictions CSV: {exc}") return [] plots: List[Dict[str, str]] = [] labels_from_dataset: Optional[pd.Series] = None filtered_by_split = False # If a split column exists, focus on the requested split (e.g., validation=1, test=2). # If not, but label_data_path is available and matches row count, use it to filter predictions. if SPLIT_COLUMN_NAME in df_pred.columns: df_pred = df_pred[df_pred[SPLIT_COLUMN_NAME] == split_value].reset_index(drop=True) if df_pred.empty: return [] filtered_by_split = True elif label_data_path and Path(label_data_path).exists(): try: df_labels_all = pd.read_csv(label_data_path) if SPLIT_COLUMN_NAME in df_labels_all.columns and len(df_labels_all) == len(df_pred): split_mask = pd.to_numeric(df_labels_all[SPLIT_COLUMN_NAME], errors="coerce") == split_value labels_from_dataset = df_labels_all.loc[split_mask, LABEL_COLUMN_NAME].reset_index(drop=True) df_pred = df_pred.loc[split_mask].reset_index(drop=True) if df_pred.empty: return [] filtered_by_split = True except Exception as exc: print(f"Warning: Unable to filter predictions by split from label data: {exc}") # Fallback: no split info available. Assume the predictions file is already filtered # (common for test-only exports) and avoid heuristic slicing that could discard rows. if not filtered_by_split: if split_value != 2: return [] def _strip_prob_prefix(col: str) -> str: if col.startswith("label_probabilities_"): return col.replace("label_probabilities_", "") if col.startswith("probabilities_"): return col.replace("probabilities_", "") return col def _maybe_expand_probabilities_column(df: pd.DataFrame, labels_guess: List[str]) -> List[str]: """If only a single 'probabilities' column exists (list-like), expand it into per-class columns.""" if "probabilities" not in df.columns: return [] try: # Parse first non-null entry to infer length first_val = df["probabilities"].dropna().iloc[0] parsed = first_val if isinstance(first_val, str): parsed = json.loads(first_val) probs = list(parsed) n = len(probs) if n == 0: return [] # Build labels: prefer provided guess; otherwise numeric if labels_guess and len(labels_guess) == n: labels_use = labels_guess else: labels_use = [str(i) for i in range(n)] # Expand column for idx, lbl in enumerate(labels_use): df[f"probabilities_{lbl}"] = df["probabilities"].apply( lambda v: (json.loads(v)[idx] if isinstance(v, str) else list(v)[idx]) if pd.notnull(v) else np.nan ) return [f"probabilities_{lbl}" for lbl in labels_use] except Exception: return [] # Identify probability columns prob_cols = [ c for c in df_pred.columns if ( (c.startswith("label_probabilities_") or c.startswith("probabilities_")) and c != "label_probabilities" ) ] if not prob_cols and "label_probability" in df_pred.columns: prob_cols = ["label_probability"] if not prob_cols and "probability" in df_pred.columns: prob_cols = ["probability"] if not prob_cols and "prediction_probability" in df_pred.columns: prob_cols = ["prediction_probability"] if not prob_cols and "probabilities" in df_pred.columns: labels_guess = sorted([str(u) for u in pd.unique(df_pred[LABEL_COLUMN_NAME])]) prob_cols = _maybe_expand_probabilities_column(df_pred, labels_guess) prob_cols_sorted = sorted(prob_cols) def _select_positive_prob(): if not prob_cols_sorted: return None, None # Prefer a column indicating positive/event/true/1 preferred_keys = ("event", "true", "positive", "pos", "1") for col in prob_cols_sorted: suffix = _strip_prob_prefix(col).lower() if any(k in suffix for k in preferred_keys): return col, suffix if len(prob_cols_sorted) == 2: col = prob_cols_sorted[1] return col, _strip_prob_prefix(col) col = prob_cols_sorted[0] return col, _strip_prob_prefix(col) pos_prob_col, pos_label_hint = _select_positive_prob() pos_prob_series = df_pred[pos_prob_col] if pos_prob_col and pos_prob_col in df_pred else None # Confidence series: prefer label_probability, otherwise positive prob, otherwise max prob confidence_series = None if "label_probability" in df_pred.columns: confidence_series = df_pred["label_probability"] elif pos_prob_series is not None: confidence_series = pos_prob_series elif prob_cols_sorted: confidence_series = df_pred[prob_cols_sorted].max(axis=1) # True labels def _extract_labels(): if labels_from_dataset is not None: return labels_from_dataset candidates = [ LABEL_COLUMN_NAME, f"{LABEL_COLUMN_NAME}_ground_truth", f"{LABEL_COLUMN_NAME}__ground_truth", f"{LABEL_COLUMN_NAME}_target", f"{LABEL_COLUMN_NAME}__target", "label", "label_true", ] candidates.extend( [ col for col in df_pred.columns if (col.startswith(f"{LABEL_COLUMN_NAME}_") or col.startswith(f"{LABEL_COLUMN_NAME}__")) and "probabilities" not in col and "predictions" not in col ] ) for col in candidates: if col in df_pred.columns and col not in prob_cols_sorted: return df_pred[col] if label_data_path and Path(label_data_path).exists(): try: df_all = pd.read_csv(label_data_path) if SPLIT_COLUMN_NAME in df_all.columns: df_all = df_all[df_all[SPLIT_COLUMN_NAME] == split_value].reset_index(drop=True) if LABEL_COLUMN_NAME in df_all.columns: return df_all[LABEL_COLUMN_NAME].reset_index(drop=True) except Exception as exc: print(f"Warning: Unable to load labels from dataset: {exc}") return None labels_series = _extract_labels() # Plot 1: Confidence Histogram if confidence_series is not None: fig_conf = go.Figure() fig_conf.add_trace( go.Histogram( x=confidence_series, nbinsx=20, marker=dict(color="#1f77b4", line=dict(color="#ffffff", width=1)), opacity=0.8, histnorm="percent", ) ) fig_conf.update_layout( title=dict(text="Prediction Confidence Distribution", x=0.5), xaxis_title="Predicted probability (confidence)", yaxis_title="Percentage (%)", bargap=0.05, width=700, height=500, ) _style_fig(fig_conf) plots.append(_wrap_plot("Prediction Confidence Distribution", fig_conf)) # The remaining plots require true labels and a positive-class probability if labels_series is None or pos_prob_series is None: return plots # Align lengths min_len = min(len(labels_series), len(pos_prob_series)) if min_len == 0: return plots y_true_raw = labels_series.iloc[:min_len] y_score = np.array(pos_prob_series.iloc[:min_len], dtype=float) # Determine positive label unique_labels = pd.unique(y_true_raw) unique_labels_list = list(unique_labels) positive_label = None if pos_label_hint and str(pos_label_hint) in [str(u) for u in unique_labels_list]: positive_label = pos_label_hint elif len(unique_labels_list) == 2: positive_label = unique_labels_list[1] else: positive_label = unique_labels_list[0] y_true = (y_true_raw == positive_label).astype(int).values # Utility: compute calibration points def _calibration_points(y_true_bin: np.ndarray, scores: np.ndarray): bins = np.linspace(0.0, 1.0, 11) bin_ids = np.digitize(scores, bins, right=True) bin_centers, frac_positives = [], [] for b in range(1, len(bins)): mask = bin_ids == b if not np.any(mask): continue bin_centers.append(scores[mask].mean()) frac_positives.append(y_true_bin[mask].mean()) return bin_centers, frac_positives # Plot 2: Calibration Curve (multi-class aware; one-vs-rest per label) label_prob_map = {} for col in prob_cols_sorted: if col.startswith("label_probabilities_"): cls = col.replace("label_probabilities_", "") label_prob_map[cls] = col unique_label_strs = [str(u) for u in unique_labels_list] if len(label_prob_map) > 1 and len(unique_label_strs) > 2: # Skip multi-class calibration curve for now (not informative in current report) pass else: # Binary/unknown fallback (previous behavior) bin_centers, frac_positives = _calibration_points(y_true, y_score) if bin_centers and frac_positives: fig_cal = go.Figure() fig_cal.add_trace( go.Scatter( x=bin_centers, y=frac_positives, mode="lines+markers", name="Calibration", line=dict(color="#2ca02c", width=4), ) ) fig_cal.add_trace( go.Scatter( x=[0, 1], y=[0, 1], mode="lines", name="Perfect Calibration", line=dict(color="gray", width=2, dash="dash"), ) ) fig_cal.update_layout( title=dict(text="Calibration Curve", x=0.5), xaxis_title="Predicted probability", yaxis_title="Observed frequency", width=700, height=500, ) _style_fig(fig_cal) plots.append( _wrap_plot( "Calibration Curve (Predicted Probability vs Observed Frequency)", fig_cal, ) ) return plots def build_binary_threshold_plot( predictions_path: str, label_data_path: Optional[str] = None, split_value: int = 1, ) -> Optional[Dict[str, str]]: """Build a binary threshold sweep plot (accuracy, precision, recall, F1) for a given split.""" preds_file = Path(predictions_path) if not preds_file.exists(): return None try: df_pred = pd.read_csv(predictions_path) except Exception as exc: print(f"Warning: Unable to read predictions CSV for threshold plot: {exc}") return None labels_from_dataset: Optional[pd.Series] = None df_full = df_pred.copy() def _filter_by_split(df: pd.DataFrame, split_val: int) -> pd.DataFrame: if SPLIT_COLUMN_NAME in df.columns: return df[df[SPLIT_COLUMN_NAME] == split_val].reset_index(drop=True) return df # Try preferred split, then fallback to others with data (val -> test -> train) candidate_splits = [split_value, 2, 0, 1] if split_value == 1 else [split_value, 1, 0, 2] df_candidate = pd.DataFrame() used_split: Optional[int] = None for sv in candidate_splits: df_candidate = _filter_by_split(df_full, sv) if not df_candidate.empty: used_split = sv break if used_split is None: df_candidate = df_full df_pred = df_candidate.reset_index(drop=True) # If still empty (e.g., split column exists but no rows for candidates), fall back to all rows if df_pred.empty: df_pred = df_full.reset_index(drop=True) labels_from_dataset = None if label_data_path and Path(label_data_path).exists(): try: df_labels_all = pd.read_csv(label_data_path) if SPLIT_COLUMN_NAME in df_labels_all.columns and len(df_labels_all) == len(df_full): mask = ( pd.to_numeric(df_labels_all[SPLIT_COLUMN_NAME], errors="coerce") == used_split if used_split is not None and SPLIT_COLUMN_NAME in df_labels_all.columns else pd.Series([True] * len(df_full)) ) labels_from_dataset = df_labels_all.loc[mask, LABEL_COLUMN_NAME].reset_index(drop=True) if len(labels_from_dataset) == len(df_pred): labels_from_dataset = labels_from_dataset.reset_index(drop=True) except Exception as exc: print(f"Warning: Unable to align labels for threshold plot: {exc}") # Identify probability columns prob_cols = [ c for c in df_pred.columns if ( (c.startswith("label_probabilities_") or c.startswith("probabilities_")) and c != "label_probabilities" ) ] if not prob_cols and "probabilities" in df_pred.columns: labels_guess = sorted([str(u) for u in pd.unique(df_pred.get(LABEL_COLUMN_NAME, []))]) # reuse expansion logic from diagnostics try: first_val = df_pred["probabilities"].dropna().iloc[0] parsed = json.loads(first_val) if isinstance(first_val, str) else list(first_val) n = len(parsed) if n > 0: if labels_guess and len(labels_guess) == n: labels_use = labels_guess else: labels_use = [str(i) for i in range(n)] for idx, lbl in enumerate(labels_use): df_pred[f"probabilities_{lbl}"] = df_pred["probabilities"].apply( lambda v: (json.loads(v)[idx] if isinstance(v, str) else list(v)[idx]) if pd.notnull(v) else np.nan ) prob_cols = [f"probabilities_{lbl}" for lbl in labels_use] except Exception: prob_cols = [] prob_cols_sorted = sorted(prob_cols) def _strip_prob_prefix(col: str) -> str: if col.startswith("label_probabilities_"): return col.replace("label_probabilities_", "") if col.startswith("probabilities_"): return col.replace("probabilities_", "") return col # True labels def _extract_labels(): if labels_from_dataset is not None: return labels_from_dataset for col in [ LABEL_COLUMN_NAME, f"{LABEL_COLUMN_NAME}_ground_truth", f"{LABEL_COLUMN_NAME}__ground_truth", f"{LABEL_COLUMN_NAME}_target", f"{LABEL_COLUMN_NAME}__target", "label", "label_true", "label_predictions", "prediction", ]: if col in df_pred.columns and col not in prob_cols_sorted: return df_pred[col] return None labels_series = _extract_labels() if labels_series is None or not prob_cols_sorted: return None # Positive prob column selection preferred_keys = ("event", "true", "positive", "pos", "1") pos_prob_col = None for col in prob_cols_sorted: suffix = _strip_prob_prefix(col).lower() if any(k in suffix for k in preferred_keys): pos_prob_col = col break if pos_prob_col is None: pos_prob_col = prob_cols_sorted[-1] min_len = min(len(labels_series), len(df_pred[pos_prob_col])) if min_len == 0: return None y_true = np.array(labels_series.iloc[:min_len]) # map to binary 0/1 unique_labels = pd.unique(y_true) if len(unique_labels) < 2: return None positive_label = unique_labels[1] if len(unique_labels) >= 2 else unique_labels[0] y_true_bin = (y_true == positive_label).astype(int) y_score = np.array(df_pred[pos_prob_col].iloc[:min_len], dtype=float) thresholds = np.linspace(0.0, 1.0, 101) accs: List[float] = [] precs: List[float] = [] recs: List[float] = [] f1s: List[float] = [] for t in thresholds: preds = (y_score >= t).astype(int) accs.append(accuracy_score(y_true_bin, preds)) precs.append(precision_score(y_true_bin, preds, zero_division=0)) recs.append(recall_score(y_true_bin, preds, zero_division=0)) f1s.append(f1_score(y_true_bin, preds, zero_division=0)) best_idx = int(np.argmax(f1s)) best_thr = thresholds[best_idx] fig = go.Figure() fig.add_trace(go.Scatter(x=thresholds, y=accs, mode="lines", name="Accuracy", line=dict(width=4))) fig.add_trace(go.Scatter(x=thresholds, y=precs, mode="lines", name="Precision", line=dict(width=4))) fig.add_trace(go.Scatter(x=thresholds, y=recs, mode="lines", name="Recall", line=dict(width=4))) fig.add_trace(go.Scatter(x=thresholds, y=f1s, mode="lines", name="F1-Score", line=dict(width=4))) fig.add_shape( type="line", x0=best_thr, x1=best_thr, y0=0, y1=1, line=dict(color="gray", width=2, dash="dash"), ) fig.update_layout( title=dict(text="Threshold plot", x=0.5), xaxis_title="Threshold", yaxis_title="Metric value", yaxis=dict(range=[0, 1]), width=760, height=520, legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1), ) _style_fig(fig) return _wrap_plot("Threshold plot", fig, include_js=True) def build_multiclass_roc_pr_plots( predictions_path: str, split_value: int = 2, ) -> List[Dict[str, str]]: """Build one-vs-rest ROC and PR curves for multi-class classification from predictions.""" preds_file = Path(predictions_path) if not preds_file.exists(): return [] try: df_pred = pd.read_csv(predictions_path) except Exception as exc: print(f"Warning: Unable to read predictions CSV: {exc}") return [] if SPLIT_COLUMN_NAME in df_pred.columns: df_pred = df_pred[df_pred[SPLIT_COLUMN_NAME] == split_value].reset_index(drop=True) if df_pred.empty: return [] if LABEL_COLUMN_NAME not in df_pred.columns: return [] # Identify per-class probability columns prob_cols = [ c for c in df_pred.columns if ( (c.startswith("label_probabilities_") or c.startswith("probabilities_")) and c != "label_probabilities" ) ] if not prob_cols: return [] labels = [c.replace("label_probabilities_", "").replace("probabilities_", "") for c in prob_cols] labels_sorted = sorted(labels) # Ensure all labels are present as probability columns prob_map = { c.replace("label_probabilities_", "").replace("probabilities_", ""): c for c in prob_cols } if len(labels_sorted) < 3: return [] y_true_raw = df_pred[LABEL_COLUMN_NAME].astype(str) # Drop rows with NaN probabilities across any class to avoid metric errors prob_matrix = df_pred[[prob_map[lbl] for lbl in labels_sorted]].astype(float) mask_valid = ~prob_matrix.isnull().any(axis=1) prob_matrix = prob_matrix[mask_valid] y_true_raw = y_true_raw[mask_valid] if prob_matrix.empty: return [] y_true_bin = label_binarize(y_true_raw, classes=labels_sorted) y_score = prob_matrix.to_numpy() plots: List[Dict[str, str]] = [] # ROC: one-vs-rest + micro fig_roc = go.Figure() added_any = False for idx, lbl in enumerate(labels_sorted): if y_true_bin[:, idx].sum() == 0 or y_true_bin[:, idx].sum() == len(y_true_bin): continue # skip classes without both positives and negatives fpr, tpr, _ = roc_curve(y_true_bin[:, idx], y_score[:, idx]) fig_roc.add_trace( go.Scatter( x=fpr, y=tpr, mode="lines", name=f"{lbl} (AUC={auc(fpr, tpr):.3f})", line=dict(width=3), ) ) added_any = True # Micro-average only if we have mixed labels if y_true_bin.sum() > 0 and y_true_bin.sum() < y_true_bin.size: fpr_micro, tpr_micro, _ = roc_curve(y_true_bin.ravel(), y_score.ravel()) fig_roc.add_trace( go.Scatter( x=fpr_micro, y=tpr_micro, mode="lines", name=f"Micro-average (AUC={auc(fpr_micro, tpr_micro):.3f})", line=dict(width=3, dash="dash"), ) ) added_any = True if not added_any: return [] fig_roc.add_trace( go.Scatter( x=[0, 1], y=[0, 1], mode="lines", name="Random", line=dict(color="gray", width=2, dash="dot"), ) ) fig_roc.update_layout( title=dict(text="Multi-class ROC-AUC (one-vs-rest)", x=0.5), xaxis_title="False Positive Rate", yaxis_title="True Positive Rate", width=820, height=620, legend=dict( x=0.62, y=0.05, bgcolor="rgba(255,255,255,0.9)", bordercolor="rgba(0,0,0,0.2)", borderwidth=1, ), ) _style_fig(fig_roc) plots.append(_wrap_plot("Multi-class ROC-AUC (one-vs-rest)", fig_roc)) # PR: one-vs-rest + micro AP fig_pr = go.Figure() added_pr = False for idx, lbl in enumerate(labels_sorted): if y_true_bin[:, idx].sum() == 0: continue prec, rec, _ = precision_recall_curve(y_true_bin[:, idx], y_score[:, idx]) ap = average_precision_score(y_true_bin[:, idx], y_score[:, idx]) fig_pr.add_trace( go.Scatter( x=rec, y=prec, mode="lines", name=f"{lbl} (AP={ap:.3f})", line=dict(width=3), ) ) added_pr = True if y_true_bin.sum() > 0: prec_micro, rec_micro, _ = precision_recall_curve(y_true_bin.ravel(), y_score.ravel()) ap_micro = average_precision_score(y_true_bin, y_score, average="micro") fig_pr.add_trace( go.Scatter( x=rec_micro, y=prec_micro, mode="lines", name=f"Micro-average (AP={ap_micro:.3f})", line=dict(width=3, dash="dash"), ) ) added_pr = True if not added_pr: return plots fig_pr.update_layout( title=dict(text="Multi-class Precision-Recall (one-vs-rest)", x=0.5), xaxis_title="Recall", yaxis_title="Precision", width=820, height=620, legend=dict( x=0.62, y=0.05, bgcolor="rgba(255,255,255,0.9)", bordercolor="rgba(0,0,0,0.2)", borderwidth=1, ), ) _style_fig(fig_pr) plots.append(_wrap_plot("Multi-class Precision-Recall (one-vs-rest)", fig_pr)) return plots def build_multiclass_metric_plots(test_stats_path: str) -> List[Dict[str, str]]: """Alternative multi-class transparency plots using test_statistics.json per-class stats.""" ts_path = Path(test_stats_path) if not ts_path.exists(): return [] try: with open(ts_path, "r") as f: test_stats = json.load(f) except Exception: return [] label_stats = test_stats.get("label", {}) pcs = label_stats.get("per_class_stats", {}) if not pcs: return [] classes = list(pcs.keys()) if not classes: return [] metrics = ["precision", "recall", "f1_score", "specificity", "accuracy"] fig_bar = go.Figure() for metric in metrics: values = [] for cls in classes: v = pcs.get(cls, {}).get(metric) values.append(v if isinstance(v, (int, float)) else 0) fig_bar.add_trace( go.Bar( x=classes, y=values, name=metric.replace("_", " ").title(), ) ) fig_bar.update_layout( title=dict(text="Per-Class Metrics (Test)", x=0.5), xaxis_title="Class", yaxis_title="Metric value", barmode="group", width=900, height=600, legend=dict( x=1.02, y=1.0, bgcolor="rgba(255,255,255,0.9)", bordercolor="rgba(0,0,0,0.2)", borderwidth=1, ), ) _style_fig(fig_bar) return [_wrap_plot("Per-Class Metrics (Test)", fig_bar)]
