diff ludwig_backend.py @ 12:bcfa2e234a80 draft

planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
author goeckslab
date Fri, 21 Nov 2025 15:58:13 +0000
parents
children
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/ludwig_backend.py	Fri Nov 21 15:58:13 2025 +0000
@@ -0,0 +1,893 @@
+import json
+import logging
+from pathlib import Path
+from typing import Any, Dict, Optional, Protocol, Tuple
+
+import pandas as pd
+import pandas.api.types as ptypes
+import yaml
+from constants import (
+    IMAGE_PATH_COLUMN_NAME,
+    LABEL_COLUMN_NAME,
+    MODEL_ENCODER_TEMPLATES,
+    SPLIT_COLUMN_NAME,
+)
+from html_structure import (
+    build_tabbed_html,
+    encode_image_to_base64,
+    format_config_table_html,
+    format_stats_table_html,
+    format_test_merged_stats_table_html,
+    format_train_val_stats_table_html,
+    get_html_closing,
+    get_html_template,
+    get_metrics_help_modal,
+)
+from ludwig.globals import (
+    DESCRIPTION_FILE_NAME,
+    PREDICTIONS_PARQUET_FILE_NAME,
+    TEST_STATISTICS_FILE_NAME,
+    TRAIN_SET_METADATA_FILE_NAME,
+)
+from ludwig.utils.data_utils import get_split_path
+from metaformer_setup import get_visualizations_registry, META_DEFAULT_CFGS
+from plotly_plots import build_classification_plots
+from utils import detect_output_type, extract_metrics_from_json
+
+logger = logging.getLogger("ImageLearner")
+
+
+class Backend(Protocol):
+    """Interface for a machine learning backend."""
+
+    def prepare_config(
+        self,
+        config_params: Dict[str, Any],
+        split_config: Dict[str, Any],
+    ) -> str:
+        ...
+
+    def run_experiment(
+        self,
+        dataset_path: Path,
+        config_path: Path,
+        output_dir: Path,
+        random_seed: int,
+    ) -> None:
+        ...
+
+    def generate_plots(self, output_dir: Path) -> None:
+        ...
+
+    def generate_html_report(
+        self,
+        title: str,
+        output_dir: str,
+        config: Dict[str, Any],
+        split_info: str,
+    ) -> Path:
+        ...
+
+
+class LudwigDirectBackend:
+    """Backend for running Ludwig experiments directly via the internal experiment_cli function."""
+
+    def _detect_image_dimensions(self, image_zip_path: str) -> Tuple[int, int]:
+        """Detect image dimensions from the first image in the dataset."""
+        try:
+            import zipfile
+            from PIL import Image
+            import io
+
+            # Check if image_zip is provided
+            if not image_zip_path:
+                logger.warning("No image zip provided, using default 224x224")
+                return 224, 224
+
+            # Extract first image to detect dimensions
+            with zipfile.ZipFile(image_zip_path, 'r') as z:
+                image_files = [f for f in z.namelist() if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
+                if not image_files:
+                    logger.warning("No image files found in zip, using default 224x224")
+                    return 224, 224
+
+                # Check first image
+                with z.open(image_files[0]) as f:
+                    img = Image.open(io.BytesIO(f.read()))
+                    width, height = img.size
+                    logger.info(f"Detected image dimensions: {width}x{height}")
+                    return height, width  # Return as (height, width) to match encoder config
+
+        except Exception as e:
+            logger.warning(f"Error detecting image dimensions: {e}, using default 224x224")
+            return 224, 224
+
+    def prepare_config(
+        self,
+        config_params: Dict[str, Any],
+        split_config: Dict[str, Any],
+    ) -> str:
+        logger.info("LudwigDirectBackend: Preparing YAML configuration.")
+
+        model_name = config_params.get("model_name", "resnet18")
+        use_pretrained = config_params.get("use_pretrained", False)
+        fine_tune = config_params.get("fine_tune", False)
+        if use_pretrained:
+            trainable = bool(fine_tune)
+        else:
+            trainable = True
+        epochs = config_params.get("epochs", 10)
+        batch_size = config_params.get("batch_size")
+        num_processes = config_params.get("preprocessing_num_processes", 1)
+        early_stop = config_params.get("early_stop", None)
+        learning_rate = config_params.get("learning_rate")
+        learning_rate = "auto" if learning_rate is None else float(learning_rate)
+        raw_encoder = MODEL_ENCODER_TEMPLATES.get(model_name, model_name)
+
+        # --- MetaFormer detection and config logic ---
+        def _is_metaformer(name: str) -> bool:
+            return isinstance(name, str) and name.startswith(
+                (
+                    "identityformer_",
+                    "randformer_",
+                    "poolformerv2_",
+                    "convformer_",
+                    "caformer_",
+                )
+            )
+
+        # Check if this is a MetaFormer model (either direct name or in custom_model)
+        is_metaformer = (
+            _is_metaformer(model_name)
+            or (isinstance(raw_encoder, dict) and "custom_model" in raw_encoder and _is_metaformer(raw_encoder["custom_model"]))
+        )
+
+        metaformer_resize: Optional[Tuple[int, int]] = None
+        metaformer_channels = 3
+
+        if is_metaformer:
+            # Handle MetaFormer models
+            custom_model = None
+            if isinstance(raw_encoder, dict) and "custom_model" in raw_encoder:
+                custom_model = raw_encoder["custom_model"]
+            else:
+                custom_model = model_name
+
+            logger.info(f"DETECTED MetaFormer model: {custom_model}")
+            cfg_channels, cfg_height, cfg_width = 3, 224, 224
+            if META_DEFAULT_CFGS:
+                model_cfg = META_DEFAULT_CFGS.get(custom_model, {})
+                input_size = model_cfg.get("input_size")
+                if isinstance(input_size, (list, tuple)) and len(input_size) == 3:
+                    cfg_channels, cfg_height, cfg_width = (
+                        int(input_size[0]),
+                        int(input_size[1]),
+                        int(input_size[2]),
+                    )
+
+            target_height, target_width = cfg_height, cfg_width
+            resize_value = config_params.get("image_resize")
+            if resize_value and resize_value != "original":
+                try:
+                    dimensions = resize_value.split("x")
+                    if len(dimensions) == 2:
+                        target_height, target_width = int(dimensions[0]), int(dimensions[1])
+                        if target_height <= 0 or target_width <= 0:
+                            raise ValueError(
+                                f"Image resize must be positive integers, received {resize_value}."
+                            )
+                        logger.info(f"MetaFormer explicit resize: {target_height}x{target_width}")
+                    else:
+                        raise ValueError(resize_value)
+                except (ValueError, IndexError):
+                    logger.warning(
+                        "Invalid image resize format '%s'; falling back to model default %sx%s",
+                        resize_value,
+                        cfg_height,
+                        cfg_width,
+                    )
+                    target_height, target_width = cfg_height, cfg_width
+            else:
+                image_zip_path = config_params.get("image_zip", "")
+                detected_height, detected_width = self._detect_image_dimensions(image_zip_path)
+                if use_pretrained:
+                    if (detected_height, detected_width) != (cfg_height, cfg_width):
+                        logger.info(
+                            "MetaFormer pretrained weights expect %sx%s; resizing from detected %sx%s",
+                            cfg_height,
+                            cfg_width,
+                            detected_height,
+                            detected_width,
+                        )
+                else:
+                    target_height, target_width = detected_height, detected_width
+                if target_height <= 0 or target_width <= 0:
+                    raise ValueError(
+                        f"Invalid detected image dimensions for MetaFormer: {target_height}x{target_width}."
+                    )
+
+            metaformer_channels = cfg_channels
+            metaformer_resize = (target_height, target_width)
+
+            encoder_config = {
+                "type": "stacked_cnn",
+                "height": target_height,
+                "width": target_width,
+                "num_channels": metaformer_channels,
+                "output_size": 128,
+                "use_pretrained": use_pretrained,
+                "trainable": trainable,
+                "custom_model": custom_model,
+            }
+
+        elif isinstance(raw_encoder, dict):
+            # Handle image resize for regular encoders
+            # Note: Standard encoders like ResNet don't support height/width parameters
+            # Resize will be handled at the preprocessing level by Ludwig
+            if config_params.get("image_resize") and config_params["image_resize"] != "original":
+                logger.info(f"Resize requested: {config_params['image_resize']} for standard encoder. Resize will be handled at preprocessing level.")
+
+            encoder_config = {
+                **raw_encoder,
+                "use_pretrained": use_pretrained,
+                "trainable": trainable,
+            }
+        else:
+            encoder_config = {"type": raw_encoder}
+
+        batch_size_cfg = batch_size or "auto"
+
+        label_column_path = config_params.get("label_column_data_path")
+        label_series = None
+        label_metadata_hint = config_params.get("label_metadata") or {}
+        output_type_hint = config_params.get("output_type_hint")
+        num_unique_labels = int(label_metadata_hint.get("num_unique", 2))
+        numeric_binary_labels = bool(label_metadata_hint.get("is_numeric_binary", False))
+        likely_regression = bool(label_metadata_hint.get("likely_regression", False))
+        if label_column_path is not None and Path(label_column_path).exists():
+            try:
+                label_series = pd.read_csv(label_column_path)[LABEL_COLUMN_NAME]
+                non_na = label_series.dropna()
+                if not non_na.empty:
+                    num_unique_labels = non_na.nunique()
+                is_numeric = ptypes.is_numeric_dtype(label_series.dtype)
+                numeric_binary_labels = is_numeric and num_unique_labels == 2
+                likely_regression = (
+                    is_numeric and not numeric_binary_labels and num_unique_labels > 10
+                )
+                if numeric_binary_labels:
+                    logger.info(
+                        "Detected numeric binary labels in '%s'; configuring Ludwig for binary classification.",
+                        LABEL_COLUMN_NAME,
+                    )
+            except Exception as e:
+                logger.warning(f"Could not read label column for task detection: {e}")
+
+        if output_type_hint == "binary":
+            num_unique_labels = 2
+            numeric_binary_labels = numeric_binary_labels or bool(
+                label_metadata_hint.get("is_numeric", False)
+            )
+
+        if numeric_binary_labels:
+            task_type = "classification"
+        elif likely_regression:
+            task_type = "regression"
+        else:
+            task_type = "classification"
+
+        if task_type == "regression" and numeric_binary_labels:
+            logger.warning(
+                "Numeric binary labels detected but regression task chosen; forcing classification to avoid invalid Ludwig config."
+            )
+            task_type = "classification"
+
+        config_params["task_type"] = task_type
+
+        image_feat: Dict[str, Any] = {
+            "name": IMAGE_PATH_COLUMN_NAME,
+            "type": "image",
+        }
+        # Set preprocessing dimensions FIRST for MetaFormer models
+        if is_metaformer:
+            if metaformer_resize is None:
+                metaformer_resize = (224, 224)
+            height, width = metaformer_resize
+
+            # CRITICAL: Set preprocessing dimensions FIRST for MetaFormer models
+            # This is essential for MetaFormer models to work properly
+            if "preprocessing" not in image_feat:
+                image_feat["preprocessing"] = {}
+            image_feat["preprocessing"]["height"] = height
+            image_feat["preprocessing"]["width"] = width
+            # Use infer_image_dimensions=True to allow Ludwig to read images for validation
+            # but set explicit max dimensions to control the output size
+            image_feat["preprocessing"]["infer_image_dimensions"] = True
+            image_feat["preprocessing"]["infer_image_max_height"] = height
+            image_feat["preprocessing"]["infer_image_max_width"] = width
+            image_feat["preprocessing"]["num_channels"] = metaformer_channels
+            image_feat["preprocessing"]["resize_method"] = "interpolate"  # Use interpolation for better quality
+            image_feat["preprocessing"]["standardize_image"] = "imagenet1k"  # Use ImageNet standardization
+            # Force Ludwig to respect our dimensions by setting additional parameters
+            image_feat["preprocessing"]["requires_equal_dimensions"] = False
+            logger.info(f"Set preprocessing dimensions for MetaFormer: {height}x{width} (infer_dimensions=True with max dimensions to allow validation)")
+        # Now set the encoder configuration
+        image_feat["encoder"] = encoder_config
+
+        if config_params.get("augmentation") is not None:
+            image_feat["augmentation"] = config_params["augmentation"]
+
+        # Add resize configuration for standard encoders (ResNet, etc.)
+        # FIXED: MetaFormer models now respect user dimensions completely
+        # Previously there was a double resize issue where MetaFormer would force 224x224
+        # Now both MetaFormer and standard encoders respect user's resize choice
+        if (not is_metaformer) and config_params.get("image_resize") and config_params["image_resize"] != "original":
+            try:
+                dimensions = config_params["image_resize"].split("x")
+                if len(dimensions) == 2:
+                    height, width = int(dimensions[0]), int(dimensions[1])
+                    if height <= 0 or width <= 0:
+                        raise ValueError(
+                            f"Image resize must be positive integers, received {config_params['image_resize']}."
+                        )
+
+                    # Add resize to preprocessing for standard encoders
+                    if "preprocessing" not in image_feat:
+                        image_feat["preprocessing"] = {}
+                    image_feat["preprocessing"]["height"] = height
+                    image_feat["preprocessing"]["width"] = width
+                    # Use infer_image_dimensions=True to allow Ludwig to read images for validation
+                    # but set explicit max dimensions to control the output size
+                    image_feat["preprocessing"]["infer_image_dimensions"] = True
+                    image_feat["preprocessing"]["infer_image_max_height"] = height
+                    image_feat["preprocessing"]["infer_image_max_width"] = width
+                    logger.info(f"Added resize preprocessing: {height}x{width} for standard encoder with infer_image_dimensions=True and max dimensions")
+            except (ValueError, IndexError):
+                logger.warning(f"Invalid image resize format: {config_params['image_resize']}, skipping resize preprocessing")
+        if task_type == "regression":
+            output_feat = {
+                "name": LABEL_COLUMN_NAME,
+                "type": "number",
+                "decoder": {"type": "regressor"},
+                "loss": {"type": "mean_squared_error"},
+            }
+            val_metric = config_params.get("validation_metric", "mean_squared_error")
+
+        else:
+            if num_unique_labels == 2:
+                output_feat = {
+                    "name": LABEL_COLUMN_NAME,
+                    "type": "binary",
+                    "loss": {"type": "binary_weighted_cross_entropy"},
+                }
+                if config_params.get("threshold") is not None:
+                    output_feat["threshold"] = float(config_params["threshold"])
+            else:
+                output_feat = {
+                    "name": LABEL_COLUMN_NAME,
+                    "type": "category",
+                    "loss": {"type": "softmax_cross_entropy"},
+                }
+            val_metric = None
+
+        conf: Dict[str, Any] = {
+            "model_type": "ecd",
+            "input_features": [image_feat],
+            "output_features": [output_feat],
+            "combiner": {"type": "concat"},
+            "trainer": {
+                "epochs": epochs,
+                "early_stop": early_stop,
+                "batch_size": batch_size_cfg,
+                "learning_rate": learning_rate,
+                # only set validation_metric for regression
+                **({"validation_metric": val_metric} if val_metric else {}),
+            },
+            "preprocessing": {
+                "split": split_config,
+                "num_processes": num_processes,
+                "in_memory": False,
+            },
+        }
+
+        logger.debug("LudwigDirectBackend: Config dict built.")
+        try:
+            yaml_str = yaml.dump(conf, sort_keys=False, indent=2)
+            logger.info("LudwigDirectBackend: YAML config generated.")
+            return yaml_str
+        except Exception:
+            logger.error(
+                "LudwigDirectBackend: Failed to serialize YAML.",
+                exc_info=True,
+            )
+            raise
+
+    def run_experiment(
+        self,
+        dataset_path: Path,
+        config_path: Path,
+        output_dir: Path,
+        random_seed: int = 42,
+    ) -> None:
+        """Invoke Ludwig's internal experiment_cli function to run the experiment."""
+        logger.info("LudwigDirectBackend: Starting experiment execution.")
+
+        try:
+            from ludwig.experiment import experiment_cli
+        except ImportError as e:
+            logger.error(
+                "LudwigDirectBackend: Could not import experiment_cli.",
+                exc_info=True,
+            )
+            raise RuntimeError("Ludwig import failed.") from e
+
+        output_dir.mkdir(parents=True, exist_ok=True)
+
+        try:
+            experiment_cli(
+                dataset=str(dataset_path),
+                config=str(config_path),
+                output_directory=str(output_dir),
+                random_seed=random_seed,
+                skip_preprocessing=True,
+            )
+            logger.info(
+                f"LudwigDirectBackend: Experiment completed. Results in {output_dir}"
+            )
+        except TypeError as e:
+            logger.error(
+                "LudwigDirectBackend: Argument mismatch in experiment_cli call.",
+                exc_info=True,
+            )
+            raise RuntimeError("Ludwig argument error.") from e
+        except Exception:
+            logger.error(
+                "LudwigDirectBackend: Experiment execution error.",
+                exc_info=True,
+            )
+            raise
+
+    def get_training_process(self, output_dir) -> Optional[Dict[str, Any]]:
+        """Retrieve the learning rate used in the most recent Ludwig run."""
+        output_dir = Path(output_dir)
+        exp_dirs = sorted(
+            output_dir.glob("experiment_run*"),
+            key=lambda p: p.stat().st_mtime,
+        )
+
+        if not exp_dirs:
+            logger.warning(f"No experiment run directories found in {output_dir}")
+            return None
+
+        progress_file = exp_dirs[-1] / "model" / "training_progress.json"
+        if not progress_file.exists():
+            logger.warning(f"No training_progress.json found in {progress_file}")
+            return None
+
+        try:
+            with progress_file.open("r", encoding="utf-8") as f:
+                data = json.load(f)
+            return {
+                "learning_rate": data.get("learning_rate"),
+                "batch_size": data.get("batch_size"),
+                "epoch": data.get("epoch"),
+            }
+        except Exception as e:
+            logger.warning(f"Failed to read training progress info: {e}")
+            return {}
+
+    def convert_parquet_to_csv(self, output_dir: Path):
+        """Convert the predictions Parquet file to CSV."""
+        output_dir = Path(output_dir)
+        exp_dirs = sorted(
+            output_dir.glob("experiment_run*"),
+            key=lambda p: p.stat().st_mtime,
+        )
+        if not exp_dirs:
+            logger.warning(f"No experiment run dirs found in {output_dir}")
+            return
+        exp_dir = exp_dirs[-1]
+        parquet_path = exp_dir / PREDICTIONS_PARQUET_FILE_NAME
+        csv_path = exp_dir / "predictions.csv"
+
+        # Check if parquet file exists before trying to convert
+        if not parquet_path.exists():
+            logger.info(f"Predictions parquet file not found at {parquet_path}, skipping conversion")
+            return
+
+        try:
+            df = pd.read_parquet(parquet_path)
+            df.to_csv(csv_path, index=False)
+            logger.info(f"Converted Parquet to CSV: {csv_path}")
+        except Exception as e:
+            logger.error(f"Error converting Parquet to CSV: {e}")
+
+    def generate_plots(self, output_dir: Path) -> None:
+        """Generate all registered Ludwig visualizations for the latest experiment run."""
+        logger.info("Generating all Ludwig visualizations…")
+
+        test_plots = {
+            "compare_performance",
+            "compare_classifiers_performance_from_prob",
+            "compare_classifiers_performance_from_pred",
+            "compare_classifiers_performance_changing_k",
+            "compare_classifiers_multiclass_multimetric",
+            "compare_classifiers_predictions",
+            "confidence_thresholding_2thresholds_2d",
+            "confidence_thresholding_2thresholds_3d",
+            "confidence_thresholding",
+            "confidence_thresholding_data_vs_acc",
+            "binary_threshold_vs_metric",
+            "roc_curves",
+            "roc_curves_from_test_statistics",
+            "calibration_1_vs_all",
+            "calibration_multiclass",
+            "confusion_matrix",
+            "frequency_vs_f1",
+        }
+        train_plots = {
+            "learning_curves",
+            "compare_classifiers_performance_subset",
+        }
+
+        output_dir = Path(output_dir)
+        exp_dirs = sorted(
+            output_dir.glob("experiment_run*"),
+            key=lambda p: p.stat().st_mtime,
+        )
+        if not exp_dirs:
+            logger.warning(f"No experiment run dirs found in {output_dir}")
+            return
+        exp_dir = exp_dirs[-1]
+
+        viz_dir = exp_dir / "visualizations"
+        viz_dir.mkdir(exist_ok=True)
+        train_viz = viz_dir / "train"
+        test_viz = viz_dir / "test"
+        train_viz.mkdir(parents=True, exist_ok=True)
+        test_viz.mkdir(parents=True, exist_ok=True)
+
+        def _check(p: Path) -> Optional[str]:
+            return str(p) if p.exists() else None
+
+        training_stats = _check(exp_dir / "training_statistics.json")
+        test_stats = _check(exp_dir / TEST_STATISTICS_FILE_NAME)
+        probs_path = _check(exp_dir / PREDICTIONS_PARQUET_FILE_NAME)
+        gt_metadata = _check(exp_dir / "model" / TRAIN_SET_METADATA_FILE_NAME)
+
+        dataset_path = None
+        split_file = None
+        desc = exp_dir / DESCRIPTION_FILE_NAME
+        if desc.exists():
+            with open(desc, "r") as f:
+                cfg = json.load(f)
+            dataset_path = _check(Path(cfg.get("dataset", "")))
+            split_file = _check(Path(get_split_path(cfg.get("dataset", ""))))
+
+        output_feature = ""
+        if desc.exists():
+            try:
+                output_feature = cfg["config"]["output_features"][0]["name"]
+            except Exception:
+                pass
+        if not output_feature and test_stats:
+            with open(test_stats, "r") as f:
+                stats = json.load(f)
+            output_feature = next(iter(stats.keys()), "")
+
+        viz_registry = get_visualizations_registry()
+        for viz_name, viz_func in viz_registry.items():
+            if viz_name in train_plots:
+                viz_dir_plot = train_viz
+            elif viz_name in test_plots:
+                viz_dir_plot = test_viz
+            else:
+                continue
+
+            try:
+                viz_func(
+                    training_statistics=[training_stats] if training_stats else [],
+                    test_statistics=[test_stats] if test_stats else [],
+                    probabilities=[probs_path] if probs_path else [],
+                    output_feature_name=output_feature,
+                    ground_truth_split=2,
+                    top_n_classes=[0],
+                    top_k=3,
+                    ground_truth_metadata=gt_metadata,
+                    ground_truth=dataset_path,
+                    split_file=split_file,
+                    output_directory=str(viz_dir_plot),
+                    normalize=False,
+                    file_format="png",
+                )
+                logger.info(f"✔ Generated {viz_name}")
+            except Exception as e:
+                logger.warning(f"✘ Skipped {viz_name}: {e}")
+
+        logger.info(f"All visualizations written to {viz_dir}")
+
+    def generate_html_report(
+        self,
+        title: str,
+        output_dir: str,
+        config: dict,
+        split_info: str,
+    ) -> Path:
+        """Assemble an HTML report from visualizations under train_val/ and test/ folders."""
+        cwd = Path.cwd()
+        report_name = title.lower().replace(" ", "_") + "_report.html"
+        report_path = cwd / report_name
+        output_dir = Path(output_dir)
+        output_type = None
+
+        exp_dirs = sorted(
+            output_dir.glob("experiment_run*"),
+            key=lambda p: p.stat().st_mtime,
+        )
+        if not exp_dirs:
+            raise RuntimeError(f"No 'experiment*' dirs found in {output_dir}")
+        exp_dir = exp_dirs[-1]
+
+        base_viz_dir = exp_dir / "visualizations"
+        train_viz_dir = base_viz_dir / "train"
+        test_viz_dir = base_viz_dir / "test"
+
+        html = get_html_template()
+
+        # Extra CSS & JS: center Plotly and enable CSV download for predictions table
+        html += """
+<style>
+  /* Center Plotly figures (both wrapper and native classes) */
+  .plotly-center { display: flex; justify-content: center; }
+  .plotly-center .plotly-graph-div, .plotly-center .js-plotly-plot { margin: 0 auto !important; }
+  .js-plotly-plot, .plotly-graph-div { margin-left: auto !important; margin-right: auto !important; }
+
+  /* Download button for predictions table */
+  .download-btn {
+    padding: 8px 12px;
+    border: 1px solid #4CAF50;
+    background: #4CAF50;
+    color: white;
+    border-radius: 6px;
+    cursor: pointer;
+  }
+  .download-btn:hover { filter: brightness(0.95); }
+  .preds-controls {
+    display: flex;
+    justify-content: flex-end;
+    gap: 8px;
+    margin: 8px 0;
+  }
+</style>
+<script>
+  function tableToCSV(table){
+    const rows = Array.from(table.querySelectorAll('tr'));
+    return rows.map(row =>
+      Array.from(row.querySelectorAll('th,td')).map(cell => {
+        let text = cell.innerText.replace(/\\r?\\n|\\r/g,' ').trim();
+        if (text.includes('"') || text.includes(',')) {
+          text = '"' + text.replace(/"/g,'""') + '"';
+        }
+        return text;
+      }).join(',')
+    ).join('\\n');
+  }
+  document.addEventListener('DOMContentLoaded', function(){
+    const btn = document.getElementById('downloadPredsCsv');
+    if(btn){
+      btn.addEventListener('click', function(){
+        const tbl = document.querySelector('.predictions-table');
+        if(!tbl){ alert('Predictions table not found.'); return; }
+        const csv = tableToCSV(tbl);
+        const blob = new Blob([csv], {type: 'text/csv;charset=utf-8;'});
+        const url = URL.createObjectURL(blob);
+        const a = document.createElement('a');
+        a.href = url;
+        a.download = 'ground_truth_vs_predictions.csv';
+        document.body.appendChild(a);
+        a.click();
+        document.body.removeChild(a);
+        URL.revokeObjectURL(url);
+      });
+    }
+  });
+</script>
+"""
+        html += f"<h1>{title}</h1>"
+
+        metrics_html = ""
+        train_val_metrics_html = ""
+        test_metrics_html = ""
+        try:
+            train_stats_path = exp_dir / "training_statistics.json"
+            test_stats_path = exp_dir / TEST_STATISTICS_FILE_NAME
+            if train_stats_path.exists() and test_stats_path.exists():
+                with open(train_stats_path) as f:
+                    train_stats = json.load(f)
+                with open(test_stats_path) as f:
+                    test_stats = json.load(f)
+                output_type = detect_output_type(test_stats)
+                metrics_html = format_stats_table_html(train_stats, test_stats, output_type)
+                train_val_metrics_html = format_train_val_stats_table_html(
+                    train_stats, test_stats
+                )
+                test_metrics_html = format_test_merged_stats_table_html(
+                    extract_metrics_from_json(train_stats, test_stats, output_type)[
+                        "test"
+                    ], output_type
+                )
+        except Exception as e:
+            logger.warning(
+                f"Could not load stats for HTML report: {type(e).__name__}: {e}"
+            )
+
+        config_html = ""
+        training_progress = self.get_training_process(output_dir)
+        try:
+            config_html = format_config_table_html(
+                config, split_info, training_progress, output_type
+            )
+        except Exception as e:
+            logger.warning(f"Could not load config for HTML report: {e}")
+
+        # ---------- image rendering with exclusions ----------
+        def render_img_section(
+            title: str,
+            dir_path: Path,
+            output_type: str = None,
+            exclude_names: Optional[set] = None,
+        ) -> str:
+            if not dir_path.exists():
+                return f"<h2>{title}</h2><p><em>Directory not found.</em></p>"
+
+            exclude_names = exclude_names or set()
+
+            imgs = list(dir_path.glob("*.png"))
+
+            # Exclude ROC curves and standard confusion matrices (keep only entropy version)
+            default_exclude = {
+                # "roc_curves.png",  # Remove ROC curves from test tab
+                "confusion_matrix__label_top5.png",  # Remove standard confusion matrix
+                "confusion_matrix__label_top10.png",  # Remove duplicate
+                "confusion_matrix__label_top6.png",   # Remove duplicate
+                "confusion_matrix_entropy__label_top10.png",  # Keep only top5
+                "confusion_matrix_entropy__label_top6.png",   # Keep only top5
+            }
+            title_is_test = title.lower().startswith("test")
+            if title_is_test and output_type == "binary":
+                default_exclude.update(
+                    {
+                        "confusion_matrix__label_top2.png",
+                        "confusion_matrix_entropy__label_top2.png",
+                        "roc_curves_from_prediction_statistics.png",
+                    }
+                )
+            elif title_is_test and output_type == "category":
+                default_exclude.update(
+                    {
+                        "compare_classifiers_multiclass_multimetric__label_best10.png",
+                        "compare_classifiers_multiclass_multimetric__label_sorted.png",
+                        "compare_classifiers_multiclass_multimetric__label_worst10.png",
+                    }
+                )
+
+            imgs = [
+                img
+                for img in imgs
+                if img.name not in default_exclude
+                and img.name not in exclude_names
+            ]
+
+            if not imgs:
+                return f"<h2>{title}</h2><p><em>No plots found.</em></p>"
+
+            # Sort images by name for consistent ordering (works with string and numeric labels)
+            imgs = sorted(imgs, key=lambda x: x.name)
+
+            html_section = ""
+            custom_titles = {
+                "compare_classifiers_multiclass_multimetric__label_top10": "Metric Comparison by Label",
+                "compare_classifiers_performance_from_prob": "Label Metric Comparison by Probability",
+            }
+            for img in imgs:
+                b64 = encode_image_to_base64(str(img))
+                default_title = img.stem.replace("_", " ").title()
+                img_title = custom_titles.get(img.stem, default_title)
+                html_section += (
+                    f"<h2 style='text-align: center;'>{img_title}</h2>"
+                    f'<div class="plot" style="margin-bottom:20px;text-align:center;">'
+                    f'<img src="data:image/png;base64,{b64}" '
+                    f'style="max-width:90%;max-height:600px;border:1px solid #ddd;" />'
+                    f"</div>"
+                )
+            return html_section
+
+        tab1_content = config_html + metrics_html
+
+        tab2_content = train_val_metrics_html + render_img_section(
+            "Training and Validation Visualizations",
+            train_viz_dir,
+            output_type,
+            exclude_names={
+                "compare_classifiers_performance_from_prob.png",
+                "roc_curves_from_prediction_statistics.png",
+                "precision_recall_curves_from_prediction_statistics.png",
+                "precision_recall_curve.png",
+            },
+        )
+
+        # --- Predictions vs Ground Truth table (REGRESSION ONLY) ---
+        preds_section = ""
+        parquet_path = exp_dir / PREDICTIONS_PARQUET_FILE_NAME
+        if output_type == "regression" and parquet_path.exists():
+            try:
+                # 1) load predictions from Parquet
+                df_preds = pd.read_parquet(parquet_path).reset_index(drop=True)
+                # assume the column containing your model's prediction is named "prediction"
+                # or contains that substring:
+                pred_col = next(
+                    (c for c in df_preds.columns if "prediction" in c.lower()),
+                    None,
+                )
+                if pred_col is None:
+                    raise ValueError("No prediction column found in Parquet output")
+                df_pred = df_preds[[pred_col]].rename(columns={pred_col: "prediction"})
+
+                # 2) load ground truth for the test split from prepared CSV
+                df_all = pd.read_csv(config["label_column_data_path"])
+                df_gt = df_all[df_all[SPLIT_COLUMN_NAME] == 2][
+                    LABEL_COLUMN_NAME
+                ].reset_index(drop=True)
+                # 3) concatenate side-by-side
+                df_table = pd.concat([df_gt, df_pred], axis=1)
+                df_table.columns = [LABEL_COLUMN_NAME, "prediction"]
+
+                # 4) render as HTML
+                preds_html = df_table.to_html(index=False, classes="predictions-table")
+                preds_section = (
+                    "<h2 style='text-align: center;'>Ground Truth vs. Predictions</h2>"
+                    "<div class='preds-controls'>"
+                    "<button id='downloadPredsCsv' class='download-btn'>Download CSV</button>"
+                    "</div>"
+                    "<div class='scroll-rows-30' style='overflow-x:auto; overflow-y:auto; max-height:900px; margin-bottom:20px;'>"
+                    + preds_html
+                    + "</div>"
+                )
+            except Exception as e:
+                logger.warning(f"Could not build Predictions vs GT table: {e}")
+
+        tab3_content = test_metrics_html + preds_section
+
+        if output_type in ("binary", "category") and test_stats_path.exists():
+            try:
+                interactive_plots = build_classification_plots(
+                    str(test_stats_path),
+                    str(train_stats_path) if train_stats_path.exists() else None,
+                )
+                for plot in interactive_plots:
+                    tab3_content += (
+                        f"<h2 style='text-align: center;'>{plot['title']}</h2>"
+                        f"<div class='plotly-center'>{plot['html']}</div>"
+                    )
+                logger.info(f"Generated {len(interactive_plots)} interactive Plotly plots")
+            except Exception as e:
+                logger.warning(f"Could not generate Plotly plots: {e}")
+
+        # Add static TEST PNGs (with default dedupe/exclusions)
+        tab3_content += render_img_section(
+            "Test Visualizations", test_viz_dir, output_type
+        )
+
+        tabbed_html = build_tabbed_html(tab1_content, tab2_content, tab3_content)
+        modal_html = get_metrics_help_modal()
+        html += tabbed_html + modal_html + get_html_closing()
+
+        try:
+            with open(report_path, "w") as f:
+                f.write(html)
+            logger.info(f"HTML report generated at: {report_path}")
+        except Exception as e:
+            logger.error(f"Failed to write HTML report: {e}")
+            raise
+
+        return report_path