changeset 17:db9be962dc13 draft default tip

planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
author goeckslab
date Wed, 10 Dec 2025 00:24:13 +0000
parents 8729f69e9207
children
files html_structure.py image_learner.xml image_workflow.py ludwig_backend.py plotly_plots.py
diffstat 5 files changed, 1322 insertions(+), 475 deletions(-) [+]
line wrap: on
line diff
--- a/html_structure.py	Wed Dec 03 01:28:52 2025 +0000
+++ b/html_structure.py	Wed Dec 10 00:24:13 2025 +0000
@@ -1,6 +1,6 @@
 import base64
 import json
-from typing import Any, Dict, Optional
+from typing import Any, Dict, List, Optional
 
 from constants import METRIC_DISPLAY_NAMES
 from utils import detect_output_type, extract_metrics_from_json
@@ -23,6 +23,7 @@
 ) -> str:
     display_keys = [
         "architecture",
+        "image_size",
         "pretrained",
         "trainable",
         "target_column",
@@ -58,6 +59,15 @@
         else:
             if key == "task_type":
                 val_str = val.title() if isinstance(val, str) else "N/A"
+            elif key == "image_size":
+                if val is None:
+                    val_str = "N/A"
+                elif isinstance(val, (list, tuple)) and len(val) == 2:
+                    val_str = f"{val[0]}x{val[1]}"
+                elif isinstance(val, str) and val.lower() == "original":
+                    val_str = "Original (no resize)"
+                else:
+                    val_str = str(val)
             elif key == "batch_size":
                 if isinstance(val, (int, float)):
                     val_str = int(val)
@@ -115,6 +125,11 @@
                             "Ludwig Trainer Parameters</a> for details."
                             "</span>"
                         )
+            elif key == "validation_metric":
+                if val is not None:
+                    val_str = METRIC_DISPLAY_NAMES.get(str(val), str(val))
+                else:
+                    val_str = "N/A"
             elif key == "epochs":
                 if val is None:
                     val_str = "N/A"
@@ -729,6 +744,64 @@
     )
     return modal_html + modal_js
 
+
+def format_dataset_overview_table(rows: List[Dict[str, Any]], regression_mode: bool = False) -> str:
+    """Render a dataset overview table.
+
+    - Classification: per-label distribution across train/val/test.
+    - Regression: split counts (train/val/test).
+    """
+    heading = "<h2 style='text-align: center;'>Dataset Overview</h2>"
+    if not rows:
+        return heading + "<p style='text-align: center; color: #666;'>Dataset overview unavailable.</p><br>"
+
+    if regression_mode:
+        headers = ["Split", "Count"]
+        html = (
+            heading
+            + "<div style='display: flex; justify-content: center;'>"
+            + "<table class='performance-summary' style='border-collapse: collapse; table-layout: fixed;'>"
+            + "<thead><tr>"
+            + "".join(
+                f"<th style='padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;'>{h}</th>"
+                for h in headers
+            )
+            + "</tr></thead><tbody>"
+        )
+        for row in rows:
+            html += generate_table_row(
+                [
+                    row.get("split", "N/A"),
+                    row.get("count", 0),
+                ],
+                "padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;",
+            )
+        html += "</tbody></table></div><br>"
+    else:
+        html = (
+            heading
+            + "<div style='display: flex; justify-content: center;'>"
+            + "<table class='performance-summary' style='border-collapse: collapse; table-layout: fixed;'>"
+            + "<thead><tr>"
+            + "<th style='padding: 10px; border: 1px solid #ccc; text-align: left; white-space: nowrap;'>Label</th>"
+            + "<th style='padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;'>Train</th>"
+            + "<th style='padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;'>Validation</th>"
+            + "<th style='padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;'>Test</th>"
+            + "</tr></thead><tbody>"
+        )
+        for row in rows:
+            html += generate_table_row(
+                [
+                    row.get("label", "N/A"),
+                    row.get("train", 0),
+                    row.get("validation", 0),
+                    row.get("test", 0),
+                ],
+                "padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;",
+            )
+        html += "</tbody></table></div><br>"
+    return html
+
 # -----------------------------------------
 # MODEL PERFORMANCE (Train/Val/Test) TABLE
 # -----------------------------------------
--- a/image_learner.xml	Wed Dec 03 01:28:52 2025 +0000
+++ b/image_learner.xml	Wed Dec 10 00:24:13 2025 +0000
@@ -130,13 +130,9 @@
             </when>
             <when value="regression">
                 <param name="validation_metric_regression" type="select" optional="true" label="Validation metric (regression)" help="Metrics accepted by Ludwig for regression outputs.">
-                    <option value="pearson_r" selected="true">Pearson r</option>
-                    <option value="mae">MAE</option>
+                    <option value="mae" selected="true">MAE</option>
                     <option value="mse">MSE</option>
                     <option value="rmse">RMSE</option>
-                    <option value="mape">MAPE</option>
-                    <option value="r2">R²</option>
-                    <option value="explained_variance">Explained Variance</option>
                     <option value="loss">Loss</option>
                 </param>
             </when>
--- a/image_workflow.py	Wed Dec 03 01:28:52 2025 +0000
+++ b/image_workflow.py	Wed Dec 10 00:24:13 2025 +0000
@@ -5,7 +5,7 @@
 import tempfile
 import zipfile
 from pathlib import Path
-from typing import Any, Dict, Optional, Tuple
+from typing import Any, Dict, List, Optional, Tuple
 
 import pandas as pd
 import pandas.api.types as ptypes
@@ -35,6 +35,8 @@
         self.image_extract_dir: Optional[Path] = None
         self.label_metadata: Dict[str, Any] = {}
         self.output_type_hint: Optional[str] = None
+        self.label_split_counts: List[Dict[str, int]] = []
+        self.split_counts: Dict[int, int] = {}
         logger.info(f"Orchestrator initialized with backend: {type(backend).__name__}")
 
     def _create_temp_dirs(self) -> None:
@@ -186,6 +188,34 @@
             logger.error("Error saving prepared CSV", exc_info=True)
             raise
 
+        # Capture actual split counts for downstream reporting (avoids heuristic 70/10/20 tables)
+        try:
+            split_series = pd.to_numeric(df[SPLIT_COLUMN_NAME], errors="coerce")
+            split_series = split_series.dropna().astype(int)
+            self.split_counts = {int(k): int(v) for k, v in split_series.value_counts().to_dict().items()}
+            if LABEL_COLUMN_NAME in df.columns:
+                counts = (
+                    df.dropna(subset=[LABEL_COLUMN_NAME])
+                    .assign(**{SPLIT_COLUMN_NAME: split_series})
+                    .groupby([LABEL_COLUMN_NAME, SPLIT_COLUMN_NAME])
+                    .size()
+                    .unstack(fill_value=0)
+                    .sort_index()
+                )
+                self.label_split_counts = [
+                    {
+                        "label": str(lbl),
+                        "train": int(row.get(0, 0)),
+                        "validation": int(row.get(1, 0)),
+                        "test": int(row.get(2, 0)),
+                    }
+                    for lbl, row in counts.iterrows()
+                ]
+        except Exception:
+            logger.warning("Unable to capture split counts for reporting", exc_info=True)
+            self.label_split_counts = []
+            self.split_counts = {}
+
         self._capture_label_metadata(df)
 
         return final_csv, split_config, split_info
@@ -349,6 +379,8 @@
                 "random_seed": self.args.random_seed,
                 "early_stop": self.args.early_stop,
                 "label_column_data_path": csv_path,
+                "label_split_counts": self.label_split_counts,
+                "split_counts": self.split_counts,
                 "augmentation": self.args.augmentation,
                 "image_resize": self.args.image_resize,
                 "image_zip": self.args.image_zip,
--- a/ludwig_backend.py	Wed Dec 03 01:28:52 2025 +0000
+++ b/ludwig_backend.py	Wed Dec 10 00:24:13 2025 +0000
@@ -1,8 +1,9 @@
+import inspect
 import json
 import logging
 import os
 from pathlib import Path
-from typing import Any, Dict, Optional, Protocol, Tuple
+from typing import Any, Dict, List, Optional, Protocol, Tuple
 
 import pandas as pd
 import pandas.api.types as ptypes
@@ -17,6 +18,7 @@
     build_tabbed_html,
     encode_image_to_base64,
     format_config_table_html,
+    format_dataset_overview_table,
     format_stats_table_html,
     format_test_merged_stats_table_html,
     format_train_val_stats_table_html,
@@ -33,7 +35,9 @@
 from ludwig.utils.data_utils import get_split_path
 from metaformer_setup import get_visualizations_registry, META_DEFAULT_CFGS
 from plotly_plots import (
+    build_binary_threshold_plot,
     build_classification_plots,
+    build_multiclass_metric_plots,
     build_prediction_diagnostics,
     build_regression_test_plots,
     build_regression_train_val_plots,
@@ -267,6 +271,23 @@
         else:
             encoder_config = {"type": raw_encoder}
 
+        # Set a human-friendly architecture string for reporting
+        arch_display = None
+        if is_metaformer and custom_model:
+            arch_display = str(custom_model)
+        elif isinstance(raw_encoder, dict):
+            enc_type = raw_encoder.get("type")
+            enc_variant = raw_encoder.get("model_variant")
+            if enc_type:
+                base = str(enc_type).replace("_", " ").title()
+                arch_display = f"{base} {enc_variant}" if enc_variant is not None else base
+        else:
+            arch_display = str(raw_encoder).replace("_", " ").title()
+
+        if not arch_display:
+            arch_display = str(model_name)
+        config_params["architecture"] = arch_display
+
         batch_size_cfg = batch_size or "auto"
 
         label_column_path = config_params.get("label_column_data_path")
@@ -343,6 +364,7 @@
             # Force Ludwig to respect our dimensions by setting additional parameters
             image_feat["preprocessing"]["requires_equal_dimensions"] = False
             logger.info(f"Set preprocessing dimensions for MetaFormer: {height}x{width} (infer_dimensions=True with max dimensions to allow validation)")
+            config_params["image_size"] = f"{height}x{width}"
         # Now set the encoder configuration
         image_feat["encoder"] = encoder_config
 
@@ -374,8 +396,12 @@
                     image_feat["preprocessing"]["infer_image_max_height"] = height
                     image_feat["preprocessing"]["infer_image_max_width"] = width
                     logger.info(f"Added resize preprocessing: {height}x{width} for standard encoder with infer_image_dimensions=True and max dimensions")
+                    config_params["image_size"] = f"{height}x{width}"
             except (ValueError, IndexError):
                 logger.warning(f"Invalid image resize format: {config_params['image_resize']}, skipping resize preprocessing")
+        elif not is_metaformer:
+            # No explicit resize provided; keep for reporting purposes
+            config_params.setdefault("image_size", "original")
 
         def _resolve_validation_metric(task: str, requested: Optional[str]) -> Optional[str]:
             """Pick a validation metric that Ludwig will accept for the resolved task."""
@@ -471,6 +497,9 @@
                 config_params.get("validation_metric"),
             )
 
