# HG changeset patch # User goeckslab # Date 1748985731 0 # Node ID 54b871dfc51ec8ad8e8213a639bfcd36c676fb1a planemo upload for repository https://github.com/goeckslab/gleam.git commit b7411ff35b6228ccdfd36cd4ebd946c03ac7f7e9 diff -r 000000000000 -r 54b871dfc51e image_learner.xml --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/image_learner.xml Tue Jun 03 21:22:11 2025 +0000 @@ -0,0 +1,270 @@ + + trains and evaluates a image classification model + + quay.io/goeckslab/galaxy-ludwig-gpu:0.10.1 + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +@misc{https://doi.org/10.48550/arxiv.1909.07930, + doi = {10.48550/ARXIV.1909.07930}, + url = {https://arxiv.org/abs/1909.07930}, + author = {Molino, Piero and Dudin, Yaroslav and Miryala, Sai Sumanth}, + title = {Ludwig: a type-based declarative deep learning toolbox}, + publisher = {arXiv}, + year = {2019}, + copyright = {arXiv.org perpetual, non-exclusive license} +} + + + diff -r 000000000000 -r 54b871dfc51e image_learner_cli.py --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/image_learner_cli.py Tue Jun 03 21:22:11 2025 +0000 @@ -0,0 +1,1137 @@ +#!/usr/bin/env python3 +import argparse +import json +import logging +import os +import shutil +import sys +import tempfile +import zipfile +from pathlib import Path +from typing import Any, Dict, Optional, Protocol, Tuple + +import pandas as pd +import yaml +from ludwig.globals import ( + DESCRIPTION_FILE_NAME, + PREDICTIONS_PARQUET_FILE_NAME, + TEST_STATISTICS_FILE_NAME, + TRAIN_SET_METADATA_FILE_NAME, +) +from ludwig.utils.data_utils import get_split_path +from ludwig.visualize import get_visualizations_registry +from sklearn.model_selection import train_test_split +from utils import encode_image_to_base64, get_html_closing, get_html_template + +# --- Constants --- +SPLIT_COLUMN_NAME = 'split' +LABEL_COLUMN_NAME = 'label' +IMAGE_PATH_COLUMN_NAME = 'image_path' +DEFAULT_SPLIT_PROBABILITIES = [0.7, 0.1, 0.2] +TEMP_CSV_FILENAME = "processed_data_for_ludwig.csv" +TEMP_CONFIG_FILENAME = "ludwig_config.yaml" +TEMP_DIR_PREFIX = "ludwig_api_work_" +MODEL_ENCODER_TEMPLATES: Dict[str, Any] = { + 'stacked_cnn': 'stacked_cnn', + 'resnet18': {'type': 'resnet', 'model_variant': 18}, + 'resnet34': {'type': 'resnet', 'model_variant': 34}, + 'resnet50': {'type': 'resnet', 'model_variant': 50}, + 'resnet101': {'type': 'resnet', 'model_variant': 101}, + 'resnet152': {'type': 'resnet', 'model_variant': 152}, + 'resnext50_32x4d': {'type': 'resnext', 'model_variant': '50_32x4d'}, + 'resnext101_32x8d': {'type': 'resnext', 'model_variant': '101_32x8d'}, + 'resnext101_64x4d': {'type': 'resnext', 'model_variant': '101_64x4d'}, + 'resnext152_32x8d': {'type': 'resnext', 'model_variant': '152_32x8d'}, + 'wide_resnet50_2': {'type': 'wide_resnet', 'model_variant': '50_2'}, + 'wide_resnet101_2': {'type': 'wide_resnet', 'model_variant': '101_2'}, + 'wide_resnet103_2': {'type': 'wide_resnet', 'model_variant': '103_2'}, + 'efficientnet_b0': {'type': 'efficientnet', 'model_variant': 'b0'}, + 'efficientnet_b1': {'type': 'efficientnet', 'model_variant': 'b1'}, + 'efficientnet_b2': {'type': 'efficientnet', 'model_variant': 'b2'}, + 'efficientnet_b3': {'type': 'efficientnet', 'model_variant': 'b3'}, + 'efficientnet_b4': {'type': 'efficientnet', 'model_variant': 'b4'}, + 'efficientnet_b5': {'type': 'efficientnet', 'model_variant': 'b5'}, + 'efficientnet_b6': {'type': 'efficientnet', 'model_variant': 'b6'}, + 'efficientnet_b7': {'type': 'efficientnet', 'model_variant': 'b7'}, + 'efficientnet_v2_s': {'type': 'efficientnet', 'model_variant': 'v2_s'}, + 'efficientnet_v2_m': {'type': 'efficientnet', 'model_variant': 'v2_m'}, + 'efficientnet_v2_l': {'type': 'efficientnet', 'model_variant': 'v2_l'}, + 'regnet_y_400mf': {'type': 'regnet', 'model_variant': 'y_400mf'}, + 'regnet_y_800mf': {'type': 'regnet', 'model_variant': 'y_800mf'}, + 'regnet_y_1_6gf': {'type': 'regnet', 'model_variant': 'y_1_6gf'}, + 'regnet_y_3_2gf': {'type': 'regnet', 'model_variant': 'y_3_2gf'}, + 'regnet_y_8gf': {'type': 'regnet', 'model_variant': 'y_8gf'}, + 'regnet_y_16gf': {'type': 'regnet', 'model_variant': 'y_16gf'}, + 'regnet_y_32gf': {'type': 'regnet', 'model_variant': 'y_32gf'}, + 'regnet_y_128gf': {'type': 'regnet', 'model_variant': 'y_128gf'}, + 'regnet_x_400mf': {'type': 'regnet', 'model_variant': 'x_400mf'}, + 'regnet_x_800mf': {'type': 'regnet', 'model_variant': 'x_800mf'}, + 'regnet_x_1_6gf': {'type': 'regnet', 'model_variant': 'x_1_6gf'}, + 'regnet_x_3_2gf': {'type': 'regnet', 'model_variant': 'x_3_2gf'}, + 'regnet_x_8gf': {'type': 'regnet', 'model_variant': 'x_8gf'}, + 'regnet_x_16gf': {'type': 'regnet', 'model_variant': 'x_16gf'}, + 'regnet_x_32gf': {'type': 'regnet', 'model_variant': 'x_32gf'}, + 'vgg11': {'type': 'vgg', 'model_variant': 11}, + 'vgg11_bn': {'type': 'vgg', 'model_variant': '11_bn'}, + 'vgg13': {'type': 'vgg', 'model_variant': 13}, + 'vgg13_bn': {'type': 'vgg', 'model_variant': '13_bn'}, + 'vgg16': {'type': 'vgg', 'model_variant': 16}, + 'vgg16_bn': {'type': 'vgg', 'model_variant': '16_bn'}, + 'vgg19': {'type': 'vgg', 'model_variant': 19}, + 'vgg19_bn': {'type': 'vgg', 'model_variant': '19_bn'}, + 'shufflenet_v2_x0_5': {'type': 'shufflenet_v2', 'model_variant': 'x0_5'}, + 'shufflenet_v2_x1_0': {'type': 'shufflenet_v2', 'model_variant': 'x1_0'}, + 'shufflenet_v2_x1_5': {'type': 'shufflenet_v2', 'model_variant': 'x1_5'}, + 'shufflenet_v2_x2_0': {'type': 'shufflenet_v2', 'model_variant': 'x2_0'}, + 'squeezenet1_0': {'type': 'squeezenet', 'model_variant': '1_0'}, + 'squeezenet1_1': {'type': 'squeezenet', 'model_variant': '1_1'}, + 'swin_t': {'type': 'swin_transformer', 'model_variant': 't'}, + 'swin_s': {'type': 'swin_transformer', 'model_variant': 's'}, + 'swin_b': {'type': 'swin_transformer', 'model_variant': 'b'}, + 'swin_v2_t': {'type': 'swin_transformer', 'model_variant': 'v2_t'}, + 'swin_v2_s': {'type': 'swin_transformer', 'model_variant': 'v2_s'}, + 'swin_v2_b': {'type': 'swin_transformer', 'model_variant': 'v2_b'}, + 'vit_b_16': {'type': 'vision_transformer', 'model_variant': 'b_16'}, + 'vit_b_32': {'type': 'vision_transformer', 'model_variant': 'b_32'}, + 'vit_l_16': {'type': 'vision_transformer', 'model_variant': 'l_16'}, + 'vit_l_32': {'type': 'vision_transformer', 'model_variant': 'l_32'}, + 'vit_h_14': {'type': 'vision_transformer', 'model_variant': 'h_14'}, + 'convnext_tiny': {'type': 'convnext', 'model_variant': 'tiny'}, + 'convnext_small': {'type': 'convnext', 'model_variant': 'small'}, + 'convnext_base': {'type': 'convnext', 'model_variant': 'base'}, + 'convnext_large': {'type': 'convnext', 'model_variant': 'large'}, + 'maxvit_t': {'type': 'maxvit', 'model_variant': 't'}, + 'alexnet': {'type': 'alexnet'}, + 'googlenet': {'type': 'googlenet'}, + 'inception_v3': {'type': 'inception_v3'}, + 'mobilenet_v2': {'type': 'mobilenet_v2'}, + 'mobilenet_v3_large': {'type': 'mobilenet_v3_large'}, + 'mobilenet_v3_small': {'type': 'mobilenet_v3_small'}, +} + +# --- Logging Setup --- +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s %(levelname)s %(name)s: %(message)s' +) +logger = logging.getLogger("ImageLearner") + + +def format_config_table_html( + config: dict, + split_info: Optional[str] = None, + training_progress: dict = None) -> str: + display_keys = [ + "model_name", + "epochs", + "batch_size", + "fine_tune", + "use_pretrained", + "learning_rate", + "random_seed", + "early_stop", + ] + + rows = [] + + for key in display_keys: + val = config.get(key, "N/A") + if key == "batch_size": + if val is not None: + val = int(val) + else: + if training_progress: + val = "Auto-selected batch size by Ludwig:
" + resolved_val = training_progress.get("batch_size") + val += ( + f"{resolved_val}
" + ) + else: + val = "auto" + if key == "learning_rate": + resolved_val = None + if val is None or val == "auto": + if training_progress: + resolved_val = training_progress.get("learning_rate") + val = ( + "Auto-selected learning rate by Ludwig:
" + f"{resolved_val if resolved_val else val}
" + "" + "Based on model architecture and training setup (e.g., fine-tuning).
" + "See Ludwig Trainer Parameters for details." + "
" + ) + else: + val = ( + "Auto-selected by Ludwig
" + "" + "Automatically tuned based on architecture and dataset.
" + "See Ludwig Trainer Parameters for details." + "
" + ) + else: + val = f"{val:.6f}" + if key == "epochs": + if training_progress and "epoch" in training_progress and val > training_progress["epoch"]: + val = ( + f"Because of early stopping: the training" + f"stopped at epoch {training_progress['epoch']}" + ) + + if val is None: + continue + rows.append( + f"" + f"" + f"{key.replace('_', ' ').title()}" + f"{val}" + f"" + ) + + if split_info: + rows.append( + f"" + f"Data Split" + f"{split_info}" + f"" + ) + + return ( + "

