diff ludwig_backend.py @ 17:db9be962dc13 draft default tip

planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
author goeckslab
date Wed, 10 Dec 2025 00:24:13 +0000
parents 8729f69e9207
children
line wrap: on
line diff
--- a/ludwig_backend.py	Wed Dec 03 01:28:52 2025 +0000
+++ b/ludwig_backend.py	Wed Dec 10 00:24:13 2025 +0000
@@ -1,8 +1,9 @@
+import inspect
 import json
 import logging
 import os
 from pathlib import Path
-from typing import Any, Dict, Optional, Protocol, Tuple
+from typing import Any, Dict, List, Optional, Protocol, Tuple
 
 import pandas as pd
 import pandas.api.types as ptypes
@@ -17,6 +18,7 @@
     build_tabbed_html,
     encode_image_to_base64,
     format_config_table_html,
+    format_dataset_overview_table,
     format_stats_table_html,
     format_test_merged_stats_table_html,
     format_train_val_stats_table_html,
@@ -33,7 +35,9 @@
 from ludwig.utils.data_utils import get_split_path
 from metaformer_setup import get_visualizations_registry, META_DEFAULT_CFGS
 from plotly_plots import (
+    build_binary_threshold_plot,
     build_classification_plots,
+    build_multiclass_metric_plots,
     build_prediction_diagnostics,
     build_regression_test_plots,
     build_regression_train_val_plots,
@@ -267,6 +271,23 @@
         else:
             encoder_config = {"type": raw_encoder}
 
+        # Set a human-friendly architecture string for reporting
+        arch_display = None
+        if is_metaformer and custom_model:
+            arch_display = str(custom_model)
+        elif isinstance(raw_encoder, dict):
+            enc_type = raw_encoder.get("type")
+            enc_variant = raw_encoder.get("model_variant")
+            if enc_type:
+                base = str(enc_type).replace("_", " ").title()
+                arch_display = f"{base} {enc_variant}" if enc_variant is not None else base
+        else:
+            arch_display = str(raw_encoder).replace("_", " ").title()
+
+        if not arch_display:
+            arch_display = str(model_name)
+        config_params["architecture"] = arch_display
+
         batch_size_cfg = batch_size or "auto"
 
         label_column_path = config_params.get("label_column_data_path")
@@ -343,6 +364,7 @@
             # Force Ludwig to respect our dimensions by setting additional parameters
             image_feat["preprocessing"]["requires_equal_dimensions"] = False
             logger.info(f"Set preprocessing dimensions for MetaFormer: {height}x{width} (infer_dimensions=True with max dimensions to allow validation)")
+            config_params["image_size"] = f"{height}x{width}"
         # Now set the encoder configuration
         image_feat["encoder"] = encoder_config
 
@@ -374,8 +396,12 @@
                     image_feat["preprocessing"]["infer_image_max_height"] = height
                     image_feat["preprocessing"]["infer_image_max_width"] = width
                     logger.info(f"Added resize preprocessing: {height}x{width} for standard encoder with infer_image_dimensions=True and max dimensions")
+                    config_params["image_size"] = f"{height}x{width}"
             except (ValueError, IndexError):
                 logger.warning(f"Invalid image resize format: {config_params['image_resize']}, skipping resize preprocessing")
+        elif not is_metaformer:
+            # No explicit resize provided; keep for reporting purposes
+            config_params.setdefault("image_size", "original")
 
         def _resolve_validation_metric(task: str, requested: Optional[str]) -> Optional[str]:
             """Pick a validation metric that Ludwig will accept for the resolved task."""
@@ -471,6 +497,9 @@
                 config_params.get("validation_metric"),
             )
 
