changeset 2:186424a7eca7 draft

planemo upload for repository https://github.com/goeckslab/gleam.git commit 91fa4aba245520fc0680088a07cead66bcfd4ed2
author goeckslab
date Thu, 03 Jul 2025 20:43:24 +0000
parents 39202fe5cf97
children 2c3a3dfaf1a9
files constants.py image_learner.xml image_learner_cli.py test-data/age_regression.zip test-data/expected_regression.html test-data/utkface_labels.csv utils.py
diffstat 7 files changed, 900 insertions(+), 391 deletions(-) [+]
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/constants.py	Thu Jul 03 20:43:24 2025 +0000
@@ -0,0 +1,119 @@
+from typing import Any, Dict
+
+# --- Constants ---
+SPLIT_COLUMN_NAME = "split"
+LABEL_COLUMN_NAME = "label"
+IMAGE_PATH_COLUMN_NAME = "image_path"
+DEFAULT_SPLIT_PROBABILITIES = [0.7, 0.1, 0.2]
+TEMP_CSV_FILENAME = "processed_data_for_ludwig.csv"
+TEMP_CONFIG_FILENAME = "ludwig_config.yaml"
+TEMP_DIR_PREFIX = "ludwig_api_work_"
+MODEL_ENCODER_TEMPLATES: Dict[str, Any] = {
+    "stacked_cnn": "stacked_cnn",
+    "resnet18": {"type": "resnet", "model_variant": 18},
+    "resnet34": {"type": "resnet", "model_variant": 34},
+    "resnet50": {"type": "resnet", "model_variant": 50},
+    "resnet101": {"type": "resnet", "model_variant": 101},
+    "resnet152": {"type": "resnet", "model_variant": 152},
+    "resnext50_32x4d": {"type": "resnext", "model_variant": "50_32x4d"},
+    "resnext101_32x8d": {"type": "resnext", "model_variant": "101_32x8d"},
+    "resnext101_64x4d": {"type": "resnext", "model_variant": "101_64x4d"},
+    "resnext152_32x8d": {"type": "resnext", "model_variant": "152_32x8d"},
+    "wide_resnet50_2": {"type": "wide_resnet", "model_variant": "50_2"},
+    "wide_resnet101_2": {"type": "wide_resnet", "model_variant": "101_2"},
+    "wide_resnet103_2": {"type": "wide_resnet", "model_variant": "103_2"},
+    "efficientnet_b0": {"type": "efficientnet", "model_variant": "b0"},
+    "efficientnet_b1": {"type": "efficientnet", "model_variant": "b1"},
+    "efficientnet_b2": {"type": "efficientnet", "model_variant": "b2"},
+    "efficientnet_b3": {"type": "efficientnet", "model_variant": "b3"},
+    "efficientnet_b4": {"type": "efficientnet", "model_variant": "b4"},
+    "efficientnet_b5": {"type": "efficientnet", "model_variant": "b5"},
+    "efficientnet_b6": {"type": "efficientnet", "model_variant": "b6"},
+    "efficientnet_b7": {"type": "efficientnet", "model_variant": "b7"},
+    "efficientnet_v2_s": {"type": "efficientnet", "model_variant": "v2_s"},
+    "efficientnet_v2_m": {"type": "efficientnet", "model_variant": "v2_m"},
+    "efficientnet_v2_l": {"type": "efficientnet", "model_variant": "v2_l"},
+    "regnet_y_400mf": {"type": "regnet", "model_variant": "y_400mf"},
+    "regnet_y_800mf": {"type": "regnet", "model_variant": "y_800mf"},
+    "regnet_y_1_6gf": {"type": "regnet", "model_variant": "y_1_6gf"},
+    "regnet_y_3_2gf": {"type": "regnet", "model_variant": "y_3_2gf"},
+    "regnet_y_8gf": {"type": "regnet", "model_variant": "y_8gf"},
+    "regnet_y_16gf": {"type": "regnet", "model_variant": "y_16gf"},
+    "regnet_y_32gf": {"type": "regnet", "model_variant": "y_32gf"},
+    "regnet_y_128gf": {"type": "regnet", "model_variant": "y_128gf"},
+    "regnet_x_400mf": {"type": "regnet", "model_variant": "x_400mf"},
+    "regnet_x_800mf": {"type": "regnet", "model_variant": "x_800mf"},
+    "regnet_x_1_6gf": {"type": "regnet", "model_variant": "x_1_6gf"},
+    "regnet_x_3_2gf": {"type": "regnet", "model_variant": "x_3_2gf"},
+    "regnet_x_8gf": {"type": "regnet", "model_variant": "x_8gf"},
+    "regnet_x_16gf": {"type": "regnet", "model_variant": "x_16gf"},
+    "regnet_x_32gf": {"type": "regnet", "model_variant": "x_32gf"},
+    "vgg11": {"type": "vgg", "model_variant": 11},
+    "vgg11_bn": {"type": "vgg", "model_variant": "11_bn"},
+    "vgg13": {"type": "vgg", "model_variant": 13},
+    "vgg13_bn": {"type": "vgg", "model_variant": "13_bn"},
+    "vgg16": {"type": "vgg", "model_variant": 16},
+    "vgg16_bn": {"type": "vgg", "model_variant": "16_bn"},
+    "vgg19": {"type": "vgg", "model_variant": 19},
+    "vgg19_bn": {"type": "vgg", "model_variant": "19_bn"},
+    "shufflenet_v2_x0_5": {"type": "shufflenet_v2", "model_variant": "x0_5"},
+    "shufflenet_v2_x1_0": {"type": "shufflenet_v2", "model_variant": "x1_0"},
+    "shufflenet_v2_x1_5": {"type": "shufflenet_v2", "model_variant": "x1_5"},
+    "shufflenet_v2_x2_0": {"type": "shufflenet_v2", "model_variant": "x2_0"},
+    "squeezenet1_0": {"type": "squeezenet", "model_variant": "1_0"},
+    "squeezenet1_1": {"type": "squeezenet", "model_variant": "1_1"},
+    "swin_t": {"type": "swin_transformer", "model_variant": "t"},
+    "swin_s": {"type": "swin_transformer", "model_variant": "s"},
+    "swin_b": {"type": "swin_transformer", "model_variant": "b"},
+    "swin_v2_t": {"type": "swin_transformer", "model_variant": "v2_t"},
+    "swin_v2_s": {"type": "swin_transformer", "model_variant": "v2_s"},
+    "swin_v2_b": {"type": "swin_transformer", "model_variant": "v2_b"},
+    "vit_b_16": {"type": "vision_transformer", "model_variant": "b_16"},
+    "vit_b_32": {"type": "vision_transformer", "model_variant": "b_32"},
+    "vit_l_16": {"type": "vision_transformer", "model_variant": "l_16"},
+    "vit_l_32": {"type": "vision_transformer", "model_variant": "l_32"},
+    "vit_h_14": {"type": "vision_transformer", "model_variant": "h_14"},
+    "convnext_tiny": {"type": "convnext", "model_variant": "tiny"},
+    "convnext_small": {"type": "convnext", "model_variant": "small"},
+    "convnext_base": {"type": "convnext", "model_variant": "base"},
+    "convnext_large": {"type": "convnext", "model_variant": "large"},
+    "maxvit_t": {"type": "maxvit", "model_variant": "t"},
+    "alexnet": {"type": "alexnet"},
+    "googlenet": {"type": "googlenet"},
+    "inception_v3": {"type": "inception_v3"},
+    "mobilenet_v2": {"type": "mobilenet_v2"},
+    "mobilenet_v3_large": {"type": "mobilenet_v3_large"},
+    "mobilenet_v3_small": {"type": "mobilenet_v3_small"},
+}
+METRIC_DISPLAY_NAMES = {
+    "accuracy": "Accuracy",
+    "accuracy_micro": "Accuracy-Micro",
+    "loss": "Loss",
+    "roc_auc": "ROC-AUC",
+    "roc_auc_macro": "ROC-AUC-Macro",
+    "roc_auc_micro": "ROC-AUC-Micro",
+    "hits_at_k": "Hits at K",
+    "precision": "Precision",
+    "recall": "Recall",
+    "specificity": "Specificity",
+    "kappa_score": "Cohen's Kappa",
+    "token_accuracy": "Token Accuracy",
+    "avg_precision_macro": "Precision-Macro",
+    "avg_recall_macro": "Recall-Macro",
+    "avg_f1_score_macro": "F1-score-Macro",
+    "avg_precision_micro": "Precision-Micro",
+    "avg_recall_micro": "Recall-Micro",
+    "avg_f1_score_micro": "F1-score-Micro",
+    "avg_precision_weighted": "Precision-Weighted",
+    "avg_recall_weighted": "Recall-Weighted",
+    "avg_f1_score_weighted": "F1-score-Weighted",
+    "average_precision_macro": "Precision-Average-Macro",
+    "average_precision_micro": "Precision-Average-Micro",
+    "average_precision_samples": "Precision-Average-Samples",
+    "mean_squared_error": "Mean Squared Error",
+    "mean_absolute_error": "Mean Absolute Error",
+    "r2": "R² Score",
+    "root_mean_squared_error": "Root Mean Squared Error",
+    "mean_absolute_percentage_error": "Mean Absolute % Error",
+    "root_mean_squared_percentage_error": "Root Mean Squared % Error",
+}
--- a/image_learner.xml	Wed Jul 02 18:59:10 2025 +0000
+++ b/image_learner.xml	Thu Jul 03 20:43:24 2025 +0000
@@ -44,23 +44,26 @@
                     #if $batch_size_define == "true"
                         --batch-size "$batch_size"
                     #end if
-                    --split-probabilities "$train_split" "$val_split" "$test_split"   
+                    --split-probabilities "$train_split" "$val_split" "$test_split"
                 #end if