+        # Propagate the resolved validation metric (including any task-based fallback or alias normalization)
+        config_params["validation_metric"] = val_metric
+
         conf: Dict[str, Any] = {
             "model_type": "ecd",
             "input_features": [image_feat],
@@ -641,18 +670,62 @@
         except Exception as e:
             logger.error(f"Error converting Parquet to CSV: {e}")
 
-    def generate_plots(self, output_dir: Path) -> None:
-        """Generate all registered Ludwig visualizations for the latest experiment run."""
-        logger.info("Generating all Ludwig visualizations…")
+    @staticmethod
+    def _extract_metric_series(stats: Dict[str, Any], split: str, prefer: Optional[str] = None) -> Tuple[Optional[str], Optional[List[float]]]:
+        """Pull the first numeric metric list we can find for the requested split."""
+        if not isinstance(stats, dict):
+            return None, None
+
+        split_stats = stats.get(split, {})
+        ordered_metrics: List[Tuple[str, List[float]]] = []
+
+        def _append_metrics(metric_map: Dict[str, Any]) -> None:
+            for metric_name, values in metric_map.items():
+                if isinstance(values, list) and any(isinstance(v, (int, float)) for v in values):
+                    ordered_metrics.append((metric_name, values))
 
-        # Keep only lightweight plots (drop compare_performance/roc_curves)
-        test_plots = {
-            "roc_curves_from_test_statistics",
-            "confusion_matrix",
-        }
+        if isinstance(split_stats, dict):
+            combined = split_stats.get("combined")
+            if isinstance(combined, dict):
+                _append_metrics(combined)
+
+            for feature_name, feature_metrics in split_stats.items():
+                if feature_name == "combined" or not isinstance(feature_metrics, dict):
+                    continue
+                _append_metrics(feature_metrics)
+
+        if prefer:
+            for metric_name, values in ordered_metrics:
+                if metric_name == prefer:
+                    return metric_name, values
+
+        return ordered_metrics[0] if ordered_metrics else (None, None)
+
+    def generate_plots(self, output_dir: Path) -> None:
+        """Generate Ludwig visualizations (train/val + test) for the latest experiment run."""
+        logger.info("Generating Ludwig visualizations (train/val + test)…")
+
+        # Train/validation visualizations
         train_plots = {
             "learning_curves",
-            "compare_classifiers_performance_subset",
+        }
+
+        # Test visualizations (multi-class transparency)
+        test_plots = {
+            "confusion_matrix",
+            "compare_performance",
+            "compare_classifiers_multiclass_multimetric",
+            "frequency_vs_f1",
+            "confidence_thresholding",
+            "confidence_thresholding_data_vs_acc",
+            "confidence_thresholding_data_vs_acc_subset",
+            "confidence_thresholding_data_vs_acc_subset_per_class",
+            # Binary-only visualizations will still be attempted; multi-class replacements handled elsewhere
+            "binary_threshold_vs_metric",
+            "roc_curves",
+            "precision_recall_curves",
+            "calibration_1_vs_all",
+            "calibration_multiclass",
         }
 
         output_dir = Path(output_dir)
@@ -677,7 +750,6 @@
 
         training_stats = _check(exp_dir / "training_statistics.json")
         test_stats = _check(exp_dir / TEST_STATISTICS_FILE_NAME)
-        probs_path = _check(exp_dir / PREDICTIONS_PARQUET_FILE_NAME)
         gt_metadata = _check(exp_dir / "model" / TRAIN_SET_METADATA_FILE_NAME)
 
         dataset_path = None
@@ -688,6 +760,9 @@
                 cfg = json.load(f)
             dataset_path = _check(Path(cfg.get("dataset", "")))
             split_file = _check(Path(get_split_path(cfg.get("dataset", ""))))
+            model_name = cfg.get("model_name", "model")
+        else:
+            model_name = "model"
 
         output_feature = ""
         if desc.exists():
@@ -700,7 +775,44 @@
                 stats = json.load(f)
             output_feature = next(iter(stats.keys()), "")
 
+        probs_path = None
+        prob_candidates = [
+            exp_dir / f"{LABEL_COLUMN_NAME}_probabilities.csv",
+            exp_dir / f"{output_feature}_probabilities.csv" if output_feature else None,
+            exp_dir / "probabilities.csv",
+            exp_dir / "predictions.csv",
+            exp_dir / PREDICTIONS_PARQUET_FILE_NAME,
+        ]
+        for cand in prob_candidates:
+            if cand and Path(cand).exists():
+                probs_path = str(cand)
+                break
+
         viz_registry = get_visualizations_registry()
+        if not viz_registry:
+            logger.warning(
+                "Ludwig visualizations registry not available; train/test PNGs will be skipped."
+            )
+            return
+
+        base_kwargs = {
+            "training_statistics": [training_stats] if training_stats else [],
+            "test_statistics": [test_stats] if test_stats else [],
+            "probabilities": [probs_path] if probs_path else [],
+            "output_feature_name": output_feature,
+            "ground_truth_split": 2,
+            "top_n_classes": [20],
+            "top_k": 3,
+            "metrics": ["f1", "precision", "recall", "accuracy"],
+            "positive_label": 0,
+            "ground_truth_metadata": gt_metadata,
+            "ground_truth": dataset_path,
+            "split_file": split_file,
+            "output_directory": None,  # set per plot below
+            "normalize": False,
+            "file_format": "png",
+            "model_names": [model_name],
+        }
         for viz_name, viz_func in viz_registry.items():
             if viz_name in train_plots:
                 viz_dir_plot = train_viz
@@ -710,25 +822,22 @@
                 continue
 
             try:
+                # Build per-viz kwargs based on the function signature to avoid unexpected args
+                sig_params = set(inspect.signature(viz_func).parameters.keys())
+                call_kwargs = {
+                    k: v
+                    for k, v in base_kwargs.items()
+                    if k in sig_params and v is not None
+                }
+                if "output_directory" in sig_params:
+                    call_kwargs["output_directory"] = str(viz_dir_plot)
+
                 viz_func(
-                    training_statistics=[training_stats] if training_stats else [],
-                    test_statistics=[test_stats] if test_stats else [],
-                    probabilities=[probs_path] if probs_path else [],
-                    output_feature_name=output_feature,
-                    ground_truth_split=2,
-                    top_n_classes=[0],
-                    top_k=3,
-                    ground_truth_metadata=gt_metadata,
-                    ground_truth=dataset_path,
-                    split_file=split_file,
-                    output_directory=str(viz_dir_plot),
-                    normalize=False,
-                    file_format="png",
+                    **call_kwargs,
                 )
                 logger.info(f"✔ Generated {viz_name}")
             except Exception as e:
                 logger.warning(f"✘ Skipped {viz_name}: {e}")
-
         logger.info(f"All visualizations written to {viz_dir}")
 
     def generate_html_report(
@@ -756,6 +865,7 @@
         label_metadata_path = config.get("label_column_data_path")
         if label_metadata_path:
             label_metadata_path = Path(label_metadata_path)
+        dataset_path_from_desc: Optional[Path] = None
 
         # Pull additional config details from description.json if available
         config_for_summary = dict(config)
@@ -765,7 +875,8 @@
         if desc_path.exists():
             try:
                 with open(desc_path, "r") as f:
-                    desc_cfg = json.load(f).get("config", {})
+                    desc_json = json.load(f)
+                desc_cfg = desc_json.get("config", {}) if isinstance(desc_json, dict) else {}
                 encoder_cfg = (
                     desc_cfg.get("input_features", [{}])[0].get("encoder", {})
                     if isinstance(desc_cfg.get("input_features", [{}]), list)
@@ -783,10 +894,20 @@
 
                 arch_type = encoder_cfg.get("type")
                 arch_variant = encoder_cfg.get("model_variant")
+                arch_custom = encoder_cfg.get("custom_model")
                 arch_name = None
+                if arch_custom:
+                    arch_name = str(arch_custom)
                 if arch_type:
                     arch_base = str(arch_type).replace("_", " ").title()
-                    arch_name = f"{arch_base} {arch_variant}" if arch_variant is not None else arch_base
+                    arch_type_name = (
+                        f"{arch_base} {arch_variant}" if arch_variant is not None else arch_base
+                    )
+                    # Prefer explicit custom model names (e.g., MetaFormer) but fall back to encoder type
+                    arch_name = arch_name or arch_type_name
+                if not arch_name and config.get("model_name"):
+                    # As a last resort, show the user-selected model name (handles custom/MetaFormer cases)
+                    arch_name = str(config.get("model_name"))
 
                 summary_fields = {
                     "architecture": arch_name,
@@ -814,12 +935,22 @@
                     if k in {"target_column", "image_column"} and config_for_summary.get(k):
                         continue
                     config_for_summary.setdefault(k, v)
+
+                dataset_field = None
+                if isinstance(desc_json, dict):
+                    dataset_field = desc_json.get("dataset") or desc_cfg.get("dataset")
+                if dataset_field:
+                    try:
+                        dataset_path_from_desc = Path(dataset_field)
+                    except TypeError:
+                        dataset_path_from_desc = None
+                if dataset_path_from_desc and (not label_metadata_path or not label_metadata_path.exists()):
+                    label_metadata_path = dataset_path_from_desc
             except Exception as e:  # pragma: no cover - defensive
                 logger.warning(f"Could not merge description.json into config summary: {e}")
 
         base_viz_dir = exp_dir / "visualizations"
         train_viz_dir = base_viz_dir / "train"
-        test_viz_dir = base_viz_dir / "test"
 
         html = get_html_template()
 
@@ -880,10 +1011,164 @@
       });
     }
   });
