diff image_learner_cli.py @ 2:186424a7eca7 draft

planemo upload for repository https://github.com/goeckslab/gleam.git commit 91fa4aba245520fc0680088a07cead66bcfd4ed2
author goeckslab
date Thu, 03 Jul 2025 20:43:24 +0000
parents 39202fe5cf97
children 09904b1f61f5
line wrap: on
line diff
--- a/image_learner_cli.py	Wed Jul 02 18:59:10 2025 +0000
+++ b/image_learner_cli.py	Thu Jul 03 20:43:24 2025 +0000
@@ -1,4 +1,3 @@
-#!/usr/bin/env python3
 import argparse
 import json
 import logging
@@ -11,7 +10,18 @@
 from typing import Any, Dict, Optional, Protocol, Tuple
 
 import pandas as pd
+import pandas.api.types as ptypes
 import yaml
+from constants import (
+    IMAGE_PATH_COLUMN_NAME,
+    LABEL_COLUMN_NAME,
+    METRIC_DISPLAY_NAMES,
+    MODEL_ENCODER_TEMPLATES,
+    SPLIT_COLUMN_NAME,
+    TEMP_CONFIG_FILENAME,
+    TEMP_CSV_FILENAME,
+    TEMP_DIR_PREFIX
+)
 from ludwig.globals import (
     DESCRIPTION_FILE_NAME,
     PREDICTIONS_PARQUET_FILE_NAME,
@@ -21,258 +31,29 @@
 from ludwig.utils.data_utils import get_split_path
 from ludwig.visualize import get_visualizations_registry
 from sklearn.model_selection import train_test_split
-from utils import encode_image_to_base64, get_html_closing, get_html_template
-
-# --- Constants ---
-SPLIT_COLUMN_NAME = "split"
-LABEL_COLUMN_NAME = "label"
-IMAGE_PATH_COLUMN_NAME = "image_path"
-DEFAULT_SPLIT_PROBABILITIES = [0.7, 0.1, 0.2]
-TEMP_CSV_FILENAME = "processed_data_for_ludwig.csv"
-TEMP_CONFIG_FILENAME = "ludwig_config.yaml"
-TEMP_DIR_PREFIX = "ludwig_api_work_"
-MODEL_ENCODER_TEMPLATES: Dict[str, Any] = {
-    "stacked_cnn": "stacked_cnn",
-    "resnet18": {"type": "resnet", "model_variant": 18},
-    "resnet34": {"type": "resnet", "model_variant": 34},
-    "resnet50": {"type": "resnet", "model_variant": 50},
-    "resnet101": {"type": "resnet", "model_variant": 101},
-    "resnet152": {"type": "resnet", "model_variant": 152},
-    "resnext50_32x4d": {"type": "resnext", "model_variant": "50_32x4d"},
-    "resnext101_32x8d": {"type": "resnext", "model_variant": "101_32x8d"},
-    "resnext101_64x4d": {"type": "resnext", "model_variant": "101_64x4d"},
-    "resnext152_32x8d": {"type": "resnext", "model_variant": "152_32x8d"},
-    "wide_resnet50_2": {"type": "wide_resnet", "model_variant": "50_2"},
-    "wide_resnet101_2": {"type": "wide_resnet", "model_variant": "101_2"},
-    "wide_resnet103_2": {"type": "wide_resnet", "model_variant": "103_2"},
-    "efficientnet_b0": {"type": "efficientnet", "model_variant": "b0"},
-    "efficientnet_b1": {"type": "efficientnet", "model_variant": "b1"},
-    "efficientnet_b2": {"type": "efficientnet", "model_variant": "b2"},
-    "efficientnet_b3": {"type": "efficientnet", "model_variant": "b3"},
-    "efficientnet_b4": {"type": "efficientnet", "model_variant": "b4"},
-    "efficientnet_b5": {"type": "efficientnet", "model_variant": "b5"},
-    "efficientnet_b6": {"type": "efficientnet", "model_variant": "b6"},
-    "efficientnet_b7": {"type": "efficientnet", "model_variant": "b7"},
-    "efficientnet_v2_s": {"type": "efficientnet", "model_variant": "v2_s"},
-    "efficientnet_v2_m": {"type": "efficientnet", "model_variant": "v2_m"},
-    "efficientnet_v2_l": {"type": "efficientnet", "model_variant": "v2_l"},
-    "regnet_y_400mf": {"type": "regnet", "model_variant": "y_400mf"},
-    "regnet_y_800mf": {"type": "regnet", "model_variant": "y_800mf"},
-    "regnet_y_1_6gf": {"type": "regnet", "model_variant": "y_1_6gf"},
-    "regnet_y_3_2gf": {"type": "regnet", "model_variant": "y_3_2gf"},
-    "regnet_y_8gf": {"type": "regnet", "model_variant": "y_8gf"},
-    "regnet_y_16gf": {"type": "regnet", "model_variant": "y_16gf"},
-    "regnet_y_32gf": {"type": "regnet", "model_variant": "y_32gf"},
-    "regnet_y_128gf": {"type": "regnet", "model_variant": "y_128gf"},
-    "regnet_x_400mf": {"type": "regnet", "model_variant": "x_400mf"},
-    "regnet_x_800mf": {"type": "regnet", "model_variant": "x_800mf"},
-    "regnet_x_1_6gf": {"type": "regnet", "model_variant": "x_1_6gf"},
-    "regnet_x_3_2gf": {"type": "regnet", "model_variant": "x_3_2gf"},
-    "regnet_x_8gf": {"type": "regnet", "model_variant": "x_8gf"},
-    "regnet_x_16gf": {"type": "regnet", "model_variant": "x_16gf"},
-    "regnet_x_32gf": {"type": "regnet", "model_variant": "x_32gf"},
-    "vgg11": {"type": "vgg", "model_variant": 11},
-    "vgg11_bn": {"type": "vgg", "model_variant": "11_bn"},
-    "vgg13": {"type": "vgg", "model_variant": 13},
-    "vgg13_bn": {"type": "vgg", "model_variant": "13_bn"},
-    "vgg16": {"type": "vgg", "model_variant": 16},
-    "vgg16_bn": {"type": "vgg", "model_variant": "16_bn"},
-    "vgg19": {"type": "vgg", "model_variant": 19},
-    "vgg19_bn": {"type": "vgg", "model_variant": "19_bn"},
-    "shufflenet_v2_x0_5": {"type": "shufflenet_v2", "model_variant": "x0_5"},
-    "shufflenet_v2_x1_0": {"type": "shufflenet_v2", "model_variant": "x1_0"},
-    "shufflenet_v2_x1_5": {"type": "shufflenet_v2", "model_variant": "x1_5"},
-    "shufflenet_v2_x2_0": {"type": "shufflenet_v2", "model_variant": "x2_0"},
-    "squeezenet1_0": {"type": "squeezenet", "model_variant": "1_0"},
-    "squeezenet1_1": {"type": "squeezenet", "model_variant": "1_1"},
-    "swin_t": {"type": "swin_transformer", "model_variant": "t"},
-    "swin_s": {"type": "swin_transformer", "model_variant": "s"},
-    "swin_b": {"type": "swin_transformer", "model_variant": "b"},
-    "swin_v2_t": {"type": "swin_transformer", "model_variant": "v2_t"},
-    "swin_v2_s": {"type": "swin_transformer", "model_variant": "v2_s"},
-    "swin_v2_b": {"type": "swin_transformer", "model_variant": "v2_b"},
-    "vit_b_16": {"type": "vision_transformer", "model_variant": "b_16"},
-    "vit_b_32": {"type": "vision_transformer", "model_variant": "b_32"},
-    "vit_l_16": {"type": "vision_transformer", "model_variant": "l_16"},
-    "vit_l_32": {"type": "vision_transformer", "model_variant": "l_32"},
-    "vit_h_14": {"type": "vision_transformer", "model_variant": "h_14"},
-    "convnext_tiny": {"type": "convnext", "model_variant": "tiny"},
-    "convnext_small": {"type": "convnext", "model_variant": "small"},
-    "convnext_base": {"type": "convnext", "model_variant": "base"},
-    "convnext_large": {"type": "convnext", "model_variant": "large"},
-    "maxvit_t": {"type": "maxvit", "model_variant": "t"},
-    "alexnet": {"type": "alexnet"},
-    "googlenet": {"type": "googlenet"},
-    "inception_v3": {"type": "inception_v3"},
-    "mobilenet_v2": {"type": "mobilenet_v2"},
-    "mobilenet_v3_large": {"type": "mobilenet_v3_large"},
-    "mobilenet_v3_small": {"type": "mobilenet_v3_small"},
-}
-METRIC_DISPLAY_NAMES = {
-    "accuracy": "Accuracy",
-    "accuracy_micro": "Accuracy-Micro",
-    "loss": "Loss",
-    "roc_auc": "ROC-AUC",
-    "roc_auc_macro": "ROC-AUC-Macro",
-    "roc_auc_micro": "ROC-AUC-Micro",
-    "hits_at_k": "Hits at K",
-    "precision": "Precision",
-    "recall": "Recall",
-    "specificity": "Specificity",
-    "kappa_score": "Cohen's Kappa",
-    "token_accuracy": "Token Accuracy",
-    "avg_precision_macro": "Precision-Macro",
-    "avg_recall_macro": "Recall-Macro",
-    "avg_f1_score_macro": "F1-score-Macro",
-    "avg_precision_micro": "Precision-Micro",
-    "avg_recall_micro": "Recall-Micro",
-    "avg_f1_score_micro": "F1-score-Micro",
-    "avg_precision_weighted": "Precision-Weighted",
-    "avg_recall_weighted": "Recall-Weighted",
-    "avg_f1_score_weighted": "F1-score-Weighted",
-    "average_precision_macro": " Precision-Average-Macro",
-    "average_precision_micro": "Precision-Average-Micro",
-    "average_precision_samples": "Precision-Average-Samples",
-}
+from utils import (
+    build_tabbed_html,
+    encode_image_to_base64,
+    get_html_closing,
+    get_html_template,
+    get_metrics_help_modal
+)
 
 # --- Logging Setup ---
 logging.basicConfig(
     level=logging.INFO,
-    format="%(asctime)s %(levelname)s %(name)s: %(message)s",
+    format='%(asctime)s %(levelname)s %(name)s: %(message)s',
 )
 logger = logging.getLogger("ImageLearner")
 
 
-def get_metrics_help_modal() -> str:
-    modal_html = """
-<div id="metricsHelpModal" class="modal">
-  <div class="modal-content">
-    <span class="close">×</span>
-    <h2>Model Evaluation Metrics — Help Guide</h2>
-    <div class="metrics-guide">
-      <h3>1) General Metrics</h3>
-      <p><strong>Loss:</strong> Measures the difference between predicted and actual values. Lower is better. Often used for optimization during training.</p>
-      <p><strong>Accuracy:</strong> Proportion of correct predictions among all predictions. Simple but can be misleading for imbalanced datasets.</p>
-      <p><strong>Micro Accuracy:</strong> Calculates accuracy by summing up all individual true positives and true negatives across all classes, making it suitable for multiclass or multilabel problems.</p>
-      <p><strong>Token Accuracy:</strong> Measures how often the predicted tokens (e.g., in sequences) match the true tokens. Useful in sequence prediction tasks like NLP.</p>
-      <h3>2) Precision, Recall & Specificity</h3>
-      <p><strong>Precision:</strong> Out of all positive predictions, how many were correct. Precision = TP / (TP + FP). Helps when false positives are costly.</p>
-      <p><strong>Recall (Sensitivity):</strong> Out of all actual positives, how many were predicted correctly. Recall = TP / (TP + FN). Important when missing positives is risky.</p>
-      <p><strong>Specificity:</strong> True negative rate. Measures how well the model identifies negatives. Specificity = TN / (TN + FP). Useful in medical testing to avoid false alarms.</p>
-      <h3>3) Macro, Micro, and Weighted Averages</h3>
-      <p><strong>Macro Precision / Recall / F1:</strong> Averages the metric across all classes, treating each class equally, regardless of class frequency. Best when class sizes are balanced.</p>
-      <p><strong>Micro Precision / Recall / F1:</strong> Aggregates TP, FP, FN across all classes before computing the metric. Gives a global view and is ideal for class-imbalanced problems.</p>
-      <p><strong>Weighted Precision / Recall / F1:</strong> Averages each metric across classes, weighted by the number of true instances per class. Balances importance of classes based on frequency.</p>
-      <h3>4) Average Precision (PR-AUC Variants)</h3>
-      <p><strong>Average Precision Macro:</strong> Precision-Recall AUC averaged across all classes equally. Useful for balanced multi-class problems.</p>
-      <p><strong>Average Precision Micro:</strong> Global Precision-Recall AUC using all instances. Best for imbalanced data or multi-label classification.</p>
-      <p><strong>Average Precision Samples:</strong> Precision-Recall AUC averaged across individual samples (not classes). Ideal for multi-label problems where each sample can belong to multiple classes.</p>
-      <h3>5) ROC-AUC Variants</h3>
-      <p><strong>ROC-AUC:</strong> Measures model's ability to distinguish between classes. AUC = 1 is perfect; 0.5 is random guessing. Use for binary classification.</p>
-      <p><strong>Macro ROC-AUC:</strong> Averages the AUC across all classes equally. Suitable when classes are balanced and of equal importance.</p>
-      <p><strong>Micro ROC-AUC:</strong> Computes AUC from aggregated predictions across all classes. Useful in multiclass or multilabel settings with imbalance.</p>
-      <h3>6) Ranking Metrics</h3>
-      <p><strong>Hits at K:</strong> Measures whether the true label is among the top-K predictions. Common in recommendation systems and retrieval tasks.</p>
-      <h3>7) Confusion Matrix Stats (Per Class)</h3>
-      <p><strong>True Positives / Negatives (TP / TN):</strong> Correct predictions for positives and negatives respectively.</p>
-      <p><strong>False Positives / Negatives (FP / FN):</strong> Incorrect predictions — false alarms and missed detections.</p>
-      <h3>8) Other Useful Metrics</h3>
-      <p><strong>Cohen's Kappa:</strong> Measures agreement between predicted and actual values adjusted for chance. Useful for multiclass classification with imbalanced labels.</p>
-      <p><strong>Matthews Correlation Coefficient (MCC):</strong> Balanced measure of prediction quality that takes into account TP, TN, FP, and FN. Particularly effective for imbalanced datasets.</p>
-      <h3>9) Metric Recommendations</h3>
-      <ul>
-        <li>Use <strong>Accuracy + F1</strong> for balanced data.</li>
-        <li>Use <strong>Precision, Recall, ROC-AUC</strong> for imbalanced datasets.</li>
-        <li>Use <strong>Average Precision Micro</strong> for multilabel or class-imbalanced problems.</li>
-        <li>Use <strong>Macro scores</strong> when all classes should be treated equally.</li>
-        <li>Use <strong>Weighted scores</strong> when class imbalance should be accounted for without ignoring small classes.</li>
-        <li>Use <strong>Confusion Matrix stats</strong> to analyze class-wise performance.</li>
-        <li>Use <strong>Hits at K</strong> for recommendation or ranking-based tasks.</li>
-      </ul>
-    </div>
-  </div>
-</div>
-"""
-    modal_css = """
-<style>
-.modal {
-  display: none;
-  position: fixed;
-  z-index: 1;
-  left: 0;
-  top: 0;
-  width: 100%;
-  height: 100%;
-  overflow: auto;
-  background-color: rgba(0,0,0,0.4);
-}
-.modal-content {
-  background-color: #fefefe;
-  margin: 15% auto;
-  padding: 20px;
-  border: 1px solid #888;
-  width: 80%;
-  max-width: 800px;
-}
-.close {
-  color: #aaa;
-  float: right;
-  font-size: 28px;
-  font-weight: bold;
-}
-.close:hover,
-.close:focus {
-  color: black;
-  text-decoration: none;
-  cursor: pointer;
-}
-.metrics-guide h3 {
-  margin-top: 20px;
-}
-.metrics-guide p {
-  margin: 5px 0;
-}
-.metrics-guide ul {
-  margin: 10px 0;
-  padding-left: 20px;
-}
-</style>
-"""
-    modal_js = """
-<script>
-document.addEventListener("DOMContentLoaded", function() {
-  var modal = document.getElementById("metricsHelpModal");
-  var closeBtn = document.getElementsByClassName("close")[0];
-
-  document.querySelectorAll(".openMetricsHelp").forEach(btn => {
-    btn.onclick = function() {
-      modal.style.display = "block";
-    };
-  });
-
-  if (closeBtn) {
-    closeBtn.onclick = function() {
-      modal.style.display = "none";
-    };
-  }
-
-  window.onclick = function(event) {
-    if (event.target == modal) {
-      modal.style.display = "none";
-    }
-  }
-});
-</script>
-"""
-    return modal_css + modal_html + modal_js
-
-
 def format_config_table_html(
     config: dict,
     split_info: Optional[str] = None,
     training_progress: dict = None,
 ) -> str:
     display_keys = [
+        "task_type",
         "model_name",
         "epochs",
         "batch_size",
@@ -287,6 +68,8 @@
 
     for key in display_keys:
         val = config.get(key, "N/A")
+        if key == "task_type":
+            val = val.title() if isinstance(val, str) else val
         if key == "batch_size":
             if val is not None:
                 val = int(val)
@@ -348,6 +131,18 @@
             f"</tr>"
         )
 
+    aug_cfg = config.get("augmentation")
+    if aug_cfg:
+        types = [str(a.get("type", "")) for a in aug_cfg]
+        aug_val = ", ".join(types)
+        rows.append(
+            "<tr>"
+            "<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left;'>Augmentation</td>"
+            "<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center;'>"
+            f"{aug_val}</td>"
+            "</tr>"
+        )
+
     if split_info:
         rows.append(
             f"<tr>"
@@ -371,7 +166,7 @@
         "<p style='text-align: center; font-size: 0.9em;'>"
         "Model trained using Ludwig.<br>"
         "If want to learn more about Ludwig default settings,"
-        "please check the their <a href='https://ludwig.ai' target='_blank'>"
+        "please check their <a href='https://ludwig.ai' target='_blank'>"
         "website(ludwig.ai)</a>."
         "</p><hr>"
     )
@@ -380,6 +175,8 @@
 def detect_output_type(test_stats):
     """Detects if the output type is 'binary' or 'category' based on test statistics."""
     label_stats = test_stats.get("label", {})
+    if "mean_squared_error" in label_stats:
+        return "regression"
     per_class = label_stats.get("per_class_stats", {})
     if len(per_class) == 2:
         return "binary"
@@ -420,6 +217,24 @@
                 "specificity": get_last_value(label_stats, "specificity"),
                 "roc_auc": get_last_value(label_stats, "roc_auc"),
             }