+        # Propagate the resolved validation metric (including any task-based fallback or alias normalization)
+        config_params["validation_metric"] = val_metric
+
         conf: Dict[str, Any] = {
             "model_type": "ecd",
             "input_features": [image_feat],
@@ -641,18 +670,62 @@
         except Exception as e:
             logger.error(f"Error converting Parquet to CSV: {e}")
 
-    def generate_plots(self, output_dir: Path) -> None:
-        """Generate all registered Ludwig visualizations for the latest experiment run."""
-        logger.info("Generating all Ludwig visualizations…")
+    @staticmethod
+    def _extract_metric_series(stats: Dict[str, Any], split: str, prefer: Optional[str] = None) -> Tuple[Optional[str], Optional[List[float]]]:
+        """Pull the first numeric metric list we can find for the requested split."""
+        if not isinstance(stats, dict):
+            return None, None
+
+        split_stats = stats.get(split, {})
+        ordered_metrics: List[Tuple[str, List[float]]] = []
+
+        def _append_metrics(metric_map: Dict[str, Any]) -> None:
+            for metric_name, values in metric_map.items():
+                if isinstance(values, list) and any(isinstance(v, (int, float)) for v in values):
+                    ordered_metrics.append((metric_name, values))
 
-        # Keep only lightweight plots (drop compare_performance/roc_curves)
-        test_plots = {
-            "roc_curves_from_test_statistics",
-            "confusion_matrix",
-        }
+        if isinstance(split_stats, dict):
+            combined = split_stats.get("combined")
+            if isinstance(combined, dict):
+                _append_metrics(combined)
+
+            for feature_name, feature_metrics in split_stats.items():
+                if feature_name == "combined" or not isinstance(feature_metrics, dict):
+                    continue
+                _append_metrics(feature_metrics)
+
+        if prefer:
+            for metric_name, values in ordered_metrics:
+                if metric_name == prefer:
+                    return metric_name, values
+
+        return ordered_metrics[0] if ordered_metrics else (None, None)
+
+    def generate_plots(self, output_dir: Path) -> None:
+        """Generate Ludwig visualizations (train/val + test) for the latest experiment run."""
+        logger.info("Generating Ludwig visualizations (train/val + test)…")
+
+        # Train/validation visualizations
         train_plots = {
             "learning_curves",
-            "compare_classifiers_performance_subset",
+        }
+
+        # Test visualizations (multi-class transparency)
+        test_plots = {
+            "confusion_matrix",
+            "compare_performance",
+            "compare_classifiers_multiclass_multimetric",
+            "frequency_vs_f1",
+            "confidence_thresholding",
+            "confidence_thresholding_data_vs_acc",
+            "confidence_thresholding_data_vs_acc_subset",
+            "confidence_thresholding_data_vs_acc_subset_per_class",
+            # Binary-only visualizations will still be attempted; multi-class replacements handled elsewhere
+            "binary_threshold_vs_metric",
+            "roc_curves",
+            "precision_recall_curves",
+            "calibration_1_vs_all",
+            "calibration_multiclass",
         }
 
         output_dir = Path(output_dir)
@@ -677,7 +750,6 @@
 
         training_stats = _check(exp_dir / "training_statistics.json")
         test_stats = _check(exp_dir / TEST_STATISTICS_FILE_NAME)
-        probs_path = _check(exp_dir / PREDICTIONS_PARQUET_FILE_NAME)
         gt_metadata = _check(exp_dir / "model" / TRAIN_SET_METADATA_FILE_NAME)
 
         dataset_path = None
@@ -688,6 +760,9 @@
                 cfg = json.load(f)
             dataset_path = _check(Path(cfg.get("dataset", "")))
             split_file = _check(Path(get_split_path(cfg.get("dataset", ""))))
+            model_name = cfg.get("model_name", "model")
+        else:
+            model_name = "model"
 
         output_feature = ""
         if desc.exists():
@@ -700,7 +775,44 @@
                 stats = json.load(f)
             output_feature = next(iter(stats.keys()), "")
 
+        probs_path = None
+        prob_candidates = [
+            exp_dir / f"{LABEL_COLUMN_NAME}_probabilities.csv",
+            exp_dir / f"{output_feature}_probabilities.csv" if output_feature else None,
+            exp_dir / "probabilities.csv",
+            exp_dir / "predictions.csv",
+            exp_dir / PREDICTIONS_PARQUET_FILE_NAME,
+        ]
+        for cand in prob_candidates:
+            if cand and Path(cand).exists():
+                probs_path = str(cand)
+                break
+
         viz_registry = get_visualizations_registry()
+        if not viz_registry:
+            logger.warning(
+                "Ludwig visualizations registry not available; train/test PNGs will be skipped."
+            )
+            return
+
+        base_kwargs = {
+            "training_statistics": [training_stats] if training_stats else [],
+            "test_statistics": [test_stats] if test_stats else [],
+            "probabilities": [probs_path] if probs_path else [],
+            "output_feature_name": output_feature,
+            "ground_truth_split": 2,
+            "top_n_classes": [20],
+            "top_k": 3,
+            "metrics": ["f1", "precision", "recall", "accuracy"],
+            "positive_label": 0,
+            "ground_truth_metadata": gt_metadata,
+            "ground_truth": dataset_path,
+            "split_file": split_file,
+            "output_directory": None,  # set per plot below
+            "normalize": False,
+            "file_format": "png",
+            "model_names": [model_name],
+        }
         for viz_name, viz_func in viz_registry.items():
             if viz_name in train_plots:
                 viz_dir_plot = train_viz
@@ -710,25 +822,22 @@
                 continue
 
             try:
+                # Build per-viz kwargs based on the function signature to avoid unexpected args
+                sig_params = set(inspect.signature(viz_func).parameters.keys())
+                call_kwargs = {
+                    k: v
+                    for k, v in base_kwargs.items()
+                    if k in sig_params and v is not None
+                }
+                if "output_directory" in sig_params:
+                    call_kwargs["output_directory"] = str(viz_dir_plot)
+
                 viz_func(
-                    training_statistics=[training_stats] if training_stats else [],
-                    test_statistics=[test_stats] if test_stats else [],
-                    probabilities=[probs_path] if probs_path else [],
-                    output_feature_name=output_feature,
-                    ground_truth_split=2,
-                    top_n_classes=[0],
-                    top_k=3,
-                    ground_truth_metadata=gt_metadata,
-                    ground_truth=dataset_path,
-                    split_file=split_file,
-                    output_directory=str(viz_dir_plot),
-                    normalize=False,
-                    file_format="png",
+                    **call_kwargs,
                 )
                 logger.info(f"✔ Generated {viz_name}")
             except Exception as e:
                 logger.warning(f"✘ Skipped {viz_name}: {e}")
-
         logger.info(f"All visualizations written to {viz_dir}")
 
     def generate_html_report(
@@ -756,6 +865,7 @@
         label_metadata_path = config.get("label_column_data_path")
         if label_metadata_path:
             label_metadata_path = Path(label_metadata_path)
+        dataset_path_from_desc: Optional[Path] = None
 
         # Pull additional config details from description.json if available
         config_for_summary = dict(config)
@@ -765,7 +875,8 @@
         if desc_path.exists():
             try:
                 with open(desc_path, "r") as f:
-                    desc_cfg = json.load(f).get("config", {})
+                    desc_json = json.load(f)
+                desc_cfg = desc_json.get("config", {}) if isinstance(desc_json, dict) else {}
                 encoder_cfg = (
                     desc_cfg.get("input_features", [{}])[0].get("encoder", {})
                     if isinstance(desc_cfg.get("input_features", [{}]), list)
@@ -783,10 +894,20 @@
 
                 arch_type = encoder_cfg.get("type")
                 arch_variant = encoder_cfg.get("model_variant")
+                arch_custom = encoder_cfg.get("custom_model")
                 arch_name = None
+                if arch_custom:
+                    arch_name = str(arch_custom)
                 if arch_type:
                     arch_base = str(arch_type).replace("_", " ").title()
-                    arch_name = f"{arch_base} {arch_variant}" if arch_variant is not None else arch_base
+                    arch_type_name = (
+                        f"{arch_base} {arch_variant}" if arch_variant is not None else arch_base
+                    )
+                    # Prefer explicit custom model names (e.g., MetaFormer) but fall back to encoder type
+                    arch_name = arch_name or arch_type_name
+                if not arch_name and config.get("model_name"):
+                    # As a last resort, show the user-selected model name (handles custom/MetaFormer cases)
+                    arch_name = str(config.get("model_name"))
 
                 summary_fields = {
                     "architecture": arch_name,
@@ -814,12 +935,22 @@
                     if k in {"target_column", "image_column"} and config_for_summary.get(k):
                         continue
                     config_for_summary.setdefault(k, v)
+
+                dataset_field = None
+                if isinstance(desc_json, dict):
+                    dataset_field = desc_json.get("dataset") or desc_cfg.get("dataset")
+                if dataset_field:
+                    try:
+                        dataset_path_from_desc = Path(dataset_field)
+                    except TypeError:
+                        dataset_path_from_desc = None
+                if dataset_path_from_desc and (not label_metadata_path or not label_metadata_path.exists()):
+                    label_metadata_path = dataset_path_from_desc
             except Exception as e:  # pragma: no cover - defensive
                 logger.warning(f"Could not merge description.json into config summary: {e}")
 
         base_viz_dir = exp_dir / "visualizations"
         train_viz_dir = base_viz_dir / "train"
-        test_viz_dir = base_viz_dir / "test"
 
         html = get_html_template()
 
@@ -880,10 +1011,164 @@
       });
     }
   });