-</script>
+        </script>
 """
         html += f"<h1>{title}</h1>"
 
+        def append_plot_blocks(tab_html: str, plots: List[Dict[str, str]], title_suffix: str = "") -> str:
+            """Append Plotly blocks to a tab with consistent markup."""
+            if not plots:
+                return tab_html
+            suffix = title_suffix or ""
+            for plot in plots:
+                tab_html += (
+                    f"<h2 style='text-align: center;'>{plot['title']}{suffix}</h2>"
+                    f"<div class='plotly-center'>{plot['html']}</div>"
+                )
+            return tab_html
+
+        def build_dataset_overview(
+            label_metadata: Optional[Path],
+            output_type: Optional[str],
+            split_probabilities: Optional[List[float]],
+            label_split_counts: Optional[List[Dict[str, int]]] = None,
+            split_counts: Optional[Dict[int, int]] = None,
+            fallback_dataset: Optional[Path] = None,
+        ) -> str:
+            """Summarize dataset distribution across splits using the actual split config."""
+            if label_split_counts:
+                # Use the actual counts captured during data prep instead of heuristics.
+                return format_dataset_overview_table(label_split_counts, regression_mode=False)
+
+            if output_type == "regression" and split_counts:
+                rows = [
+                    {"split": "train", "count": int(split_counts.get(0, 0))},
+                    {"split": "validation", "count": int(split_counts.get(1, 0))},
+                    {"split": "test", "count": int(split_counts.get(2, 0))},
+                ]
+                return format_dataset_overview_table(rows, regression_mode=True)
+
+            candidate_paths: List[Path] = []
+            if label_metadata and label_metadata.exists():
+                candidate_paths.append(label_metadata)
+            if fallback_dataset and fallback_dataset.exists():
+                candidate_paths.append(fallback_dataset)
+            if not candidate_paths:
+                return format_dataset_overview_table([])
+
+            def _normalize_split_probabilities(
+                probs: Optional[List[float]],
+            ) -> Optional[List[float]]:
+                if not probs or len(probs) != 3:
+                    return None
+                try:
+                    probs = [float(p) for p in probs]
+                except (TypeError, ValueError):
+                    return None
+                total = sum(probs)
+                if total <= 0:
+                    return None
+                return [p / total for p in probs]
+
+            def _split_counts_from_column(df: pd.DataFrame) -> Dict[int, int]:
+                if SPLIT_COLUMN_NAME not in df.columns:
+                    return {}
+                split_series = pd.to_numeric(
+                    df[SPLIT_COLUMN_NAME], errors="coerce"
+                ).dropna()
+                if split_series.empty:
+                    return {}
+                split_series = split_series.astype(int)
+                return split_series.value_counts().to_dict()
+
+            def _split_counts_from_probs(total: int, probs: List[float]) -> Dict[int, int]:
+                train_n = int(total * probs[0])
+                val_n = int(total * probs[1])
+                test_n = max(0, total - train_n - val_n)
+                return {0: train_n, 1: val_n, 2: test_n}
+
+            fallback_rows: Optional[List[Dict[str, int]]] = None
+            for meta_path in candidate_paths:
+                try:
+                    df_labels = pd.read_csv(meta_path)
+                    probs = _normalize_split_probabilities(split_probabilities)
+
+                    # Regression (or missing label column): only need split counts
+                    if output_type == "regression" or LABEL_COLUMN_NAME not in df_labels.columns:
+                        split_counts_found = _split_counts_from_column(df_labels)
+                        if split_counts_found:
+                            rows = [
+                                {"split": "train", "count": int(split_counts_found.get(0, 0))},
+                                {"split": "validation", "count": int(split_counts_found.get(1, 0))},
+                                {"split": "test", "count": int(split_counts_found.get(2, 0))},
+                            ]
+                            return format_dataset_overview_table(rows, regression_mode=True)
+                        if probs and fallback_rows is None:
+                            split_counts_found = _split_counts_from_probs(len(df_labels), probs)
+                            fallback_rows = [
+                                {"split": "train", "count": int(split_counts_found.get(0, 0))},
+                                {"split": "validation", "count": int(split_counts_found.get(1, 0))},
+                                {"split": "test", "count": int(split_counts_found.get(2, 0))},
+                            ]
+                        continue
+
+                    # Classification: prefer actual split assignments; fall back to configured probabilities
+                    if SPLIT_COLUMN_NAME in df_labels.columns:
+                        df_counts = df_labels[[LABEL_COLUMN_NAME, SPLIT_COLUMN_NAME]].copy()
+                        df_counts[SPLIT_COLUMN_NAME] = pd.to_numeric(
+                            df_counts[SPLIT_COLUMN_NAME], errors="coerce"
+                        )
+                        df_counts = df_counts.dropna(subset=[SPLIT_COLUMN_NAME])
+                        if df_counts.empty:
+                            continue
+
+                        df_counts[SPLIT_COLUMN_NAME] = df_counts[SPLIT_COLUMN_NAME].astype(int)
+                        df_counts = df_counts.dropna(subset=[LABEL_COLUMN_NAME])
+                        counts = (
+                            df_counts.groupby([LABEL_COLUMN_NAME, SPLIT_COLUMN_NAME])
+                            .size()
+                            .unstack(fill_value=0)
+                            .sort_index()
+                        )
+                        rows = []
+                        for lbl, row in counts.iterrows():
+                            rows.append(
+                                {
+                                    "label": str(lbl),
+                                    "train": int(row.get(0, 0)),
+                                    "validation": int(row.get(1, 0)),
+                                    "test": int(row.get(2, 0)),
+                                }
+                            )
+                        return format_dataset_overview_table(rows)
+
+                    if probs:
+                        label_series = df_labels[LABEL_COLUMN_NAME].dropna()
+                        label_counts = label_series.value_counts().sort_index()
+                        if label_counts.empty:
+                            continue
+                        rows = []
+                        for lbl, count in label_counts.items():
+                            train_n = int(count * probs[0])
+                            val_n = int(count * probs[1])
+                            test_n = max(0, count - train_n - val_n)
+                            rows.append(
+                                {
+                                    "label": str(lbl),
+                                    "train": train_n,
+                                    "validation": val_n,
+                                    "test": test_n,
+                                }
+                            )
+                        fallback_rows = fallback_rows or rows
+                except Exception as exc:
+                    logger.warning("Failed to build dataset overview from %s: %s", meta_path, exc)
+                    continue
+
+            if fallback_rows:
+                return format_dataset_overview_table(fallback_rows, regression_mode=output_type == "regression")
+            return format_dataset_overview_table([])
+
         metrics_html = ""
         train_val_metrics_html = ""
         test_metrics_html = ""
@@ -911,6 +1196,23 @@
                 f"Could not load stats for HTML report: {type(e).__name__}: {e}"
             )
 
+        if not output_type:
+            # Fallback to configured task type when stats are unavailable (e.g., failed run).
+            output_type = (
+                str(config_for_summary.get("task_type")).lower()
+                if config_for_summary.get("task_type")
+                else None
+            )
+
+        dataset_overview_html = build_dataset_overview(
+            label_metadata_path,
+            output_type,
+            config.get("split_probabilities"),
+            config.get("label_split_counts"),
+            config.get("split_counts"),
+            dataset_path_from_desc,
+        )
+
         config_html = ""
         training_progress = self.get_training_process(output_dir)
         try:
@@ -937,11 +1239,12 @@
             exclude_names: Optional[set] = None,
         ) -> str:
             if not dir_path.exists():
-                return f"<h2>{title}</h2><p><em>Directory not found.</em></p>"
+                return ""
 
             exclude_names = exclude_names or set()
 
-            imgs = list(dir_path.glob("*.png"))
+            # Search recursively because Ludwig can nest figures under per-feature folders
+            imgs = list(dir_path.rglob("*.png"))
 
             # Exclude ROC curves and standard confusion matrices (keep only entropy version)
             default_exclude = {
@@ -983,7 +1286,7 @@
             ]
 
             if not imgs:
-                return f"<h2>{title}</h2><p><em>No plots found.</em></p>"
+                return ""
 
             # Sort images by name for consistent ordering (works with string and numeric labels)
             imgs = sorted(imgs, key=lambda x: x.name)
@@ -1006,36 +1309,86 @@
                 )
             return html_section
 
-        # Show performance first, then config
-        tab1_content = metrics_html + config_html
+        # Show dataset overview, performance first, then config
+        predictions_csv_path = exp_dir / "predictions.csv"
+
+        tab1_content = dataset_overview_html + metrics_html + config_html
 
-        tab2_content = train_val_metrics_html + render_img_section(
-            "Training and Validation Visualizations",
-            train_viz_dir,
-            output_type,
-            exclude_names={
-                "compare_classifiers_performance_from_prob.png",
-                "roc_curves_from_prediction_statistics.png",
-                "precision_recall_curves_from_prediction_statistics.png",
-                "precision_recall_curve.png",
-            },
+        tab2_content = train_val_metrics_html
+        # Preload binary threshold plot so it appears first in Train/Val tab
+        threshold_plot = None
+        threshold_value = (
+            config_for_summary.get("threshold")
+            if config_for_summary.get("threshold") is not None
+            else config.get("threshold")
         )
+        if threshold_value is None and output_type == "binary":
+            threshold_value = 0.5
+        if output_type == "binary" and predictions_csv_path.exists():
+            try:
+                threshold_plot = build_binary_threshold_plot(
+                    str(predictions_csv_path),
+                    label_data_path=str(config.get("label_column_data_path"))
+                    if config.get("label_column_data_path")
+                    else None,
+                    split_value=1,
+                )
+            except Exception as e:
+                logger.warning(f"Could not generate validation threshold plot: {e}")
+
         if train_stats_path.exists():
             try:
                 if output_type == "regression":
                     tv_plots = build_regression_train_val_plots(str(train_stats_path))
+                    tab2_content = append_plot_blocks(tab2_content, tv_plots)
                 else:
                     tv_plots = build_train_validation_plots(str(train_stats_path))
-                for plot in tv_plots:
-                    tab2_content += (
-                        f"<h2 style='text-align: center;'>{plot['title']}</h2>"
-                        f"<div class='plotly-center'>{plot['html']}</div>"
-                    )
-                if tv_plots:
-                    logger.info(f"Generated {len(tv_plots)} train/val diagnostic plots")
+                    # Add threshold plot first, then other train/val plots
+                    if threshold_plot:
+                        tab2_content = append_plot_blocks(tab2_content, [threshold_plot])
+                        # Only append once; avoid duplicates if added elsewhere
+                        threshold_plot = None
+                    tab2_content = append_plot_blocks(tab2_content, tv_plots)
+                    if threshold_plot or tv_plots:
+                        logger.info(
+                            f"Added {len(tv_plots) + (1 if threshold_plot else 0)} train/val diagnostic plots"
+                        )
             except Exception as e:
                 logger.warning(f"Could not generate train/val plots: {e}")
 
+        # Only include training PNGs for regression; classification is handled by filtered Plotly plots
+        if output_type == "regression":
+            tab2_content += render_img_section(
+                "Training and Validation Visualizations",
+                train_viz_dir,
+                output_type,
+                exclude_names={
+                    "compare_classifiers_performance_from_prob.png",
+                    "roc_curves_from_prediction_statistics.png",
+                    "precision_recall_curves_from_prediction_statistics.png",
+                    "precision_recall_curve.png",
+                },
+            )
+
+        # Validation diagnostics (calibration/threshold) from predictions.csv, using split=1
+        if output_type in ("binary", "category") and predictions_csv_path.exists():
+            try:
+                val_diag_plots = build_prediction_diagnostics(
+                    str(predictions_csv_path),
+                    label_data_path=str(config.get("label_column_data_path"))
+                    if config.get("label_column_data_path")
+                    else None,
+                    split_value=1,
+                )
+                val_conf_plots = [
+                    p for p in val_diag_plots if "Prediction Confidence Distribution" in p.get("title", "")
+                ]
+                tab2_content = append_plot_blocks(
+                    tab2_content, val_conf_plots, " (Validation)"
+                )
+            except Exception as e:
+                logger.warning(f"Could not generate validation diagnostics: {e}")
+
         # --- Predictions vs Ground Truth table (REGRESSION ONLY) ---
         preds_section = ""
         parquet_path = exp_dir / PREDICTIONS_PARQUET_FILE_NAME
@@ -1077,18 +1430,12 @@
                 logger.warning(f"Could not build Predictions vs GT table: {e}")
 
         tab3_content = test_metrics_html + preds_section
-        test_plotly_added = False
 
         if output_type == "regression" and train_stats_path.exists():
             try:
                 test_plots = build_regression_test_plots(str(train_stats_path))
-                for plot in test_plots:
-                    tab3_content += (
-                        f"<h2 style='text-align: center;'>{plot['title']}</h2>"
-                        f"<div class='plotly-center'>{plot['html']}</div>"
-                    )
+                tab3_content = append_plot_blocks(tab3_content, test_plots)
                 if test_plots:
-                    test_plotly_added = True
                     logger.info(f"Generated {len(test_plots)} regression test plots")
             except Exception as e:
                 logger.warning(f"Could not generate regression test plots: {e}")
@@ -1104,46 +1451,42 @@
                     train_set_metadata_path=str(train_set_metadata_path)
                     if train_set_metadata_path.exists()
                     else None,
+                    threshold=threshold_value,
                 )
-                for plot in interactive_plots:
-                    tab3_content += (
-                        f"<h2 style='text-align: center;'>{plot['title']}</h2>"
-                        f"<div class='plotly-center'>{plot['html']}</div>"
-                    )
+                tab3_content = append_plot_blocks(tab3_content, interactive_plots)
                 if interactive_plots:
-                    test_plotly_added = True
-                logger.info(f"Generated {len(interactive_plots)} interactive Plotly plots")
+                    logger.info(f"Generated {len(interactive_plots)} interactive Plotly plots")
             except Exception as e:
                 logger.warning(f"Could not generate Plotly plots: {e}")
 
-            # Add prediction diagnostics from predictions.csv
-            predictions_csv_path = exp_dir / "predictions.csv"
-            try:
-                diag_plots = build_prediction_diagnostics(
-                    str(predictions_csv_path),
-                    label_data_path=str(config.get("label_column_data_path"))
-                    if config.get("label_column_data_path")
-                    else None,
-                    threshold=config.get("threshold"),
-                )
-                for plot in diag_plots:
-                    tab3_content += (
-                        f"<h2 style='text-align: center;'>{plot['title']}</h2>"
-                        f"<div class='plotly-center'>{plot['html']}</div>"
+            # Multi-class transparency plots from test stats (replace ROC/PR for multi-class)
+            if output_type == "category" and test_stats_path.exists():
+                try:
+                    multi_curves = build_multiclass_metric_plots(str(test_stats_path))
+                    tab3_content = append_plot_blocks(tab3_content, multi_curves)
+                    if multi_curves:
+                        logger.info("Added multi-class per-class metric plots to test tab")
+                except Exception as e:
+                    logger.warning(f"Could not generate multi-class metric plots: {e}")
+
+            # Test diagnostics (confidence histogram) from predictions.csv, using split=2
+            if predictions_csv_path.exists():
+                try:
+                    test_diag_plots = build_prediction_diagnostics(
+                        str(predictions_csv_path),
+                        label_data_path=str(config.get("label_column_data_path"))
+                        if config.get("label_column_data_path")
+                        else None,
+                        split_value=2,
                     )
-                if diag_plots:
-                    test_plotly_added = True
-                    logger.info(f"Generated {len(diag_plots)} prediction diagnostic plots")
-            except Exception as e:
-                logger.warning(f"Could not generate prediction diagnostics: {e}")
-
-        # Fallback: include static PNGs if no interactive plots were added
-        if not test_plotly_added:
-            tab3_content += render_img_section(
-                "Test Visualizations (PNG fallback)",
-                test_viz_dir,
-                output_type,
-            )
+                    test_conf_plots = [
+                        p for p in test_diag_plots if "Prediction Confidence Distribution" in p.get("title", "")
+                    ]
+                    if test_conf_plots:
+                        tab3_content = append_plot_blocks(tab3_content, test_conf_plots)
+                        logger.info("Added test prediction confidence plot")
+                except Exception as e:
+                    logger.warning(f"Could not generate test diagnostics: {e}")
 
         # Add static TEST PNGs (with default dedupe/exclusions)
         tabbed_html = build_tabbed_html(tab1_content, tab2_content, tab3_content)
--- a/plotly_plots.py	Wed Dec 03 01:28:52 2025 +0000
+++ b/plotly_plots.py	Wed Dec 10 00:24:13 2025 +0000
@@ -7,6 +7,17 @@
 import plotly.graph_objects as go
 import plotly.io as pio
 from constants import LABEL_COLUMN_NAME, SPLIT_COLUMN_NAME
+from sklearn.metrics import (
+    accuracy_score,
+    auc,
+    average_precision_score,
+    f1_score,
+    precision_recall_curve,
+    precision_score,
+    recall_score,
+    roc_curve,
+)
+from sklearn.preprocessing import label_binarize
 
 
 def _style_fig(fig: go.Figure, font_size: int = 12) -> go.Figure:
@@ -21,6 +32,64 @@
     return fig
 
 
+def _fig_to_html(
+    fig: go.Figure, *, include_js: bool = False, config: Optional[dict] = None
+) -> str:
+    """Render a Plotly figure to a lightweight HTML fragment."""
+    include_plotlyjs = "cdn" if include_js else False
+    return pio.to_html(
+        fig,
+        full_html=False,
+        include_plotlyjs=include_plotlyjs,
+        config=config,
+    )
+
+
+def _wrap_plot(
+    title: str,
+    fig: go.Figure,
+    *,
+    include_js: bool = False,
+    config: Optional[dict] = None,
+) -> Dict[str, str]:
+    """Package a figure with its title for downstream HTML rendering."""
+    return {"title": title, "html": _fig_to_html(fig, include_js=include_js, config=config)}
+
+
+def _line_chart(
+    traces: List[tuple],
+    *,
+    title: str,
+    yaxis_title: str,
+) -> go.Figure:
+    """Build a basic epoch-indexed line chart for train/val/test curves."""
+    fig = go.Figure()
+    for name, series in traces:
+        if not series:
+            continue
+        epochs = list(range(1, len(series) + 1))
+        fig.add_trace(
+            go.Scatter(
+                x=epochs,
+                y=series,
+                mode="lines+markers",
+                name=name,
+                line=dict(width=4),
+            )
+        )
+
+    fig.update_layout(
+        title=dict(text=title, x=0.5),
+        xaxis_title="Epoch",
+        yaxis_title=yaxis_title,
+        width=760,
+        height=520,
+        hovermode="x unified",
+    )
+    _style_fig(fig)
+    return fig
+
+
 def _labels_from_metadata_dict(meta_dict: dict) -> List[str]:
     """Extract ordered label names from Ludwig train_set_metadata."""
     if not isinstance(meta_dict, dict):
@@ -106,6 +175,7 @@
     training_stats_path: Optional[str] = None,
     metadata_csv_path: Optional[str] = None,
     train_set_metadata_path: Optional[str] = None,
+    threshold: Optional[float] = None,
 ) -> List[Dict[str, str]]:
     """
     Read Ludwig’s test_statistics.json and build three interactive Plotly panels:
@@ -156,8 +226,11 @@
         )
     )
     fig_cm.update_traces(xgap=2, ygap=2)
+    cm_title = "Confusion Matrix"
+    if threshold is not None:
+        cm_title = f"Confusion Matrix (Threshold: {threshold})"
     fig_cm.update_layout(
-        title=dict(text="Confusion Matrix", x=0.5),
+        title=dict(text=cm_title, x=0.5),
         xaxis_title="Predicted",
         yaxis_title="Observed",
         yaxis_autorange="reversed",
@@ -196,25 +269,19 @@
                 yshift=-2,
             )
 
-    plots.append({
-        "title": "Confusion Matrix",
-        "html": pio.to_html(
-            fig_cm,
-            full_html=False,
-            include_plotlyjs="cdn",
-            config=common_cfg
-        )
-    })
+    plots.append(
+        _wrap_plot("Confusion Matrix", fig_cm, include_js=True, config=common_cfg)
+    )
 
-    # 1) ROC Curve (from test_statistics)
-    roc_plot = _build_static_roc_plot(label_stats, common_cfg, friendly_labels=labels)
-    if roc_plot:
-        plots.append(roc_plot)
+    # 1) ROC / PR curves only for binary tasks
+    if n_classes == 2:
+        roc_plot = _build_static_roc_plot(label_stats, common_cfg, friendly_labels=labels)
+        if roc_plot:
+            plots.append(roc_plot)
 
-    # 2) Precision-Recall Curve (from test_statistics)
-    pr_plot = _build_precision_recall_plot(label_stats, common_cfg)
-    if pr_plot:
-        plots.append(pr_plot)
+        pr_plot = _build_precision_recall_plot(label_stats, common_cfg)
+        if pr_plot:
+            plots.append(pr_plot)
 
     # 2) Classification Report Heatmap
     pcs = label_stats.get("per_class_stats", {})
@@ -259,15 +326,9 @@
             margin=dict(t=80, l=80, r=80, b=80),
         )
         _style_fig(fig_cr)
-        plots.append({
-            "title": "Per-Class metrics",
-            "html": pio.to_html(
-                fig_cr,
-                full_html=False,
-                include_plotlyjs=False,
-                config=common_cfg
-            )
-        })
+        plots.append(
+            _wrap_plot("Per-Class metrics", fig_cr, config=common_cfg)
+        )
 
     # 3) Prediction Diagnostics (from predictions.csv)
     # Note: appended separately in generate_html_report, not returned here.
@@ -294,8 +355,6 @@
     include_js = True  # Load Plotly.js once for this group
 
     def _get_series(stats: dict, metric: str) -> List[float]:
-        if metric not in stats:
-            return []
         vals = stats.get(metric, [])
         if isinstance(vals, list):
             return [float(v) for v in vals]
@@ -304,181 +363,98 @@
         except Exception:
             return []
 
-    def _line_plot(metric_key: str, title: str, yaxis_title: str) -> Optional[Dict[str, str]]:
-        train_series = _get_series(label_train, metric_key)
-        val_series = _get_series(label_val, metric_key)
+    metric_specs = [
+        ("loss", "Loss across epochs", "Loss"),
+        ("accuracy", "Accuracy across epochs", "Accuracy"),
+        ("roc_auc", "ROC-AUC across epochs", "ROC-AUC"),
+        ("precision", "Precision across epochs", "Precision"),
+        ("recall", "Recall/Sensitivity across epochs", "Recall"),
+        ("specificity", "Specificity across epochs", "Specificity"),
+    ]
+
+    for key, title, yaxis in metric_specs:
+        train_series = _get_series(label_train, key)
+        val_series = _get_series(label_val, key)
         if not train_series and not val_series:
