changeset 0:54b871dfc51e draft default tip

planemo upload for repository https://github.com/goeckslab/gleam.git commit b7411ff35b6228ccdfd36cd4ebd946c03ac7f7e9
author goeckslab
date Tue, 03 Jun 2025 21:22:11 +0000
parents
children
files image_learner.xml image_learner_cli.py test-data/image_classification_results_report_mnist.html test-data/mnist_subset.csv test-data/mnist_subset.zip utils.py
diffstat 6 files changed, 1813 insertions(+), 0 deletions(-) [+]
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/image_learner.xml	Tue Jun 03 21:22:11 2025 +0000
@@ -0,0 +1,270 @@
+<tool id="image_learner" name="Image Learner for Classification" version="0.1.0" profile="22.05">
+    <description>trains and evaluates a image classification model</description>
+    <requirements>
+        <container type="docker">quay.io/goeckslab/galaxy-ludwig-gpu:0.10.1</container>
+    </requirements>
+    <required_files>
+        <include path="utils.py" />
+        <include path="image_learner_cli.py" />
+    </required_files>
+    <stdio>
+        <exit_code range="137" level="fatal_oom" description="Out of Memory" />
+        <exit_code range="1:" level="fatal" description="Error occurred. Please check Tool Standard Error" />
+    </stdio>
+    <command>
+        <![CDATA[
+            #import re
+
+            #if $input_csv
+
+            #set $sanitized_input_csv = re.sub('[^\w\-_\.]', '_', $input_csv.element_identifier.strip())
+            ln -sf '$input_csv' "./${sanitized_input_csv}";
+            #end if
+
+            python '$__tool_directory__/image_learner_cli.py'
+                --csv-file "./${sanitized_input_csv}"
+                --image-zip "$image_zip"
+                --model-name "$model_name"
+                #if $use_pretrained == "true"
+                    --use-pretrained
+                    #if $fine_tune == "true"
+                        --fine-tune
+                    #end if
+                #end if
+                #if $customize_defaults == "true"
+                    #if $epochs
+                        --epochs "$epochs"
+                    #end if
+                    #if $early_stop
+                        --early-stop "$early_stop"
+                    #end if
+                    #if $learning_rate_define == "true"
+                        --learning-rate "$learning_rate"
+                    #end if
+                    #if $batch_size_define == "true"
+                        --batch-size "$batch_size"
+                    #end if
+                    --split-probabilities "$train_split" "$val_split" "$test_split"   
+                #end if
+                --random-seed "$random_seed" 
+                --output-dir "." &&
+
+            mkdir -p '$output_model.extra_files_path' &&
+            cp -r experiment_run/model/*.json experiment_run/model/model_weights '$output_model.extra_files_path' &&
+
+            echo "Image Learner Classification Experiment is Done!"        
+        ]]>
+    </command>
+    
+    <inputs>
+        <param argument="input_csv" type="data" format="csv" optional="false" label="the metadata csv containing image_path column, label column and optional split column" />
+        <param name="image_zip" type="data" format="zip" optional="false" label="Image zip" help="Image zip file containing your image data"/>
+        <param name="model_name" type="select" optional="false" label="Select a model for your experiment" >
+           
+            <option value="resnet18">Resnet18</option>
+            <option value="resnet34">Resnet34</option>
+            <option value="resnet50">Resnet50</option>
+            <option value="resnet101">Resnet101</option>
+            <option value="resnet152">Resnet152</option>
+            <option value="resnext50_32x4d">Resnext50_32x4d</option>
+            <option value="resnext101_32x8d">Resnext101_32x8d</option>
+            <option value="resnext101_64x4d">Resnext101_64x4d</option>
+            <option value="resnext152_32x8d">Resnext152_32x8d</option>
+            <option value="wide_resnet50_2">Wide_resnet50_2</option>
+            <option value="wide_resnet101_2">Wide_resnet101_2</option>
+            <option value="wide_resnet103_2">Wide_resnet103_2</option>
+            <option value="efficientnet_b0">Efficientnet_b0</option>
+            <option value="efficientnet_b1">Efficientnet_b1</option>
+            <option value="efficientnet_b2">Efficientnet_b2</option>
+            <option value="efficientnet_b3">Efficientnet_b3</option>
+            <option value="efficientnet_b4">Efficientnet_b4</option>
+            <option value="efficientnet_b5">Efficientnet_b5</option>
+            <option value="efficientnet_b6">Efficientnet_b6</option>
+            <option value="efficientnet_b7">Efficientnet_b7</option>
+            <option value="efficientnet_v2_s">Efficientnet_v2_s</option>
+            <option value="efficientnet_v2_m">Efficientnet_v2_m</option>
+            <option value="efficientnet_v2_l">Efficientnet_v2_l</option>
+            <option value="regnet_y_400mf">Regnet_y_400mf</option>
+            <option value="regnet_y_800mf">Regnet_y_800mf</option>
+            <option value="regnet_y_1_6gf">Regnet_y_1_6gf</option>
+            <option value="regnet_y_3_2gf">Regnet_y_3_2gf</option>
+            <option value="regnet_y_8gf">Regnet_y_8gf</option>
+            <option value="regnet_y_16gf">Regnet_y_16gf</option>
+            <option value="regnet_y_32gf">Regnet_y_32gf</option>
+            <option value="regnet_y_128gf">Regnet_y_128gf</option>
+            <option value="regnet_x_400mf">Regnet_x_400mf</option>
+            <option value="regnet_x_800mf">Regnet_x_800mf</option>
+            <option value="regnet_x_1_6gf">Regnet_x_1_6gf</option>
+            <option value="regnet_x_3_2gf">Regnet_x_3_2gf</option>
+            <option value="regnet_x_8gf">Regnet_x_8gf</option>
+            <option value="regnet_x_16gf">Regnet_x_16gf</option>
+            <option value="regnet_x_32gf">Regnet_x_32gf</option>
+            <option value="vgg11">Vgg11</option>
+            <option value="vgg11_bn">Vgg11_bn</option>
+            <option value="vgg13">Vgg13</option>
+            <option value="vgg13_bn">Vgg13_bn</option>
+            <option value="vgg16">Vgg16</option>
+            <option value="vgg16_bn">Vgg16_bn</option>
+            <option value="vgg19">Vgg19</option>
+            <option value="vgg19_bn">Vgg19_bn</option>
+            <option value="shufflenet_v2_x0_5">Shufflenet_v2_x0_5</option>
+            <option value="shufflenet_v2_x1_0">Shufflenet_v2_x1_0</option>
+            <option value="shufflenet_v2_x1_5">Shufflenet_v2_x1_5</option>
+            <option value="shufflenet_v2_x2_0">Shufflenet_v2_x2_0</option>
+            <option value="squeezenet1_0">Squeezenet1_0</option>
+            <option value="squeezenet1_1">Squeezenet1_1</option>
+            <option value="swin_t">Swin_t</option>
+            <option value="swin_s">Swin_s</option>
+            <option value="swin_b">Swin_b</option>
+            <option value="swin_v2_t">Swin_v2_t</option>
+            <option value="swin_v2_s">Swin_v2_s</option>
+            <option value="swin_v2_b">Swin_v2_b</option>
+            <option value="vit_b_16">Vit_b_16</option>
+            <option value="vit_b_32">Vit_b_32</option>
+            <option value="vit_l_16">Vit_l_16</option>
+            <option value="vit_l_32">Vit_l_32</option>
+            <option value="vit_h_14">Vit_h_14</option>
+            <option value="convnext_tiny">Convnext_tiny</option>
+            <option value="convnext_small">Convnext_small</option>
+            <option value="convnext_base">Convnext_base</option>
+            <option value="convnext_large">Convnext_large</option>
+            <option value="maxvit_t">Maxvit_t</option>
+            <option value="alexnet">Alexnet</option>
+            <option value="googlenet">Googlenet</option>
+            <option value="inception_v3">Inception_v3</option>
+            <option value="mobilenet_v2">Mobilenet_v2</option>
+            <option value="mobilenet_v3_large">Mobilenet_v3_large</option>
+            <option value="mobilenet_v3_small">Mobilenet_v3_small</option>
+        </param>
+
+        <conditional name="scratch_fine_tune">
+            <param name="use_pretrained" type="select"
+                label="Use pretrained weights?"
+                help="If select no, the encoder, combiner, and decoder will all be initialized and trained from scratch.  
+               (e.g. when your images are very different from ImageNet or no suitable pretrained model exists.)">
+                <option value="false">No</option>
+                <option value="true" selected="true">Yes</option>
+            </param>
+            <when value="true">
+                <param name="fine_tune" type="select" label="Fine tune the encoder?"
+                    help="Whether to fine tune the encoder(combiner and decoder will be fine-tued anyway)" >
+                    <option value="false" >No</option>
+                    <option value="true" selected="true">Yes</option>
+                </param>
+            </when>
+            <when value="false">
+                <!-- No additional parameters to show if the user selects 'No' -->
+            </when>
+        </conditional>
+        <param argument="random_seed" type="integer" value="42" optional="true" label="Random seed (set for reproducibility)" min="0" max="999999"/>
+        <conditional name="advanced_settings">
+            <param name="customize_defaults" type="select" label="Customize Default Settings?" help="Select yes if you want to customize the default settings of the experiment.">
+                <option value="false" selected="true">No</option>
+                <option value="true">Yes</option>
+            </param>
+            <when value="true">
+                <param name="epochs" type="integer" value="10" min="1" max="99999" label="Epochs" help="Total number of full passes through the training dataset. Higher values may improve accuracy but increase training time. Default: 10." />
+                <param name="early_stop" type="integer" value="5" min="1" max="99999" label="Early Stop" help="Number of epochs with no improvement after which training will be stopped. Default: 5." />
+                <conditional name="learning_rate_condition">
+                    <param name="learning_rate_define" type="select" label="Define an initial learning rate?" help="Want to change the initial learning rate from default to a number? See ludwig.ai for more info. Default: No" >
+                        <option value="false" selected="true" >No</option>
+                        <option value="true">Yes</option>
+                    </param>
+                    <when value="true">
+                        <param name="learning_rate" type="float" value="0.001" min="0.0001" max="1.0" label="Learning Rate" help="Initial learning rate for the optimizer. Default: 0.001." />
+                    </when>
+                    <when value="false">
+                        <!-- No additional parameters to show if the user selects 'No' -->
+                    </when>
+                </conditional>
+                <conditional name="batch_size_condition">
+                    <param name="batch_size_define" type="select" label="Define your batch size?" help="Want to change the batch size from auto to a number? See ludwig.ai for more info. Default: No" >
+                        <option value="false" selected="true" >No</option>
+                        <option value="true">Yes</option>
+                    </param>
+                    <when value="true">
+                        <param name="batch_size" type="integer" value="32" min="1" max="99999" label="Batch Size" help="Number of samples per gradient update. Default: 32." />
+                    </when>
+                    <when value="false">
+                        <!-- No additional parameters to show if the user selects 'No' -->
+                    </when>
+                </conditional>
+                <param name="train_split" type="float"
+                        label="Training split proportion (only works if no split column in the metadata csv)"
+                        value="0.7"
+                        help="Fraction of data for training (e.g., 0.7). train split + val split + test split should = 1"/>
+                <param name="val_split"   type="float"
+                        label="Validation split proportion (only works if no split column in the metadata csv)"
+                        value="0.1"
+                        help="Fraction of data for validation (e.g., 0.1). train split + val split + test split should = 1"/>
+                <param name="test_split"  type="float"
+                        label="Test split proportion (only works if no split column in the metadata csv)"
+                        value="0.2"
+                        help="Fraction of data for testing (e.g., 0.2) train split + val split + test split should = 1."/>
+            </when>
+            <when value="false">
+                <!-- No additional parameters to show if the user selects 'No' -->
+            </when>
+        </conditional>    
+    </inputs>       
+    <outputs>
+        <data format="ludwig_model" name="output_model" label="${tool.name} trained model on ${on_string}" />
+        <data format="html" name="output_report" from_work_dir="image_classification_results_report.html" label="${tool.name} report on ${on_string}" />
+        <collection type="list" name="output_pred_csv" label="${tool.name} predictions CSVs/experiment stats/plots on ${on_string}" >
+            <discover_datasets pattern="(?P&lt;designation&gt;predictions\.csv)" format="csv" directory="experiment_run" />
+            <discover_datasets pattern="(?P&lt;designation&gt;.+)\.json" format="json" directory="experiment_run" />
+            <discover_datasets pattern="(?P&lt;designation&gt;.+)\.png" format="png" directory="experiment_run/visualizations/train" />
+            <discover_datasets pattern="(?P&lt;designation&gt;.+)\.png" format="png" directory="experiment_run/visualizations/test" />
+        </collection>
+    </outputs>
+    <tests>
+        <test expect_num_outputs="3">
+            <param name="input_csv" value="mnist_subset.csv" ftype="csv" />
+            <param name="image_zip" value="mnist_subset.zip" ftype="zip" />
+            <param name="model_name" value="resnet18" />
+            <output name="output_report" file="image_classification_results_report_mnist.html" compare="sim_size" delta="20000" >
+                <assert_contents>
+                    <has_text text="Epochs" />
+                </assert_contents>
+            </output>
+
+            <output_collection name="output_pred_csv" type="list" >
+                <element name="predictions.csv" >
+                    <assert_contents>
+                        <has_n_columns n="1" />
+                    </assert_contents>
+                </element>
+            </output_collection>
+        </test>
+    </tests>
+    <help>
+        <![CDATA[
+**What it does**
+Image Learner for Classification: trains and evaluates a image classification model. 
+It uses the metadata csv to find the image paths and labels. 
+The metadata csv should contain a column with the name 'image_path' and a column with the name 'label'.
+Optionally, you can also add a column with the name 'split' to specify which split each row belongs to (train, val, test). 
+If you do not provide a split column, the tool will automatically split the data into train, val, and test sets based on the proportions you specify or [0.7, 0.1, 0.2] by default.
+
+
+**Outputs**
+The tool will output a trained model in the form of a ludwig_model file,
+a report in the form of an HTML file, and a collection of CSV/json/png files containing the predictions, experiment stats and visualizations.
+The html report will contain metrics&experiment setup parameters, train&val plots and test plots.
+
+        ]]>
+    </help>
+    <citations>
+            <citation type="bibtex">
+@misc{https://doi.org/10.48550/arxiv.1909.07930,
+    doi = {10.48550/ARXIV.1909.07930},
+    url = {https://arxiv.org/abs/1909.07930},
+    author = {Molino, Piero and Dudin, Yaroslav and Miryala, Sai Sumanth},
+    title = {Ludwig: a type-based declarative deep learning toolbox},
+    publisher = {arXiv},
+    year = {2019},
+    copyright = {arXiv.org perpetual, non-exclusive license}
+}
+            </citation>
+        </citations>
+</tool>
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/image_learner_cli.py	Tue Jun 03 21:22:11 2025 +0000
@@ -0,0 +1,1137 @@
+#!/usr/bin/env python3
+import argparse
+import json
+import logging
+import os
+import shutil
+import sys
+import tempfile
+import zipfile
+from pathlib import Path
+from typing import Any, Dict, Optional, Protocol, Tuple
+
+import pandas as pd
+import yaml
+from ludwig.globals import (
+    DESCRIPTION_FILE_NAME,
+    PREDICTIONS_PARQUET_FILE_NAME,
+    TEST_STATISTICS_FILE_NAME,
+    TRAIN_SET_METADATA_FILE_NAME,
+)
+from ludwig.utils.data_utils import get_split_path
+from ludwig.visualize import get_visualizations_registry
+from sklearn.model_selection import train_test_split
+from utils import encode_image_to_base64, get_html_closing, get_html_template
+
+# --- Constants ---
+SPLIT_COLUMN_NAME = 'split'
+LABEL_COLUMN_NAME = 'label'
+IMAGE_PATH_COLUMN_NAME = 'image_path'
+DEFAULT_SPLIT_PROBABILITIES = [0.7, 0.1, 0.2]
+TEMP_CSV_FILENAME = "processed_data_for_ludwig.csv"
+TEMP_CONFIG_FILENAME = "ludwig_config.yaml"
+TEMP_DIR_PREFIX = "ludwig_api_work_"
+MODEL_ENCODER_TEMPLATES: Dict[str, Any] = {
+    'stacked_cnn': 'stacked_cnn',
+    'resnet18': {'type': 'resnet', 'model_variant': 18},
+    'resnet34': {'type': 'resnet', 'model_variant': 34},
+    'resnet50': {'type': 'resnet', 'model_variant': 50},
+    'resnet101': {'type': 'resnet', 'model_variant': 101},
+    'resnet152': {'type': 'resnet', 'model_variant': 152},
+    'resnext50_32x4d': {'type': 'resnext', 'model_variant': '50_32x4d'},
+    'resnext101_32x8d': {'type': 'resnext', 'model_variant': '101_32x8d'},
+    'resnext101_64x4d': {'type': 'resnext', 'model_variant': '101_64x4d'},
+    'resnext152_32x8d': {'type': 'resnext', 'model_variant': '152_32x8d'},
+    'wide_resnet50_2': {'type': 'wide_resnet', 'model_variant': '50_2'},
+    'wide_resnet101_2': {'type': 'wide_resnet', 'model_variant': '101_2'},
+    'wide_resnet103_2': {'type': 'wide_resnet', 'model_variant': '103_2'},
+    'efficientnet_b0': {'type': 'efficientnet', 'model_variant': 'b0'},
+    'efficientnet_b1': {'type': 'efficientnet', 'model_variant': 'b1'},
+    'efficientnet_b2': {'type': 'efficientnet', 'model_variant': 'b2'},
+    'efficientnet_b3': {'type': 'efficientnet', 'model_variant': 'b3'},
+    'efficientnet_b4': {'type': 'efficientnet', 'model_variant': 'b4'},
+    'efficientnet_b5': {'type': 'efficientnet', 'model_variant': 'b5'},
+    'efficientnet_b6': {'type': 'efficientnet', 'model_variant': 'b6'},
+    'efficientnet_b7': {'type': 'efficientnet', 'model_variant': 'b7'},
+    'efficientnet_v2_s': {'type': 'efficientnet', 'model_variant': 'v2_s'},
+    'efficientnet_v2_m': {'type': 'efficientnet', 'model_variant': 'v2_m'},
+    'efficientnet_v2_l': {'type': 'efficientnet', 'model_variant': 'v2_l'},
+    'regnet_y_400mf': {'type': 'regnet', 'model_variant': 'y_400mf'},
+    'regnet_y_800mf': {'type': 'regnet', 'model_variant': 'y_800mf'},
+    'regnet_y_1_6gf': {'type': 'regnet', 'model_variant': 'y_1_6gf'},
+    'regnet_y_3_2gf': {'type': 'regnet', 'model_variant': 'y_3_2gf'},
+    'regnet_y_8gf': {'type': 'regnet', 'model_variant': 'y_8gf'},
+    'regnet_y_16gf': {'type': 'regnet', 'model_variant': 'y_16gf'},
+    'regnet_y_32gf': {'type': 'regnet', 'model_variant': 'y_32gf'},
+    'regnet_y_128gf': {'type': 'regnet', 'model_variant': 'y_128gf'},
+    'regnet_x_400mf': {'type': 'regnet', 'model_variant': 'x_400mf'},
+    'regnet_x_800mf': {'type': 'regnet', 'model_variant': 'x_800mf'},
+    'regnet_x_1_6gf': {'type': 'regnet', 'model_variant': 'x_1_6gf'},
+    'regnet_x_3_2gf': {'type': 'regnet', 'model_variant': 'x_3_2gf'},
+    'regnet_x_8gf': {'type': 'regnet', 'model_variant': 'x_8gf'},
+    'regnet_x_16gf': {'type': 'regnet', 'model_variant': 'x_16gf'},
+    'regnet_x_32gf': {'type': 'regnet', 'model_variant': 'x_32gf'},
+    'vgg11': {'type': 'vgg', 'model_variant': 11},
+    'vgg11_bn': {'type': 'vgg', 'model_variant': '11_bn'},
+    'vgg13': {'type': 'vgg', 'model_variant': 13},
+    'vgg13_bn': {'type': 'vgg', 'model_variant': '13_bn'},
+    'vgg16': {'type': 'vgg', 'model_variant': 16},
+    'vgg16_bn': {'type': 'vgg', 'model_variant': '16_bn'},
+    'vgg19': {'type': 'vgg', 'model_variant': 19},
+    'vgg19_bn': {'type': 'vgg', 'model_variant': '19_bn'},
+    'shufflenet_v2_x0_5': {'type': 'shufflenet_v2', 'model_variant': 'x0_5'},
+    'shufflenet_v2_x1_0': {'type': 'shufflenet_v2', 'model_variant': 'x1_0'},
+    'shufflenet_v2_x1_5': {'type': 'shufflenet_v2', 'model_variant': 'x1_5'},
+    'shufflenet_v2_x2_0': {'type': 'shufflenet_v2', 'model_variant': 'x2_0'},
+    'squeezenet1_0': {'type': 'squeezenet', 'model_variant': '1_0'},
+    'squeezenet1_1': {'type': 'squeezenet', 'model_variant': '1_1'},
+    'swin_t': {'type': 'swin_transformer', 'model_variant': 't'},
+    'swin_s': {'type': 'swin_transformer', 'model_variant': 's'},
+    'swin_b': {'type': 'swin_transformer', 'model_variant': 'b'},
+    'swin_v2_t': {'type': 'swin_transformer', 'model_variant': 'v2_t'},
+    'swin_v2_s': {'type': 'swin_transformer', 'model_variant': 'v2_s'},
+    'swin_v2_b': {'type': 'swin_transformer', 'model_variant': 'v2_b'},
+    'vit_b_16': {'type': 'vision_transformer', 'model_variant': 'b_16'},
+    'vit_b_32': {'type': 'vision_transformer', 'model_variant': 'b_32'},
+    'vit_l_16': {'type': 'vision_transformer', 'model_variant': 'l_16'},
+    'vit_l_32': {'type': 'vision_transformer', 'model_variant': 'l_32'},
+    'vit_h_14': {'type': 'vision_transformer', 'model_variant': 'h_14'},
+    'convnext_tiny': {'type': 'convnext', 'model_variant': 'tiny'},
+    'convnext_small': {'type': 'convnext', 'model_variant': 'small'},
+    'convnext_base': {'type': 'convnext', 'model_variant': 'base'},
+    'convnext_large': {'type': 'convnext', 'model_variant': 'large'},
+    'maxvit_t': {'type': 'maxvit', 'model_variant': 't'},
+    'alexnet': {'type': 'alexnet'},
+    'googlenet': {'type': 'googlenet'},
+    'inception_v3': {'type': 'inception_v3'},
+    'mobilenet_v2': {'type': 'mobilenet_v2'},
+    'mobilenet_v3_large': {'type': 'mobilenet_v3_large'},
+    'mobilenet_v3_small': {'type': 'mobilenet_v3_small'},
+}
+
+# --- Logging Setup ---
+logging.basicConfig(
+    level=logging.INFO,
+    format='%(asctime)s %(levelname)s %(name)s: %(message)s'
+)
+logger = logging.getLogger("ImageLearner")
+
+
+def format_config_table_html(
+        config: dict,
+        split_info: Optional[str] = None,
+        training_progress: dict = None) -> str:
+    display_keys = [
+        "model_name",
+        "epochs",
+        "batch_size",
+        "fine_tune",
+        "use_pretrained",
+        "learning_rate",
+        "random_seed",
+        "early_stop",
+    ]
+
+    rows = []
+
+    for key in display_keys:
+        val = config.get(key, "N/A")
+        if key == "batch_size":
+            if val is not None:
+                val = int(val)
+            else:
+                if training_progress:
+                    val = "Auto-selected batch size by Ludwig:<br>"
+                    resolved_val = training_progress.get("batch_size")
+                    val += (
+                        f"<span style='font-size: 0.85em;'>{resolved_val}</span><br>"
+                    )
+                else:
+                    val = "auto"
+        if key == "learning_rate":
+            resolved_val = None
+            if val is None or val == "auto":
+                if training_progress:
+                    resolved_val = training_progress.get("learning_rate")
+                    val = (
+                        "Auto-selected learning rate by Ludwig:<br>"
+                        f"<span style='font-size: 0.85em;'>{resolved_val if resolved_val else val}</span><br>"
+                        "<span style='font-size: 0.85em;'>"
+                        "Based on model architecture and training setup (e.g., fine-tuning).<br>"
+                        "See <a href='https://ludwig.ai/latest/configuration/trainer/#trainer-parameters' "
+                        "target='_blank'>Ludwig Trainer Parameters</a> for details."
+                        "</span>"
+                    )
+                else:
+                    val = (
+                        "Auto-selected by Ludwig<br>"
+                        "<span style='font-size: 0.85em;'>"
+                        "Automatically tuned based on architecture and dataset.<br>"
+                        "See <a href='https://ludwig.ai/latest/configuration/trainer/#trainer-parameters' "
+                        "target='_blank'>Ludwig Trainer Parameters</a> for details."
+                        "</span>"
+                    )
+            else:
+                val = f"{val:.6f}"
+        if key == "epochs":
+            if training_progress and "epoch" in training_progress and val > training_progress["epoch"]:
+                val = (
+                    f"Because of early stopping: the training"
+                    f"stopped at epoch {training_progress['epoch']}"
+                )
+
+        if val is None:
+            continue
+        rows.append(
+            f"<tr>"
+            f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left;'>"
+            f"{key.replace('_', ' ').title()}</td>"
+            f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center;'>{val}</td>"
+            f"</tr>"
+        )
+
+    if split_info:
+        rows.append(
+            f"<tr>"
+            f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left;'>Data Split</td>"
+            f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center;'>{split_info}</td>"
+            f"</tr>"
+        )
+
+    return (
+        "<h2 style='text-align: center;'>Training Setup</h2>"
+        "<div style='display: flex; justify-content: center;'>"
+        "<table style='border-collapse: collapse; width: 60%; table-layout: auto;'>"
+        "<thead><tr>"
+        "<th style='padding: 10px; border: 1px solid #ccc; text-align: left;'>Parameter</th>"
+        "<th style='padding: 10px; border: 1px solid #ccc; text-align: center;'>Value</th>"
+        "</tr></thead><tbody>" + "".join(rows) + "</tbody></table></div><br>"
+        "<p style='text-align: center; font-size: 0.9em;'>"
+        "Model trained using Ludwig.<br>"
+        "If want to learn more about Ludwig default settings,"
+        "please check the their <a href='https://ludwig.ai' target='_blank'>website(ludwig.ai)</a>."
+        "</p><hr>"
+    )
+
+
+def format_stats_table_html(training_stats: dict, test_stats: dict) -> str:
+    train_metrics = training_stats.get("training", {}).get("label", {})
+    val_metrics = training_stats.get("validation", {}).get("label", {})
+    test_metrics = test_stats.get("label", {})
+
+    all_metrics = set(train_metrics) | set(val_metrics) | set(test_metrics)
+
+    def get_last_value(stats, key):
+        val = stats.get(key)
+        if isinstance(val, list) and val:
+            return val[-1]
+        elif isinstance(val, (int, float)):
+            return val
+        return None
+
+    rows = []
+    for metric in sorted(all_metrics):
+        t = get_last_value(train_metrics, metric)
+        v = get_last_value(val_metrics, metric)
+        te = get_last_value(test_metrics, metric)
+        if all(x is not None for x in [t, v, te]):
+            row = (
+                f"<tr>"
+                f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left;'>{metric}</td>"
+                f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center;'>{t:.4f}</td>"
+                f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center;'>{v:.4f}</td>"
+                f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center;'>{te:.4f}</td>"
+                f"</tr>"
+            )
+            rows.append(row)
+
+    if not rows:
+        return "<p><em>No metric values found.</em></p>"
+
+    return (
+        "<h2 style='text-align: center;'>Model Performance Summary</h2>"
+        "<div style='display: flex; justify-content: center;'>"
+        "<table style='border-collapse: collapse; width: 80%; table-layout: fixed;'>"
+        "<colgroup>"
+        "<col style='width: 40%;'>"
+        "<col style='width: 20%;'>"
+        "<col style='width: 20%;'>"
+        "<col style='width: 20%;'>"
+        "</colgroup>"
+        "<thead><tr>"
+        "<th style='padding: 10px; border: 1px solid #ccc; text-align: left;'>Metric</th>"
+        "<th style='padding: 10px; border: 1px solid #ccc; text-align: center;'>Train</th>"
+        "<th style='padding: 10px; border: 1px solid #ccc; text-align: center;'>Validation</th>"
+        "<th style='padding: 10px; border: 1px solid #ccc; text-align: center;'>Test</th>"
+        "</tr></thead><tbody>" + "".join(rows) + "</tbody></table></div><br>"
+    )
+
+
+def build_tabbed_html(
+        metrics_html: str,
+        train_viz_html: str,
+        test_viz_html: str) -> str:
+    return f"""
+<style>
+.tabs {{
+  display: flex;
+  border-bottom: 2px solid #ccc;
+  margin-bottom: 1rem;
+}}
+.tab {{
+  padding: 10px 20px;
+  cursor: pointer;
+  border: 1px solid #ccc;
+  border-bottom: none;
+  background: #f9f9f9;
+  margin-right: 5px;
+  border-top-left-radius: 8px;
+  border-top-right-radius: 8px;
+}}
+.tab.active {{
+  background: white;
+  font-weight: bold;
+}}
+.tab-content {{
+  display: none;
+  padding: 20px;
+  border: 1px solid #ccc;
+  border-top: none;
+}}
+.tab-content.active {{
+  display: block;
+}}
+</style>
+
+<div class="tabs">
+  <div class="tab active" onclick="showTab('metrics')"> Config & Metrics</div>
+  <div class="tab" onclick="showTab('trainval')"> Train/Validation Plots</div>
+  <div class="tab" onclick="showTab('test')"> Test Plots</div>
+</div>
+
+<div id="metrics" class="tab-content active">
+  {metrics_html}
+</div>
+<div id="trainval" class="tab-content">
+  {train_viz_html}
+</div>
+<div id="test" class="tab-content">
+  {test_viz_html}
+</div>
+
+<script>
+function showTab(id) {{
+  document.querySelectorAll('.tab-content').forEach(el => el.classList.remove('active'));
+  document.querySelectorAll('.tab').forEach(el => el.classList.remove('active'));
+  document.getElementById(id).classList.add('active');
+  document.querySelector(`.tab[onclick*="${{id}}"]`).classList.add('active');
+}}
+</script>
+"""
+
+
+def split_data_0_2(
+    df: pd.DataFrame,
+    split_column: str,
+    validation_size: float = 0.15,
+    random_state: int = 42,
+    label_column: Optional[str] = None,
+) -> pd.DataFrame:
+    """
+    Given a DataFrame whose split_column only contains {0,2}, re-assign
+    a portion of the 0s to become 1s (validation). Returns a fresh DataFrame.
+    """
+    # Work on a copy
+    out = df.copy()
+    # Ensure split col is integer dtype
+    out[split_column] = pd.to_numeric(out[split_column], errors="coerce").astype(int)
+
+    idx_train = out.index[out[split_column] == 0].tolist()
+
+    if not idx_train:
+        logger.info("No rows with split=0; nothing to do.")
+        return out
+
+    # Determine stratify array if possible
+    stratify_arr = None
+    if label_column and label_column in out.columns:
+        # Only stratify if at least two classes and enough samples
+        label_counts = out.loc[idx_train, label_column].value_counts()
+        if label_counts.size > 1 and (label_counts.min() * validation_size) >= 1:
+            stratify_arr = out.loc[idx_train, label_column]
+        else:
+            logger.warning("Cannot stratify (too few labels); splitting without stratify.")
+
+    # Edge cases
+    if validation_size <= 0:
+        logger.info("validation_size <= 0; keeping all as train.")
+        return out
+    if validation_size >= 1:
+        logger.info("validation_size >= 1; moving all train → validation.")
+        out.loc[idx_train, split_column] = 1
+        return out
+
+    # Do the split
+    try:
+        train_idx, val_idx = train_test_split(
+            idx_train,
+            test_size=validation_size,
+            random_state=random_state,
+            stratify=stratify_arr
+        )
+    except ValueError as e:
+        logger.warning(f"Stratified split failed ({e}); retrying without stratify.")
+        train_idx, val_idx = train_test_split(
+            idx_train,
+            test_size=validation_size,
+            random_state=random_state,
+            stratify=None
+        )
+
+    # Assign new splits
+    out.loc[train_idx, split_column] = 0
+    out.loc[val_idx, split_column] = 1
+    # idx_test stays at 2
+
+    # Cast back to a clean integer type
+    out[split_column] = out[split_column].astype(int)
+    # print(out)
+    return out
+
+
+class Backend(Protocol):
+    """Interface for a machine learning backend."""
+    def prepare_config(
+        self,
+        config_params: Dict[str, Any],
+        split_config: Dict[str, Any]
+    ) -> str:
+        ...
+
+    def run_experiment(
+        self,
+        dataset_path: Path,
+        config_path: Path,
+        output_dir: Path,
+        random_seed: int,
+    ) -> None:
+        ...
+
+    def generate_plots(
+        self,
+        output_dir: Path
+    ) -> None:
+        ...
+
+    def generate_html_report(
+        self,
+        title: str,
+        output_dir: str
+    ) -> Path:
+        ...
+
+
+class LudwigDirectBackend:
+    """
+    Backend for running Ludwig experiments directly via the internal experiment_cli function.
+    """
+
+    def prepare_config(
+        self,
+        config_params: Dict[str, Any],
+        split_config: Dict[str, Any],
+    ) -> str:
+        """
+        Build and serialize the Ludwig YAML configuration.
+        """
+        logger.info("LudwigDirectBackend: Preparing YAML configuration.")
+
+        model_name = config_params.get("model_name", "resnet18")
+        use_pretrained = config_params.get("use_pretrained", False)
+        fine_tune = config_params.get("fine_tune", False)
+        epochs = config_params.get("epochs", 10)
+        batch_size = config_params.get("batch_size")
+        num_processes = config_params.get("preprocessing_num_processes", 1)
+        early_stop = config_params.get("early_stop", None)
+        learning_rate = config_params.get("learning_rate")
+        learning_rate = "auto" if learning_rate is None else float(learning_rate)
+        trainable = fine_tune or (not use_pretrained)
+        if not use_pretrained and not trainable:
+            logger.warning("trainable=False; use_pretrained=False is ignored.")
+            logger.warning("Setting trainable=True to train the model from scratch.")
+            trainable = True
+
+        # Encoder setup
+        raw_encoder = MODEL_ENCODER_TEMPLATES.get(model_name, model_name)
+        if isinstance(raw_encoder, dict):
+            encoder_config = {
+                **raw_encoder,
+                "use_pretrained": use_pretrained,
+                "trainable": trainable,
+            }
+        else:
+            encoder_config = {"type": raw_encoder}
+
+        # Trainer & optimizer
+        # optimizer = {"type": "adam", "learning_rate": 5e-5} if fine_tune else {"type": "adam"}
+        batch_size_cfg = batch_size or "auto"
+
+        conf: Dict[str, Any] = {
+            "model_type": "ecd",
+            "input_features": [
+                {
+                    "name": IMAGE_PATH_COLUMN_NAME,
+                    "type": "image",
+                    "encoder": encoder_config,
+                }
+            ],
+            "output_features": [
+                {"name": LABEL_COLUMN_NAME, "type": "category"}
+            ],
+            "combiner": {"type": "concat"},
+            "trainer": {
+                "epochs": epochs,
+                "early_stop": early_stop,
+                "batch_size": batch_size_cfg,
+                "learning_rate": learning_rate,
+            },
+            "preprocessing": {
+                "split": split_config,
+                "num_processes": num_processes,
+                "in_memory": False,
+            },
+        }
+
+        logger.debug("LudwigDirectBackend: Config dict built.")
+        try:
+            yaml_str = yaml.dump(conf, sort_keys=False, indent=2)
+            logger.info("LudwigDirectBackend: YAML config generated.")
+            return yaml_str
+        except Exception:
+            logger.error("LudwigDirectBackend: Failed to serialize YAML.", exc_info=True)
+            raise
+
+    def run_experiment(
+        self,
+        dataset_path: Path,
+        config_path: Path,
+        output_dir: Path,
+        random_seed: int = 42,
+    ) -> None:
+        """
+        Invoke Ludwig's internal experiment_cli function to run the experiment.
+        """
+        logger.info("LudwigDirectBackend: Starting experiment execution.")
+
+        try:
+            from ludwig.experiment import experiment_cli
+        except ImportError as e:
+            logger.error(
+                "LudwigDirectBackend: Could not import experiment_cli.",
+                exc_info=True
+            )
+            raise RuntimeError("Ludwig import failed.") from e
+
+        output_dir.mkdir(parents=True, exist_ok=True)
+
+        try:
+            experiment_cli(
+                dataset=str(dataset_path),
+                config=str(config_path),
+                output_directory=str(output_dir),
+                random_seed=random_seed,
+            )
+            logger.info(f"LudwigDirectBackend: Experiment completed. Results in {output_dir}")
+        except TypeError as e:
+            logger.error(
+                "LudwigDirectBackend: Argument mismatch in experiment_cli call.",
+                exc_info=True
+            )
+            raise RuntimeError("Ludwig argument error.") from e
+        except Exception:
+            logger.error(
+                "LudwigDirectBackend: Experiment execution error.",
+                exc_info=True
+            )
+            raise
+
+    def get_training_process(self, output_dir) -> float:
+        """
+        Retrieve the learning rate used in the most recent Ludwig run.
+        Returns:
+            float: learning rate (or None if not found)
+        """
+        output_dir = Path(output_dir)
+        exp_dirs = sorted(
+            output_dir.glob("experiment_run*"),
+            key=lambda p: p.stat().st_mtime
+        )
+
+        if not exp_dirs:
+            logger.warning(f"No experiment run directories found in {output_dir}")
+            return None
+
+        progress_file = exp_dirs[-1] / "model" / "training_progress.json"
+        if not progress_file.exists():
+            logger.warning(f"No training_progress.json found in {progress_file}")
+            return None
+
+        try:
+            with progress_file.open("r", encoding="utf-8") as f:
+                data = json.load(f)
+            return {
+                "learning_rate": data.get("learning_rate"),
+                "batch_size": data.get("batch_size"),
+                "epoch": data.get("epoch"),
+            }
+        except Exception as e:
+            self.logger.warning(f"Failed to read training progress info: {e}")
+            return {}
+
+    def convert_parquet_to_csv(self, output_dir: Path):
+        """Convert the predictions Parquet file to CSV."""
+        output_dir = Path(output_dir)
+        exp_dirs = sorted(
+            output_dir.glob("experiment_run*"),
+            key=lambda p: p.stat().st_mtime
+        )
+        if not exp_dirs:
+            logger.warning(f"No experiment run dirs found in {output_dir}")
+            return
+        exp_dir = exp_dirs[-1]
+        parquet_path = exp_dir / PREDICTIONS_PARQUET_FILE_NAME
+        csv_path = exp_dir / "predictions.csv"
+        try:
+            df = pd.read_parquet(parquet_path)
+            df.to_csv(csv_path, index=False)
+            logger.info(f"Converted Parquet to CSV: {csv_path}")
+        except Exception as e:
+            logger.error(f"Error converting Parquet to CSV: {e}")
+
+    def generate_plots(self, output_dir: Path) -> None:
+        """
+        Generate _all_ registered Ludwig visualizations for the latest experiment run.
+        """
+        logger.info("Generating all Ludwig visualizations…")
+
+        test_plots = {
+            'compare_performance',
+            'compare_classifiers_performance_from_prob',
+            'compare_classifiers_performance_from_pred',
+            'compare_classifiers_performance_changing_k',
+            'compare_classifiers_multiclass_multimetric',
+            'compare_classifiers_predictions',
+            'confidence_thresholding_2thresholds_2d',
+            'confidence_thresholding_2thresholds_3d',
+            'confidence_thresholding',
+            'confidence_thresholding_data_vs_acc',
+            'binary_threshold_vs_metric',
+            'roc_curves',
+            'roc_curves_from_test_statistics',
+            'calibration_1_vs_all',
+            'calibration_multiclass',
+            'confusion_matrix',
+            'frequency_vs_f1',
+        }
+        train_plots = {
+            'learning_curves',
+            'compare_classifiers_performance_subset',
+        }
+
+        # 1) find the most recent experiment directory
+        output_dir = Path(output_dir)
+        exp_dirs = sorted(
+            output_dir.glob("experiment_run*"),
+            key=lambda p: p.stat().st_mtime
+        )
+        if not exp_dirs:
+            logger.warning(f"No experiment run dirs found in {output_dir}")
+            return
+        exp_dir = exp_dirs[-1]
+
+        # 2) ensure viz output subfolder exists
+        viz_dir = exp_dir / "visualizations"
+        viz_dir.mkdir(exist_ok=True)
+        train_viz = viz_dir / "train"
+        test_viz = viz_dir / "test"
+        train_viz.mkdir(parents=True, exist_ok=True)
+        test_viz.mkdir(parents=True, exist_ok=True)
+
+        # 3) helper to check file existence
+        def _check(p: Path) -> Optional[str]:
+            return str(p) if p.exists() else None
+
+        # 4) gather standard Ludwig output files
+        training_stats = _check(exp_dir / "training_statistics.json")
+        test_stats = _check(exp_dir / TEST_STATISTICS_FILE_NAME)
+        probs_path = _check(exp_dir / PREDICTIONS_PARQUET_FILE_NAME)
+        gt_metadata = _check(exp_dir / "model" / TRAIN_SET_METADATA_FILE_NAME)
+
+        # 5) try to read original dataset & split file from description.json
+        dataset_path = None
+        split_file = None
+        desc = exp_dir / DESCRIPTION_FILE_NAME
+        if desc.exists():
+            with open(desc, "r") as f:
+                cfg = json.load(f)
+            dataset_path = _check(Path(cfg.get("dataset", "")))
+            split_file = _check(Path(get_split_path(cfg.get("dataset", ""))))
+
+        # 6) infer output feature name
+        output_feature = ""
+        if desc.exists():
+            try:
+                output_feature = cfg["config"]["output_features"][0]["name"]
+            except Exception:
+                pass
+        if not output_feature and test_stats:
+            with open(test_stats, "r") as f:
+                stats = json.load(f)
+            output_feature = next(iter(stats.keys()), "")
+
+        # 7) loop through every registered viz
+        viz_registry = get_visualizations_registry()
+        for viz_name, viz_func in viz_registry.items():
+            viz_dir_plot = None
+            if viz_name in train_plots:
+                viz_dir_plot = train_viz
+            elif viz_name in test_plots:
+                viz_dir_plot = test_viz
+
+            try:
+                viz_func(
+                    training_statistics=[training_stats] if training_stats else [],
+                    test_statistics=[test_stats] if test_stats else [],
+                    probabilities=[probs_path] if probs_path else [],
+                    output_feature_name=output_feature,
+                    ground_truth_split=2,
+                    top_n_classes=[0],
+                    top_k=3,
+                    ground_truth_metadata=gt_metadata,
+                    ground_truth=dataset_path,
+                    split_file=split_file,
+                    output_directory=str(viz_dir_plot),
+                    normalize=False,
+                    file_format="png",
+                )
+                logger.info(f"✔ Generated {viz_name}")
+            except Exception as e:
+                logger.warning(f"✘ Skipped {viz_name}: {e}")
+
+        logger.info(f"All visualizations written to {viz_dir}")
+
+    def generate_html_report(
+            self,
+            title: str,
+            output_dir: str,
+            config: dict,
+            split_info: str) -> Path:
+        """
+        Assemble an HTML report from visualizations under train_val/ and test/ folders.
+        """
+        cwd = Path.cwd()
+        report_name = title.lower().replace(" ", "_") + "_report.html"
+        report_path = cwd / report_name
+        output_dir = Path(output_dir)
+
+        # Find latest experiment dir
+        exp_dirs = sorted(output_dir.glob("experiment_run*"), key=lambda p: p.stat().st_mtime)
+        if not exp_dirs:
+            raise RuntimeError(f"No 'experiment*' dirs found in {output_dir}")
+        exp_dir = exp_dirs[-1]
+
+        base_viz_dir = exp_dir / "visualizations"
+        train_viz_dir = base_viz_dir / "train"
+        test_viz_dir = base_viz_dir / "test"
+
+        html = get_html_template()
+        html += f"<h1>{title}</h1>"
+
+        metrics_html = ""
+
+        # Load and embed metrics table (training/val/test stats)
+        try:
+            train_stats_path = exp_dir / "training_statistics.json"
+            test_stats_path = exp_dir / TEST_STATISTICS_FILE_NAME
+            if train_stats_path.exists() and test_stats_path.exists():
+                with open(train_stats_path) as f:
+                    train_stats = json.load(f)
+                with open(test_stats_path) as f:
+                    test_stats = json.load(f)
+                output_feature = next(iter(train_stats.keys()), "")
+                if output_feature:
+                    metrics_html += format_stats_table_html(train_stats, test_stats)
+        except Exception as e:
+            logger.warning(f"Could not load stats for HTML report: {e}")
+
+        config_html = ""
+        training_progress = self.get_training_process(output_dir)
+        try:
+            config_html = format_config_table_html(config, split_info, training_progress)
+        except Exception as e:
+            logger.warning(f"Could not load config for HTML report: {e}")
+
+        def render_img_section(title: str, dir_path: Path) -> str:
+            if not dir_path.exists():
+                return f"<h2>{title}</h2><p><em>Directory not found.</em></p>"
+            imgs = sorted(dir_path.glob("*.png"))
+            if not imgs:
+                return f"<h2>{title}</h2><p><em>No plots found.</em></p>"
+
+            section_html = f"<h2 style='text-align: center;'>{title}</h2><div>"
+            for img in imgs:
+                b64 = encode_image_to_base64(str(img))
+                section_html += (
+                    f'<div class="plot" style="margin-bottom:20px;text-align:center;">'
+                    f"<h3>{img.stem.replace('_',' ').title()}</h3>"
+                    f'<img src="data:image/png;base64,{b64}" '
+                    'style="max-width:90%;max-height:600px;border:1px solid #ddd;" />'
+                    "</div>"
+                )
+            section_html += "</div>"
+            return section_html
+
+        train_plots_html = render_img_section("Training & Validation Visualizations", train_viz_dir)
+        test_plots_html = render_img_section("Test Visualizations", test_viz_dir)
+        html += build_tabbed_html(config_html + metrics_html, train_plots_html, test_plots_html)
+        html += get_html_closing()
+
+        try:
+            with open(report_path, "w") as f:
+                f.write(html)
+            logger.info(f"HTML report generated at: {report_path}")
+        except Exception as e:
+            logger.error(f"Failed to write HTML report: {e}")
+            raise
+
+        return report_path
+
+
+class WorkflowOrchestrator:
+    """
+    Manages the image-classification workflow:
+      1. Creates temp dirs
+      2. Extracts images
+      3. Prepares data (CSV + splits)
+      4. Renders a backend config
+      5. Runs the experiment
+      6. Cleans up
+    """
+
+    def __init__(self, args: argparse.Namespace, backend: Backend):
+        self.args = args
+        self.backend = backend
+        self.temp_dir: Optional[Path] = None
+        self.image_extract_dir: Optional[Path] = None
+        logger.info(f"Orchestrator initialized with backend: {type(backend).__name__}")
+
+    def _create_temp_dirs(self) -> None:
+        """Create temporary output and image extraction directories."""
+        try:
+            self.temp_dir = Path(tempfile.mkdtemp(
+                dir=self.args.output_dir,
+                prefix=TEMP_DIR_PREFIX
+            ))
+            self.image_extract_dir = self.temp_dir / "images"
+            self.image_extract_dir.mkdir()
+            logger.info(f"Created temp directory: {self.temp_dir}")
+        except Exception:
+            logger.error("Failed to create temporary directories", exc_info=True)
+            raise
+
+    def _extract_images(self) -> None:
+        """Extract images from ZIP into the temp image directory."""
+        if self.image_extract_dir is None:
+            raise RuntimeError("Temp image directory not initialized.")
+        logger.info(f"Extracting images from {self.args.image_zip} → {self.image_extract_dir}")
+        try:
+            with zipfile.ZipFile(self.args.image_zip, "r") as z:
+                z.extractall(self.image_extract_dir)
+            logger.info("Image extraction complete.")
+        except Exception:
+            logger.error("Error extracting zip file", exc_info=True)
+            raise
+
+    def _prepare_data(self) -> Tuple[Path, Dict[str, Any]]:
+        """
+        Load CSV, update image paths, handle splits, and write prepared CSV.
+        Returns:
+            final_csv_path: Path to the prepared CSV
+            split_config: Dict for backend split settings
+        """
+        if not self.temp_dir or not self.image_extract_dir:
+            raise RuntimeError("Temp dirs not initialized before data prep.")
+
+        # 1) Load
+        try:
+            df = pd.read_csv(self.args.csv_file)
+            logger.info(f"Loaded CSV: {self.args.csv_file}")
+        except Exception:
+            logger.error("Error loading CSV file", exc_info=True)
+            raise
+
+        # 2) Validate columns
+        required = {IMAGE_PATH_COLUMN_NAME, LABEL_COLUMN_NAME}
+        missing = required - set(df.columns)
+        if missing:
+            raise ValueError(f"Missing CSV columns: {', '.join(missing)}")
+
+        # 3) Update image paths
+        try:
+            df[IMAGE_PATH_COLUMN_NAME] = df[IMAGE_PATH_COLUMN_NAME].apply(
+                lambda p: str((self.image_extract_dir / p).resolve())
+            )
+        except Exception:
+            logger.error("Error updating image paths", exc_info=True)
+            raise
+
+        # 4) Handle splits
+        if SPLIT_COLUMN_NAME in df.columns:
+            df, split_config, split_info = self._process_fixed_split(df)
+        else:
+            logger.info("No split column; using random split")
+            split_config = {
+                "type": "random",
+                "probabilities": self.args.split_probabilities
+            }
+            split_info = (
+                f"No split column in CSV. Used random split: "
+                f"{[int(p*100) for p in self.args.split_probabilities]}% for train/val/test."
+            )
+
+        # 5) Write out prepared CSV
+        final_csv = TEMP_CSV_FILENAME
+        try:
+            df.to_csv(final_csv, index=False)
+            logger.info(f"Saved prepared data to {final_csv}")
+        except Exception:
+            logger.error("Error saving prepared CSV", exc_info=True)
+            raise
+
+        return final_csv, split_config, split_info
+
+    def _process_fixed_split(self, df: pd.DataFrame) -> Dict[str, Any]:
+        """Process a fixed split column (0=train,1=val,2=test)."""
+        logger.info(f"Fixed split column '{SPLIT_COLUMN_NAME}' detected.")
+        try:
+            col = df[SPLIT_COLUMN_NAME]
+            df[SPLIT_COLUMN_NAME] = pd.to_numeric(col, errors="coerce").astype(pd.Int64Dtype())
+            if df[SPLIT_COLUMN_NAME].isna().any():
+                logger.warning("Split column contains non-numeric/missing values.")
+
+            unique = set(df[SPLIT_COLUMN_NAME].dropna().unique())
+            logger.info(f"Unique split values: {unique}")
+
+            if unique == {0, 2}:
+                df = split_data_0_2(
+                    df, SPLIT_COLUMN_NAME,
+                    validation_size=self.args.validation_size,
+                    label_column=LABEL_COLUMN_NAME,
+                    random_state=self.args.random_seed
+                )
+                split_info = (
+                    "Detected a split column (with values 0 and 2) in the input CSV. "
+                    f"Used this column as a base and"
+                    f"reassigned {self.args.validation_size * 100:.1f}% "
+                    "of the training set (originally labeled 0) to validation (labeled 1)."
+                )
+
+                logger.info("Applied custom 0/2 split.")
+            elif unique.issubset({0, 1, 2}):
+                split_info = "Used user-defined split column from CSV."
+                logger.info("Using fixed split as-is.")
+            else:
+                raise ValueError(f"Unexpected split values: {unique}")
+
+            return df, {"type": "fixed", "column": SPLIT_COLUMN_NAME}, split_info
+
+        except Exception:
+            logger.error("Error processing fixed split", exc_info=True)
+            raise
+
+    def _cleanup_temp_dirs(self) -> None:
+        """Remove any temporary directories."""
+        if self.temp_dir and self.temp_dir.exists():
+            logger.info(f"Cleaning up temp directory: {self.temp_dir}")
+            shutil.rmtree(self.temp_dir, ignore_errors=True)
+        self.temp_dir = None
+        self.image_extract_dir = None
+
+    def run(self) -> None:
+        """Execute the full workflow end-to-end."""
+        logger.info("Starting workflow...")
+        self.args.output_dir.mkdir(parents=True, exist_ok=True)
+
+        try:
+            self._create_temp_dirs()
+            self._extract_images()
+            csv_path, split_cfg, split_info = self._prepare_data()
+
+            use_pretrained = self.args.use_pretrained or self.args.fine_tune
+
+            backend_args = {
+                "model_name": self.args.model_name,
+                "fine_tune": self.args.fine_tune,
+                "use_pretrained": use_pretrained,
+                "epochs": self.args.epochs,
+                "batch_size": self.args.batch_size,
+                "preprocessing_num_processes": self.args.preprocessing_num_processes,
+                "split_probabilities": self.args.split_probabilities,
+                "learning_rate": self.args.learning_rate,
+                "random_seed": self.args.random_seed,
+                "early_stop": self.args.early_stop,
+            }
+            yaml_str = self.backend.prepare_config(backend_args, split_cfg)
+
+            config_file = self.temp_dir / TEMP_CONFIG_FILENAME
+            config_file.write_text(yaml_str)
+            logger.info(f"Wrote backend config: {config_file}")
+
+            self.backend.run_experiment(
+                csv_path,
+                config_file,
+                self.args.output_dir,
+                self.args.random_seed
+            )
+            logger.info("Workflow completed successfully.")
+            self.backend.generate_plots(self.args.output_dir)
+            report_file = self.backend.generate_html_report(
+                "Image Classification Results",
+                self.args.output_dir,
+                backend_args,
+                split_info
+            )
+            logger.info(f"HTML report generated at: {report_file}")
+            self.backend.convert_parquet_to_csv(self.args.output_dir)
+            logger.info("Converted Parquet to CSV.")
+        except Exception:
+            logger.error("Workflow execution failed", exc_info=True)
+            raise
+
+        finally:
+            self._cleanup_temp_dirs()
+
+
+def parse_learning_rate(s):
+    try:
+        return float(s)
+    except (TypeError, ValueError):
+        return None
+
+
+class SplitProbAction(argparse.Action):
+    def __call__(self, parser, namespace, values, option_string=None):
+        # values is a list of three floats
+        train, val, test = values
+        total = train + val + test
+        if abs(total - 1.0) > 1e-6:
+            parser.error(
+                f"--split-probabilities must sum to 1.0; "
+                f"got {train:.3f} + {val:.3f} + {test:.3f} = {total:.3f}"
+            )
+        setattr(namespace, self.dest, values)
+
+
+def main():
+
+    parser = argparse.ArgumentParser(
+        description="Image Classification Learner with Pluggable Backends"
+    )
+    parser.add_argument(
+        "--csv-file", required=True, type=Path,
+        help="Path to the input CSV"
+    )
+    parser.add_argument(
+        "--image-zip", required=True, type=Path,
+        help="Path to the images ZIP"
+    )
+    parser.add_argument(
+        "--model-name", required=True,
+        choices=MODEL_ENCODER_TEMPLATES.keys(),
+        help="Which model template to use"
+    )
+    parser.add_argument(
+        "--use-pretrained", action="store_true",
+        help="Use pretrained weights for the model"
+    )
+    parser.add_argument(
+        "--fine-tune", action="store_true",
+        help="Enable fine-tuning"
+    )
+    parser.add_argument(
+        "--epochs", type=int, default=10,
+        help="Number of training epochs"
+    )
+    parser.add_argument(
+        "--early-stop", type=int, default=5,
+        help="Early stopping patience"
+    )
+    parser.add_argument(
+        "--batch-size", type=int,
+        help="Batch size (None = auto)"
+    )
+    parser.add_argument(
+        "--output-dir", type=Path, default=Path("learner_output"),
+        help="Where to write outputs"
+    )
+    parser.add_argument(
+        "--validation-size", type=float, default=0.15,
+        help="Fraction for validation (0.0–1.0)"
+    )
+    parser.add_argument(
+        "--preprocessing-num-processes", type=int,
+        default=max(1, os.cpu_count() // 2),
+        help="CPU processes for data prep"
+    )
+    parser.add_argument(
+        "--split-probabilities", type=float, nargs=3,
+        metavar=("train", "val", "test"),
+        action=SplitProbAction,
+        default=[0.7, 0.1, 0.2],
+        help="Random split proportions (e.g., 0.7 0.1 0.2). Only used if no split column is present."
+    )
+    parser.add_argument(
+        "--random-seed", type=int, default=42,
+        help="Random seed used for dataset splitting (default: 42)"
+    )
+    parser.add_argument(
+        "--learning-rate", type=parse_learning_rate, default=None,
+        help="Learning rate. If not provided, Ludwig will auto-select it."
+    )
+
+    args = parser.parse_args()
+
+    # -- Validation --
+    if not 0.0 <= args.validation_size <= 1.0:
+        parser.error("validation-size must be between 0.0 and 1.0")
+    if not args.csv_file.is_file():
+        parser.error(f"CSV not found: {args.csv_file}")
+    if not args.image_zip.is_file():
+        parser.error(f"ZIP not found: {args.image_zip}")
+
+    # --- Instantiate Backend and Orchestrator ---
+    # Use the new LudwigDirectBackend
+    backend_instance = LudwigDirectBackend()
+    orchestrator = WorkflowOrchestrator(args, backend_instance)
+
+    # --- Run Workflow ---
+    exit_code = 0
+    try:
+        orchestrator.run()
+        logger.info("Main script finished successfully.")
+    except Exception as e:
+        logger.error(f"Main script failed.{e}")
+        exit_code = 1
+    finally:
+        sys.exit(exit_code)
+
+
+if __name__ == '__main__':
+    try:
+        import ludwig
+        logger.debug(f"Found Ludwig version: {ludwig.globals.LUDWIG_VERSION}")
+    except ImportError:
+        logger.error("Ludwig library not found. Please ensure Ludwig is installed ('pip install ludwig[image]')")
+        sys.exit(1)
+
+    main()
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/test-data/image_classification_results_report_mnist.html	Tue Jun 03 21:22:11 2025 +0000
@@ -0,0 +1,129 @@
+
+    <html>
+    <head>
+        <title>Galaxy-Ludwig Report</title>
+        <style>
+          body {
+              font-family: Arial, sans-serif;
+              margin: 0;
+              padding: 20px;
+              background-color: #f4f4f4;
+          }
+          .container {
+              max-width: 800px;
+              margin: auto;
+              background: white;
+              padding: 20px;
+              box-shadow: 0 0 10px rgba(0, 0, 0, 0.1);
+              overflow-x: auto;
+          }
+          h1 {
+              text-align: center;
+              color: #333;
+          }
+          h2 {
+              border-bottom: 2px solid #4CAF50;
+              color: #4CAF50;
+              padding-bottom: 5px;
+          }
+          table {
+              border-collapse: collapse;
+              margin: 20px 0;
+              width: 100%;
+              table-layout: fixed; /* Enforces consistent column widths */
+          }
+          table, th, td {
+              border: 1px solid #ddd;
+          }
+          th, td {
+              padding: 8px;
+              text-align: center; /* Center-align text */
+              vertical-align: middle; /* Center-align content vertically */
+              word-wrap: break-word; /* Break long words to avoid overflow */
+          }
+          th:first-child, td:first-child {
+              width: 5%; /* Smaller width for the first column */
+          }
+          th:nth-child(2), td:nth-child(2) {
+              width: 50%; /* Wider for the metric/description column */
+          }
+          th:last-child, td:last-child {
+              width: 25%; /* Value column gets remaining space */
+          }
+          th {
+              background-color: #4CAF50;
+              color: white;
+          }
+          .plot {
+              text-align: center;
+              margin: 20px 0;
+          }
+          .plot img {
+              max-width: 100%;
+              height: auto;
+          }
+        </style>
+    </head>
+    <body>
+    <div class="container">
+    <h1>Image Classification Results</h1>
+<style>
+.tabs {
+  display: flex;
+  border-bottom: 2px solid #ccc;
+  margin-bottom: 1rem;
+}
+.tab {
+  padding: 10px 20px;
+  cursor: pointer;
+  border: 1px solid #ccc;
+  border-bottom: none;
+  background: #f9f9f9;
+  margin-right: 5px;
+  border-top-left-radius: 8px;
+  border-top-right-radius: 8px;
+}
+.tab.active {
+  background: white;
+  font-weight: bold;
+}
+.tab-content {
+  display: none;
+  padding: 20px;
+  border: 1px solid #ccc;
+  border-top: none;
+}
+.tab-content.active {
+  display: block;
+}
+</style>
+
+<div class="tabs">
+  <div class="tab active" onclick="showTab('metrics')"> Config & Metrics</div>
+  <div class="tab" onclick="showTab('trainval')"> Train/Validation Plots</div>
+  <div class="tab" onclick="showTab('test')"> Test Plots</div>
+</div>
+
+<div id="metrics" class="tab-content active">
+  <h2 style='text-align: center;'>Training Setup</h2><div style='display: flex; justify-content: center;'><table style='border-collapse: collapse; width: 60%; table-layout: auto;'><thead><tr><th style='padding: 10px; border: 1px solid #ccc; text-align: left;'>Parameter</th><th style='padding: 10px; border: 1px solid #ccc; text-align: center;'>Value</th></tr></thead><tbody><tr><td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left;'>Model Name</td><td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center;'>resnet18</td></tr><tr><td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left;'>Epochs</td><td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center;'>Because of early stopping: the trainingstopped at epoch 7</td></tr><tr><td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left;'>Batch Size</td><td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center;'>Auto-selected batch size by Ludwig:<br><span style='font-size: 0.85em;'>16</span><br></td></tr><tr><td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left;'>Fine Tune</td><td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center;'>True</td></tr><tr><td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left;'>Use Pretrained</td><td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center;'>True</td></tr><tr><td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left;'>Learning Rate</td><td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center;'>Auto-selected learning rate by Ludwig:<br><span style='font-size: 0.85em;'>1e-05</span><br><span style='font-size: 0.85em;'>Based on model architecture and training setup (e.g., fine-tuning).<br>See <a href='https://ludwig.ai/latest/configuration/trainer/#trainer-parameters' target='_blank'>Ludwig Trainer Parameters</a> for details.</span></td></tr><tr><td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left;'>Random Seedearly Stop</td><td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center;'>N/A</td></tr><tr><td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left;'>Data Split</td><td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center;'>Detected a split column (with values 0 and 2) in the input CSV. Used this column as a base andreassigned 15.0% of the training set (originally labeled 0) to validation (labeled 1).</td></tr></tbody></table></div><br><p style='text-align: center; font-size: 0.9em;'>Model trained using Ludwig.<br>If want to learn more about Ludwig default settings,please check the their <a href='https://ludwig.ai' target='_blank'>website(ludwig.ai)</a>.</p><hr><h2 style='text-align: center;'>Model Performance Summary</h2><div style='display: flex; justify-content: center;'><table style='border-collapse: collapse; width: 80%; table-layout: fixed;'><colgroup><col style='width: 40%;'><col style='width: 20%;'><col style='width: 20%;'><col style='width: 20%;'></colgroup><thead><tr><th style='padding: 10px; border: 1px solid #ccc; text-align: left;'>Metric</th><th style='padding: 10px; border: 1px solid #ccc; text-align: center;'>Train</th><th style='padding: 10px; border: 1px solid #ccc; text-align: center;'>Validation</th><th style='padding: 10px; border: 1px solid #ccc; text-align: center;'>Test</th></tr></thead><tbody><tr><td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left;'>accuracy</td><td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center;'>0.8417</td><td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center;'>0.1500</td><td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center;'>0.2000</td></tr><tr><td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left;'>accuracy_micro</td><td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center;'>0.8471</td><td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center;'>0.2000</td><td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center;'>0.2000</td></tr><tr><td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left;'>hits_at_k</td><td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center;'>0.9250</td><td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center;'>0.4500</td><td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center;'>0.3000</td></tr><tr><td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left;'>loss</td><td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center;'>0.6749</td><td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center;'>2.7907</td><td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center;'>2.8261</td></tr><tr><td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left;'>roc_auc</td><td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center;'>0.9998</td><td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center;'>0.7824</td><td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center;'>0.6917</td></tr></tbody></table></div><br>
+</div>
+<div id="trainval" class="tab-content">
+  <h2 style='text-align: center;'>Training & Validation Visualizations</h2><div><div class="plot" style="margin-bottom:20px;text-align:center;"><h3>Learning Curves Label Accuracy</h3><img src="" style="max-width:90%;max-height:600px;border:1px solid #ddd;" /></div><div class="plot" style="margin-bottom:20px;text-align:center;"><h3>Learning Curves Label Hits At K</h3><img src="" style="max-width:90%;max-height:600px;border:1px solid #ddd;" /></div><div class="plot" style="margin-bottom:20px;text-align:center;"><h3>Learning Curves Label Loss</h3><img src="" style="max-width:90%;max-height:600px;border:1px solid #ddd;" /></div></div>
+</div>
+<div id="test" class="tab-content">
+  <h2 style='text-align: center;'>Test Visualizations</h2><div><div class="plot" style="margin-bottom:20px;text-align:center;"><h3>Compare Classifiers Multiclass Multimetric  Label Best10</h3><img src="" style="max-width:90%;max-height:600px;border:1px solid #ddd;" /></div><div class="plot" style="margin-bottom:20px;text-align:center;"><h3>Compare Classifiers Multiclass Multimetric  Label Sorted</h3><img src="" style="max-width:90%;max-height:600px;border:1px solid #ddd;" /></div><div class="plot" style="margin-bottom:20px;text-align:center;"><h3>Compare Classifiers Multiclass Multimetric  Label Top10</h3><img src="" style="max-width:90%;max-height:600px;border:1px solid #ddd;" /></div><div class="plot" style="margin-bottom:20px;text-align:center;"><h3>Compare Classifiers Multiclass Multimetric  Label Worst10</h3><img src="" style="max-width:90%;max-height:600px;border:1px solid #ddd;" /></div><div class="plot" style="margin-bottom:20px;text-align:center;"><h3>Compare Classifiers Performance From Prob</h3><img src="" style="max-width:90%;max-height:600px;border:1px solid #ddd;" /></div><div class="plot" style="margin-bottom:20px;text-align:center;"><h3>Compare Performance Label</h3><img src="" style="max-width:90%;max-height:600px;border:1px solid #ddd;" /></div><div class="plot" style="margin-bottom:20px;text-align:center;"><h3>Confusion Matrix  Label Top10</h3><img src="" style="max-width:90%;max-height:600px;border:1px solid #ddd;" /></div><div class="plot" style="margin-bottom:20px;text-align:center;"><h3>Confusion Matrix Entropy  Label Top10</h3><img src="" style="max-width:90%;max-height:600px;border:1px solid #ddd;" /></div><div class="plot" style="margin-bottom:20px;text-align:center;"><h3>Frequency Vs F1  Label</h3><img src="" style="max-width:90%;max-height:600px;border:1px solid #ddd;" /></div><div class="plot" style="margin-bottom:20px;text-align:center;"><h3>Roc Curves</h3><img src="" style="max-width:90%;max-height:600px;border:1px solid #ddd;" /></div></div>
+</div>
+
+<script>
+function showTab(id) {
+  document.querySelectorAll('.tab-content').forEach(el => el.classList.remove('active'));
+  document.querySelectorAll('.tab').forEach(el => el.classList.remove('active'));
+  document.getElementById(id).classList.add('active');
+  document.querySelector(`.tab[onclick*="${id}"]`).classList.add('active');
+}
+</script>
+
+    </div>
+    </body>
+    </html>
+    
\ No newline at end of file
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/test-data/mnist_subset.csv	Tue Jun 03 21:22:11 2025 +0000
@@ -0,0 +1,121 @@
+image_path,label,split
+training/0/5680.jpg,0,0
+training/0/5699.jpg,0,0
+training/0/5766.jpg,0,0
+training/0/5524.jpg,0,0
+training/0/5003.jpg,0,0
+training/0/5527.jpg,0,0
+training/0/5359.jpg,0,0
+training/0/5452.jpg,0,0
+training/0/5010.jpg,0,0
+training/0/5405.jpg,0,0
+training/1/6100.jpg,1,0
+training/1/6015.jpg,1,0
+training/1/5754.jpg,1,0
+training/1/6275.jpg,1,0
+training/1/6247.jpg,1,0
+training/1/6552.jpg,1,0
+training/1/6129.jpg,1,0
+training/1/6733.jpg,1,0
+training/1/6590.jpg,1,0
+training/1/6727.jpg,1,0
+training/2/5585.jpg,2,0
+training/2/5865.jpg,2,0
+training/2/4984.jpg,2,0
+training/2/4992.jpg,2,0
+training/2/5008.jpg,2,0
+training/2/5325.jpg,2,0
+training/2/5438.jpg,2,0
+training/2/5807.jpg,2,0
+training/2/5323.jpg,2,0
+training/2/5407.jpg,2,0
+training/3/5869.jpg,3,0
+training/3/5333.jpg,3,0
+training/3/5813.jpg,3,0
+training/3/6093.jpg,3,0
+training/3/5714.jpg,3,0
+training/3/5519.jpg,3,0
+training/3/5586.jpg,3,0
+training/3/5410.jpg,3,0
+training/3/5577.jpg,3,0
+training/3/5710.jpg,3,0
+training/4/5092.jpg,4,0
+training/4/5793.jpg,4,0
+training/4/5610.jpg,4,0
+training/4/5123.jpg,4,0
+training/4/5685.jpg,4,0
+training/4/4972.jpg,4,0
+training/4/4887.jpg,4,0
+training/4/5052.jpg,4,0
+training/4/5348.jpg,4,0
+training/4/5368.jpg,4,0
+training/5/5100.jpg,5,0
+training/5/4442.jpg,5,0
+training/5/4745.jpg,5,0
+training/5/4592.jpg,5,0
+training/5/4707.jpg,5,0
+training/5/5305.jpg,5,0
+training/5/4506.jpg,5,0
+training/5/5118.jpg,5,0
+training/5/4888.jpg,5,0
+training/5/5282.jpg,5,0
+training/6/5553.jpg,6,0
+training/6/5260.jpg,6,0
+training/6/5899.jpg,6,0
+training/6/5231.jpg,6,0
+training/6/5743.jpg,6,0
+training/6/5567.jpg,6,0
+training/6/5823.jpg,6,0
+training/6/5849.jpg,6,0
+training/6/5076.jpg,6,0
+training/6/5435.jpg,6,0
+training/7/6036.jpg,7,0
+training/7/5488.jpg,7,0
+training/7/5506.jpg,7,0
+training/7/6194.jpg,7,0
+training/7/5934.jpg,7,0
+training/7/5634.jpg,7,0
+training/7/5834.jpg,7,0
+training/7/5721.jpg,7,0
+training/7/6204.jpg,7,0
+training/7/5481.jpg,7,0
+training/8/5844.jpg,8,0
+training/8/5001.jpg,8,0
+training/8/5785.jpg,8,0
+training/8/5462.jpg,8,0
+training/8/4938.jpg,8,0
+training/8/4933.jpg,8,0
+training/8/5341.jpg,8,0
+training/8/5057.jpg,8,0
+training/8/4880.jpg,8,0
+training/8/5039.jpg,8,0
+training/9/5193.jpg,9,0
+training/9/5870.jpg,9,0
+training/9/5756.jpg,9,0
+training/9/5186.jpg,9,0
+training/9/5688.jpg,9,0
+training/9/5579.jpg,9,0
+training/9/5444.jpg,9,0
+training/9/5931.jpg,9,0
+training/9/5541.jpg,9,0
+training/9/5786.jpg,9,0
+test/0/833.jpg,0,2
+test/0/855.jpg,0,2
+test/1/1110.jpg,1,2
+test/1/969.jpg,1,2
+test/2/961.jpg,2,2
+test/2/971.jpg,2,2
+test/3/895.jpg,3,2
+test/3/1005.jpg,3,2
+test/4/940.jpg,4,2
+test/4/975.jpg,4,2
+test/5/780.jpg,5,2
+test/5/834.jpg,5,2
+test/6/932.jpg,6,2
+test/6/796.jpg,6,2
+test/7/835.jpg,7,2
+test/7/863.jpg,7,2
+test/8/899.jpg,8,2
+test/8/898.jpg,8,2
+test/9/1007.jpg,9,2
+test/9/954.jpg,9,2
Binary file test-data/mnist_subset.zip has changed
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/utils.py	Tue Jun 03 21:22:11 2025 +0000
@@ -0,0 +1,156 @@
+import base64
+import json
+
+
+def get_html_template():
+    return """
+    <html>
+    <head>
+        <title>Galaxy-Ludwig Report</title>
+        <style>
+          body {
+              font-family: Arial, sans-serif;
+              margin: 0;
+              padding: 20px;
+              background-color: #f4f4f4;
+          }
+          .container {
+              max-width: 800px;
+              margin: auto;
+              background: white;
+              padding: 20px;
+              box-shadow: 0 0 10px rgba(0, 0, 0, 0.1);
+              overflow-x: auto;
+          }
+          h1 {
+              text-align: center;
+              color: #333;
+          }
+          h2 {
+              border-bottom: 2px solid #4CAF50;
+              color: #4CAF50;
+              padding-bottom: 5px;
+          }
+          table {
+              border-collapse: collapse;
+              margin: 20px 0;
+              width: 100%;
+              table-layout: fixed; /* Enforces consistent column widths */
+          }
+          table, th, td {
+              border: 1px solid #ddd;
+          }
+          th, td {
+              padding: 8px;
+              text-align: center; /* Center-align text */
+              vertical-align: middle; /* Center-align content vertically */
+              word-wrap: break-word; /* Break long words to avoid overflow */
+          }
+          th:first-child, td:first-child {
+              width: 5%; /* Smaller width for the first column */
+          }
+          th:nth-child(2), td:nth-child(2) {
+              width: 50%; /* Wider for the metric/description column */
+          }
+          th:last-child, td:last-child {
+              width: 25%; /* Value column gets remaining space */
+          }
+          th {
+              background-color: #4CAF50;
+              color: white;
+          }
+          .plot {
+              text-align: center;
+              margin: 20px 0;
+          }
+          .plot img {
+              max-width: 100%;
+              height: auto;
+          }
+        </style>
+    </head>
+    <body>
+    <div class="container">
+    """
+
+
+def get_html_closing():
+    return """
+    </div>
+    </body>
+    </html>
+    """
+
+
+def encode_image_to_base64(image_path):
+    """Convert an image file to a base64 encoded string."""
+    with open(image_path, "rb") as img_file:
+        return base64.b64encode(img_file.read()).decode("utf-8")
+
+
+def json_to_nested_html_table(json_data, depth=0):
+    """
+    Convert JSON object to an HTML nested table.
+
+    Parameters:
+        json_data (dict or list): The JSON data to convert.
+        depth (int): Current depth level for indentation.
+
+    Returns:
+        str: HTML string for the nested table.
+    """
+    # Base case: if JSON is a simple key-value pair dictionary
+    if isinstance(json_data, dict) and all(
+        not isinstance(v, (dict, list)) for v in json_data.values()
+    ):
+        # Render a flat table
+        rows = [
+            f"<tr><th>{key}</th><td>{value}</td></tr>"
+            for key, value in json_data.items()
+        ]
+        return f"<table>{''.join(rows)}</table>"
+
+    # Base case: if JSON is a list of simple values
+    if isinstance(json_data, list) and all(
+        not isinstance(v, (dict, list)) for v in json_data
+    ):
+        rows = [
+            f"<tr><th>Index {i}</th><td>{value}</td></tr>"
+            for i, value in enumerate(json_data)
+        ]
+        return f"<table>{''.join(rows)}</table>"
+
+    # Recursive case: if JSON contains nested structures
+    if isinstance(json_data, dict):
+        rows = [
+            f"<tr><th style='padding-left:{depth * 20}px;'>{key}</th>"
+            f"<td>{json_to_nested_html_table(value, depth + 1)}</td></tr>"
+            for key, value in json_data.items()
+        ]
+        return f"<table>{''.join(rows)}</table>"
+
+    if isinstance(json_data, list):
+        rows = [
+            f"<tr><th style='padding-left:{depth * 20}px;'>[{i}]</th>"
+            f"<td>{json_to_nested_html_table(value, depth + 1)}</td></tr>"
+            for i, value in enumerate(json_data)
+        ]
+        return f"<table>{''.join(rows)}</table>"
+
+    # Base case: simple value
+    return f"{json_data}"
+
+
+def json_to_html_table(json_data):
+    """
+    Convert JSON to a vertically oriented HTML table.
+
+    Parameters:
+        json_data (str or dict): JSON string or dictionary.
+
+    Returns:
+        str: HTML table representation.
+    """
+    if isinstance(json_data, str):
+        json_data = json.loads(json_data)
+    return json_to_nested_html_table(json_data)