Configuration details unavailable.
"
+ )
+ if not config_html:
+ config_html = (
+ "No configuration details found.
"
+ )
# ---------- image rendering with exclusions ----------
def render_img_section(
@@ -776,6 +951,11 @@
for img in imgs
if img.name not in default_exclude
and img.name not in exclude_names
+ and not (
+ "learning_curves" in img.stem
+ and "loss" in img.stem
+ and "label" in img.stem
+ )
]
if not imgs:
@@ -802,7 +982,8 @@
)
return html_section
- tab1_content = config_html + metrics_html
+ # Show performance first, then config
+ tab1_content = metrics_html + config_html
tab2_content = train_val_metrics_html + render_img_section(
"Training and Validation Visualizations",
@@ -815,6 +996,21 @@
"precision_recall_curve.png",
},
)
+ if train_stats_path.exists():
+ try:
+ if output_type == "regression":
+ tv_plots = build_regression_train_val_plots(str(train_stats_path))
+ else:
+ tv_plots = build_train_validation_plots(str(train_stats_path))
+ for plot in tv_plots:
+ tab2_content += (
+ f""
+ "
"
+ preds_html
+ "
"
)
@@ -857,27 +1053,75 @@
logger.warning(f"Could not build Predictions vs GT table: {e}")
tab3_content = test_metrics_html + preds_section
+ test_plotly_added = False
+
+ if output_type == "regression" and train_stats_path.exists():
+ try:
+ test_plots = build_regression_test_plots(str(train_stats_path))
+ for plot in test_plots:
+ tab3_content += (
+ f"
{plot['title']}
"
+ f"
{plot['html']}
"
+ )
+ if test_plots:
+ test_plotly_added = True
+ logger.info(f"Generated {len(test_plots)} regression test plots")
+ except Exception as e:
+ logger.warning(f"Could not generate regression test plots: {e}")
if output_type in ("binary", "category") and test_stats_path.exists():
try:
interactive_plots = build_classification_plots(
str(test_stats_path),
str(train_stats_path) if train_stats_path.exists() else None,
+ metadata_csv_path=str(label_metadata_path)
+ if label_metadata_path and label_metadata_path.exists()
+ else None,
+ train_set_metadata_path=str(train_set_metadata_path)
+ if train_set_metadata_path.exists()
+ else None,
)
for plot in interactive_plots:
tab3_content += (
f"
{plot['title']}
"
f"
{plot['html']}
"
)
+ if interactive_plots:
+ test_plotly_added = True
logger.info(f"Generated {len(interactive_plots)} interactive Plotly plots")
except Exception as e:
logger.warning(f"Could not generate Plotly plots: {e}")
+ # Add prediction diagnostics from predictions.csv
+ predictions_csv_path = exp_dir / "predictions.csv"
+ try:
+ diag_plots = build_prediction_diagnostics(
+ str(predictions_csv_path),
+ label_data_path=str(config.get("label_column_data_path"))
+ if config.get("label_column_data_path")
+ else None,
+ threshold=config.get("threshold"),
+ )
+ for plot in diag_plots:
+ tab3_content += (
+ f"
{plot['title']}
"
+ f"
{plot['html']}
"
+ )
+ if diag_plots:
+ test_plotly_added = True
+ logger.info(f"Generated {len(diag_plots)} prediction diagnostic plots")
+ except Exception as e:
+ logger.warning(f"Could not generate prediction diagnostics: {e}")
+
+ # Fallback: include static PNGs if no interactive plots were added
+ if not test_plotly_added:
+ tab3_content += render_img_section(
+ "Test Visualizations (PNG fallback)",
+ test_viz_dir,
+ output_type,
+ )
+
# Add static TEST PNGs (with default dedupe/exclusions)
- tab3_content += render_img_section(
- "Test Visualizations", test_viz_dir, output_type
- )
-
tabbed_html = build_tabbed_html(tab1_content, tab2_content, tab3_content)
modal_html = get_metrics_help_modal()
html += tabbed_html + modal_html + get_html_closing()
diff -r 94cd9ac4a9b1 -r d17e3a1b8659 plotly_plots.py
--- a/plotly_plots.py Wed Nov 26 22:00:32 2025 +0000
+++ b/plotly_plots.py Fri Nov 28 15:45:49 2025 +0000
@@ -7,13 +7,105 @@
import plotly.graph_objects as go
import plotly.io as pio
from constants import LABEL_COLUMN_NAME, SPLIT_COLUMN_NAME
-from sklearn.metrics import auc, roc_curve
-from sklearn.preprocessing import label_binarize
+
+
+def _style_fig(fig: go.Figure, font_size: int = 12) -> go.Figure:
+ """Apply consistent styling across Plotly figures."""
+ fig.update_layout(
+ font=dict(size=font_size),
+ plot_bgcolor="#ffffff",
+ paper_bgcolor="#ffffff",
+ )
+ fig.update_xaxes(gridcolor="#e8e8e8")
+ fig.update_yaxes(gridcolor="#e8e8e8")
+ return fig
+
+
+def _labels_from_metadata_dict(meta_dict: dict) -> List[str]:
+ """Extract ordered label names from Ludwig train_set_metadata."""
+ if not isinstance(meta_dict, dict):
+ return []
+
+ for key in ("idx2str", "idx2label", "vocab"):
+ seq = meta_dict.get(key)
+ if isinstance(seq, list) and seq:
+ return [str(v) for v in seq]
+
+ str2idx = meta_dict.get("str2idx")
+ if isinstance(str2idx, dict) and str2idx:
+ int_indices = [v for v in str2idx.values() if isinstance(v, int)]
+ if int_indices:
+ max_idx = max(int_indices)
+ ordered = [None] * (max_idx + 1)
+ for name, idx in str2idx.items():
+ if isinstance(idx, int) and 0 <= idx < len(ordered):
+ ordered[idx] = name
+ return [str(v) for v in ordered if v is not None]
+
+ return []
+
+
+def _resolve_confusion_labels(
+ label_stats: dict,
+ n_classes: int,
+ metadata_csv_path: Optional[str],
+ train_set_metadata_path: Optional[str],
+) -> List[str]:
+ """Prefer original labels from metadata; fall back to stats if unavailable."""
+ if train_set_metadata_path:
+ try:
+ meta_path = Path(train_set_metadata_path)
+ if meta_path.exists():
+ with open(meta_path, "r") as f:
+ meta_json = json.load(f)
+ label_meta = meta_json.get(LABEL_COLUMN_NAME)
+ if not isinstance(label_meta, dict):
+ label_meta = next(
+ (
+ v
+ for v in meta_json.values()
+ if isinstance(v, dict)
+ and any(k in v for k in ("idx2str", "str2idx", "idx2label", "vocab"))
+ ),
+ None,
+ )
+ labels_from_meta = _labels_from_metadata_dict(label_meta) if label_meta else []
+ if labels_from_meta and len(labels_from_meta) >= n_classes:
+ return [str(label) for label in labels_from_meta[:n_classes]]
+ except Exception as exc:
+ print(f"Warning: Unable to read labels from train_set_metadata: {exc}")
+
+ if metadata_csv_path:
+ try:
+ csv_path = Path(metadata_csv_path)
+ if csv_path.exists():
+ df_meta = pd.read_csv(csv_path)
+ if LABEL_COLUMN_NAME in df_meta.columns:
+ uniques = df_meta[LABEL_COLUMN_NAME].dropna().unique().tolist()
+ if uniques and len(uniques) >= n_classes:
+ return [str(u) for u in uniques[:n_classes]]
+ except Exception as exc:
+ print(f"Warning: Unable to read labels from metadata CSV: {exc}")
+
+ pcs = label_stats.get("per_class_stats", {})
+ if pcs:
+ pcs_labels = [str(k) for k in pcs.keys()]
+ if len(pcs_labels) >= n_classes:
+ return pcs_labels[:n_classes]
+
+ labels = label_stats.get("labels")
+ if not labels:
+ labels = [str(i) for i in range(n_classes)]
+ if len(labels) < n_classes:
+ labels = labels + [str(i) for i in range(len(labels), n_classes)]
+ return [str(label) for label in labels[:n_classes]]
def build_classification_plots(
test_stats_path: str,
training_stats_path: Optional[str] = None,
+ metadata_csv_path: Optional[str] = None,
+ train_set_metadata_path: Optional[str] = None,
) -> List[Dict[str, str]]:
"""
Read Ludwig’s test_statistics.json and build three interactive Plotly panels:
@@ -21,6 +113,9 @@
- ROC-AUC
- Classification Report Heatmap
+ If metadata paths are provided, the confusion matrix axes will use the original
+ label values from the training metadata rather than integer-encoded labels.
+
Returns a list of dicts, each with:
{
"title":
,
@@ -42,12 +137,12 @@
# 0) Confusion Matrix
cm = np.array(label_stats["confusion_matrix"], dtype=int)
- # Try to get actual class names from per_class_stats keys (which contain the real labels)
- pcs = label_stats.get("per_class_stats", {})
- if pcs:
- labels = list(pcs.keys())
- else:
- labels = label_stats.get("labels", [str(i) for i in range(cm.shape[0])])
+ labels = _resolve_confusion_labels(
+ label_stats,
+ n_classes,
+ metadata_csv_path=metadata_csv_path,
+ train_set_metadata_path=train_set_metadata_path,
+ )
total = cm.sum()
fig_cm = go.Figure(
@@ -70,6 +165,7 @@
height=side_px,
margin=dict(t=100, l=80, r=80, b=80),
)
+ _style_fig(fig_cm)
# annotate counts and percentages
mval = cm.max() if cm.size else 0
@@ -110,16 +206,28 @@
)
})
- # 1) ROC-AUC Curves (Multi-class)
- roc_plot = _build_roc_auc_plot(test_stats_path, labels, common_cfg)
+ # 1) ROC Curve (from test_statistics)
+ roc_plot = _build_static_roc_plot(label_stats, common_cfg, friendly_labels=labels)
if roc_plot:
plots.append(roc_plot)
+ # 2) Precision-Recall Curve (from test_statistics)
+ pr_plot = _build_precision_recall_plot(label_stats, common_cfg)
+ if pr_plot:
+ plots.append(pr_plot)
+
# 2) Classification Report Heatmap
pcs = label_stats.get("per_class_stats", {})
if pcs:
classes = list(pcs.keys())
- metrics = ["precision", "recall", "f1_score"]
+ metrics = [
+ "precision",
+ "recall",
+ "f1_score",
+ "accuracy",
+ "matthews_correlation_coefficient",
+ "specificity",
+ ]
z, txt = [], []
for c in classes:
row, trow = [], []
@@ -133,7 +241,7 @@
fig_cr = go.Figure(
go.Heatmap(
z=z,
- x=metrics,
+ x=[m.replace("_", " ") for m in metrics],
y=[str(c) for c in classes],
text=txt,
texttemplate="%{text}",
@@ -143,15 +251,16 @@
)
)
fig_cr.update_layout(
- title="Classification Report",
+ title="Per-Class metrics",
xaxis_title="",
yaxis_title="Class",
width=side_px,
height=side_px,
margin=dict(t=80, l=80, r=80, b=80),
)
+ _style_fig(fig_cr)
plots.append({
- "title": "Classification Report",
+ "title": "Per-Class metrics",
"html": pio.to_html(
fig_cr,
full_html=False,
@@ -160,68 +269,667 @@
)
})
+ # 3) Prediction Diagnostics (from predictions.csv)
+ # Note: appended separately in generate_html_report, not returned here.
+
+ return plots
+
+
+def build_train_validation_plots(train_stats_path: str) -> List[Dict[str, str]]:
+ """Generate Train/Validation learning curve plots from training_statistics.json."""
+ if not train_stats_path or not Path(train_stats_path).exists():
+ return []
+ try:
+ with open(train_stats_path, "r") as f:
+ train_stats = json.load(f)
+ except Exception as exc:
+ print(f"Warning: Unable to read training statistics: {exc}")
+ return []
+
+ label_train = (train_stats.get("training") or {}).get("label", {})
+ label_val = (train_stats.get("validation") or {}).get("label", {})
+ if not label_train and not label_val:
+ return []
+ plots: List[Dict[str, str]] = []
+ include_js = True # Load Plotly.js once for this group
+
+ def _get_series(stats: dict, metric: str) -> List[float]:
+ if metric not in stats:
+ return []
+ vals = stats.get(metric, [])
+ if isinstance(vals, list):
+ return [float(v) for v in vals]
+ try:
+ return [float(vals)]
+ except Exception:
+ return []
+
+ def _line_plot(metric_key: str, title: str, yaxis_title: str) -> Optional[Dict[str, str]]:
+ train_series = _get_series(label_train, metric_key)
+ val_series = _get_series(label_val, metric_key)
+ if not train_series and not val_series:
+ return None
+ epochs_train = list(range(1, len(train_series) + 1))
+ epochs_val = list(range(1, len(val_series) + 1))
+ fig = go.Figure()
+ if train_series:
+ fig.add_trace(
+ go.Scatter(
+ x=epochs_train,
+ y=train_series,
+ mode="lines+markers",
+ name="Train",
+ line=dict(width=4),
+ )
+ )
+ if val_series:
+ fig.add_trace(
+ go.Scatter(
+ x=epochs_val,
+ y=val_series,
+ mode="lines+markers",
+ name="Validation",
+ line=dict(width=4),
+ )
+ )
+ fig.update_layout(
+ title=dict(text=title, x=0.5),
+ xaxis_title="Epoch",
+ yaxis_title=yaxis_title,
+ width=760,
+ height=520,
+ hovermode="x unified",
+ )
+ _style_fig(fig)
+ return {
+ "title": title,
+ "html": pio.to_html(
+ fig,
+ full_html=False,
+ include_plotlyjs="cdn" if include_js else False,
+ ),
+ }
+
+ # Core learning curves
+ for key, title in [
+ ("roc_auc", "ROC-AUC across epochs"),
+ ("precision", "Precision across epochs"),
+ ("recall", "Recall/Sensitivity across epochs"),
+ ("specificity", "Specificity across epochs"),
+ ]:
+ plot = _line_plot(key, title, title.replace("Learning Curve", "").strip())
+ if plot:
+ plots.append(plot)
+ include_js = False
+
+ # Precision vs Recall evolution (validation)
+ val_prec = _get_series(label_val, "precision")
+ val_rec = _get_series(label_val, "recall")
+ if val_prec and val_rec:
+ epochs = list(range(1, min(len(val_prec), len(val_rec)) + 1))
+ fig_pr = go.Figure()
+ fig_pr.add_trace(
+ go.Scatter(
+ x=epochs,
+ y=val_prec[: len(epochs)],
+ mode="lines+markers",
+ name="Precision",
+ )
+ )
+ fig_pr.add_trace(
+ go.Scatter(
+ x=epochs,
+ y=val_rec[: len(epochs)],
+ mode="lines+markers",
+ name="Recall",
+ )
+ )
+ fig_pr.update_layout(
+ title=dict(text="Validation Precision and Recall by Epoch", x=0.5),
+ xaxis_title="Epoch",
+ yaxis_title="Value",
+ width=760,
+ height=520,
+ hovermode="x unified",
+ )
+ _style_fig(fig_pr)
+ plots.append({
+ "title": "Precision vs Recall Evolution",
+ "html": pio.to_html(
+ fig_pr,
+ full_html=False,
+ include_plotlyjs="cdn" if include_js else False,
+ ),
+ })
+ include_js = False
+
+ # F1-score derived
+ def _compute_f1(p: List[float], r: List[float]) -> List[float]:
+ f1_vals = []
+ for prec, rec in zip(p, r):
+ if (prec + rec) == 0:
+ f1_vals.append(0.0)
+ else:
+ f1_vals.append(2 * prec * rec / (prec + rec))
+ return f1_vals
+
+ f1_train = _compute_f1(_get_series(label_train, "precision"), _get_series(label_train, "recall"))
+ f1_val = _compute_f1(val_prec, val_rec)
+ if f1_train or f1_val:
+ fig = go.Figure()
+ if f1_train:
+ fig.add_trace(go.Scatter(x=list(range(1, len(f1_train) + 1)), y=f1_train, mode="lines+markers", name="Train", line=dict(width=4)))
+ if f1_val:
+ fig.add_trace(go.Scatter(x=list(range(1, len(f1_val) + 1)), y=f1_val, mode="lines+markers", name="Validation", line=dict(width=4)))
+ fig.update_layout(
+ title=dict(text="F1-Score across epochs (derived)", x=0.5),
+ xaxis_title="Epoch",
+ yaxis_title="F1-Score",
+ width=760,
+ height=520,
+ hovermode="x unified",
+ )
+ _style_fig(fig)
+ plots.append({
+ "title": "F1-Score across epochs (derived)",
+ "html": pio.to_html(
+ fig,
+ full_html=False,
+ include_plotlyjs="cdn" if include_js else False,
+ ),
+ })
+ include_js = False
+
+ # Overfitting Gap: Train vs Val ROC-AUC (gap)
+ roc_train = _get_series(label_train, "roc_auc")
+ roc_val = _get_series(label_val, "roc_auc")
+ if roc_train and roc_val:
+ epochs_gap = list(range(1, min(len(roc_train), len(roc_val)) + 1))
+ gaps = [t - v for t, v in zip(roc_train[:len(epochs_gap)], roc_val[:len(epochs_gap)])]
+ fig_gap = go.Figure()
+ fig_gap.add_trace(go.Scatter(x=epochs_gap, y=gaps, mode="lines+markers", name="Train - Val ROC-AUC", line=dict(width=4)))
+ fig_gap.update_layout(
+ title=dict(text="Overfitting gap: ROC-AUC across epochs", x=0.5),
+ xaxis_title="Epoch",
+ yaxis_title="Gap",
+ width=760,
+ height=520,
+ hovermode="x unified",
+ )
+ _style_fig(fig_gap)
+ plots.append({
+ "title": "Overfitting gap: ROC-AUC across epochs",
+ "html": pio.to_html(
+ fig_gap,
+ full_html=False,
+ include_plotlyjs="cdn" if include_js else False,
+ ),
+ })
+ include_js = False
+
+ # Best Epoch Dashboard (based on max val ROC-AUC)
+ if roc_val:
+ best_idx = int(np.argmax(roc_val))
+ best_epoch = best_idx + 1
+ spec_val = _get_series(label_val, "specificity")
+ metrics_at_best = {
+ "ROC-AUC": roc_val[best_idx] if best_idx < len(roc_val) else None,
+ "Precision": val_prec[best_idx] if best_idx < len(val_prec) else None,
+ "Recall": val_rec[best_idx] if best_idx < len(val_rec) else None,
+ "Specificity": spec_val[best_idx] if best_idx < len(spec_val) else None,
+ "F1-Score": f1_val[best_idx] if best_idx < len(f1_val) else None,
+ }
+ fig_best = go.Figure()
+ for name, value in metrics_at_best.items():
+ if value is not None:
+ fig_best.add_trace(go.Bar(name=name, x=[name], y=[value]))
+ fig_best.update_layout(
+ title=dict(text=f"Best Epoch Dashboard (Val ROC-AUC @ epoch {best_epoch})", x=0.5),
+ xaxis_title="Metric",
+ yaxis_title="Value",
+ width=760,
+ height=520,
+ showlegend=False,
+ )
+ _style_fig(fig_best)
+ plots.append({
+ "title": "Best Validation Epoch Snapshot (Metrics)",
+ "html": pio.to_html(
+ fig_best,
+ full_html=False,
+ include_plotlyjs="cdn" if include_js else False,
+ ),
+ })
+ include_js = False
+
return plots
-def _build_roc_auc_plot(test_stats_path: str, class_labels: List[str], config: dict) -> Optional[Dict[str, str]]:
- """
- Build an interactive ROC-AUC curve plot for multi-class classification.
- Following sklearn's ROC example with micro-average and per-class curves.
+def _get_regression_series(split_stats: dict, metric: str) -> List[float]:
+ if metric not in split_stats:
+ return []
+ vals = split_stats.get(metric, [])
+ if isinstance(vals, list):
+ return [float(v) for v in vals]
+ try:
+ return [float(vals)]
+ except Exception:
+ return []
+
- Args:
- test_stats_path: Path to test_statistics.json
- class_labels: List of class label names
- config: Plotly config dict
+def _regression_line_plot(
+ train_split: dict,
+ val_split: dict,
+ metric_key: str,
+ title: str,
+ yaxis_title: str,
+ include_js: bool,
+) -> Optional[Dict[str, str]]:
+ train_series = _get_regression_series(train_split, metric_key)
+ val_series = _get_regression_series(val_split, metric_key)
+ if not train_series and not val_series:
+ return None
+ epochs_train = list(range(1, len(train_series) + 1))
+ epochs_val = list(range(1, len(val_series) + 1))
+ fig = go.Figure()
+ if train_series:
+ fig.add_trace(
+ go.Scatter(
+ x=epochs_train,
+ y=train_series,
+ mode="lines+markers",
+ name="Train",
+ line=dict(width=4),
+ )
+ )
+ if val_series:
+ fig.add_trace(
+ go.Scatter(
+ x=epochs_val,
+ y=val_series,
+ mode="lines+markers",
+ name="Validation",
+ line=dict(width=4),
+ )
+ )
+ fig.update_layout(
+ title=dict(text=title, x=0.5),
+ xaxis_title="Epoch",
+ yaxis_title=yaxis_title,
+ width=760,
+ height=520,
+ hovermode="x unified",
+ )
+ _style_fig(fig)
+ return {
+ "title": title,
+ "html": pio.to_html(
+ fig,
+ full_html=False,
+ include_plotlyjs="cdn" if include_js else False,
+ ),
+ }
+
+
+def build_regression_train_val_plots(train_stats_path: str) -> List[Dict[str, str]]:
+ """Generate regression Train/Validation learning curve plots from training_statistics.json."""
+ if not train_stats_path or not Path(train_stats_path).exists():
+ return []
+ try:
+ with open(train_stats_path, "r") as f:
+ train_stats = json.load(f)
+ except Exception as exc:
+ print(f"Warning: Unable to read training statistics: {exc}")
+ return []
+
+ label_train = (train_stats.get("training") or {}).get("label", {})
+ label_val = (train_stats.get("validation") or {}).get("label", {})
+ if not label_train and not label_val:
+ return []
+
+ plots: List[Dict[str, str]] = []
+ include_js = True
+ for metric_key, title, ytitle in [
+ ("mean_absolute_error", "Mean Absolute Error across epochs", "MAE"),
+ ("root_mean_squared_error", "Root Mean Squared Error across epochs", "RMSE"),
+ ("mean_absolute_percentage_error", "Mean Absolute Percentage Error across epochs", "MAPE"),
+ ("r2", "R² across epochs", "R²"),
+ ("loss", "Loss across epochs", "Loss"),
+ ]:
+ plot = _regression_line_plot(label_train, label_val, metric_key, title, ytitle, include_js)
+ if plot:
+ plots.append(plot)
+ include_js = False
+ return plots
+
- Returns:
- Dict with title and HTML, or None if data unavailable
- """
+def build_regression_test_plots(train_stats_path: str) -> List[Dict[str, str]]:
+ """Generate regression Test learning curves from training_statistics.json."""
+ if not train_stats_path or not Path(train_stats_path).exists():
+ return []
try:
- # Get the experiment directory from test_stats_path
- exp_dir = Path(test_stats_path).parent
+ with open(train_stats_path, "r") as f:
+ train_stats = json.load(f)
+ except Exception as exc:
+ print(f"Warning: Unable to read training statistics: {exc}")
+ return []
+
+ label_test = (train_stats.get("test") or {}).get("label", {})
+ if not label_test:
+ return []
- # Load predictions with probabilities
- predictions_path = exp_dir / "predictions.csv"
- if not predictions_path.exists():
- return None
+ plots: List[Dict[str, str]] = []
+ include_js = True
+ metrics = [
+ ("mean_absolute_error", "Mean Absolute Error Across Epochs", "MAE"),
+ ("root_mean_squared_error", "Root Mean Squared Error Across Epochs", "RMSE"),
+ ("mean_absolute_percentage_error", "Mean Absolute Percentage Error Across Epochs", "MAPE"),
+ ("r2", "R² Across Epochs", "R²"),
+ ("loss", "Loss Across Epochs", "Loss"),
+ ]
+ epochs = None
+ for metric_key, title, ytitle in metrics:
+ series = _get_regression_series(label_test, metric_key)
+ if not series:
+ continue
+ if epochs is None:
+ epochs = list(range(1, len(series) + 1))
+ fig = go.Figure()
+ fig.add_trace(
+ go.Scatter(
+ x=epochs,
+ y=series[: len(epochs)],
+ mode="lines+markers",
+ name="Test",
+ line=dict(width=4),
+ )
+ )
+ fig.update_layout(
+ title=dict(text=title, x=0.5),
+ xaxis_title="Epoch",
+ yaxis_title=ytitle,
+ width=760,
+ height=520,
+ hovermode="x unified",
+ )
+ _style_fig(fig)
+ plots.append({
+ "title": title,
+ "html": pio.to_html(
+ fig,
+ full_html=False,
+ include_plotlyjs="cdn" if include_js else False,
+ ),
+ })
+ include_js = False
+ return plots
- df_pred = pd.read_csv(predictions_path)
+
+def _build_static_roc_plot(
+ label_stats: dict, config: dict, friendly_labels: Optional[List[str]] = None
+) -> Optional[Dict[str, str]]:
+ """Build ROC curve directly from test_statistics.json (single curve)."""
+ roc_data = label_stats.get("roc_curve")
+ if not isinstance(roc_data, dict):
+ return None
+
+ fpr = roc_data.get("false_positive_rate")
+ tpr = roc_data.get("true_positive_rate")
+ if not fpr or not tpr or len(fpr) != len(tpr):
+ return None
+
+ try:
+ fig = go.Figure()
+ fig.add_trace(
+ go.Scatter(
+ x=fpr,
+ y=tpr,
+ mode="lines+markers",
+ name="ROC Curve",
+ line=dict(color="#1f77b4", width=4),
+ hovertemplate="FPR: %{x:.3f}
TPR: %{y:.3f}",
+ )
+ )
+ fig.add_trace(
+ go.Scatter(
+ x=[0, 1],
+ y=[0, 1],
+ mode="lines",
+ name="Random Classifier",
+ line=dict(color="gray", width=2, dash="dash"),
+ hovertemplate="Random Classifier",
+ )
+ )
+
+ auc_val = label_stats.get("roc_auc") or label_stats.get("roc_auc_macro") or label_stats.get("roc_auc_micro")
+ auc_txt = f" (AUC = {auc_val:.3f})" if isinstance(auc_val, (int, float)) else ""
- if SPLIT_COLUMN_NAME in df_pred.columns:
- split_series = df_pred[SPLIT_COLUMN_NAME].astype(str).str.lower()
- test_mask = split_series.isin({"2", "test", "testing"})
- if test_mask.any():
- df_pred = df_pred[test_mask].reset_index(drop=True)
+ # Determine which label is treated as positive for the curve
+ label_list: List = []
+ pcs = label_stats.get("per_class_stats", {})
+ if pcs:
+ label_list = list(pcs.keys())
+ if not label_list:
+ labels_from_stats = label_stats.get("labels")
+ if isinstance(labels_from_stats, list):
+ label_list = labels_from_stats
+
+ # Try to resolve index of the positive label explicitly provided by Ludwig
+ pos_label_raw = (
+ roc_data.get("positive_label")
+ or roc_data.get("positive_class")
+ or label_stats.get("positive_label")
+ )
+ pos_label_idx = None
+ if pos_label_raw is not None and isinstance(label_list, list):
+ try:
+ pos_label_idx = label_list.index(pos_label_raw)
+ except ValueError:
+ pos_label_idx = None
+
+ # Fallback: use the second label if available, otherwise the first
+ if pos_label_idx is None:
+ if isinstance(label_list, list) and len(label_list) >= 2:
+ pos_label_idx = 1
+ elif isinstance(label_list, list) and label_list:
+ pos_label_idx = 0
+
+ if pos_label_raw is None and isinstance(label_list, list) and pos_label_idx is not None:
+ pos_label_raw = label_list[pos_label_idx]
+
+ # Map to friendly label if we have one from metadata/CSV
+ pos_label_display = pos_label_raw
+ if (
+ friendly_labels
+ and isinstance(pos_label_idx, int)
+ and 0 <= pos_label_idx < len(friendly_labels)
+ ):
+ pos_label_display = friendly_labels[pos_label_idx]
+
+ pos_label_txt = (
+ f"Positive class: {pos_label_display}"
+ if pos_label_display is not None
+ else "Positive class: (not available)"
+ )
+
+ title_label = f"ROC Curve{auc_txt}"
+ if pos_label_display is not None:
+ title_label = f"ROC Curve (Positive Class: {pos_label_display}){auc_txt}"
- if df_pred.empty:
- return None
+ fig.update_layout(
+ title=dict(text=title_label, x=0.5),
+ xaxis_title="False Positive Rate",
+ yaxis_title="True Positive Rate",
+ width=700,
+ height=600,
+ margin=dict(t=80, l=80, r=80, b=110),
+ hovermode="closest",
+ legend=dict(
+ x=0.6,
+ y=0.1,
+ bgcolor="rgba(255,255,255,0.9)",
+ bordercolor="rgba(0,0,0,0.2)",
+ borderwidth=1,
+ ),
+ )
+ _style_fig(fig)
+ fig.update_xaxes(range=[0, 1.0])
+ fig.update_yaxes(range=[0, 1.05])
- # Extract probability columns (label_probabilities_0, label_probabilities_1, etc.)
- # or label_probabilities_ for string labels
- prob_cols = [col for col in df_pred.columns if col.startswith('label_probabilities_') and col != 'label_probabilities']
+ fig.add_annotation(
+ x=0.5,
+ y=-0.15,
+ xref="paper",
+ yref="paper",
+ showarrow=False,
+ text=f"{pos_label_txt}",
+ xanchor="center",
+ )
+
+ return {
+ "title": "ROC Curve",
+ "html": pio.to_html(
+ fig,
+ full_html=False,
+ include_plotlyjs=False,
+ config=config,
+ ),
+ }
+ except Exception as e:
+ print(f"Error building ROC plot: {e}")
+ return None
+
+
+def _build_precision_recall_plot(label_stats: dict, config: dict) -> Optional[Dict[str, str]]:
+ """Build Precision-Recall curve directly from test_statistics.json."""
+ pr_data = label_stats.get("precision_recall_curve")
+ if not isinstance(pr_data, dict):
+ return None
+
+ precisions = pr_data.get("precisions")
+ recalls = pr_data.get("recalls")
+ if not precisions or not recalls or len(precisions) != len(recalls):
+ return None
- # Sort by class number if numeric, otherwise keep alphabetical order
- if prob_cols and prob_cols[0].split('_')[-1].isdigit():
- prob_cols.sort(key=lambda x: int(x.split('_')[-1]))
- else:
- prob_cols.sort() # Alphabetical sort for string class names
+ try:
+ fig = go.Figure()
+ fig.add_trace(
+ go.Scatter(
+ x=recalls,
+ y=precisions,
+ mode="lines+markers",
+ name="Precision-Recall",
+ line=dict(color="#d62728", width=4),
+ hovertemplate="Recall: %{x:.3f}
Precision: %{y:.3f}",
+ )
+ )
+
+ ap_val = (
+ label_stats.get("average_precision_macro")
+ or label_stats.get("average_precision_micro")
+ or label_stats.get("average_precision_samples")
+ )
+ ap_txt = f" (AP = {ap_val:.3f})" if isinstance(ap_val, (int, float)) else ""
+
+ fig.update_layout(
+ title=dict(text=f"Precision-Recall Curve{ap_txt}", x=0.5),
+ xaxis_title="Recall",
+ yaxis_title="Precision",
+ width=700,
+ height=600,
+ margin=dict(t=80, l=80, r=80, b=80),
+ hovermode="closest",
+ legend=dict(
+ x=0.6,
+ y=0.1,
+ bgcolor="rgba(255,255,255,0.9)",
+ bordercolor="rgba(0,0,0,0.2)",
+ borderwidth=1,
+ ),
+ )
+ _style_fig(fig)
+ fig.update_xaxes(range=[0, 1.0])
+ fig.update_yaxes(range=[0, 1.05])
+
+ return {
+ "title": "Precision-Recall Curve",
+ "html": pio.to_html(
+ fig,
+ full_html=False,
+ include_plotlyjs=False,
+ config=config,
+ ),
+ }
+ except Exception as e:
+ print(f"Error building Precision-Recall plot: {e}")
+ return None
+
- if not prob_cols:
- return None
+def build_prediction_diagnostics(
+ predictions_path: str,
+ label_data_path: Optional[str] = None,
+ split_value: int = 2,
+ threshold: Optional[float] = None,
+) -> List[Dict[str, str]]:
+ """Generate diagnostic plots from predictions.csv for classification tasks."""
+ preds_file = Path(predictions_path)
+ if not preds_file.exists():
+ return []
+
+ try:
+ df_pred = pd.read_csv(predictions_path)
+ except Exception as exc:
+ print(f"Warning: Unable to read predictions CSV: {exc}")
+ return []
+
+ plots: List[Dict[str, str]] = []
+
+ # Identify probability columns
+ prob_cols = [
+ c for c in df_pred.columns
+ if c.startswith("label_probabilities_") and c != "label_probabilities"
+ ]
+ prob_cols_sorted = sorted(prob_cols)
- # Get probabilities matrix (n_samples x n_classes)
- y_score = df_pred[prob_cols].values
- n_classes = len(prob_cols)
+ def _select_positive_prob():
+ if not prob_cols_sorted:
+ return None, None
+ # Prefer a column indicating positive/event/true/1
+ preferred_keys = ("event", "true", "positive", "pos", "1")
+ for col in prob_cols_sorted:
+ suffix = col.replace("label_probabilities_", "").lower()
+ if any(k in suffix for k in preferred_keys):
+ return col, suffix
+ if len(prob_cols_sorted) == 2:
+ col = prob_cols_sorted[1]
+ return col, col.replace("label_probabilities_", "")
+ col = prob_cols_sorted[0]
+ return col, col.replace("label_probabilities_", "")
- y_true = None
- candidate_cols = [
+ pos_prob_col, pos_label_hint = _select_positive_prob()
+ pos_prob_series = df_pred[pos_prob_col] if pos_prob_col and pos_prob_col in df_pred else None
+
+ # Confidence series: prefer label_probability, otherwise positive prob, otherwise max prob
+ confidence_series = None
+ if "label_probability" in df_pred.columns:
+ confidence_series = df_pred["label_probability"]
+ elif pos_prob_series is not None:
+ confidence_series = pos_prob_series
+ elif prob_cols_sorted:
+ confidence_series = df_pred[prob_cols_sorted].max(axis=1)
+
+ # True labels
+ def _extract_labels():
+ candidates = [
LABEL_COLUMN_NAME,
f"{LABEL_COLUMN_NAME}_ground_truth",
f"{LABEL_COLUMN_NAME}__ground_truth",
f"{LABEL_COLUMN_NAME}_target",
f"{LABEL_COLUMN_NAME}__target",
+ "label",
+ "label_true",
]
- candidate_cols.extend(
+ candidates.extend(
[
col
for col in df_pred.columns
@@ -230,174 +938,182 @@
and "predictions" not in col
]
)
- for col in candidate_cols:
- if col in df_pred.columns and col not in prob_cols:
- y_true = df_pred[col].values
- break
+ for col in candidates:
+ if col in df_pred.columns and col not in prob_cols_sorted:
+ return df_pred[col]
+ if label_data_path and Path(label_data_path).exists():
+ try:
+ df_all = pd.read_csv(label_data_path)
+ if SPLIT_COLUMN_NAME in df_all.columns:
+ df_all = df_all[df_all[SPLIT_COLUMN_NAME] == split_value].reset_index(drop=True)
+ if LABEL_COLUMN_NAME in df_all.columns:
+ return df_all[LABEL_COLUMN_NAME].reset_index(drop=True)
+ except Exception as exc:
+ print(f"Warning: Unable to load labels from dataset: {exc}")
+ return None
- if y_true is None:
- desc_path = exp_dir / "description.json"
- if desc_path.exists():
- try:
- with open(desc_path, 'r') as f:
- desc = json.load(f)
- dataset_path = desc.get('dataset', '')
- if dataset_path and Path(dataset_path).exists():
- df_orig = pd.read_csv(dataset_path)
- if SPLIT_COLUMN_NAME in df_orig.columns:
- df_orig = df_orig[df_orig[SPLIT_COLUMN_NAME] == 2].reset_index(drop=True)
- if LABEL_COLUMN_NAME in df_orig.columns:
- y_true = df_orig[LABEL_COLUMN_NAME].values
- if len(y_true) != len(df_pred):
- print(
- f"Warning: Test set size mismatch. Truncating to {len(df_pred)} samples for ROC plot."
- )
- y_true = y_true[:len(df_pred)]
- else:
- print("Warning: Original dataset referenced in description.json is unavailable.")
- except Exception as exc: # pragma: no cover - defensive
- print(f"Warning: Failed to recover labels from dataset: {exc}")
-
- if y_true is None or len(y_true) == 0:
- print("Warning: Unable to locate ground-truth labels for ROC plot.")
- return None
-
- if len(y_true) != len(y_score):
- limit = min(len(y_true), len(y_score))
- if limit == 0:
- return None
- print(f"Warning: Aligning prediction and label lengths to {limit} samples for ROC plot.")
- y_true = y_true[:limit]
- y_score = y_score[:limit]
+ labels_series = _extract_labels()
- # Get actual class names from probability column names
- actual_classes = [col.replace('label_probabilities_', '') for col in prob_cols]
- display_classes = class_labels if len(class_labels) == n_classes else actual_classes
-
- # Binarize the output following sklearn example
- # Use actual class names if they're strings, otherwise use range
- if isinstance(y_true[0], str):
- y_test = label_binarize(y_true, classes=actual_classes)
- else:
- y_test = label_binarize(y_true, classes=list(range(n_classes)))
-
- # Handle binary classification case
- if y_test.ndim != 2:
- y_test = np.atleast_2d(y_test)
+ # Plot 1: Confidence Histogram
+ if confidence_series is not None:
+ fig_conf = go.Figure()
+ fig_conf.add_trace(
+ go.Histogram(
+ x=confidence_series,
+ nbinsx=20,
+ marker=dict(color="#1f77b4", line=dict(color="#ffffff", width=1)),
+ opacity=0.8,
+ histnorm="percent",
+ )
+ )
+ fig_conf.update_layout(
+ title=dict(text="Prediction Confidence Distribution", x=0.5),
+ xaxis_title="Predicted probability (confidence)",
+ yaxis_title="Percentage (%)",
+ bargap=0.05,
+ width=700,
+ height=500,
+ )
+ _style_fig(fig_conf)
+ plots.append({
+ "title": "Prediction Confidence Distribution",
+ "html": pio.to_html(fig_conf, full_html=False, include_plotlyjs=False),
+ })
- if n_classes == 2:
- if y_test.shape[1] == 1:
- y_test = np.hstack([1 - y_test, y_test])
- elif y_test.shape[1] != 2:
- print("Warning: Unexpected label binarization shape for binary ROC plot.")
- return None
- elif y_test.shape[1] != n_classes:
- print("Warning: Label binarization did not produce expected class dimension; skipping ROC plot.")
- return None
+ # The remaining plots require true labels and a positive-class probability
+ if labels_series is None or pos_prob_series is None:
+ return plots
+
+ # Align lengths
+ min_len = min(len(labels_series), len(pos_prob_series))
+ if min_len == 0:
+ return plots
+ y_true_raw = labels_series.iloc[:min_len]
+ y_score = np.array(pos_prob_series.iloc[:min_len], dtype=float)
- # Compute ROC curve and ROC area for each class (following sklearn example)
- fpr = dict()
- tpr = dict()
- roc_auc = dict()
+ # Determine positive label
+ unique_labels = pd.unique(y_true_raw)
+ unique_labels_list = list(unique_labels)
+ positive_label = None
+ if pos_label_hint and str(pos_label_hint) in [str(u) for u in unique_labels_list]:
+ positive_label = pos_label_hint
+ elif len(unique_labels_list) == 2:
+ positive_label = unique_labels_list[1]
+ else:
+ positive_label = unique_labels_list[0]
- for i in range(n_classes):
- if np.sum(y_test[:, i]) > 0: # Check if class exists in test set
- fpr[i], tpr[i], _ = roc_curve(y_test[:, i], y_score[:, i])
- roc_auc[i] = auc(fpr[i], tpr[i])
-
- # Compute micro-average ROC curve and ROC area (sklearn example)
- fpr["micro"], tpr["micro"], _ = roc_curve(y_test.ravel(), y_score.ravel())
- roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])
-
- # Create ROC curve plot
- fig_roc = go.Figure()
+ y_true = (y_true_raw == positive_label).astype(int).values
- # Colors for different classes
- colors = [
- '#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd',
- '#8c564b', '#e377c2', '#7f7f7f', '#bcbd22', '#17becf'
- ]
-
- # Plot micro-average ROC curve first (most important)
- fig_roc.add_trace(go.Scatter(
- x=fpr["micro"],
- y=tpr["micro"],
- mode='lines',
- name=f'Micro-average ROC (AUC = {roc_auc["micro"]:.3f})',
- line=dict(color='deeppink', width=3, dash='dot'),
- hovertemplate=('Micro-average ROC
'
- 'FPR: %{x:.3f}
'
- 'TPR: %{y:.3f}
'
- f'AUC: {roc_auc["micro"]:.3f}')
- ))
-
- # Plot ROC curve for each class
- for i in range(n_classes):
- if i in roc_auc: # Only plot if class exists in test set
- class_name = display_classes[i] if i < len(display_classes) else f"Class {i}"
- color = colors[i % len(colors)]
-
- fig_roc.add_trace(go.Scatter(
- x=fpr[i],
- y=tpr[i],
- mode='lines',
- name=f'{class_name} (AUC = {roc_auc[i]:.3f})',
- line=dict(color=color, width=2),
- hovertemplate=(f'{class_name}
'
- 'FPR: %{x:.3f}
'
- 'TPR: %{y:.3f}
'
- f'AUC: {roc_auc[i]:.3f}')
- ))
+ # Plot 2: Calibration Curve
+ bins = np.linspace(0.0, 1.0, 11)
+ bin_ids = np.digitize(y_score, bins, right=True)
+ bin_centers = []
+ frac_positives = []
+ for b in range(1, len(bins)):
+ mask = bin_ids == b
+ if not np.any(mask):
+ continue
+ bin_centers.append(y_score[mask].mean())
+ frac_positives.append(y_true[mask].mean())
+ if bin_centers and frac_positives:
+ fig_cal = go.Figure()
+ fig_cal.add_trace(
+ go.Scatter(
+ x=bin_centers,
+ y=frac_positives,
+ mode="lines+markers",
+ name="Calibration",
+ line=dict(color="#2ca02c", width=4),
+ )
+ )
+ fig_cal.add_trace(
+ go.Scatter(
+ x=[0, 1],
+ y=[0, 1],
+ mode="lines",
+ name="Perfect Calibration",
+ line=dict(color="gray", width=2, dash="dash"),
+ )
+ )
+ fig_cal.update_layout(
+ title=dict(text="Calibration Curve", x=0.5),
+ xaxis_title="Predicted probability",
+ yaxis_title="Observed frequency",
+ width=700,
+ height=500,
+ )
+ _style_fig(fig_cal)
+ plots.append({
+ "title": "Calibration Curve (Predicted Probability vs Observed Frequency)",
+ "html": pio.to_html(fig_cal, full_html=False, include_plotlyjs=False),
+ })
- # Add diagonal line (random classifier)
- fig_roc.add_trace(go.Scatter(
- x=[0, 1],
- y=[0, 1],
- mode='lines',
- name='Random Classifier',
- line=dict(color='gray', width=1, dash='dash'),
- hovertemplate='Random Classifier
AUC = 0.500'
- ))
-
- # Calculate macro-average AUC
- class_aucs = [roc_auc[i] for i in range(n_classes) if i in roc_auc]
- if class_aucs:
- macro_auc = np.mean(class_aucs)
- title_text = f"ROC Curves (Micro-avg = {roc_auc['micro']:.3f}, Macro-avg = {macro_auc:.3f})"
- else:
- title_text = f"ROC Curves (Micro-avg = {roc_auc['micro']:.3f})"
+ # Plot 3: Threshold vs Metrics
+ thresholds = np.linspace(0.0, 1.0, 21)
+ accs, f1s, sens, specs = [], [], [], []
+ for t in thresholds:
+ y_pred = (y_score >= t).astype(int)
+ tp = np.sum((y_true == 1) & (y_pred == 1))
+ tn = np.sum((y_true == 0) & (y_pred == 0))
+ fp = np.sum((y_true == 0) & (y_pred == 1))
+ fn = np.sum((y_true == 1) & (y_pred == 0))
+ acc = (tp + tn) / max(len(y_true), 1)
+ prec = tp / max(tp + fp, 1e-9)
+ rec = tp / max(tp + fn, 1e-9)
+ f1 = 0 if (prec + rec) == 0 else 2 * prec * rec / (prec + rec)
+ sensitivity = rec
+ specificity = tn / max(tn + fp, 1e-9)
+ accs.append(acc)
+ f1s.append(f1)
+ sens.append(sensitivity)
+ specs.append(specificity)
- fig_roc.update_layout(
- title=dict(text=title_text, x=0.5),
- xaxis_title="False Positive Rate",
- yaxis_title="True Positive Rate",
- width=700,
- height=600,
- margin=dict(t=80, l=80, r=80, b=80),
- legend=dict(
- x=0.6,
- y=0.1,
- bgcolor="rgba(255,255,255,0.9)",
- bordercolor="rgba(0,0,0,0.2)",
- borderwidth=1
- ),
- hovermode='closest'
- )
+ fig_thresh = go.Figure()
+ fig_thresh.add_trace(go.Scatter(x=thresholds, y=accs, mode="lines", name="Accuracy", line=dict(width=4)))
+ fig_thresh.add_trace(go.Scatter(x=thresholds, y=f1s, mode="lines", name="F1", line=dict(width=4)))
+ fig_thresh.add_trace(go.Scatter(x=thresholds, y=sens, mode="lines", name="Sensitivity", line=dict(width=4)))
+ fig_thresh.add_trace(go.Scatter(x=thresholds, y=specs, mode="lines", name="Specificity", line=dict(width=4)))
+ fig_thresh.update_layout(
+ title=dict(text="Threshold Sweep: Accuracy, F1, Sensitivity, Specificity", x=0.5),
+ xaxis_title="Decision threshold",
+ yaxis_title="Metric value",
+ width=700,
+ height=500,
+ legend=dict(
+ x=0.7,
+ y=0.2,
+ bgcolor="rgba(255,255,255,0.9)",
+ bordercolor="rgba(0,0,0,0.2)",
+ borderwidth=1,
+ ),
+ shapes=[
+ dict(
+ type="line",
+ x0=threshold,
+ x1=threshold,
+ y0=0,
+ y1=1,
+ xref="x",
+ yref="paper",
+ line=dict(color="#d62728", width=2, dash="dash"),
+ )
+ ] if isinstance(threshold, (int, float)) else [],
+ annotations=[
+ dict(
+ x=threshold,
+ y=1.02,
+ xref="x",
+ yref="paper",
+ showarrow=False,
+ text=f"Threshold = {threshold:.2f}",
+ font=dict(size=11, color="#d62728"),
+ )
+ ] if isinstance(threshold, (int, float)) else [],
+ )
+ _style_fig(fig_thresh)
+ plots.append({
+ "title": "Threshold Sweep: Accuracy, F1, Sensitivity, Specificity",
+ "html": pio.to_html(fig_thresh, full_html=False, include_plotlyjs=False),
+ })
- # Set equal aspect ratio and proper range
- fig_roc.update_xaxes(range=[0, 1.0])
- fig_roc.update_yaxes(range=[0, 1.05])
-
- return {
- "title": "ROC-AUC Curves",
- "html": pio.to_html(
- fig_roc,
- full_html=False,
- include_plotlyjs=False,
- config=config
- )
- }
-
- except Exception as e:
- print(f"Error building ROC-AUC plot: {e}")
- return None
+ return plots