-            return None
-        epochs_train = list(range(1, len(train_series) + 1))
-        epochs_val = list(range(1, len(val_series) + 1))
-        fig = go.Figure()
-        if train_series:
-            fig.add_trace(
-                go.Scatter(
-                    x=epochs_train,
-                    y=train_series,
-                    mode="lines+markers",
-                    name="Train",
-                    line=dict(width=4),
-                )
-            )
-        if val_series:
-            fig.add_trace(
-                go.Scatter(
-                    x=epochs_val,
-                    y=val_series,
-                    mode="lines+markers",
-                    name="Validation",
-                    line=dict(width=4),
-                )
-            )
-        fig.update_layout(
-            title=dict(text=title, x=0.5),
-            xaxis_title="Epoch",
-            yaxis_title=yaxis_title,
-            width=760,
-            height=520,
-            hovermode="x unified",
+            continue
+        fig = _line_chart(
+            [("Train", train_series), ("Validation", val_series)],
+            title=title,
+            yaxis_title=yaxis,
         )
-        _style_fig(fig)
-        return {
-            "title": title,
-            "html": pio.to_html(
-                fig,
-                full_html=False,
-                include_plotlyjs="cdn" if include_js else False,
-            ),
-        }
-
-    # Core learning curves
-    for key, title in [
-        ("roc_auc", "ROC-AUC across epochs"),
-        ("precision", "Precision across epochs"),
-        ("recall", "Recall/Sensitivity across epochs"),
-        ("specificity", "Specificity across epochs"),
-    ]:
-        plot = _line_plot(key, title, title.replace("Learning Curve", "").strip())
-        if plot:
-            plots.append(plot)
-            include_js = False
+        plots.append(_wrap_plot(title, fig, include_js=include_js))
+        include_js = False
 
     # Precision vs Recall evolution (validation)
     val_prec = _get_series(label_val, "precision")
     val_rec = _get_series(label_val, "recall")
     if val_prec and val_rec:
-        epochs = list(range(1, min(len(val_prec), len(val_rec)) + 1))
-        fig_pr = go.Figure()
-        fig_pr.add_trace(
-            go.Scatter(
-                x=epochs,
-                y=val_prec[: len(epochs)],
-                mode="lines+markers",
-                name="Precision",
-            )
-        )
-        fig_pr.add_trace(
-            go.Scatter(
-                x=epochs,
-                y=val_rec[: len(epochs)],
-                mode="lines+markers",
-                name="Recall",
-            )
+        max_len = min(len(val_prec), len(val_rec))
+        fig_pr = _line_chart(
+            [
+                ("Precision", val_prec[:max_len]),
+                ("Recall", val_rec[:max_len]),
+            ],
+            title="Validation Precision and Recall by Epoch",
+            yaxis_title="Value",
         )
-        fig_pr.update_layout(
-            title=dict(text="Validation Precision and Recall by Epoch", x=0.5),
-            xaxis_title="Epoch",
-            yaxis_title="Value",
-            width=760,
-            height=520,
-            hovermode="x unified",
-        )
-        _style_fig(fig_pr)
-        plots.append({
-            "title": "Precision vs Recall Evolution",
-            "html": pio.to_html(
-                fig_pr,
-                full_html=False,
-                include_plotlyjs="cdn" if include_js else False,
-            ),
-        })
+        plots.append(_wrap_plot("Precision vs Recall Evolution", fig_pr, include_js=include_js))
         include_js = False
 
-    # F1-score derived
     def _compute_f1(p: List[float], r: List[float]) -> List[float]:
-        f1_vals = []
-        for prec, rec in zip(p, r):
-            if (prec + rec) == 0:
-                f1_vals.append(0.0)
-            else:
-                f1_vals.append(2 * prec * rec / (prec + rec))
-        return f1_vals
+        return [
+            0.0 if (prec + rec) == 0 else 2 * prec * rec / (prec + rec)
+            for prec, rec in zip(p, r)
+        ]
 
     f1_train = _compute_f1(_get_series(label_train, "precision"), _get_series(label_train, "recall"))
     f1_val = _compute_f1(val_prec, val_rec)
     if f1_train or f1_val:
-        fig = go.Figure()
-        if f1_train:
-            fig.add_trace(go.Scatter(x=list(range(1, len(f1_train) + 1)), y=f1_train, mode="lines+markers", name="Train", line=dict(width=4)))
-        if f1_val:
-            fig.add_trace(go.Scatter(x=list(range(1, len(f1_val) + 1)), y=f1_val, mode="lines+markers", name="Validation", line=dict(width=4)))
-        fig.update_layout(
-            title=dict(text="F1-Score across epochs (derived)", x=0.5),
-            xaxis_title="Epoch",
+        fig_f1 = _line_chart(
+            [("Train", f1_train), ("Validation", f1_val)],
+            title="F1-Score across epochs (derived)",
             yaxis_title="F1-Score",
-            width=760,
-            height=520,
-            hovermode="x unified",
         )
-        _style_fig(fig)
-        plots.append({
-            "title": "F1-Score across epochs (derived)",
-            "html": pio.to_html(
-                fig,
-                full_html=False,
-                include_plotlyjs="cdn" if include_js else False,
-            ),
-        })
+        plots.append(_wrap_plot("F1-Score across epochs (derived)", fig_f1, include_js=include_js))
         include_js = False
 
     # Overfitting Gap: Train vs Val ROC-AUC (gap)
     roc_train = _get_series(label_train, "roc_auc")
     roc_val = _get_series(label_val, "roc_auc")
     if roc_train and roc_val:
-        epochs_gap = list(range(1, min(len(roc_train), len(roc_val)) + 1))
-        gaps = [t - v for t, v in zip(roc_train[:len(epochs_gap)], roc_val[:len(epochs_gap)])]
-        fig_gap = go.Figure()
-        fig_gap.add_trace(go.Scatter(x=epochs_gap, y=gaps, mode="lines+markers", name="Train - Val ROC-AUC", line=dict(width=4)))
-        fig_gap.update_layout(
-            title=dict(text="Overfitting gap: ROC-AUC across epochs", x=0.5),
-            xaxis_title="Epoch",
+        max_len = min(len(roc_train), len(roc_val))
+        gaps = [t - v for t, v in zip(roc_train[:max_len], roc_val[:max_len])]
+        fig_gap = _line_chart(
+            [("Train - Val ROC-AUC", gaps)],
+            title="Overfitting gap: ROC-AUC across epochs",
             yaxis_title="Gap",
-            width=760,
-            height=520,
-            hovermode="x unified",
         )
-        _style_fig(fig_gap)
-        plots.append({
-            "title": "Overfitting gap: ROC-AUC across epochs",
-            "html": pio.to_html(
-                fig_gap,
-                full_html=False,
-                include_plotlyjs="cdn" if include_js else False,
-            ),
-        })
+        plots.append(_wrap_plot("Overfitting gap: ROC-AUC across epochs", fig_gap, include_js=include_js))
         include_js = False
 
     # Best Epoch Dashboard (based on max val ROC-AUC)
     if roc_val:
         best_idx = int(np.argmax(roc_val))
         best_epoch = best_idx + 1
-        spec_val = _get_series(label_val, "specificity")
-        metrics_at_best = {
-            "ROC-AUC": roc_val[best_idx] if best_idx < len(roc_val) else None,
-            "Precision": val_prec[best_idx] if best_idx < len(val_prec) else None,
-            "Recall": val_rec[best_idx] if best_idx < len(val_rec) else None,
-            "Specificity": spec_val[best_idx] if best_idx < len(spec_val) else None,
-            "F1-Score": f1_val[best_idx] if best_idx < len(f1_val) else None,
+        metrics_at_best: Dict[str, Optional[float]] = {
+            "ROC-AUC": roc_val[best_idx] if best_idx < len(roc_val) else None
         }
+
+        for metric_key, label in [
+            ("accuracy", "Accuracy"),
+            ("balanced_accuracy", "Balanced Accuracy"),
+            ("precision", "Precision"),
+            ("recall", "Recall"),
+            ("specificity", "Specificity"),
+            ("loss", "Loss"),
+        ]:
+            series = _get_series(label_val, metric_key)
+            if series and best_idx < len(series):
+                metrics_at_best[label] = series[best_idx]
+
+        if f1_val and best_idx < len(f1_val):
+            metrics_at_best["F1-Score (derived)"] = f1_val[best_idx]
+
         fig_best = go.Figure()
         for name, value in metrics_at_best.items():
             if value is not None:
@@ -492,15 +468,7 @@
             showlegend=False,
         )
         _style_fig(fig_best)
-        plots.append({
-            "title": "Best Validation Epoch Snapshot (Metrics)",
-            "html": pio.to_html(
-                fig_best,
-                full_html=False,
-                include_plotlyjs="cdn" if include_js else False,
-            ),
-        })
-        include_js = False
+        plots.append(_wrap_plot("Best Validation Epoch Snapshot (Metrics)", fig_best, include_js=include_js))
 
     return plots
 
@@ -529,46 +497,13 @@
     val_series = _get_regression_series(val_split, metric_key)
     if not train_series and not val_series:
         return None
-    epochs_train = list(range(1, len(train_series) + 1))
-    epochs_val = list(range(1, len(val_series) + 1))
-    fig = go.Figure()
-    if train_series:
-        fig.add_trace(
-            go.Scatter(
-                x=epochs_train,
-                y=train_series,
-                mode="lines+markers",
-                name="Train",
-                line=dict(width=4),
-            )
-        )
-    if val_series:
-        fig.add_trace(
-            go.Scatter(
-                x=epochs_val,
-                y=val_series,
-                mode="lines+markers",
-                name="Validation",
-                line=dict(width=4),
-            )
-        )
-    fig.update_layout(
-        title=dict(text=title, x=0.5),
-        xaxis_title="Epoch",
+
+    fig = _line_chart(
+        [("Train", train_series), ("Validation", val_series)],
+        title=title,
         yaxis_title=yaxis_title,
-        width=760,
-        height=520,
-        hovermode="x unified",
     )
-    _style_fig(fig)
-    return {
-        "title": title,
-        "html": pio.to_html(
-            fig,
-            full_html=False,
-            include_plotlyjs="cdn" if include_js else False,
-        ),
-    }
+    return _wrap_plot(title, fig, include_js=include_js)
 
 
 def build_regression_train_val_plots(train_stats_path: str) -> List[Dict[str, str]]:
@@ -627,46 +562,25 @@
         ("r2", "R² Across Epochs", "R²"),
         ("loss", "Loss Across Epochs", "Loss"),
     ]
-    epochs = None
     for metric_key, title, ytitle in metrics:
         series = _get_regression_series(label_test, metric_key)
         if not series:
             continue
-        if epochs is None:
-            epochs = list(range(1, len(series) + 1))
-        fig = go.Figure()
-        fig.add_trace(
-            go.Scatter(
-                x=epochs,
-                y=series[: len(epochs)],
-                mode="lines+markers",
-                name="Test",
-                line=dict(width=4),
-            )
+        fig = _line_chart(
+            [("Test", series)],
+            title=title,
+            yaxis_title=ytitle,
         )
-        fig.update_layout(
-            title=dict(text=title, x=0.5),
-            xaxis_title="Epoch",
-            yaxis_title=ytitle,
-            width=760,
-            height=520,
-            hovermode="x unified",
-        )
-        _style_fig(fig)
-        plots.append({
-            "title": title,
-            "html": pio.to_html(
-                fig,
-                full_html=False,
-                include_plotlyjs="cdn" if include_js else False,
-            ),
-        })
+        plots.append(_wrap_plot(title, fig, include_js=include_js))
         include_js = False
     return plots
 
 
 def _build_static_roc_plot(
-    label_stats: dict, config: dict, friendly_labels: Optional[List[str]] = None
+    label_stats: dict,
+    config: dict,
+    friendly_labels: Optional[List[str]] = None,
+    threshold: Optional[float] = None,
 ) -> Optional[Dict[str, str]]:
     """Build ROC curve directly from test_statistics.json (single curve)."""
     roc_data = label_stats.get("roc_curve")
@@ -776,6 +690,42 @@
         fig.update_xaxes(range=[0, 1.0])
         fig.update_yaxes(range=[0, 1.05])
 
+        roc_thresholds = roc_data.get("thresholds")
+        if threshold is not None and isinstance(roc_thresholds, list) and len(roc_thresholds) == len(fpr):
+            try:
+                diffs = [abs(th - threshold) for th in roc_thresholds]
+                best_idx = int(np.argmin(diffs))
+                # dashed guides through the chosen point
+                fig.add_shape(
+                    type="line",
+                    x0=fpr[best_idx],
+                    x1=fpr[best_idx],
+                    y0=0,
+                    y1=tpr[best_idx],
+                    line=dict(color="gray", width=2, dash="dash"),
+                )
+                fig.add_shape(
+                    type="line",
+                    x0=0,
+                    x1=fpr[best_idx],
+                    y0=tpr[best_idx],
+                    y1=tpr[best_idx],
+                    line=dict(color="gray", width=2, dash="dash"),
+                )
+                fig.add_trace(
+                    go.Scatter(
+                        x=[fpr[best_idx]],
+                        y=[tpr[best_idx]],
+                        mode="markers",
+                        marker=dict(color="black", size=10, symbol="x"),
+                        name=f"Threshold={threshold}",
+                        hovertemplate="FPR: %{x:.3f}<br>TPR: %{y:.3f}<br>Threshold: %{text}<extra></extra>",
+                        text=[f"{threshold}"],
+                    )
+                )
+            except Exception as exc:
+                print(f"Warning: could not add threshold marker to ROC: {exc}")
+
         fig.add_annotation(
             x=0.5,
             y=-0.15,
@@ -786,21 +736,17 @@
             xanchor="center",
         )
 
-        return {
-            "title": "ROC Curve",
-            "html": pio.to_html(
-                fig,
-                full_html=False,
-                include_plotlyjs=False,
-                config=config,
-            ),
-        }
+        return _wrap_plot("ROC Curve", fig, config=config)
     except Exception as e:
         print(f"Error building ROC plot: {e}")
         return None
 
 
-def _build_precision_recall_plot(label_stats: dict, config: dict) -> Optional[Dict[str, str]]:
+def _build_precision_recall_plot(
+    label_stats: dict,
+    config: dict,
+    threshold: Optional[float] = None,
+) -> Optional[Dict[str, str]]:
     """Build Precision-Recall curve directly from test_statistics.json."""
     pr_data = label_stats.get("precision_recall_curve")
     if not isinstance(pr_data, dict):
@@ -811,6 +757,8 @@
     if not precisions or not recalls or len(precisions) != len(recalls):
         return None
 
+    thresholds = pr_data.get("thresholds")
+
     try:
         fig = go.Figure()
         fig.add_trace(
@@ -851,15 +799,41 @@
         fig.update_xaxes(range=[0, 1.0])
         fig.update_yaxes(range=[0, 1.05])
 
-        return {
-            "title": "Precision-Recall Curve",
-            "html": pio.to_html(
-                fig,
-                full_html=False,
-                include_plotlyjs=False,
-                config=config,
-            ),
-        }
+        if threshold is not None and isinstance(thresholds, list) and len(thresholds) == len(recalls):
+            try:
+                diffs = [abs(th - threshold) for th in thresholds]
+                best_idx = int(np.argmin(diffs))
+                fig.add_shape(
+                    type="line",
+                    x0=recalls[best_idx],
+                    x1=recalls[best_idx],
+                    y0=0,
+                    y1=precisions[best_idx],
+                    line=dict(color="gray", width=2, dash="dash"),
+                )
+                fig.add_shape(
+                    type="line",
+                    x0=0,
+                    x1=recalls[best_idx],
+                    y0=precisions[best_idx],
+                    y1=precisions[best_idx],
+                    line=dict(color="gray", width=2, dash="dash"),
+                )
+                fig.add_trace(
+                    go.Scatter(
+                        x=[recalls[best_idx]],
+                        y=[precisions[best_idx]],
+                        mode="markers",
+                        marker=dict(color="black", size=10, symbol="x"),
+                        name=f"Threshold={threshold}",
+                        hovertemplate="Recall: %{x:.3f}<br>Precision: %{y:.3f}<br>Threshold: %{text}<extra></extra>",
+                        text=[f"{threshold}"],
+                    )
+                )
+            except Exception as exc:
+                print(f"Warning: could not add threshold marker to PR: {exc}")
+
+        return _wrap_plot("Precision-Recall Curve", fig, config=config)
     except Exception as e:
         print(f"Error building Precision-Recall plot: {e}")
         return None
@@ -869,7 +843,6 @@
     predictions_path: str,
     label_data_path: Optional[str] = None,
     split_value: int = 2,
-    threshold: Optional[float] = None,
 ) -> List[Dict[str, str]]:
     """Generate diagnostic plots from predictions.csv for classification tasks."""
     preds_file = Path(predictions_path)
@@ -883,12 +856,89 @@
         return []
 
     plots: List[Dict[str, str]] = []
+    labels_from_dataset: Optional[pd.Series] = None
+
+    filtered_by_split = False
+
+    # If a split column exists, focus on the requested split (e.g., validation=1, test=2).
+    # If not, but label_data_path is available and matches row count, use it to filter predictions.
+    if SPLIT_COLUMN_NAME in df_pred.columns:
+        df_pred = df_pred[df_pred[SPLIT_COLUMN_NAME] == split_value].reset_index(drop=True)
+        if df_pred.empty:
+            return []
+        filtered_by_split = True
+    elif label_data_path and Path(label_data_path).exists():
+        try:
+            df_labels_all = pd.read_csv(label_data_path)
+            if SPLIT_COLUMN_NAME in df_labels_all.columns and len(df_labels_all) == len(df_pred):
+                split_mask = pd.to_numeric(df_labels_all[SPLIT_COLUMN_NAME], errors="coerce") == split_value
+                labels_from_dataset = df_labels_all.loc[split_mask, LABEL_COLUMN_NAME].reset_index(drop=True)
+                df_pred = df_pred.loc[split_mask].reset_index(drop=True)
+                if df_pred.empty:
+                    return []
+                filtered_by_split = True
+        except Exception as exc:
+            print(f"Warning: Unable to filter predictions by split from label data: {exc}")
+
+    # Fallback: no split info available. Assume the predictions file is already filtered
+    # (common for test-only exports) and avoid heuristic slicing that could discard rows.
+    if not filtered_by_split:
+        if split_value != 2:
+            return []
+
+    def _strip_prob_prefix(col: str) -> str:
+        if col.startswith("label_probabilities_"):
+            return col.replace("label_probabilities_", "")
+        if col.startswith("probabilities_"):
+            return col.replace("probabilities_", "")
+        return col
+
+    def _maybe_expand_probabilities_column(df: pd.DataFrame, labels_guess: List[str]) -> List[str]:
+        """If only a single 'probabilities' column exists (list-like), expand it into per-class columns."""
+        if "probabilities" not in df.columns:
+            return []
+        try:
+            # Parse first non-null entry to infer length
+            first_val = df["probabilities"].dropna().iloc[0]
+            parsed = first_val
+            if isinstance(first_val, str):
+                parsed = json.loads(first_val)
+            probs = list(parsed)
+            n = len(probs)
+            if n == 0:
+                return []
+            # Build labels: prefer provided guess; otherwise numeric
+            if labels_guess and len(labels_guess) == n:
+                labels_use = labels_guess
+            else:
+                labels_use = [str(i) for i in range(n)]
+            # Expand column
+            for idx, lbl in enumerate(labels_use):
+                df[f"probabilities_{lbl}"] = df["probabilities"].apply(
+                    lambda v: (json.loads(v)[idx] if isinstance(v, str) else list(v)[idx]) if pd.notnull(v) else np.nan
+                )
+            return [f"probabilities_{lbl}" for lbl in labels_use]
+        except Exception:
+            return []
 
     # Identify probability columns
     prob_cols = [
-        c for c in df_pred.columns
-        if c.startswith("label_probabilities_") and c != "label_probabilities"
+        c
+        for c in df_pred.columns
+        if (
+            (c.startswith("label_probabilities_") or c.startswith("probabilities_"))
+            and c != "label_probabilities"
+        )
     ]
+    if not prob_cols and "label_probability" in df_pred.columns:
+        prob_cols = ["label_probability"]
+    if not prob_cols and "probability" in df_pred.columns:
+        prob_cols = ["probability"]
+    if not prob_cols and "prediction_probability" in df_pred.columns:
+        prob_cols = ["prediction_probability"]
+    if not prob_cols and "probabilities" in df_pred.columns:
+        labels_guess = sorted([str(u) for u in pd.unique(df_pred[LABEL_COLUMN_NAME])])
+        prob_cols = _maybe_expand_probabilities_column(df_pred, labels_guess)
     prob_cols_sorted = sorted(prob_cols)
 
     def _select_positive_prob():
@@ -897,14 +947,14 @@
         # Prefer a column indicating positive/event/true/1
         preferred_keys = ("event", "true", "positive", "pos", "1")
         for col in prob_cols_sorted:
-            suffix = col.replace("label_probabilities_", "").lower()
+            suffix = _strip_prob_prefix(col).lower()
             if any(k in suffix for k in preferred_keys):
                 return col, suffix
         if len(prob_cols_sorted) == 2:
             col = prob_cols_sorted[1]
-            return col, col.replace("label_probabilities_", "")
+            return col, _strip_prob_prefix(col)
         col = prob_cols_sorted[0]
-        return col, col.replace("label_probabilities_", "")
+        return col, _strip_prob_prefix(col)
 
     pos_prob_col, pos_label_hint = _select_positive_prob()
     pos_prob_series = df_pred[pos_prob_col] if pos_prob_col and pos_prob_col in df_pred else None
@@ -920,6 +970,8 @@
 
     # True labels
     def _extract_labels():
+        if labels_from_dataset is not None:
+            return labels_from_dataset
         candidates = [
             LABEL_COLUMN_NAME,
             f"{LABEL_COLUMN_NAME}_ground_truth",
@@ -975,10 +1027,7 @@
             height=500,
         )
         _style_fig(fig_conf)
-        plots.append({
-            "title": "Prediction Confidence Distribution",
-            "html": pio.to_html(fig_conf, full_html=False, include_plotlyjs=False),
-        })
+        plots.append(_wrap_plot("Prediction Confidence Distribution", fig_conf))
 
     # The remaining plots require true labels and a positive-class probability
     if labels_series is None or pos_prob_series is None:
@@ -1004,116 +1053,470 @@
 
     y_true = (y_true_raw == positive_label).astype(int).values
 
-    # Plot 2: Calibration Curve
-    bins = np.linspace(0.0, 1.0, 11)
-    bin_ids = np.digitize(y_score, bins, right=True)
-    bin_centers = []
-    frac_positives = []
-    for b in range(1, len(bins)):
-        mask = bin_ids == b
-        if not np.any(mask):
-            continue
-        bin_centers.append(y_score[mask].mean())
-        frac_positives.append(y_true[mask].mean())
-    if bin_centers and frac_positives:
-        fig_cal = go.Figure()
-        fig_cal.add_trace(
+    # Utility: compute calibration points
+    def _calibration_points(y_true_bin: np.ndarray, scores: np.ndarray):
+        bins = np.linspace(0.0, 1.0, 11)
+        bin_ids = np.digitize(scores, bins, right=True)
+        bin_centers, frac_positives = [], []
+        for b in range(1, len(bins)):
+            mask = bin_ids == b
+            if not np.any(mask):
+                continue
+            bin_centers.append(scores[mask].mean())
+            frac_positives.append(y_true_bin[mask].mean())
+        return bin_centers, frac_positives
+
+    # Plot 2: Calibration Curve (multi-class aware; one-vs-rest per label)
+    label_prob_map = {}
+    for col in prob_cols_sorted:
+        if col.startswith("label_probabilities_"):
+            cls = col.replace("label_probabilities_", "")
+            label_prob_map[cls] = col
+
+    unique_label_strs = [str(u) for u in unique_labels_list]
+    if len(label_prob_map) > 1 and len(unique_label_strs) > 2:
+        # Skip multi-class calibration curve for now (not informative in current report)
+        pass
+    else:
+        # Binary/unknown fallback (previous behavior)
+        bin_centers, frac_positives = _calibration_points(y_true, y_score)
+        if bin_centers and frac_positives:
+            fig_cal = go.Figure()
+            fig_cal.add_trace(
+                go.Scatter(
+                    x=bin_centers,
+                    y=frac_positives,
+                    mode="lines+markers",
+                    name="Calibration",
+                    line=dict(color="#2ca02c", width=4),
+                )
+            )
+            fig_cal.add_trace(
+                go.Scatter(
+                    x=[0, 1],
+                    y=[0, 1],
+                    mode="lines",
+                    name="Perfect Calibration",
+                    line=dict(color="gray", width=2, dash="dash"),
+                )
+            )
+            fig_cal.update_layout(
+                title=dict(text="Calibration Curve", x=0.5),
+                xaxis_title="Predicted probability",
+                yaxis_title="Observed frequency",
+                width=700,
+                height=500,
+            )
+            _style_fig(fig_cal)
+            plots.append(
+                _wrap_plot(
+                    "Calibration Curve (Predicted Probability vs Observed Frequency)",
+                    fig_cal,
+                )
+            )
+
+    return plots
+
+
+def build_binary_threshold_plot(
+    predictions_path: str,
+    label_data_path: Optional[str] = None,
+    split_value: int = 1,
+) -> Optional[Dict[str, str]]:
+    """Build a binary threshold sweep plot (accuracy, precision, recall, F1) for a given split."""
+    preds_file = Path(predictions_path)
+    if not preds_file.exists():
+        return None
+
+    try:
+        df_pred = pd.read_csv(predictions_path)
+    except Exception as exc:
+        print(f"Warning: Unable to read predictions CSV for threshold plot: {exc}")
+        return None
+
+    labels_from_dataset: Optional[pd.Series] = None
+    df_full = df_pred.copy()
+
+    def _filter_by_split(df: pd.DataFrame, split_val: int) -> pd.DataFrame:
+        if SPLIT_COLUMN_NAME in df.columns:
+            return df[df[SPLIT_COLUMN_NAME] == split_val].reset_index(drop=True)
+        return df
+
+    # Try preferred split, then fallback to others with data (val -> test -> train)
+    candidate_splits = [split_value, 2, 0, 1] if split_value == 1 else [split_value, 1, 0, 2]
+    df_candidate = pd.DataFrame()
+    used_split: Optional[int] = None
+    for sv in candidate_splits:
+        df_candidate = _filter_by_split(df_full, sv)
+        if not df_candidate.empty:
+            used_split = sv
+            break
+    if used_split is None:
+        df_candidate = df_full
+    df_pred = df_candidate.reset_index(drop=True)
+
+    # If still empty (e.g., split column exists but no rows for candidates), fall back to all rows
+    if df_pred.empty:
+        df_pred = df_full.reset_index(drop=True)
+        labels_from_dataset = None
+
+    if label_data_path and Path(label_data_path).exists():
+        try:
+            df_labels_all = pd.read_csv(label_data_path)
+            if SPLIT_COLUMN_NAME in df_labels_all.columns and len(df_labels_all) == len(df_full):
+                mask = (
+                    pd.to_numeric(df_labels_all[SPLIT_COLUMN_NAME], errors="coerce") == used_split
+                    if used_split is not None and SPLIT_COLUMN_NAME in df_labels_all.columns
+                    else pd.Series([True] * len(df_full))
+                )
+                labels_from_dataset = df_labels_all.loc[mask, LABEL_COLUMN_NAME].reset_index(drop=True)
+                if len(labels_from_dataset) == len(df_pred):
+                    labels_from_dataset = labels_from_dataset.reset_index(drop=True)
+        except Exception as exc:
+            print(f"Warning: Unable to align labels for threshold plot: {exc}")
+
+    # Identify probability columns
+    prob_cols = [
+        c
+        for c in df_pred.columns
+        if (
+            (c.startswith("label_probabilities_") or c.startswith("probabilities_"))
+            and c != "label_probabilities"
+        )
+    ]
+    if not prob_cols and "probabilities" in df_pred.columns:
+        labels_guess = sorted([str(u) for u in pd.unique(df_pred.get(LABEL_COLUMN_NAME, []))])
+        # reuse expansion logic from diagnostics
+        try:
+            first_val = df_pred["probabilities"].dropna().iloc[0]
+            parsed = json.loads(first_val) if isinstance(first_val, str) else list(first_val)
+            n = len(parsed)
+            if n > 0:
+                if labels_guess and len(labels_guess) == n:
+                    labels_use = labels_guess
+                else:
+                    labels_use = [str(i) for i in range(n)]
+                for idx, lbl in enumerate(labels_use):
+                    df_pred[f"probabilities_{lbl}"] = df_pred["probabilities"].apply(
+                        lambda v: (json.loads(v)[idx] if isinstance(v, str) else list(v)[idx]) if pd.notnull(v) else np.nan
+                    )
+                prob_cols = [f"probabilities_{lbl}" for lbl in labels_use]
+        except Exception:
+            prob_cols = []
+    prob_cols_sorted = sorted(prob_cols)
+
+    def _strip_prob_prefix(col: str) -> str:
+        if col.startswith("label_probabilities_"):
+            return col.replace("label_probabilities_", "")
+        if col.startswith("probabilities_"):
+            return col.replace("probabilities_", "")
+        return col
+
+    # True labels
+    def _extract_labels():
+        if labels_from_dataset is not None:
+            return labels_from_dataset
+        for col in [
+            LABEL_COLUMN_NAME,
+            f"{LABEL_COLUMN_NAME}_ground_truth",
+            f"{LABEL_COLUMN_NAME}__ground_truth",
+            f"{LABEL_COLUMN_NAME}_target",
+            f"{LABEL_COLUMN_NAME}__target",
+            "label",
+            "label_true",
+            "label_predictions",
+            "prediction",
+        ]:
+            if col in df_pred.columns and col not in prob_cols_sorted:
+                return df_pred[col]
+        return None
+
+    labels_series = _extract_labels()
+    if labels_series is None or not prob_cols_sorted:
+        return None
+
+    # Positive prob column selection
+    preferred_keys = ("event", "true", "positive", "pos", "1")
+    pos_prob_col = None
+    for col in prob_cols_sorted:
+        suffix = _strip_prob_prefix(col).lower()
+        if any(k in suffix for k in preferred_keys):
+            pos_prob_col = col
+            break
+    if pos_prob_col is None:
+        pos_prob_col = prob_cols_sorted[-1]
+
+    min_len = min(len(labels_series), len(df_pred[pos_prob_col]))
+    if min_len == 0:
+        return None
+
+    y_true = np.array(labels_series.iloc[:min_len])
+    # map to binary 0/1
+    unique_labels = pd.unique(y_true)
+    if len(unique_labels) < 2:
+        return None
+    positive_label = unique_labels[1] if len(unique_labels) >= 2 else unique_labels[0]
+    y_true_bin = (y_true == positive_label).astype(int)
+    y_score = np.array(df_pred[pos_prob_col].iloc[:min_len], dtype=float)
+
+    thresholds = np.linspace(0.0, 1.0, 101)
+    accs: List[float] = []
+    precs: List[float] = []
+    recs: List[float] = []
+    f1s: List[float] = []
+    for t in thresholds:
+        preds = (y_score >= t).astype(int)
+        accs.append(accuracy_score(y_true_bin, preds))
+        precs.append(precision_score(y_true_bin, preds, zero_division=0))
+        recs.append(recall_score(y_true_bin, preds, zero_division=0))
+        f1s.append(f1_score(y_true_bin, preds, zero_division=0))
+
+    best_idx = int(np.argmax(f1s))
+    best_thr = thresholds[best_idx]
+
+    fig = go.Figure()
+    fig.add_trace(go.Scatter(x=thresholds, y=accs, mode="lines", name="Accuracy", line=dict(width=4)))
+    fig.add_trace(go.Scatter(x=thresholds, y=precs, mode="lines", name="Precision", line=dict(width=4)))
+    fig.add_trace(go.Scatter(x=thresholds, y=recs, mode="lines", name="Recall", line=dict(width=4)))
+    fig.add_trace(go.Scatter(x=thresholds, y=f1s, mode="lines", name="F1-Score", line=dict(width=4)))
+    fig.add_shape(
+        type="line",
+        x0=best_thr,
+        x1=best_thr,
+        y0=0,
+        y1=1,
+        line=dict(color="gray", width=2, dash="dash"),
+    )
+    fig.update_layout(
+        title=dict(text="Threshold plot", x=0.5),
+        xaxis_title="Threshold",
+        yaxis_title="Metric value",
+        yaxis=dict(range=[0, 1]),
+        width=760,
+        height=520,
+        legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1),
+    )
+    _style_fig(fig)
+    return _wrap_plot("Threshold plot", fig, include_js=True)
+
+
+def build_multiclass_roc_pr_plots(
+    predictions_path: str,
+    split_value: int = 2,
+) -> List[Dict[str, str]]:
+    """Build one-vs-rest ROC and PR curves for multi-class classification from predictions."""
+    preds_file = Path(predictions_path)
+    if not preds_file.exists():
+        return []
+    try:
+        df_pred = pd.read_csv(predictions_path)
+    except Exception as exc:
+        print(f"Warning: Unable to read predictions CSV: {exc}")
+        return []
+
+    if SPLIT_COLUMN_NAME in df_pred.columns:
+        df_pred = df_pred[df_pred[SPLIT_COLUMN_NAME] == split_value].reset_index(drop=True)
+    if df_pred.empty:
+        return []
+
+    if LABEL_COLUMN_NAME not in df_pred.columns:
+        return []
+
+    # Identify per-class probability columns
+    prob_cols = [
+        c
+        for c in df_pred.columns
+        if (
+            (c.startswith("label_probabilities_") or c.startswith("probabilities_"))
+            and c != "label_probabilities"
+        )
+    ]
+    if not prob_cols:
+        return []
+    labels = [c.replace("label_probabilities_", "").replace("probabilities_", "") for c in prob_cols]
+    labels_sorted = sorted(labels)
+
+    # Ensure all labels are present as probability columns
+    prob_map = {
+        c.replace("label_probabilities_", "").replace("probabilities_", ""): c
+        for c in prob_cols
+    }
+    if len(labels_sorted) < 3:
+        return []
+
+    y_true_raw = df_pred[LABEL_COLUMN_NAME].astype(str)
+    # Drop rows with NaN probabilities across any class to avoid metric errors
+    prob_matrix = df_pred[[prob_map[lbl] for lbl in labels_sorted]].astype(float)
+    mask_valid = ~prob_matrix.isnull().any(axis=1)
+    prob_matrix = prob_matrix[mask_valid]
+    y_true_raw = y_true_raw[mask_valid]
+    if prob_matrix.empty:
+        return []
+
+    y_true_bin = label_binarize(y_true_raw, classes=labels_sorted)
+    y_score = prob_matrix.to_numpy()
+
+    plots: List[Dict[str, str]] = []
+
+    # ROC: one-vs-rest + micro
+    fig_roc = go.Figure()
+    added_any = False
+    for idx, lbl in enumerate(labels_sorted):
+        if y_true_bin[:, idx].sum() == 0 or y_true_bin[:, idx].sum() == len(y_true_bin):
+            continue  # skip classes without both positives and negatives
+        fpr, tpr, _ = roc_curve(y_true_bin[:, idx], y_score[:, idx])
+        fig_roc.add_trace(
             go.Scatter(
-                x=bin_centers,
-                y=frac_positives,
-                mode="lines+markers",
-                name="Calibration",
-                line=dict(color="#2ca02c", width=4),
-            )
-        )
-        fig_cal.add_trace(
-            go.Scatter(
-                x=[0, 1],
-                y=[0, 1],
+                x=fpr,
+                y=tpr,
                 mode="lines",
-                name="Perfect Calibration",
-                line=dict(color="gray", width=2, dash="dash"),
+                name=f"{lbl} (AUC={auc(fpr, tpr):.3f})",
+                line=dict(width=3),
             )
         )
-        fig_cal.update_layout(
-            title=dict(text="Calibration Curve", x=0.5),
-            xaxis_title="Predicted probability",
-            yaxis_title="Observed frequency",
-            width=700,
-            height=500,
+        added_any = True
+    # Micro-average only if we have mixed labels
+    if y_true_bin.sum() > 0 and y_true_bin.sum() < y_true_bin.size:
+        fpr_micro, tpr_micro, _ = roc_curve(y_true_bin.ravel(), y_score.ravel())
+        fig_roc.add_trace(
+            go.Scatter(
+                x=fpr_micro,
+                y=tpr_micro,
+                mode="lines",
+                name=f"Micro-average (AUC={auc(fpr_micro, tpr_micro):.3f})",
+                line=dict(width=3, dash="dash"),
+            )
         )
-        _style_fig(fig_cal)
-        plots.append({
-            "title": "Calibration Curve (Predicted Probability vs Observed Frequency)",
-            "html": pio.to_html(fig_cal, full_html=False, include_plotlyjs=False),
-        })
-
-    # Plot 3: Threshold vs Metrics
-    thresholds = np.linspace(0.0, 1.0, 21)
-    accs, f1s, sens, specs = [], [], [], []
-    for t in thresholds:
-        y_pred = (y_score >= t).astype(int)
-        tp = np.sum((y_true == 1) & (y_pred == 1))
-        tn = np.sum((y_true == 0) & (y_pred == 0))
-        fp = np.sum((y_true == 0) & (y_pred == 1))
-        fn = np.sum((y_true == 1) & (y_pred == 0))
-        acc = (tp + tn) / max(len(y_true), 1)
-        prec = tp / max(tp + fp, 1e-9)
-        rec = tp / max(tp + fn, 1e-9)
-        f1 = 0 if (prec + rec) == 0 else 2 * prec * rec / (prec + rec)
-        sensitivity = rec
-        specificity = tn / max(tn + fp, 1e-9)
-        accs.append(acc)
-        f1s.append(f1)
-        sens.append(sensitivity)
-        specs.append(specificity)
-
-    fig_thresh = go.Figure()
-    fig_thresh.add_trace(go.Scatter(x=thresholds, y=accs, mode="lines", name="Accuracy", line=dict(width=4)))
-    fig_thresh.add_trace(go.Scatter(x=thresholds, y=f1s, mode="lines", name="F1", line=dict(width=4)))
-    fig_thresh.add_trace(go.Scatter(x=thresholds, y=sens, mode="lines", name="Sensitivity", line=dict(width=4)))
-    fig_thresh.add_trace(go.Scatter(x=thresholds, y=specs, mode="lines", name="Specificity", line=dict(width=4)))
-    fig_thresh.update_layout(
-        title=dict(text="Threshold Sweep: Accuracy, F1, Sensitivity, Specificity", x=0.5),
-        xaxis_title="Decision threshold",
-        yaxis_title="Metric value",
-        width=700,
-        height=500,
+        added_any = True
+    if not added_any:
+        return []
+    fig_roc.add_trace(
+        go.Scatter(
+            x=[0, 1],
+            y=[0, 1],
+            mode="lines",
+            name="Random",
+            line=dict(color="gray", width=2, dash="dot"),
+        )
+    )
+    fig_roc.update_layout(
+        title=dict(text="Multi-class ROC-AUC (one-vs-rest)", x=0.5),
+        xaxis_title="False Positive Rate",
+        yaxis_title="True Positive Rate",
+        width=820,
+        height=620,
         legend=dict(
-            x=0.7,
-            y=0.2,
+            x=0.62,
+            y=0.05,
             bgcolor="rgba(255,255,255,0.9)",
             bordercolor="rgba(0,0,0,0.2)",
             borderwidth=1,
         ),
-        shapes=[
-            dict(
-                type="line",
-                x0=threshold,
-                x1=threshold,
-                y0=0,
-                y1=1,
-                xref="x",
-                yref="paper",
-                line=dict(color="#d62728", width=2, dash="dash"),
+    )
+    _style_fig(fig_roc)
+    plots.append(_wrap_plot("Multi-class ROC-AUC (one-vs-rest)", fig_roc))
+
+    # PR: one-vs-rest + micro AP
+    fig_pr = go.Figure()
+    added_pr = False
+    for idx, lbl in enumerate(labels_sorted):
+        if y_true_bin[:, idx].sum() == 0:
+            continue
+        prec, rec, _ = precision_recall_curve(y_true_bin[:, idx], y_score[:, idx])
+        ap = average_precision_score(y_true_bin[:, idx], y_score[:, idx])
+        fig_pr.add_trace(
+            go.Scatter(
+                x=rec,
+                y=prec,
+                mode="lines",
+                name=f"{lbl} (AP={ap:.3f})",
+                line=dict(width=3),
             )
-        ] if isinstance(threshold, (int, float)) else [],
-        annotations=[
-            dict(
-                x=threshold,
-                y=1.02,
-                xref="x",
-                yref="paper",
-                showarrow=False,
-                text=f"Threshold = {threshold:.2f}",
-                font=dict(size=11, color="#d62728"),
+        )
+        added_pr = True
+    if y_true_bin.sum() > 0:
+        prec_micro, rec_micro, _ = precision_recall_curve(y_true_bin.ravel(), y_score.ravel())
+        ap_micro = average_precision_score(y_true_bin, y_score, average="micro")
+        fig_pr.add_trace(
+            go.Scatter(
+                x=rec_micro,
+                y=prec_micro,
+                mode="lines",
+                name=f"Micro-average (AP={ap_micro:.3f})",
+                line=dict(width=3, dash="dash"),
             )
-        ] if isinstance(threshold, (int, float)) else [],
+        )
+        added_pr = True
+    if not added_pr:
+        return plots
+    fig_pr.update_layout(
+        title=dict(text="Multi-class Precision-Recall (one-vs-rest)", x=0.5),
+        xaxis_title="Recall",
+        yaxis_title="Precision",
+        width=820,
+        height=620,
+        legend=dict(
+            x=0.62,
+            y=0.05,
+            bgcolor="rgba(255,255,255,0.9)",
+            bordercolor="rgba(0,0,0,0.2)",
+            borderwidth=1,
+        ),
     )
-    _style_fig(fig_thresh)
-    plots.append({
-        "title": "Threshold Sweep: Accuracy, F1, Sensitivity, Specificity",
-        "html": pio.to_html(fig_thresh, full_html=False, include_plotlyjs=False),
-    })
+    _style_fig(fig_pr)
+    plots.append(_wrap_plot("Multi-class Precision-Recall (one-vs-rest)", fig_pr))
 
     return plots
+
+
+def build_multiclass_metric_plots(test_stats_path: str) -> List[Dict[str, str]]:
+    """Alternative multi-class transparency plots using test_statistics.json per-class stats."""
+    ts_path = Path(test_stats_path)
+    if not ts_path.exists():
+        return []
+    try:
+        with open(ts_path, "r") as f:
+            test_stats = json.load(f)
+    except Exception:
+        return []
+
+    label_stats = test_stats.get("label", {})
+    pcs = label_stats.get("per_class_stats", {})
+    if not pcs:
+        return []
+    classes = list(pcs.keys())
+    if not classes:
+        return []
+
+    metrics = ["precision", "recall", "f1_score", "specificity", "accuracy"]
+    fig_bar = go.Figure()
+    for metric in metrics:
+        values = []
+        for cls in classes:
+            v = pcs.get(cls, {}).get(metric)
+            values.append(v if isinstance(v, (int, float)) else 0)
+        fig_bar.add_trace(
+            go.Bar(
+                x=classes,
+                y=values,
+                name=metric.replace("_", " ").title(),
+            )
+        )
+    fig_bar.update_layout(
+        title=dict(text="Per-Class Metrics (Test)", x=0.5),
+        xaxis_title="Class",
+        yaxis_title="Metric value",
+        barmode="group",
+        width=900,
+        height=600,
+        legend=dict(
+            x=1.02,
+            y=1.0,
+            bgcolor="rgba(255,255,255,0.9)",
+            bordercolor="rgba(0,0,0,0.2)",
+            borderwidth=1,
+        ),
+    )
+    _style_fig(fig_bar)
+
+    return [_wrap_plot("Per-Class Metrics (Test)", fig_bar)]