-                --random-seed "$random_seed" 
+                #if $augmentation
+                    --augmentation "$augmentation"
+                #end if
+                --random-seed "$random_seed"
                 --output-dir "." &&
 
             mkdir -p '$output_model.extra_files_path' &&
             cp -r experiment_run/model/*.json experiment_run/model/model_weights '$output_model.extra_files_path' &&
 
-            echo "Image Learner Classification Experiment is Done!"        
+            echo "Image Learner Classification Experiment is Done!"
         ]]>
     </command>
-    
+
     <inputs>
         <param argument="input_csv" type="data" format="csv" optional="false" label="the metadata csv containing image_path column, label column and optional split column" />
         <param name="image_zip" type="data" format="zip" optional="false" label="Image zip" help="Image zip file containing your image data"/>
         <param name="model_name" type="select" optional="false" label="Select a model for your experiment" >
-           
+
             <option value="resnet18">Resnet18</option>
             <option value="resnet34">Resnet34</option>
             <option value="resnet50">Resnet50</option>
@@ -140,8 +143,7 @@
         <conditional name="scratch_fine_tune">
             <param name="use_pretrained" type="select"
                 label="Use pretrained weights?"
-                help="If select no, the encoder, combiner, and decoder will all be initialized and trained from scratch.  
-               (e.g. when your images are very different from ImageNet or no suitable pretrained model exists.)">
+                help="If select no, the encoder, combiner, and decoder will all be initialized and trained from scratch. (e.g. when your images are very different from ImageNet or no suitable pretrained model exists.)">
                 <option value="false">No</option>
                 <option value="true" selected="true">Yes</option>
             </param>
@@ -156,6 +158,20 @@
                 <!-- No additional parameters to show if the user selects 'No' -->
             </when>
         </conditional>
+        <param argument="augmentation"
+               name="augmentation"
+               type="select"
+               multiple="true"
+               display="checkboxes"
+               label="Image Augmentation"
+               help="Pick any combination of augmentations to apply">
+            <option value="random_horizontal_flip">Random Horizontal Flip</option>
+            <option value="random_vertical_flip">Random Vertical Flip</option>
+            <option value="random_rotate">Random Rotate</option>
+            <option value="random_blur">Random Blur</option>
+            <option value="random_brightness">Random Brightness</option>
+            <option value="random_contrast">Random Contrast</option>
+        </param>
         <param argument="random_seed" type="integer" value="42" optional="true" label="Random seed (set for reproducibility)" min="0" max="999999"/>
         <conditional name="advanced_settings">
             <param name="customize_defaults" type="select" label="Customize Default Settings?" help="Select yes if you want to customize the default settings of the experiment.">
@@ -205,8 +221,8 @@
             <when value="false">
                 <!-- No additional parameters to show if the user selects 'No' -->
             </when>
-        </conditional>    
-    </inputs>       
+        </conditional>
+    </inputs>
     <outputs>
         <data format="ludwig_model" name="output_model" label="${tool.name} trained model on ${on_string}" />
         <data format="html" name="output_report" from_work_dir="image_classification_results_report.html" label="${tool.name} report on ${on_string}" />
@@ -238,6 +254,48 @@
                 </element>
             </output_collection>
         </test>
+        <test expect_num_outputs="3">
+            <param name="input_csv" value="mnist_subset.csv" ftype="csv" />
+            <param name="image_zip" value="mnist_subset.zip" ftype="zip" />
+            <param name="model_name" value="resnet18" />
+            <param name="augmentation" value="random_horizontal_flip,random_vertical_flip,random_rotate" />
+            <output name="output_report">
+                <assert_contents>
+                    <has_text text="Results Summary" />
+                    <has_text text="Train/Validation Results" />
+                    <has_text text="Test Results" />
+                </assert_contents>
+            </output>
+
+            <output_collection name="output_pred_csv" type="list" >
+                <element name="predictions.csv" >
+                    <assert_contents>
+                        <has_n_columns n="1" />
+                    </assert_contents>
+                </element>
+            </output_collection>
+        </test>
+        <test expect_num_outputs="3">
+            <param name="input_csv" value="utkface_labels.csv" ftype="csv" />
+            <param name="image_zip" value="age_regression.zip" ftype="zip" />
+            <param name="model_name" value="resnet18" />
+            <output name="output_report">
+                <assert_contents>
+                    <has_text text="Results Summary" />
+                    <has_text text="Train/Validation Results" />
+                    <has_text text="Test Results" />
+                </assert_contents>
+            </output>
+            <output name="output_report" file="expected_regression.html" compare="sim_size"/>
+
+            <output_collection name="output_pred_csv" type="list" >
+                <element name="predictions.csv" >
+                    <assert_contents>
+                        <has_n_columns n="1" />
+                    </assert_contents>
+                </element>
+            </output_collection>
+        </test>
     </tests>
     <help>
         <![CDATA[
@@ -248,6 +306,8 @@
 Optionally, you can also add a column with the name 'split' to specify which split each row belongs to (train, val, test). 
 If you do not provide a split column, the tool will automatically split the data into train, val, and test sets based on the proportions you specify or [0.7, 0.1, 0.2] by default.
 
+**If the selected label column has more than 10 unique values, the tool will automatically treat the task as a regression problem and apply appropriate metrics (e.g., MSE, RMSE, R²).**
+
 
 **Outputs**
 The tool will output a trained model in the form of a ludwig_model file,
--- a/image_learner_cli.py	Wed Jul 02 18:59:10 2025 +0000
+++ b/image_learner_cli.py	Thu Jul 03 20:43:24 2025 +0000
@@ -1,4 +1,3 @@
-#!/usr/bin/env python3
 import argparse
 import json
 import logging
@@ -11,7 +10,18 @@
 from typing import Any, Dict, Optional, Protocol, Tuple
 
 import pandas as pd
+import pandas.api.types as ptypes
 import yaml
+from constants import (
+    IMAGE_PATH_COLUMN_NAME,
+    LABEL_COLUMN_NAME,
+    METRIC_DISPLAY_NAMES,
+    MODEL_ENCODER_TEMPLATES,
+    SPLIT_COLUMN_NAME,
+    TEMP_CONFIG_FILENAME,
+    TEMP_CSV_FILENAME,
+    TEMP_DIR_PREFIX
+)
 from ludwig.globals import (
     DESCRIPTION_FILE_NAME,
     PREDICTIONS_PARQUET_FILE_NAME,
@@ -21,258 +31,29 @@
 from ludwig.utils.data_utils import get_split_path
 from ludwig.visualize import get_visualizations_registry
 from sklearn.model_selection import train_test_split
-from utils import encode_image_to_base64, get_html_closing, get_html_template
-
-# --- Constants ---
-SPLIT_COLUMN_NAME = "split"
-LABEL_COLUMN_NAME = "label"
-IMAGE_PATH_COLUMN_NAME = "image_path"
-DEFAULT_SPLIT_PROBABILITIES = [0.7, 0.1, 0.2]
-TEMP_CSV_FILENAME = "processed_data_for_ludwig.csv"
-TEMP_CONFIG_FILENAME = "ludwig_config.yaml"
-TEMP_DIR_PREFIX = "ludwig_api_work_"
-MODEL_ENCODER_TEMPLATES: Dict[str, Any] = {
-    "stacked_cnn": "stacked_cnn",
-    "resnet18": {"type": "resnet", "model_variant": 18},
-    "resnet34": {"type": "resnet", "model_variant": 34},
-    "resnet50": {"type": "resnet", "model_variant": 50},
-    "resnet101": {"type": "resnet", "model_variant": 101},
-    "resnet152": {"type": "resnet", "model_variant": 152},
-    "resnext50_32x4d": {"type": "resnext", "model_variant": "50_32x4d"},
-    "resnext101_32x8d": {"type": "resnext", "model_variant": "101_32x8d"},
-    "resnext101_64x4d": {"type": "resnext", "model_variant": "101_64x4d"},
-    "resnext152_32x8d": {"type": "resnext", "model_variant": "152_32x8d"},
-    "wide_resnet50_2": {"type": "wide_resnet", "model_variant": "50_2"},
-    "wide_resnet101_2": {"type": "wide_resnet", "model_variant": "101_2"},
-    "wide_resnet103_2": {"type": "wide_resnet", "model_variant": "103_2"},
-    "efficientnet_b0": {"type": "efficientnet", "model_variant": "b0"},
-    "efficientnet_b1": {"type": "efficientnet", "model_variant": "b1"},
-    "efficientnet_b2": {"type": "efficientnet", "model_variant": "b2"},
-    "efficientnet_b3": {"type": "efficientnet", "model_variant": "b3"},
-    "efficientnet_b4": {"type": "efficientnet", "model_variant": "b4"},
-    "efficientnet_b5": {"type": "efficientnet", "model_variant": "b5"},
-    "efficientnet_b6": {"type": "efficientnet", "model_variant": "b6"},
-    "efficientnet_b7": {"type": "efficientnet", "model_variant": "b7"},
-    "efficientnet_v2_s": {"type": "efficientnet", "model_variant": "v2_s"},
-    "efficientnet_v2_m": {"type": "efficientnet", "model_variant": "v2_m"},
-    "efficientnet_v2_l": {"type": "efficientnet", "model_variant": "v2_l"},
-    "regnet_y_400mf": {"type": "regnet", "model_variant": "y_400mf"},
-    "regnet_y_800mf": {"type": "regnet", "model_variant": "y_800mf"},
-    "regnet_y_1_6gf": {"type": "regnet", "model_variant": "y_1_6gf"},
-    "regnet_y_3_2gf": {"type": "regnet", "model_variant": "y_3_2gf"},
-    "regnet_y_8gf": {"type": "regnet", "model_variant": "y_8gf"},
-    "regnet_y_16gf": {"type": "regnet", "model_variant": "y_16gf"},
-    "regnet_y_32gf": {"type": "regnet", "model_variant": "y_32gf"},
-    "regnet_y_128gf": {"type": "regnet", "model_variant": "y_128gf"},
-    "regnet_x_400mf": {"type": "regnet", "model_variant": "x_400mf"},
-    "regnet_x_800mf": {"type": "regnet", "model_variant": "x_800mf"},
-    "regnet_x_1_6gf": {"type": "regnet", "model_variant": "x_1_6gf"},
-    "regnet_x_3_2gf": {"type": "regnet", "model_variant": "x_3_2gf"},
-    "regnet_x_8gf": {"type": "regnet", "model_variant": "x_8gf"},
-    "regnet_x_16gf": {"type": "regnet", "model_variant": "x_16gf"},
-    "regnet_x_32gf": {"type": "regnet", "model_variant": "x_32gf"},
-    "vgg11": {"type": "vgg", "model_variant": 11},
-    "vgg11_bn": {"type": "vgg", "model_variant": "11_bn"},
-    "vgg13": {"type": "vgg", "model_variant": 13},
-    "vgg13_bn": {"type": "vgg", "model_variant": "13_bn"},
-    "vgg16": {"type": "vgg", "model_variant": 16},
-    "vgg16_bn": {"type": "vgg", "model_variant": "16_bn"},
-    "vgg19": {"type": "vgg", "model_variant": 19},
-    "vgg19_bn": {"type": "vgg", "model_variant": "19_bn"},
-    "shufflenet_v2_x0_5": {"type": "shufflenet_v2", "model_variant": "x0_5"},
-    "shufflenet_v2_x1_0": {"type": "shufflenet_v2", "model_variant": "x1_0"},
-    "shufflenet_v2_x1_5": {"type": "shufflenet_v2", "model_variant": "x1_5"},
-    "shufflenet_v2_x2_0": {"type": "shufflenet_v2", "model_variant": "x2_0"},
-    "squeezenet1_0": {"type": "squeezenet", "model_variant": "1_0"},
-    "squeezenet1_1": {"type": "squeezenet", "model_variant": "1_1"},
-    "swin_t": {"type": "swin_transformer", "model_variant": "t"},
-    "swin_s": {"type": "swin_transformer", "model_variant": "s"},
-    "swin_b": {"type": "swin_transformer", "model_variant": "b"},
-    "swin_v2_t": {"type": "swin_transformer", "model_variant": "v2_t"},
-    "swin_v2_s": {"type": "swin_transformer", "model_variant": "v2_s"},
-    "swin_v2_b": {"type": "swin_transformer", "model_variant": "v2_b"},
-    "vit_b_16": {"type": "vision_transformer", "model_variant": "b_16"},
-    "vit_b_32": {"type": "vision_transformer", "model_variant": "b_32"},
-    "vit_l_16": {"type": "vision_transformer", "model_variant": "l_16"},
-    "vit_l_32": {"type": "vision_transformer", "model_variant": "l_32"},
-    "vit_h_14": {"type": "vision_transformer", "model_variant": "h_14"},
-    "convnext_tiny": {"type": "convnext", "model_variant": "tiny"},
-    "convnext_small": {"type": "convnext", "model_variant": "small"},
-    "convnext_base": {"type": "convnext", "model_variant": "base"},
-    "convnext_large": {"type": "convnext", "model_variant": "large"},
-    "maxvit_t": {"type": "maxvit", "model_variant": "t"},
-    "alexnet": {"type": "alexnet"},
-    "googlenet": {"type": "googlenet"},
-    "inception_v3": {"type": "inception_v3"},
-    "mobilenet_v2": {"type": "mobilenet_v2"},
-    "mobilenet_v3_large": {"type": "mobilenet_v3_large"},
-    "mobilenet_v3_small": {"type": "mobilenet_v3_small"},
-}
-METRIC_DISPLAY_NAMES = {
-    "accuracy": "Accuracy",
-    "accuracy_micro": "Accuracy-Micro",
-    "loss": "Loss",
-    "roc_auc": "ROC-AUC",
-    "roc_auc_macro": "ROC-AUC-Macro",
-    "roc_auc_micro": "ROC-AUC-Micro",
-    "hits_at_k": "Hits at K",
-    "precision": "Precision",
-    "recall": "Recall",
-    "specificity": "Specificity",
-    "kappa_score": "Cohen's Kappa",
-    "token_accuracy": "Token Accuracy",
-    "avg_precision_macro": "Precision-Macro",
-    "avg_recall_macro": "Recall-Macro",
-    "avg_f1_score_macro": "F1-score-Macro",
-    "avg_precision_micro": "Precision-Micro",
-    "avg_recall_micro": "Recall-Micro",
-    "avg_f1_score_micro": "F1-score-Micro",
-    "avg_precision_weighted": "Precision-Weighted",
-    "avg_recall_weighted": "Recall-Weighted",
-    "avg_f1_score_weighted": "F1-score-Weighted",
-    "average_precision_macro": " Precision-Average-Macro",
-    "average_precision_micro": "Precision-Average-Micro",
-    "average_precision_samples": "Precision-Average-Samples",
-}
+from utils import (
+    build_tabbed_html,
+    encode_image_to_base64,
+    get_html_closing,
+    get_html_template,
+    get_metrics_help_modal
+)
 
 # --- Logging Setup ---
 logging.basicConfig(
     level=logging.INFO,
-    format="%(asctime)s %(levelname)s %(name)s: %(message)s",
+    format='%(asctime)s %(levelname)s %(name)s: %(message)s',
 )
 logger = logging.getLogger("ImageLearner")
 
 
-def get_metrics_help_modal() -> str:
-    modal_html = """
-<div id="metricsHelpModal" class="modal">
-  <div class="modal-content">
-    <span class="close">×</span>
-    <h2>Model Evaluation Metrics — Help Guide</h2>
-    <div class="metrics-guide">
-      <h3>1) General Metrics</h3>
-      <p><strong>Loss:</strong> Measures the difference between predicted and actual values. Lower is better. Often used for optimization during training.</p>
-      <p><strong>Accuracy:</strong> Proportion of correct predictions among all predictions. Simple but can be misleading for imbalanced datasets.</p>
-      <p><strong>Micro Accuracy:</strong> Calculates accuracy by summing up all individual true positives and true negatives across all classes, making it suitable for multiclass or multilabel problems.</p>
-      <p><strong>Token Accuracy:</strong> Measures how often the predicted tokens (e.g., in sequences) match the true tokens. Useful in sequence prediction tasks like NLP.</p>
-      <h3>2) Precision, Recall & Specificity</h3>
-      <p><strong>Precision:</strong> Out of all positive predictions, how many were correct. Precision = TP / (TP + FP). Helps when false positives are costly.</p>
-      <p><strong>Recall (Sensitivity):</strong> Out of all actual positives, how many were predicted correctly. Recall = TP / (TP + FN). Important when missing positives is risky.</p>
-      <p><strong>Specificity:</strong> True negative rate. Measures how well the model identifies negatives. Specificity = TN / (TN + FP). Useful in medical testing to avoid false alarms.</p>
-      <h3>3) Macro, Micro, and Weighted Averages</h3>
-      <p><strong>Macro Precision / Recall / F1:</strong> Averages the metric across all classes, treating each class equally, regardless of class frequency. Best when class sizes are balanced.</p>
-      <p><strong>Micro Precision / Recall / F1:</strong> Aggregates TP, FP, FN across all classes before computing the metric. Gives a global view and is ideal for class-imbalanced problems.</p>
-      <p><strong>Weighted Precision / Recall / F1:</strong> Averages each metric across classes, weighted by the number of true instances per class. Balances importance of classes based on frequency.</p>
-      <h3>4) Average Precision (PR-AUC Variants)</h3>
-      <p><strong>Average Precision Macro:</strong> Precision-Recall AUC averaged across all classes equally. Useful for balanced multi-class problems.</p>
-      <p><strong>Average Precision Micro:</strong> Global Precision-Recall AUC using all instances. Best for imbalanced data or multi-label classification.</p>
-      <p><strong>Average Precision Samples:</strong> Precision-Recall AUC averaged across individual samples (not classes). Ideal for multi-label problems where each sample can belong to multiple classes.</p>
-      <h3>5) ROC-AUC Variants</h3>
-      <p><strong>ROC-AUC:</strong> Measures model's ability to distinguish between classes. AUC = 1 is perfect; 0.5 is random guessing. Use for binary classification.</p>
-      <p><strong>Macro ROC-AUC:</strong> Averages the AUC across all classes equally. Suitable when classes are balanced and of equal importance.</p>
-      <p><strong>Micro ROC-AUC:</strong> Computes AUC from aggregated predictions across all classes. Useful in multiclass or multilabel settings with imbalance.</p>
-      <h3>6) Ranking Metrics</h3>
-      <p><strong>Hits at K:</strong> Measures whether the true label is among the top-K predictions. Common in recommendation systems and retrieval tasks.</p>
-      <h3>7) Confusion Matrix Stats (Per Class)</h3>
-      <p><strong>True Positives / Negatives (TP / TN):</strong> Correct predictions for positives and negatives respectively.</p>
-      <p><strong>False Positives / Negatives (FP / FN):</strong> Incorrect predictions — false alarms and missed detections.</p>
-      <h3>8) Other Useful Metrics</h3>
-      <p><strong>Cohen's Kappa:</strong> Measures agreement between predicted and actual values adjusted for chance. Useful for multiclass classification with imbalanced labels.</p>
-      <p><strong>Matthews Correlation Coefficient (MCC):</strong> Balanced measure of prediction quality that takes into account TP, TN, FP, and FN. Particularly effective for imbalanced datasets.</p>
-      <h3>9) Metric Recommendations</h3>
-      <ul>
-        <li>Use <strong>Accuracy + F1</strong> for balanced data.</li>
-        <li>Use <strong>Precision, Recall, ROC-AUC</strong> for imbalanced datasets.</li>
-        <li>Use <strong>Average Precision Micro</strong> for multilabel or class-imbalanced problems.</li>
-        <li>Use <strong>Macro scores</strong> when all classes should be treated equally.</li>
-        <li>Use <strong>Weighted scores</strong> when class imbalance should be accounted for without ignoring small classes.</li>
-        <li>Use <strong>Confusion Matrix stats</strong> to analyze class-wise performance.</li>
-        <li>Use <strong>Hits at K</strong> for recommendation or ranking-based tasks.</li>
-      </ul>
-    </div>
-  </div>
-</div>
-"""
-    modal_css = """
-<style>
-.modal {
-  display: none;
-  position: fixed;
-  z-index: 1;
-  left: 0;
-  top: 0;
-  width: 100%;
-  height: 100%;
-  overflow: auto;
-  background-color: rgba(0,0,0,0.4);
-}
-.modal-content {
-  background-color: #fefefe;
-  margin: 15% auto;
-  padding: 20px;
-  border: 1px solid #888;
-  width: 80%;
-  max-width: 800px;
-}
-.close {
-  color: #aaa;
-  float: right;
-  font-size: 28px;
-  font-weight: bold;
-}
-.close:hover,
-.close:focus {
-  color: black;
-  text-decoration: none;
-  cursor: pointer;
-}
-.metrics-guide h3 {
-  margin-top: 20px;
-}
-.metrics-guide p {
-  margin: 5px 0;
-}
-.metrics-guide ul {
-  margin: 10px 0;
-  padding-left: 20px;
-}
-</style>
-"""
-    modal_js = """
-<script>
-document.addEventListener("DOMContentLoaded", function() {
-  var modal = document.getElementById("metricsHelpModal");
-  var closeBtn = document.getElementsByClassName("close")[0];
-
-  document.querySelectorAll(".openMetricsHelp").forEach(btn => {
-    btn.onclick = function() {
-      modal.style.display = "block";
-    };
-  });
-
-  if (closeBtn) {
-    closeBtn.onclick = function() {
-      modal.style.display = "none";
-    };
-  }
-
-  window.onclick = function(event) {
-    if (event.target == modal) {
-      modal.style.display = "none";
-    }
-  }
-});
-</script>
-"""
-    return modal_css + modal_html + modal_js
-
-
 def format_config_table_html(
     config: dict,
     split_info: Optional[str] = None,
     training_progress: dict = None,
 ) -> str:
     display_keys = [
+        "task_type",
         "model_name",
         "epochs",
         "batch_size",
@@ -287,6 +68,8 @@
 
     for key in display_keys:
         val = config.get(key, "N/A")
+        if key == "task_type":
+            val = val.title() if isinstance(val, str) else val
         if key == "batch_size":
             if val is not None:
                 val = int(val)
@@ -348,6 +131,18 @@
             f"</tr>"
         )
 
+    aug_cfg = config.get("augmentation")
+    if aug_cfg:
+        types = [str(a.get("type", "")) for a in aug_cfg]
+        aug_val = ", ".join(types)
+        rows.append(
+            "<tr>"
+            "<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left;'>Augmentation</td>"
+            "<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center;'>"
+            f"{aug_val}</td>"
+            "</tr>"
+        )
+
     if split_info:
         rows.append(
             f"<tr>"
@@ -371,7 +166,7 @@
         "<p style='text-align: center; font-size: 0.9em;'>"
         "Model trained using Ludwig.<br>"
         "If want to learn more about Ludwig default settings,"
-        "please check the their <a href='https://ludwig.ai' target='_blank'>"
+        "please check their <a href='https://ludwig.ai' target='_blank'>"
         "website(ludwig.ai)</a>."
         "</p><hr>"
     )
@@ -380,6 +175,8 @@
 def detect_output_type(test_stats):
     """Detects if the output type is 'binary' or 'category' based on test statistics."""
     label_stats = test_stats.get("label", {})
+    if "mean_squared_error" in label_stats:
+        return "regression"
     per_class = label_stats.get("per_class_stats", {})
     if len(per_class) == 2:
         return "binary"
@@ -420,6 +217,24 @@
                 "specificity": get_last_value(label_stats, "specificity"),
                 "roc_auc": get_last_value(label_stats, "roc_auc"),
             }
+        elif output_type == "regression":
+            metrics[split] = {
+                "loss": get_last_value(label_stats, "loss"),
+                "mean_absolute_error": get_last_value(
+                    label_stats, "mean_absolute_error"
+                ),
+                "mean_absolute_percentage_error": get_last_value(
+                    label_stats, "mean_absolute_percentage_error"
+                ),
+                "mean_squared_error": get_last_value(label_stats, "mean_squared_error"),
+                "root_mean_squared_error": get_last_value(
+                    label_stats, "root_mean_squared_error"
+                ),
+                "root_mean_squared_percentage_error": get_last_value(
+                    label_stats, "root_mean_squared_percentage_error"
+                ),
+                "r2": get_last_value(label_stats, "r2"),
+            }
         else:
             metrics[split] = {
                 "accuracy": get_last_value(label_stats, "accuracy"),
@@ -565,7 +380,9 @@
     return html
 
 
-def format_test_merged_stats_table_html(test_metrics: Dict[str, Optional[float]]) -> str:
+def format_test_merged_stats_table_html(
+    test_metrics: Dict[str, Optional[float]],
+) -> str:
     """Formats an HTML table for test metrics."""
     rows = []
     for key in sorted(test_metrics.keys()):
@@ -598,63 +415,6 @@
     return html
 
 
-def build_tabbed_html(metrics_html: str, train_val_html: str, test_html: str) -> str:
-    return f"""
-<style>
-.tabs {{
-  display: flex;
-  border-bottom: 2px solid #ccc;
-  margin-bottom: 1rem;
-}}
-.tab {{
-  padding: 10px 20px;
-  cursor: pointer;
-  border: 1px solid #ccc;
-  border-bottom: none;
-  background: #f9f9f9;
-  margin-right: 5px;
-  border-top-left-radius: 8px;
-  border-top-right-radius: 8px;
-}}
-.tab.active {{
-  background: white;
-  font-weight: bold;
-}}
-.tab-content {{
-  display: none;
-  padding: 20px;
-  border: 1px solid #ccc;
-  border-top: none;
-}}
-.tab-content.active {{
-  display: block;
-}}
-</style>
-<div class="tabs">
-  <div class="tab active" onclick="showTab('metrics')"> Config & Results Summary</div>
-  <div class="tab" onclick="showTab('trainval')"> Train/Validation Results</div>
-  <div class="tab" onclick="showTab('test')"> Test Results</div>
-</div>
-<div id="metrics" class="tab-content active">
-  {metrics_html}
-</div>
-<div id="trainval" class="tab-content">
-  {train_val_html}
-</div>
-<div id="test" class="tab-content">
-  {test_html}
-</div>
-<script>
-function showTab(id) {{
-  document.querySelectorAll('.tab-content').forEach(el => el.classList.remove('active'));
-  document.querySelectorAll('.tab').forEach(el => el.classList.remove('active'));
-  document.getElementById(id).classList.add('active');
-  document.querySelector(`.tab[onclick*="${{id}}"]`).classList.add('active');
-}}
-</script>
-"""
-
-
 def split_data_0_2(
     df: pd.DataFrame,
     split_column: str,
@@ -727,16 +487,15 @@
     ) -> None:
         ...
 
-    def generate_plots(
-        self,
-        output_dir: Path
-    ) -> None:
+    def generate_plots(self, output_dir: Path) -> None:
         ...
 
     def generate_html_report(
         self,
         title: str,
-        output_dir: str
+        output_dir: str,
+        config: Dict[str, Any],
+        split_info: str,
     ) -> Path:
         ...
 
@@ -749,23 +508,21 @@
         config_params: Dict[str, Any],
         split_config: Dict[str, Any],
     ) -> str:
-        """Build and serialize the Ludwig YAML configuration."""
         logger.info("LudwigDirectBackend: Preparing YAML configuration.")
 
         model_name = config_params.get("model_name", "resnet18")
         use_pretrained = config_params.get("use_pretrained", False)
         fine_tune = config_params.get("fine_tune", False)
+        if use_pretrained:
+            trainable = bool(fine_tune)
+        else:
+            trainable = True
         epochs = config_params.get("epochs", 10)
         batch_size = config_params.get("batch_size")
         num_processes = config_params.get("preprocessing_num_processes", 1)
         early_stop = config_params.get("early_stop", None)
         learning_rate = config_params.get("learning_rate")
         learning_rate = "auto" if learning_rate is None else float(learning_rate)
-        trainable = fine_tune or (not use_pretrained)
-        if not use_pretrained and not trainable:
-            logger.warning("trainable=False; use_pretrained=False is ignored.")
-            logger.warning("Setting trainable=True to train the model from scratch.")
-            trainable = True
         raw_encoder = MODEL_ENCODER_TEMPLATES.get(model_name, model_name)
         if isinstance(raw_encoder, dict):
             encoder_config = {
@@ -779,39 +536,68 @@
         batch_size_cfg = batch_size or "auto"
 
         label_column_path = config_params.get("label_column_data_path")
+        label_series = None
         if label_column_path is not None and Path(label_column_path).exists():
             try:
                 label_series = pd.read_csv(label_column_path)[LABEL_COLUMN_NAME]
-                num_unique_labels = label_series.nunique()
             except Exception as e:
-                logger.warning(
-                    f"Could not determine label cardinality, defaulting to 'binary': {e}"
-                )
-                num_unique_labels = 2
+                logger.warning(f"Could not read label column for task detection: {e}")
+
+        if (
+            label_series is not None
+            and ptypes.is_numeric_dtype(label_series.dtype)
+            and label_series.nunique() > 10
+        ):
+            task_type = "regression"
         else:
-            logger.warning(
-                "label_column_data_path not provided, defaulting to 'binary'"
+            task_type = "classification"
+
+        config_params["task_type"] = task_type
+
+        image_feat: Dict[str, Any] = {
+            "name": IMAGE_PATH_COLUMN_NAME,
+            "type": "image",
+            "encoder": encoder_config,
+        }
+        if config_params.get("augmentation") is not None:
+            image_feat["augmentation"] = config_params["augmentation"]
+
+        if task_type == "regression":
+            output_feat = {
+                "name": LABEL_COLUMN_NAME,
+                "type": "number",
+                "decoder": {"type": "regressor"},
+                "loss": {"type": "mean_squared_error"},
+                "evaluation": {
+                    "metrics": [
+                        "mean_squared_error",
+                        "mean_absolute_error",
+                        "r2",
+                    ]
+                },
+            }
+            val_metric = config_params.get("validation_metric", "mean_squared_error")
+
+        else:
+            num_unique_labels = (
+                label_series.nunique() if label_series is not None else 2
             )
-            num_unique_labels = 2
-
-        output_type = "binary" if num_unique_labels == 2 else "category"
+            output_type = "binary" if num_unique_labels == 2 else "category"
+            output_feat = {"name": LABEL_COLUMN_NAME, "type": output_type}
+            val_metric = None
 
         conf: Dict[str, Any] = {
             "model_type": "ecd",
-            "input_features": [
-                {
-                    "name": IMAGE_PATH_COLUMN_NAME,
-                    "type": "image",
-                    "encoder": encoder_config,
-                }
-            ],
-            "output_features": [{"name": LABEL_COLUMN_NAME, "type": output_type}],
+            "input_features": [image_feat],
+            "output_features": [output_feat],
             "combiner": {"type": "concat"},
             "trainer": {
                 "epochs": epochs,
                 "early_stop": early_stop,
                 "batch_size": batch_size_cfg,
                 "learning_rate": learning_rate,
+                # only set validation_metric for regression
+                **({"validation_metric": val_metric} if val_metric else {}),
             },
             "preprocessing": {
                 "split": split_config,
@@ -876,7 +662,7 @@
             )
             raise
 
-    def get_training_process(self, output_dir) -> float:
+    def get_training_process(self, output_dir) -> Optional[Dict[str, Any]]:
         """Retrieve the learning rate used in the most recent Ludwig run."""
         output_dir = Path(output_dir)
         exp_dirs = sorted(
@@ -1000,11 +786,12 @@
 
         viz_registry = get_visualizations_registry()
         for viz_name, viz_func in viz_registry.items():
-            viz_dir_plot = None
             if viz_name in train_plots:
                 viz_dir_plot = train_viz
             elif viz_name in test_plots:
                 viz_dir_plot = test_viz
+            else:
+                continue
 
             try:
                 viz_func(
@@ -1040,6 +827,7 @@
         report_name = title.lower().replace(" ", "_") + "_report.html"
         report_path = cwd / report_name
         output_dir = Path(output_dir)
+        output_type = None
 
         exp_dirs = sorted(
             output_dir.glob("experiment_run*"),
@@ -1059,7 +847,6 @@
         metrics_html = ""
         train_val_metrics_html = ""
         test_metrics_html = ""
-
         try:
             train_stats_path = exp_dir / "training_statistics.json"
             test_stats_path = exp_dir / TEST_STATISTICS_FILE_NAME
@@ -1069,18 +856,14 @@
                 with open(test_stats_path) as f:
                     test_stats = json.load(f)
                 output_type = detect_output_type(test_stats)
-                all_metrics = extract_metrics_from_json(
-                    train_stats,
-                    test_stats,
-                    output_type,
-                )
                 metrics_html = format_stats_table_html(train_stats, test_stats)
                 train_val_metrics_html = format_train_val_stats_table_html(
-                    train_stats,
-                    test_stats,
+                    train_stats, test_stats
                 )
                 test_metrics_html = format_test_merged_stats_table_html(
-                    all_metrics["test"],
+                    extract_metrics_from_json(train_stats, test_stats, output_type)[
+                        "test"
+                    ]
                 )
         except Exception as e:
             logger.warning(
@@ -1090,11 +873,15 @@
         config_html = ""
         training_progress = self.get_training_process(output_dir)
         try:
-            config_html = format_config_table_html(config, split_info, training_progress)
+            config_html = format_config_table_html(
+                config, split_info, training_progress
+            )
         except Exception as e:
             logger.warning(f"Could not load config for HTML report: {e}")
 
-        def render_img_section(title: str, dir_path: Path, output_type: str = None) -> str:
+        def render_img_section(
+            title: str, dir_path: Path, output_type: str = None
+        ) -> str:
             if not dir_path.exists():
                 return f"<h2>{title}</h2><p><em>Directory not found.</em></p>"
 
@@ -1141,11 +928,7 @@
                     img_names[fname] for fname in display_order if fname in img_names
                 ]
                 remaining = sorted(
-                    [
-                        img
-                        for img in img_names.values()
-                        if img.name not in display_order
-                    ]
+                    [img for img in img_names.values() if img.name not in display_order]
                 )
                 imgs = ordered_imgs + remaining
 
@@ -1173,46 +956,61 @@
             section_html += "</div>"
             return section_html
 
-        button_html = """
-        <button class="help-modal-btn openMetricsHelp">Model Evaluation Metrics — Help Guide</button>
-        <br><br>
-        <style>
-        .help-modal-btn {
-            background-color: #17623b;
-            color: #fff;
-            border: none;
-            border-radius: 24px;
-            padding: 10px 28px;
-            font-size: 1.1rem;
-            font-weight: bold;
-            letter-spacing: 0.03em;
-            cursor: pointer;
-            transition: background 0.2s, box-shadow 0.2s;
-            box-shadow: 0 2px 8px rgba(23,98,59,0.07);
-        }
-        .help-modal-btn:hover, .help-modal-btn:focus {
-            background-color: #21895e;
-            outline: none;
-            box-shadow: 0 4px 16px rgba(23,98,59,0.14);
-        }
-        </style>
-        """
-        tab1_content = button_html + config_html + metrics_html
-        tab2_content = (
-            button_html
-            + train_val_metrics_html
-            + render_img_section("Training & Validation Visualizations", train_viz_dir)
+        tab1_content = config_html + metrics_html
+
+        tab2_content = train_val_metrics_html + render_img_section(
+            "Training & Validation Visualizations", train_viz_dir
         )
+
+        # --- Predictions vs Ground Truth table ---
+        preds_section = ""
+        parquet_path = exp_dir / PREDICTIONS_PARQUET_FILE_NAME
+        if parquet_path.exists():
+            try:
+                # 1) load predictions from Parquet
+                df_preds = pd.read_parquet(parquet_path).reset_index(drop=True)
+                # assume the column containing your model's prediction is named "prediction"
+                # or contains that substring:
+                pred_col = next(
+                    (c for c in df_preds.columns if "prediction" in c.lower()),
+                    None,
+                )
+                if pred_col is None:
+                    raise ValueError("No prediction column found in Parquet output")
+                df_pred = df_preds[[pred_col]].rename(columns={pred_col: "prediction"})
+
+                # 2) load ground truth for the test split from prepared CSV
+                df_all = pd.read_csv(config["label_column_data_path"])
+                df_gt = df_all[df_all[SPLIT_COLUMN_NAME] == 2][
+                    LABEL_COLUMN_NAME
+                ].reset_index(drop=True)
+
+                # 3) concatenate side‐by‐side
+                df_table = pd.concat([df_gt, df_pred], axis=1)
+                df_table.columns = [LABEL_COLUMN_NAME, "prediction"]
+
+                # 4) render as HTML
+                preds_html = df_table.to_html(index=False, classes="predictions-table")
+                preds_section = (
+                    "<h2 style='text-align: center;'>Predictions vs. Ground Truth</h2>"
+                    "<div style='overflow-x:auto; margin-bottom:20px;'>"
+                    + preds_html
+                    + "</div>"
+                )
+            except Exception as e:
+                logger.warning(f"Could not build Predictions vs GT table: {e}")
+        # Test tab = Metrics + Preds table + Visualizations
+
         tab3_content = (
-            button_html
-            + test_metrics_html
+            test_metrics_html
+            + preds_section
             + render_img_section("Test Visualizations", test_viz_dir, output_type)
         )
 
+        # assemble the tabs and help modal
         tabbed_html = build_tabbed_html(tab1_content, tab2_content, tab3_content)
         modal_html = get_metrics_help_modal()
-        html += tabbed_html + modal_html
-        html += get_html_closing()
+        html += tabbed_html + modal_html + get_html_closing()
 
         try:
             with open(report_path, "w") as f:
@@ -1263,7 +1061,7 @@
             logger.error("Error extracting zip file", exc_info=True)
             raise
 
-    def _prepare_data(self) -> Tuple[Path, Dict[str, Any]]:
+    def _prepare_data(self) -> Tuple[Path, Dict[str, Any], str]:
         """Load CSV, update image paths, handle splits, and write prepared CSV."""
         if not self.temp_dir or not self.image_extract_dir:
             raise RuntimeError("Temp dirs not initialized before data prep.")
@@ -1302,8 +1100,9 @@
                 f"for train/val/test."
             )
 
-        final_csv = TEMP_CSV_FILENAME
+        final_csv = self.temp_dir / TEMP_CSV_FILENAME
         try:
+
             df.to_csv(final_csv, index=False)
             logger.info(f"Saved prepared data to {final_csv}")
         except Exception:
@@ -1312,7 +1111,9 @@
 
         return final_csv, split_config, split_info
 
-    def _process_fixed_split(self, df: pd.DataFrame) -> Dict[str, Any]:
+    def _process_fixed_split(
+        self, df: pd.DataFrame
+    ) -> Tuple[pd.DataFrame, Dict[str, Any], str]:
         """Process a fixed split column (0=train,1=val,2=test)."""
         logger.info(f"Fixed split column '{SPLIT_COLUMN_NAME}' detected.")
         try:
@@ -1384,6 +1185,7 @@
                 "random_seed": self.args.random_seed,
                 "early_stop": self.args.early_stop,
                 "label_column_data_path": csv_path,
+                "augmentation": self.args.augmentation,
             }
             yaml_str = self.backend.prepare_config(backend_args, split_cfg)
 
@@ -1422,6 +1224,29 @@
         return None
 
 
+def aug_parse(aug_string: str):
+    """
+    Parse comma-separated augmentation keys into Ludwig augmentation dicts.
+    Raises ValueError on unknown key.
+    """
+    mapping = {
+        "random_horizontal_flip": {"type": "random_horizontal_flip"},
+        "random_vertical_flip": {"type": "random_vertical_flip"},
+        "random_rotate": {"type": "random_rotate", "degree": 10},
+        "random_blur": {"type": "random_blur", "kernel_size": 3},
+        "random_brightness": {"type": "random_brightness", "min": 0.5, "max": 2.0},
+        "random_contrast": {"type": "random_contrast", "min": 0.5, "max": 2.0},
+    }
+    aug_list = []
+    for tok in aug_string.split(","):
+        key = tok.strip()
+        if key not in mapping:
+            valid = ", ".join(mapping.keys())
+            raise ValueError(f"Unknown augmentation '{key}'. Valid choices: {valid}")
+        aug_list.append(mapping[key])
+    return aug_list
+
+
 class SplitProbAction(argparse.Action):
     def __call__(self, parser, namespace, values, option_string=None):
         train, val, test = values
@@ -1508,7 +1333,10 @@
         metavar=("train", "val", "test"),
         action=SplitProbAction,
         default=[0.7, 0.1, 0.2],
-        help="Random split proportions (e.g., 0.7 0.1 0.2). Only used if no split column.",
+        help=(
+            "Random split proportions (e.g., 0.7 0.1 0.2)."
+            "Only used if no split column."
+        ),
     )
     parser.add_argument(
         "--random-seed",
@@ -1522,6 +1350,17 @@
         default=None,
         help="Learning rate. If not provided, Ludwig will auto-select it.",
     )
+    parser.add_argument(
+        "--augmentation",
+        type=str,
+        default=None,
+        help=(
+            "Comma-separated list (in order) of any of: "
+            "random_horizontal_flip, random_vertical_flip, random_rotate, "
+            "random_blur, random_brightness, random_contrast. "
+            "E.g. --augmentation random_horizontal_flip,random_rotate"
+        ),
+    )
 
     args = parser.parse_args()
 
@@ -1531,6 +1370,12 @@
         parser.error(f"CSV not found: {args.csv_file}")
     if not args.image_zip.is_file():
         parser.error(f"ZIP not found: {args.image_zip}")
+    if args.augmentation is not None:
+        try:
+            augmentation_setup = aug_parse(args.augmentation)
+            setattr(args, "augmentation", augmentation_setup)
+        except ValueError as e:
+            parser.error(str(e))
 
     backend_instance = LudwigDirectBackend()
     orchestrator = WorkflowOrchestrator(args, backend_instance)
Binary file test-data/age_regression.zip has changed
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/test-data/expected_regression.html	Thu Jul 03 20:43:24 2025 +0000
@@ -0,0 +1,276 @@
+
+    <html>
+    <head>
+        <meta charset="UTF-8">
+        <title>Galaxy-Ludwig Report</title>
+        <style>
+          body {
+              font-family: Arial, sans-serif;
+              margin: 0;
+              padding: 20px;
+              background-color: #f4f4f4;
+          }
+          .container {
+              max-width: 800px;
+              margin: auto;
+              background: white;
+              padding: 20px;
+              box-shadow: 0 0 10px rgba(0, 0, 0, 0.1);
+              overflow-x: auto;
+          }
+          h1 {
+              text-align: center;
+              color: #333;
+          }
+          h2 {
+              border-bottom: 2px solid #4CAF50;
+              color: #4CAF50;
+              padding-bottom: 5px;
+          }
+          table {
+              border-collapse: collapse;
+              margin: 20px 0;
+              width: 100%;
+              table-layout: fixed; /* Enforces consistent column widths */
+          }
+          table, th, td {
+              border: 1px solid #ddd;
+          }
+          th, td {
+              padding: 8px;
+              text-align: center; /* Center-align text */
+              vertical-align: middle; /* Center-align content vertically */
+              word-wrap: break-word; /* Break long words to avoid overflow */
+          }
+          th:first-child, td:first-child {
+              width: 5%; /* Smaller width for the first column */
+          }
+          th:nth-child(2), td:nth-child(2) {
+              width: 50%; /* Wider for the metric/description column */
+          }
+          th:last-child, td:last-child {
+              width: 25%; /* Value column gets remaining space */
+          }
+          th {
+              background-color: #4CAF50;
+              color: white;
+          }
+          .plot {
+              text-align: center;
+              margin: 20px 0;
+          }
+          .plot img {
+              max-width: 100%;
+              height: auto;
+          }
+        </style>
+    </head>
+    <body>
+    <div class="container">
+    <h1>Image Classification Results</h1>
+<style>
+  .tabs {
+    display: flex;
+    align-items: center;
+    border-bottom: 2px solid #ccc;
+    margin-bottom: 1rem;
+  }
+  .tab {
+    padding: 10px 20px;
+    cursor: pointer;
+    border: 1px solid #ccc;
+    border-bottom: none;
+    background: #f9f9f9;
+    margin-right: 5px;
+    border-top-left-radius: 8px;
+    border-top-right-radius: 8px;
+  }
+  .tab.active {
+    background: white;
+    font-weight: bold;
+  }
+  /* new help-button styling */
+  .help-btn {
+    margin-left: auto;
+    padding: 6px 12px;
+    font-size: 0.9rem;
+    border: 1px solid #4CAF50;
+    border-radius: 4px;
+    background: #4CAF50;
+    color: white;
+    cursor: pointer;
+  }
+  .tab-content {
+    display: none;
+    padding: 20px;
+    border: 1px solid #ccc;
+    border-top: none;
+  }
+  .tab-content.active {
+    display: block;
+  }
+</style>
+
+<div class="tabs">
+  <div class="tab active" onclick="showTab('metrics')">Config &amp; Results Summary</div>
+  <div class="tab" onclick="showTab('trainval')">Train/Validation Results</div>
+  <div class="tab" onclick="showTab('test')">Test Results</div>
+  <!-- always-visible help button -->
+  <button id="openMetricsHelp" class="help-btn">Help</button>
+</div>
+
+<div id="metrics" class="tab-content active">
+  <h2 style='text-align: center;'>Training Setup</h2><div style='display: flex; justify-content: center;'><table style='border-collapse: collapse; width: 60%; table-layout: auto;'><thead><tr><th style='padding: 10px; border: 1px solid #ccc; text-align: left;'>Parameter</th><th style='padding: 10px; border: 1px solid #ccc; text-align: center;'>Value</th></tr></thead><tbody><tr><td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left;'>Task Type</td><td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center;'>Regression</td></tr><tr><td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left;'>Model Name</td><td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center;'>resnet18</td></tr><tr><td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left;'>Epochs</td><td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center;'>10</td></tr><tr><td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left;'>Batch Size</td><td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center;'>Auto-selected batch size by Ludwig:<br><span style='font-size: 0.85em;'>1</span><br></td></tr><tr><td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left;'>Fine Tune</td><td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center;'>True</td></tr><tr><td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left;'>Use Pretrained</td><td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center;'>True</td></tr><tr><td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left;'>Learning Rate</td><td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center;'>Auto-selected learning rate by Ludwig:<br><span style='font-size: 0.85em;'>1e-05</span><br><span style='font-size: 0.85em;'>Based on model architecture and training setup (e.g., fine-tuning).<br>See <a href='https://ludwig.ai/latest/configuration/trainer/#trainer-parameters' target='_blank'>Ludwig Trainer Parameters</a> for details.</span></td></tr><tr><td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left;'>Random Seed</td><td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center;'>42</td></tr><tr><td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left;'>Early Stop</td><td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center;'>5</td></tr><tr><td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left;'>Data Split</td><td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center;'>Used user-defined split column from CSV.</td></tr></tbody></table></div><br><p style='text-align: center; font-size: 0.9em;'>Model trained using Ludwig.<br>If want to learn more about Ludwig default settings,please check their <a href='https://ludwig.ai' target='_blank'>website(ludwig.ai)</a>.</p><hr><h2 style='text-align: center;'>Model Performance Summary</h2><div style='display: flex; justify-content: center;'><table style='border-collapse: collapse; table-layout: auto;'><thead><tr><th style='padding: 10px; border: 1px solid #ccc; text-align: left; white-space: nowrap;'>Metric</th><th style='padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;'>Train</th><th style='padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;'>Validation</th><th style='padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;'>Test</th></tr></thead><tbody><tr><td style='padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;'>Loss</td><td style='padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;'>420.7510</td><td style='padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;'>2060.3052</td><td style='padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;'>8205.5977</td></tr><tr><td style='padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;'>Mean Absolute Error</td><td style='padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;'>17.3023</td><td style='padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;'>45.1572</td><td style='padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;'>86.0225</td></tr><tr><td style='padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;'>Mean Absolute % Error</td><td style='padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;'>0.6416</td><td style='padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;'>0.9613</td><td style='padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;'>0.9580</td></tr><tr><td style='padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;'>Mean Squared Error</td><td style='padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;'>420.7510</td><td style='padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;'>2060.3052</td><td style='padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;'>8205.5977</td></tr><tr><td style='padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;'>R² Score</td><td style='padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;'>-2.1257</td><td style='padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;'>-81.4122</td><td style='padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;'>-10.2560</td></tr><tr><td style='padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;'>Root Mean Squared Error</td><td style='padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;'>20.5122</td><td style='padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;'>45.3906</td><td style='padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;'>90.5848</td></tr><tr><td style='padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;'>Root Mean Squared % Error</td><td style='padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;'>0.6416</td><td style='padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;'>0.9613</td><td style='padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;'>0.9580</td></tr></tbody></table></div><br>
+</div>
+<div id="trainval" class="tab-content">
+  <h2 style='text-align: center;'>Train/Validation Performance Summary</h2><div style='display: flex; justify-content: center;'><table style='border-collapse: collapse; table-layout: auto;'><thead><tr><th style='padding: 10px; border: 1px solid #ccc; text-align: left; white-space: nowrap;'>Metric</th><th style='padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;'>Train</th><th style='padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;'>Validation</th></tr></thead><tbody><tr><td style='padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;'>Loss</td><td style='padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;'>420.7510</td><td style='padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;'>2060.3052</td></tr><tr><td style='padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;'>Mean Absolute Error</td><td style='padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;'>17.3023</td><td style='padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;'>45.1572</td></tr><tr><td style='padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;'>Mean Absolute % Error</td><td style='padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;'>0.6416</td><td style='padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;'>0.9613</td></tr><tr><td style='padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;'>Mean Squared Error</td><td style='padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;'>420.7510</td><td style='padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;'>2060.3052</td></tr><tr><td style='padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;'>R² Score</td><td style='padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;'>-2.1257</td><td style='padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;'>-81.4122</td></tr><tr><td style='padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;'>Root Mean Squared Error</td><td style='padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;'>20.5122</td><td style='padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;'>45.3906</td></tr><tr><td style='padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;'>Root Mean Squared % Error</td><td style='padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;'>0.6416</td><td style='padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;'>0.9613</td></tr></tbody></table></div><br><h2 style='text-align: center;'>Training & Validation Visualizations</h2><div><div class="plot" style="margin-bottom:20px;text-align:center;"><h3>Learning Curves Label Loss</h3><img src="" style="max-width:90%;max-height:600px;border:1px solid #ddd;" /></div></div>
+</div>
+<div id="test" class="tab-content">
+  <h2 style='text-align: center;'>Test Performance Summary</h2><div style='display: flex; justify-content: center;'><table style='border-collapse: collapse; table-layout: auto;'><thead><tr><th style='padding: 10px; border: 1px solid #ccc; text-align: left; white-space: nowrap;'>Metric</th><th style='padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;'>Test</th></tr></thead><tbody><tr><td style='padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;'>Loss</td><td style='padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;'>8205.5977</td></tr><tr><td style='padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;'>Mean Absolute Error</td><td style='padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;'>86.0225</td></tr><tr><td style='padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;'>Mean Absolute % Error</td><td style='padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;'>0.9580</td></tr><tr><td style='padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;'>Mean Squared Error</td><td style='padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;'>8205.5977</td></tr><tr><td style='padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;'>R² Score</td><td style='padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;'>-10.2560</td></tr><tr><td style='padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;'>Root Mean Squared Error</td><td style='padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;'>90.5848</td></tr><tr><td style='padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;'>Root Mean Squared % Error</td><td style='padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;'>0.9580</td></tr></tbody></table></div><br><h2 style='text-align: center;'>Predictions vs. Ground Truth</h2><div style='overflow-x:auto; margin-bottom:20px;'><table border="1" class="dataframe predictions-table">
+  <thead>
+    <tr style="text-align: right;">
+      <th>label</th>
+      <th>prediction</th>
+    </tr>
+  </thead>
+  <tbody>
+    <tr>
+      <td>62</td>
+      <td>4.362794</td>
+    </tr>
+    <tr>
+      <td>116</td>
+      <td>1.592189</td>
+    </tr>
+  </tbody>
+</table></div><h2 style='text-align: center;'>Test Visualizations</h2><div><div class="plot" style="margin-bottom:20px;text-align:center;"><h3>Compare Performance Label</h3><img src="" style="max-width:90%;max-height:600px;border:1px solid #ddd;" /></div></div>
+</div>
+
+<script>
+function showTab(id) {
+  document.querySelectorAll('.tab-content').forEach(el => el.classList.remove('active'));
+  document.querySelectorAll('.tab').forEach(el => el.classList.remove('active'));
+  document.getElementById(id).classList.add('active');
+  document.querySelector(`.tab[onclick*="${id}"]`).classList.add('active');
+}
+</script>
+
+<style>
+.modal {
+  display: none;
+  position: fixed;
+  z-index: 1;
+  left: 0;
+  top: 0;
+  width: 100%;
+  height: 100%;
+  overflow: auto;
+  background-color: rgba(0,0,0,0.4);
+}
+.modal-content {
+  background-color: #fefefe;
+  margin: 15% auto;
+  padding: 20px;
+  border: 1px solid #888;
+  width: 80%;
+  max-width: 800px;
+}
+.close {
+  color: #aaa;
+  float: right;
+  font-size: 28px;
+  font-weight: bold;
+}
+.close:hover,
+.close:focus {
+  color: black;
+  text-decoration: none;
+  cursor: pointer;
+}
+.metrics-guide h3 {
+  margin-top: 20px;
+}
+.metrics-guide p {
+  margin: 5px 0;
+}
+.metrics-guide ul {
+  margin: 10px 0;
+  padding-left: 20px;
+}
+</style>
+
+<div id="metricsHelpModal" class="modal">
+  <div class="modal-content">
+    <span class="close">×</span>
+    <h2>Model Evaluation Metrics — Help Guide</h2>
+    <div class="metrics-guide">
+      <h3>1) General Metrics</h3>
+      <p><strong>Loss:</strong> Measures the difference between predicted and actual values. Lower is better. Often used for optimization during training.</p>
+      <p><strong>Accuracy:</strong> Proportion of correct predictions among all predictions. Simple but can be misleading for imbalanced datasets.</p>
+      <p><strong>Micro Accuracy:</strong> Calculates accuracy by summing up all individual true positives and true negatives across all classes, making it suitable for multiclass or multilabel problems.</p>
+      <p><strong>Token Accuracy:</strong> Measures how often the predicted tokens (e.g., in sequences) match the true tokens. Useful in sequence prediction tasks like NLP.</p>
+      <h3>2) Precision, Recall & Specificity</h3>
+      <p><strong>Precision:</strong> Out of all positive predictions, how many were correct. Precision = TP / (TP + FP). Helps when false positives are costly.</p>
+      <p><strong>Recall (Sensitivity):</strong> Out of all actual positives, how many were predicted correctly. Recall = TP / (TP + FN). Important when missing positives is risky.</p>
+      <p><strong>Specificity:</strong> True negative rate. Measures how well the model identifies negatives. Specificity = TN / (TN + FP). Useful in medical testing to avoid false alarms.</p>
+      <h3>3) Macro, Micro, and Weighted Averages</h3>
+      <p><strong>Macro Precision / Recall / F1:</strong> Averages the metric across all classes, treating each class equally, regardless of class frequency. Best when class sizes are balanced.</p>
+      <p><strong>Micro Precision / Recall / F1:</strong> Aggregates TP, FP, FN across all classes before computing the metric. Gives a global view and is ideal for class-imbalanced problems.</p>
+      <p><strong>Weighted Precision / Recall / F1:</strong> Averages each metric across classes, weighted by the number of true instances per class. Balances importance of classes based on frequency.</p>
+      <h3>4) Average Precision (PR-AUC Variants)</h3>
+      <p><strong>Average Precision Macro:</strong> Precision-Recall AUC averaged across all classes equally. Useful for balanced multi-class problems.</p>
+      <p><strong>Average Precision Micro:</strong> Global Precision-Recall AUC using all instances. Best for imbalanced data or multi-label classification.</p>
+      <p><strong>Average Precision Samples:</strong> Precision-Recall AUC averaged across individual samples (not classes). Ideal for multi-label problems where each sample can belong to multiple classes.</p>
+      <h3>5) ROC-AUC Variants</h3>
+      <p><strong>ROC-AUC:</strong> Measures model's ability to distinguish between classes. AUC = 1 is perfect; 0.5 is random guessing. Use for binary classification.</p>
+      <p><strong>Macro ROC-AUC:</strong> Averages the AUC across all classes equally. Suitable when classes are balanced and of equal importance.</p>
+      <p><strong>Micro ROC-AUC:</strong> Computes AUC from aggregated predictions across all classes. Useful in multiclass or multilabel settings with imbalance.</p>
+      <h3>6) Ranking Metrics</h3>
+      <p><strong>Hits at K:</strong> Measures whether the true label is among the top-K predictions. Common in recommendation systems and retrieval tasks.</p>
+      <h3>7) Confusion Matrix Stats (Per Class)</h3>
+      <p><strong>True Positives / Negatives (TP / TN):</strong> Correct predictions for positives and negatives respectively.</p>
+      <p><strong>False Positives / Negatives (FP / FN):</strong> Incorrect predictions — false alarms and missed detections.</p>
+      <h3>8) Other Useful Metrics</h3>
+      <p><strong>Cohen's Kappa:</strong> Measures agreement between predicted and actual values adjusted for chance. Useful for multiclass classification with imbalanced labels.</p>
+      <p><strong>Matthews Correlation Coefficient (MCC):</strong> Balanced measure of prediction quality that takes into account TP, TN, FP, and FN. Particularly effective for imbalanced datasets.</p>
+      <h3>9) Metric Recommendations</h3>
+      <ul>
+        <li>Use <strong>Accuracy + F1</strong> for balanced data.</li>
+        <li>Use <strong>Precision, Recall, ROC-AUC</strong> for imbalanced datasets.</li>
+        <li>Use <strong>Average Precision Micro</strong> for multilabel or class-imbalanced problems.</li>
+        <li>Use <strong>Macro scores</strong> when all classes should be treated equally.</li>
+        <li>Use <strong>Weighted scores</strong> when class imbalance should be accounted for without ignoring small classes.</li>
+        <li>Use <strong>Confusion Matrix stats</strong> to analyze class-wise performance.</li>
+        <li>Use <strong>Hits at K</strong> for recommendation or ranking-based tasks.</li>
+      </ul>
+    </div>
+  </div>
+</div>
+
+<script>
+document.addEventListener("DOMContentLoaded", function() {
+  var modal = document.getElementById("metricsHelpModal");
+  var openBtn = document.getElementById("openMetricsHelp");
+  var span = document.getElementsByClassName("close")[0];
+  if (openBtn && modal) {
+    openBtn.onclick = function() {
+      modal.style.display = "block";
+    };
+  }
+  if (span && modal) {
+    span.onclick = function() {
+      modal.style.display = "none";
+    };
+  }
+  window.onclick = function(event) {
+    if (event.target == modal) {
+      modal.style.display = "none";
+    }
+  }
+});
+</script>
+
+    </div>
+    </body>
+    </html>
+    
\ No newline at end of file
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/test-data/utkface_labels.csv	Thu Jul 03 20:43:24 2025 +0000
@@ -0,0 +1,13 @@
+image_path,label,split
+1_0_4_20161221195047839.jpg.chip.jpg,1,0
+4_0_0_20170110213253207.jpg.chip.jpg,4,0
+18_1_0_20170109212647587.jpg.chip.jpg,18,0
+24_1_2_20170116164628750.jpg.chip.jpg,24,0
+26_1_0_20170116164911648.jpg.chip.jpg,26,0
+28_1_0_20170109141748286.jpg.chip.jpg,28,0
+31_0_4_20170120133240958.jpg.chip.jpg,31,0
+35_0_0_20170104202556995.jpg.chip.jpg,35,0
+42_0_2_20170117130543345.jpg.chip.jpg,42,1
+52_0_1_20170117175021585.jpg.chip.jpg,52,1
+62_1_3_20170109132000815.jpg.chip.jpg,62,2
+116_1_2_20170112220255503.jpg.chip.jpg,116,2
--- a/utils.py	Wed Jul 02 18:59:10 2025 +0000
+++ b/utils.py	Thu Jul 03 20:43:24 2025 +0000
@@ -155,3 +155,199 @@
     if isinstance(json_data, str):
         json_data = json.loads(json_data)
     return json_to_nested_html_table(json_data)