Training Setup

" + "
" + "" + "" + "" + "" + "" + "".join(rows) + "
ParameterValue

" + "

" + "Model trained using Ludwig.
" + "If want to learn more about Ludwig default settings," + "please check the their website(ludwig.ai)." + "


" + ) + + +def format_stats_table_html(training_stats: dict, test_stats: dict) -> str: + train_metrics = training_stats.get("training", {}).get("label", {}) + val_metrics = training_stats.get("validation", {}).get("label", {}) + test_metrics = test_stats.get("label", {}) + + all_metrics = set(train_metrics) | set(val_metrics) | set(test_metrics) + + def get_last_value(stats, key): + val = stats.get(key) + if isinstance(val, list) and val: + return val[-1] + elif isinstance(val, (int, float)): + return val + return None + + rows = [] + for metric in sorted(all_metrics): + t = get_last_value(train_metrics, metric) + v = get_last_value(val_metrics, metric) + te = get_last_value(test_metrics, metric) + if all(x is not None for x in [t, v, te]): + row = ( + f"" + f"{metric}" + f"{t:.4f}" + f"{v:.4f}" + f"{te:.4f}" + f"" + ) + rows.append(row) + + if not rows: + return "

No metric values found.

" + + return ( + "

Model Performance Summary

" + "
" + "" + "" + "" + "" + "" + "" + "" + "" + "" + "" + "" + "" + "" + "".join(rows) + "
MetricTrainValidationTest

