diff image_learner_cli.py @ 0:54b871dfc51e draft default tip

planemo upload for repository https://github.com/goeckslab/gleam.git commit b7411ff35b6228ccdfd36cd4ebd946c03ac7f7e9
author goeckslab
date Tue, 03 Jun 2025 21:22:11 +0000
parents
children
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/image_learner_cli.py	Tue Jun 03 21:22:11 2025 +0000
@@ -0,0 +1,1137 @@
+#!/usr/bin/env python3
+import argparse
+import json
+import logging
+import os
+import shutil
+import sys
+import tempfile
+import zipfile
+from pathlib import Path
+from typing import Any, Dict, Optional, Protocol, Tuple
+
+import pandas as pd
+import yaml
+from ludwig.globals import (
+    DESCRIPTION_FILE_NAME,
+    PREDICTIONS_PARQUET_FILE_NAME,
+    TEST_STATISTICS_FILE_NAME,
+    TRAIN_SET_METADATA_FILE_NAME,
+)
+from ludwig.utils.data_utils import get_split_path
+from ludwig.visualize import get_visualizations_registry
+from sklearn.model_selection import train_test_split
+from utils import encode_image_to_base64, get_html_closing, get_html_template
+
+# --- Constants ---
+SPLIT_COLUMN_NAME = 'split'
+LABEL_COLUMN_NAME = 'label'
+IMAGE_PATH_COLUMN_NAME = 'image_path'
+DEFAULT_SPLIT_PROBABILITIES = [0.7, 0.1, 0.2]
+TEMP_CSV_FILENAME = "processed_data_for_ludwig.csv"
+TEMP_CONFIG_FILENAME = "ludwig_config.yaml"
+TEMP_DIR_PREFIX = "ludwig_api_work_"
+MODEL_ENCODER_TEMPLATES: Dict[str, Any] = {
+    'stacked_cnn': 'stacked_cnn',
+    'resnet18': {'type': 'resnet', 'model_variant': 18},
+    'resnet34': {'type': 'resnet', 'model_variant': 34},
+    'resnet50': {'type': 'resnet', 'model_variant': 50},
+    'resnet101': {'type': 'resnet', 'model_variant': 101},
+    'resnet152': {'type': 'resnet', 'model_variant': 152},
+    'resnext50_32x4d': {'type': 'resnext', 'model_variant': '50_32x4d'},
+    'resnext101_32x8d': {'type': 'resnext', 'model_variant': '101_32x8d'},
+    'resnext101_64x4d': {'type': 'resnext', 'model_variant': '101_64x4d'},
+    'resnext152_32x8d': {'type': 'resnext', 'model_variant': '152_32x8d'},
+    'wide_resnet50_2': {'type': 'wide_resnet', 'model_variant': '50_2'},
+    'wide_resnet101_2': {'type': 'wide_resnet', 'model_variant': '101_2'},
+    'wide_resnet103_2': {'type': 'wide_resnet', 'model_variant': '103_2'},
+    'efficientnet_b0': {'type': 'efficientnet', 'model_variant': 'b0'},
+    'efficientnet_b1': {'type': 'efficientnet', 'model_variant': 'b1'},
+    'efficientnet_b2': {'type': 'efficientnet', 'model_variant': 'b2'},
+    'efficientnet_b3': {'type': 'efficientnet', 'model_variant': 'b3'},
+    'efficientnet_b4': {'type': 'efficientnet', 'model_variant': 'b4'},
+    'efficientnet_b5': {'type': 'efficientnet', 'model_variant': 'b5'},
+    'efficientnet_b6': {'type': 'efficientnet', 'model_variant': 'b6'},
+    'efficientnet_b7': {'type': 'efficientnet', 'model_variant': 'b7'},
+    'efficientnet_v2_s': {'type': 'efficientnet', 'model_variant': 'v2_s'},
+    'efficientnet_v2_m': {'type': 'efficientnet', 'model_variant': 'v2_m'},
+    'efficientnet_v2_l': {'type': 'efficientnet', 'model_variant': 'v2_l'},
+    'regnet_y_400mf': {'type': 'regnet', 'model_variant': 'y_400mf'},
+    'regnet_y_800mf': {'type': 'regnet', 'model_variant': 'y_800mf'},
+    'regnet_y_1_6gf': {'type': 'regnet', 'model_variant': 'y_1_6gf'},
+    'regnet_y_3_2gf': {'type': 'regnet', 'model_variant': 'y_3_2gf'},
+    'regnet_y_8gf': {'type': 'regnet', 'model_variant': 'y_8gf'},
+    'regnet_y_16gf': {'type': 'regnet', 'model_variant': 'y_16gf'},
+    'regnet_y_32gf': {'type': 'regnet', 'model_variant': 'y_32gf'},
+    'regnet_y_128gf': {'type': 'regnet', 'model_variant': 'y_128gf'},
+    'regnet_x_400mf': {'type': 'regnet', 'model_variant': 'x_400mf'},
+    'regnet_x_800mf': {'type': 'regnet', 'model_variant': 'x_800mf'},
+    'regnet_x_1_6gf': {'type': 'regnet', 'model_variant': 'x_1_6gf'},
+    'regnet_x_3_2gf': {'type': 'regnet', 'model_variant': 'x_3_2gf'},
+    'regnet_x_8gf': {'type': 'regnet', 'model_variant': 'x_8gf'},
+    'regnet_x_16gf': {'type': 'regnet', 'model_variant': 'x_16gf'},
+    'regnet_x_32gf': {'type': 'regnet', 'model_variant': 'x_32gf'},
+    'vgg11': {'type': 'vgg', 'model_variant': 11},
+    'vgg11_bn': {'type': 'vgg', 'model_variant': '11_bn'},
+    'vgg13': {'type': 'vgg', 'model_variant': 13},
+    'vgg13_bn': {'type': 'vgg', 'model_variant': '13_bn'},
+    'vgg16': {'type': 'vgg', 'model_variant': 16},
+    'vgg16_bn': {'type': 'vgg', 'model_variant': '16_bn'},
+    'vgg19': {'type': 'vgg', 'model_variant': 19},
+    'vgg19_bn': {'type': 'vgg', 'model_variant': '19_bn'},
+    'shufflenet_v2_x0_5': {'type': 'shufflenet_v2', 'model_variant': 'x0_5'},
+    'shufflenet_v2_x1_0': {'type': 'shufflenet_v2', 'model_variant': 'x1_0'},
+    'shufflenet_v2_x1_5': {'type': 'shufflenet_v2', 'model_variant': 'x1_5'},
+    'shufflenet_v2_x2_0': {'type': 'shufflenet_v2', 'model_variant': 'x2_0'},
+    'squeezenet1_0': {'type': 'squeezenet', 'model_variant': '1_0'},
+    'squeezenet1_1': {'type': 'squeezenet', 'model_variant': '1_1'},
+    'swin_t': {'type': 'swin_transformer', 'model_variant': 't'},
+    'swin_s': {'type': 'swin_transformer', 'model_variant': 's'},
+    'swin_b': {'type': 'swin_transformer', 'model_variant': 'b'},
+    'swin_v2_t': {'type': 'swin_transformer', 'model_variant': 'v2_t'},
+    'swin_v2_s': {'type': 'swin_transformer', 'model_variant': 'v2_s'},
+    'swin_v2_b': {'type': 'swin_transformer', 'model_variant': 'v2_b'},
+    'vit_b_16': {'type': 'vision_transformer', 'model_variant': 'b_16'},
+    'vit_b_32': {'type': 'vision_transformer', 'model_variant': 'b_32'},
+    'vit_l_16': {'type': 'vision_transformer', 'model_variant': 'l_16'},
+    'vit_l_32': {'type': 'vision_transformer', 'model_variant': 'l_32'},
+    'vit_h_14': {'type': 'vision_transformer', 'model_variant': 'h_14'},
+    'convnext_tiny': {'type': 'convnext', 'model_variant': 'tiny'},
+    'convnext_small': {'type': 'convnext', 'model_variant': 'small'},
+    'convnext_base': {'type': 'convnext', 'model_variant': 'base'},
+    'convnext_large': {'type': 'convnext', 'model_variant': 'large'},
+    'maxvit_t': {'type': 'maxvit', 'model_variant': 't'},
+    'alexnet': {'type': 'alexnet'},
+    'googlenet': {'type': 'googlenet'},
+    'inception_v3': {'type': 'inception_v3'},
+    'mobilenet_v2': {'type': 'mobilenet_v2'},
+    'mobilenet_v3_large': {'type': 'mobilenet_v3_large'},
+    'mobilenet_v3_small': {'type': 'mobilenet_v3_small'},
+}
+
+# --- Logging Setup ---
+logging.basicConfig(
+    level=logging.INFO,
+    format='%(asctime)s %(levelname)s %(name)s: %(message)s'
+)
+logger = logging.getLogger("ImageLearner")
+
+
+def format_config_table_html(
+        config: dict,
+        split_info: Optional[str] = None,
+        training_progress: dict = None) -> str:
+    display_keys = [
+        "model_name",
+        "epochs",
+        "batch_size",
+        "fine_tune",
+        "use_pretrained",
+        "learning_rate",
+        "random_seed",
+        "early_stop",
+    ]
+
+    rows = []
+
+    for key in display_keys:
+        val = config.get(key, "N/A")
+        if key == "batch_size":
+            if val is not None:
+                val = int(val)
+            else:
+                if training_progress:
+                    val = "Auto-selected batch size by Ludwig:<br>"
+                    resolved_val = training_progress.get("batch_size")
+                    val += (
+                        f"<span style='font-size: 0.85em;'>{resolved_val}</span><br>"
+                    )
+                else:
+                    val = "auto"
+        if key == "learning_rate":
+            resolved_val = None
+            if val is None or val == "auto":
+                if training_progress:
+                    resolved_val = training_progress.get("learning_rate")
+                    val = (
+                        "Auto-selected learning rate by Ludwig:<br>"
+                        f"<span style='font-size: 0.85em;'>{resolved_val if resolved_val else val}</span><br>"
+                        "<span style='font-size: 0.85em;'>"
+                        "Based on model architecture and training setup (e.g., fine-tuning).<br>"
+                        "See <a href='https://ludwig.ai/latest/configuration/trainer/#trainer-parameters' "
+                        "target='_blank'>Ludwig Trainer Parameters</a> for details."
+                        "</span>"
+                    )
+                else:
+                    val = (
+                        "Auto-selected by Ludwig<br>"
+                        "<span style='font-size: 0.85em;'>"
+                        "Automatically tuned based on architecture and dataset.<br>"
+                        "See <a href='https://ludwig.ai/latest/configuration/trainer/#trainer-parameters' "
+                        "target='_blank'>Ludwig Trainer Parameters</a> for details."
+                        "</span>"
+                    )
+            else:
+                val = f"{val:.6f}"
+        if key == "epochs":
+            if training_progress and "epoch" in training_progress and val > training_progress["epoch"]:
+                val = (
+                    f"Because of early stopping: the training"
+                    f"stopped at epoch {training_progress['epoch']}"
+                )
+
+        if val is None:
+            continue
+        rows.append(
+            f"<tr>"
+            f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left;'>"
+            f"{key.replace('_', ' ').title()}</td>"
+            f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center;'>{val}</td>"
+            f"</tr>"
+        )
+
+    if split_info:
+        rows.append(
+            f"<tr>"
+            f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left;'>Data Split</td>"
+            f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center;'>{split_info}</td>"
+            f"</tr>"
+        )
+
+    return (
+        "<h2 style='text-align: center;'>Training Setup</h2>"
+        "<div style='display: flex; justify-content: center;'>"
+        "<table style='border-collapse: collapse; width: 60%; table-layout: auto;'>"
+        "<thead><tr>"
+        "<th style='padding: 10px; border: 1px solid #ccc; text-align: left;'>Parameter</th>"
+        "<th style='padding: 10px; border: 1px solid #ccc; text-align: center;'>Value</th>"
+        "</tr></thead><tbody>" + "".join(rows) + "</tbody></table></div><br>"
+        "<p style='text-align: center; font-size: 0.9em;'>"
+        "Model trained using Ludwig.<br>"
+        "If want to learn more about Ludwig default settings,"
+        "please check the their <a href='https://ludwig.ai' target='_blank'>website(ludwig.ai)</a>."
+        "</p><hr>"
+    )
+
+
+def format_stats_table_html(training_stats: dict, test_stats: dict) -> str:
+    train_metrics = training_stats.get("training", {}).get("label", {})
+    val_metrics = training_stats.get("validation", {}).get("label", {})
+    test_metrics = test_stats.get("label", {})
+
+    all_metrics = set(train_metrics) | set(val_metrics) | set(test_metrics)
+
+    def get_last_value(stats, key):
+        val = stats.get(key)
+        if isinstance(val, list) and val:
+            return val[-1]
+        elif isinstance(val, (int, float)):
+            return val
+        return None
+
+    rows = []
+    for metric in sorted(all_metrics):
+        t = get_last_value(train_metrics, metric)
+        v = get_last_value(val_metrics, metric)
+        te = get_last_value(test_metrics, metric)
+        if all(x is not None for x in [t, v, te]):
+            row = (
+                f"<tr>"
+                f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left;'>{metric}</td>"
+                f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center;'>{t:.4f}</td>"
+                f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center;'>{v:.4f}</td>"
+                f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center;'>{te:.4f}</td>"
+                f"</tr>"
+            )
+            rows.append(row)
+
+    if not rows:
+        return "<p><em>No metric values found.</em></p>"
+
+    return (
+        "<h2 style='text-align: center;'>Model Performance Summary</h2>"
+        "<div style='display: flex; justify-content: center;'>"
+        "<table style='border-collapse: collapse; width: 80%; table-layout: fixed;'>"
+        "<colgroup>"
+        "<col style='width: 40%;'>"
+        "<col style='width: 20%;'>"
+        "<col style='width: 20%;'>"
+        "<col style='width: 20%;'>"
+        "</colgroup>"
+        "<thead><tr>"
+        "<th style='padding: 10px; border: 1px solid #ccc; text-align: left;'>Metric</th>"
+        "<th style='padding: 10px; border: 1px solid #ccc; text-align: center;'>Train</th>"
+        "<th style='padding: 10px; border: 1px solid #ccc; text-align: center;'>Validation</th>"
+        "<th style='padding: 10px; border: 1px solid #ccc; text-align: center;'>Test</th>"
+        "</tr></thead><tbody>" + "".join(rows) + "</tbody></table></div><br>"
+    )
+
+
+def build_tabbed_html(
+        metrics_html: str,
+        train_viz_html: str,
+        test_viz_html: str) -> str:
+    return f"""
+<style>
+.tabs {{
+  display: flex;
+  border-bottom: 2px solid #ccc;
+  margin-bottom: 1rem;
+}}
+.tab {{
+  padding: 10px 20px;
+  cursor: pointer;
+  border: 1px solid #ccc;
+  border-bottom: none;
+  background: #f9f9f9;
+  margin-right: 5px;
+  border-top-left-radius: 8px;
+  border-top-right-radius: 8px;
+}}
+.tab.active {{
+  background: white;
+  font-weight: bold;
+}}
+.tab-content {{
+  display: none;
+  padding: 20px;
+  border: 1px solid #ccc;
+  border-top: none;
+}}
+.tab-content.active {{
+  display: block;
+}}
+</style>
+
+<div class="tabs">
+  <div class="tab active" onclick="showTab('metrics')"> Config & Metrics</div>
+  <div class="tab" onclick="showTab('trainval')"> Train/Validation Plots</div>
+  <div class="tab" onclick="showTab('test')"> Test Plots</div>
+</div>
+
+<div id="metrics" class="tab-content active">
+  {metrics_html}
+</div>
+<div id="trainval" class="tab-content">
+  {train_viz_html}
+</div>
+<div id="test" class="tab-content">
+  {test_viz_html}
+</div>
+
+<script>
+function showTab(id) {{
+  document.querySelectorAll('.tab-content').forEach(el => el.classList.remove('active'));
+  document.querySelectorAll('.tab').forEach(el => el.classList.remove('active'));
+  document.getElementById(id).classList.add('active');
+  document.querySelector(`.tab[onclick*="${{id}}"]`).classList.add('active');
+}}
+</script>
+"""
+
+
+def split_data_0_2(
+    df: pd.DataFrame,
+    split_column: str,
+    validation_size: float = 0.15,
+    random_state: int = 42,
+    label_column: Optional[str] = None,
+) -> pd.DataFrame:
+    """
+    Given a DataFrame whose split_column only contains {0,2}, re-assign
+    a portion of the 0s to become 1s (validation). Returns a fresh DataFrame.
+    """
+    # Work on a copy
+    out = df.copy()
+    # Ensure split col is integer dtype
+    out[split_column] = pd.to_numeric(out[split_column], errors="coerce").astype(int)
+
+    idx_train = out.index[out[split_column] == 0].tolist()
+
+    if not idx_train:
+        logger.info("No rows with split=0; nothing to do.")
+        return out
+
+    # Determine stratify array if possible
+    stratify_arr = None
+    if label_column and label_column in out.columns:
+        # Only stratify if at least two classes and enough samples
+        label_counts = out.loc[idx_train, label_column].value_counts()
+        if label_counts.size > 1 and (label_counts.min() * validation_size) >= 1:
+            stratify_arr = out.loc[idx_train, label_column]
+        else:
+            logger.warning("Cannot stratify (too few labels); splitting without stratify.")
+
+    # Edge cases
+    if validation_size <= 0:
+        logger.info("validation_size <= 0; keeping all as train.")
+        return out
+    if validation_size >= 1:
+        logger.info("validation_size >= 1; moving all train → validation.")
+        out.loc[idx_train, split_column] = 1
+        return out
+
+    # Do the split
+    try:
+        train_idx, val_idx = train_test_split(
+            idx_train,
+            test_size=validation_size,
+            random_state=random_state,
+            stratify=stratify_arr
+        )
+    except ValueError as e:
+        logger.warning(f"Stratified split failed ({e}); retrying without stratify.")
+        train_idx, val_idx = train_test_split(
+            idx_train,
+            test_size=validation_size,
+            random_state=random_state,
+            stratify=None
+        )
+
+    # Assign new splits
+    out.loc[train_idx, split_column] = 0
+    out.loc[val_idx, split_column] = 1
+    # idx_test stays at 2
+
+    # Cast back to a clean integer type
+    out[split_column] = out[split_column].astype(int)
+    # print(out)
+    return out
+
+
+class Backend(Protocol):
+    """Interface for a machine learning backend."""
+    def prepare_config(
+        self,
+        config_params: Dict[str, Any],
+        split_config: Dict[str, Any]
+    ) -> str:
+        ...
+
+    def run_experiment(
+        self,
+        dataset_path: Path,
+        config_path: Path,
+        output_dir: Path,
+        random_seed: int,
+    ) -> None:
+        ...
+
+    def generate_plots(
+        self,
+        output_dir: Path
+    ) -> None:
+        ...
+
+    def generate_html_report(
+        self,
+        title: str,
+        output_dir: str
+    ) -> Path:
+        ...
+
+
+class LudwigDirectBackend:
+    """
+    Backend for running Ludwig experiments directly via the internal experiment_cli function.
+    """
+
+    def prepare_config(
+        self,
+        config_params: Dict[str, Any],
+        split_config: Dict[str, Any],
+    ) -> str:
+        """
+        Build and serialize the Ludwig YAML configuration.
+        """
+        logger.info("LudwigDirectBackend: Preparing YAML configuration.")
+
+        model_name = config_params.get("model_name", "resnet18")
+        use_pretrained = config_params.get("use_pretrained", False)
+        fine_tune = config_params.get("fine_tune", False)
+        epochs = config_params.get("epochs", 10)
+        batch_size = config_params.get("batch_size")
+        num_processes = config_params.get("preprocessing_num_processes", 1)
+        early_stop = config_params.get("early_stop", None)
+        learning_rate = config_params.get("learning_rate")
+        learning_rate = "auto" if learning_rate is None else float(learning_rate)
+        trainable = fine_tune or (not use_pretrained)
+        if not use_pretrained and not trainable:
+            logger.warning("trainable=False; use_pretrained=False is ignored.")
+            logger.warning("Setting trainable=True to train the model from scratch.")
+            trainable = True
+
+        # Encoder setup
+        raw_encoder = MODEL_ENCODER_TEMPLATES.get(model_name, model_name)
+        if isinstance(raw_encoder, dict):
+            encoder_config = {
+                **raw_encoder,
+                "use_pretrained": use_pretrained,
+                "trainable": trainable,
+            }
+        else:
+            encoder_config = {"type": raw_encoder}
+
+        # Trainer & optimizer
+        # optimizer = {"type": "adam", "learning_rate": 5e-5} if fine_tune else {"type": "adam"}
+        batch_size_cfg = batch_size or "auto"
+
+        conf: Dict[str, Any] = {
+            "model_type": "ecd",
+            "input_features": [
+                {
+                    "name": IMAGE_PATH_COLUMN_NAME,
+                    "type": "image",
+                    "encoder": encoder_config,
+                }
+            ],
+            "output_features": [
+                {"name": LABEL_COLUMN_NAME, "type": "category"}
+            ],
+            "combiner": {"type": "concat"},
+            "trainer": {
+                "epochs": epochs,
+                "early_stop": early_stop,
+                "batch_size": batch_size_cfg,
+                "learning_rate": learning_rate,
+            },
+            "preprocessing": {
+                "split": split_config,
+                "num_processes": num_processes,
+                "in_memory": False,
+            },
+        }
+
+        logger.debug("LudwigDirectBackend: Config dict built.")
+        try:
+            yaml_str = yaml.dump(conf, sort_keys=False, indent=2)
+            logger.info("LudwigDirectBackend: YAML config generated.")
+            return yaml_str
+        except Exception:
+            logger.error("LudwigDirectBackend: Failed to serialize YAML.", exc_info=True)
+            raise
+
+    def run_experiment(
+        self,
+        dataset_path: Path,
+        config_path: Path,
+        output_dir: Path,
+        random_seed: int = 42,
+    ) -> None:
+        """
+        Invoke Ludwig's internal experiment_cli function to run the experiment.
+        """
+        logger.info("LudwigDirectBackend: Starting experiment execution.")
+
+        try:
+            from ludwig.experiment import experiment_cli
+        except ImportError as e:
+            logger.error(
+                "LudwigDirectBackend: Could not import experiment_cli.",
+                exc_info=True
+            )
+            raise RuntimeError("Ludwig import failed.") from e
+
+        output_dir.mkdir(parents=True, exist_ok=True)
+
+        try:
+            experiment_cli(
+                dataset=str(dataset_path),
+                config=str(config_path),
+                output_directory=str(output_dir),
+                random_seed=random_seed,
+            )
+            logger.info(f"LudwigDirectBackend: Experiment completed. Results in {output_dir}")
+        except TypeError as e:
+            logger.error(
+                "LudwigDirectBackend: Argument mismatch in experiment_cli call.",
+                exc_info=True
+            )
+            raise RuntimeError("Ludwig argument error.") from e
+        except Exception:
+            logger.error(
+                "LudwigDirectBackend: Experiment execution error.",
+                exc_info=True
+            )
+            raise
+
+    def get_training_process(self, output_dir) -> float:
+        """
+        Retrieve the learning rate used in the most recent Ludwig run.
+        Returns:
+            float: learning rate (or None if not found)
+        """
+        output_dir = Path(output_dir)
+        exp_dirs = sorted(
+            output_dir.glob("experiment_run*"),
+            key=lambda p: p.stat().st_mtime
+        )
+
+        if not exp_dirs:
+            logger.warning(f"No experiment run directories found in {output_dir}")
+            return None
+
+        progress_file = exp_dirs[-1] / "model" / "training_progress.json"
+        if not progress_file.exists():
+            logger.warning(f"No training_progress.json found in {progress_file}")
+            return None
+
+        try:
+            with progress_file.open("r", encoding="utf-8") as f:
+                data = json.load(f)
+            return {
+                "learning_rate": data.get("learning_rate"),
+                "batch_size": data.get("batch_size"),
+                "epoch": data.get("epoch"),
+            }
+        except Exception as e:
+            self.logger.warning(f"Failed to read training progress info: {e}")
+            return {}
+
+    def convert_parquet_to_csv(self, output_dir: Path):
+        """Convert the predictions Parquet file to CSV."""
+        output_dir = Path(output_dir)
+        exp_dirs = sorted(
+            output_dir.glob("experiment_run*"),
+            key=lambda p: p.stat().st_mtime
+        )
+        if not exp_dirs:
+            logger.warning(f"No experiment run dirs found in {output_dir}")
+            return
+        exp_dir = exp_dirs[-1]
+        parquet_path = exp_dir / PREDICTIONS_PARQUET_FILE_NAME
+        csv_path = exp_dir / "predictions.csv"
+        try:
+            df = pd.read_parquet(parquet_path)
+            df.to_csv(csv_path, index=False)
+            logger.info(f"Converted Parquet to CSV: {csv_path}")
+        except Exception as e:
+            logger.error(f"Error converting Parquet to CSV: {e}")
+
+    def generate_plots(self, output_dir: Path) -> None:
+        """
+        Generate _all_ registered Ludwig visualizations for the latest experiment run.
+        """
+        logger.info("Generating all Ludwig visualizations…")
+
+        test_plots = {
+            'compare_performance',
+            'compare_classifiers_performance_from_prob',
+            'compare_classifiers_performance_from_pred',
+            'compare_classifiers_performance_changing_k',
+            'compare_classifiers_multiclass_multimetric',
+            'compare_classifiers_predictions',
+            'confidence_thresholding_2thresholds_2d',
+            'confidence_thresholding_2thresholds_3d',
+            'confidence_thresholding',
+            'confidence_thresholding_data_vs_acc',
+            'binary_threshold_vs_metric',
+            'roc_curves',
+            'roc_curves_from_test_statistics',
+            'calibration_1_vs_all',
+            'calibration_multiclass',
+            'confusion_matrix',
+            'frequency_vs_f1',
+        }
+        train_plots = {
+            'learning_curves',
+            'compare_classifiers_performance_subset',
+        }
+
+        # 1) find the most recent experiment directory
+        output_dir = Path(output_dir)
+        exp_dirs = sorted(
+            output_dir.glob("experiment_run*"),
+            key=lambda p: p.stat().st_mtime
+        )
+        if not exp_dirs:
+            logger.warning(f"No experiment run dirs found in {output_dir}")
+            return
+        exp_dir = exp_dirs[-1]
+
+        # 2) ensure viz output subfolder exists
+        viz_dir = exp_dir / "visualizations"
+        viz_dir.mkdir(exist_ok=True)
+        train_viz = viz_dir / "train"
+        test_viz = viz_dir / "test"
+        train_viz.mkdir(parents=True, exist_ok=True)
+        test_viz.mkdir(parents=True, exist_ok=True)
+
+        # 3) helper to check file existence
+        def _check(p: Path) -> Optional[str]:
+            return str(p) if p.exists() else None
+
+        # 4) gather standard Ludwig output files
+        training_stats = _check(exp_dir / "training_statistics.json")
+        test_stats = _check(exp_dir / TEST_STATISTICS_FILE_NAME)
+        probs_path = _check(exp_dir / PREDICTIONS_PARQUET_FILE_NAME)
+        gt_metadata = _check(exp_dir / "model" / TRAIN_SET_METADATA_FILE_NAME)
+
+        # 5) try to read original dataset & split file from description.json
+        dataset_path = None
+        split_file = None
+        desc = exp_dir / DESCRIPTION_FILE_NAME
+        if desc.exists():
+            with open(desc, "r") as f:
+                cfg = json.load(f)
+            dataset_path = _check(Path(cfg.get("dataset", "")))
+            split_file = _check(Path(get_split_path(cfg.get("dataset", ""))))
+
+        # 6) infer output feature name
+        output_feature = ""
+        if desc.exists():
+            try:
+                output_feature = cfg["config"]["output_features"][0]["name"]
+            except Exception:
+                pass
+        if not output_feature and test_stats:
+            with open(test_stats, "r") as f:
+                stats = json.load(f)
+            output_feature = next(iter(stats.keys()), "")
+
+        # 7) loop through every registered viz
+        viz_registry = get_visualizations_registry()
+        for viz_name, viz_func in viz_registry.items():
+            viz_dir_plot = None
+            if viz_name in train_plots:
+                viz_dir_plot = train_viz
+            elif viz_name in test_plots:
+                viz_dir_plot = test_viz
+
+            try:
+                viz_func(
+                    training_statistics=[training_stats] if training_stats else [],
+                    test_statistics=[test_stats] if test_stats else [],
+                    probabilities=[probs_path] if probs_path else [],
+                    output_feature_name=output_feature,
+                    ground_truth_split=2,
+                    top_n_classes=[0],
+                    top_k=3,
+                    ground_truth_metadata=gt_metadata,
+                    ground_truth=dataset_path,
+                    split_file=split_file,
+                    output_directory=str(viz_dir_plot),
+                    normalize=False,
+                    file_format="png",
+                )
+                logger.info(f"✔ Generated {viz_name}")
+            except Exception as e:
+                logger.warning(f"✘ Skipped {viz_name}: {e}")
+
+        logger.info(f"All visualizations written to {viz_dir}")
+
+    def generate_html_report(
+            self,
+            title: str,
+            output_dir: str,
+            config: dict,
+            split_info: str) -> Path:
+        """
+        Assemble an HTML report from visualizations under train_val/ and test/ folders.
+        """
+        cwd = Path.cwd()
+        report_name = title.lower().replace(" ", "_") + "_report.html"
+        report_path = cwd / report_name
+        output_dir = Path(output_dir)
+
+        # Find latest experiment dir
+        exp_dirs = sorted(output_dir.glob("experiment_run*"), key=lambda p: p.stat().st_mtime)
+        if not exp_dirs:
+            raise RuntimeError(f"No 'experiment*' dirs found in {output_dir}")
+        exp_dir = exp_dirs[-1]
+
+        base_viz_dir = exp_dir / "visualizations"
+        train_viz_dir = base_viz_dir / "train"
+        test_viz_dir = base_viz_dir / "test"
+
+        html = get_html_template()
+        html += f"<h1>{title}</h1>"
+
+        metrics_html = ""
+
+        # Load and embed metrics table (training/val/test stats)
+        try:
+            train_stats_path = exp_dir / "training_statistics.json"
+            test_stats_path = exp_dir / TEST_STATISTICS_FILE_NAME
+            if train_stats_path.exists() and test_stats_path.exists():
+                with open(train_stats_path) as f:
+                    train_stats = json.load(f)
+                with open(test_stats_path) as f:
+                    test_stats = json.load(f)
+                output_feature = next(iter(train_stats.keys()), "")
+                if output_feature:
+                    metrics_html += format_stats_table_html(train_stats, test_stats)
+        except Exception as e:
+            logger.warning(f"Could not load stats for HTML report: {e}")
+
+        config_html = ""
+        training_progress = self.get_training_process(output_dir)
+        try:
+            config_html = format_config_table_html(config, split_info, training_progress)
+        except Exception as e:
+            logger.warning(f"Could not load config for HTML report: {e}")
+
+        def render_img_section(title: str, dir_path: Path) -> str:
+            if not dir_path.exists():
+                return f"<h2>{title}</h2><p><em>Directory not found.</em></p>"
+            imgs = sorted(dir_path.glob("*.png"))
+            if not imgs:
+                return f"<h2>{title}</h2><p><em>No plots found.</em></p>"
+
+            section_html = f"<h2 style='text-align: center;'>{title}</h2><div>"
+            for img in imgs:
+                b64 = encode_image_to_base64(str(img))
+                section_html += (
+                    f'<div class="plot" style="margin-bottom:20px;text-align:center;">'
+                    f"<h3>{img.stem.replace('_',' ').title()}</h3>"
+                    f'<img src="data:image/png;base64,{b64}" '
+                    'style="max-width:90%;max-height:600px;border:1px solid #ddd;" />'
+                    "</div>"
+                )
+            section_html += "</div>"
+            return section_html
+
+        train_plots_html = render_img_section("Training & Validation Visualizations", train_viz_dir)
+        test_plots_html = render_img_section("Test Visualizations", test_viz_dir)
+        html += build_tabbed_html(config_html + metrics_html, train_plots_html, test_plots_html)
+        html += get_html_closing()
+
+        try:
+            with open(report_path, "w") as f:
+                f.write(html)
+            logger.info(f"HTML report generated at: {report_path}")
+        except Exception as e:
+            logger.error(f"Failed to write HTML report: {e}")
+            raise
+
+        return report_path
+
+
+class WorkflowOrchestrator:
+    """
+    Manages the image-classification workflow:
+      1. Creates temp dirs
+      2. Extracts images
+      3. Prepares data (CSV + splits)
+      4. Renders a backend config
+      5. Runs the experiment
+      6. Cleans up
+    """
+
+    def __init__(self, args: argparse.Namespace, backend: Backend):
+        self.args = args
+        self.backend = backend
+        self.temp_dir: Optional[Path] = None
+        self.image_extract_dir: Optional[Path] = None
+        logger.info(f"Orchestrator initialized with backend: {type(backend).__name__}")
+
+    def _create_temp_dirs(self) -> None:
+        """Create temporary output and image extraction directories."""
+        try:
+            self.temp_dir = Path(tempfile.mkdtemp(
+                dir=self.args.output_dir,
+                prefix=TEMP_DIR_PREFIX
+            ))
+            self.image_extract_dir = self.temp_dir / "images"
+            self.image_extract_dir.mkdir()
+            logger.info(f"Created temp directory: {self.temp_dir}")
+        except Exception:
+            logger.error("Failed to create temporary directories", exc_info=True)
+            raise
+
+    def _extract_images(self) -> None:
+        """Extract images from ZIP into the temp image directory."""
+        if self.image_extract_dir is None:
+            raise RuntimeError("Temp image directory not initialized.")
+        logger.info(f"Extracting images from {self.args.image_zip} → {self.image_extract_dir}")
+        try:
+            with zipfile.ZipFile(self.args.image_zip, "r") as z:
+                z.extractall(self.image_extract_dir)
+            logger.info("Image extraction complete.")
+        except Exception:
+            logger.error("Error extracting zip file", exc_info=True)
+            raise
+
+    def _prepare_data(self) -> Tuple[Path, Dict[str, Any]]:
+        """
+        Load CSV, update image paths, handle splits, and write prepared CSV.
+        Returns:
+            final_csv_path: Path to the prepared CSV
+            split_config: Dict for backend split settings
+        """
+        if not self.temp_dir or not self.image_extract_dir:
+            raise RuntimeError("Temp dirs not initialized before data prep.")
+
+        # 1) Load
+        try:
+            df = pd.read_csv(self.args.csv_file)
+            logger.info(f"Loaded CSV: {self.args.csv_file}")
+        except Exception:
+            logger.error("Error loading CSV file", exc_info=True)
+            raise
+
+        # 2) Validate columns
+        required = {IMAGE_PATH_COLUMN_NAME, LABEL_COLUMN_NAME}
+        missing = required - set(df.columns)
+        if missing:
+            raise ValueError(f"Missing CSV columns: {', '.join(missing)}")
+
+        # 3) Update image paths
+        try:
+            df[IMAGE_PATH_COLUMN_NAME] = df[IMAGE_PATH_COLUMN_NAME].apply(
+                lambda p: str((self.image_extract_dir / p).resolve())
+            )
+        except Exception:
+            logger.error("Error updating image paths", exc_info=True)
+            raise
+
+        # 4) Handle splits
+        if SPLIT_COLUMN_NAME in df.columns:
+            df, split_config, split_info = self._process_fixed_split(df)
+        else:
+            logger.info("No split column; using random split")
+            split_config = {
+                "type": "random",
+                "probabilities": self.args.split_probabilities
+            }
+            split_info = (
+                f"No split column in CSV. Used random split: "
+                f"{[int(p*100) for p in self.args.split_probabilities]}% for train/val/test."
+            )
+
+        # 5) Write out prepared CSV
+        final_csv = TEMP_CSV_FILENAME
+        try:
+            df.to_csv(final_csv, index=False)
+            logger.info(f"Saved prepared data to {final_csv}")
+        except Exception:
+            logger.error("Error saving prepared CSV", exc_info=True)
+            raise
+
+        return final_csv, split_config, split_info
+
+    def _process_fixed_split(self, df: pd.DataFrame) -> Dict[str, Any]:
+        """Process a fixed split column (0=train,1=val,2=test)."""
+        logger.info(f"Fixed split column '{SPLIT_COLUMN_NAME}' detected.")
+        try:
+            col = df[SPLIT_COLUMN_NAME]
+            df[SPLIT_COLUMN_NAME] = pd.to_numeric(col, errors="coerce").astype(pd.Int64Dtype())
+            if df[SPLIT_COLUMN_NAME].isna().any():
+                logger.warning("Split column contains non-numeric/missing values.")
+
+            unique = set(df[SPLIT_COLUMN_NAME].dropna().unique())
+            logger.info(f"Unique split values: {unique}")
+
+            if unique == {0, 2}:
+                df = split_data_0_2(
+                    df, SPLIT_COLUMN_NAME,
+                    validation_size=self.args.validation_size,
+                    label_column=LABEL_COLUMN_NAME,
+                    random_state=self.args.random_seed
+                )
+                split_info = (
+                    "Detected a split column (with values 0 and 2) in the input CSV. "
+                    f"Used this column as a base and"
+                    f"reassigned {self.args.validation_size * 100:.1f}% "
+                    "of the training set (originally labeled 0) to validation (labeled 1)."
+                )
+
+                logger.info("Applied custom 0/2 split.")
+            elif unique.issubset({0, 1, 2}):
+                split_info = "Used user-defined split column from CSV."
+                logger.info("Using fixed split as-is.")
+            else:
+                raise ValueError(f"Unexpected split values: {unique}")
+
+            return df, {"type": "fixed", "column": SPLIT_COLUMN_NAME}, split_info
+
+        except Exception:
+            logger.error("Error processing fixed split", exc_info=True)
+            raise
+
+    def _cleanup_temp_dirs(self) -> None:
+        """Remove any temporary directories."""
+        if self.temp_dir and self.temp_dir.exists():
+            logger.info(f"Cleaning up temp directory: {self.temp_dir}")
+            shutil.rmtree(self.temp_dir, ignore_errors=True)
+        self.temp_dir = None
+        self.image_extract_dir = None
+
+    def run(self) -> None:
+        """Execute the full workflow end-to-end."""
+        logger.info("Starting workflow...")
+        self.args.output_dir.mkdir(parents=True, exist_ok=True)
+
+        try:
+            self._create_temp_dirs()
+            self._extract_images()
+            csv_path, split_cfg, split_info = self._prepare_data()
+
+            use_pretrained = self.args.use_pretrained or self.args.fine_tune
+
+            backend_args = {
+                "model_name": self.args.model_name,
+                "fine_tune": self.args.fine_tune,
+                "use_pretrained": use_pretrained,
+                "epochs": self.args.epochs,
+                "batch_size": self.args.batch_size,
+                "preprocessing_num_processes": self.args.preprocessing_num_processes,
+                "split_probabilities": self.args.split_probabilities,
+                "learning_rate": self.args.learning_rate,
+                "random_seed": self.args.random_seed,
+                "early_stop": self.args.early_stop,
+            }
+            yaml_str = self.backend.prepare_config(backend_args, split_cfg)
+
+            config_file = self.temp_dir / TEMP_CONFIG_FILENAME
+            config_file.write_text(yaml_str)
+            logger.info(f"Wrote backend config: {config_file}")
+
+            self.backend.run_experiment(
+                csv_path,
+                config_file,
+                self.args.output_dir,
+                self.args.random_seed
+            )
+            logger.info("Workflow completed successfully.")
+            self.backend.generate_plots(self.args.output_dir)
+            report_file = self.backend.generate_html_report(
+                "Image Classification Results",
+                self.args.output_dir,
+                backend_args,
+                split_info
+            )
+            logger.info(f"HTML report generated at: {report_file}")
+            self.backend.convert_parquet_to_csv(self.args.output_dir)
+            logger.info("Converted Parquet to CSV.")
+        except Exception:
+            logger.error("Workflow execution failed", exc_info=True)
+            raise
+
+        finally:
+            self._cleanup_temp_dirs()
+
+
+def parse_learning_rate(s):
+    try:
+        return float(s)
+    except (TypeError, ValueError):
+        return None
+
+
+class SplitProbAction(argparse.Action):
+    def __call__(self, parser, namespace, values, option_string=None):
+        # values is a list of three floats
+        train, val, test = values
+        total = train + val + test
+        if abs(total - 1.0) > 1e-6:
+            parser.error(
+                f"--split-probabilities must sum to 1.0; "
+                f"got {train:.3f} + {val:.3f} + {test:.3f} = {total:.3f}"
+            )
+        setattr(namespace, self.dest, values)
+
+
+def main():
+
+    parser = argparse.ArgumentParser(
+        description="Image Classification Learner with Pluggable Backends"
+    )
+    parser.add_argument(
+        "--csv-file", required=True, type=Path,
+        help="Path to the input CSV"
+    )
+    parser.add_argument(
+        "--image-zip", required=True, type=Path,
+        help="Path to the images ZIP"
+    )
+    parser.add_argument(
+        "--model-name", required=True,
+        choices=MODEL_ENCODER_TEMPLATES.keys(),
+        help="Which model template to use"
+    )
+    parser.add_argument(
+        "--use-pretrained", action="store_true",
+        help="Use pretrained weights for the model"
+    )
+    parser.add_argument(
+        "--fine-tune", action="store_true",
+        help="Enable fine-tuning"
+    )
+    parser.add_argument(
+        "--epochs", type=int, default=10,
+        help="Number of training epochs"
+    )
+    parser.add_argument(
+        "--early-stop", type=int, default=5,
+        help="Early stopping patience"
+    )
+    parser.add_argument(
+        "--batch-size", type=int,
+        help="Batch size (None = auto)"
+    )
+    parser.add_argument(
+        "--output-dir", type=Path, default=Path("learner_output"),
+        help="Where to write outputs"
+    )
+    parser.add_argument(
+        "--validation-size", type=float, default=0.15,
+        help="Fraction for validation (0.0–1.0)"
+    )
+    parser.add_argument(
+        "--preprocessing-num-processes", type=int,
+        default=max(1, os.cpu_count() // 2),
+        help="CPU processes for data prep"
+    )
+    parser.add_argument(
+        "--split-probabilities", type=float, nargs=3,
+        metavar=("train", "val", "test"),
+        action=SplitProbAction,
+        default=[0.7, 0.1, 0.2],
+        help="Random split proportions (e.g., 0.7 0.1 0.2). Only used if no split column is present."
+    )
+    parser.add_argument(
+        "--random-seed", type=int, default=42,
+        help="Random seed used for dataset splitting (default: 42)"
+    )
+    parser.add_argument(
+        "--learning-rate", type=parse_learning_rate, default=None,
+        help="Learning rate. If not provided, Ludwig will auto-select it."
+    )
+
+    args = parser.parse_args()
+
+    # -- Validation --
+    if not 0.0 <= args.validation_size <= 1.0:
+        parser.error("validation-size must be between 0.0 and 1.0")
+    if not args.csv_file.is_file():
+        parser.error(f"CSV not found: {args.csv_file}")
+    if not args.image_zip.is_file():
+        parser.error(f"ZIP not found: {args.image_zip}")
+
+    # --- Instantiate Backend and Orchestrator ---
+    # Use the new LudwigDirectBackend
+    backend_instance = LudwigDirectBackend()
+    orchestrator = WorkflowOrchestrator(args, backend_instance)
+
+    # --- Run Workflow ---
+    exit_code = 0
+    try:
+        orchestrator.run()
+        logger.info("Main script finished successfully.")
+    except Exception as e:
+        logger.error(f"Main script failed.{e}")
+        exit_code = 1
+    finally:
+        sys.exit(exit_code)
+
+
+if __name__ == '__main__':
+    try:
+        import ludwig
+        logger.debug(f"Found Ludwig version: {ludwig.globals.LUDWIG_VERSION}")
+    except ImportError:
+        logger.error("Ludwig library not found. Please ensure Ludwig is installed ('pip install ludwig[image]')")
+        sys.exit(1)
+
+    main()