+
+
+def build_tabbed_html(metrics_html: str, train_val_html: str, test_html: str) -> str:
+    return f"""
+<style>
+  .tabs {{
+    display: flex;
+    align-items: center;
+    border-bottom: 2px solid #ccc;
+    margin-bottom: 1rem;
+  }}
+  .tab {{
+    padding: 10px 20px;
+    cursor: pointer;
+    border: 1px solid #ccc;
+    border-bottom: none;
+    background: #f9f9f9;
+    margin-right: 5px;
+    border-top-left-radius: 8px;
+    border-top-right-radius: 8px;
+  }}
+  .tab.active {{
+    background: white;
+    font-weight: bold;
+  }}
+  /* new help-button styling */
+  .help-btn {{
+    margin-left: auto;
+    padding: 6px 12px;
+    font-size: 0.9rem;
+    border: 1px solid #4CAF50;
+    border-radius: 4px;
+    background: #4CAF50;
+    color: white;
+    cursor: pointer;
+  }}
+  .tab-content {{
+    display: none;
+    padding: 20px;
+    border: 1px solid #ccc;
+    border-top: none;
+  }}
+  .tab-content.active {{
+    display: block;
+  }}
+</style>
+
+<div class="tabs">
+  <div class="tab active" onclick="showTab('metrics')">Config &amp; Results Summary</div>
+  <div class="tab" onclick="showTab('trainval')">Train/Validation Results</div>
+  <div class="tab" onclick="showTab('test')">Test Results</div>
+  <!-- always-visible help button -->
+  <button id="openMetricsHelp" class="help-btn">Help</button>
+</div>
+
+<div id="metrics" class="tab-content active">
+  {metrics_html}
+</div>
+<div id="trainval" class="tab-content">
+  {train_val_html}
+</div>
+<div id="test" class="tab-content">
+  {test_html}
+</div>
+
+<script>
+function showTab(id) {{
+  document.querySelectorAll('.tab-content').forEach(el => el.classList.remove('active'));
+  document.querySelectorAll('.tab').forEach(el => el.classList.remove('active'));
+  document.getElementById(id).classList.add('active');
+  document.querySelector(`.tab[onclick*="${{id}}"]`).classList.add('active');
+}}
+</script>
+"""
+
+
+def get_metrics_help_modal() -> str:
+    modal_html = """
+<div id="metricsHelpModal" class="modal">
+  <div class="modal-content">
+    <span class="close">×</span>
+    <h2>Model Evaluation Metrics — Help Guide</h2>
+    <div class="metrics-guide">
+      <h3>1) General Metrics</h3>
+      <p><strong>Loss:</strong> Measures the difference between predicted and actual values. Lower is better. Often used for optimization during training.</p>
+      <p><strong>Accuracy:</strong> Proportion of correct predictions among all predictions. Simple but can be misleading for imbalanced datasets.</p>
+      <p><strong>Micro Accuracy:</strong> Calculates accuracy by summing up all individual true positives and true negatives across all classes, making it suitable for multiclass or multilabel problems.</p>
+      <p><strong>Token Accuracy:</strong> Measures how often the predicted tokens (e.g., in sequences) match the true tokens. Useful in sequence prediction tasks like NLP.</p>
+      <h3>2) Precision, Recall & Specificity</h3>
+      <p><strong>Precision:</strong> Out of all positive predictions, how many were correct. Precision = TP / (TP + FP). Helps when false positives are costly.</p>
+      <p><strong>Recall (Sensitivity):</strong> Out of all actual positives, how many were predicted correctly. Recall = TP / (TP + FN). Important when missing positives is risky.</p>
+      <p><strong>Specificity:</strong> True negative rate. Measures how well the model identifies negatives. Specificity = TN / (TN + FP). Useful in medical testing to avoid false alarms.</p>
+      <h3>3) Macro, Micro, and Weighted Averages</h3>
+      <p><strong>Macro Precision / Recall / F1:</strong> Averages the metric across all classes, treating each class equally, regardless of class frequency. Best when class sizes are balanced.</p>
+      <p><strong>Micro Precision / Recall / F1:</strong> Aggregates TP, FP, FN across all classes before computing the metric. Gives a global view and is ideal for class-imbalanced problems.</p>
+      <p><strong>Weighted Precision / Recall / F1:</strong> Averages each metric across classes, weighted by the number of true instances per class. Balances importance of classes based on frequency.</p>
+      <h3>4) Average Precision (PR-AUC Variants)</h3>
+      <p><strong>Average Precision Macro:</strong> Precision-Recall AUC averaged across all classes equally. Useful for balanced multi-class problems.</p>
+      <p><strong>Average Precision Micro:</strong> Global Precision-Recall AUC using all instances. Best for imbalanced data or multi-label classification.</p>
+      <p><strong>Average Precision Samples:</strong> Precision-Recall AUC averaged across individual samples (not classes). Ideal for multi-label problems where each sample can belong to multiple classes.</p>
+      <h3>5) ROC-AUC Variants</h3>
+      <p><strong>ROC-AUC:</strong> Measures model's ability to distinguish between classes. AUC = 1 is perfect; 0.5 is random guessing. Use for binary classification.</p>
+      <p><strong>Macro ROC-AUC:</strong> Averages the AUC across all classes equally. Suitable when classes are balanced and of equal importance.</p>
+      <p><strong>Micro ROC-AUC:</strong> Computes AUC from aggregated predictions across all classes. Useful in multiclass or multilabel settings with imbalance.</p>
+      <h3>6) Ranking Metrics</h3>
+      <p><strong>Hits at K:</strong> Measures whether the true label is among the top-K predictions. Common in recommendation systems and retrieval tasks.</p>
+      <h3>7) Confusion Matrix Stats (Per Class)</h3>
+      <p><strong>True Positives / Negatives (TP / TN):</strong> Correct predictions for positives and negatives respectively.</p>
+      <p><strong>False Positives / Negatives (FP / FN):</strong> Incorrect predictions — false alarms and missed detections.</p>
+      <h3>8) Other Useful Metrics</h3>
+      <p><strong>Cohen's Kappa:</strong> Measures agreement between predicted and actual values adjusted for chance. Useful for multiclass classification with imbalanced labels.</p>
+      <p><strong>Matthews Correlation Coefficient (MCC):</strong> Balanced measure of prediction quality that takes into account TP, TN, FP, and FN. Particularly effective for imbalanced datasets.</p>
+      <h3>9) Metric Recommendations</h3>
+      <ul>
+        <li>Use <strong>Accuracy + F1</strong> for balanced data.</li>
+        <li>Use <strong>Precision, Recall, ROC-AUC</strong> for imbalanced datasets.</li>
+        <li>Use <strong>Average Precision Micro</strong> for multilabel or class-imbalanced problems.</li>
+        <li>Use <strong>Macro scores</strong> when all classes should be treated equally.</li>
+        <li>Use <strong>Weighted scores</strong> when class imbalance should be accounted for without ignoring small classes.</li>
+        <li>Use <strong>Confusion Matrix stats</strong> to analyze class-wise performance.</li>
+        <li>Use <strong>Hits at K</strong> for recommendation or ranking-based tasks.</li>
+      </ul>
+    </div>
+  </div>
+</div>
+"""
+    modal_css = """
+<style>
+.modal {
+  display: none;
+  position: fixed;
+  z-index: 1;
+  left: 0;
+  top: 0;
+  width: 100%;
+  height: 100%;
+  overflow: auto;
+  background-color: rgba(0,0,0,0.4);
+}
+.modal-content {
+  background-color: #fefefe;
+  margin: 15% auto;
+  padding: 20px;
+  border: 1px solid #888;
+  width: 80%;
+  max-width: 800px;
+}
+.close {
+  color: #aaa;
+  float: right;
+  font-size: 28px;
+  font-weight: bold;
+}
+.close:hover,
+.close:focus {
+  color: black;
+  text-decoration: none;
+  cursor: pointer;
+}
+.metrics-guide h3 {
+  margin-top: 20px;
+}
+.metrics-guide p {
+  margin: 5px 0;
+}
+.metrics-guide ul {
+  margin: 10px 0;
+  padding-left: 20px;
+}
+</style>
+"""
+    modal_js = """
+<script>
+document.addEventListener("DOMContentLoaded", function() {
+  var modal = document.getElementById("metricsHelpModal");
+  var openBtn = document.getElementById("openMetricsHelp");
+  var span = document.getElementsByClassName("close")[0];
+  if (openBtn && modal) {
+    openBtn.onclick = function() {
+      modal.style.display = "block";
+    };
+  }
+  if (span && modal) {
+    span.onclick = function() {
+      modal.style.display = "none";
+    };
+  }
+  window.onclick = function(event) {
+    if (event.target == modal) {
+      modal.style.display = "none";
+    }
+  }
+});
+</script>
+"""
+    return modal_css + modal_html + modal_js