+        elif output_type == "regression":
+            metrics[split] = {
+                "loss": get_last_value(label_stats, "loss"),
+                "mean_absolute_error": get_last_value(
+                    label_stats, "mean_absolute_error"
+                ),
+                "mean_absolute_percentage_error": get_last_value(
+                    label_stats, "mean_absolute_percentage_error"
+                ),
+                "mean_squared_error": get_last_value(label_stats, "mean_squared_error"),
+                "root_mean_squared_error": get_last_value(
+                    label_stats, "root_mean_squared_error"
+                ),
+                "root_mean_squared_percentage_error": get_last_value(
+                    label_stats, "root_mean_squared_percentage_error"
+                ),
+                "r2": get_last_value(label_stats, "r2"),
+            }
         else:
             metrics[split] = {
                 "accuracy": get_last_value(label_stats, "accuracy"),
@@ -565,7 +380,9 @@
     return html
 
 
-def format_test_merged_stats_table_html(test_metrics: Dict[str, Optional[float]]) -> str:
+def format_test_merged_stats_table_html(
+    test_metrics: Dict[str, Optional[float]],
+) -> str:
     """Formats an HTML table for test metrics."""
     rows = []
     for key in sorted(test_metrics.keys()):
@@ -598,63 +415,6 @@
     return html
 
 
-def build_tabbed_html(metrics_html: str, train_val_html: str, test_html: str) -> str:
-    return f"""
-<style>
-.tabs {{
-  display: flex;
-  border-bottom: 2px solid #ccc;
-  margin-bottom: 1rem;
-}}
-.tab {{
-  padding: 10px 20px;
-  cursor: pointer;
-  border: 1px solid #ccc;
-  border-bottom: none;
-  background: #f9f9f9;
-  margin-right: 5px;
-  border-top-left-radius: 8px;
-  border-top-right-radius: 8px;
-}}
-.tab.active {{
-  background: white;
-  font-weight: bold;
-}}
-.tab-content {{
-  display: none;
-  padding: 20px;
-  border: 1px solid #ccc;
-  border-top: none;
-}}
-.tab-content.active {{
-  display: block;
-}}
-</style>
-<div class="tabs">
-  <div class="tab active" onclick="showTab('metrics')"> Config & Results Summary</div>
-  <div class="tab" onclick="showTab('trainval')"> Train/Validation Results</div>
-  <div class="tab" onclick="showTab('test')"> Test Results</div>
-</div>
-<div id="metrics" class="tab-content active">
-  {metrics_html}
-</div>
-<div id="trainval" class="tab-content">
-  {train_val_html}
-</div>
-<div id="test" class="tab-content">
-  {test_html}
-</div>
-<script>
-function showTab(id) {{
-  document.querySelectorAll('.tab-content').forEach(el => el.classList.remove('active'));
-  document.querySelectorAll('.tab').forEach(el => el.classList.remove('active'));
-  document.getElementById(id).classList.add('active');
-  document.querySelector(`.tab[onclick*="${{id}}"]`).classList.add('active');
-}}
-</script>
-"""
-
-
 def split_data_0_2(
     df: pd.DataFrame,
     split_column: str,
@@ -727,16 +487,15 @@
     ) -> None:
         ...
 
-    def generate_plots(
-        self,
-        output_dir: Path
-    ) -> None:
+    def generate_plots(self, output_dir: Path) -> None:
         ...
 
     def generate_html_report(
         self,
         title: str,
-        output_dir: str
+        output_dir: str,
+        config: Dict[str, Any],
+        split_info: str,
     ) -> Path:
         ...
 
@@ -749,23 +508,21 @@
         config_params: Dict[str, Any],
         split_config: Dict[str, Any],
     ) -> str:
-        """Build and serialize the Ludwig YAML configuration."""
         logger.info("LudwigDirectBackend: Preparing YAML configuration.")
 
         model_name = config_params.get("model_name", "resnet18")
         use_pretrained = config_params.get("use_pretrained", False)
         fine_tune = config_params.get("fine_tune", False)
+        if use_pretrained:
+            trainable = bool(fine_tune)
+        else:
+            trainable = True
         epochs = config_params.get("epochs", 10)
         batch_size = config_params.get("batch_size")
         num_processes = config_params.get("preprocessing_num_processes", 1)
         early_stop = config_params.get("early_stop", None)
         learning_rate = config_params.get("learning_rate")
         learning_rate = "auto" if learning_rate is None else float(learning_rate)
-        trainable = fine_tune or (not use_pretrained)
-        if not use_pretrained and not trainable:
-            logger.warning("trainable=False; use_pretrained=False is ignored.")
-            logger.warning("Setting trainable=True to train the model from scratch.")
-            trainable = True
         raw_encoder = MODEL_ENCODER_TEMPLATES.get(model_name, model_name)
         if isinstance(raw_encoder, dict):
             encoder_config = {
@@ -779,39 +536,68 @@
         batch_size_cfg = batch_size or "auto"
 
         label_column_path = config_params.get("label_column_data_path")
+        label_series = None
         if label_column_path is not None and Path(label_column_path).exists():
             try:
                 label_series = pd.read_csv(label_column_path)[LABEL_COLUMN_NAME]
-                num_unique_labels = label_series.nunique()
             except Exception as e:
-                logger.warning(
-                    f"Could not determine label cardinality, defaulting to 'binary': {e}"
-                )
-                num_unique_labels = 2
+                logger.warning(f"Could not read label column for task detection: {e}")
+
+        if (
+            label_series is not None
+            and ptypes.is_numeric_dtype(label_series.dtype)
+            and label_series.nunique() > 10
+        ):
+            task_type = "regression"
         else:
-            logger.warning(
-                "label_column_data_path not provided, defaulting to 'binary'"
+            task_type = "classification"
+
+        config_params["task_type"] = task_type
+
+        image_feat: Dict[str, Any] = {
+            "name": IMAGE_PATH_COLUMN_NAME,
+            "type": "image",
+            "encoder": encoder_config,
+        }
+        if config_params.get("augmentation") is not None:
+            image_feat["augmentation"] = config_params["augmentation"]
+
+        if task_type == "regression":
+            output_feat = {
+                "name": LABEL_COLUMN_NAME,
+                "type": "number",
+                "decoder": {"type": "regressor"},
+                "loss": {"type": "mean_squared_error"},
+                "evaluation": {
+                    "metrics": [
+                        "mean_squared_error",
+                        "mean_absolute_error",
+                        "r2",
+                    ]
+                },
+            }
+            val_metric = config_params.get("validation_metric", "mean_squared_error")
+
+        else:
+            num_unique_labels = (
+                label_series.nunique() if label_series is not None else 2
             )
-            num_unique_labels = 2
-
-        output_type = "binary" if num_unique_labels == 2 else "category"
+            output_type = "binary" if num_unique_labels == 2 else "category"
+            output_feat = {"name": LABEL_COLUMN_NAME, "type": output_type}
+            val_metric = None
 
         conf: Dict[str, Any] = {
             "model_type": "ecd",
-            "input_features": [
-                {
-                    "name": IMAGE_PATH_COLUMN_NAME,
-                    "type": "image",
-                    "encoder": encoder_config,
-                }
-            ],
-            "output_features": [{"name": LABEL_COLUMN_NAME, "type": output_type}],
+            "input_features": [image_feat],
+            "output_features": [output_feat],
             "combiner": {"type": "concat"},
             "trainer": {
                 "epochs": epochs,
                 "early_stop": early_stop,
                 "batch_size": batch_size_cfg,
                 "learning_rate": learning_rate,
+                # only set validation_metric for regression
+                **({"validation_metric": val_metric} if val_metric else {}),
             },
             "preprocessing": {
                 "split": split_config,
@@ -876,7 +662,7 @@
             )
             raise
 
-    def get_training_process(self, output_dir) -> float:
+    def get_training_process(self, output_dir) -> Optional[Dict[str, Any]]:
         """Retrieve the learning rate used in the most recent Ludwig run."""
         output_dir = Path(output_dir)
         exp_dirs = sorted(
@@ -1000,11 +786,12 @@
 
         viz_registry = get_visualizations_registry()
         for viz_name, viz_func in viz_registry.items():
-            viz_dir_plot = None
             if viz_name in train_plots:
                 viz_dir_plot = train_viz
             elif viz_name in test_plots:
                 viz_dir_plot = test_viz
+            else:
+                continue
 
             try:
                 viz_func(
@@ -1040,6 +827,7 @@
         report_name = title.lower().replace(" ", "_") + "_report.html"
         report_path = cwd / report_name
         output_dir = Path(output_dir)
+        output_type = None
 
         exp_dirs = sorted(
             output_dir.glob("experiment_run*"),
@@ -1059,7 +847,6 @@
         metrics_html = ""
         train_val_metrics_html = ""
         test_metrics_html = ""
-
         try:
             train_stats_path = exp_dir / "training_statistics.json"
             test_stats_path = exp_dir / TEST_STATISTICS_FILE_NAME
@@ -1069,18 +856,14 @@
                 with open(test_stats_path) as f:
                     test_stats = json.load(f)
                 output_type = detect_output_type(test_stats)
-                all_metrics = extract_metrics_from_json(
-                    train_stats,
-                    test_stats,
-                    output_type,
-                )
                 metrics_html = format_stats_table_html(train_stats, test_stats)
                 train_val_metrics_html = format_train_val_stats_table_html(
-                    train_stats,
-                    test_stats,
+                    train_stats, test_stats
                 )
                 test_metrics_html = format_test_merged_stats_table_html(
-                    all_metrics["test"],
+                    extract_metrics_from_json(train_stats, test_stats, output_type)[
+                        "test"
+                    ]
                 )
         except Exception as e:
             logger.warning(
@@ -1090,11 +873,15 @@
         config_html = ""
         training_progress = self.get_training_process(output_dir)
         try:
-            config_html = format_config_table_html(config, split_info, training_progress)
+            config_html = format_config_table_html(
+                config, split_info, training_progress
+            )
         except Exception as e:
             logger.warning(f"Could not load config for HTML report: {e}")
 
-        def render_img_section(title: str, dir_path: Path, output_type: str = None) -> str:
+        def render_img_section(
+            title: str, dir_path: Path, output_type: str = None
+        ) -> str:
             if not dir_path.exists():
                 return f"<h2>{title}</h2><p><em>Directory not found.</em></p>"
 
@@ -1141,11 +928,7 @@
                     img_names[fname] for fname in display_order if fname in img_names
                 ]
                 remaining = sorted(
-                    [
-                        img
-                        for img in img_names.values()
-                        if img.name not in display_order
-                    ]
+                    [img for img in img_names.values() if img.name not in display_order]
                 )
                 imgs = ordered_imgs + remaining
 
@@ -1173,46 +956,61 @@
             section_html += "</div>"
             return section_html
 
-        button_html = """
-        <button class="help-modal-btn openMetricsHelp">Model Evaluation Metrics — Help Guide</button>
-        <br><br>
-        <style>
-        .help-modal-btn {
-            background-color: #17623b;
-            color: #fff;
-            border: none;
-            border-radius: 24px;
-            padding: 10px 28px;
-            font-size: 1.1rem;
-            font-weight: bold;
-            letter-spacing: 0.03em;
-            cursor: pointer;
-            transition: background 0.2s, box-shadow 0.2s;
-            box-shadow: 0 2px 8px rgba(23,98,59,0.07);
-        }
-        .help-modal-btn:hover, .help-modal-btn:focus {
-            background-color: #21895e;
-            outline: none;
-            box-shadow: 0 4px 16px rgba(23,98,59,0.14);
-        }
-        </style>
-        """
-        tab1_content = button_html + config_html + metrics_html
-        tab2_content = (
-            button_html
-            + train_val_metrics_html
-            + render_img_section("Training & Validation Visualizations", train_viz_dir)
+        tab1_content = config_html + metrics_html
+
+        tab2_content = train_val_metrics_html + render_img_section(
+            "Training & Validation Visualizations", train_viz_dir
         )
+
+        # --- Predictions vs Ground Truth table ---
+        preds_section = ""
+        parquet_path = exp_dir / PREDICTIONS_PARQUET_FILE_NAME
+        if parquet_path.exists():
+            try:
+                # 1) load predictions from Parquet
+                df_preds = pd.read_parquet(parquet_path).reset_index(drop=True)
+                # assume the column containing your model's prediction is named "prediction"
+                # or contains that substring:
+                pred_col = next(
+                    (c for c in df_preds.columns if "prediction" in c.lower()),
+                    None,
+                )
+                if pred_col is None:
+                    raise ValueError("No prediction column found in Parquet output")
+                df_pred = df_preds[[pred_col]].rename(columns={pred_col: "prediction"})
+
+                # 2) load ground truth for the test split from prepared CSV
+                df_all = pd.read_csv(config["label_column_data_path"])
+                df_gt = df_all[df_all[SPLIT_COLUMN_NAME] == 2][
+                    LABEL_COLUMN_NAME
+                ].reset_index(drop=True)
+
+                # 3) concatenate side‐by‐side
+                df_table = pd.concat([df_gt, df_pred], axis=1)
+                df_table.columns = [LABEL_COLUMN_NAME, "prediction"]
+
+                # 4) render as HTML
+                preds_html = df_table.to_html(index=False, classes="predictions-table")
+                preds_section = (
+                    "<h2 style='text-align: center;'>Predictions vs. Ground Truth</h2>"
+                    "<div style='overflow-x:auto; margin-bottom:20px;'>"
+                    + preds_html
+                    + "</div>"
+                )
+            except Exception as e:
+                logger.warning(f"Could not build Predictions vs GT table: {e}")
+        # Test tab = Metrics + Preds table + Visualizations
+
         tab3_content = (
-            button_html
-            + test_metrics_html
+            test_metrics_html
+            + preds_section
             + render_img_section("Test Visualizations", test_viz_dir, output_type)
         )
 
+        # assemble the tabs and help modal
         tabbed_html = build_tabbed_html(tab1_content, tab2_content, tab3_content)
         modal_html = get_metrics_help_modal()
-        html += tabbed_html + modal_html
-        html += get_html_closing()
+        html += tabbed_html + modal_html + get_html_closing()
 
         try:
             with open(report_path, "w") as f:
@@ -1263,7 +1061,7 @@
             logger.error("Error extracting zip file", exc_info=True)
             raise
 
-    def _prepare_data(self) -> Tuple[Path, Dict[str, Any]]:
+    def _prepare_data(self) -> Tuple[Path, Dict[str, Any], str]:
         """Load CSV, update image paths, handle splits, and write prepared CSV."""
         if not self.temp_dir or not self.image_extract_dir:
             raise RuntimeError("Temp dirs not initialized before data prep.")
@@ -1302,8 +1100,9 @@
                 f"for train/val/test."
             )
 
-        final_csv = TEMP_CSV_FILENAME
+        final_csv = self.temp_dir / TEMP_CSV_FILENAME
         try:
+
             df.to_csv(final_csv, index=False)
             logger.info(f"Saved prepared data to {final_csv}")
         except Exception:
@@ -1312,7 +1111,9 @@
 
         return final_csv, split_config, split_info
 
-    def _process_fixed_split(self, df: pd.DataFrame) -> Dict[str, Any]:
+    def _process_fixed_split(
+        self, df: pd.DataFrame
+    ) -> Tuple[pd.DataFrame, Dict[str, Any], str]:
         """Process a fixed split column (0=train,1=val,2=test)."""
         logger.info(f"Fixed split column '{SPLIT_COLUMN_NAME}' detected.")
         try:
@@ -1384,6 +1185,7 @@
                 "random_seed": self.args.random_seed,
                 "early_stop": self.args.early_stop,
                 "label_column_data_path": csv_path,
+                "augmentation": self.args.augmentation,
             }
             yaml_str = self.backend.prepare_config(backend_args, split_cfg)
 
@@ -1422,6 +1224,29 @@
         return None
 
 
+def aug_parse(aug_string: str):
+    """
+    Parse comma-separated augmentation keys into Ludwig augmentation dicts.
+    Raises ValueError on unknown key.
+    """
+    mapping = {
+        "random_horizontal_flip": {"type": "random_horizontal_flip"},
+        "random_vertical_flip": {"type": "random_vertical_flip"},
+        "random_rotate": {"type": "random_rotate", "degree": 10},
+        "random_blur": {"type": "random_blur", "kernel_size": 3},
+        "random_brightness": {"type": "random_brightness", "min": 0.5, "max": 2.0},
+        "random_contrast": {"type": "random_contrast", "min": 0.5, "max": 2.0},
+    }
+    aug_list = []
+    for tok in aug_string.split(","):
+        key = tok.strip()
+        if key not in mapping:
+            valid = ", ".join(mapping.keys())
+            raise ValueError(f"Unknown augmentation '{key}'. Valid choices: {valid}")
+        aug_list.append(mapping[key])
+    return aug_list
+
+
 class SplitProbAction(argparse.Action):
     def __call__(self, parser, namespace, values, option_string=None):
         train, val, test = values
@@ -1508,7 +1333,10 @@
         metavar=("train", "val", "test"),
         action=SplitProbAction,
         default=[0.7, 0.1, 0.2],
-        help="Random split proportions (e.g., 0.7 0.1 0.2). Only used if no split column.",
+        help=(
+            "Random split proportions (e.g., 0.7 0.1 0.2)."
+            "Only used if no split column."
+        ),
     )
     parser.add_argument(
         "--random-seed",
@@ -1522,6 +1350,17 @@
         default=None,
         help="Learning rate. If not provided, Ludwig will auto-select it.",
     )
+    parser.add_argument(
+        "--augmentation",
+        type=str,
+        default=None,
+        help=(
+            "Comma-separated list (in order) of any of: "
+            "random_horizontal_flip, random_vertical_flip, random_rotate, "
+            "random_blur, random_brightness, random_contrast. "
+            "E.g. --augmentation random_horizontal_flip,random_rotate"
+        ),
+    )
 
     args = parser.parse_args()
 
@@ -1531,6 +1370,12 @@
         parser.error(f"CSV not found: {args.csv_file}")
     if not args.image_zip.is_file():
         parser.error(f"ZIP not found: {args.image_zip}")
+    if args.augmentation is not None:
+        try:
+            augmentation_setup = aug_parse(args.augmentation)
+            setattr(args, "augmentation", augmentation_setup)
+        except ValueError as e:
+            parser.error(str(e))
 
     backend_instance = LudwigDirectBackend()
     orchestrator = WorkflowOrchestrator(args, backend_instance)