-</script>
+        </script>
 """
         html += f"<h1>{title}</h1>"
 
+        def append_plot_blocks(tab_html: str, plots: List[Dict[str, str]], title_suffix: str = "") -> str:
+            """Append Plotly blocks to a tab with consistent markup."""
+            if not plots:
+                return tab_html
+            suffix = title_suffix or ""
+            for plot in plots:
+                tab_html += (
+                    f"<h2 style='text-align: center;'>{plot['title']}{suffix}</h2>"
+                    f"<div class='plotly-center'>{plot['html']}</div>"
+                )
+            return tab_html
+
+        def build_dataset_overview(
+            label_metadata: Optional[Path],
+            output_type: Optional[str],
+            split_probabilities: Optional[List[float]],
+            label_split_counts: Optional[List[Dict[str, int]]] = None,
+            split_counts: Optional[Dict[int, int]] = None,
+            fallback_dataset: Optional[Path] = None,
+        ) -> str:
+            """Summarize dataset distribution across splits using the actual split config."""
+            if label_split_counts:
+                # Use the actual counts captured during data prep instead of heuristics.
+                return format_dataset_overview_table(label_split_counts, regression_mode=False)
+
+            if output_type == "regression" and split_counts:
+                rows = [
+                    {"split": "train", "count": int(split_counts.get(0, 0))},
+                    {"split": "validation", "count": int(split_counts.get(1, 0))},
+                    {"split": "test", "count": int(split_counts.get(2, 0))},
+                ]
+                return format_dataset_overview_table(rows, regression_mode=True)
+
+            candidate_paths: List[Path] = []
+            if label_metadata and label_metadata.exists():
+                candidate_paths.append(label_metadata)
+            if fallback_dataset and fallback_dataset.exists():
+                candidate_paths.append(fallback_dataset)
+            if not candidate_paths:
+                return format_dataset_overview_table([])
+
+            def _normalize_split_probabilities(
+                probs: Optional[List[float]],
+            ) -> Optional[List[float]]:
+                if not probs or len(probs) != 3:
+                    return None
+                try:
+                    probs = [float(p) for p in probs]
+                except (TypeError, ValueError):
+                    return None
+                total = sum(probs)
+                if total <= 0:
+                    return None
+                return [p / total for p in probs]
+
+            def _split_counts_from_column(df: pd.DataFrame) -> Dict[int, int]:
+                if SPLIT_COLUMN_NAME not in df.columns:
+                    return {}
+                split_series = pd.to_numeric(
+                    df[SPLIT_COLUMN_NAME], errors="coerce"
+                ).dropna()
+                if split_series.empty:
+                    return {}
+                split_series = split_series.astype(int)
+                return split_series.value_counts().to_dict()
+
+            def _split_counts_from_probs(total: int, probs: List[float]) -> Dict[int, int]:
+                train_n = int(total * probs[0])
+                val_n = int(total * probs[1])
+                test_n = max(0, total - train_n - val_n)
+                return {0: train_n, 1: val_n, 2: test_n}
+
+            fallback_rows: Optional[List[Dict[str, int]]] = None
+            for meta_path in candidate_paths:
+                try:
+                    df_labels = pd.read_csv(meta_path)
+                    probs = _normalize_split_probabilities(split_probabilities)
+
+                    # Regression (or missing label column): only need split counts
+                    if output_type == "regression" or LABEL_COLUMN_NAME not in df_labels.columns:
+                        split_counts_found = _split_counts_from_column(df_labels)
+                        if split_counts_found:
+                            rows = [
+                                {"split": "train", "count": int(split_counts_found.get(0, 0))},
+                                {"split": "validation", "count": int(split_counts_found.get(1, 0))},
+                                {"split": "test", "count": int(split_counts_found.get(2, 0))},
+                            ]
+                            return format_dataset_overview_table(rows, regression_mode=True)
+                        if probs and fallback_rows is None:
+                            split_counts_found = _split_counts_from_probs(len(df_labels), probs)
+                            fallback_rows = [
+                                {"split": "train", "count": int(split_counts_found.get(0, 0))},
+                                {"split": "validation", "count": int(split_counts_found.get(1, 0))},
+                                {"split": "test", "count": int(split_counts_found.get(2, 0))},
+                            ]
+                        continue
+
+                    # Classification: prefer actual split assignments; fall back to configured probabilities
+                    if SPLIT_COLUMN_NAME in df_labels.columns:
+                        df_counts = df_labels[[LABEL_COLUMN_NAME, SPLIT_COLUMN_NAME]].copy()
+                        df_counts[SPLIT_COLUMN_NAME] = pd.to_numeric(
+                            df_counts[SPLIT_COLUMN_NAME], errors="coerce"
+                        )
+                        df_counts = df_counts.dropna(subset=[SPLIT_COLUMN_NAME])
+                        if df_counts.empty:
+                            continue
+
+                        df_counts[SPLIT_COLUMN_NAME] = df_counts[SPLIT_COLUMN_NAME].astype(int)
+                        df_counts = df_counts.dropna(subset=[LABEL_COLUMN_NAME])
+                        counts = (
+                            df_counts.groupby([LABEL_COLUMN_NAME, SPLIT_COLUMN_NAME])
+                            .size()
+                            .unstack(fill_value=0)
+                            .sort_index()
+                        )
+                        rows = []
+                        for lbl, row in counts.iterrows():
+                            rows.append(
+                                {
+                                    "label": str(lbl),
+                                    "train": int(row.get(0, 0)),
+                                    "validation": int(row.get(1, 0)),
+                                    "test": int(row.get(2, 0)),
+                                }
+                            )
+                        return format_dataset_overview_table(rows)
+
+                    if probs:
+                        label_series = df_labels[LABEL_COLUMN_NAME].dropna()
+                        label_counts = label_series.value_counts().sort_index()
+                        if label_counts.empty:
+                            continue
+                        rows = []
+                        for lbl, count in label_counts.items():
+                            train_n = int(count * probs[0])
+                            val_n = int(count * probs[1])
+                            test_n = max(0, count - train_n - val_n)
+                            rows.append(
+                                {
+                                    "label": str(lbl),
+                                    "train": train_n,
+                                    "validation": val_n,
+                                    "test": test_n,
+                                }
+                            )
+                        fallback_rows = fallback_rows or rows
+                except Exception as exc:
+                    logger.warning("Failed to build dataset overview from %s: %s", meta_path, exc)
+                    continue
+
+            if fallback_rows:
+                return format_dataset_overview_table(fallback_rows, regression_mode=output_type == "regression")
+            return format_dataset_overview_table([])
+
         metrics_html = ""
         train_val_metrics_html = ""
         test_metrics_html = ""
@@ -911,6 +1196,23 @@
                 f"Could not load stats for HTML report: {type(e).__name__}: {e}"
             )
 
+        if not output_type:
+            # Fallback to configured task type when stats are unavailable (e.g., failed run).
+            output_type = (
+                str(config_for_summary.get("task_type")).lower()
+                if config_for_summary.get("task_type")
+                else None
+            )
+
+        dataset_overview_html = build_dataset_overview(
+            label_metadata_path,
+            output_type,
+            config.get("split_probabilities"),
+            config.get("label_split_counts"),
+            config.get("split_counts"),
+            dataset_path_from_desc,
+        )
+
         config_html = ""
         training_progress = self.get_training_process(output_dir)
         try:
@@ -937,11 +1239,12 @@
             exclude_names: Optional[set] = None,
         ) -> str:
             if not dir_path.exists():
-                return f"<h2>{title}</h2><p><em>Directory not found.</em></p>"
+                return ""
 
             exclude_names = exclude_names or set()
 
-            imgs = list(dir_path.glob("*.png"))
+            # Search recursively because Ludwig can nest figures under per-feature folders
+            imgs = list(dir_path.rglob("*.png"))
 
             # Exclude ROC curves and standard confusion matrices (keep only entropy version)
             default_exclude = {
@@ -983,7 +1286,7 @@
             ]
 
             if not imgs:
-                return f"<h2>{title}</h2><p><em>No plots found.</em></p>"
+                return ""
 
             # Sort images by name for consistent ordering (works with string and numeric labels)
             imgs = sorted(imgs, key=lambda x: x.name)
@@ -1006,36 +1309,86 @@
                 )
             return html_section
 
-        # Show performance first, then config
-        tab1_content = metrics_html + config_html
+        # Show dataset overview, performance first, then config
+        predictions_csv_path = exp_dir / "predictions.csv"
+
+        tab1_content = dataset_overview_html + metrics_html + config_html
 
-        tab2_content = train_val_metrics_html + render_img_section(
-            "Training and Validation Visualizations",
-            train_viz_dir,
-            output_type,
-            exclude_names={
-                "compare_classifiers_performance_from_prob.png",
-                "roc_curves_from_prediction_statistics.png",
-                "precision_recall_curves_from_prediction_statistics.png",
-                "precision_recall_curve.png",
-            },
+        tab2_content = train_val_metrics_html
+        # Preload binary threshold plot so it appears first in Train/Val tab
+        threshold_plot = None
+        threshold_value = (
+            config_for_summary.get("threshold")
+            if config_for_summary.get("threshold") is not None
+            else config.get("threshold")
         )
+        if threshold_value is None and output_type == "binary":
+            threshold_value = 0.5
+        if output_type == "binary" and predictions_csv_path.exists():
+            try:
+                threshold_plot = build_binary_threshold_plot(
+                    str(predictions_csv_path),
+                    label_data_path=str(config.get("label_column_data_path"))
+                    if config.get("label_column_data_path")
+                    else None,
+                    split_value=1,
+                )
+            except Exception as e:
+                logger.warning(f"Could not generate validation threshold plot: {e}")
+
         if train_stats_path.exists():
             try:
                 if output_type == "regression":
                     tv_plots = build_regression_train_val_plots(str(train_stats_path))
+                    tab2_content = append_plot_blocks(tab2_content, tv_plots)
                 else:
                     tv_plots = build_train_validation_plots(str(train_stats_path))
-                for plot in tv_plots:
-                    tab2_content += (
-                        f"<h2 style='text-align: center;'>{plot['title']}</h2>"
-                        f"<div class='plotly-center'>{plot['html']}</div>"
-                    )
-                if tv_plots:
-                    logger.info(f"Generated {len(tv_plots)} train/val diagnostic plots")
+                    # Add threshold plot first, then other train/val plots
+                    if threshold_plot:
+                        tab2_content = append_plot_blocks(tab2_content, [threshold_plot])
+                        # Only append once; avoid duplicates if added elsewhere
+                        threshold_plot = None
+                    tab2_content = append_plot_blocks(tab2_content, tv_plots)
+                    if threshold_plot or tv_plots:
+                        logger.info(
+                            f"Added {len(tv_plots) + (1 if threshold_plot else 0)} train/val diagnostic plots"
+                        )
             except Exception as e:
                 logger.warning(f"Could not generate train/val plots: {e}")
 
+        # Only include training PNGs for regression; classification is handled by filtered Plotly plots
+        if output_type == "regression":
+            tab2_content += render_img_section(
+                "Training and Validation Visualizations",
+                train_viz_dir,
+                output_type,
+                exclude_names={
+                    "compare_classifiers_performance_from_prob.png",
+                    "roc_curves_from_prediction_statistics.png",
+                    "precision_recall_curves_from_prediction_statistics.png",
+                    "precision_recall_curve.png",
+                },
+            )
+
+        # Validation diagnostics (calibration/threshold) from predictions.csv, using split=1
+        if output_type in ("binary", "category") and predictions_csv_path.exists():
+            try:
+                val_diag_plots = build_prediction_diagnostics(
+                    str(predictions_csv_path),
+                    label_data_path=str(config.get("label_column_data_path"))
+                    if config.get("label_column_data_path")
+                    else None,
+                    split_value=1,
+                )
+                val_conf_plots = [
+                    p for p in val_diag_plots if "Prediction Confidence Distribution" in p.get("title", "")
+                ]
+                tab2_content = append_plot_blocks(
+                    tab2_content, val_conf_plots, " (Validation)"
+                )
+            except Exception as e:
+                logger.warning(f"Could not generate validation diagnostics: {e}")
+
         # --- Predictions vs Ground Truth table (REGRESSION ONLY) ---
         preds_section = ""
         parquet_path = exp_dir / PREDICTIONS_PARQUET_FILE_NAME
@@ -1077,18 +1430,12 @@
                 logger.warning(f"Could not build Predictions vs GT table: {e}")
 
         tab3_content = test_metrics_html + preds_section
-        test_plotly_added = False
 
         if output_type == "regression" and train_stats_path.exists():
             try:
                 test_plots = build_regression_test_plots(str(train_stats_path))
-                for plot in test_plots:
-                    tab3_content += (
-                        f"<h2 style='text-align: center;'>{plot['title']}</h2>"
-                        f"<div class='plotly-center'>{plot['html']}</div>"
-                    )
+                tab3_content = append_plot_blocks(tab3_content, test_plots)
                 if test_plots:
-                    test_plotly_added = True
                     logger.info(f"Generated {len(test_plots)} regression test plots")
             except Exception as e:
                 logger.warning(f"Could not generate regression test plots: {e}")
@@ -1104,46 +1451,42 @@
                     train_set_metadata_path=str(train_set_metadata_path)
                     if train_set_metadata_path.exists()
                     else None,
+                    threshold=threshold_value,
                 )
-                for plot in interactive_plots:
-                    tab3_content += (
-                        f"<h2 style='text-align: center;'>{plot['title']}</h2>"
-                        f"<div class='plotly-center'>{plot['html']}</div>"
-                    )
+                tab3_content = append_plot_blocks(tab3_content, interactive_plots)
                 if interactive_plots:
-                    test_plotly_added = True
-                logger.info(f"Generated {len(interactive_plots)} interactive Plotly plots")
+                    logger.info(f"Generated {len(interactive_plots)} interactive Plotly plots")
             except Exception as e:
                 logger.warning(f"Could not generate Plotly plots: {e}")
 
-            # Add prediction diagnostics from predictions.csv
-            predictions_csv_path = exp_dir / "predictions.csv"
-            try:
-                diag_plots = build_prediction_diagnostics(
-                    str(predictions_csv_path),
-                    label_data_path=str(config.get("label_column_data_path"))
-                    if config.get("label_column_data_path")
-                    else None,
-                    threshold=config.get("threshold"),
-                )
-                for plot in diag_plots:
-                    tab3_content += (
-                        f"<h2 style='text-align: center;'>{plot['title']}</h2>"
-                        f"<div class='plotly-center'>{plot['html']}</div>"
+            # Multi-class transparency plots from test stats (replace ROC/PR for multi-class)
+            if output_type == "category" and test_stats_path.exists():
+                try:
+                    multi_curves = build_multiclass_metric_plots(str(test_stats_path))
+                    tab3_content = append_plot_blocks(tab3_content, multi_curves)
+                    if multi_curves:
+                        logger.info("Added multi-class per-class metric plots to test tab")
+                except Exception as e:
+                    logger.warning(f"Could not generate multi-class metric plots: {e}")
+
+            # Test diagnostics (confidence histogram) from predictions.csv, using split=2
+            if predictions_csv_path.exists():
+                try:
+                    test_diag_plots = build_prediction_diagnostics(
+                        str(predictions_csv_path),
+                        label_data_path=str(config.get("label_column_data_path"))
+                        if config.get("label_column_data_path")
+                        else None,
+                        split_value=2,
                     )
-                if diag_plots:
-                    test_plotly_added = True
-                    logger.info(f"Generated {len(diag_plots)} prediction diagnostic plots")
-            except Exception as e:
-                logger.warning(f"Could not generate prediction diagnostics: {e}")
-
-        # Fallback: include static PNGs if no interactive plots were added
-        if not test_plotly_added:
-            tab3_content += render_img_section(
-                "Test Visualizations (PNG fallback)",
-                test_viz_dir,
-                output_type,
-            )
+                    test_conf_plots = [
+                        p for p in test_diag_plots if "Prediction Confidence Distribution" in p.get("title", "")
+                    ]
+                    if test_conf_plots:
+                        tab3_content = append_plot_blocks(tab3_content, test_conf_plots)
+                        logger.info("Added test prediction confidence plot")
+                except Exception as e:
+                    logger.warning(f"Could not generate test diagnostics: {e}")
 
         # Add static TEST PNGs (with default dedupe/exclusions)
         tabbed_html = build_tabbed_html(tab1_content, tab2_content, tab3_content)