" + ) + + +def build_tabbed_html( + metrics_html: str, + train_viz_html: str, + test_viz_html: str) -> str: + return f""" + + +
+
Config & Metrics
+
Train/Validation Plots
+
Test Plots
+
+ +
+ {metrics_html} +
+
+ {train_viz_html} +
+
+ {test_viz_html} +
+ + +""" + + +def split_data_0_2( + df: pd.DataFrame, + split_column: str, + validation_size: float = 0.15, + random_state: int = 42, + label_column: Optional[str] = None, +) -> pd.DataFrame: + """ + Given a DataFrame whose split_column only contains {0,2}, re-assign + a portion of the 0s to become 1s (validation). Returns a fresh DataFrame. + """ + # Work on a copy + out = df.copy() + # Ensure split col is integer dtype + out[split_column] = pd.to_numeric(out[split_column], errors="coerce").astype(int) + + idx_train = out.index[out[split_column] == 0].tolist() + + if not idx_train: + logger.info("No rows with split=0; nothing to do.") + return out + + # Determine stratify array if possible + stratify_arr = None + if label_column and label_column in out.columns: + # Only stratify if at least two classes and enough samples + label_counts = out.loc[idx_train, label_column].value_counts() + if label_counts.size > 1 and (label_counts.min() * validation_size) >= 1: + stratify_arr = out.loc[idx_train, label_column] + else: + logger.warning("Cannot stratify (too few labels); splitting without stratify.") + + # Edge cases + if validation_size <= 0: + logger.info("validation_size <= 0; keeping all as train.") + return out + if validation_size >= 1: + logger.info("validation_size >= 1; moving all train → validation.") + out.loc[idx_train, split_column] = 1 + return out + + # Do the split + try: + train_idx, val_idx = train_test_split( + idx_train, + test_size=validation_size, + random_state=random_state, + stratify=stratify_arr + ) + except ValueError as e: + logger.warning(f"Stratified split failed ({e}); retrying without stratify.") + train_idx, val_idx = train_test_split( + idx_train, + test_size=validation_size, + random_state=random_state, + stratify=None + ) + + # Assign new splits + out.loc[train_idx, split_column] = 0 + out.loc[val_idx, split_column] = 1 + # idx_test stays at 2 + + # Cast back to a clean integer type + out[split_column] = out[split_column].astype(int) + # print(out) + return out + + +class Backend(Protocol): + """Interface for a machine learning backend.""" + def prepare_config( + self, + config_params: Dict[str, Any], + split_config: Dict[str, Any] + ) -> str: + ... + + def run_experiment( + self, + dataset_path: Path, + config_path: Path, + output_dir: Path, + random_seed: int, + ) -> None: + ... + + def generate_plots( + self, + output_dir: Path + ) -> None: + ... + + def generate_html_report( + self, + title: str, + output_dir: str + ) -> Path: + ... + + +class LudwigDirectBackend: + """ + Backend for running Ludwig experiments directly via the internal experiment_cli function. + """ + + def prepare_config( + self, + config_params: Dict[str, Any], + split_config: Dict[str, Any], + ) -> str: + """ + Build and serialize the Ludwig YAML configuration. + """ + logger.info("LudwigDirectBackend: Preparing YAML configuration.") + + model_name = config_params.get("model_name", "resnet18") + use_pretrained = config_params.get("use_pretrained", False) + fine_tune = config_params.get("fine_tune", False) + epochs = config_params.get("epochs", 10) + batch_size = config_params.get("batch_size") + num_processes = config_params.get("preprocessing_num_processes", 1) + early_stop = config_params.get("early_stop", None) + learning_rate = config_params.get("learning_rate") + learning_rate = "auto" if learning_rate is None else float(learning_rate) + trainable = fine_tune or (not use_pretrained) + if not use_pretrained and not trainable: + logger.warning("trainable=False; use_pretrained=False is ignored.") + logger.warning("Setting trainable=True to train the model from scratch.") + trainable = True + + # Encoder setup + raw_encoder = MODEL_ENCODER_TEMPLATES.get(model_name, model_name) + if isinstance(raw_encoder, dict): + encoder_config = { + **raw_encoder, + "use_pretrained": use_pretrained, + "trainable": trainable, + } + else: + encoder_config = {"type": raw_encoder} + + # Trainer & optimizer + # optimizer = {"type": "adam", "learning_rate": 5e-5} if fine_tune else {"type": "adam"} + batch_size_cfg = batch_size or "auto" + + conf: Dict[str, Any] = { + "model_type": "ecd", + "input_features": [ + { + "name": IMAGE_PATH_COLUMN_NAME, + "type": "image", + "encoder": encoder_config, + } + ], + "output_features": [ + {"name": LABEL_COLUMN_NAME, "type": "category"} + ], + "combiner": {"type": "concat"}, + "trainer": { + "epochs": epochs, + "early_stop": early_stop, + "batch_size": batch_size_cfg, + "learning_rate": learning_rate, + }, + "preprocessing": { + "split": split_config, + "num_processes": num_processes, + "in_memory": False, + }, + } + + logger.debug("LudwigDirectBackend: Config dict built.") + try: + yaml_str = yaml.dump(conf, sort_keys=False, indent=2) + logger.info("LudwigDirectBackend: YAML config generated.") + return yaml_str + except Exception: + logger.error("LudwigDirectBackend: Failed to serialize YAML.", exc_info=True) + raise + + def run_experiment( + self, + dataset_path: Path, + config_path: Path, + output_dir: Path, + random_seed: int = 42, + ) -> None: + """ + Invoke Ludwig's internal experiment_cli function to run the experiment. + """ + logger.info("LudwigDirectBackend: Starting experiment execution.") + + try: + from ludwig.experiment import experiment_cli + except ImportError as e: + logger.error( + "LudwigDirectBackend: Could not import experiment_cli.", + exc_info=True + ) + raise RuntimeError("Ludwig import failed.") from e + + output_dir.mkdir(parents=True, exist_ok=True) + + try: + experiment_cli( + dataset=str(dataset_path), + config=str(config_path), + output_directory=str(output_dir), + random_seed=random_seed, + ) + logger.info(f"LudwigDirectBackend: Experiment completed. Results in {output_dir}") + except TypeError as e: + logger.error( + "LudwigDirectBackend: Argument mismatch in experiment_cli call.", + exc_info=True + ) + raise RuntimeError("Ludwig argument error.") from e + except Exception: + logger.error( + "LudwigDirectBackend: Experiment execution error.", + exc_info=True + ) + raise + + def get_training_process(self, output_dir) -> float: + """ + Retrieve the learning rate used in the most recent Ludwig run. + Returns: + float: learning rate (or None if not found) + """ + output_dir = Path(output_dir) + exp_dirs = sorted( + output_dir.glob("experiment_run*"), + key=lambda p: p.stat().st_mtime + ) + + if not exp_dirs: + logger.warning(f"No experiment run directories found in {output_dir}") + return None + + progress_file = exp_dirs[-1] / "model" / "training_progress.json" + if not progress_file.exists(): + logger.warning(f"No training_progress.json found in {progress_file}") + return None + + try: + with progress_file.open("r", encoding="utf-8") as f: + data = json.load(f) + return { + "learning_rate": data.get("learning_rate"), + "batch_size": data.get("batch_size"), + "epoch": data.get("epoch"), + } + except Exception as e: + self.logger.warning(f"Failed to read training progress info: {e}") + return {} + + def convert_parquet_to_csv(self, output_dir: Path): + """Convert the predictions Parquet file to CSV.""" + output_dir = Path(output_dir) + exp_dirs = sorted( + output_dir.glob("experiment_run*"), + key=lambda p: p.stat().st_mtime + ) + if not exp_dirs: + logger.warning(f"No experiment run dirs found in {output_dir}") + return + exp_dir = exp_dirs[-1] + parquet_path = exp_dir / PREDICTIONS_PARQUET_FILE_NAME + csv_path = exp_dir / "predictions.csv" + try: + df = pd.read_parquet(parquet_path) + df.to_csv(csv_path, index=False) + logger.info(f"Converted Parquet to CSV: {csv_path}") + except Exception as e: + logger.error(f"Error converting Parquet to CSV: {e}") + + def generate_plots(self, output_dir: Path) -> None: + """ + Generate _all_ registered Ludwig visualizations for the latest experiment run. + """ + logger.info("Generating all Ludwig visualizations…") + + test_plots = { + 'compare_performance', + 'compare_classifiers_performance_from_prob', + 'compare_classifiers_performance_from_pred', + 'compare_classifiers_performance_changing_k', + 'compare_classifiers_multiclass_multimetric', + 'compare_classifiers_predictions', + 'confidence_thresholding_2thresholds_2d', + 'confidence_thresholding_2thresholds_3d', + 'confidence_thresholding', + 'confidence_thresholding_data_vs_acc', + 'binary_threshold_vs_metric', + 'roc_curves', + 'roc_curves_from_test_statistics', + 'calibration_1_vs_all', + 'calibration_multiclass', + 'confusion_matrix', + 'frequency_vs_f1', + } + train_plots = { + 'learning_curves', + 'compare_classifiers_performance_subset', + } + + # 1) find the most recent experiment directory + output_dir = Path(output_dir) + exp_dirs = sorted( + output_dir.glob("experiment_run*"), + key=lambda p: p.stat().st_mtime + ) + if not exp_dirs: + logger.warning(f"No experiment run dirs found in {output_dir}") + return + exp_dir = exp_dirs[-1] + + # 2) ensure viz output subfolder exists + viz_dir = exp_dir / "visualizations" + viz_dir.mkdir(exist_ok=True) + train_viz = viz_dir / "train" + test_viz = viz_dir / "test" + train_viz.mkdir(parents=True, exist_ok=True) + test_viz.mkdir(parents=True, exist_ok=True) + + # 3) helper to check file existence + def _check(p: Path) -> Optional[str]: + return str(p) if p.exists() else None + + # 4) gather standard Ludwig output files + training_stats = _check(exp_dir / "training_statistics.json") + test_stats = _check(exp_dir / TEST_STATISTICS_FILE_NAME) + probs_path = _check(exp_dir / PREDICTIONS_PARQUET_FILE_NAME) + gt_metadata = _check(exp_dir / "model" / TRAIN_SET_METADATA_FILE_NAME) + + # 5) try to read original dataset & split file from description.json + dataset_path = None + split_file = None + desc = exp_dir / DESCRIPTION_FILE_NAME + if desc.exists(): + with open(desc, "r") as f: + cfg = json.load(f) + dataset_path = _check(Path(cfg.get("dataset", ""))) + split_file = _check(Path(get_split_path(cfg.get("dataset", "")))) + + # 6) infer output feature name + output_feature = "" + if desc.exists(): + try: + output_feature = cfg["config"]["output_features"][0]["name"] + except Exception: + pass + if not output_feature and test_stats: + with open(test_stats, "r") as f: + stats = json.load(f) + output_feature = next(iter(stats.keys()), "") + + # 7) loop through every registered viz + viz_registry = get_visualizations_registry() + for viz_name, viz_func in viz_registry.items(): + viz_dir_plot = None + if viz_name in train_plots: + viz_dir_plot = train_viz + elif viz_name in test_plots: + viz_dir_plot = test_viz + + try: + viz_func( + training_statistics=[training_stats] if training_stats else [], + test_statistics=[test_stats] if test_stats else [], + probabilities=[probs_path] if probs_path else [], + output_feature_name=output_feature, + ground_truth_split=2, + top_n_classes=[0], + top_k=3, + ground_truth_metadata=gt_metadata, + ground_truth=dataset_path, + split_file=split_file, + output_directory=str(viz_dir_plot), + normalize=False, + file_format="png", + ) + logger.info(f"✔ Generated {viz_name}") + except Exception as e: + logger.warning(f"✘ Skipped {viz_name}: {e}") + + logger.info(f"All visualizations written to {viz_dir}") + + def generate_html_report( + self, + title: str, + output_dir: str, + config: dict, + split_info: str) -> Path: + """ + Assemble an HTML report from visualizations under train_val/ and test/ folders. + """ + cwd = Path.cwd() + report_name = title.lower().replace(" ", "_") + "_report.html" + report_path = cwd / report_name + output_dir = Path(output_dir) + + # Find latest experiment dir + exp_dirs = sorted(output_dir.glob("experiment_run*"), key=lambda p: p.stat().st_mtime) + if not exp_dirs: + raise RuntimeError(f"No 'experiment*' dirs found in {output_dir}") + exp_dir = exp_dirs[-1] + + base_viz_dir = exp_dir / "visualizations" + train_viz_dir = base_viz_dir / "train" + test_viz_dir = base_viz_dir / "test" + + html = get_html_template() + html += f"

