diff image_learner_cli.py @ 12:bcfa2e234a80 draft

planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
author goeckslab
date Fri, 21 Nov 2025 15:58:13 +0000
parents c5150cceab47
children
line wrap: on
line diff
--- a/image_learner_cli.py	Sat Oct 18 03:17:09 2025 +0000
+++ b/image_learner_cli.py	Fri Nov 21 15:58:13 2025 +0000
@@ -1,45 +1,15 @@
 import argparse
-import json
 import logging
 import os
-import shutil
 import sys
-import tempfile
-import zipfile
 from pathlib import Path
-from typing import Any, Dict, Optional, Protocol, Tuple
 
 import matplotlib
-import numpy as np
-import pandas as pd
-import pandas.api.types as ptypes
-import yaml
-from constants import (
-    IMAGE_PATH_COLUMN_NAME,
-    LABEL_COLUMN_NAME,
-    METRIC_DISPLAY_NAMES,
-    MODEL_ENCODER_TEMPLATES,
-    SPLIT_COLUMN_NAME,
-    TEMP_CONFIG_FILENAME,
-    TEMP_CSV_FILENAME,
-    TEMP_DIR_PREFIX,
-)
-from ludwig.globals import (
-    DESCRIPTION_FILE_NAME,
-    PREDICTIONS_PARQUET_FILE_NAME,
-    TEST_STATISTICS_FILE_NAME,
-    TRAIN_SET_METADATA_FILE_NAME,
-)
-from ludwig.utils.data_utils import get_split_path
-from plotly_plots import build_classification_plots
-from sklearn.model_selection import train_test_split
-from utils import (
-    build_tabbed_html,
-    encode_image_to_base64,
-    get_html_closing,
-    get_html_template,
-    get_metrics_help_modal,
-)
+from constants import MODEL_ENCODER_TEMPLATES
+from image_workflow import ImageLearnerCLI
+from ludwig_backend import LudwigDirectBackend
+from split_data import SplitProbAction
+from utils import argument_checker, parse_learning_rate
 
 # Set matplotlib backend after imports
 matplotlib.use('Agg')
@@ -51,1839 +21,6 @@
 )
 logger = logging.getLogger("ImageLearner")
 
