Mercurial > repos > goeckslab > image_learner
diff image_learner_cli.py @ 12:bcfa2e234a80 draft
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
| author | goeckslab |
|---|---|
| date | Fri, 21 Nov 2025 15:58:13 +0000 |
| parents | c5150cceab47 |
| children |
line wrap: on
line diff
--- a/image_learner_cli.py Sat Oct 18 03:17:09 2025 +0000 +++ b/image_learner_cli.py Fri Nov 21 15:58:13 2025 +0000 @@ -1,45 +1,15 @@ import argparse -import json import logging import os -import shutil import sys -import tempfile -import zipfile from pathlib import Path -from typing import Any, Dict, Optional, Protocol, Tuple import matplotlib -import numpy as np -import pandas as pd -import pandas.api.types as ptypes -import yaml -from constants import ( - IMAGE_PATH_COLUMN_NAME, - LABEL_COLUMN_NAME, - METRIC_DISPLAY_NAMES, - MODEL_ENCODER_TEMPLATES, - SPLIT_COLUMN_NAME, - TEMP_CONFIG_FILENAME, - TEMP_CSV_FILENAME, - TEMP_DIR_PREFIX, -) -from ludwig.globals import ( - DESCRIPTION_FILE_NAME, - PREDICTIONS_PARQUET_FILE_NAME, - TEST_STATISTICS_FILE_NAME, - TRAIN_SET_METADATA_FILE_NAME, -) -from ludwig.utils.data_utils import get_split_path -from plotly_plots import build_classification_plots -from sklearn.model_selection import train_test_split -from utils import ( - build_tabbed_html, - encode_image_to_base64, - get_html_closing, - get_html_template, - get_metrics_help_modal, -) +from constants import MODEL_ENCODER_TEMPLATES +from image_workflow import ImageLearnerCLI +from ludwig_backend import LudwigDirectBackend +from split_data import SplitProbAction +from utils import argument_checker, parse_learning_rate # Set matplotlib backend after imports matplotlib.use('Agg') @@ -51,1839 +21,6 @@ ) logger = logging.getLogger("ImageLearner") -# Optional MetaFormer configuration registry -META_DEFAULT_CFGS: Dict[str, Any] = {} -try: - from MetaFormer import default_cfgs as META_DEFAULT_CFGS # type: ignore[attr-defined] -except Exception as e: - logger.debug("MetaFormer default configs unavailable: %s", e) - META_DEFAULT_CFGS = {} - -# Try to import Ludwig visualization registry (may fail due to optional dependencies) -# This must come AFTER logger is defined -_ludwig_viz_available = False -get_visualizations_registry = None -try: - from ludwig.visualize import get_visualizations_registry - _ludwig_viz_available = True - logger.info("Ludwig visualizations available") -except ImportError as e: - logger.warning(f"Ludwig visualizations not available: {e}. Will use fallback plots only.") -except Exception as e: - logger.warning(f"Ludwig visualizations not available due to dependency issues: {e}. Will use fallback plots only.") - -# --- MetaFormer patching integration --- -_metaformer_patch_ok = False -try: - from MetaFormer.metaformer_stacked_cnn import patch_ludwig_stacked_cnn as _mf_patch - if _mf_patch(): - _metaformer_patch_ok = True - logger.info("MetaFormer patching applied for Ludwig stacked_cnn encoder.") -except Exception as e: - logger.warning(f"MetaFormer stacked CNN not available: {e}") - _metaformer_patch_ok = False - -# Note: CAFormer models are now handled through MetaFormer framework - - -def format_config_table_html( - config: dict, - split_info: Optional[str] = None, - training_progress: dict = None, - output_type: Optional[str] = None, -) -> str: - display_keys = [ - "task_type", - "model_name", - "epochs", - "batch_size", - "fine_tune", - "use_pretrained", - "learning_rate", - "random_seed", - "early_stop", - "threshold", - ] - - rows = [] - - for key in display_keys: - val = config.get(key, None) - if key == "threshold": - if output_type != "binary": - continue - val = val if val is not None else 0.5 - val_str = f"{val:.2f}" - if val == 0.5: - val_str += " (default)" - else: - if key == "task_type": - val_str = val.title() if isinstance(val, str) else "N/A" - elif key == "batch_size": - if val is not None: - val_str = int(val) - else: - val = "auto" - val_str = "auto" - resolved_val = None - if val is None or val == "auto": - if training_progress: - resolved_val = training_progress.get("batch_size") - val = ( - "Auto-selected batch size by Ludwig:<br>" - f"<span style='font-size: 0.85em;'>" - f"{resolved_val if resolved_val else val}</span><br>" - "<span style='font-size: 0.85em;'>" - "Based on model architecture and training setup " - "(e.g., fine-tuning).<br>" - "See <a href='https://ludwig.ai/latest/configuration/trainer/" - "#trainer-parameters' target='_blank'>" - "Ludwig Trainer Parameters</a> for details." - "</span>" - ) - else: - val = ( - "Auto-selected by Ludwig<br>" - "<span style='font-size: 0.85em;'>" - "Automatically tuned based on architecture and dataset.<br>" - "See <a href='https://ludwig.ai/latest/configuration/trainer/" - "#trainer-parameters' target='_blank'>" - "Ludwig Trainer Parameters</a> for details." - "</span>" - ) - elif key == "learning_rate": - if val is not None and val != "auto": - val_str = f"{val:.6f}" - else: - if training_progress: - resolved_val = training_progress.get("learning_rate") - val_str = ( - "Auto-selected learning rate by Ludwig:<br>" - f"<span style='font-size: 0.85em;'>" - f"{resolved_val if resolved_val else 'auto'}</span><br>" - "<span style='font-size: 0.85em;'>" - "Based on model architecture and training setup " - "(e.g., fine-tuning).<br>" - "</span>" - ) - else: - val_str = ( - "Auto-selected by Ludwig<br>" - "<span style='font-size: 0.85em;'>" - "Automatically tuned based on architecture and dataset.<br>" - "See <a href='https://ludwig.ai/latest/configuration/trainer/" - "#trainer-parameters' target='_blank'>" - "Ludwig Trainer Parameters</a> for details." - "</span>" - ) - elif key == "epochs": - if val is None: - val_str = "N/A" - else: - if ( - training_progress - and "epoch" in training_progress - and val > training_progress["epoch"] - ): - val_str = ( - f"Because of early stopping: the training " - f"stopped at epoch {training_progress['epoch']}" - ) - else: - val_str = val - else: - val_str = val if val is not None else "N/A" - if val_str == "N/A" and key not in ["task_type"]: - continue - rows.append( - f"<tr>" - f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left; " - f"white-space: normal; word-break: break-word; overflow-wrap: anywhere;'>" - f"{key.replace('_', ' ').title()}</td>" - f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center; " - f"white-space: normal; word-break: break-word; overflow-wrap: anywhere;'>" - f"{val_str}</td>" - f"</tr>" - ) - - aug_cfg = config.get("augmentation") - if aug_cfg: - types = [str(a.get("type", "")) for a in aug_cfg] - aug_val = ", ".join(types) - rows.append( - f"<tr><td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left; " - f"white-space: normal; word-break: break-word; overflow-wrap: anywhere;'>Augmentation</td>" - f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center; " - f"white-space: normal; word-break: break-word; overflow-wrap: anywhere;'>{aug_val}</td></tr>" - ) - - if split_info: - rows.append( - f"<tr><td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left; " - f"white-space: normal; word-break: break-word; overflow-wrap: anywhere;'>Data Split</td>" - f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center; " - f"white-space: normal; word-break: break-word; overflow-wrap: anywhere;'>{split_info}</td></tr>" - ) - - html = f""" - <h2 style="text-align: center;">Model and Training Summary</h2> - <div style="display: flex; justify-content: center;"> - <table style="border-collapse: collapse; width: 100%; table-layout: fixed;"> - <thead><tr> - <th style="padding: 10px; border: 1px solid #ccc; text-align: left; white-space: nowrap; overflow: hidden; text-overflow: ellipsis;">Parameter</th> - <th style="padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap; overflow: hidden; text-overflow: ellipsis;">Value</th> - </tr></thead> - <tbody> - {"".join(rows)} - </tbody> - </table> - </div><br> - <p style="text-align: center; font-size: 0.9em;"> - Model trained using <a href="https://ludwig.ai/" target="_blank" rel="noopener noreferrer">Ludwig</a>. - <a href="https://ludwig.ai/latest/configuration/" target="_blank" rel="noopener noreferrer"> - Ludwig documentation provides detailed information about default model and training parameters - </a> - </p><hr> - """ - return html - - -def detect_output_type(test_stats): - """Detects if the output type is 'binary' or 'category' based on test statistics.""" - label_stats = test_stats.get("label", {}) - if "mean_squared_error" in label_stats: - return "regression" - per_class = label_stats.get("per_class_stats", {}) - if len(per_class) == 2: - return "binary" - return "category" - - -def extract_metrics_from_json( - train_stats: dict, - test_stats: dict, - output_type: str, -) -> dict: - """Extracts relevant metrics from training and test statistics based on the output type.""" - metrics = {"training": {}, "validation": {}, "test": {}} - - def get_last_value(stats, key): - val = stats.get(key) - if isinstance(val, list) and val: - return val[-1] - elif isinstance(val, (int, float)): - return val - return None - - for split in ["training", "validation"]: - split_stats = train_stats.get(split, {}) - if not split_stats: - logging.warning(f"No statistics found for {split} split") - continue - label_stats = split_stats.get("label", {}) - if not label_stats: - logging.warning(f"No label statistics found for {split} split") - continue - if output_type == "binary": - metrics[split] = { - "accuracy": get_last_value(label_stats, "accuracy"), - "loss": get_last_value(label_stats, "loss"), - "precision": get_last_value(label_stats, "precision"), - "recall": get_last_value(label_stats, "recall"), - "specificity": get_last_value(label_stats, "specificity"), - "roc_auc": get_last_value(label_stats, "roc_auc"), - } - elif output_type == "regression": - metrics[split] = { - "loss": get_last_value(label_stats, "loss"), - "mean_absolute_error": get_last_value( - label_stats, "mean_absolute_error" - ), - "mean_absolute_percentage_error": get_last_value( - label_stats, "mean_absolute_percentage_error" - ), - "mean_squared_error": get_last_value(label_stats, "mean_squared_error"), - "root_mean_squared_error": get_last_value( - label_stats, "root_mean_squared_error" - ), - "root_mean_squared_percentage_error": get_last_value( - label_stats, "root_mean_squared_percentage_error" - ), - "r2": get_last_value(label_stats, "r2"), - } - else: - metrics[split] = { - "accuracy": get_last_value(label_stats, "accuracy"), - "accuracy_micro": get_last_value(label_stats, "accuracy_micro"), - "loss": get_last_value(label_stats, "loss"), - "roc_auc": get_last_value(label_stats, "roc_auc"), - "hits_at_k": get_last_value(label_stats, "hits_at_k"), - } - - # Test metrics: dynamic extraction according to exclusions - test_label_stats = test_stats.get("label", {}) - if not test_label_stats: - logging.warning("No label statistics found for test split") - else: - combined_stats = test_stats.get("combined", {}) - overall_stats = test_label_stats.get("overall_stats", {}) - - # Define exclusions - if output_type == "binary": - exclude = {"per_class_stats", "precision_recall_curve", "roc_curve"} - else: - exclude = {"per_class_stats", "confusion_matrix"} - - # 1. Get all scalar test_label_stats not excluded - test_metrics = {} - for k, v in test_label_stats.items(): - if k in exclude: - continue - if k == "overall_stats": - continue - if isinstance(v, (int, float, str, bool)): - test_metrics[k] = v - - # 2. Add overall_stats (flattened) - for k, v in overall_stats.items(): - test_metrics[k] = v - - # 3. Optionally include combined/loss if present and not already - if "loss" in combined_stats and "loss" not in test_metrics: - test_metrics["loss"] = combined_stats["loss"] - metrics["test"] = test_metrics - return metrics - - -def generate_table_row(cells, styles): - """Helper function to generate an HTML table row.""" - return ( - "<tr>" - + "".join(f"<td style='{styles}'>{cell}</td>" for cell in cells) - + "</tr>" - ) - - -# ----------------------------------------- -# 2) MODEL PERFORMANCE (Train/Val/Test) TABLE -# ----------------------------------------- -def format_stats_table_html(train_stats: dict, test_stats: dict, output_type: str) -> str: - """Formats a combined HTML table for training, validation, and test metrics.""" - all_metrics = extract_metrics_from_json(train_stats, test_stats, output_type) - rows = [] - for metric_key in sorted(all_metrics["training"].keys()): - if ( - metric_key in all_metrics["validation"] - and metric_key in all_metrics["test"] - ): - display_name = METRIC_DISPLAY_NAMES.get( - metric_key, - metric_key.replace("_", " ").title(), - ) - t = all_metrics["training"].get(metric_key) - v = all_metrics["validation"].get(metric_key) - te = all_metrics["test"].get(metric_key) - if all(x is not None for x in [t, v, te]): - rows.append([display_name, f"{t:.4f}", f"{v:.4f}", f"{te:.4f}"]) - - if not rows: - return "<table><tr><td>No metric values found.</td></tr></table>" - - html = ( - "<h2 style='text-align: center;'>Model Performance Summary</h2>" - "<div style='display: flex; justify-content: center;'>" - "<table class='performance-summary' style='border-collapse: collapse;'>" - "<thead><tr>" - "<th class='sortable' style='padding: 10px; border: 1px solid #ccc; text-align: left; white-space: nowrap;'>Metric</th>" - "<th class='sortable' style='padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;'>Train</th>" - "<th class='sortable' style='padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;'>Validation</th>" - "<th class='sortable' style='padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;'>Test</th>" - "</tr></thead><tbody>" - ) - for row in rows: - html += generate_table_row( - row, - "padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;", - ) - html += "</tbody></table></div><br>" - return html - - -# ------------------------------------------- -# 3) TRAIN/VALIDATION PERFORMANCE SUMMARY TABLE -# ------------------------------------------- -def format_train_val_stats_table_html(train_stats: dict, test_stats: dict) -> str: - """Format train/validation metrics into an HTML table.""" - all_metrics = extract_metrics_from_json(train_stats, test_stats, detect_output_type(test_stats)) - rows = [] - for metric_key in sorted(all_metrics["training"].keys()): - if metric_key in all_metrics["validation"]: - display_name = METRIC_DISPLAY_NAMES.get( - metric_key, - metric_key.replace("_", " ").title(), - ) - t = all_metrics["training"].get(metric_key) - v = all_metrics["validation"].get(metric_key) - if t is not None and v is not None: - rows.append([display_name, f"{t:.4f}", f"{v:.4f}"]) - - if not rows: - return "<table><tr><td>No metric values found for Train/Validation.</td></tr></table>" - - html = ( - "<h2 style='text-align: center;'>Train/Validation Performance Summary</h2>" - "<div style='display: flex; justify-content: center;'>" - "<table class='performance-summary' style='border-collapse: collapse;'>" - "<thead><tr>" - "<th class='sortable' style='padding: 10px; border: 1px solid #ccc; text-align: left; white-space: nowrap;'>Metric</th>" - "<th class='sortable' style='padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;'>Train</th>" - "<th class='sortable' style='padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;'>Validation</th>" - "</tr></thead><tbody>" - ) - for row in rows: - html += generate_table_row( - row, - "padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;", - ) - html += "</tbody></table></div><br>" - return html - - -# ----------------------------------------- -# 4) TEST‐ONLY PERFORMANCE SUMMARY TABLE -# ----------------------------------------- -def format_test_merged_stats_table_html( - test_metrics: Dict[str, Any], output_type: str -) -> str: - """Format test metrics into an HTML table.""" - rows = [] - for key in sorted(test_metrics.keys()): - display_name = METRIC_DISPLAY_NAMES.get(key, key.replace("_", " ").title()) - value = test_metrics[key] - if value is not None: - rows.append([display_name, f"{value:.4f}"]) - - if not rows: - return "<table><tr><td>No test metric values found.</td></tr></table>" - - html = ( - "<h2 style='text-align: center;'>Test Performance Summary</h2>" - "<div style='display: flex; justify-content: center;'>" - "<table class='performance-summary' style='border-collapse: collapse;'>" - "<thead><tr>" - "<th class='sortable' style='padding: 10px; border: 1px solid #ccc; text-align: left; white-space: nowrap;'>Metric</th>" - "<th class='sortable' style='padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;'>Test</th>" - "</tr></thead><tbody>" - ) - for row in rows: - html += generate_table_row( - row, - "padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;", - ) - html += "</tbody></table></div><br>" - return html - - -def split_data_0_2( - df: pd.DataFrame, - split_column: str, - validation_size: float = 0.1, - random_state: int = 42, - label_column: Optional[str] = None, -) -> pd.DataFrame: - """Given a DataFrame whose split_column only contains {0,2}, re-assign a portion of the 0s to become 1s (validation).""" - out = df.copy() - out[split_column] = pd.to_numeric(out[split_column], errors="coerce").astype(int) - - idx_train = out.index[out[split_column] == 0].tolist() - - if not idx_train: - logger.info("No rows with split=0; nothing to do.") - return out - stratify_arr = None - if label_column and label_column in out.columns: - label_counts = out.loc[idx_train, label_column].value_counts() - if label_counts.size > 1: - # Force stratify even with fewer samples - adjust validation_size if needed - min_samples_per_class = label_counts.min() - if min_samples_per_class * validation_size < 1: - # Adjust validation_size to ensure at least 1 sample per class, but do not exceed original validation_size - adjusted_validation_size = min( - validation_size, 1.0 / min_samples_per_class - ) - if adjusted_validation_size != validation_size: - validation_size = adjusted_validation_size - logger.info( - f"Adjusted validation_size to {validation_size:.3f} to ensure at least one sample per class in validation" - ) - stratify_arr = out.loc[idx_train, label_column] - logger.info("Using stratified split for validation set") - else: - logger.warning("Only one label class found; cannot stratify") - if validation_size <= 0: - logger.info("validation_size <= 0; keeping all as train.") - return out - if validation_size >= 1: - logger.info("validation_size >= 1; moving all train → validation.") - out.loc[idx_train, split_column] = 1 - return out - # Always try stratified split first - try: - train_idx, val_idx = train_test_split( - idx_train, - test_size=validation_size, - random_state=random_state, - stratify=stratify_arr, - ) - logger.info("Successfully applied stratified split") - except ValueError as e: - logger.warning(f"Stratified split failed ({e}); falling back to random split.") - train_idx, val_idx = train_test_split( - idx_train, - test_size=validation_size, - random_state=random_state, - stratify=None, - ) - out.loc[train_idx, split_column] = 0 - out.loc[val_idx, split_column] = 1 - out[split_column] = out[split_column].astype(int) - return out - - -def create_stratified_random_split( - df: pd.DataFrame, - split_column: str, - split_probabilities: list = [0.7, 0.1, 0.2], - random_state: int = 42, - label_column: Optional[str] = None, -) -> pd.DataFrame: - """Create a stratified random split when no split column exists.""" - out = df.copy() - - # initialize split column - out[split_column] = 0 - - if not label_column or label_column not in out.columns: - logger.warning( - "No label column found; using random split without stratification" - ) - # fall back to simple random assignment - indices = out.index.tolist() - np.random.seed(random_state) - np.random.shuffle(indices) - - n_total = len(indices) - n_train = int(n_total * split_probabilities[0]) - n_val = int(n_total * split_probabilities[1]) - - out.loc[indices[:n_train], split_column] = 0 - out.loc[indices[n_train:n_train + n_val], split_column] = 1 - out.loc[indices[n_train + n_val:], split_column] = 2 - - return out.astype({split_column: int}) - - # check if stratification is possible - label_counts = out[label_column].value_counts() - min_samples_per_class = label_counts.min() - - # ensure we have enough samples for stratification: - # Each class must have at least as many samples as the number of splits, - # so that each split can receive at least one sample per class. - min_samples_required = len(split_probabilities) - if min_samples_per_class < min_samples_required: - logger.warning( - f"Insufficient samples per class for stratification (min: {min_samples_per_class}, required: {min_samples_required}); using random split" - ) - # fall back to simple random assignment - indices = out.index.tolist() - np.random.seed(random_state) - np.random.shuffle(indices) - - n_total = len(indices) - n_train = int(n_total * split_probabilities[0]) - n_val = int(n_total * split_probabilities[1]) - - out.loc[indices[:n_train], split_column] = 0 - out.loc[indices[n_train:n_train + n_val], split_column] = 1 - out.loc[indices[n_train + n_val:], split_column] = 2 - - return out.astype({split_column: int}) - - logger.info("Using stratified random split for train/validation/test sets") - - # first split: separate test set - train_val_idx, test_idx = train_test_split( - out.index.tolist(), - test_size=split_probabilities[2], - random_state=random_state, - stratify=out[label_column], - ) - - # second split: separate training and validation from remaining data - val_size_adjusted = split_probabilities[1] / ( - split_probabilities[0] + split_probabilities[1] - ) - train_idx, val_idx = train_test_split( - train_val_idx, - test_size=val_size_adjusted, - random_state=random_state, - stratify=out.loc[train_val_idx, label_column] if label_column and label_column in out.columns else None, - ) - - # assign split values - out.loc[train_idx, split_column] = 0 - out.loc[val_idx, split_column] = 1 - out.loc[test_idx, split_column] = 2 - - logger.info("Successfully applied stratified random split") - logger.info( - f"Split counts: Train={len(train_idx)}, Val={len(val_idx)}, Test={len(test_idx)}" - ) - return out.astype({split_column: int}) - - -class Backend(Protocol): - """Interface for a machine learning backend.""" - - def prepare_config( - self, - config_params: Dict[str, Any], - split_config: Dict[str, Any], - ) -> str: - ... - - def run_experiment( - self, - dataset_path: Path, - config_path: Path, - output_dir: Path, - random_seed: int, - ) -> None: - ... - - def generate_plots(self, output_dir: Path) -> None: - ... - - def generate_html_report( - self, - title: str, - output_dir: str, - config: Dict[str, Any], - split_info: str, - ) -> Path: - ... - - -class LudwigDirectBackend: - """Backend for running Ludwig experiments directly via the internal experiment_cli function.""" - - def _detect_image_dimensions(self, image_zip_path: str) -> Tuple[int, int]: - """Detect image dimensions from the first image in the dataset.""" - try: - import zipfile - from PIL import Image - import io - - # Check if image_zip is provided - if not image_zip_path: - logger.warning("No image zip provided, using default 224x224") - return 224, 224 - - # Extract first image to detect dimensions - with zipfile.ZipFile(image_zip_path, 'r') as z: - image_files = [f for f in z.namelist() if f.lower().endswith(('.png', '.jpg', '.jpeg'))] - if not image_files: - logger.warning("No image files found in zip, using default 224x224") - return 224, 224 - - # Check first image - with z.open(image_files[0]) as f: - img = Image.open(io.BytesIO(f.read())) - width, height = img.size - logger.info(f"Detected image dimensions: {width}x{height}") - return height, width # Return as (height, width) to match encoder config - - except Exception as e: - logger.warning(f"Error detecting image dimensions: {e}, using default 224x224") - return 224, 224 - - def prepare_config( - self, - config_params: Dict[str, Any], - split_config: Dict[str, Any], - ) -> str: - logger.info("LudwigDirectBackend: Preparing YAML configuration.") - - model_name = config_params.get("model_name", "resnet18") - use_pretrained = config_params.get("use_pretrained", False) - fine_tune = config_params.get("fine_tune", False) - if use_pretrained: - trainable = bool(fine_tune) - else: - trainable = True - epochs = config_params.get("epochs", 10) - batch_size = config_params.get("batch_size") - num_processes = config_params.get("preprocessing_num_processes", 1) - early_stop = config_params.get("early_stop", None) - learning_rate = config_params.get("learning_rate") - learning_rate = "auto" if learning_rate is None else float(learning_rate) - raw_encoder = MODEL_ENCODER_TEMPLATES.get(model_name, model_name) - - # --- MetaFormer detection and config logic --- - def _is_metaformer(name: str) -> bool: - return isinstance(name, str) and name.startswith( - ( - "identityformer_", - "randformer_", - "poolformerv2_", - "convformer_", - "caformer_", - ) - ) - - # Check if this is a MetaFormer model (either direct name or in custom_model) - is_metaformer = ( - _is_metaformer(model_name) - or (isinstance(raw_encoder, dict) and "custom_model" in raw_encoder and _is_metaformer(raw_encoder["custom_model"])) - ) - - metaformer_resize: Optional[Tuple[int, int]] = None - metaformer_channels = 3 - - if is_metaformer: - # Handle MetaFormer models - custom_model = None - if isinstance(raw_encoder, dict) and "custom_model" in raw_encoder: - custom_model = raw_encoder["custom_model"] - else: - custom_model = model_name - - logger.info(f"DETECTED MetaFormer model: {custom_model}") - cfg_channels, cfg_height, cfg_width = 3, 224, 224 - if META_DEFAULT_CFGS: - model_cfg = META_DEFAULT_CFGS.get(custom_model, {}) - input_size = model_cfg.get("input_size") - if isinstance(input_size, (list, tuple)) and len(input_size) == 3: - cfg_channels, cfg_height, cfg_width = ( - int(input_size[0]), - int(input_size[1]), - int(input_size[2]), - ) - - target_height, target_width = cfg_height, cfg_width - resize_value = config_params.get("image_resize") - if resize_value and resize_value != "original": - try: - dimensions = resize_value.split("x") - if len(dimensions) == 2: - target_height, target_width = int(dimensions[0]), int(dimensions[1]) - if target_height <= 0 or target_width <= 0: - raise ValueError( - f"Image resize must be positive integers, received {resize_value}." - ) - logger.info(f"MetaFormer explicit resize: {target_height}x{target_width}") - else: - raise ValueError(resize_value) - except (ValueError, IndexError): - logger.warning( - "Invalid image resize format '%s'; falling back to model default %sx%s", - resize_value, - cfg_height, - cfg_width, - ) - target_height, target_width = cfg_height, cfg_width - else: - image_zip_path = config_params.get("image_zip", "") - detected_height, detected_width = self._detect_image_dimensions(image_zip_path) - if use_pretrained: - if (detected_height, detected_width) != (cfg_height, cfg_width): - logger.info( - "MetaFormer pretrained weights expect %sx%s; resizing from detected %sx%s", - cfg_height, - cfg_width, - detected_height, - detected_width, - ) - else: - target_height, target_width = detected_height, detected_width - if target_height <= 0 or target_width <= 0: - raise ValueError( - f"Invalid detected image dimensions for MetaFormer: {target_height}x{target_width}." - ) - - metaformer_channels = cfg_channels - metaformer_resize = (target_height, target_width) - - encoder_config = { - "type": "stacked_cnn", - "height": target_height, - "width": target_width, - "num_channels": metaformer_channels, - "output_size": 128, - "use_pretrained": use_pretrained, - "trainable": trainable, - "custom_model": custom_model, - } - - elif isinstance(raw_encoder, dict): - # Handle image resize for regular encoders - # Note: Standard encoders like ResNet don't support height/width parameters - # Resize will be handled at the preprocessing level by Ludwig - if config_params.get("image_resize") and config_params["image_resize"] != "original": - logger.info(f"Resize requested: {config_params['image_resize']} for standard encoder. Resize will be handled at preprocessing level.") - - encoder_config = { - **raw_encoder, - "use_pretrained": use_pretrained, - "trainable": trainable, - } - else: - encoder_config = {"type": raw_encoder} - - batch_size_cfg = batch_size or "auto" - - label_column_path = config_params.get("label_column_data_path") - label_series = None - if label_column_path is not None and Path(label_column_path).exists(): - try: - label_series = pd.read_csv(label_column_path)[LABEL_COLUMN_NAME] - except Exception as e: - logger.warning(f"Could not read label column for task detection: {e}") - - if ( - label_series is not None - and ptypes.is_numeric_dtype(label_series.dtype) - and label_series.nunique() > 10 - ): - task_type = "regression" - else: - task_type = "classification" - - config_params["task_type"] = task_type - - image_feat: Dict[str, Any] = { - "name": IMAGE_PATH_COLUMN_NAME, - "type": "image", - } - # Set preprocessing dimensions FIRST for MetaFormer models - if is_metaformer: - if metaformer_resize is None: - metaformer_resize = (224, 224) - height, width = metaformer_resize - - # CRITICAL: Set preprocessing dimensions FIRST for MetaFormer models - # This is essential for MetaFormer models to work properly - if "preprocessing" not in image_feat: - image_feat["preprocessing"] = {} - image_feat["preprocessing"]["height"] = height - image_feat["preprocessing"]["width"] = width - # Use infer_image_dimensions=True to allow Ludwig to read images for validation - # but set explicit max dimensions to control the output size - image_feat["preprocessing"]["infer_image_dimensions"] = True - image_feat["preprocessing"]["infer_image_max_height"] = height - image_feat["preprocessing"]["infer_image_max_width"] = width - image_feat["preprocessing"]["num_channels"] = metaformer_channels - image_feat["preprocessing"]["resize_method"] = "interpolate" # Use interpolation for better quality - image_feat["preprocessing"]["standardize_image"] = "imagenet1k" # Use ImageNet standardization - # Force Ludwig to respect our dimensions by setting additional parameters - image_feat["preprocessing"]["requires_equal_dimensions"] = False - logger.info(f"Set preprocessing dimensions for MetaFormer: {height}x{width} (infer_dimensions=True with max dimensions to allow validation)") - # Now set the encoder configuration - image_feat["encoder"] = encoder_config - - if config_params.get("augmentation") is not None: - image_feat["augmentation"] = config_params["augmentation"] - - # Add resize configuration for standard encoders (ResNet, etc.) - # FIXED: MetaFormer models now respect user dimensions completely - # Previously there was a double resize issue where MetaFormer would force 224x224 - # Now both MetaFormer and standard encoders respect user's resize choice - if (not is_metaformer) and config_params.get("image_resize") and config_params["image_resize"] != "original": - try: - dimensions = config_params["image_resize"].split("x") - if len(dimensions) == 2: - height, width = int(dimensions[0]), int(dimensions[1]) - if height <= 0 or width <= 0: - raise ValueError( - f"Image resize must be positive integers, received {config_params['image_resize']}." - ) - - # Add resize to preprocessing for standard encoders - if "preprocessing" not in image_feat: - image_feat["preprocessing"] = {} - image_feat["preprocessing"]["height"] = height - image_feat["preprocessing"]["width"] = width - # Use infer_image_dimensions=True to allow Ludwig to read images for validation - # but set explicit max dimensions to control the output size - image_feat["preprocessing"]["infer_image_dimensions"] = True - image_feat["preprocessing"]["infer_image_max_height"] = height - image_feat["preprocessing"]["infer_image_max_width"] = width - logger.info(f"Added resize preprocessing: {height}x{width} for standard encoder with infer_image_dimensions=True and max dimensions") - except (ValueError, IndexError): - logger.warning(f"Invalid image resize format: {config_params['image_resize']}, skipping resize preprocessing") - if task_type == "regression": - output_feat = { - "name": LABEL_COLUMN_NAME, - "type": "number", - "decoder": {"type": "regressor", "input_size": 1}, - "loss": {"type": "mean_squared_error"}, - "evaluation": { - "metrics": [ - "mean_squared_error", - "mean_absolute_error", - "r2", - ] - }, - } - val_metric = config_params.get("validation_metric", "mean_squared_error") - - else: - num_unique_labels = ( - label_series.nunique() if label_series is not None else 2 - ) - output_type = "binary" if num_unique_labels == 2 else "category" - # Determine if this is regression or classification based on label type - is_regression = ( - label_series is not None - and ptypes.is_numeric_dtype(label_series.dtype) - and label_series.nunique() > 10 - ) - - if is_regression: - output_feat = { - "name": LABEL_COLUMN_NAME, - "type": "number", - "decoder": {"type": "regressor", "input_size": 1}, - "loss": {"type": "mean_squared_error"}, - } - else: - if num_unique_labels == 2: - output_feat = { - "name": LABEL_COLUMN_NAME, - "type": "binary", - "decoder": {"type": "classifier", "input_size": 1}, - "loss": {"type": "softmax_cross_entropy"}, - } - else: - output_feat = { - "name": LABEL_COLUMN_NAME, - "type": "category", - "decoder": {"type": "classifier", "input_size": num_unique_labels}, - "loss": {"type": "softmax_cross_entropy"}, - } - if output_type == "binary" and config_params.get("threshold") is not None: - output_feat["threshold"] = float(config_params["threshold"]) - val_metric = None - - conf: Dict[str, Any] = { - "model_type": "ecd", - "input_features": [image_feat], - "output_features": [output_feat], - "combiner": {"type": "concat"}, - "trainer": { - "epochs": epochs, - "early_stop": early_stop, - "batch_size": batch_size_cfg, - "learning_rate": learning_rate, - # only set validation_metric for regression - **({"validation_metric": val_metric} if val_metric else {}), - }, - "preprocessing": { - "split": split_config, - "num_processes": num_processes, - "in_memory": False, - }, - } - - logger.debug("LudwigDirectBackend: Config dict built.") - try: - yaml_str = yaml.dump(conf, sort_keys=False, indent=2) - logger.info("LudwigDirectBackend: YAML config generated.") - return yaml_str - except Exception: - logger.error( - "LudwigDirectBackend: Failed to serialize YAML.", - exc_info=True, - ) - raise - - def run_experiment( - self, - dataset_path: Path, - config_path: Path, - output_dir: Path, - random_seed: int = 42, - ) -> None: - """Invoke Ludwig's internal experiment_cli function to run the experiment.""" - logger.info("LudwigDirectBackend: Starting experiment execution.") - - try: - from ludwig.experiment import experiment_cli - except ImportError as e: - logger.error( - "LudwigDirectBackend: Could not import experiment_cli.", - exc_info=True, - ) - raise RuntimeError("Ludwig import failed.") from e - - output_dir.mkdir(parents=True, exist_ok=True) - - try: - experiment_cli( - dataset=str(dataset_path), - config=str(config_path), - output_directory=str(output_dir), - random_seed=random_seed, - skip_preprocessing=True, - ) - logger.info( - f"LudwigDirectBackend: Experiment completed. Results in {output_dir}" - ) - except TypeError as e: - logger.error( - "LudwigDirectBackend: Argument mismatch in experiment_cli call.", - exc_info=True, - ) - raise RuntimeError("Ludwig argument error.") from e - except Exception: - logger.error( - "LudwigDirectBackend: Experiment execution error.", - exc_info=True, - ) - raise - - def get_training_process(self, output_dir) -> Optional[Dict[str, Any]]: - """Retrieve the learning rate used in the most recent Ludwig run.""" - output_dir = Path(output_dir) - exp_dirs = sorted( - output_dir.glob("experiment_run*"), - key=lambda p: p.stat().st_mtime, - ) - - if not exp_dirs: - logger.warning(f"No experiment run directories found in {output_dir}") - return None - - progress_file = exp_dirs[-1] / "model" / "training_progress.json" - if not progress_file.exists(): - logger.warning(f"No training_progress.json found in {progress_file}") - return None - - try: - with progress_file.open("r", encoding="utf-8") as f: - data = json.load(f) - return { - "learning_rate": data.get("learning_rate"), - "batch_size": data.get("batch_size"), - "epoch": data.get("epoch"), - } - except Exception as e: - logger.warning(f"Failed to read training progress info: {e}") - return {} - - def convert_parquet_to_csv(self, output_dir: Path): - """Convert the predictions Parquet file to CSV.""" - output_dir = Path(output_dir) - exp_dirs = sorted( - output_dir.glob("experiment_run*"), - key=lambda p: p.stat().st_mtime, - ) - if not exp_dirs: - logger.warning(f"No experiment run dirs found in {output_dir}") - return - exp_dir = exp_dirs[-1] - parquet_path = exp_dir / PREDICTIONS_PARQUET_FILE_NAME - csv_path = exp_dir / "predictions.csv" - - # Check if parquet file exists before trying to convert - if not parquet_path.exists(): - logger.info(f"Predictions parquet file not found at {parquet_path}, skipping conversion") - return - - try: - df = pd.read_parquet(parquet_path) - df.to_csv(csv_path, index=False) - logger.info(f"Converted Parquet to CSV: {csv_path}") - except Exception as e: - logger.error(f"Error converting Parquet to CSV: {e}") - - def generate_plots(self, output_dir: Path) -> None: - """Generate all registered Ludwig visualizations for the latest experiment run.""" - logger.info("Generating all Ludwig visualizations…") - - test_plots = { - "compare_performance", - "compare_classifiers_performance_from_prob", - "compare_classifiers_performance_from_pred", - "compare_classifiers_performance_changing_k", - "compare_classifiers_multiclass_multimetric", - "compare_classifiers_predictions", - "confidence_thresholding_2thresholds_2d", - "confidence_thresholding_2thresholds_3d", - "confidence_thresholding", - "confidence_thresholding_data_vs_acc", - "binary_threshold_vs_metric", - "roc_curves", - "roc_curves_from_test_statistics", - "calibration_1_vs_all", - "calibration_multiclass", - "confusion_matrix", - "frequency_vs_f1", - } - train_plots = { - "learning_curves", - "compare_classifiers_performance_subset", - } - - output_dir = Path(output_dir) - exp_dirs = sorted( - output_dir.glob("experiment_run*"), - key=lambda p: p.stat().st_mtime, - ) - if not exp_dirs: - logger.warning(f"No experiment run dirs found in {output_dir}") - return - exp_dir = exp_dirs[-1] - - viz_dir = exp_dir / "visualizations" - viz_dir.mkdir(exist_ok=True) - train_viz = viz_dir / "train" - test_viz = viz_dir / "test" - train_viz.mkdir(parents=True, exist_ok=True) - test_viz.mkdir(parents=True, exist_ok=True) - - def _check(p: Path) -> Optional[str]: - return str(p) if p.exists() else None - - training_stats = _check(exp_dir / "training_statistics.json") - test_stats = _check(exp_dir / TEST_STATISTICS_FILE_NAME) - probs_path = _check(exp_dir / PREDICTIONS_PARQUET_FILE_NAME) - gt_metadata = _check(exp_dir / "model" / TRAIN_SET_METADATA_FILE_NAME) - - dataset_path = None - split_file = None - desc = exp_dir / DESCRIPTION_FILE_NAME - if desc.exists(): - with open(desc, "r") as f: - cfg = json.load(f) - dataset_path = _check(Path(cfg.get("dataset", ""))) - split_file = _check(Path(get_split_path(cfg.get("dataset", "")))) - - output_feature = "" - if desc.exists(): - try: - output_feature = cfg["config"]["output_features"][0]["name"] - except Exception: - pass - if not output_feature and test_stats: - with open(test_stats, "r") as f: - stats = json.load(f) - output_feature = next(iter(stats.keys()), "") - - viz_registry = get_visualizations_registry() - for viz_name, viz_func in viz_registry.items(): - if viz_name in train_plots: - viz_dir_plot = train_viz - elif viz_name in test_plots: - viz_dir_plot = test_viz - else: - continue - - try: - viz_func( - training_statistics=[training_stats] if training_stats else [], - test_statistics=[test_stats] if test_stats else [], - probabilities=[probs_path] if probs_path else [], - output_feature_name=output_feature, - ground_truth_split=2, - top_n_classes=[0], - top_k=3, - ground_truth_metadata=gt_metadata, - ground_truth=dataset_path, - split_file=split_file, - output_directory=str(viz_dir_plot), - normalize=False, - file_format="png", - ) - logger.info(f"✔ Generated {viz_name}") - except Exception as e: - logger.warning(f"✘ Skipped {viz_name}: {e}") - - logger.info(f"All visualizations written to {viz_dir}") - - def generate_html_report( - self, - title: str, - output_dir: str, - config: dict, - split_info: str, - ) -> Path: - """Assemble an HTML report from visualizations under train_val/ and test/ folders.""" - cwd = Path.cwd() - report_name = title.lower().replace(" ", "_") + "_report.html" - report_path = cwd / report_name - output_dir = Path(output_dir) - output_type = None - - exp_dirs = sorted( - output_dir.glob("experiment_run*"), - key=lambda p: p.stat().st_mtime, - ) - if not exp_dirs: - raise RuntimeError(f"No 'experiment*' dirs found in {output_dir}") - exp_dir = exp_dirs[-1] - - base_viz_dir = exp_dir / "visualizations" - train_viz_dir = base_viz_dir / "train" - test_viz_dir = base_viz_dir / "test" - - html = get_html_template() - - # Extra CSS & JS: center Plotly and enable CSV download for predictions table - html += """ -<style> - /* Center Plotly figures (both wrapper and native classes) */ - .plotly-center { display: flex; justify-content: center; } - .plotly-center .plotly-graph-div, .plotly-center .js-plotly-plot { margin: 0 auto !important; } - .js-plotly-plot, .plotly-graph-div { margin-left: auto !important; margin-right: auto !important; } - - /* Download button for predictions table */ - .download-btn { - padding: 8px 12px; - border: 1px solid #4CAF50; - background: #4CAF50; - color: white; - border-radius: 6px; - cursor: pointer; - } - .download-btn:hover { filter: brightness(0.95); } - .preds-controls { - display: flex; - justify-content: flex-end; - gap: 8px; - margin: 8px 0; - } -</style> -<script> - function tableToCSV(table){ - const rows = Array.from(table.querySelectorAll('tr')); - return rows.map(row => - Array.from(row.querySelectorAll('th,td')).map(cell => { - let text = cell.innerText.replace(/\\r?\\n|\\r/g,' ').trim(); - if (text.includes('"') || text.includes(',')) { - text = '"' + text.replace(/"/g,'""') + '"'; - } - return text; - }).join(',') - ).join('\\n'); - } - document.addEventListener('DOMContentLoaded', function(){ - const btn = document.getElementById('downloadPredsCsv'); - if(btn){ - btn.addEventListener('click', function(){ - const tbl = document.querySelector('.predictions-table'); - if(!tbl){ alert('Predictions table not found.'); return; } - const csv = tableToCSV(tbl); - const blob = new Blob([csv], {type: 'text/csv;charset=utf-8;'}); - const url = URL.createObjectURL(blob); - const a = document.createElement('a'); - a.href = url; - a.download = 'ground_truth_vs_predictions.csv'; - document.body.appendChild(a); - a.click(); - document.body.removeChild(a); - URL.revokeObjectURL(url); - }); - } - }); -</script> -""" - html += f"<h1>{title}</h1>" - - metrics_html = "" - train_val_metrics_html = "" - test_metrics_html = "" - try: - train_stats_path = exp_dir / "training_statistics.json" - test_stats_path = exp_dir / TEST_STATISTICS_FILE_NAME - if train_stats_path.exists() and test_stats_path.exists(): - with open(train_stats_path) as f: - train_stats = json.load(f) - with open(test_stats_path) as f: - test_stats = json.load(f) - output_type = detect_output_type(test_stats) - metrics_html = format_stats_table_html(train_stats, test_stats, output_type) - train_val_metrics_html = format_train_val_stats_table_html( - train_stats, test_stats - ) - test_metrics_html = format_test_merged_stats_table_html( - extract_metrics_from_json(train_stats, test_stats, output_type)[ - "test" - ], output_type - ) - except Exception as e: - logger.warning( - f"Could not load stats for HTML report: {type(e).__name__}: {e}" - ) - - config_html = "" - training_progress = self.get_training_process(output_dir) - try: - config_html = format_config_table_html( - config, split_info, training_progress, output_type - ) - except Exception as e: - logger.warning(f"Could not load config for HTML report: {e}") - - # ---------- image rendering with exclusions ---------- - def render_img_section( - title: str, - dir_path: Path, - output_type: str = None, - exclude_names: Optional[set] = None, - ) -> str: - if not dir_path.exists(): - return f"<h2>{title}</h2><p><em>Directory not found.</em></p>" - - exclude_names = exclude_names or set() - - imgs = list(dir_path.glob("*.png")) - - # Exclude ROC curves and standard confusion matrices (keep only entropy version) - default_exclude = { - # "roc_curves.png", # Remove ROC curves from test tab - "confusion_matrix__label_top5.png", # Remove standard confusion matrix - "confusion_matrix__label_top10.png", # Remove duplicate - "confusion_matrix__label_top6.png", # Remove duplicate - "confusion_matrix_entropy__label_top10.png", # Keep only top5 - "confusion_matrix_entropy__label_top6.png", # Keep only top5 - } - - imgs = [ - img - for img in imgs - if img.name not in default_exclude - and img.name not in exclude_names - ] - - if not imgs: - return f"<h2>{title}</h2><p><em>No plots found.</em></p>" - - # Sort images by name for consistent ordering (works with string and numeric labels) - imgs = sorted(imgs, key=lambda x: x.name) - - html_section = "" - for img in imgs: - b64 = encode_image_to_base64(str(img)) - img_title = img.stem.replace("_", " ").title() - html_section += ( - f"<h2 style='text-align: center;'>{img_title}</h2>" - f'<div class="plot" style="margin-bottom:20px;text-align:center;">' - f'<img src="data:image/png;base64,{b64}" ' - f'style="max-width:90%;max-height:600px;border:1px solid #ddd;" />' - f"</div>" - ) - return html_section - - tab1_content = config_html + metrics_html - - tab2_content = train_val_metrics_html + render_img_section( - "Training and Validation Visualizations", - train_viz_dir, - output_type, - exclude_names={ - "compare_classifiers_performance_from_prob.png", - "roc_curves_from_prediction_statistics.png", - "precision_recall_curves_from_prediction_statistics.png", - "precision_recall_curve.png", - }, - ) - - # --- Predictions vs Ground Truth table (REGRESSION ONLY) --- - preds_section = "" - parquet_path = exp_dir / PREDICTIONS_PARQUET_FILE_NAME - if output_type == "regression" and parquet_path.exists(): - try: - # 1) load predictions from Parquet - df_preds = pd.read_parquet(parquet_path).reset_index(drop=True) - # assume the column containing your model's prediction is named "prediction" - # or contains that substring: - pred_col = next( - (c for c in df_preds.columns if "prediction" in c.lower()), - None, - ) - if pred_col is None: - raise ValueError("No prediction column found in Parquet output") - df_pred = df_preds[[pred_col]].rename(columns={pred_col: "prediction"}) - - # 2) load ground truth for the test split from prepared CSV - df_all = pd.read_csv(config["label_column_data_path"]) - df_gt = df_all[df_all[SPLIT_COLUMN_NAME] == 2][ - LABEL_COLUMN_NAME - ].reset_index(drop=True) - # 3) concatenate side-by-side - df_table = pd.concat([df_gt, df_pred], axis=1) - df_table.columns = [LABEL_COLUMN_NAME, "prediction"] - - # 4) render as HTML - preds_html = df_table.to_html(index=False, classes="predictions-table") - preds_section = ( - "<h2 style='text-align: center;'>Ground Truth vs. Predictions</h2>" - "<div class='preds-controls'>" - "<button id='downloadPredsCsv' class='download-btn'>Download CSV</button>" - "</div>" - "<div class='scroll-rows-30' style='overflow-x:auto; overflow-y:auto; max-height:900px; margin-bottom:20px;'>" - + preds_html - + "</div>" - ) - except Exception as e: - logger.warning(f"Could not build Predictions vs GT table: {e}") - - tab3_content = test_metrics_html + preds_section - - if output_type in ("binary", "category") and test_stats_path.exists(): - try: - interactive_plots = build_classification_plots( - str(test_stats_path), - str(train_stats_path) if train_stats_path.exists() else None, - ) - for plot in interactive_plots: - tab3_content += ( - f"<h2 style='text-align: center;'>{plot['title']}</h2>" - f"<div class='plotly-center'>{plot['html']}</div>" - ) - logger.info(f"Generated {len(interactive_plots)} interactive Plotly plots") - except Exception as e: - logger.warning(f"Could not generate Plotly plots: {e}") - - # Add static TEST PNGs (with default dedupe/exclusions) - tab3_content += render_img_section( - "Test Visualizations", test_viz_dir, output_type - ) - - tabbed_html = build_tabbed_html(tab1_content, tab2_content, tab3_content) - modal_html = get_metrics_help_modal() - html += tabbed_html + modal_html + get_html_closing() - - try: - with open(report_path, "w") as f: - f.write(html) - logger.info(f"HTML report generated at: {report_path}") - except Exception as e: - logger.error(f"Failed to write HTML report: {e}") - raise - - return report_path - - -class WorkflowOrchestrator: - """Manages the image-classification workflow.""" - - def __init__(self, args: argparse.Namespace, backend: Backend): - self.args = args - self.backend = backend - self.temp_dir: Optional[Path] = None - self.image_extract_dir: Optional[Path] = None - logger.info(f"Orchestrator initialized with backend: {type(backend).__name__}") - - def run(self) -> None: - """Execute the full workflow end-to-end.""" - # Delegate to the backend's run_experiment method - self.backend.run_experiment() - - -class ImageLearnerCLI: - """Manages the image-classification workflow.""" - - def __init__(self, args: argparse.Namespace, backend: Backend): - self.args = args - self.backend = backend - self.temp_dir: Optional[Path] = None - self.image_extract_dir: Optional[Path] = None - logger.info(f"Orchestrator initialized with backend: {type(backend).__name__}") - - def _create_temp_dirs(self) -> None: - """Create temporary output and image extraction directories.""" - try: - self.temp_dir = Path( - tempfile.mkdtemp(dir=self.args.output_dir, prefix=TEMP_DIR_PREFIX) - ) - self.image_extract_dir = self.temp_dir / "images" - self.image_extract_dir.mkdir() - logger.info(f"Created temp directory: {self.temp_dir}") - except Exception: - logger.error("Failed to create temporary directories", exc_info=True) - raise - - def _extract_images(self) -> None: - """Extract images into the temp image directory. - - If a ZIP file is provided, extract it - - If a directory is provided, copy its contents - """ - if self.image_extract_dir is None: - raise RuntimeError("Temp image directory not initialized.") - src = Path(self.args.image_zip) - logger.info(f"Preparing images from {src} → {self.image_extract_dir}") - try: - if src.is_dir(): - # copy directory tree - for root, dirs, files in os.walk(src): - rel = Path(root).relative_to(src) - target_root = self.image_extract_dir / rel - target_root.mkdir(parents=True, exist_ok=True) - for fn in files: - shutil.copy2(Path(root) / fn, target_root / fn) - logger.info("Image directory copied.") - else: - with zipfile.ZipFile(src, "r") as z: - z.extractall(self.image_extract_dir) - logger.info("Image extraction complete.") - except Exception: - logger.error("Error preparing images", exc_info=True) - raise - - def _process_fixed_split( - self, df: pd.DataFrame - ) -> Tuple[pd.DataFrame, Dict[str, Any], str]: - """Process datasets that already have a split column.""" - unique = set(df[SPLIT_COLUMN_NAME].unique()) - if unique == {0, 2}: - # Split 0/2 detected, create validation set - df = split_data_0_2( - df=df, - split_column=SPLIT_COLUMN_NAME, - validation_size=self.args.validation_size, - random_state=self.args.random_seed, - label_column=LABEL_COLUMN_NAME, - ) - split_config = {"type": "fixed", "column": SPLIT_COLUMN_NAME} - split_info = ( - "Detected a split column (with values 0 and 2) in the input CSV. " - f"Used this column as a base and reassigned " - f"{self.args.validation_size * 100:.1f}% " - "of the training set (originally labeled 0) to validation (labeled 1) using stratified sampling." - ) - logger.info("Applied custom 0/2 split.") - elif unique.issubset({0, 1, 2}): - # Standard 0/1/2 split - split_config = {"type": "fixed", "column": SPLIT_COLUMN_NAME} - split_info = ( - "Detected a split column with train(0)/validation(1)/test(2) " - "values in the input CSV. Used this column as-is." - ) - logger.info("Fixed split column detected.") - else: - raise ValueError( - f"Split column contains unexpected values: {unique}. " - "Expected: {{0,1,2}} or {{0,2}}" - ) - - return df, split_config, split_info - - def _prepare_data(self) -> Tuple[Path, Dict[str, Any], str]: - """Load CSV, update image paths, handle splits, and write prepared CSV.""" - if not self.temp_dir or not self.image_extract_dir: - raise RuntimeError("Temp dirs not initialized before data prep.") - - try: - df = pd.read_csv(self.args.csv_file) - logger.info(f"Loaded CSV: {self.args.csv_file}") - except Exception: - logger.error("Error loading CSV file", exc_info=True) - raise - - required = {IMAGE_PATH_COLUMN_NAME, LABEL_COLUMN_NAME} - missing = required - set(df.columns) - if missing: - raise ValueError(f"Missing CSV columns: {', '.join(missing)}") - - try: - # Use relative paths that Ludwig can resolve from its internal working directory - df[IMAGE_PATH_COLUMN_NAME] = df[IMAGE_PATH_COLUMN_NAME].apply( - lambda p: str(Path("images") / p) - ) - except Exception: - logger.error("Error updating image paths", exc_info=True) - raise - - if SPLIT_COLUMN_NAME in df.columns: - df, split_config, split_info = self._process_fixed_split(df) - else: - logger.info("No split column; creating stratified random split") - df = create_stratified_random_split( - df=df, - split_column=SPLIT_COLUMN_NAME, - split_probabilities=self.args.split_probabilities, - random_state=self.args.random_seed, - label_column=LABEL_COLUMN_NAME, - ) - split_config = { - "type": "fixed", - "column": SPLIT_COLUMN_NAME, - } - split_info = ( - f"No split column in CSV. Created stratified random split: " - f"{[int(p * 100) for p in self.args.split_probabilities]}% " - f"for train/val/test with balanced label distribution." - ) - - final_csv = self.temp_dir / TEMP_CSV_FILENAME - - try: - - df.to_csv(final_csv, index=False) - logger.info(f"Saved prepared data to {final_csv}") - except Exception: - logger.error("Error saving prepared CSV", exc_info=True) - raise - - return final_csv, split_config, split_info - -# Removed duplicate method - - def _detect_image_dimensions(self) -> Tuple[int, int]: - """Detect image dimensions from the first image in the dataset.""" - try: - import zipfile - from PIL import Image - import io - - # Check if image_zip is provided - if not self.args.image_zip: - logger.warning("No image zip provided, using default 224x224") - return 224, 224 - - # Extract first image to detect dimensions - with zipfile.ZipFile(self.args.image_zip, 'r') as z: - image_files = [f for f in z.namelist() if f.lower().endswith(('.png', '.jpg', '.jpeg'))] - if not image_files: - logger.warning("No image files found in zip, using default 224x224") - return 224, 224 - - # Check first image - with z.open(image_files[0]) as f: - img = Image.open(io.BytesIO(f.read())) - width, height = img.size - logger.info(f"Detected image dimensions: {width}x{height}") - return height, width # Return as (height, width) to match encoder config - - except Exception as e: - logger.warning(f"Error detecting image dimensions: {e}, using default 224x224") - return 224, 224 - - def _cleanup_temp_dirs(self) -> None: - if self.temp_dir and self.temp_dir.exists(): - logger.info(f"Cleaning up temp directory: {self.temp_dir}") - # Don't clean up for debugging - shutil.rmtree(self.temp_dir, ignore_errors=True) - self.temp_dir = None - self.image_extract_dir = None - - def run(self) -> None: - """Execute the full workflow end-to-end.""" - logger.info("Starting workflow...") - self.args.output_dir.mkdir(parents=True, exist_ok=True) - - try: - self._create_temp_dirs() - self._extract_images() - csv_path, split_cfg, split_info = self._prepare_data() - - use_pretrained = self.args.use_pretrained or self.args.fine_tune - - backend_args = { - "model_name": self.args.model_name, - "fine_tune": self.args.fine_tune, - "use_pretrained": use_pretrained, - "epochs": self.args.epochs, - "batch_size": self.args.batch_size, - "preprocessing_num_processes": self.args.preprocessing_num_processes, - "split_probabilities": self.args.split_probabilities, - "learning_rate": self.args.learning_rate, - "random_seed": self.args.random_seed, - "early_stop": self.args.early_stop, - "label_column_data_path": csv_path, - "augmentation": self.args.augmentation, - "image_resize": self.args.image_resize, - "image_zip": self.args.image_zip, - "threshold": self.args.threshold, - } - yaml_str = self.backend.prepare_config(backend_args, split_cfg) - - config_file = self.temp_dir / TEMP_CONFIG_FILENAME - config_file.write_text(yaml_str) - logger.info(f"Wrote backend config: {config_file}") - - ran_ok = True - try: - # Run Ludwig experiment with absolute paths to avoid working directory issues - self.backend.run_experiment( - csv_path, - config_file, - self.args.output_dir, - self.args.random_seed, - ) - except Exception: - logger.error("Workflow execution failed", exc_info=True) - ran_ok = False - - if ran_ok: - logger.info("Workflow completed successfully.") - # Generate a very small set of plots to conserve disk space - self.backend.generate_plots(self.args.output_dir) - # Build HTML report (robust to missing metrics) - report_file = self.backend.generate_html_report( - "Image Classification Results", - self.args.output_dir, - backend_args, - split_info, - ) - logger.info(f"HTML report generated at: {report_file}") - # Convert predictions parquet → csv - self.backend.convert_parquet_to_csv(self.args.output_dir) - logger.info("Converted Parquet to CSV.") - # Post-process cleanup to reduce disk footprint for subsequent tests - try: - self._postprocess_cleanup(self.args.output_dir) - except Exception as cleanup_err: - logger.warning(f"Cleanup step failed: {cleanup_err}") - else: - # Fallback: create minimal outputs so downstream steps can proceed - logger.warning("Falling back to minimal outputs due to runtime failure.") - try: - self._create_minimal_outputs(self.args.output_dir, csv_path) - # Even in fallback, produce an HTML shell so tests find required text - report_file = self.backend.generate_html_report( - "Image Classification Results", - self.args.output_dir, - backend_args, - split_info, - ) - logger.info(f"HTML report (fallback) generated at: {report_file}") - except Exception as fb_err: - logger.error(f"Failed to build fallback outputs: {fb_err}") - raise - - except Exception: - logger.error("Workflow execution failed", exc_info=True) - raise - finally: - self._cleanup_temp_dirs() - - def _postprocess_cleanup(self, output_dir: Path) -> None: - """Remove large intermediates and caches to conserve disk space across tests.""" - output_dir = Path(output_dir) - exp_dirs = sorted( - output_dir.glob("experiment_run*"), - key=lambda p: p.stat().st_mtime, - ) - if exp_dirs: - exp_dir = exp_dirs[-1] - # Remove training checkpoints directory if present - ckpt_dir = exp_dir / "model" / "training_checkpoints" - if ckpt_dir.exists(): - shutil.rmtree(ckpt_dir, ignore_errors=True) - # Remove predictions parquet once CSV is generated - parquet_path = exp_dir / PREDICTIONS_PARQUET_FILE_NAME - if parquet_path.exists(): - try: - parquet_path.unlink() - except Exception: - pass - - # Clear torch hub cache under the job-scoped home, if present - job_home_torch_hub = Path.cwd() / "home" / ".cache" / "torch" / "hub" - if job_home_torch_hub.exists(): - shutil.rmtree(job_home_torch_hub, ignore_errors=True) - - # Also try the default user cache as a best-effort (may not exist in job sandbox) - user_home_torch_hub = Path.home() / ".cache" / "torch" / "hub" - if user_home_torch_hub.exists(): - shutil.rmtree(user_home_torch_hub, ignore_errors=True) - - # Clear huggingface cache if present in the job sandbox - job_home_hf = Path.cwd() / "home" / ".cache" / "huggingface" - if job_home_hf.exists(): - shutil.rmtree(job_home_hf, ignore_errors=True) - - def _create_minimal_outputs(self, output_dir: Path, prepared_csv_path: Path) -> None: - """Create a minimal set of outputs so Galaxy can collect expected artifacts. - - - experiment_run/ - - predictions.csv (1 column) - - visualizations/train/ (empty) - - visualizations/test/ (empty) - - model/ - - model_weights/ (empty) - - model_hyperparameters.json (stub) - """ - output_dir = Path(output_dir) - exp_dir = output_dir / "experiment_run" - (exp_dir / "visualizations" / "train").mkdir(parents=True, exist_ok=True) - (exp_dir / "visualizations" / "test").mkdir(parents=True, exist_ok=True) - model_dir = exp_dir / "model" - (model_dir / "model_weights").mkdir(parents=True, exist_ok=True) - - # Stub JSON so the tool's copy step succeeds - try: - (model_dir / "model_hyperparameters.json").write_text("{}\n") - except Exception: - pass - - # Create a small predictions.csv with exactly 1 column - try: - df_all = pd.read_csv(prepared_csv_path) - from constants import SPLIT_COLUMN_NAME # local import to avoid cycle at top - num_rows = int((df_all[SPLIT_COLUMN_NAME] == 2).sum()) if SPLIT_COLUMN_NAME in df_all.columns else 1 - except Exception: - num_rows = 1 - num_rows = max(1, num_rows) - pd.DataFrame({"prediction": [0] * num_rows}).to_csv(exp_dir / "predictions.csv", index=False) - - -def parse_learning_rate(s): - try: - return float(s) - except (TypeError, ValueError): - return None - - -def aug_parse(aug_string: str): - """ - Parse comma-separated augmentation keys into Ludwig augmentation dicts. - Raises ValueError on unknown key. - """ - mapping = { - "random_horizontal_flip": {"type": "random_horizontal_flip"}, - "random_vertical_flip": {"type": "random_vertical_flip"}, - "random_rotate": {"type": "random_rotate", "degree": 10}, - "random_blur": {"type": "random_blur", "kernel_size": 3}, - "random_brightness": {"type": "random_brightness", "min": 0.5, "max": 2.0}, - "random_contrast": {"type": "random_contrast", "min": 0.5, "max": 2.0}, - } - aug_list = [] - for tok in aug_string.split(","): - key = tok.strip() - if not key: - continue - if key not in mapping: - valid = ", ".join(mapping.keys()) - raise ValueError(f"Unknown augmentation '{key}'. Valid choices: {valid}") - aug_list.append(mapping[key]) - return aug_list - - -class SplitProbAction(argparse.Action): - def __call__(self, parser, namespace, values, option_string=None): - train, val, test = values - total = train + val + test - if abs(total - 1.0) > 1e-6: - parser.error( - f"--split-probabilities must sum to 1.0; " - f"got {train:.3f} + {val:.3f} + {test:.3f} = {total:.3f}" - ) - setattr(namespace, self.dest, values) - def main(): parser = argparse.ArgumentParser( @@ -1893,7 +30,7 @@ "--csv-file", required=True, type=Path, - help="Path to the input CSV", + help="Path to the input metadata file (CSV, TSV, etc)", ) parser.add_argument( "--image-zip", @@ -2008,18 +145,7 @@ args = parser.parse_args() - if not 0.0 <= args.validation_size <= 1.0: - parser.error("validation-size must be between 0.0 and 1.0") - if not args.csv_file.is_file(): - parser.error(f"CSV not found: {args.csv_file}") - if not (args.image_zip.is_file() or args.image_zip.is_dir()): - parser.error(f"ZIP or directory not found: {args.image_zip}") - if args.augmentation is not None: - try: - augmentation_setup = aug_parse(args.augmentation) - setattr(args, "augmentation", augmentation_setup) - except ValueError as e: - parser.error(str(e)) + argument_checker(args, parser) backend_instance = LudwigDirectBackend() orchestrator = ImageLearnerCLI(args, backend_instance)