{title}

" + + metrics_html = "" + + # Load and embed metrics table (training/val/test stats) + try: + train_stats_path = exp_dir / "training_statistics.json" + test_stats_path = exp_dir / TEST_STATISTICS_FILE_NAME + if train_stats_path.exists() and test_stats_path.exists(): + with open(train_stats_path) as f: + train_stats = json.load(f) + with open(test_stats_path) as f: + test_stats = json.load(f) + output_feature = next(iter(train_stats.keys()), "") + if output_feature: + metrics_html += format_stats_table_html(train_stats, test_stats) + except Exception as e: + logger.warning(f"Could not load stats for HTML report: {e}") + + config_html = "" + training_progress = self.get_training_process(output_dir) + try: + config_html = format_config_table_html(config, split_info, training_progress) + except Exception as e: + logger.warning(f"Could not load config for HTML report: {e}") + + def render_img_section(title: str, dir_path: Path) -> str: + if not dir_path.exists(): + return f"

{title}

Directory not found.

" + imgs = sorted(dir_path.glob("*.png")) + if not imgs: + return f"

{title}

No plots found.

" + + section_html = f"

{title}

" + for img in imgs: + b64 = encode_image_to_base64(str(img)) + section_html += ( + f'
' + f"

{img.stem.replace('_',' ').title()}