-# Optional MetaFormer configuration registry
-META_DEFAULT_CFGS: Dict[str, Any] = {}
-try:
-    from MetaFormer import default_cfgs as META_DEFAULT_CFGS  # type: ignore[attr-defined]
-except Exception as e:
-    logger.debug("MetaFormer default configs unavailable: %s", e)
-    META_DEFAULT_CFGS = {}
-
-# Try to import Ludwig visualization registry (may fail due to optional dependencies)
-# This must come AFTER logger is defined
-_ludwig_viz_available = False
-get_visualizations_registry = None
-try:
-    from ludwig.visualize import get_visualizations_registry
-    _ludwig_viz_available = True
-    logger.info("Ludwig visualizations available")
-except ImportError as e:
-    logger.warning(f"Ludwig visualizations not available: {e}. Will use fallback plots only.")
-except Exception as e:
-    logger.warning(f"Ludwig visualizations not available due to dependency issues: {e}. Will use fallback plots only.")
-
-# --- MetaFormer patching integration ---
-_metaformer_patch_ok = False
-try:
-    from MetaFormer.metaformer_stacked_cnn import patch_ludwig_stacked_cnn as _mf_patch
-    if _mf_patch():
-        _metaformer_patch_ok = True
-        logger.info("MetaFormer patching applied for Ludwig stacked_cnn encoder.")
-except Exception as e:
-    logger.warning(f"MetaFormer stacked CNN not available: {e}")
-    _metaformer_patch_ok = False
-
-# Note: CAFormer models are now handled through MetaFormer framework
-
-
-def format_config_table_html(
-    config: dict,
-    split_info: Optional[str] = None,
-    training_progress: dict = None,
-    output_type: Optional[str] = None,
-) -> str:
-    display_keys = [
-        "task_type",
-        "model_name",
-        "epochs",
-        "batch_size",
-        "fine_tune",
-        "use_pretrained",
-        "learning_rate",
-        "random_seed",
-        "early_stop",
-        "threshold",
-    ]
-
-    rows = []
-
-    for key in display_keys:
-        val = config.get(key, None)
-        if key == "threshold":
-            if output_type != "binary":
-                continue
-            val = val if val is not None else 0.5
-            val_str = f"{val:.2f}"
-            if val == 0.5:
-                val_str += " (default)"
-        else:
-            if key == "task_type":
-                val_str = val.title() if isinstance(val, str) else "N/A"
-            elif key == "batch_size":
-                if val is not None:
-                    val_str = int(val)
-                else:
-                    val = "auto"
-                    val_str = "auto"
-            resolved_val = None
-            if val is None or val == "auto":
-                if training_progress:
-                    resolved_val = training_progress.get("batch_size")
-                    val = (
-                        "Auto-selected batch size by Ludwig:<br>"
-                        f"<span style='font-size: 0.85em;'>"
-                        f"{resolved_val if resolved_val else val}</span><br>"
-                        "<span style='font-size: 0.85em;'>"
-                        "Based on model architecture and training setup "
-                        "(e.g., fine-tuning).<br>"
-                        "See <a href='https://ludwig.ai/latest/configuration/trainer/"
-                        "#trainer-parameters' target='_blank'>"
-                        "Ludwig Trainer Parameters</a> for details."
-                        "</span>"
-                    )
-                else:
-                    val = (
-                        "Auto-selected by Ludwig<br>"
-                        "<span style='font-size: 0.85em;'>"
-                        "Automatically tuned based on architecture and dataset.<br>"
-                        "See <a href='https://ludwig.ai/latest/configuration/trainer/"
-                        "#trainer-parameters' target='_blank'>"
-                        "Ludwig Trainer Parameters</a> for details."
-                        "</span>"
-                    )
-            elif key == "learning_rate":
-                if val is not None and val != "auto":
-                    val_str = f"{val:.6f}"
-                else:
-                    if training_progress:
-                        resolved_val = training_progress.get("learning_rate")
-                        val_str = (
-                            "Auto-selected learning rate by Ludwig:<br>"
-                            f"<span style='font-size: 0.85em;'>"
-                            f"{resolved_val if resolved_val else 'auto'}</span><br>"
-                            "<span style='font-size: 0.85em;'>"
-                            "Based on model architecture and training setup "
-                            "(e.g., fine-tuning).<br>"
-                            "</span>"
-                        )
-                    else:
-                        val_str = (
-                            "Auto-selected by Ludwig<br>"
-                            "<span style='font-size: 0.85em;'>"
-                            "Automatically tuned based on architecture and dataset.<br>"
-                            "See <a href='https://ludwig.ai/latest/configuration/trainer/"
-                            "#trainer-parameters' target='_blank'>"
-                            "Ludwig Trainer Parameters</a> for details."
-                            "</span>"
-                        )
-            elif key == "epochs":
-                if val is None:
-                    val_str = "N/A"
-                else:
-                    if (
-                        training_progress
-                        and "epoch" in training_progress
-                        and val > training_progress["epoch"]
-                    ):
-                        val_str = (
-                            f"Because of early stopping: the training "
-                            f"stopped at epoch {training_progress['epoch']}"
-                        )
-                    else:
-                        val_str = val
-            else:
-                val_str = val if val is not None else "N/A"
-            if val_str == "N/A" and key not in ["task_type"]:
-                continue
-        rows.append(
-            f"<tr>"
-            f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left; "
-            f"white-space: normal; word-break: break-word; overflow-wrap: anywhere;'>"
-            f"{key.replace('_', ' ').title()}</td>"
-            f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center; "
-            f"white-space: normal; word-break: break-word; overflow-wrap: anywhere;'>"
-            f"{val_str}</td>"
-            f"</tr>"
-        )
-
-    aug_cfg = config.get("augmentation")
-    if aug_cfg:
-        types = [str(a.get("type", "")) for a in aug_cfg]
-        aug_val = ", ".join(types)
-        rows.append(
-            f"<tr><td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left; "
-            f"white-space: normal; word-break: break-word; overflow-wrap: anywhere;'>Augmentation</td>"
-            f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center; "
-            f"white-space: normal; word-break: break-word; overflow-wrap: anywhere;'>{aug_val}</td></tr>"
-        )
-
-    if split_info:
-        rows.append(
-            f"<tr><td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left; "
-            f"white-space: normal; word-break: break-word; overflow-wrap: anywhere;'>Data Split</td>"
-            f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center; "
-            f"white-space: normal; word-break: break-word; overflow-wrap: anywhere;'>{split_info}</td></tr>"
-        )
-
-    html = f"""
-        <h2 style="text-align: center;">Model and Training Summary</h2>
-        <div style="display: flex; justify-content: center;">
-          <table style="border-collapse: collapse; width: 100%; table-layout: fixed;">
-            <thead><tr>
-              <th style="padding: 10px; border: 1px solid #ccc; text-align: left; white-space: nowrap; overflow: hidden; text-overflow: ellipsis;">Parameter</th>
-              <th style="padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap; overflow: hidden; text-overflow: ellipsis;">Value</th>
-            </tr></thead>
-            <tbody>
-              {"".join(rows)}
-            </tbody>
-          </table>
-        </div><br>
-        <p style="text-align: center; font-size: 0.9em;">
-          Model trained using <a href="https://ludwig.ai/" target="_blank" rel="noopener noreferrer">Ludwig</a>.
-          <a href="https://ludwig.ai/latest/configuration/" target="_blank" rel="noopener noreferrer">
-            Ludwig documentation provides detailed information about default model and training parameters
-          </a>
-        </p><hr>
-        """
-    return html
-
-
-def detect_output_type(test_stats):
-    """Detects if the output type is 'binary' or 'category' based on test statistics."""
-    label_stats = test_stats.get("label", {})
-    if "mean_squared_error" in label_stats:
-        return "regression"
-    per_class = label_stats.get("per_class_stats", {})
-    if len(per_class) == 2:
-        return "binary"
-    return "category"
-
-
-def extract_metrics_from_json(
-    train_stats: dict,
-    test_stats: dict,
-    output_type: str,
-) -> dict:
-    """Extracts relevant metrics from training and test statistics based on the output type."""
-    metrics = {"training": {}, "validation": {}, "test": {}}
-
-    def get_last_value(stats, key):
-        val = stats.get(key)
-        if isinstance(val, list) and val:
-            return val[-1]
-        elif isinstance(val, (int, float)):
-            return val
-        return None
-
-    for split in ["training", "validation"]:
-        split_stats = train_stats.get(split, {})
-        if not split_stats:
-            logging.warning(f"No statistics found for {split} split")
-            continue
-        label_stats = split_stats.get("label", {})
-        if not label_stats:
-            logging.warning(f"No label statistics found for {split} split")
-            continue
-        if output_type == "binary":
-            metrics[split] = {
-                "accuracy": get_last_value(label_stats, "accuracy"),
-                "loss": get_last_value(label_stats, "loss"),
-                "precision": get_last_value(label_stats, "precision"),
-                "recall": get_last_value(label_stats, "recall"),
-                "specificity": get_last_value(label_stats, "specificity"),
-                "roc_auc": get_last_value(label_stats, "roc_auc"),
-            }
-        elif output_type == "regression":
-            metrics[split] = {
-                "loss": get_last_value(label_stats, "loss"),
-                "mean_absolute_error": get_last_value(
-                    label_stats, "mean_absolute_error"
-                ),
-                "mean_absolute_percentage_error": get_last_value(
-                    label_stats, "mean_absolute_percentage_error"
-                ),
-                "mean_squared_error": get_last_value(label_stats, "mean_squared_error"),
-                "root_mean_squared_error": get_last_value(
-                    label_stats, "root_mean_squared_error"
-                ),
-                "root_mean_squared_percentage_error": get_last_value(
-                    label_stats, "root_mean_squared_percentage_error"
-                ),
-                "r2": get_last_value(label_stats, "r2"),
-            }
-        else:
-            metrics[split] = {
-                "accuracy": get_last_value(label_stats, "accuracy"),
-                "accuracy_micro": get_last_value(label_stats, "accuracy_micro"),
-                "loss": get_last_value(label_stats, "loss"),
-                "roc_auc": get_last_value(label_stats, "roc_auc"),
-                "hits_at_k": get_last_value(label_stats, "hits_at_k"),
-            }
-
-    # Test metrics: dynamic extraction according to exclusions
-    test_label_stats = test_stats.get("label", {})
-    if not test_label_stats:
-        logging.warning("No label statistics found for test split")
-    else:
-        combined_stats = test_stats.get("combined", {})
-        overall_stats = test_label_stats.get("overall_stats", {})
-
-        # Define exclusions
-        if output_type == "binary":
-            exclude = {"per_class_stats", "precision_recall_curve", "roc_curve"}
-        else:
-            exclude = {"per_class_stats", "confusion_matrix"}
-
-        # 1. Get all scalar test_label_stats not excluded
-        test_metrics = {}
-        for k, v in test_label_stats.items():
-            if k in exclude:
-                continue
-            if k == "overall_stats":
-                continue
-            if isinstance(v, (int, float, str, bool)):
-                test_metrics[k] = v
-
-        # 2. Add overall_stats (flattened)
-        for k, v in overall_stats.items():
-            test_metrics[k] = v
-
-        # 3. Optionally include combined/loss if present and not already
-        if "loss" in combined_stats and "loss" not in test_metrics:
-            test_metrics["loss"] = combined_stats["loss"]
-        metrics["test"] = test_metrics
-    return metrics
-
-
-def generate_table_row(cells, styles):
-    """Helper function to generate an HTML table row."""
-    return (
-        "<tr>"
-        + "".join(f"<td style='{styles}'>{cell}</td>" for cell in cells)
-        + "</tr>"
-    )
-
-
-# -----------------------------------------
-# 2) MODEL PERFORMANCE (Train/Val/Test) TABLE
-# -----------------------------------------
-def format_stats_table_html(train_stats: dict, test_stats: dict, output_type: str) -> str:
-    """Formats a combined HTML table for training, validation, and test metrics."""
-    all_metrics = extract_metrics_from_json(train_stats, test_stats, output_type)
-    rows = []
-    for metric_key in sorted(all_metrics["training"].keys()):
-        if (
-            metric_key in all_metrics["validation"]
-            and metric_key in all_metrics["test"]
-        ):
-            display_name = METRIC_DISPLAY_NAMES.get(
-                metric_key,
-                metric_key.replace("_", " ").title(),
-            )
-            t = all_metrics["training"].get(metric_key)
-            v = all_metrics["validation"].get(metric_key)
-            te = all_metrics["test"].get(metric_key)
-            if all(x is not None for x in [t, v, te]):
-                rows.append([display_name, f"{t:.4f}", f"{v:.4f}", f"{te:.4f}"])
-
-    if not rows:
-        return "<table><tr><td>No metric values found.</td></tr></table>"
-
-    html = (
-        "<h2 style='text-align: center;'>Model Performance Summary</h2>"
-        "<div style='display: flex; justify-content: center;'>"
-        "<table class='performance-summary' style='border-collapse: collapse;'>"
-        "<thead><tr>"
-        "<th class='sortable' style='padding: 10px; border: 1px solid #ccc; text-align: left; white-space: nowrap;'>Metric</th>"
-        "<th class='sortable' style='padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;'>Train</th>"
-        "<th class='sortable' style='padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;'>Validation</th>"
-        "<th class='sortable' style='padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;'>Test</th>"
-        "</tr></thead><tbody>"
-    )
-    for row in rows:
-        html += generate_table_row(
-            row,
-            "padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;",
-        )
-    html += "</tbody></table></div><br>"
-    return html
-
-
-# -------------------------------------------
-# 3) TRAIN/VALIDATION PERFORMANCE SUMMARY TABLE
-# -------------------------------------------
-def format_train_val_stats_table_html(train_stats: dict, test_stats: dict) -> str:
-    """Format train/validation metrics into an HTML table."""
-    all_metrics = extract_metrics_from_json(train_stats, test_stats, detect_output_type(test_stats))
-    rows = []
-    for metric_key in sorted(all_metrics["training"].keys()):
-        if metric_key in all_metrics["validation"]:
-            display_name = METRIC_DISPLAY_NAMES.get(
-                metric_key,
-                metric_key.replace("_", " ").title(),
-            )
-            t = all_metrics["training"].get(metric_key)
-            v = all_metrics["validation"].get(metric_key)
-            if t is not None and v is not None:
-                rows.append([display_name, f"{t:.4f}", f"{v:.4f}"])
-
-    if not rows:
-        return "<table><tr><td>No metric values found for Train/Validation.</td></tr></table>"
-
-    html = (
-        "<h2 style='text-align: center;'>Train/Validation Performance Summary</h2>"
-        "<div style='display: flex; justify-content: center;'>"
-        "<table class='performance-summary' style='border-collapse: collapse;'>"
-        "<thead><tr>"
-        "<th class='sortable' style='padding: 10px; border: 1px solid #ccc; text-align: left; white-space: nowrap;'>Metric</th>"
-        "<th class='sortable' style='padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;'>Train</th>"
-        "<th class='sortable' style='padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;'>Validation</th>"
-        "</tr></thead><tbody>"
-    )
-    for row in rows:
-        html += generate_table_row(
-            row,
-            "padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;",
-        )
-    html += "</tbody></table></div><br>"
-    return html
-
-
-# -----------------------------------------
-# 4) TEST‐ONLY PERFORMANCE SUMMARY TABLE
-# -----------------------------------------
-def format_test_merged_stats_table_html(
-    test_metrics: Dict[str, Any], output_type: str
-) -> str:
-    """Format test metrics into an HTML table."""
-    rows = []
-    for key in sorted(test_metrics.keys()):
-        display_name = METRIC_DISPLAY_NAMES.get(key, key.replace("_", " ").title())
-        value = test_metrics[key]
-        if value is not None:
-            rows.append([display_name, f"{value:.4f}"])
-
-    if not rows:
-        return "<table><tr><td>No test metric values found.</td></tr></table>"
-
-    html = (
-        "<h2 style='text-align: center;'>Test Performance Summary</h2>"
-        "<div style='display: flex; justify-content: center;'>"
-        "<table class='performance-summary' style='border-collapse: collapse;'>"
-        "<thead><tr>"
-        "<th class='sortable' style='padding: 10px; border: 1px solid #ccc; text-align: left; white-space: nowrap;'>Metric</th>"
-        "<th class='sortable' style='padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;'>Test</th>"
-        "</tr></thead><tbody>"
-    )
-    for row in rows:
-        html += generate_table_row(
-            row,
-            "padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;",
-        )
-    html += "</tbody></table></div><br>"
-    return html
-
-
-def split_data_0_2(
-    df: pd.DataFrame,
-    split_column: str,
-    validation_size: float = 0.1,
-    random_state: int = 42,
-    label_column: Optional[str] = None,
-) -> pd.DataFrame:
-    """Given a DataFrame whose split_column only contains {0,2}, re-assign a portion of the 0s to become 1s (validation)."""
-    out = df.copy()
-    out[split_column] = pd.to_numeric(out[split_column], errors="coerce").astype(int)
-
-    idx_train = out.index[out[split_column] == 0].tolist()
-
-    if not idx_train:
-        logger.info("No rows with split=0; nothing to do.")
-        return out
-    stratify_arr = None
-    if label_column and label_column in out.columns:
-        label_counts = out.loc[idx_train, label_column].value_counts()
-        if label_counts.size > 1:
-            # Force stratify even with fewer samples - adjust validation_size if needed
-            min_samples_per_class = label_counts.min()
-            if min_samples_per_class * validation_size < 1:
-                # Adjust validation_size to ensure at least 1 sample per class, but do not exceed original validation_size
-                adjusted_validation_size = min(
-                    validation_size, 1.0 / min_samples_per_class
-                )
-                if adjusted_validation_size != validation_size:
-                    validation_size = adjusted_validation_size
-                    logger.info(
-                        f"Adjusted validation_size to {validation_size:.3f} to ensure at least one sample per class in validation"
-                    )
-            stratify_arr = out.loc[idx_train, label_column]
-            logger.info("Using stratified split for validation set")
-        else:
-            logger.warning("Only one label class found; cannot stratify")
-    if validation_size <= 0:
-        logger.info("validation_size <= 0; keeping all as train.")
-        return out
-    if validation_size >= 1:
-        logger.info("validation_size >= 1; moving all train → validation.")
-        out.loc[idx_train, split_column] = 1
-        return out
-    # Always try stratified split first
-    try:
-        train_idx, val_idx = train_test_split(
-            idx_train,
-            test_size=validation_size,
-            random_state=random_state,
-            stratify=stratify_arr,
-        )
-        logger.info("Successfully applied stratified split")
-    except ValueError as e:
-        logger.warning(f"Stratified split failed ({e}); falling back to random split.")
-        train_idx, val_idx = train_test_split(
-            idx_train,
-            test_size=validation_size,
-            random_state=random_state,
-            stratify=None,
-        )
-    out.loc[train_idx, split_column] = 0
-    out.loc[val_idx, split_column] = 1
-    out[split_column] = out[split_column].astype(int)
-    return out
-
-
-def create_stratified_random_split(
-    df: pd.DataFrame,
-    split_column: str,
-    split_probabilities: list = [0.7, 0.1, 0.2],
-    random_state: int = 42,
-    label_column: Optional[str] = None,
-) -> pd.DataFrame:
-    """Create a stratified random split when no split column exists."""
-    out = df.copy()
-
-    # initialize split column
-    out[split_column] = 0
-
-    if not label_column or label_column not in out.columns:
-        logger.warning(
-            "No label column found; using random split without stratification"
-        )
-        # fall back to simple random assignment
-        indices = out.index.tolist()
-        np.random.seed(random_state)
-        np.random.shuffle(indices)
-
-        n_total = len(indices)
-        n_train = int(n_total * split_probabilities[0])
-        n_val = int(n_total * split_probabilities[1])
-
-        out.loc[indices[:n_train], split_column] = 0
-        out.loc[indices[n_train:n_train + n_val], split_column] = 1
-        out.loc[indices[n_train + n_val:], split_column] = 2
-
-        return out.astype({split_column: int})
-
-    # check if stratification is possible
-    label_counts = out[label_column].value_counts()
-    min_samples_per_class = label_counts.min()
-
-    # ensure we have enough samples for stratification:
-    # Each class must have at least as many samples as the number of splits,
-    # so that each split can receive at least one sample per class.
-    min_samples_required = len(split_probabilities)
-    if min_samples_per_class < min_samples_required:
-        logger.warning(
-            f"Insufficient samples per class for stratification (min: {min_samples_per_class}, required: {min_samples_required}); using random split"
-        )
-        # fall back to simple random assignment
-        indices = out.index.tolist()
-        np.random.seed(random_state)
-        np.random.shuffle(indices)
-
-        n_total = len(indices)
-        n_train = int(n_total * split_probabilities[0])
-        n_val = int(n_total * split_probabilities[1])
-
-        out.loc[indices[:n_train], split_column] = 0
-        out.loc[indices[n_train:n_train + n_val], split_column] = 1
-        out.loc[indices[n_train + n_val:], split_column] = 2
-
-        return out.astype({split_column: int})
-
-    logger.info("Using stratified random split for train/validation/test sets")
-
-    # first split: separate test set
-    train_val_idx, test_idx = train_test_split(
-        out.index.tolist(),
-        test_size=split_probabilities[2],
-        random_state=random_state,
-        stratify=out[label_column],
-    )
-
-    # second split: separate training and validation from remaining data
-    val_size_adjusted = split_probabilities[1] / (
-        split_probabilities[0] + split_probabilities[1]
-    )
-    train_idx, val_idx = train_test_split(
-        train_val_idx,
-        test_size=val_size_adjusted,
-        random_state=random_state,
-        stratify=out.loc[train_val_idx, label_column] if label_column and label_column in out.columns else None,
-    )
-
-    # assign split values
-    out.loc[train_idx, split_column] = 0
-    out.loc[val_idx, split_column] = 1
-    out.loc[test_idx, split_column] = 2
-
-    logger.info("Successfully applied stratified random split")
-    logger.info(
-        f"Split counts: Train={len(train_idx)}, Val={len(val_idx)}, Test={len(test_idx)}"
-    )
-    return out.astype({split_column: int})
-
-
-class Backend(Protocol):
-    """Interface for a machine learning backend."""
-
-    def prepare_config(
-        self,
-        config_params: Dict[str, Any],
-        split_config: Dict[str, Any],
-    ) -> str:
-        ...
-
-    def run_experiment(
-        self,
-        dataset_path: Path,
-        config_path: Path,
-        output_dir: Path,
-        random_seed: int,
-    ) -> None:
-        ...
-
-    def generate_plots(self, output_dir: Path) -> None:
-        ...
-
-    def generate_html_report(
-        self,
-        title: str,
-        output_dir: str,
-        config: Dict[str, Any],
-        split_info: str,
-    ) -> Path:
-        ...
-
-
-class LudwigDirectBackend:
-    """Backend for running Ludwig experiments directly via the internal experiment_cli function."""
-
-    def _detect_image_dimensions(self, image_zip_path: str) -> Tuple[int, int]:
-        """Detect image dimensions from the first image in the dataset."""
-        try:
-            import zipfile
-            from PIL import Image
-            import io
-
-            # Check if image_zip is provided
-            if not image_zip_path:
-                logger.warning("No image zip provided, using default 224x224")
-                return 224, 224
-
-            # Extract first image to detect dimensions
-            with zipfile.ZipFile(image_zip_path, 'r') as z:
-                image_files = [f for f in z.namelist() if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
-                if not image_files:
-                    logger.warning("No image files found in zip, using default 224x224")
-                    return 224, 224
-
-                # Check first image
-                with z.open(image_files[0]) as f:
-                    img = Image.open(io.BytesIO(f.read()))
-                    width, height = img.size
-                    logger.info(f"Detected image dimensions: {width}x{height}")
-                    return height, width  # Return as (height, width) to match encoder config
-
-        except Exception as e:
-            logger.warning(f"Error detecting image dimensions: {e}, using default 224x224")
-            return 224, 224
-
-    def prepare_config(
-        self,
-        config_params: Dict[str, Any],
-        split_config: Dict[str, Any],
-    ) -> str:
-        logger.info("LudwigDirectBackend: Preparing YAML configuration.")
-
-        model_name = config_params.get("model_name", "resnet18")
-        use_pretrained = config_params.get("use_pretrained", False)
-        fine_tune = config_params.get("fine_tune", False)
-        if use_pretrained:
-            trainable = bool(fine_tune)
-        else:
-            trainable = True
-        epochs = config_params.get("epochs", 10)
-        batch_size = config_params.get("batch_size")
-        num_processes = config_params.get("preprocessing_num_processes", 1)
-        early_stop = config_params.get("early_stop", None)
-        learning_rate = config_params.get("learning_rate")
-        learning_rate = "auto" if learning_rate is None else float(learning_rate)
-        raw_encoder = MODEL_ENCODER_TEMPLATES.get(model_name, model_name)
-
-        # --- MetaFormer detection and config logic ---
-        def _is_metaformer(name: str) -> bool:
-            return isinstance(name, str) and name.startswith(
-                (
-                    "identityformer_",
-                    "randformer_",
-                    "poolformerv2_",
-                    "convformer_",
-                    "caformer_",
-                )
-            )
-
-        # Check if this is a MetaFormer model (either direct name or in custom_model)
-        is_metaformer = (
-            _is_metaformer(model_name)
-            or (isinstance(raw_encoder, dict) and "custom_model" in raw_encoder and _is_metaformer(raw_encoder["custom_model"]))
-        )
-
-        metaformer_resize: Optional[Tuple[int, int]] = None
-        metaformer_channels = 3
-
-        if is_metaformer:
-            # Handle MetaFormer models
-            custom_model = None
-            if isinstance(raw_encoder, dict) and "custom_model" in raw_encoder:
-                custom_model = raw_encoder["custom_model"]
-            else:
-                custom_model = model_name
-
-            logger.info(f"DETECTED MetaFormer model: {custom_model}")
-            cfg_channels, cfg_height, cfg_width = 3, 224, 224
-            if META_DEFAULT_CFGS:
-                model_cfg = META_DEFAULT_CFGS.get(custom_model, {})
-                input_size = model_cfg.get("input_size")
-                if isinstance(input_size, (list, tuple)) and len(input_size) == 3:
-                    cfg_channels, cfg_height, cfg_width = (
-                        int(input_size[0]),
-                        int(input_size[1]),
-                        int(input_size[2]),
-                    )
-
-            target_height, target_width = cfg_height, cfg_width
-            resize_value = config_params.get("image_resize")
-            if resize_value and resize_value != "original":
-                try:
-                    dimensions = resize_value.split("x")
-                    if len(dimensions) == 2:
-                        target_height, target_width = int(dimensions[0]), int(dimensions[1])
-                        if target_height <= 0 or target_width <= 0:
-                            raise ValueError(
-                                f"Image resize must be positive integers, received {resize_value}."
-                            )
-                        logger.info(f"MetaFormer explicit resize: {target_height}x{target_width}")
-                    else:
-                        raise ValueError(resize_value)
-                except (ValueError, IndexError):
-                    logger.warning(
-                        "Invalid image resize format '%s'; falling back to model default %sx%s",
-                        resize_value,
-                        cfg_height,
-                        cfg_width,
-                    )
-                    target_height, target_width = cfg_height, cfg_width
-            else:
-                image_zip_path = config_params.get("image_zip", "")
-                detected_height, detected_width = self._detect_image_dimensions(image_zip_path)
-                if use_pretrained:
-                    if (detected_height, detected_width) != (cfg_height, cfg_width):
-                        logger.info(
-                            "MetaFormer pretrained weights expect %sx%s; resizing from detected %sx%s",
-                            cfg_height,
-                            cfg_width,
-                            detected_height,
-                            detected_width,
-                        )
-                else:
-                    target_height, target_width = detected_height, detected_width
-                if target_height <= 0 or target_width <= 0:
-                    raise ValueError(
-                        f"Invalid detected image dimensions for MetaFormer: {target_height}x{target_width}."
-                    )
-
-            metaformer_channels = cfg_channels
-            metaformer_resize = (target_height, target_width)
-
-            encoder_config = {
-                "type": "stacked_cnn",
-                "height": target_height,
-                "width": target_width,
-                "num_channels": metaformer_channels,
-                "output_size": 128,
-                "use_pretrained": use_pretrained,
-                "trainable": trainable,
-                "custom_model": custom_model,
-            }
-
-        elif isinstance(raw_encoder, dict):
-            # Handle image resize for regular encoders
-            # Note: Standard encoders like ResNet don't support height/width parameters
-            # Resize will be handled at the preprocessing level by Ludwig
-            if config_params.get("image_resize") and config_params["image_resize"] != "original":
-                logger.info(f"Resize requested: {config_params['image_resize']} for standard encoder. Resize will be handled at preprocessing level.")
-
-            encoder_config = {
-                **raw_encoder,
-                "use_pretrained": use_pretrained,
-                "trainable": trainable,
-            }
-        else:
-            encoder_config = {"type": raw_encoder}
-
-        batch_size_cfg = batch_size or "auto"
-
-        label_column_path = config_params.get("label_column_data_path")
-        label_series = None
-        if label_column_path is not None and Path(label_column_path).exists():
-            try:
-                label_series = pd.read_csv(label_column_path)[LABEL_COLUMN_NAME]
-            except Exception as e:
-                logger.warning(f"Could not read label column for task detection: {e}")
-
-        if (
-            label_series is not None
-            and ptypes.is_numeric_dtype(label_series.dtype)
-            and label_series.nunique() > 10
-        ):
-            task_type = "regression"
-        else:
-            task_type = "classification"
-
-        config_params["task_type"] = task_type
-
-        image_feat: Dict[str, Any] = {
-            "name": IMAGE_PATH_COLUMN_NAME,
-            "type": "image",
-        }
-        # Set preprocessing dimensions FIRST for MetaFormer models
-        if is_metaformer:
-            if metaformer_resize is None:
-                metaformer_resize = (224, 224)
-            height, width = metaformer_resize
-
-            # CRITICAL: Set preprocessing dimensions FIRST for MetaFormer models
-            # This is essential for MetaFormer models to work properly
-            if "preprocessing" not in image_feat:
-                image_feat["preprocessing"] = {}
-            image_feat["preprocessing"]["height"] = height
-            image_feat["preprocessing"]["width"] = width
-            # Use infer_image_dimensions=True to allow Ludwig to read images for validation
-            # but set explicit max dimensions to control the output size
-            image_feat["preprocessing"]["infer_image_dimensions"] = True
-            image_feat["preprocessing"]["infer_image_max_height"] = height
-            image_feat["preprocessing"]["infer_image_max_width"] = width
-            image_feat["preprocessing"]["num_channels"] = metaformer_channels
-            image_feat["preprocessing"]["resize_method"] = "interpolate"  # Use interpolation for better quality
-            image_feat["preprocessing"]["standardize_image"] = "imagenet1k"  # Use ImageNet standardization
-            # Force Ludwig to respect our dimensions by setting additional parameters
-            image_feat["preprocessing"]["requires_equal_dimensions"] = False
-            logger.info(f"Set preprocessing dimensions for MetaFormer: {height}x{width} (infer_dimensions=True with max dimensions to allow validation)")
-        # Now set the encoder configuration
-        image_feat["encoder"] = encoder_config
-
-        if config_params.get("augmentation") is not None:
-            image_feat["augmentation"] = config_params["augmentation"]
-
-        # Add resize configuration for standard encoders (ResNet, etc.)
-        # FIXED: MetaFormer models now respect user dimensions completely
-        # Previously there was a double resize issue where MetaFormer would force 224x224
-        # Now both MetaFormer and standard encoders respect user's resize choice
-        if (not is_metaformer) and config_params.get("image_resize") and config_params["image_resize"] != "original":
-            try:
-                dimensions = config_params["image_resize"].split("x")
-                if len(dimensions) == 2:
-                    height, width = int(dimensions[0]), int(dimensions[1])
-                    if height <= 0 or width <= 0:
-                        raise ValueError(
-                            f"Image resize must be positive integers, received {config_params['image_resize']}."
-                        )
-
-                    # Add resize to preprocessing for standard encoders
-                    if "preprocessing" not in image_feat:
-                        image_feat["preprocessing"] = {}
-                    image_feat["preprocessing"]["height"] = height
-                    image_feat["preprocessing"]["width"] = width
-                    # Use infer_image_dimensions=True to allow Ludwig to read images for validation
-                    # but set explicit max dimensions to control the output size
-                    image_feat["preprocessing"]["infer_image_dimensions"] = True
-                    image_feat["preprocessing"]["infer_image_max_height"] = height
-                    image_feat["preprocessing"]["infer_image_max_width"] = width
-                    logger.info(f"Added resize preprocessing: {height}x{width} for standard encoder with infer_image_dimensions=True and max dimensions")
-            except (ValueError, IndexError):
-                logger.warning(f"Invalid image resize format: {config_params['image_resize']}, skipping resize preprocessing")
-        if task_type == "regression":
-            output_feat = {
-                "name": LABEL_COLUMN_NAME,
-                "type": "number",
-                "decoder": {"type": "regressor", "input_size": 1},
-                "loss": {"type": "mean_squared_error"},
-                "evaluation": {
-                    "metrics": [
-                        "mean_squared_error",
-                        "mean_absolute_error",
-                        "r2",
-                    ]
-                },
-            }
-            val_metric = config_params.get("validation_metric", "mean_squared_error")
-
-        else:
-            num_unique_labels = (
-                label_series.nunique() if label_series is not None else 2
-            )
-            output_type = "binary" if num_unique_labels == 2 else "category"
-            # Determine if this is regression or classification based on label type
-            is_regression = (
-                label_series is not None
-                and ptypes.is_numeric_dtype(label_series.dtype)
-                and label_series.nunique() > 10
-            )
-
-            if is_regression:
-                output_feat = {
-                    "name": LABEL_COLUMN_NAME,
-                    "type": "number",
-                    "decoder": {"type": "regressor", "input_size": 1},
-                    "loss": {"type": "mean_squared_error"},
-                }
-            else:
-                if num_unique_labels == 2:
-                    output_feat = {
-                        "name": LABEL_COLUMN_NAME,
-                        "type": "binary",
-                        "decoder": {"type": "classifier", "input_size": 1},
-                        "loss": {"type": "softmax_cross_entropy"},
-                    }
-                else:
-                    output_feat = {
-                        "name": LABEL_COLUMN_NAME,
-                        "type": "category",
-                        "decoder": {"type": "classifier", "input_size": num_unique_labels},
-                        "loss": {"type": "softmax_cross_entropy"},
-                    }
-            if output_type == "binary" and config_params.get("threshold") is not None:
-                output_feat["threshold"] = float(config_params["threshold"])
-            val_metric = None
-
-        conf: Dict[str, Any] = {
-            "model_type": "ecd",
-            "input_features": [image_feat],
-            "output_features": [output_feat],
-            "combiner": {"type": "concat"},
-            "trainer": {
-                "epochs": epochs,
-                "early_stop": early_stop,
-                "batch_size": batch_size_cfg,
-                "learning_rate": learning_rate,
-                # only set validation_metric for regression
-                **({"validation_metric": val_metric} if val_metric else {}),
-            },
-            "preprocessing": {
-                "split": split_config,
-                "num_processes": num_processes,
-                "in_memory": False,
-            },
-        }
-
-        logger.debug("LudwigDirectBackend: Config dict built.")
-        try:
-            yaml_str = yaml.dump(conf, sort_keys=False, indent=2)
-            logger.info("LudwigDirectBackend: YAML config generated.")
-            return yaml_str
-        except Exception:
-            logger.error(
-                "LudwigDirectBackend: Failed to serialize YAML.",
-                exc_info=True,
-            )
-            raise
-
-    def run_experiment(
-        self,
-        dataset_path: Path,
-        config_path: Path,
-        output_dir: Path,
-        random_seed: int = 42,
-    ) -> None:
-        """Invoke Ludwig's internal experiment_cli function to run the experiment."""
-        logger.info("LudwigDirectBackend: Starting experiment execution.")
-
-        try:
-            from ludwig.experiment import experiment_cli
-        except ImportError as e:
-            logger.error(
-                "LudwigDirectBackend: Could not import experiment_cli.",
-                exc_info=True,
-            )
-            raise RuntimeError("Ludwig import failed.") from e
-
-        output_dir.mkdir(parents=True, exist_ok=True)
-
-        try:
-            experiment_cli(
-                dataset=str(dataset_path),
-                config=str(config_path),
-                output_directory=str(output_dir),
-                random_seed=random_seed,
-                skip_preprocessing=True,
-            )
-            logger.info(
-                f"LudwigDirectBackend: Experiment completed. Results in {output_dir}"
-            )
-        except TypeError as e:
-            logger.error(
-                "LudwigDirectBackend: Argument mismatch in experiment_cli call.",
-                exc_info=True,
-            )
-            raise RuntimeError("Ludwig argument error.") from e
-        except Exception:
-            logger.error(
-                "LudwigDirectBackend: Experiment execution error.",
-                exc_info=True,
-            )
-            raise
-
-    def get_training_process(self, output_dir) -> Optional[Dict[str, Any]]:
-        """Retrieve the learning rate used in the most recent Ludwig run."""
-        output_dir = Path(output_dir)
-        exp_dirs = sorted(
-            output_dir.glob("experiment_run*"),
-            key=lambda p: p.stat().st_mtime,
-        )
-
-        if not exp_dirs:
-            logger.warning(f"No experiment run directories found in {output_dir}")
-            return None
-
-        progress_file = exp_dirs[-1] / "model" / "training_progress.json"
-        if not progress_file.exists():
-            logger.warning(f"No training_progress.json found in {progress_file}")
-            return None
-
-        try:
-            with progress_file.open("r", encoding="utf-8") as f:
-                data = json.load(f)
-            return {
-                "learning_rate": data.get("learning_rate"),
-                "batch_size": data.get("batch_size"),
-                "epoch": data.get("epoch"),
-            }
-        except Exception as e:
-            logger.warning(f"Failed to read training progress info: {e}")
-            return {}
-
-    def convert_parquet_to_csv(self, output_dir: Path):
-        """Convert the predictions Parquet file to CSV."""
-        output_dir = Path(output_dir)
-        exp_dirs = sorted(
-            output_dir.glob("experiment_run*"),
-            key=lambda p: p.stat().st_mtime,
-        )
-        if not exp_dirs:
-            logger.warning(f"No experiment run dirs found in {output_dir}")
-            return
-        exp_dir = exp_dirs[-1]
-        parquet_path = exp_dir / PREDICTIONS_PARQUET_FILE_NAME
-        csv_path = exp_dir / "predictions.csv"
-
-        # Check if parquet file exists before trying to convert
-        if not parquet_path.exists():
-            logger.info(f"Predictions parquet file not found at {parquet_path}, skipping conversion")
-            return
-
-        try:
-            df = pd.read_parquet(parquet_path)
-            df.to_csv(csv_path, index=False)
-            logger.info(f"Converted Parquet to CSV: {csv_path}")
-        except Exception as e:
-            logger.error(f"Error converting Parquet to CSV: {e}")
-
-    def generate_plots(self, output_dir: Path) -> None:
-        """Generate all registered Ludwig visualizations for the latest experiment run."""
-        logger.info("Generating all Ludwig visualizations…")
-
-        test_plots = {
-            "compare_performance",
-            "compare_classifiers_performance_from_prob",
-            "compare_classifiers_performance_from_pred",
-            "compare_classifiers_performance_changing_k",
-            "compare_classifiers_multiclass_multimetric",
-            "compare_classifiers_predictions",
-            "confidence_thresholding_2thresholds_2d",
-            "confidence_thresholding_2thresholds_3d",
-            "confidence_thresholding",
-            "confidence_thresholding_data_vs_acc",
-            "binary_threshold_vs_metric",
-            "roc_curves",
-            "roc_curves_from_test_statistics",
-            "calibration_1_vs_all",
-            "calibration_multiclass",
-            "confusion_matrix",
-            "frequency_vs_f1",
-        }
-        train_plots = {
-            "learning_curves",
-            "compare_classifiers_performance_subset",
-        }
-
-        output_dir = Path(output_dir)
-        exp_dirs = sorted(
-            output_dir.glob("experiment_run*"),
-            key=lambda p: p.stat().st_mtime,
-        )
-        if not exp_dirs:
-            logger.warning(f"No experiment run dirs found in {output_dir}")
-            return
-        exp_dir = exp_dirs[-1]
-
-        viz_dir = exp_dir / "visualizations"
-        viz_dir.mkdir(exist_ok=True)
-        train_viz = viz_dir / "train"
-        test_viz = viz_dir / "test"
-        train_viz.mkdir(parents=True, exist_ok=True)
-        test_viz.mkdir(parents=True, exist_ok=True)
-
-        def _check(p: Path) -> Optional[str]:
-            return str(p) if p.exists() else None
-
-        training_stats = _check(exp_dir / "training_statistics.json")
-        test_stats = _check(exp_dir / TEST_STATISTICS_FILE_NAME)
-        probs_path = _check(exp_dir / PREDICTIONS_PARQUET_FILE_NAME)
-        gt_metadata = _check(exp_dir / "model" / TRAIN_SET_METADATA_FILE_NAME)
-
-        dataset_path = None
-        split_file = None
-        desc = exp_dir / DESCRIPTION_FILE_NAME
-        if desc.exists():
-            with open(desc, "r") as f:
-                cfg = json.load(f)
-            dataset_path = _check(Path(cfg.get("dataset", "")))
-            split_file = _check(Path(get_split_path(cfg.get("dataset", ""))))
-
-        output_feature = ""
-        if desc.exists():
-            try:
-                output_feature = cfg["config"]["output_features"][0]["name"]
-            except Exception:
-                pass
-        if not output_feature and test_stats:
-            with open(test_stats, "r") as f:
-                stats = json.load(f)
-            output_feature = next(iter(stats.keys()), "")
-
-        viz_registry = get_visualizations_registry()
-        for viz_name, viz_func in viz_registry.items():
-            if viz_name in train_plots:
-                viz_dir_plot = train_viz
-            elif viz_name in test_plots:
-                viz_dir_plot = test_viz
-            else:
-                continue
-
-            try:
-                viz_func(
-                    training_statistics=[training_stats] if training_stats else [],
-                    test_statistics=[test_stats] if test_stats else [],
-                    probabilities=[probs_path] if probs_path else [],
-                    output_feature_name=output_feature,
-                    ground_truth_split=2,
-                    top_n_classes=[0],
-                    top_k=3,
-                    ground_truth_metadata=gt_metadata,
-                    ground_truth=dataset_path,
-                    split_file=split_file,
-                    output_directory=str(viz_dir_plot),
-                    normalize=False,
-                    file_format="png",
-                )
-                logger.info(f"✔ Generated {viz_name}")
-            except Exception as e:
-                logger.warning(f"✘ Skipped {viz_name}: {e}")
-
-        logger.info(f"All visualizations written to {viz_dir}")
-
-    def generate_html_report(
-        self,
-        title: str,
-        output_dir: str,
-        config: dict,
-        split_info: str,
-    ) -> Path:
-        """Assemble an HTML report from visualizations under train_val/ and test/ folders."""
-        cwd = Path.cwd()
-        report_name = title.lower().replace(" ", "_") + "_report.html"
-        report_path = cwd / report_name
-        output_dir = Path(output_dir)
-        output_type = None
-
-        exp_dirs = sorted(
-            output_dir.glob("experiment_run*"),
-            key=lambda p: p.stat().st_mtime,
-        )
-        if not exp_dirs:
-            raise RuntimeError(f"No 'experiment*' dirs found in {output_dir}")
-        exp_dir = exp_dirs[-1]
-
-        base_viz_dir = exp_dir / "visualizations"
-        train_viz_dir = base_viz_dir / "train"
-        test_viz_dir = base_viz_dir / "test"
-
-        html = get_html_template()
-
-        # Extra CSS & JS: center Plotly and enable CSV download for predictions table
-        html += """
-<style>
-  /* Center Plotly figures (both wrapper and native classes) */
-  .plotly-center { display: flex; justify-content: center; }
-  .plotly-center .plotly-graph-div, .plotly-center .js-plotly-plot { margin: 0 auto !important; }
-  .js-plotly-plot, .plotly-graph-div { margin-left: auto !important; margin-right: auto !important; }
-
-  /* Download button for predictions table */
-  .download-btn {
-    padding: 8px 12px;
-    border: 1px solid #4CAF50;
-    background: #4CAF50;
-    color: white;
-    border-radius: 6px;
-    cursor: pointer;
-  }
-  .download-btn:hover { filter: brightness(0.95); }
-  .preds-controls {
-    display: flex;
-    justify-content: flex-end;
-    gap: 8px;
-    margin: 8px 0;
-  }
-</style>
-<script>
-  function tableToCSV(table){
-    const rows = Array.from(table.querySelectorAll('tr'));
-    return rows.map(row =>
-      Array.from(row.querySelectorAll('th,td')).map(cell => {
-        let text = cell.innerText.replace(/\\r?\\n|\\r/g,' ').trim();
-        if (text.includes('"') || text.includes(',')) {
-          text = '"' + text.replace(/"/g,'""') + '"';
-        }
-        return text;
-      }).join(',')
-    ).join('\\n');
-  }
-  document.addEventListener('DOMContentLoaded', function(){
-    const btn = document.getElementById('downloadPredsCsv');
-    if(btn){
-      btn.addEventListener('click', function(){
-        const tbl = document.querySelector('.predictions-table');
-        if(!tbl){ alert('Predictions table not found.'); return; }
-        const csv = tableToCSV(tbl);
-        const blob = new Blob([csv], {type: 'text/csv;charset=utf-8;'});
-        const url = URL.createObjectURL(blob);
-        const a = document.createElement('a');
-        a.href = url;
-        a.download = 'ground_truth_vs_predictions.csv';
-        document.body.appendChild(a);
-        a.click();
-        document.body.removeChild(a);
-        URL.revokeObjectURL(url);
-      });
-    }
-  });
-</script>
-"""
-        html += f"<h1>{title}</h1>"
-
-        metrics_html = ""
-        train_val_metrics_html = ""
-        test_metrics_html = ""
-        try:
-            train_stats_path = exp_dir / "training_statistics.json"
-            test_stats_path = exp_dir / TEST_STATISTICS_FILE_NAME
-            if train_stats_path.exists() and test_stats_path.exists():
-                with open(train_stats_path) as f:
-                    train_stats = json.load(f)
-                with open(test_stats_path) as f:
-                    test_stats = json.load(f)
-                output_type = detect_output_type(test_stats)
-                metrics_html = format_stats_table_html(train_stats, test_stats, output_type)
-                train_val_metrics_html = format_train_val_stats_table_html(
-                    train_stats, test_stats
-                )
-                test_metrics_html = format_test_merged_stats_table_html(
-                    extract_metrics_from_json(train_stats, test_stats, output_type)[
-                        "test"
-                    ], output_type
-                )
-        except Exception as e:
-            logger.warning(
-                f"Could not load stats for HTML report: {type(e).__name__}: {e}"
-            )
-
-        config_html = ""
-        training_progress = self.get_training_process(output_dir)
-        try:
-            config_html = format_config_table_html(
-                config, split_info, training_progress, output_type
-            )
-        except Exception as e:
-            logger.warning(f"Could not load config for HTML report: {e}")
-
-        # ---------- image rendering with exclusions ----------
-        def render_img_section(
-            title: str,
-            dir_path: Path,
-            output_type: str = None,
-            exclude_names: Optional[set] = None,
-        ) -> str:
-            if not dir_path.exists():
-                return f"<h2>{title}</h2><p><em>Directory not found.</em></p>"
-
-            exclude_names = exclude_names or set()
-
-            imgs = list(dir_path.glob("*.png"))
-
-            # Exclude ROC curves and standard confusion matrices (keep only entropy version)
-            default_exclude = {
-                # "roc_curves.png",  # Remove ROC curves from test tab
-                "confusion_matrix__label_top5.png",  # Remove standard confusion matrix
-                "confusion_matrix__label_top10.png",  # Remove duplicate
-                "confusion_matrix__label_top6.png",   # Remove duplicate
-                "confusion_matrix_entropy__label_top10.png",  # Keep only top5
-                "confusion_matrix_entropy__label_top6.png",   # Keep only top5
-            }
-
-            imgs = [
-                img
-                for img in imgs
-                if img.name not in default_exclude
-                and img.name not in exclude_names
-            ]
-
-            if not imgs:
-                return f"<h2>{title}</h2><p><em>No plots found.</em></p>"
-
-            # Sort images by name for consistent ordering (works with string and numeric labels)
-            imgs = sorted(imgs, key=lambda x: x.name)
-
-            html_section = ""
-            for img in imgs:
-                b64 = encode_image_to_base64(str(img))
-                img_title = img.stem.replace("_", " ").title()
-                html_section += (
-                    f"<h2 style='text-align: center;'>{img_title}</h2>"
-                    f'<div class="plot" style="margin-bottom:20px;text-align:center;">'
-                    f'<img src="data:image/png;base64,{b64}" '
-                    f'style="max-width:90%;max-height:600px;border:1px solid #ddd;" />'
-                    f"</div>"
-                )
-            return html_section
-
-        tab1_content = config_html + metrics_html
-
-        tab2_content = train_val_metrics_html + render_img_section(
-            "Training and Validation Visualizations",
-            train_viz_dir,
-            output_type,
-            exclude_names={
-                "compare_classifiers_performance_from_prob.png",
-                "roc_curves_from_prediction_statistics.png",
-                "precision_recall_curves_from_prediction_statistics.png",
-                "precision_recall_curve.png",
-            },
-        )
-
-        # --- Predictions vs Ground Truth table (REGRESSION ONLY) ---
-        preds_section = ""
-        parquet_path = exp_dir / PREDICTIONS_PARQUET_FILE_NAME
-        if output_type == "regression" and parquet_path.exists():
-            try:
-                # 1) load predictions from Parquet
-                df_preds = pd.read_parquet(parquet_path).reset_index(drop=True)
-                # assume the column containing your model's prediction is named "prediction"
-                # or contains that substring:
-                pred_col = next(
-                    (c for c in df_preds.columns if "prediction" in c.lower()),
-                    None,
-                )
-                if pred_col is None:
-                    raise ValueError("No prediction column found in Parquet output")
-                df_pred = df_preds[[pred_col]].rename(columns={pred_col: "prediction"})
-
-                # 2) load ground truth for the test split from prepared CSV
-                df_all = pd.read_csv(config["label_column_data_path"])
-                df_gt = df_all[df_all[SPLIT_COLUMN_NAME] == 2][
-                    LABEL_COLUMN_NAME
-                ].reset_index(drop=True)
-                # 3) concatenate side-by-side
-                df_table = pd.concat([df_gt, df_pred], axis=1)
-                df_table.columns = [LABEL_COLUMN_NAME, "prediction"]
-
-                # 4) render as HTML
-                preds_html = df_table.to_html(index=False, classes="predictions-table")
-                preds_section = (
-                    "<h2 style='text-align: center;'>Ground Truth vs. Predictions</h2>"
-                    "<div class='preds-controls'>"
-                    "<button id='downloadPredsCsv' class='download-btn'>Download CSV</button>"
-                    "</div>"
-                    "<div class='scroll-rows-30' style='overflow-x:auto; overflow-y:auto; max-height:900px; margin-bottom:20px;'>"
-                    + preds_html
-                    + "</div>"
-                )
-            except Exception as e:
-                logger.warning(f"Could not build Predictions vs GT table: {e}")
-
-        tab3_content = test_metrics_html + preds_section
-
-        if output_type in ("binary", "category") and test_stats_path.exists():
-            try:
-                interactive_plots = build_classification_plots(
-                    str(test_stats_path),
-                    str(train_stats_path) if train_stats_path.exists() else None,
-                )
-                for plot in interactive_plots:
-                    tab3_content += (
-                        f"<h2 style='text-align: center;'>{plot['title']}</h2>"
-                        f"<div class='plotly-center'>{plot['html']}</div>"
-                    )
-                logger.info(f"Generated {len(interactive_plots)} interactive Plotly plots")
-            except Exception as e:
-                logger.warning(f"Could not generate Plotly plots: {e}")
-
-        # Add static TEST PNGs (with default dedupe/exclusions)
-        tab3_content += render_img_section(
-            "Test Visualizations", test_viz_dir, output_type
-        )
-
-        tabbed_html = build_tabbed_html(tab1_content, tab2_content, tab3_content)
-        modal_html = get_metrics_help_modal()
-        html += tabbed_html + modal_html + get_html_closing()
-
-        try:
-            with open(report_path, "w") as f:
-                f.write(html)
-            logger.info(f"HTML report generated at: {report_path}")
-        except Exception as e:
-            logger.error(f"Failed to write HTML report: {e}")
-            raise
-
-        return report_path
-
-
-class WorkflowOrchestrator:
-    """Manages the image-classification workflow."""
-
-    def __init__(self, args: argparse.Namespace, backend: Backend):
-        self.args = args
-        self.backend = backend
-        self.temp_dir: Optional[Path] = None
-        self.image_extract_dir: Optional[Path] = None
-        logger.info(f"Orchestrator initialized with backend: {type(backend).__name__}")
-
-    def run(self) -> None:
-        """Execute the full workflow end-to-end."""
-        # Delegate to the backend's run_experiment method
-        self.backend.run_experiment()
-
-
-class ImageLearnerCLI:
-    """Manages the image-classification workflow."""
-
-    def __init__(self, args: argparse.Namespace, backend: Backend):
-        self.args = args
-        self.backend = backend
-        self.temp_dir: Optional[Path] = None
-        self.image_extract_dir: Optional[Path] = None
-        logger.info(f"Orchestrator initialized with backend: {type(backend).__name__}")
-
-    def _create_temp_dirs(self) -> None:
-        """Create temporary output and image extraction directories."""
-        try:
-            self.temp_dir = Path(
-                tempfile.mkdtemp(dir=self.args.output_dir, prefix=TEMP_DIR_PREFIX)
-            )
-            self.image_extract_dir = self.temp_dir / "images"
-            self.image_extract_dir.mkdir()
-            logger.info(f"Created temp directory: {self.temp_dir}")
-        except Exception:
-            logger.error("Failed to create temporary directories", exc_info=True)
-            raise
-
-    def _extract_images(self) -> None:
-        """Extract images into the temp image directory.
-        - If a ZIP file is provided, extract it
-        - If a directory is provided, copy its contents
-        """
-        if self.image_extract_dir is None:
-            raise RuntimeError("Temp image directory not initialized.")
-        src = Path(self.args.image_zip)
-        logger.info(f"Preparing images from {src} → {self.image_extract_dir}")
-        try:
-            if src.is_dir():
-                # copy directory tree
-                for root, dirs, files in os.walk(src):
-                    rel = Path(root).relative_to(src)
-                    target_root = self.image_extract_dir / rel
-                    target_root.mkdir(parents=True, exist_ok=True)
-                    for fn in files:
-                        shutil.copy2(Path(root) / fn, target_root / fn)
-                logger.info("Image directory copied.")
-            else:
-                with zipfile.ZipFile(src, "r") as z:
-                    z.extractall(self.image_extract_dir)
-                logger.info("Image extraction complete.")
-        except Exception:
-            logger.error("Error preparing images", exc_info=True)
-            raise
-
-    def _process_fixed_split(
-        self, df: pd.DataFrame
-    ) -> Tuple[pd.DataFrame, Dict[str, Any], str]:
-        """Process datasets that already have a split column."""
-        unique = set(df[SPLIT_COLUMN_NAME].unique())
-        if unique == {0, 2}:
-            # Split 0/2 detected, create validation set
-            df = split_data_0_2(
-                df=df,
-                split_column=SPLIT_COLUMN_NAME,
-                validation_size=self.args.validation_size,
-                random_state=self.args.random_seed,
-                label_column=LABEL_COLUMN_NAME,
-            )
-            split_config = {"type": "fixed", "column": SPLIT_COLUMN_NAME}
-            split_info = (
-                "Detected a split column (with values 0 and 2) in the input CSV. "
-                f"Used this column as a base and reassigned "
-                f"{self.args.validation_size * 100:.1f}% "
-                "of the training set (originally labeled 0) to validation (labeled 1) using stratified sampling."
-            )
-            logger.info("Applied custom 0/2 split.")
-        elif unique.issubset({0, 1, 2}):
-            # Standard 0/1/2 split
-            split_config = {"type": "fixed", "column": SPLIT_COLUMN_NAME}
-            split_info = (
-                "Detected a split column with train(0)/validation(1)/test(2) "
-                "values in the input CSV. Used this column as-is."
-            )
-            logger.info("Fixed split column detected.")
-        else:
-            raise ValueError(
-                f"Split column contains unexpected values: {unique}. "
-                "Expected: {{0,1,2}} or {{0,2}}"
-            )
-
-        return df, split_config, split_info
-
-    def _prepare_data(self) -> Tuple[Path, Dict[str, Any], str]:
-        """Load CSV, update image paths, handle splits, and write prepared CSV."""
-        if not self.temp_dir or not self.image_extract_dir:
-            raise RuntimeError("Temp dirs not initialized before data prep.")
-
-        try:
-            df = pd.read_csv(self.args.csv_file)
-            logger.info(f"Loaded CSV: {self.args.csv_file}")
-        except Exception:
-            logger.error("Error loading CSV file", exc_info=True)
-            raise
-
-        required = {IMAGE_PATH_COLUMN_NAME, LABEL_COLUMN_NAME}
-        missing = required - set(df.columns)
-        if missing:
-            raise ValueError(f"Missing CSV columns: {', '.join(missing)}")
-
-        try:
-            # Use relative paths that Ludwig can resolve from its internal working directory
-            df[IMAGE_PATH_COLUMN_NAME] = df[IMAGE_PATH_COLUMN_NAME].apply(
-                lambda p: str(Path("images") / p)
-            )
-        except Exception:
-            logger.error("Error updating image paths", exc_info=True)
-            raise
-
-        if SPLIT_COLUMN_NAME in df.columns:
-            df, split_config, split_info = self._process_fixed_split(df)
-        else:
-            logger.info("No split column; creating stratified random split")
-            df = create_stratified_random_split(
-                df=df,
-                split_column=SPLIT_COLUMN_NAME,
-                split_probabilities=self.args.split_probabilities,
-                random_state=self.args.random_seed,
-                label_column=LABEL_COLUMN_NAME,
-            )
-            split_config = {
-                "type": "fixed",
-                "column": SPLIT_COLUMN_NAME,
-            }
-            split_info = (
-                f"No split column in CSV. Created stratified random split: "
-                f"{[int(p * 100) for p in self.args.split_probabilities]}% "
-                f"for train/val/test with balanced label distribution."
-            )
-
-        final_csv = self.temp_dir / TEMP_CSV_FILENAME
-
-        try:
-
-            df.to_csv(final_csv, index=False)
-            logger.info(f"Saved prepared data to {final_csv}")
-        except Exception:
-            logger.error("Error saving prepared CSV", exc_info=True)
-            raise
-
-        return final_csv, split_config, split_info
-
-# Removed duplicate method
-
-    def _detect_image_dimensions(self) -> Tuple[int, int]:
-        """Detect image dimensions from the first image in the dataset."""
-        try:
-            import zipfile
-            from PIL import Image
-            import io
-
-            # Check if image_zip is provided
-            if not self.args.image_zip:
-                logger.warning("No image zip provided, using default 224x224")
-                return 224, 224
-
-            # Extract first image to detect dimensions
-            with zipfile.ZipFile(self.args.image_zip, 'r') as z:
-                image_files = [f for f in z.namelist() if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
-                if not image_files:
-                    logger.warning("No image files found in zip, using default 224x224")
-                    return 224, 224
-
-                # Check first image
-                with z.open(image_files[0]) as f:
-                    img = Image.open(io.BytesIO(f.read()))
-                    width, height = img.size
-                    logger.info(f"Detected image dimensions: {width}x{height}")
-                    return height, width  # Return as (height, width) to match encoder config
-
-        except Exception as e:
-            logger.warning(f"Error detecting image dimensions: {e}, using default 224x224")
-            return 224, 224
-
-    def _cleanup_temp_dirs(self) -> None:
-        if self.temp_dir and self.temp_dir.exists():
-            logger.info(f"Cleaning up temp directory: {self.temp_dir}")
-            # Don't clean up for debugging
-            shutil.rmtree(self.temp_dir, ignore_errors=True)
-        self.temp_dir = None
-        self.image_extract_dir = None
-
-    def run(self) -> None:
-        """Execute the full workflow end-to-end."""
-        logger.info("Starting workflow...")
-        self.args.output_dir.mkdir(parents=True, exist_ok=True)
-
-        try:
-            self._create_temp_dirs()
-            self._extract_images()
-            csv_path, split_cfg, split_info = self._prepare_data()
-
-            use_pretrained = self.args.use_pretrained or self.args.fine_tune
-
-            backend_args = {
-                "model_name": self.args.model_name,
-                "fine_tune": self.args.fine_tune,
-                "use_pretrained": use_pretrained,
-                "epochs": self.args.epochs,
-                "batch_size": self.args.batch_size,
-                "preprocessing_num_processes": self.args.preprocessing_num_processes,
-                "split_probabilities": self.args.split_probabilities,
-                "learning_rate": self.args.learning_rate,
-                "random_seed": self.args.random_seed,
-                "early_stop": self.args.early_stop,
-                "label_column_data_path": csv_path,
-                "augmentation": self.args.augmentation,
-                "image_resize": self.args.image_resize,
-                "image_zip": self.args.image_zip,
-                "threshold": self.args.threshold,
-            }
-            yaml_str = self.backend.prepare_config(backend_args, split_cfg)
-
-            config_file = self.temp_dir / TEMP_CONFIG_FILENAME
-            config_file.write_text(yaml_str)
-            logger.info(f"Wrote backend config: {config_file}")
-
-            ran_ok = True
-            try:
-                # Run Ludwig experiment with absolute paths to avoid working directory issues
-                self.backend.run_experiment(
-                    csv_path,
-                    config_file,
-                    self.args.output_dir,
-                    self.args.random_seed,
-                )
-            except Exception:
-                logger.error("Workflow execution failed", exc_info=True)
-                ran_ok = False
-
-            if ran_ok:
-                logger.info("Workflow completed successfully.")
-                # Generate a very small set of plots to conserve disk space
-                self.backend.generate_plots(self.args.output_dir)
-                # Build HTML report (robust to missing metrics)
-                report_file = self.backend.generate_html_report(
-                    "Image Classification Results",
-                    self.args.output_dir,
-                    backend_args,
-                    split_info,
-                )
-                logger.info(f"HTML report generated at: {report_file}")
-                # Convert predictions parquet → csv
-                self.backend.convert_parquet_to_csv(self.args.output_dir)
-                logger.info("Converted Parquet to CSV.")
-                # Post-process cleanup to reduce disk footprint for subsequent tests
-                try:
-                    self._postprocess_cleanup(self.args.output_dir)
-                except Exception as cleanup_err:
-                    logger.warning(f"Cleanup step failed: {cleanup_err}")
-            else:
-                # Fallback: create minimal outputs so downstream steps can proceed
-                logger.warning("Falling back to minimal outputs due to runtime failure.")
-                try:
-                    self._create_minimal_outputs(self.args.output_dir, csv_path)
-                    # Even in fallback, produce an HTML shell so tests find required text
-                    report_file = self.backend.generate_html_report(
-                        "Image Classification Results",
-                        self.args.output_dir,
-                        backend_args,
-                        split_info,
-                    )
-                    logger.info(f"HTML report (fallback) generated at: {report_file}")
-                except Exception as fb_err:
-                    logger.error(f"Failed to build fallback outputs: {fb_err}")
-                    raise
-
-        except Exception:
-            logger.error("Workflow execution failed", exc_info=True)
-            raise
-        finally:
-            self._cleanup_temp_dirs()
-
-    def _postprocess_cleanup(self, output_dir: Path) -> None:
-        """Remove large intermediates and caches to conserve disk space across tests."""
-        output_dir = Path(output_dir)
-        exp_dirs = sorted(
-            output_dir.glob("experiment_run*"),
-            key=lambda p: p.stat().st_mtime,
-        )
-        if exp_dirs:
-            exp_dir = exp_dirs[-1]
-            # Remove training checkpoints directory if present
-            ckpt_dir = exp_dir / "model" / "training_checkpoints"
-            if ckpt_dir.exists():
-                shutil.rmtree(ckpt_dir, ignore_errors=True)
-            # Remove predictions parquet once CSV is generated
-            parquet_path = exp_dir / PREDICTIONS_PARQUET_FILE_NAME
-            if parquet_path.exists():
-                try:
-                    parquet_path.unlink()
-                except Exception:
-                    pass
-
-        # Clear torch hub cache under the job-scoped home, if present
-        job_home_torch_hub = Path.cwd() / "home" / ".cache" / "torch" / "hub"
-        if job_home_torch_hub.exists():
-            shutil.rmtree(job_home_torch_hub, ignore_errors=True)
-
-        # Also try the default user cache as a best-effort (may not exist in job sandbox)
-        user_home_torch_hub = Path.home() / ".cache" / "torch" / "hub"
-        if user_home_torch_hub.exists():
-            shutil.rmtree(user_home_torch_hub, ignore_errors=True)
-
-        # Clear huggingface cache if present in the job sandbox
-        job_home_hf = Path.cwd() / "home" / ".cache" / "huggingface"
-        if job_home_hf.exists():
-            shutil.rmtree(job_home_hf, ignore_errors=True)
-
-    def _create_minimal_outputs(self, output_dir: Path, prepared_csv_path: Path) -> None:
-        """Create a minimal set of outputs so Galaxy can collect expected artifacts.
-
-        - experiment_run/
-            - predictions.csv (1 column)
-            - visualizations/train/ (empty)
-            - visualizations/test/ (empty)
-            - model/
-                - model_weights/ (empty)
-                - model_hyperparameters.json (stub)
-        """
-        output_dir = Path(output_dir)
-        exp_dir = output_dir / "experiment_run"
-        (exp_dir / "visualizations" / "train").mkdir(parents=True, exist_ok=True)
-        (exp_dir / "visualizations" / "test").mkdir(parents=True, exist_ok=True)
-        model_dir = exp_dir / "model"
-        (model_dir / "model_weights").mkdir(parents=True, exist_ok=True)
-
-        # Stub JSON so the tool's copy step succeeds
-        try:
-            (model_dir / "model_hyperparameters.json").write_text("{}\n")
-        except Exception:
-            pass
-
-        # Create a small predictions.csv with exactly 1 column
-        try:
-            df_all = pd.read_csv(prepared_csv_path)
-            from constants import SPLIT_COLUMN_NAME  # local import to avoid cycle at top
-            num_rows = int((df_all[SPLIT_COLUMN_NAME] == 2).sum()) if SPLIT_COLUMN_NAME in df_all.columns else 1
-        except Exception:
-            num_rows = 1
-        num_rows = max(1, num_rows)
-        pd.DataFrame({"prediction": [0] * num_rows}).to_csv(exp_dir / "predictions.csv", index=False)
-
-
-def parse_learning_rate(s):
-    try:
-        return float(s)
-    except (TypeError, ValueError):
-        return None
-
-
-def aug_parse(aug_string: str):
-    """
-    Parse comma-separated augmentation keys into Ludwig augmentation dicts.
-    Raises ValueError on unknown key.
-    """
-    mapping = {
-        "random_horizontal_flip": {"type": "random_horizontal_flip"},
-        "random_vertical_flip": {"type": "random_vertical_flip"},
-        "random_rotate": {"type": "random_rotate", "degree": 10},
-        "random_blur": {"type": "random_blur", "kernel_size": 3},
-        "random_brightness": {"type": "random_brightness", "min": 0.5, "max": 2.0},
-        "random_contrast": {"type": "random_contrast", "min": 0.5, "max": 2.0},
-    }
-    aug_list = []
-    for tok in aug_string.split(","):
-        key = tok.strip()
-        if not key:
-            continue
-        if key not in mapping:
-            valid = ", ".join(mapping.keys())
-            raise ValueError(f"Unknown augmentation '{key}'. Valid choices: {valid}")
-        aug_list.append(mapping[key])
-    return aug_list
-
-
-class SplitProbAction(argparse.Action):
-    def __call__(self, parser, namespace, values, option_string=None):
-        train, val, test = values
-        total = train + val + test
-        if abs(total - 1.0) > 1e-6:
-            parser.error(
-                f"--split-probabilities must sum to 1.0; "
-                f"got {train:.3f} + {val:.3f} + {test:.3f} = {total:.3f}"
-            )
-        setattr(namespace, self.dest, values)
-
 
 def main():
     parser = argparse.ArgumentParser(
@@ -1893,7 +30,7 @@
         "--csv-file",
         required=True,
         type=Path,
-        help="Path to the input CSV",
+        help="Path to the input metadata file (CSV, TSV, etc)",
     )
     parser.add_argument(
         "--image-zip",
@@ -2008,18 +145,7 @@
 
     args = parser.parse_args()
 
-    if not 0.0 <= args.validation_size <= 1.0:
-        parser.error("validation-size must be between 0.0 and 1.0")
-    if not args.csv_file.is_file():
-        parser.error(f"CSV not found: {args.csv_file}")
-    if not (args.image_zip.is_file() or args.image_zip.is_dir()):
-        parser.error(f"ZIP or directory not found: {args.image_zip}")
-    if args.augmentation is not None:
-        try:
-            augmentation_setup = aug_parse(args.augmentation)
-            setattr(args, "augmentation", augmentation_setup)
-        except ValueError as e:
-            parser.error(str(e))
+    argument_checker(args, parser)
 
     backend_instance = LudwigDirectBackend()
     orchestrator = ImageLearnerCLI(args, backend_instance)