Mercurial > repos > goeckslab > image_learner
diff plotly_plots.py @ 17:db9be962dc13 draft
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
| author | goeckslab |
|---|---|
| date | Wed, 10 Dec 2025 00:24:13 +0000 |
| parents | d17e3a1b8659 |
| children |
line wrap: on
line diff
--- a/plotly_plots.py Wed Dec 03 01:28:52 2025 +0000 +++ b/plotly_plots.py Wed Dec 10 00:24:13 2025 +0000 @@ -7,6 +7,17 @@ import plotly.graph_objects as go import plotly.io as pio from constants import LABEL_COLUMN_NAME, SPLIT_COLUMN_NAME +from sklearn.metrics import ( + accuracy_score, + auc, + average_precision_score, + f1_score, + precision_recall_curve, + precision_score, + recall_score, + roc_curve, +) +from sklearn.preprocessing import label_binarize def _style_fig(fig: go.Figure, font_size: int = 12) -> go.Figure: @@ -21,6 +32,64 @@ return fig +def _fig_to_html( + fig: go.Figure, *, include_js: bool = False, config: Optional[dict] = None +) -> str: + """Render a Plotly figure to a lightweight HTML fragment.""" + include_plotlyjs = "cdn" if include_js else False + return pio.to_html( + fig, + full_html=False, + include_plotlyjs=include_plotlyjs, + config=config, + ) + + +def _wrap_plot( + title: str, + fig: go.Figure, + *, + include_js: bool = False, + config: Optional[dict] = None, +) -> Dict[str, str]: + """Package a figure with its title for downstream HTML rendering.""" + return {"title": title, "html": _fig_to_html(fig, include_js=include_js, config=config)} + + +def _line_chart( + traces: List[tuple], + *, + title: str, + yaxis_title: str, +) -> go.Figure: + """Build a basic epoch-indexed line chart for train/val/test curves.""" + fig = go.Figure() + for name, series in traces: + if not series: + continue + epochs = list(range(1, len(series) + 1)) + fig.add_trace( + go.Scatter( + x=epochs, + y=series, + mode="lines+markers", + name=name, + line=dict(width=4), + ) + ) + + fig.update_layout( + title=dict(text=title, x=0.5), + xaxis_title="Epoch", + yaxis_title=yaxis_title, + width=760, + height=520, + hovermode="x unified", + ) + _style_fig(fig) + return fig + + def _labels_from_metadata_dict(meta_dict: dict) -> List[str]: """Extract ordered label names from Ludwig train_set_metadata.""" if not isinstance(meta_dict, dict): @@ -106,6 +175,7 @@ training_stats_path: Optional[str] = None, metadata_csv_path: Optional[str] = None, train_set_metadata_path: Optional[str] = None, + threshold: Optional[float] = None, ) -> List[Dict[str, str]]: """ Read Ludwig’s test_statistics.json and build three interactive Plotly panels: @@ -156,8 +226,11 @@ ) ) fig_cm.update_traces(xgap=2, ygap=2) + cm_title = "Confusion Matrix" + if threshold is not None: + cm_title = f"Confusion Matrix (Threshold: {threshold})" fig_cm.update_layout( - title=dict(text="Confusion Matrix", x=0.5), + title=dict(text=cm_title, x=0.5), xaxis_title="Predicted", yaxis_title="Observed", yaxis_autorange="reversed", @@ -196,25 +269,19 @@ yshift=-2, ) - plots.append({ - "title": "Confusion Matrix", - "html": pio.to_html( - fig_cm, - full_html=False, - include_plotlyjs="cdn", - config=common_cfg - ) - }) + plots.append( + _wrap_plot("Confusion Matrix", fig_cm, include_js=True, config=common_cfg) + ) - # 1) ROC Curve (from test_statistics) - roc_plot = _build_static_roc_plot(label_stats, common_cfg, friendly_labels=labels) - if roc_plot: - plots.append(roc_plot) + # 1) ROC / PR curves only for binary tasks + if n_classes == 2: + roc_plot = _build_static_roc_plot(label_stats, common_cfg, friendly_labels=labels) + if roc_plot: + plots.append(roc_plot) - # 2) Precision-Recall Curve (from test_statistics) - pr_plot = _build_precision_recall_plot(label_stats, common_cfg) - if pr_plot: - plots.append(pr_plot) + pr_plot = _build_precision_recall_plot(label_stats, common_cfg) + if pr_plot: + plots.append(pr_plot) # 2) Classification Report Heatmap pcs = label_stats.get("per_class_stats", {}) @@ -259,15 +326,9 @@ margin=dict(t=80, l=80, r=80, b=80), ) _style_fig(fig_cr) - plots.append({ - "title": "Per-Class metrics", - "html": pio.to_html( - fig_cr, - full_html=False, - include_plotlyjs=False, - config=common_cfg - ) - }) + plots.append( + _wrap_plot("Per-Class metrics", fig_cr, config=common_cfg) + ) # 3) Prediction Diagnostics (from predictions.csv) # Note: appended separately in generate_html_report, not returned here. @@ -294,8 +355,6 @@ include_js = True # Load Plotly.js once for this group def _get_series(stats: dict, metric: str) -> List[float]: - if metric not in stats: - return [] vals = stats.get(metric, []) if isinstance(vals, list): return [float(v) for v in vals] @@ -304,181 +363,98 @@ except Exception: return [] - def _line_plot(metric_key: str, title: str, yaxis_title: str) -> Optional[Dict[str, str]]: - train_series = _get_series(label_train, metric_key) - val_series = _get_series(label_val, metric_key) + metric_specs = [ + ("loss", "Loss across epochs", "Loss"), + ("accuracy", "Accuracy across epochs", "Accuracy"), + ("roc_auc", "ROC-AUC across epochs", "ROC-AUC"), + ("precision", "Precision across epochs", "Precision"), + ("recall", "Recall/Sensitivity across epochs", "Recall"), + ("specificity", "Specificity across epochs", "Specificity"), + ] + + for key, title, yaxis in metric_specs: + train_series = _get_series(label_train, key) + val_series = _get_series(label_val, key) if not train_series and not val_series: - return None - epochs_train = list(range(1, len(train_series) + 1)) - epochs_val = list(range(1, len(val_series) + 1)) - fig = go.Figure() - if train_series: - fig.add_trace( - go.Scatter( - x=epochs_train, - y=train_series, - mode="lines+markers", - name="Train", - line=dict(width=4), - ) - ) - if val_series: - fig.add_trace( - go.Scatter( - x=epochs_val, - y=val_series, - mode="lines+markers", - name="Validation", - line=dict(width=4), - ) - ) - fig.update_layout( - title=dict(text=title, x=0.5), - xaxis_title="Epoch", - yaxis_title=yaxis_title, - width=760, - height=520, - hovermode="x unified", + continue + fig = _line_chart( + [("Train", train_series), ("Validation", val_series)], + title=title, + yaxis_title=yaxis, ) - _style_fig(fig) - return { - "title": title, - "html": pio.to_html( - fig, - full_html=False, - include_plotlyjs="cdn" if include_js else False, - ), - } - - # Core learning curves - for key, title in [ - ("roc_auc", "ROC-AUC across epochs"), - ("precision", "Precision across epochs"), - ("recall", "Recall/Sensitivity across epochs"), - ("specificity", "Specificity across epochs"), - ]: - plot = _line_plot(key, title, title.replace("Learning Curve", "").strip()) - if plot: - plots.append(plot) - include_js = False + plots.append(_wrap_plot(title, fig, include_js=include_js)) + include_js = False # Precision vs Recall evolution (validation) val_prec = _get_series(label_val, "precision") val_rec = _get_series(label_val, "recall") if val_prec and val_rec: - epochs = list(range(1, min(len(val_prec), len(val_rec)) + 1)) - fig_pr = go.Figure() - fig_pr.add_trace( - go.Scatter( - x=epochs, - y=val_prec[: len(epochs)], - mode="lines+markers", - name="Precision", - ) - ) - fig_pr.add_trace( - go.Scatter( - x=epochs, - y=val_rec[: len(epochs)], - mode="lines+markers", - name="Recall", - ) + max_len = min(len(val_prec), len(val_rec)) + fig_pr = _line_chart( + [ + ("Precision", val_prec[:max_len]), + ("Recall", val_rec[:max_len]), + ], + title="Validation Precision and Recall by Epoch", + yaxis_title="Value", ) - fig_pr.update_layout( - title=dict(text="Validation Precision and Recall by Epoch", x=0.5), - xaxis_title="Epoch", - yaxis_title="Value", - width=760, - height=520, - hovermode="x unified", - ) - _style_fig(fig_pr) - plots.append({ - "title": "Precision vs Recall Evolution", - "html": pio.to_html( - fig_pr, - full_html=False, - include_plotlyjs="cdn" if include_js else False, - ), - }) + plots.append(_wrap_plot("Precision vs Recall Evolution", fig_pr, include_js=include_js)) include_js = False - # F1-score derived def _compute_f1(p: List[float], r: List[float]) -> List[float]: - f1_vals = [] - for prec, rec in zip(p, r): - if (prec + rec) == 0: - f1_vals.append(0.0) - else: - f1_vals.append(2 * prec * rec / (prec + rec)) - return f1_vals + return [ + 0.0 if (prec + rec) == 0 else 2 * prec * rec / (prec + rec) + for prec, rec in zip(p, r) + ] f1_train = _compute_f1(_get_series(label_train, "precision"), _get_series(label_train, "recall")) f1_val = _compute_f1(val_prec, val_rec) if f1_train or f1_val: - fig = go.Figure() - if f1_train: - fig.add_trace(go.Scatter(x=list(range(1, len(f1_train) + 1)), y=f1_train, mode="lines+markers", name="Train", line=dict(width=4))) - if f1_val: - fig.add_trace(go.Scatter(x=list(range(1, len(f1_val) + 1)), y=f1_val, mode="lines+markers", name="Validation", line=dict(width=4))) - fig.update_layout( - title=dict(text="F1-Score across epochs (derived)", x=0.5), - xaxis_title="Epoch", + fig_f1 = _line_chart( + [("Train", f1_train), ("Validation", f1_val)], + title="F1-Score across epochs (derived)", yaxis_title="F1-Score", - width=760, - height=520, - hovermode="x unified", ) - _style_fig(fig) - plots.append({ - "title": "F1-Score across epochs (derived)", - "html": pio.to_html( - fig, - full_html=False, - include_plotlyjs="cdn" if include_js else False, - ), - }) + plots.append(_wrap_plot("F1-Score across epochs (derived)", fig_f1, include_js=include_js)) include_js = False # Overfitting Gap: Train vs Val ROC-AUC (gap) roc_train = _get_series(label_train, "roc_auc") roc_val = _get_series(label_val, "roc_auc") if roc_train and roc_val: - epochs_gap = list(range(1, min(len(roc_train), len(roc_val)) + 1)) - gaps = [t - v for t, v in zip(roc_train[:len(epochs_gap)], roc_val[:len(epochs_gap)])] - fig_gap = go.Figure() - fig_gap.add_trace(go.Scatter(x=epochs_gap, y=gaps, mode="lines+markers", name="Train - Val ROC-AUC", line=dict(width=4))) - fig_gap.update_layout( - title=dict(text="Overfitting gap: ROC-AUC across epochs", x=0.5), - xaxis_title="Epoch", + max_len = min(len(roc_train), len(roc_val)) + gaps = [t - v for t, v in zip(roc_train[:max_len], roc_val[:max_len])] + fig_gap = _line_chart( + [("Train - Val ROC-AUC", gaps)], + title="Overfitting gap: ROC-AUC across epochs", yaxis_title="Gap", - width=760, - height=520, - hovermode="x unified", ) - _style_fig(fig_gap) - plots.append({ - "title": "Overfitting gap: ROC-AUC across epochs", - "html": pio.to_html( - fig_gap, - full_html=False, - include_plotlyjs="cdn" if include_js else False, - ), - }) + plots.append(_wrap_plot("Overfitting gap: ROC-AUC across epochs", fig_gap, include_js=include_js)) include_js = False # Best Epoch Dashboard (based on max val ROC-AUC) if roc_val: best_idx = int(np.argmax(roc_val)) best_epoch = best_idx + 1 - spec_val = _get_series(label_val, "specificity") - metrics_at_best = { - "ROC-AUC": roc_val[best_idx] if best_idx < len(roc_val) else None, - "Precision": val_prec[best_idx] if best_idx < len(val_prec) else None, - "Recall": val_rec[best_idx] if best_idx < len(val_rec) else None, - "Specificity": spec_val[best_idx] if best_idx < len(spec_val) else None, - "F1-Score": f1_val[best_idx] if best_idx < len(f1_val) else None, + metrics_at_best: Dict[str, Optional[float]] = { + "ROC-AUC": roc_val[best_idx] if best_idx < len(roc_val) else None } + + for metric_key, label in [ + ("accuracy", "Accuracy"), + ("balanced_accuracy", "Balanced Accuracy"), + ("precision", "Precision"), + ("recall", "Recall"), + ("specificity", "Specificity"), + ("loss", "Loss"), + ]: + series = _get_series(label_val, metric_key) + if series and best_idx < len(series): + metrics_at_best[label] = series[best_idx] + + if f1_val and best_idx < len(f1_val): + metrics_at_best["F1-Score (derived)"] = f1_val[best_idx] + fig_best = go.Figure() for name, value in metrics_at_best.items(): if value is not None: @@ -492,15 +468,7 @@ showlegend=False, ) _style_fig(fig_best) - plots.append({ - "title": "Best Validation Epoch Snapshot (Metrics)", - "html": pio.to_html( - fig_best, - full_html=False, - include_plotlyjs="cdn" if include_js else False, - ), - }) - include_js = False + plots.append(_wrap_plot("Best Validation Epoch Snapshot (Metrics)", fig_best, include_js=include_js)) return plots @@ -529,46 +497,13 @@ val_series = _get_regression_series(val_split, metric_key) if not train_series and not val_series: return None - epochs_train = list(range(1, len(train_series) + 1)) - epochs_val = list(range(1, len(val_series) + 1)) - fig = go.Figure() - if train_series: - fig.add_trace( - go.Scatter( - x=epochs_train, - y=train_series, - mode="lines+markers", - name="Train", - line=dict(width=4), - ) - ) - if val_series: - fig.add_trace( - go.Scatter( - x=epochs_val, - y=val_series, - mode="lines+markers", - name="Validation", - line=dict(width=4), - ) - ) - fig.update_layout( - title=dict(text=title, x=0.5), - xaxis_title="Epoch", + + fig = _line_chart( + [("Train", train_series), ("Validation", val_series)], + title=title, yaxis_title=yaxis_title, - width=760, - height=520, - hovermode="x unified", ) - _style_fig(fig) - return { - "title": title, - "html": pio.to_html( - fig, - full_html=False, - include_plotlyjs="cdn" if include_js else False, - ), - } + return _wrap_plot(title, fig, include_js=include_js) def build_regression_train_val_plots(train_stats_path: str) -> List[Dict[str, str]]: @@ -627,46 +562,25 @@ ("r2", "R² Across Epochs", "R²"), ("loss", "Loss Across Epochs", "Loss"), ] - epochs = None for metric_key, title, ytitle in metrics: series = _get_regression_series(label_test, metric_key) if not series: continue - if epochs is None: - epochs = list(range(1, len(series) + 1)) - fig = go.Figure() - fig.add_trace( - go.Scatter( - x=epochs, - y=series[: len(epochs)], - mode="lines+markers", - name="Test", - line=dict(width=4), - ) + fig = _line_chart( + [("Test", series)], + title=title, + yaxis_title=ytitle, ) - fig.update_layout( - title=dict(text=title, x=0.5), - xaxis_title="Epoch", - yaxis_title=ytitle, - width=760, - height=520, - hovermode="x unified", - ) - _style_fig(fig) - plots.append({ - "title": title, - "html": pio.to_html( - fig, - full_html=False, - include_plotlyjs="cdn" if include_js else False, - ), - }) + plots.append(_wrap_plot(title, fig, include_js=include_js)) include_js = False return plots def _build_static_roc_plot( - label_stats: dict, config: dict, friendly_labels: Optional[List[str]] = None + label_stats: dict, + config: dict, + friendly_labels: Optional[List[str]] = None, + threshold: Optional[float] = None, ) -> Optional[Dict[str, str]]: """Build ROC curve directly from test_statistics.json (single curve).""" roc_data = label_stats.get("roc_curve") @@ -776,6 +690,42 @@ fig.update_xaxes(range=[0, 1.0]) fig.update_yaxes(range=[0, 1.05]) + roc_thresholds = roc_data.get("thresholds") + if threshold is not None and isinstance(roc_thresholds, list) and len(roc_thresholds) == len(fpr): + try: + diffs = [abs(th - threshold) for th in roc_thresholds] + best_idx = int(np.argmin(diffs)) + # dashed guides through the chosen point + fig.add_shape( + type="line", + x0=fpr[best_idx], + x1=fpr[best_idx], + y0=0, + y1=tpr[best_idx], + line=dict(color="gray", width=2, dash="dash"), + ) + fig.add_shape( + type="line", + x0=0, + x1=fpr[best_idx], + y0=tpr[best_idx], + y1=tpr[best_idx], + line=dict(color="gray", width=2, dash="dash"), + ) + fig.add_trace( + go.Scatter( + x=[fpr[best_idx]], + y=[tpr[best_idx]], + mode="markers", + marker=dict(color="black", size=10, symbol="x"), + name=f"Threshold={threshold}", + hovertemplate="FPR: %{x:.3f}<br>TPR: %{y:.3f}<br>Threshold: %{text}<extra></extra>", + text=[f"{threshold}"], + ) + ) + except Exception as exc: + print(f"Warning: could not add threshold marker to ROC: {exc}") + fig.add_annotation( x=0.5, y=-0.15, @@ -786,21 +736,17 @@ xanchor="center", ) - return { - "title": "ROC Curve", - "html": pio.to_html( - fig, - full_html=False, - include_plotlyjs=False, - config=config, - ), - } + return _wrap_plot("ROC Curve", fig, config=config) except Exception as e: print(f"Error building ROC plot: {e}") return None -def _build_precision_recall_plot(label_stats: dict, config: dict) -> Optional[Dict[str, str]]: +def _build_precision_recall_plot( + label_stats: dict, + config: dict, + threshold: Optional[float] = None, +) -> Optional[Dict[str, str]]: """Build Precision-Recall curve directly from test_statistics.json.""" pr_data = label_stats.get("precision_recall_curve") if not isinstance(pr_data, dict): @@ -811,6 +757,8 @@ if not precisions or not recalls or len(precisions) != len(recalls): return None + thresholds = pr_data.get("thresholds") + try: fig = go.Figure() fig.add_trace( @@ -851,15 +799,41 @@ fig.update_xaxes(range=[0, 1.0]) fig.update_yaxes(range=[0, 1.05]) - return { - "title": "Precision-Recall Curve", - "html": pio.to_html( - fig, - full_html=False, - include_plotlyjs=False, - config=config, - ), - } + if threshold is not None and isinstance(thresholds, list) and len(thresholds) == len(recalls): + try: + diffs = [abs(th - threshold) for th in thresholds] + best_idx = int(np.argmin(diffs)) + fig.add_shape( + type="line", + x0=recalls[best_idx], + x1=recalls[best_idx], + y0=0, + y1=precisions[best_idx], + line=dict(color="gray", width=2, dash="dash"), + ) + fig.add_shape( + type="line", + x0=0, + x1=recalls[best_idx], + y0=precisions[best_idx], + y1=precisions[best_idx], + line=dict(color="gray", width=2, dash="dash"), + ) + fig.add_trace( + go.Scatter( + x=[recalls[best_idx]], + y=[precisions[best_idx]], + mode="markers", + marker=dict(color="black", size=10, symbol="x"), + name=f"Threshold={threshold}", + hovertemplate="Recall: %{x:.3f}<br>Precision: %{y:.3f}<br>Threshold: %{text}<extra></extra>", + text=[f"{threshold}"], + ) + ) + except Exception as exc: + print(f"Warning: could not add threshold marker to PR: {exc}") + + return _wrap_plot("Precision-Recall Curve", fig, config=config) except Exception as e: print(f"Error building Precision-Recall plot: {e}") return None @@ -869,7 +843,6 @@ predictions_path: str, label_data_path: Optional[str] = None, split_value: int = 2, - threshold: Optional[float] = None, ) -> List[Dict[str, str]]: """Generate diagnostic plots from predictions.csv for classification tasks.""" preds_file = Path(predictions_path) @@ -883,12 +856,89 @@ return [] plots: List[Dict[str, str]] = [] + labels_from_dataset: Optional[pd.Series] = None + + filtered_by_split = False + + # If a split column exists, focus on the requested split (e.g., validation=1, test=2). + # If not, but label_data_path is available and matches row count, use it to filter predictions. + if SPLIT_COLUMN_NAME in df_pred.columns: + df_pred = df_pred[df_pred[SPLIT_COLUMN_NAME] == split_value].reset_index(drop=True) + if df_pred.empty: + return [] + filtered_by_split = True + elif label_data_path and Path(label_data_path).exists(): + try: + df_labels_all = pd.read_csv(label_data_path) + if SPLIT_COLUMN_NAME in df_labels_all.columns and len(df_labels_all) == len(df_pred): + split_mask = pd.to_numeric(df_labels_all[SPLIT_COLUMN_NAME], errors="coerce") == split_value + labels_from_dataset = df_labels_all.loc[split_mask, LABEL_COLUMN_NAME].reset_index(drop=True) + df_pred = df_pred.loc[split_mask].reset_index(drop=True) + if df_pred.empty: + return [] + filtered_by_split = True + except Exception as exc: + print(f"Warning: Unable to filter predictions by split from label data: {exc}") + + # Fallback: no split info available. Assume the predictions file is already filtered + # (common for test-only exports) and avoid heuristic slicing that could discard rows. + if not filtered_by_split: + if split_value != 2: + return [] + + def _strip_prob_prefix(col: str) -> str: + if col.startswith("label_probabilities_"): + return col.replace("label_probabilities_", "") + if col.startswith("probabilities_"): + return col.replace("probabilities_", "") + return col + + def _maybe_expand_probabilities_column(df: pd.DataFrame, labels_guess: List[str]) -> List[str]: + """If only a single 'probabilities' column exists (list-like), expand it into per-class columns.""" + if "probabilities" not in df.columns: + return [] + try: + # Parse first non-null entry to infer length + first_val = df["probabilities"].dropna().iloc[0] + parsed = first_val + if isinstance(first_val, str): + parsed = json.loads(first_val) + probs = list(parsed) + n = len(probs) + if n == 0: + return [] + # Build labels: prefer provided guess; otherwise numeric + if labels_guess and len(labels_guess) == n: + labels_use = labels_guess + else: + labels_use = [str(i) for i in range(n)] + # Expand column + for idx, lbl in enumerate(labels_use): + df[f"probabilities_{lbl}"] = df["probabilities"].apply( + lambda v: (json.loads(v)[idx] if isinstance(v, str) else list(v)[idx]) if pd.notnull(v) else np.nan + ) + return [f"probabilities_{lbl}" for lbl in labels_use] + except Exception: + return [] # Identify probability columns prob_cols = [ - c for c in df_pred.columns - if c.startswith("label_probabilities_") and c != "label_probabilities" + c + for c in df_pred.columns + if ( + (c.startswith("label_probabilities_") or c.startswith("probabilities_")) + and c != "label_probabilities" + ) ] + if not prob_cols and "label_probability" in df_pred.columns: + prob_cols = ["label_probability"] + if not prob_cols and "probability" in df_pred.columns: + prob_cols = ["probability"] + if not prob_cols and "prediction_probability" in df_pred.columns: + prob_cols = ["prediction_probability"] + if not prob_cols and "probabilities" in df_pred.columns: + labels_guess = sorted([str(u) for u in pd.unique(df_pred[LABEL_COLUMN_NAME])]) + prob_cols = _maybe_expand_probabilities_column(df_pred, labels_guess) prob_cols_sorted = sorted(prob_cols) def _select_positive_prob(): @@ -897,14 +947,14 @@ # Prefer a column indicating positive/event/true/1 preferred_keys = ("event", "true", "positive", "pos", "1") for col in prob_cols_sorted: - suffix = col.replace("label_probabilities_", "").lower() + suffix = _strip_prob_prefix(col).lower() if any(k in suffix for k in preferred_keys): return col, suffix if len(prob_cols_sorted) == 2: col = prob_cols_sorted[1] - return col, col.replace("label_probabilities_", "") + return col, _strip_prob_prefix(col) col = prob_cols_sorted[0] - return col, col.replace("label_probabilities_", "") + return col, _strip_prob_prefix(col) pos_prob_col, pos_label_hint = _select_positive_prob() pos_prob_series = df_pred[pos_prob_col] if pos_prob_col and pos_prob_col in df_pred else None @@ -920,6 +970,8 @@ # True labels def _extract_labels(): + if labels_from_dataset is not None: + return labels_from_dataset candidates = [ LABEL_COLUMN_NAME, f"{LABEL_COLUMN_NAME}_ground_truth", @@ -975,10 +1027,7 @@ height=500, ) _style_fig(fig_conf) - plots.append({ - "title": "Prediction Confidence Distribution", - "html": pio.to_html(fig_conf, full_html=False, include_plotlyjs=False), - }) + plots.append(_wrap_plot("Prediction Confidence Distribution", fig_conf)) # The remaining plots require true labels and a positive-class probability if labels_series is None or pos_prob_series is None: @@ -1004,116 +1053,470 @@ y_true = (y_true_raw == positive_label).astype(int).values - # Plot 2: Calibration Curve - bins = np.linspace(0.0, 1.0, 11) - bin_ids = np.digitize(y_score, bins, right=True) - bin_centers = [] - frac_positives = [] - for b in range(1, len(bins)): - mask = bin_ids == b - if not np.any(mask): - continue - bin_centers.append(y_score[mask].mean()) - frac_positives.append(y_true[mask].mean()) - if bin_centers and frac_positives: - fig_cal = go.Figure() - fig_cal.add_trace( + # Utility: compute calibration points + def _calibration_points(y_true_bin: np.ndarray, scores: np.ndarray): + bins = np.linspace(0.0, 1.0, 11) + bin_ids = np.digitize(scores, bins, right=True) + bin_centers, frac_positives = [], [] + for b in range(1, len(bins)): + mask = bin_ids == b + if not np.any(mask): + continue + bin_centers.append(scores[mask].mean()) + frac_positives.append(y_true_bin[mask].mean()) + return bin_centers, frac_positives + + # Plot 2: Calibration Curve (multi-class aware; one-vs-rest per label) + label_prob_map = {} + for col in prob_cols_sorted: + if col.startswith("label_probabilities_"): + cls = col.replace("label_probabilities_", "") + label_prob_map[cls] = col + + unique_label_strs = [str(u) for u in unique_labels_list] + if len(label_prob_map) > 1 and len(unique_label_strs) > 2: + # Skip multi-class calibration curve for now (not informative in current report) + pass + else: + # Binary/unknown fallback (previous behavior) + bin_centers, frac_positives = _calibration_points(y_true, y_score) + if bin_centers and frac_positives: + fig_cal = go.Figure() + fig_cal.add_trace( + go.Scatter( + x=bin_centers, + y=frac_positives, + mode="lines+markers", + name="Calibration", + line=dict(color="#2ca02c", width=4), + ) + ) + fig_cal.add_trace( + go.Scatter( + x=[0, 1], + y=[0, 1], + mode="lines", + name="Perfect Calibration", + line=dict(color="gray", width=2, dash="dash"), + ) + ) + fig_cal.update_layout( + title=dict(text="Calibration Curve", x=0.5), + xaxis_title="Predicted probability", + yaxis_title="Observed frequency", + width=700, + height=500, + ) + _style_fig(fig_cal) + plots.append( + _wrap_plot( + "Calibration Curve (Predicted Probability vs Observed Frequency)", + fig_cal, + ) + ) + + return plots + + +def build_binary_threshold_plot( + predictions_path: str, + label_data_path: Optional[str] = None, + split_value: int = 1, +) -> Optional[Dict[str, str]]: + """Build a binary threshold sweep plot (accuracy, precision, recall, F1) for a given split.""" + preds_file = Path(predictions_path) + if not preds_file.exists(): + return None + + try: + df_pred = pd.read_csv(predictions_path) + except Exception as exc: + print(f"Warning: Unable to read predictions CSV for threshold plot: {exc}") + return None + + labels_from_dataset: Optional[pd.Series] = None + df_full = df_pred.copy() + + def _filter_by_split(df: pd.DataFrame, split_val: int) -> pd.DataFrame: + if SPLIT_COLUMN_NAME in df.columns: + return df[df[SPLIT_COLUMN_NAME] == split_val].reset_index(drop=True) + return df + + # Try preferred split, then fallback to others with data (val -> test -> train) + candidate_splits = [split_value, 2, 0, 1] if split_value == 1 else [split_value, 1, 0, 2] + df_candidate = pd.DataFrame() + used_split: Optional[int] = None + for sv in candidate_splits: + df_candidate = _filter_by_split(df_full, sv) + if not df_candidate.empty: + used_split = sv + break + if used_split is None: + df_candidate = df_full + df_pred = df_candidate.reset_index(drop=True) + + # If still empty (e.g., split column exists but no rows for candidates), fall back to all rows + if df_pred.empty: + df_pred = df_full.reset_index(drop=True) + labels_from_dataset = None + + if label_data_path and Path(label_data_path).exists(): + try: + df_labels_all = pd.read_csv(label_data_path) + if SPLIT_COLUMN_NAME in df_labels_all.columns and len(df_labels_all) == len(df_full): + mask = ( + pd.to_numeric(df_labels_all[SPLIT_COLUMN_NAME], errors="coerce") == used_split + if used_split is not None and SPLIT_COLUMN_NAME in df_labels_all.columns + else pd.Series([True] * len(df_full)) + ) + labels_from_dataset = df_labels_all.loc[mask, LABEL_COLUMN_NAME].reset_index(drop=True) + if len(labels_from_dataset) == len(df_pred): + labels_from_dataset = labels_from_dataset.reset_index(drop=True) + except Exception as exc: + print(f"Warning: Unable to align labels for threshold plot: {exc}") + + # Identify probability columns + prob_cols = [ + c + for c in df_pred.columns + if ( + (c.startswith("label_probabilities_") or c.startswith("probabilities_")) + and c != "label_probabilities" + ) + ] + if not prob_cols and "probabilities" in df_pred.columns: + labels_guess = sorted([str(u) for u in pd.unique(df_pred.get(LABEL_COLUMN_NAME, []))]) + # reuse expansion logic from diagnostics + try: + first_val = df_pred["probabilities"].dropna().iloc[0] + parsed = json.loads(first_val) if isinstance(first_val, str) else list(first_val) + n = len(parsed) + if n > 0: + if labels_guess and len(labels_guess) == n: + labels_use = labels_guess + else: + labels_use = [str(i) for i in range(n)] + for idx, lbl in enumerate(labels_use): + df_pred[f"probabilities_{lbl}"] = df_pred["probabilities"].apply( + lambda v: (json.loads(v)[idx] if isinstance(v, str) else list(v)[idx]) if pd.notnull(v) else np.nan + ) + prob_cols = [f"probabilities_{lbl}" for lbl in labels_use] + except Exception: + prob_cols = [] + prob_cols_sorted = sorted(prob_cols) + + def _strip_prob_prefix(col: str) -> str: + if col.startswith("label_probabilities_"): + return col.replace("label_probabilities_", "") + if col.startswith("probabilities_"): + return col.replace("probabilities_", "") + return col + + # True labels + def _extract_labels(): + if labels_from_dataset is not None: + return labels_from_dataset + for col in [ + LABEL_COLUMN_NAME, + f"{LABEL_COLUMN_NAME}_ground_truth", + f"{LABEL_COLUMN_NAME}__ground_truth", + f"{LABEL_COLUMN_NAME}_target", + f"{LABEL_COLUMN_NAME}__target", + "label", + "label_true", + "label_predictions", + "prediction", + ]: + if col in df_pred.columns and col not in prob_cols_sorted: + return df_pred[col] + return None + + labels_series = _extract_labels() + if labels_series is None or not prob_cols_sorted: + return None + + # Positive prob column selection + preferred_keys = ("event", "true", "positive", "pos", "1") + pos_prob_col = None + for col in prob_cols_sorted: + suffix = _strip_prob_prefix(col).lower() + if any(k in suffix for k in preferred_keys): + pos_prob_col = col + break + if pos_prob_col is None: + pos_prob_col = prob_cols_sorted[-1] + + min_len = min(len(labels_series), len(df_pred[pos_prob_col])) + if min_len == 0: + return None + + y_true = np.array(labels_series.iloc[:min_len]) + # map to binary 0/1 + unique_labels = pd.unique(y_true) + if len(unique_labels) < 2: + return None + positive_label = unique_labels[1] if len(unique_labels) >= 2 else unique_labels[0] + y_true_bin = (y_true == positive_label).astype(int) + y_score = np.array(df_pred[pos_prob_col].iloc[:min_len], dtype=float) + + thresholds = np.linspace(0.0, 1.0, 101) + accs: List[float] = [] + precs: List[float] = [] + recs: List[float] = [] + f1s: List[float] = [] + for t in thresholds: + preds = (y_score >= t).astype(int) + accs.append(accuracy_score(y_true_bin, preds)) + precs.append(precision_score(y_true_bin, preds, zero_division=0)) + recs.append(recall_score(y_true_bin, preds, zero_division=0)) + f1s.append(f1_score(y_true_bin, preds, zero_division=0)) + + best_idx = int(np.argmax(f1s)) + best_thr = thresholds[best_idx] + + fig = go.Figure() + fig.add_trace(go.Scatter(x=thresholds, y=accs, mode="lines", name="Accuracy", line=dict(width=4))) + fig.add_trace(go.Scatter(x=thresholds, y=precs, mode="lines", name="Precision", line=dict(width=4))) + fig.add_trace(go.Scatter(x=thresholds, y=recs, mode="lines", name="Recall", line=dict(width=4))) + fig.add_trace(go.Scatter(x=thresholds, y=f1s, mode="lines", name="F1-Score", line=dict(width=4))) + fig.add_shape( + type="line", + x0=best_thr, + x1=best_thr, + y0=0, + y1=1, + line=dict(color="gray", width=2, dash="dash"), + ) + fig.update_layout( + title=dict(text="Threshold plot", x=0.5), + xaxis_title="Threshold", + yaxis_title="Metric value", + yaxis=dict(range=[0, 1]), + width=760, + height=520, + legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1), + ) + _style_fig(fig) + return _wrap_plot("Threshold plot", fig, include_js=True) + + +def build_multiclass_roc_pr_plots( + predictions_path: str, + split_value: int = 2, +) -> List[Dict[str, str]]: + """Build one-vs-rest ROC and PR curves for multi-class classification from predictions.""" + preds_file = Path(predictions_path) + if not preds_file.exists(): + return [] + try: + df_pred = pd.read_csv(predictions_path) + except Exception as exc: + print(f"Warning: Unable to read predictions CSV: {exc}") + return [] + + if SPLIT_COLUMN_NAME in df_pred.columns: + df_pred = df_pred[df_pred[SPLIT_COLUMN_NAME] == split_value].reset_index(drop=True) + if df_pred.empty: + return [] + + if LABEL_COLUMN_NAME not in df_pred.columns: + return [] + + # Identify per-class probability columns + prob_cols = [ + c + for c in df_pred.columns + if ( + (c.startswith("label_probabilities_") or c.startswith("probabilities_")) + and c != "label_probabilities" + ) + ] + if not prob_cols: + return [] + labels = [c.replace("label_probabilities_", "").replace("probabilities_", "") for c in prob_cols] + labels_sorted = sorted(labels) + + # Ensure all labels are present as probability columns + prob_map = { + c.replace("label_probabilities_", "").replace("probabilities_", ""): c + for c in prob_cols + } + if len(labels_sorted) < 3: + return [] + + y_true_raw = df_pred[LABEL_COLUMN_NAME].astype(str) + # Drop rows with NaN probabilities across any class to avoid metric errors + prob_matrix = df_pred[[prob_map[lbl] for lbl in labels_sorted]].astype(float) + mask_valid = ~prob_matrix.isnull().any(axis=1) + prob_matrix = prob_matrix[mask_valid] + y_true_raw = y_true_raw[mask_valid] + if prob_matrix.empty: + return [] + + y_true_bin = label_binarize(y_true_raw, classes=labels_sorted) + y_score = prob_matrix.to_numpy() + + plots: List[Dict[str, str]] = [] + + # ROC: one-vs-rest + micro + fig_roc = go.Figure() + added_any = False + for idx, lbl in enumerate(labels_sorted): + if y_true_bin[:, idx].sum() == 0 or y_true_bin[:, idx].sum() == len(y_true_bin): + continue # skip classes without both positives and negatives + fpr, tpr, _ = roc_curve(y_true_bin[:, idx], y_score[:, idx]) + fig_roc.add_trace( go.Scatter( - x=bin_centers, - y=frac_positives, - mode="lines+markers", - name="Calibration", - line=dict(color="#2ca02c", width=4), - ) - ) - fig_cal.add_trace( - go.Scatter( - x=[0, 1], - y=[0, 1], + x=fpr, + y=tpr, mode="lines", - name="Perfect Calibration", - line=dict(color="gray", width=2, dash="dash"), + name=f"{lbl} (AUC={auc(fpr, tpr):.3f})", + line=dict(width=3), ) ) - fig_cal.update_layout( - title=dict(text="Calibration Curve", x=0.5), - xaxis_title="Predicted probability", - yaxis_title="Observed frequency", - width=700, - height=500, + added_any = True + # Micro-average only if we have mixed labels + if y_true_bin.sum() > 0 and y_true_bin.sum() < y_true_bin.size: + fpr_micro, tpr_micro, _ = roc_curve(y_true_bin.ravel(), y_score.ravel()) + fig_roc.add_trace( + go.Scatter( + x=fpr_micro, + y=tpr_micro, + mode="lines", + name=f"Micro-average (AUC={auc(fpr_micro, tpr_micro):.3f})", + line=dict(width=3, dash="dash"), + ) ) - _style_fig(fig_cal) - plots.append({ - "title": "Calibration Curve (Predicted Probability vs Observed Frequency)", - "html": pio.to_html(fig_cal, full_html=False, include_plotlyjs=False), - }) - - # Plot 3: Threshold vs Metrics - thresholds = np.linspace(0.0, 1.0, 21) - accs, f1s, sens, specs = [], [], [], [] - for t in thresholds: - y_pred = (y_score >= t).astype(int) - tp = np.sum((y_true == 1) & (y_pred == 1)) - tn = np.sum((y_true == 0) & (y_pred == 0)) - fp = np.sum((y_true == 0) & (y_pred == 1)) - fn = np.sum((y_true == 1) & (y_pred == 0)) - acc = (tp + tn) / max(len(y_true), 1) - prec = tp / max(tp + fp, 1e-9) - rec = tp / max(tp + fn, 1e-9) - f1 = 0 if (prec + rec) == 0 else 2 * prec * rec / (prec + rec) - sensitivity = rec - specificity = tn / max(tn + fp, 1e-9) - accs.append(acc) - f1s.append(f1) - sens.append(sensitivity) - specs.append(specificity) - - fig_thresh = go.Figure() - fig_thresh.add_trace(go.Scatter(x=thresholds, y=accs, mode="lines", name="Accuracy", line=dict(width=4))) - fig_thresh.add_trace(go.Scatter(x=thresholds, y=f1s, mode="lines", name="F1", line=dict(width=4))) - fig_thresh.add_trace(go.Scatter(x=thresholds, y=sens, mode="lines", name="Sensitivity", line=dict(width=4))) - fig_thresh.add_trace(go.Scatter(x=thresholds, y=specs, mode="lines", name="Specificity", line=dict(width=4))) - fig_thresh.update_layout( - title=dict(text="Threshold Sweep: Accuracy, F1, Sensitivity, Specificity", x=0.5), - xaxis_title="Decision threshold", - yaxis_title="Metric value", - width=700, - height=500, + added_any = True + if not added_any: + return [] + fig_roc.add_trace( + go.Scatter( + x=[0, 1], + y=[0, 1], + mode="lines", + name="Random", + line=dict(color="gray", width=2, dash="dot"), + ) + ) + fig_roc.update_layout( + title=dict(text="Multi-class ROC-AUC (one-vs-rest)", x=0.5), + xaxis_title="False Positive Rate", + yaxis_title="True Positive Rate", + width=820, + height=620, legend=dict( - x=0.7, - y=0.2, + x=0.62, + y=0.05, bgcolor="rgba(255,255,255,0.9)", bordercolor="rgba(0,0,0,0.2)", borderwidth=1, ), - shapes=[ - dict( - type="line", - x0=threshold, - x1=threshold, - y0=0, - y1=1, - xref="x", - yref="paper", - line=dict(color="#d62728", width=2, dash="dash"), + ) + _style_fig(fig_roc) + plots.append(_wrap_plot("Multi-class ROC-AUC (one-vs-rest)", fig_roc)) + + # PR: one-vs-rest + micro AP + fig_pr = go.Figure() + added_pr = False + for idx, lbl in enumerate(labels_sorted): + if y_true_bin[:, idx].sum() == 0: + continue + prec, rec, _ = precision_recall_curve(y_true_bin[:, idx], y_score[:, idx]) + ap = average_precision_score(y_true_bin[:, idx], y_score[:, idx]) + fig_pr.add_trace( + go.Scatter( + x=rec, + y=prec, + mode="lines", + name=f"{lbl} (AP={ap:.3f})", + line=dict(width=3), ) - ] if isinstance(threshold, (int, float)) else [], - annotations=[ - dict( - x=threshold, - y=1.02, - xref="x", - yref="paper", - showarrow=False, - text=f"Threshold = {threshold:.2f}", - font=dict(size=11, color="#d62728"), + ) + added_pr = True + if y_true_bin.sum() > 0: + prec_micro, rec_micro, _ = precision_recall_curve(y_true_bin.ravel(), y_score.ravel()) + ap_micro = average_precision_score(y_true_bin, y_score, average="micro") + fig_pr.add_trace( + go.Scatter( + x=rec_micro, + y=prec_micro, + mode="lines", + name=f"Micro-average (AP={ap_micro:.3f})", + line=dict(width=3, dash="dash"), ) - ] if isinstance(threshold, (int, float)) else [], + ) + added_pr = True + if not added_pr: + return plots + fig_pr.update_layout( + title=dict(text="Multi-class Precision-Recall (one-vs-rest)", x=0.5), + xaxis_title="Recall", + yaxis_title="Precision", + width=820, + height=620, + legend=dict( + x=0.62, + y=0.05, + bgcolor="rgba(255,255,255,0.9)", + bordercolor="rgba(0,0,0,0.2)", + borderwidth=1, + ), ) - _style_fig(fig_thresh) - plots.append({ - "title": "Threshold Sweep: Accuracy, F1, Sensitivity, Specificity", - "html": pio.to_html(fig_thresh, full_html=False, include_plotlyjs=False), - }) + _style_fig(fig_pr) + plots.append(_wrap_plot("Multi-class Precision-Recall (one-vs-rest)", fig_pr)) return plots + + +def build_multiclass_metric_plots(test_stats_path: str) -> List[Dict[str, str]]: + """Alternative multi-class transparency plots using test_statistics.json per-class stats.""" + ts_path = Path(test_stats_path) + if not ts_path.exists(): + return [] + try: + with open(ts_path, "r") as f: + test_stats = json.load(f) + except Exception: + return [] + + label_stats = test_stats.get("label", {}) + pcs = label_stats.get("per_class_stats", {}) + if not pcs: + return [] + classes = list(pcs.keys()) + if not classes: + return [] + + metrics = ["precision", "recall", "f1_score", "specificity", "accuracy"] + fig_bar = go.Figure() + for metric in metrics: + values = [] + for cls in classes: + v = pcs.get(cls, {}).get(metric) + values.append(v if isinstance(v, (int, float)) else 0) + fig_bar.add_trace( + go.Bar( + x=classes, + y=values, + name=metric.replace("_", " ").title(), + ) + ) + fig_bar.update_layout( + title=dict(text="Per-Class Metrics (Test)", x=0.5), + xaxis_title="Class", + yaxis_title="Metric value", + barmode="group", + width=900, + height=600, + legend=dict( + x=1.02, + y=1.0, + bgcolor="rgba(255,255,255,0.9)", + bordercolor="rgba(0,0,0,0.2)", + borderwidth=1, + ), + ) + _style_fig(fig_bar) + + return [_wrap_plot("Per-Class Metrics (Test)", fig_bar)]