" + f'' + "
" + ) + section_html += "
" + return section_html + + train_plots_html = render_img_section("Training & Validation Visualizations", train_viz_dir) + test_plots_html = render_img_section("Test Visualizations", test_viz_dir) + html += build_tabbed_html(config_html + metrics_html, train_plots_html, test_plots_html) + html += get_html_closing() + + try: + with open(report_path, "w") as f: + f.write(html) + logger.info(f"HTML report generated at: {report_path}") + except Exception as e: + logger.error(f"Failed to write HTML report: {e}") + raise + + return report_path + + +class WorkflowOrchestrator: + """ + Manages the image-classification workflow: + 1. Creates temp dirs + 2. Extracts images + 3. Prepares data (CSV + splits) + 4. Renders a backend config + 5. Runs the experiment + 6. Cleans up + """ + + def __init__(self, args: argparse.Namespace, backend: Backend): + self.args = args + self.backend = backend + self.temp_dir: Optional[Path] = None + self.image_extract_dir: Optional[Path] = None + logger.info(f"Orchestrator initialized with backend: {type(backend).__name__}") + + def _create_temp_dirs(self) -> None: + """Create temporary output and image extraction directories.""" + try: + self.temp_dir = Path(tempfile.mkdtemp( + dir=self.args.output_dir, + prefix=TEMP_DIR_PREFIX + )) + self.image_extract_dir = self.temp_dir / "images" + self.image_extract_dir.mkdir() + logger.info(f"Created temp directory: {self.temp_dir}") + except Exception: + logger.error("Failed to create temporary directories", exc_info=True) + raise + + def _extract_images(self) -> None: + """Extract images from ZIP into the temp image directory.""" + if self.image_extract_dir is None: + raise RuntimeError("Temp image directory not initialized.") + logger.info(f"Extracting images from {self.args.image_zip} → {self.image_extract_dir}") + try: + with zipfile.ZipFile(self.args.image_zip, "r") as z: + z.extractall(self.image_extract_dir) + logger.info("Image extraction complete.") + except Exception: + logger.error("Error extracting zip file", exc_info=True) + raise + + def _prepare_data(self) -> Tuple[Path, Dict[str, Any]]: + """ + Load CSV, update image paths, handle splits, and write prepared CSV. + Returns: + final_csv_path: Path to the prepared CSV + split_config: Dict for backend split settings + """ + if not self.temp_dir or not self.image_extract_dir: + raise RuntimeError("Temp dirs not initialized before data prep.") + + # 1) Load + try: + df = pd.read_csv(self.args.csv_file) + logger.info(f"Loaded CSV: {self.args.csv_file}") + except Exception: + logger.error("Error loading CSV file", exc_info=True) + raise + + # 2) Validate columns + required = {IMAGE_PATH_COLUMN_NAME, LABEL_COLUMN_NAME} + missing = required - set(df.columns) + if missing: + raise ValueError(f"Missing CSV columns: {', '.join(missing)}") + + # 3) Update image paths + try: + df[IMAGE_PATH_COLUMN_NAME] = df[IMAGE_PATH_COLUMN_NAME].apply( + lambda p: str((self.image_extract_dir / p).resolve()) + ) + except Exception: + logger.error("Error updating image paths", exc_info=True) + raise + + # 4) Handle splits + if SPLIT_COLUMN_NAME in df.columns: + df, split_config, split_info = self._process_fixed_split(df) + else: + logger.info("No split column; using random split") + split_config = { + "type": "random", + "probabilities": self.args.split_probabilities + } + split_info = ( + f"No split column in CSV. Used random split: " + f"{[int(p*100) for p in self.args.split_probabilities]}% for train/val/test." + ) + + # 5) Write out prepared CSV + final_csv = TEMP_CSV_FILENAME + try: + df.to_csv(final_csv, index=False) + logger.info(f"Saved prepared data to {final_csv}") + except Exception: + logger.error("Error saving prepared CSV", exc_info=True) + raise + + return final_csv, split_config, split_info + + def _process_fixed_split(self, df: pd.DataFrame) -> Dict[str, Any]: + """Process a fixed split column (0=train,1=val,2=test).""" + logger.info(f"Fixed split column '{SPLIT_COLUMN_NAME}' detected.") + try: + col = df[SPLIT_COLUMN_NAME] + df[SPLIT_COLUMN_NAME] = pd.to_numeric(col, errors="coerce").astype(pd.Int64Dtype()) + if df[SPLIT_COLUMN_NAME].isna().any(): + logger.warning("Split column contains non-numeric/missing values.") + + unique = set(df[SPLIT_COLUMN_NAME].dropna().unique()) + logger.info(f"Unique split values: {unique}") + + if unique == {0, 2}: + df = split_data_0_2( + df, SPLIT_COLUMN_NAME, + validation_size=self.args.validation_size, + label_column=LABEL_COLUMN_NAME, + random_state=self.args.random_seed + ) + split_info = ( + "Detected a split column (with values 0 and 2) in the input CSV. " + f"Used this column as a base and" + f"reassigned {self.args.validation_size * 100:.1f}% " + "of the training set (originally labeled 0) to validation (labeled 1)." + ) + + logger.info("Applied custom 0/2 split.") + elif unique.issubset({0, 1, 2}): + split_info = "Used user-defined split column from CSV." + logger.info("Using fixed split as-is.") + else: + raise ValueError(f"Unexpected split values: {unique}") + + return df, {"type": "fixed", "column": SPLIT_COLUMN_NAME}, split_info + + except Exception: + logger.error("Error processing fixed split", exc_info=True) + raise + + def _cleanup_temp_dirs(self) -> None: + """Remove any temporary directories.""" + if self.temp_dir and self.temp_dir.exists(): + logger.info(f"Cleaning up temp directory: {self.temp_dir}") + shutil.rmtree(self.temp_dir, ignore_errors=True) + self.temp_dir = None + self.image_extract_dir = None + + def run(self) -> None: + """Execute the full workflow end-to-end.""" + logger.info("Starting workflow...") + self.args.output_dir.mkdir(parents=True, exist_ok=True) + + try: + self._create_temp_dirs() + self._extract_images() + csv_path, split_cfg, split_info = self._prepare_data() + + use_pretrained = self.args.use_pretrained or self.args.fine_tune + + backend_args = { + "model_name": self.args.model_name, + "fine_tune": self.args.fine_tune, + "use_pretrained": use_pretrained, + "epochs": self.args.epochs, + "batch_size": self.args.batch_size, + "preprocessing_num_processes": self.args.preprocessing_num_processes, + "split_probabilities": self.args.split_probabilities, + "learning_rate": self.args.learning_rate, + "random_seed": self.args.random_seed, + "early_stop": self.args.early_stop, + } + yaml_str = self.backend.prepare_config(backend_args, split_cfg) + + config_file = self.temp_dir / TEMP_CONFIG_FILENAME + config_file.write_text(yaml_str) + logger.info(f"Wrote backend config: {config_file}") + + self.backend.run_experiment( + csv_path, + config_file, + self.args.output_dir, + self.args.random_seed + ) + logger.info("Workflow completed successfully.") + self.backend.generate_plots(self.args.output_dir) + report_file = self.backend.generate_html_report( + "Image Classification Results", + self.args.output_dir, + backend_args, + split_info + ) + logger.info(f"HTML report generated at: {report_file}") + self.backend.convert_parquet_to_csv(self.args.output_dir) + logger.info("Converted Parquet to CSV.") + except Exception: + logger.error("Workflow execution failed", exc_info=True) + raise + + finally: + self._cleanup_temp_dirs() + + +def parse_learning_rate(s): + try: + return float(s) + except (TypeError, ValueError): + return None + + +class SplitProbAction(argparse.Action): + def __call__(self, parser, namespace, values, option_string=None): + # values is a list of three floats + train, val, test = values + total = train + val + test + if abs(total - 1.0) > 1e-6: + parser.error( + f"--split-probabilities must sum to 1.0; " + f"got {train:.3f} + {val:.3f} + {test:.3f} = {total:.3f}" + ) + setattr(namespace, self.dest, values) + + +def main(): + + parser = argparse.ArgumentParser( + description="Image Classification Learner with Pluggable Backends" + ) + parser.add_argument( + "--csv-file", required=True, type=Path, + help="Path to the input CSV" + ) + parser.add_argument( + "--image-zip", required=True, type=Path, + help="Path to the images ZIP" + ) + parser.add_argument( + "--model-name", required=True, + choices=MODEL_ENCODER_TEMPLATES.keys(), + help="Which model template to use" + ) + parser.add_argument( + "--use-pretrained", action="store_true", + help="Use pretrained weights for the model" + ) + parser.add_argument( + "--fine-tune", action="store_true", + help="Enable fine-tuning" + ) + parser.add_argument( + "--epochs", type=int, default=10, + help="Number of training epochs" + ) + parser.add_argument( + "--early-stop", type=int, default=5, + help="Early stopping patience" + ) + parser.add_argument( + "--batch-size", type=int, + help="Batch size (None = auto)" + ) + parser.add_argument( + "--output-dir", type=Path, default=Path("learner_output"), + help="Where to write outputs" + ) + parser.add_argument( + "--validation-size", type=float, default=0.15, + help="Fraction for validation (0.0–1.0)" + ) + parser.add_argument( + "--preprocessing-num-processes", type=int, + default=max(1, os.cpu_count() // 2), + help="CPU processes for data prep" + ) + parser.add_argument( + "--split-probabilities", type=float, nargs=3, + metavar=("train", "val", "test"), + action=SplitProbAction, + default=[0.7, 0.1, 0.2], + help="Random split proportions (e.g., 0.7 0.1 0.2). Only used if no split column is present." + ) + parser.add_argument( + "--random-seed", type=int, default=42, + help="Random seed used for dataset splitting (default: 42)" + ) + parser.add_argument( + "--learning-rate", type=parse_learning_rate, default=None, + help="Learning rate. If not provided, Ludwig will auto-select it." + ) + + args = parser.parse_args() + + # -- Validation -- + if not 0.0 <= args.validation_size <= 1.0: + parser.error("validation-size must be between 0.0 and 1.0") + if not args.csv_file.is_file(): + parser.error(f"CSV not found: {args.csv_file}") + if not args.image_zip.is_file(): + parser.error(f"ZIP not found: {args.image_zip}") + + # --- Instantiate Backend and Orchestrator --- + # Use the new LudwigDirectBackend + backend_instance = LudwigDirectBackend() + orchestrator = WorkflowOrchestrator(args, backend_instance) + + # --- Run Workflow --- + exit_code = 0 + try: + orchestrator.run() + logger.info("Main script finished successfully.") + except Exception as e: + logger.error(f"Main script failed.{e}") + exit_code = 1 + finally: + sys.exit(exit_code) + + +if __name__ == '__main__': + try: + import ludwig + logger.debug(f"Found Ludwig version: {ludwig.globals.LUDWIG_VERSION}") + except ImportError: + logger.error("Ludwig library not found. Please ensure Ludwig is installed ('pip install ludwig[image]')") + sys.exit(1) + + main() diff -r 000000000000 -r 54b871dfc51e test-data/image_classification_results_report_mnist.html --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/test-data/image_classification_results_report_mnist.html Tue Jun 03 21:22:11 2025 +0000 @@ -0,0 +1,129 @@ + + + + Galaxy-Ludwig Report + + + +
+

Image Classification Results

+ + +
+
Config & Metrics
+
Train/Validation Plots
+
Test Plots
+
+ +
+

Training Setup

ParameterValue
Model Nameresnet18
EpochsBecause of early stopping: the trainingstopped at epoch 7
Batch SizeAuto-selected batch size by Ludwig:
16
Fine TuneTrue
Use PretrainedTrue
Learning RateAuto-selected learning rate by Ludwig:
1e-05
Based on model architecture and training setup (e.g., fine-tuning).
See Ludwig Trainer Parameters for details.
Random Seedearly StopN/A
Data SplitDetected a split column (with values 0 and 2) in the input CSV. Used this column as a base andreassigned 15.0% of the training set (originally labeled 0) to validation (labeled 1).

Model trained using Ludwig.
If want to learn more about Ludwig default settings,please check the their website(ludwig.ai).


Model Performance Summary

MetricTrainValidationTest
accuracy0.84170.15000.2000
accuracy_micro0.84710.20000.2000
hits_at_k0.92500.45000.3000
loss0.67492.79072.8261
roc_auc0.99980.78240.6917

+
+
+

Training & Validation Visualizations

Learning Curves Label Accuracy

Learning Curves Label Hits At K

Learning Curves Label Loss

+
+
+

Test Visualizations

Compare Classifiers Multiclass Multimetric Label Best10

Compare Classifiers Multiclass Multimetric Label Sorted

Compare Classifiers Multiclass Multimetric Label Top10

Compare Classifiers Multiclass Multimetric Label Worst10

Compare Classifiers Performance From Prob

Compare Performance Label

Confusion Matrix Label Top10

Confusion Matrix Entropy Label Top10

Frequency Vs F1 Label

Roc Curves

+
+ + + +
+ + + \ No newline at end of file diff -r 000000000000 -r 54b871dfc51e test-data/mnist_subset.csv --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/test-data/mnist_subset.csv Tue Jun 03 21:22:11 2025 +0000 @@ -0,0 +1,121 @@ +image_path,label,split +training/0/5680.jpg,0,0 +training/0/5699.jpg,0,0 +training/0/5766.jpg,0,0 +training/0/5524.jpg,0,0 +training/0/5003.jpg,0,0 +training/0/5527.jpg,0,0 +training/0/5359.jpg,0,0 +training/0/5452.jpg,0,0 +training/0/5010.jpg,0,0 +training/0/5405.jpg,0,0 +training/1/6100.jpg,1,0 +training/1/6015.jpg,1,0 +training/1/5754.jpg,1,0 +training/1/6275.jpg,1,0 +training/1/6247.jpg,1,0 +training/1/6552.jpg,1,0 +training/1/6129.jpg,1,0 +training/1/6733.jpg,1,0 +training/1/6590.jpg,1,0 +training/1/6727.jpg,1,0 +training/2/5585.jpg,2,0 +training/2/5865.jpg,2,0 +training/2/4984.jpg,2,0 +training/2/4992.jpg,2,0 +training/2/5008.jpg,2,0 +training/2/5325.jpg,2,0 +training/2/5438.jpg,2,0 +training/2/5807.jpg,2,0 +training/2/5323.jpg,2,0 +training/2/5407.jpg,2,0 +training/3/5869.jpg,3,0 +training/3/5333.jpg,3,0 +training/3/5813.jpg,3,0 +training/3/6093.jpg,3,0 +training/3/5714.jpg,3,0 +training/3/5519.jpg,3,0 +training/3/5586.jpg,3,0 +training/3/5410.jpg,3,0 +training/3/5577.jpg,3,0 +training/3/5710.jpg,3,0 +training/4/5092.jpg,4,0 +training/4/5793.jpg,4,0 +training/4/5610.jpg,4,0 +training/4/5123.jpg,4,0 +training/4/5685.jpg,4,0 +training/4/4972.jpg,4,0 +training/4/4887.jpg,4,0 +training/4/5052.jpg,4,0 +training/4/5348.jpg,4,0 +training/4/5368.jpg,4,0 +training/5/5100.jpg,5,0 +training/5/4442.jpg,5,0 +training/5/4745.jpg,5,0 +training/5/4592.jpg,5,0 +training/5/4707.jpg,5,0 +training/5/5305.jpg,5,0 +training/5/4506.jpg,5,0 +training/5/5118.jpg,5,0 +training/5/4888.jpg,5,0 +training/5/5282.jpg,5,0 +training/6/5553.jpg,6,0 +training/6/5260.jpg,6,0 +training/6/5899.jpg,6,0 +training/6/5231.jpg,6,0 +training/6/5743.jpg,6,0 +training/6/5567.jpg,6,0 +training/6/5823.jpg,6,0 +training/6/5849.jpg,6,0 +training/6/5076.jpg,6,0 +training/6/5435.jpg,6,0 +training/7/6036.jpg,7,0 +training/7/5488.jpg,7,0 +training/7/5506.jpg,7,0 +training/7/6194.jpg,7,0 +training/7/5934.jpg,7,0 +training/7/5634.jpg,7,0 +training/7/5834.jpg,7,0 +training/7/5721.jpg,7,0 +training/7/6204.jpg,7,0 +training/7/5481.jpg,7,0 +training/8/5844.jpg,8,0 +training/8/5001.jpg,8,0 +training/8/5785.jpg,8,0 +training/8/5462.jpg,8,0 +training/8/4938.jpg,8,0 +training/8/4933.jpg,8,0 +training/8/5341.jpg,8,0 +training/8/5057.jpg,8,0 +training/8/4880.jpg,8,0 +training/8/5039.jpg,8,0 +training/9/5193.jpg,9,0 +training/9/5870.jpg,9,0 +training/9/5756.jpg,9,0 +training/9/5186.jpg,9,0 +training/9/5688.jpg,9,0 +training/9/5579.jpg,9,0 +training/9/5444.jpg,9,0 +training/9/5931.jpg,9,0 +training/9/5541.jpg,9,0 +training/9/5786.jpg,9,0 +test/0/833.jpg,0,2 +test/0/855.jpg,0,2 +test/1/1110.jpg,1,2 +test/1/969.jpg,1,2 +test/2/961.jpg,2,2 +test/2/971.jpg,2,2 +test/3/895.jpg,3,2 +test/3/1005.jpg,3,2 +test/4/940.jpg,4,2 +test/4/975.jpg,4,2 +test/5/780.jpg,5,2 +test/5/834.jpg,5,2 +test/6/932.jpg,6,2 +test/6/796.jpg,6,2 +test/7/835.jpg,7,2 +test/7/863.jpg,7,2 +test/8/899.jpg,8,2 +test/8/898.jpg,8,2 +test/9/1007.jpg,9,2 +test/9/954.jpg,9,2 diff -r 000000000000 -r 54b871dfc51e test-data/mnist_subset.zip Binary file test-data/mnist_subset.zip has changed diff -r 000000000000 -r 54b871dfc51e utils.py --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/utils.py Tue Jun 03 21:22:11 2025 +0000 @@ -0,0 +1,156 @@ +import base64 +import json + + +def get_html_template(): + return """ + + + Galaxy-Ludwig Report + + + +
+ """ + + +def get_html_closing(): + return """ +
+ + + """ + + +def encode_image_to_base64(image_path): + """Convert an image file to a base64 encoded string.""" + with open(image_path, "rb") as img_file: + return base64.b64encode(img_file.read()).decode("utf-8") + + +def json_to_nested_html_table(json_data, depth=0): + """ + Convert JSON object to an HTML nested table. + + Parameters: + json_data (dict or list): The JSON data to convert. + depth (int): Current depth level for indentation. + + Returns: + str: HTML string for the nested table. + """ + # Base case: if JSON is a simple key-value pair dictionary + if isinstance(json_data, dict) and all( + not isinstance(v, (dict, list)) for v in json_data.values() + ): + # Render a flat table + rows = [ + f"{key}{value}" + for key, value in json_data.items() + ] + return f"{''.join(rows)}
" + + # Base case: if JSON is a list of simple values + if isinstance(json_data, list) and all( + not isinstance(v, (dict, list)) for v in json_data + ): + rows = [ + f"Index {i}{value}" + for i, value in enumerate(json_data) + ] + return f"{''.join(rows)}
" + + # Recursive case: if JSON contains nested structures + if isinstance(json_data, dict): + rows = [ + f"{key}" + f"{json_to_nested_html_table(value, depth + 1)}" + for key, value in json_data.items() + ] + return f"{''.join(rows)}
" + + if isinstance(json_data, list): + rows = [ + f"[{i}]" + f"{json_to_nested_html_table(value, depth + 1)}" + for i, value in enumerate(json_data) + ] + return f"{''.join(rows)}
" + + # Base case: simple value + return f"{json_data}" + + +def json_to_html_table(json_data): + """ + Convert JSON to a vertically oriented HTML table. + + Parameters: + json_data (str or dict): JSON string or dictionary. + + Returns: + str: HTML table representation. + """ + if isinstance(json_data, str): + json_data = json.loads(json_data) + return json_to_nested_html_table(json_data)