Mercurial > repos > goeckslab > image_learner
comparison image_learner_cli.py @ 12:bcfa2e234a80 draft
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
| author | goeckslab |
|---|---|
| date | Fri, 21 Nov 2025 15:58:13 +0000 |
| parents | c5150cceab47 |
| children |
comparison
equal
deleted
inserted
replaced
| 11:c5150cceab47 | 12:bcfa2e234a80 |
|---|---|
| 1 import argparse | 1 import argparse |
| 2 import json | |
| 3 import logging | 2 import logging |
| 4 import os | 3 import os |
| 5 import shutil | |
| 6 import sys | 4 import sys |
| 7 import tempfile | |
| 8 import zipfile | |
| 9 from pathlib import Path | 5 from pathlib import Path |
| 10 from typing import Any, Dict, Optional, Protocol, Tuple | |
| 11 | 6 |
| 12 import matplotlib | 7 import matplotlib |
| 13 import numpy as np | 8 from constants import MODEL_ENCODER_TEMPLATES |
| 14 import pandas as pd | 9 from image_workflow import ImageLearnerCLI |
| 15 import pandas.api.types as ptypes | 10 from ludwig_backend import LudwigDirectBackend |
| 16 import yaml | 11 from split_data import SplitProbAction |
| 17 from constants import ( | 12 from utils import argument_checker, parse_learning_rate |
| 18 IMAGE_PATH_COLUMN_NAME, | |
| 19 LABEL_COLUMN_NAME, | |
| 20 METRIC_DISPLAY_NAMES, | |
| 21 MODEL_ENCODER_TEMPLATES, | |
| 22 SPLIT_COLUMN_NAME, | |
| 23 TEMP_CONFIG_FILENAME, | |
| 24 TEMP_CSV_FILENAME, | |
| 25 TEMP_DIR_PREFIX, | |
| 26 ) | |
| 27 from ludwig.globals import ( | |
| 28 DESCRIPTION_FILE_NAME, | |
| 29 PREDICTIONS_PARQUET_FILE_NAME, | |
| 30 TEST_STATISTICS_FILE_NAME, | |
| 31 TRAIN_SET_METADATA_FILE_NAME, | |
| 32 ) | |
| 33 from ludwig.utils.data_utils import get_split_path | |
| 34 from plotly_plots import build_classification_plots | |
| 35 from sklearn.model_selection import train_test_split | |
| 36 from utils import ( | |
| 37 build_tabbed_html, | |
| 38 encode_image_to_base64, | |
| 39 get_html_closing, | |
| 40 get_html_template, | |
| 41 get_metrics_help_modal, | |
| 42 ) | |
| 43 | 13 |
| 44 # Set matplotlib backend after imports | 14 # Set matplotlib backend after imports |
| 45 matplotlib.use('Agg') | 15 matplotlib.use('Agg') |
| 46 | 16 |
| 47 # --- Logging Setup --- | 17 # --- Logging Setup --- |
| 49 level=logging.INFO, | 19 level=logging.INFO, |
| 50 format="%(asctime)s %(levelname)s %(name)s: %(message)s", | 20 format="%(asctime)s %(levelname)s %(name)s: %(message)s", |
| 51 ) | 21 ) |
| 52 logger = logging.getLogger("ImageLearner") | 22 logger = logging.getLogger("ImageLearner") |
| 53 | 23 |
| 54 # Optional MetaFormer configuration registry | |
| 55 META_DEFAULT_CFGS: Dict[str, Any] = {} | |
| 56 try: | |
| 57 from MetaFormer import default_cfgs as META_DEFAULT_CFGS # type: ignore[attr-defined] | |
| 58 except Exception as e: | |
| 59 logger.debug("MetaFormer default configs unavailable: %s", e) | |
| 60 META_DEFAULT_CFGS = {} | |
| 61 | |
| 62 # Try to import Ludwig visualization registry (may fail due to optional dependencies) | |
| 63 # This must come AFTER logger is defined | |
| 64 _ludwig_viz_available = False | |
| 65 get_visualizations_registry = None | |
| 66 try: | |
| 67 from ludwig.visualize import get_visualizations_registry | |
| 68 _ludwig_viz_available = True | |
| 69 logger.info("Ludwig visualizations available") | |
| 70 except ImportError as e: | |
| 71 logger.warning(f"Ludwig visualizations not available: {e}. Will use fallback plots only.") | |
| 72 except Exception as e: | |
| 73 logger.warning(f"Ludwig visualizations not available due to dependency issues: {e}. Will use fallback plots only.") | |
| 74 | |
| 75 # --- MetaFormer patching integration --- | |
| 76 _metaformer_patch_ok = False | |
| 77 try: | |
| 78 from MetaFormer.metaformer_stacked_cnn import patch_ludwig_stacked_cnn as _mf_patch | |
| 79 if _mf_patch(): | |
| 80 _metaformer_patch_ok = True | |
| 81 logger.info("MetaFormer patching applied for Ludwig stacked_cnn encoder.") | |
| 82 except Exception as e: | |
| 83 logger.warning(f"MetaFormer stacked CNN not available: {e}") | |
| 84 _metaformer_patch_ok = False | |
| 85 | |
| 86 # Note: CAFormer models are now handled through MetaFormer framework | |
| 87 | |
| 88 | |
| 89 def format_config_table_html( | |
| 90 config: dict, | |
| 91 split_info: Optional[str] = None, | |
| 92 training_progress: dict = None, | |
| 93 output_type: Optional[str] = None, | |
| 94 ) -> str: | |
| 95 display_keys = [ | |
| 96 "task_type", | |
| 97 "model_name", | |
| 98 "epochs", | |
| 99 "batch_size", | |
| 100 "fine_tune", | |
| 101 "use_pretrained", | |
| 102 "learning_rate", | |
| 103 "random_seed", | |
| 104 "early_stop", | |
| 105 "threshold", | |
| 106 ] | |
| 107 | |
| 108 rows = [] | |
| 109 | |
| 110 for key in display_keys: | |
| 111 val = config.get(key, None) | |
| 112 if key == "threshold": | |
| 113 if output_type != "binary": | |
| 114 continue | |
| 115 val = val if val is not None else 0.5 | |
| 116 val_str = f"{val:.2f}" | |
| 117 if val == 0.5: | |
| 118 val_str += " (default)" | |
| 119 else: | |
| 120 if key == "task_type": | |
| 121 val_str = val.title() if isinstance(val, str) else "N/A" | |
| 122 elif key == "batch_size": | |
| 123 if val is not None: | |
| 124 val_str = int(val) | |
| 125 else: | |
| 126 val = "auto" | |
| 127 val_str = "auto" | |
| 128 resolved_val = None | |
| 129 if val is None or val == "auto": | |
| 130 if training_progress: | |
| 131 resolved_val = training_progress.get("batch_size") | |
| 132 val = ( | |
| 133 "Auto-selected batch size by Ludwig:<br>" | |
| 134 f"<span style='font-size: 0.85em;'>" | |
| 135 f"{resolved_val if resolved_val else val}</span><br>" | |
| 136 "<span style='font-size: 0.85em;'>" | |
| 137 "Based on model architecture and training setup " | |
| 138 "(e.g., fine-tuning).<br>" | |
| 139 "See <a href='https://ludwig.ai/latest/configuration/trainer/" | |
| 140 "#trainer-parameters' target='_blank'>" | |
| 141 "Ludwig Trainer Parameters</a> for details." | |
| 142 "</span>" | |
| 143 ) | |
| 144 else: | |
| 145 val = ( | |
| 146 "Auto-selected by Ludwig<br>" | |
| 147 "<span style='font-size: 0.85em;'>" | |
| 148 "Automatically tuned based on architecture and dataset.<br>" | |
| 149 "See <a href='https://ludwig.ai/latest/configuration/trainer/" | |
| 150 "#trainer-parameters' target='_blank'>" | |
| 151 "Ludwig Trainer Parameters</a> for details." | |
| 152 "</span>" | |
| 153 ) | |
| 154 elif key == "learning_rate": | |
| 155 if val is not None and val != "auto": | |
| 156 val_str = f"{val:.6f}" | |
| 157 else: | |
| 158 if training_progress: | |
| 159 resolved_val = training_progress.get("learning_rate") | |
| 160 val_str = ( | |
| 161 "Auto-selected learning rate by Ludwig:<br>" | |
| 162 f"<span style='font-size: 0.85em;'>" | |
| 163 f"{resolved_val if resolved_val else 'auto'}</span><br>" | |
| 164 "<span style='font-size: 0.85em;'>" | |
| 165 "Based on model architecture and training setup " | |
| 166 "(e.g., fine-tuning).<br>" | |
| 167 "</span>" | |
| 168 ) | |
| 169 else: | |
| 170 val_str = ( | |
| 171 "Auto-selected by Ludwig<br>" | |
| 172 "<span style='font-size: 0.85em;'>" | |
| 173 "Automatically tuned based on architecture and dataset.<br>" | |
| 174 "See <a href='https://ludwig.ai/latest/configuration/trainer/" | |
| 175 "#trainer-parameters' target='_blank'>" | |
| 176 "Ludwig Trainer Parameters</a> for details." | |
| 177 "</span>" | |
| 178 ) | |
| 179 elif key == "epochs": | |
| 180 if val is None: | |
| 181 val_str = "N/A" | |
| 182 else: | |
| 183 if ( | |
| 184 training_progress | |
| 185 and "epoch" in training_progress | |
| 186 and val > training_progress["epoch"] | |
| 187 ): | |
| 188 val_str = ( | |
| 189 f"Because of early stopping: the training " | |
| 190 f"stopped at epoch {training_progress['epoch']}" | |
| 191 ) | |
| 192 else: | |
| 193 val_str = val | |
| 194 else: | |
| 195 val_str = val if val is not None else "N/A" | |
| 196 if val_str == "N/A" and key not in ["task_type"]: | |
| 197 continue | |
| 198 rows.append( | |
| 199 f"<tr>" | |
| 200 f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left; " | |
| 201 f"white-space: normal; word-break: break-word; overflow-wrap: anywhere;'>" | |
| 202 f"{key.replace('_', ' ').title()}</td>" | |
| 203 f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center; " | |
| 204 f"white-space: normal; word-break: break-word; overflow-wrap: anywhere;'>" | |
| 205 f"{val_str}</td>" | |
| 206 f"</tr>" | |
| 207 ) | |
| 208 | |
| 209 aug_cfg = config.get("augmentation") | |
| 210 if aug_cfg: | |
| 211 types = [str(a.get("type", "")) for a in aug_cfg] | |
| 212 aug_val = ", ".join(types) | |
| 213 rows.append( | |
| 214 f"<tr><td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left; " | |
| 215 f"white-space: normal; word-break: break-word; overflow-wrap: anywhere;'>Augmentation</td>" | |
| 216 f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center; " | |
| 217 f"white-space: normal; word-break: break-word; overflow-wrap: anywhere;'>{aug_val}</td></tr>" | |
| 218 ) | |
| 219 | |
| 220 if split_info: | |
| 221 rows.append( | |
| 222 f"<tr><td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left; " | |
| 223 f"white-space: normal; word-break: break-word; overflow-wrap: anywhere;'>Data Split</td>" | |
| 224 f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center; " | |
| 225 f"white-space: normal; word-break: break-word; overflow-wrap: anywhere;'>{split_info}</td></tr>" | |
| 226 ) | |
| 227 | |
| 228 html = f""" | |
| 229 <h2 style="text-align: center;">Model and Training Summary</h2> | |
| 230 <div style="display: flex; justify-content: center;"> | |
| 231 <table style="border-collapse: collapse; width: 100%; table-layout: fixed;"> | |
| 232 <thead><tr> | |
| 233 <th style="padding: 10px; border: 1px solid #ccc; text-align: left; white-space: nowrap; overflow: hidden; text-overflow: ellipsis;">Parameter</th> | |
| 234 <th style="padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap; overflow: hidden; text-overflow: ellipsis;">Value</th> | |
| 235 </tr></thead> | |
| 236 <tbody> | |
| 237 {"".join(rows)} | |
| 238 </tbody> | |
| 239 </table> | |
| 240 </div><br> | |
| 241 <p style="text-align: center; font-size: 0.9em;"> | |
| 242 Model trained using <a href="https://ludwig.ai/" target="_blank" rel="noopener noreferrer">Ludwig</a>. | |
| 243 <a href="https://ludwig.ai/latest/configuration/" target="_blank" rel="noopener noreferrer"> | |
| 244 Ludwig documentation provides detailed information about default model and training parameters | |
| 245 </a> | |
| 246 </p><hr> | |
| 247 """ | |
| 248 return html | |
| 249 | |
| 250 | |
| 251 def detect_output_type(test_stats): | |
| 252 """Detects if the output type is 'binary' or 'category' based on test statistics.""" | |
| 253 label_stats = test_stats.get("label", {}) | |
| 254 if "mean_squared_error" in label_stats: | |
| 255 return "regression" | |
| 256 per_class = label_stats.get("per_class_stats", {}) | |
| 257 if len(per_class) == 2: | |
| 258 return "binary" | |
| 259 return "category" | |
| 260 | |
| 261 | |
| 262 def extract_metrics_from_json( | |
| 263 train_stats: dict, | |
| 264 test_stats: dict, | |
| 265 output_type: str, | |
| 266 ) -> dict: | |
| 267 """Extracts relevant metrics from training and test statistics based on the output type.""" | |
| 268 metrics = {"training": {}, "validation": {}, "test": {}} | |
| 269 | |
| 270 def get_last_value(stats, key): | |
| 271 val = stats.get(key) | |
| 272 if isinstance(val, list) and val: | |
| 273 return val[-1] | |
| 274 elif isinstance(val, (int, float)): | |
| 275 return val | |
| 276 return None | |
| 277 | |
| 278 for split in ["training", "validation"]: | |
| 279 split_stats = train_stats.get(split, {}) | |
| 280 if not split_stats: | |
| 281 logging.warning(f"No statistics found for {split} split") | |
| 282 continue | |
| 283 label_stats = split_stats.get("label", {}) | |
| 284 if not label_stats: | |
| 285 logging.warning(f"No label statistics found for {split} split") | |
| 286 continue | |
| 287 if output_type == "binary": | |
| 288 metrics[split] = { | |
| 289 "accuracy": get_last_value(label_stats, "accuracy"), | |
| 290 "loss": get_last_value(label_stats, "loss"), | |
| 291 "precision": get_last_value(label_stats, "precision"), | |
| 292 "recall": get_last_value(label_stats, "recall"), | |
| 293 "specificity": get_last_value(label_stats, "specificity"), | |
| 294 "roc_auc": get_last_value(label_stats, "roc_auc"), | |
| 295 } | |
| 296 elif output_type == "regression": | |
| 297 metrics[split] = { | |
| 298 "loss": get_last_value(label_stats, "loss"), | |
| 299 "mean_absolute_error": get_last_value( | |
| 300 label_stats, "mean_absolute_error" | |
| 301 ), | |
| 302 "mean_absolute_percentage_error": get_last_value( | |
| 303 label_stats, "mean_absolute_percentage_error" | |
| 304 ), | |
| 305 "mean_squared_error": get_last_value(label_stats, "mean_squared_error"), | |
| 306 "root_mean_squared_error": get_last_value( | |
| 307 label_stats, "root_mean_squared_error" | |
| 308 ), | |
| 309 "root_mean_squared_percentage_error": get_last_value( | |
| 310 label_stats, "root_mean_squared_percentage_error" | |
| 311 ), | |
| 312 "r2": get_last_value(label_stats, "r2"), | |
| 313 } | |
| 314 else: | |
| 315 metrics[split] = { | |
| 316 "accuracy": get_last_value(label_stats, "accuracy"), | |
| 317 "accuracy_micro": get_last_value(label_stats, "accuracy_micro"), | |
| 318 "loss": get_last_value(label_stats, "loss"), | |
| 319 "roc_auc": get_last_value(label_stats, "roc_auc"), | |
| 320 "hits_at_k": get_last_value(label_stats, "hits_at_k"), | |
| 321 } | |
| 322 | |
| 323 # Test metrics: dynamic extraction according to exclusions | |
| 324 test_label_stats = test_stats.get("label", {}) | |
| 325 if not test_label_stats: | |
| 326 logging.warning("No label statistics found for test split") | |
| 327 else: | |
| 328 combined_stats = test_stats.get("combined", {}) | |
| 329 overall_stats = test_label_stats.get("overall_stats", {}) | |
| 330 | |
| 331 # Define exclusions | |
| 332 if output_type == "binary": | |
| 333 exclude = {"per_class_stats", "precision_recall_curve", "roc_curve"} | |
| 334 else: | |
| 335 exclude = {"per_class_stats", "confusion_matrix"} | |
| 336 | |
| 337 # 1. Get all scalar test_label_stats not excluded | |
| 338 test_metrics = {} | |
| 339 for k, v in test_label_stats.items(): | |
| 340 if k in exclude: | |
| 341 continue | |
| 342 if k == "overall_stats": | |
| 343 continue | |
| 344 if isinstance(v, (int, float, str, bool)): | |
| 345 test_metrics[k] = v | |
| 346 | |
| 347 # 2. Add overall_stats (flattened) | |
| 348 for k, v in overall_stats.items(): | |
| 349 test_metrics[k] = v | |
| 350 | |
| 351 # 3. Optionally include combined/loss if present and not already | |
| 352 if "loss" in combined_stats and "loss" not in test_metrics: | |
| 353 test_metrics["loss"] = combined_stats["loss"] | |
| 354 metrics["test"] = test_metrics | |
| 355 return metrics | |
| 356 | |
| 357 | |
| 358 def generate_table_row(cells, styles): | |
| 359 """Helper function to generate an HTML table row.""" | |
| 360 return ( | |
| 361 "<tr>" | |
| 362 + "".join(f"<td style='{styles}'>{cell}</td>" for cell in cells) | |
| 363 + "</tr>" | |
| 364 ) | |
| 365 | |
| 366 | |
| 367 # ----------------------------------------- | |
| 368 # 2) MODEL PERFORMANCE (Train/Val/Test) TABLE | |
| 369 # ----------------------------------------- | |
| 370 def format_stats_table_html(train_stats: dict, test_stats: dict, output_type: str) -> str: | |
| 371 """Formats a combined HTML table for training, validation, and test metrics.""" | |
| 372 all_metrics = extract_metrics_from_json(train_stats, test_stats, output_type) | |
| 373 rows = [] | |
| 374 for metric_key in sorted(all_metrics["training"].keys()): | |
| 375 if ( | |
| 376 metric_key in all_metrics["validation"] | |
| 377 and metric_key in all_metrics["test"] | |
| 378 ): | |
| 379 display_name = METRIC_DISPLAY_NAMES.get( | |
| 380 metric_key, | |
| 381 metric_key.replace("_", " ").title(), | |
| 382 ) | |
| 383 t = all_metrics["training"].get(metric_key) | |
| 384 v = all_metrics["validation"].get(metric_key) | |
| 385 te = all_metrics["test"].get(metric_key) | |
| 386 if all(x is not None for x in [t, v, te]): | |
| 387 rows.append([display_name, f"{t:.4f}", f"{v:.4f}", f"{te:.4f}"]) | |
| 388 | |
| 389 if not rows: | |
| 390 return "<table><tr><td>No metric values found.</td></tr></table>" | |
| 391 | |
| 392 html = ( | |
| 393 "<h2 style='text-align: center;'>Model Performance Summary</h2>" | |
| 394 "<div style='display: flex; justify-content: center;'>" | |
| 395 "<table class='performance-summary' style='border-collapse: collapse;'>" | |
| 396 "<thead><tr>" | |
| 397 "<th class='sortable' style='padding: 10px; border: 1px solid #ccc; text-align: left; white-space: nowrap;'>Metric</th>" | |
| 398 "<th class='sortable' style='padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;'>Train</th>" | |
| 399 "<th class='sortable' style='padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;'>Validation</th>" | |
| 400 "<th class='sortable' style='padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;'>Test</th>" | |
| 401 "</tr></thead><tbody>" | |
| 402 ) | |
| 403 for row in rows: | |
| 404 html += generate_table_row( | |
| 405 row, | |
| 406 "padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;", | |
| 407 ) | |
| 408 html += "</tbody></table></div><br>" | |
| 409 return html | |
| 410 | |
| 411 | |
| 412 # ------------------------------------------- | |
| 413 # 3) TRAIN/VALIDATION PERFORMANCE SUMMARY TABLE | |
| 414 # ------------------------------------------- | |
| 415 def format_train_val_stats_table_html(train_stats: dict, test_stats: dict) -> str: | |
| 416 """Format train/validation metrics into an HTML table.""" | |
| 417 all_metrics = extract_metrics_from_json(train_stats, test_stats, detect_output_type(test_stats)) | |
| 418 rows = [] | |
| 419 for metric_key in sorted(all_metrics["training"].keys()): | |
| 420 if metric_key in all_metrics["validation"]: | |
| 421 display_name = METRIC_DISPLAY_NAMES.get( | |
| 422 metric_key, | |
| 423 metric_key.replace("_", " ").title(), | |
| 424 ) | |
| 425 t = all_metrics["training"].get(metric_key) | |
| 426 v = all_metrics["validation"].get(metric_key) | |
| 427 if t is not None and v is not None: | |
| 428 rows.append([display_name, f"{t:.4f}", f"{v:.4f}"]) | |
| 429 | |
| 430 if not rows: | |
| 431 return "<table><tr><td>No metric values found for Train/Validation.</td></tr></table>" | |
| 432 | |
| 433 html = ( | |
| 434 "<h2 style='text-align: center;'>Train/Validation Performance Summary</h2>" | |
| 435 "<div style='display: flex; justify-content: center;'>" | |
| 436 "<table class='performance-summary' style='border-collapse: collapse;'>" | |
| 437 "<thead><tr>" | |
| 438 "<th class='sortable' style='padding: 10px; border: 1px solid #ccc; text-align: left; white-space: nowrap;'>Metric</th>" | |
| 439 "<th class='sortable' style='padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;'>Train</th>" | |
| 440 "<th class='sortable' style='padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;'>Validation</th>" | |
| 441 "</tr></thead><tbody>" | |
| 442 ) | |
| 443 for row in rows: | |
| 444 html += generate_table_row( | |
| 445 row, | |
| 446 "padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;", | |
| 447 ) | |
| 448 html += "</tbody></table></div><br>" | |
| 449 return html | |
| 450 | |
| 451 | |
| 452 # ----------------------------------------- | |
| 453 # 4) TEST‐ONLY PERFORMANCE SUMMARY TABLE | |
| 454 # ----------------------------------------- | |
| 455 def format_test_merged_stats_table_html( | |
| 456 test_metrics: Dict[str, Any], output_type: str | |
| 457 ) -> str: | |
| 458 """Format test metrics into an HTML table.""" | |
| 459 rows = [] | |
| 460 for key in sorted(test_metrics.keys()): | |
| 461 display_name = METRIC_DISPLAY_NAMES.get(key, key.replace("_", " ").title()) | |
| 462 value = test_metrics[key] | |
| 463 if value is not None: | |
| 464 rows.append([display_name, f"{value:.4f}"]) | |
| 465 | |
| 466 if not rows: | |
| 467 return "<table><tr><td>No test metric values found.</td></tr></table>" | |
| 468 | |
| 469 html = ( | |
| 470 "<h2 style='text-align: center;'>Test Performance Summary</h2>" | |
| 471 "<div style='display: flex; justify-content: center;'>" | |
| 472 "<table class='performance-summary' style='border-collapse: collapse;'>" | |
| 473 "<thead><tr>" | |
| 474 "<th class='sortable' style='padding: 10px; border: 1px solid #ccc; text-align: left; white-space: nowrap;'>Metric</th>" | |
| 475 "<th class='sortable' style='padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;'>Test</th>" | |
| 476 "</tr></thead><tbody>" | |
| 477 ) | |
| 478 for row in rows: | |
| 479 html += generate_table_row( | |
| 480 row, | |
| 481 "padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;", | |
| 482 ) | |
| 483 html += "</tbody></table></div><br>" | |
| 484 return html | |
| 485 | |
| 486 | |
| 487 def split_data_0_2( | |
| 488 df: pd.DataFrame, | |
| 489 split_column: str, | |
| 490 validation_size: float = 0.1, | |
| 491 random_state: int = 42, | |
| 492 label_column: Optional[str] = None, | |
| 493 ) -> pd.DataFrame: | |
| 494 """Given a DataFrame whose split_column only contains {0,2}, re-assign a portion of the 0s to become 1s (validation).""" | |
| 495 out = df.copy() | |
| 496 out[split_column] = pd.to_numeric(out[split_column], errors="coerce").astype(int) | |
| 497 | |
| 498 idx_train = out.index[out[split_column] == 0].tolist() | |
| 499 | |
| 500 if not idx_train: | |
| 501 logger.info("No rows with split=0; nothing to do.") | |
| 502 return out | |
| 503 stratify_arr = None | |
| 504 if label_column and label_column in out.columns: | |
| 505 label_counts = out.loc[idx_train, label_column].value_counts() | |
| 506 if label_counts.size > 1: | |
| 507 # Force stratify even with fewer samples - adjust validation_size if needed | |
| 508 min_samples_per_class = label_counts.min() | |
| 509 if min_samples_per_class * validation_size < 1: | |
| 510 # Adjust validation_size to ensure at least 1 sample per class, but do not exceed original validation_size | |
| 511 adjusted_validation_size = min( | |
| 512 validation_size, 1.0 / min_samples_per_class | |
| 513 ) | |
| 514 if adjusted_validation_size != validation_size: | |
| 515 validation_size = adjusted_validation_size | |
| 516 logger.info( | |
| 517 f"Adjusted validation_size to {validation_size:.3f} to ensure at least one sample per class in validation" | |
| 518 ) | |
| 519 stratify_arr = out.loc[idx_train, label_column] | |
| 520 logger.info("Using stratified split for validation set") | |
| 521 else: | |
| 522 logger.warning("Only one label class found; cannot stratify") | |
| 523 if validation_size <= 0: | |
| 524 logger.info("validation_size <= 0; keeping all as train.") | |
| 525 return out | |
| 526 if validation_size >= 1: | |
| 527 logger.info("validation_size >= 1; moving all train → validation.") | |
| 528 out.loc[idx_train, split_column] = 1 | |
| 529 return out | |
| 530 # Always try stratified split first | |
| 531 try: | |
| 532 train_idx, val_idx = train_test_split( | |
| 533 idx_train, | |
| 534 test_size=validation_size, | |
| 535 random_state=random_state, | |
| 536 stratify=stratify_arr, | |
| 537 ) | |
| 538 logger.info("Successfully applied stratified split") | |
| 539 except ValueError as e: | |
| 540 logger.warning(f"Stratified split failed ({e}); falling back to random split.") | |
| 541 train_idx, val_idx = train_test_split( | |
| 542 idx_train, | |
| 543 test_size=validation_size, | |
| 544 random_state=random_state, | |
| 545 stratify=None, | |
| 546 ) | |
| 547 out.loc[train_idx, split_column] = 0 | |
| 548 out.loc[val_idx, split_column] = 1 | |
| 549 out[split_column] = out[split_column].astype(int) | |
| 550 return out | |
| 551 | |
| 552 | |
| 553 def create_stratified_random_split( | |
| 554 df: pd.DataFrame, | |
| 555 split_column: str, | |
| 556 split_probabilities: list = [0.7, 0.1, 0.2], | |
| 557 random_state: int = 42, | |
| 558 label_column: Optional[str] = None, | |
| 559 ) -> pd.DataFrame: | |
| 560 """Create a stratified random split when no split column exists.""" | |
| 561 out = df.copy() | |
| 562 | |
| 563 # initialize split column | |
| 564 out[split_column] = 0 | |
| 565 | |
| 566 if not label_column or label_column not in out.columns: | |
| 567 logger.warning( | |
| 568 "No label column found; using random split without stratification" | |
| 569 ) | |
| 570 # fall back to simple random assignment | |
| 571 indices = out.index.tolist() | |
| 572 np.random.seed(random_state) | |
| 573 np.random.shuffle(indices) | |
| 574 | |
| 575 n_total = len(indices) | |
| 576 n_train = int(n_total * split_probabilities[0]) | |
| 577 n_val = int(n_total * split_probabilities[1]) | |
| 578 | |
| 579 out.loc[indices[:n_train], split_column] = 0 | |
| 580 out.loc[indices[n_train:n_train + n_val], split_column] = 1 | |
| 581 out.loc[indices[n_train + n_val:], split_column] = 2 | |
| 582 | |
| 583 return out.astype({split_column: int}) | |
| 584 | |
| 585 # check if stratification is possible | |
| 586 label_counts = out[label_column].value_counts() | |
| 587 min_samples_per_class = label_counts.min() | |
| 588 | |
| 589 # ensure we have enough samples for stratification: | |
| 590 # Each class must have at least as many samples as the number of splits, | |
| 591 # so that each split can receive at least one sample per class. | |
| 592 min_samples_required = len(split_probabilities) | |
| 593 if min_samples_per_class < min_samples_required: | |
| 594 logger.warning( | |
| 595 f"Insufficient samples per class for stratification (min: {min_samples_per_class}, required: {min_samples_required}); using random split" | |
| 596 ) | |
| 597 # fall back to simple random assignment | |
| 598 indices = out.index.tolist() | |
| 599 np.random.seed(random_state) | |
| 600 np.random.shuffle(indices) | |
| 601 | |
| 602 n_total = len(indices) | |
| 603 n_train = int(n_total * split_probabilities[0]) | |
| 604 n_val = int(n_total * split_probabilities[1]) | |
| 605 | |
| 606 out.loc[indices[:n_train], split_column] = 0 | |
| 607 out.loc[indices[n_train:n_train + n_val], split_column] = 1 | |
| 608 out.loc[indices[n_train + n_val:], split_column] = 2 | |
| 609 | |
| 610 return out.astype({split_column: int}) | |
| 611 | |
| 612 logger.info("Using stratified random split for train/validation/test sets") | |
| 613 | |
| 614 # first split: separate test set | |
| 615 train_val_idx, test_idx = train_test_split( | |
| 616 out.index.tolist(), | |
| 617 test_size=split_probabilities[2], | |
| 618 random_state=random_state, | |
| 619 stratify=out[label_column], | |
| 620 ) | |
| 621 | |
| 622 # second split: separate training and validation from remaining data | |
| 623 val_size_adjusted = split_probabilities[1] / ( | |
| 624 split_probabilities[0] + split_probabilities[1] | |
| 625 ) | |
| 626 train_idx, val_idx = train_test_split( | |
| 627 train_val_idx, | |
| 628 test_size=val_size_adjusted, | |
| 629 random_state=random_state, | |
| 630 stratify=out.loc[train_val_idx, label_column] if label_column and label_column in out.columns else None, | |
| 631 ) | |
| 632 | |
| 633 # assign split values | |
| 634 out.loc[train_idx, split_column] = 0 | |
| 635 out.loc[val_idx, split_column] = 1 | |
| 636 out.loc[test_idx, split_column] = 2 | |
| 637 | |
| 638 logger.info("Successfully applied stratified random split") | |
| 639 logger.info( | |
| 640 f"Split counts: Train={len(train_idx)}, Val={len(val_idx)}, Test={len(test_idx)}" | |
| 641 ) | |
| 642 return out.astype({split_column: int}) | |
| 643 | |
| 644 | |
| 645 class Backend(Protocol): | |
| 646 """Interface for a machine learning backend.""" | |
| 647 | |
| 648 def prepare_config( | |
| 649 self, | |
| 650 config_params: Dict[str, Any], | |
| 651 split_config: Dict[str, Any], | |
| 652 ) -> str: | |
| 653 ... | |
| 654 | |
| 655 def run_experiment( | |
| 656 self, | |
| 657 dataset_path: Path, | |
| 658 config_path: Path, | |
| 659 output_dir: Path, | |
| 660 random_seed: int, | |
| 661 ) -> None: | |
| 662 ... | |
| 663 | |
| 664 def generate_plots(self, output_dir: Path) -> None: | |
| 665 ... | |
| 666 | |
| 667 def generate_html_report( | |
| 668 self, | |
| 669 title: str, | |
| 670 output_dir: str, | |
| 671 config: Dict[str, Any], | |
| 672 split_info: str, | |
| 673 ) -> Path: | |
| 674 ... | |
| 675 | |
| 676 | |
| 677 class LudwigDirectBackend: | |
| 678 """Backend for running Ludwig experiments directly via the internal experiment_cli function.""" | |
| 679 | |
| 680 def _detect_image_dimensions(self, image_zip_path: str) -> Tuple[int, int]: | |
| 681 """Detect image dimensions from the first image in the dataset.""" | |
| 682 try: | |
| 683 import zipfile | |
| 684 from PIL import Image | |
| 685 import io | |
| 686 | |
| 687 # Check if image_zip is provided | |
| 688 if not image_zip_path: | |
| 689 logger.warning("No image zip provided, using default 224x224") | |
| 690 return 224, 224 | |
| 691 | |
| 692 # Extract first image to detect dimensions | |
| 693 with zipfile.ZipFile(image_zip_path, 'r') as z: | |
| 694 image_files = [f for f in z.namelist() if f.lower().endswith(('.png', '.jpg', '.jpeg'))] | |
| 695 if not image_files: | |
| 696 logger.warning("No image files found in zip, using default 224x224") | |
| 697 return 224, 224 | |
| 698 | |
| 699 # Check first image | |
| 700 with z.open(image_files[0]) as f: | |
| 701 img = Image.open(io.BytesIO(f.read())) | |
| 702 width, height = img.size | |
| 703 logger.info(f"Detected image dimensions: {width}x{height}") | |
| 704 return height, width # Return as (height, width) to match encoder config | |
| 705 | |
| 706 except Exception as e: | |
| 707 logger.warning(f"Error detecting image dimensions: {e}, using default 224x224") | |
| 708 return 224, 224 | |
| 709 | |
| 710 def prepare_config( | |
| 711 self, | |
| 712 config_params: Dict[str, Any], | |
| 713 split_config: Dict[str, Any], | |
| 714 ) -> str: | |
| 715 logger.info("LudwigDirectBackend: Preparing YAML configuration.") | |
| 716 | |
| 717 model_name = config_params.get("model_name", "resnet18") | |
| 718 use_pretrained = config_params.get("use_pretrained", False) | |
| 719 fine_tune = config_params.get("fine_tune", False) | |
| 720 if use_pretrained: | |
| 721 trainable = bool(fine_tune) | |
| 722 else: | |
| 723 trainable = True | |
| 724 epochs = config_params.get("epochs", 10) | |
| 725 batch_size = config_params.get("batch_size") | |
| 726 num_processes = config_params.get("preprocessing_num_processes", 1) | |
| 727 early_stop = config_params.get("early_stop", None) | |
| 728 learning_rate = config_params.get("learning_rate") | |
| 729 learning_rate = "auto" if learning_rate is None else float(learning_rate) | |
| 730 raw_encoder = MODEL_ENCODER_TEMPLATES.get(model_name, model_name) | |
| 731 | |
| 732 # --- MetaFormer detection and config logic --- | |
| 733 def _is_metaformer(name: str) -> bool: | |
| 734 return isinstance(name, str) and name.startswith( | |
| 735 ( | |
| 736 "identityformer_", | |
| 737 "randformer_", | |
| 738 "poolformerv2_", | |
| 739 "convformer_", | |
| 740 "caformer_", | |
| 741 ) | |
| 742 ) | |
| 743 | |
| 744 # Check if this is a MetaFormer model (either direct name or in custom_model) | |
| 745 is_metaformer = ( | |
| 746 _is_metaformer(model_name) | |
| 747 or (isinstance(raw_encoder, dict) and "custom_model" in raw_encoder and _is_metaformer(raw_encoder["custom_model"])) | |
| 748 ) | |
| 749 | |
| 750 metaformer_resize: Optional[Tuple[int, int]] = None | |
| 751 metaformer_channels = 3 | |
| 752 | |
| 753 if is_metaformer: | |
| 754 # Handle MetaFormer models | |
| 755 custom_model = None | |
| 756 if isinstance(raw_encoder, dict) and "custom_model" in raw_encoder: | |
| 757 custom_model = raw_encoder["custom_model"] | |
| 758 else: | |
| 759 custom_model = model_name | |
| 760 | |
| 761 logger.info(f"DETECTED MetaFormer model: {custom_model}") | |
| 762 cfg_channels, cfg_height, cfg_width = 3, 224, 224 | |
| 763 if META_DEFAULT_CFGS: | |
| 764 model_cfg = META_DEFAULT_CFGS.get(custom_model, {}) | |
| 765 input_size = model_cfg.get("input_size") | |
| 766 if isinstance(input_size, (list, tuple)) and len(input_size) == 3: | |
| 767 cfg_channels, cfg_height, cfg_width = ( | |
| 768 int(input_size[0]), | |
| 769 int(input_size[1]), | |
| 770 int(input_size[2]), | |
| 771 ) | |
| 772 | |
| 773 target_height, target_width = cfg_height, cfg_width | |
| 774 resize_value = config_params.get("image_resize") | |
| 775 if resize_value and resize_value != "original": | |
| 776 try: | |
| 777 dimensions = resize_value.split("x") | |
| 778 if len(dimensions) == 2: | |
| 779 target_height, target_width = int(dimensions[0]), int(dimensions[1]) | |
| 780 if target_height <= 0 or target_width <= 0: | |
| 781 raise ValueError( | |
| 782 f"Image resize must be positive integers, received {resize_value}." | |
| 783 ) | |
| 784 logger.info(f"MetaFormer explicit resize: {target_height}x{target_width}") | |
| 785 else: | |
| 786 raise ValueError(resize_value) | |
| 787 except (ValueError, IndexError): | |
| 788 logger.warning( | |
| 789 "Invalid image resize format '%s'; falling back to model default %sx%s", | |
| 790 resize_value, | |
| 791 cfg_height, | |
| 792 cfg_width, | |
| 793 ) | |
| 794 target_height, target_width = cfg_height, cfg_width | |
| 795 else: | |
| 796 image_zip_path = config_params.get("image_zip", "") | |
| 797 detected_height, detected_width = self._detect_image_dimensions(image_zip_path) | |
| 798 if use_pretrained: | |
| 799 if (detected_height, detected_width) != (cfg_height, cfg_width): | |
| 800 logger.info( | |
| 801 "MetaFormer pretrained weights expect %sx%s; resizing from detected %sx%s", | |
| 802 cfg_height, | |
| 803 cfg_width, | |
| 804 detected_height, | |
| 805 detected_width, | |
| 806 ) | |
| 807 else: | |
| 808 target_height, target_width = detected_height, detected_width | |
| 809 if target_height <= 0 or target_width <= 0: | |
| 810 raise ValueError( | |
| 811 f"Invalid detected image dimensions for MetaFormer: {target_height}x{target_width}." | |
| 812 ) | |
| 813 | |
| 814 metaformer_channels = cfg_channels | |
| 815 metaformer_resize = (target_height, target_width) | |
| 816 | |
| 817 encoder_config = { | |
| 818 "type": "stacked_cnn", | |
| 819 "height": target_height, | |
| 820 "width": target_width, | |
| 821 "num_channels": metaformer_channels, | |
| 822 "output_size": 128, | |
| 823 "use_pretrained": use_pretrained, | |
| 824 "trainable": trainable, | |
| 825 "custom_model": custom_model, | |
| 826 } | |
| 827 | |
| 828 elif isinstance(raw_encoder, dict): | |
| 829 # Handle image resize for regular encoders | |
| 830 # Note: Standard encoders like ResNet don't support height/width parameters | |
| 831 # Resize will be handled at the preprocessing level by Ludwig | |
| 832 if config_params.get("image_resize") and config_params["image_resize"] != "original": | |
| 833 logger.info(f"Resize requested: {config_params['image_resize']} for standard encoder. Resize will be handled at preprocessing level.") | |
| 834 | |
| 835 encoder_config = { | |
| 836 **raw_encoder, | |
| 837 "use_pretrained": use_pretrained, | |
| 838 "trainable": trainable, | |
| 839 } | |
| 840 else: | |
| 841 encoder_config = {"type": raw_encoder} | |
| 842 | |
| 843 batch_size_cfg = batch_size or "auto" | |
| 844 | |
| 845 label_column_path = config_params.get("label_column_data_path") | |
| 846 label_series = None | |
| 847 if label_column_path is not None and Path(label_column_path).exists(): | |
| 848 try: | |
| 849 label_series = pd.read_csv(label_column_path)[LABEL_COLUMN_NAME] | |
| 850 except Exception as e: | |
| 851 logger.warning(f"Could not read label column for task detection: {e}") | |
| 852 | |
| 853 if ( | |
| 854 label_series is not None | |
| 855 and ptypes.is_numeric_dtype(label_series.dtype) | |
| 856 and label_series.nunique() > 10 | |
| 857 ): | |
| 858 task_type = "regression" | |
| 859 else: | |
| 860 task_type = "classification" | |
| 861 | |
| 862 config_params["task_type"] = task_type | |
| 863 | |
| 864 image_feat: Dict[str, Any] = { | |
| 865 "name": IMAGE_PATH_COLUMN_NAME, | |
| 866 "type": "image", | |
| 867 } | |
| 868 # Set preprocessing dimensions FIRST for MetaFormer models | |
| 869 if is_metaformer: | |
| 870 if metaformer_resize is None: | |
| 871 metaformer_resize = (224, 224) | |
| 872 height, width = metaformer_resize | |
| 873 | |
| 874 # CRITICAL: Set preprocessing dimensions FIRST for MetaFormer models | |
| 875 # This is essential for MetaFormer models to work properly | |
| 876 if "preprocessing" not in image_feat: | |
| 877 image_feat["preprocessing"] = {} | |
| 878 image_feat["preprocessing"]["height"] = height | |
| 879 image_feat["preprocessing"]["width"] = width | |
| 880 # Use infer_image_dimensions=True to allow Ludwig to read images for validation | |
| 881 # but set explicit max dimensions to control the output size | |
| 882 image_feat["preprocessing"]["infer_image_dimensions"] = True | |
| 883 image_feat["preprocessing"]["infer_image_max_height"] = height | |
| 884 image_feat["preprocessing"]["infer_image_max_width"] = width | |
| 885 image_feat["preprocessing"]["num_channels"] = metaformer_channels | |
| 886 image_feat["preprocessing"]["resize_method"] = "interpolate" # Use interpolation for better quality | |
| 887 image_feat["preprocessing"]["standardize_image"] = "imagenet1k" # Use ImageNet standardization | |
| 888 # Force Ludwig to respect our dimensions by setting additional parameters | |
| 889 image_feat["preprocessing"]["requires_equal_dimensions"] = False | |
| 890 logger.info(f"Set preprocessing dimensions for MetaFormer: {height}x{width} (infer_dimensions=True with max dimensions to allow validation)") | |
| 891 # Now set the encoder configuration | |
| 892 image_feat["encoder"] = encoder_config | |
| 893 | |
| 894 if config_params.get("augmentation") is not None: | |
| 895 image_feat["augmentation"] = config_params["augmentation"] | |
| 896 | |
| 897 # Add resize configuration for standard encoders (ResNet, etc.) | |
| 898 # FIXED: MetaFormer models now respect user dimensions completely | |
| 899 # Previously there was a double resize issue where MetaFormer would force 224x224 | |
| 900 # Now both MetaFormer and standard encoders respect user's resize choice | |
| 901 if (not is_metaformer) and config_params.get("image_resize") and config_params["image_resize"] != "original": | |
| 902 try: | |
| 903 dimensions = config_params["image_resize"].split("x") | |
| 904 if len(dimensions) == 2: | |
| 905 height, width = int(dimensions[0]), int(dimensions[1]) | |
| 906 if height <= 0 or width <= 0: | |
| 907 raise ValueError( | |
| 908 f"Image resize must be positive integers, received {config_params['image_resize']}." | |
| 909 ) | |
| 910 | |
| 911 # Add resize to preprocessing for standard encoders | |
| 912 if "preprocessing" not in image_feat: | |
| 913 image_feat["preprocessing"] = {} | |
| 914 image_feat["preprocessing"]["height"] = height | |
| 915 image_feat["preprocessing"]["width"] = width | |
| 916 # Use infer_image_dimensions=True to allow Ludwig to read images for validation | |
| 917 # but set explicit max dimensions to control the output size | |
| 918 image_feat["preprocessing"]["infer_image_dimensions"] = True | |
| 919 image_feat["preprocessing"]["infer_image_max_height"] = height | |
| 920 image_feat["preprocessing"]["infer_image_max_width"] = width | |
| 921 logger.info(f"Added resize preprocessing: {height}x{width} for standard encoder with infer_image_dimensions=True and max dimensions") | |
| 922 except (ValueError, IndexError): | |
| 923 logger.warning(f"Invalid image resize format: {config_params['image_resize']}, skipping resize preprocessing") | |
| 924 if task_type == "regression": | |
| 925 output_feat = { | |
| 926 "name": LABEL_COLUMN_NAME, | |
| 927 "type": "number", | |
| 928 "decoder": {"type": "regressor", "input_size": 1}, | |
| 929 "loss": {"type": "mean_squared_error"}, | |
| 930 "evaluation": { | |
| 931 "metrics": [ | |
| 932 "mean_squared_error", | |
| 933 "mean_absolute_error", | |
| 934 "r2", | |
| 935 ] | |
| 936 }, | |
| 937 } | |
| 938 val_metric = config_params.get("validation_metric", "mean_squared_error") | |
| 939 | |
| 940 else: | |
| 941 num_unique_labels = ( | |
| 942 label_series.nunique() if label_series is not None else 2 | |
| 943 ) | |
| 944 output_type = "binary" if num_unique_labels == 2 else "category" | |
| 945 # Determine if this is regression or classification based on label type | |
| 946 is_regression = ( | |
| 947 label_series is not None | |
| 948 and ptypes.is_numeric_dtype(label_series.dtype) | |
| 949 and label_series.nunique() > 10 | |
| 950 ) | |
| 951 | |
| 952 if is_regression: | |
| 953 output_feat = { | |
| 954 "name": LABEL_COLUMN_NAME, | |
| 955 "type": "number", | |
| 956 "decoder": {"type": "regressor", "input_size": 1}, | |
| 957 "loss": {"type": "mean_squared_error"}, | |
| 958 } | |
| 959 else: | |
| 960 if num_unique_labels == 2: | |
| 961 output_feat = { | |
| 962 "name": LABEL_COLUMN_NAME, | |
| 963 "type": "binary", | |
| 964 "decoder": {"type": "classifier", "input_size": 1}, | |
| 965 "loss": {"type": "softmax_cross_entropy"}, | |
| 966 } | |
| 967 else: | |
| 968 output_feat = { | |
| 969 "name": LABEL_COLUMN_NAME, | |
| 970 "type": "category", | |
| 971 "decoder": {"type": "classifier", "input_size": num_unique_labels}, | |
| 972 "loss": {"type": "softmax_cross_entropy"}, | |
| 973 } | |
| 974 if output_type == "binary" and config_params.get("threshold") is not None: | |
| 975 output_feat["threshold"] = float(config_params["threshold"]) | |
| 976 val_metric = None | |
| 977 | |
| 978 conf: Dict[str, Any] = { | |
| 979 "model_type": "ecd", | |
| 980 "input_features": [image_feat], | |
| 981 "output_features": [output_feat], | |
| 982 "combiner": {"type": "concat"}, | |
| 983 "trainer": { | |
| 984 "epochs": epochs, | |
| 985 "early_stop": early_stop, | |
| 986 "batch_size": batch_size_cfg, | |
| 987 "learning_rate": learning_rate, | |
| 988 # only set validation_metric for regression | |
| 989 **({"validation_metric": val_metric} if val_metric else {}), | |
| 990 }, | |
| 991 "preprocessing": { | |
| 992 "split": split_config, | |
| 993 "num_processes": num_processes, | |
| 994 "in_memory": False, | |
| 995 }, | |
| 996 } | |
| 997 | |
| 998 logger.debug("LudwigDirectBackend: Config dict built.") | |
| 999 try: | |
| 1000 yaml_str = yaml.dump(conf, sort_keys=False, indent=2) | |
| 1001 logger.info("LudwigDirectBackend: YAML config generated.") | |
| 1002 return yaml_str | |
| 1003 except Exception: | |
| 1004 logger.error( | |
| 1005 "LudwigDirectBackend: Failed to serialize YAML.", | |
| 1006 exc_info=True, | |
| 1007 ) | |
| 1008 raise | |
| 1009 | |
| 1010 def run_experiment( | |
| 1011 self, | |
| 1012 dataset_path: Path, | |
| 1013 config_path: Path, | |
| 1014 output_dir: Path, | |
| 1015 random_seed: int = 42, | |
| 1016 ) -> None: | |
| 1017 """Invoke Ludwig's internal experiment_cli function to run the experiment.""" | |
| 1018 logger.info("LudwigDirectBackend: Starting experiment execution.") | |
| 1019 | |
| 1020 try: | |
| 1021 from ludwig.experiment import experiment_cli | |
| 1022 except ImportError as e: | |
| 1023 logger.error( | |
| 1024 "LudwigDirectBackend: Could not import experiment_cli.", | |
| 1025 exc_info=True, | |
| 1026 ) | |
| 1027 raise RuntimeError("Ludwig import failed.") from e | |
| 1028 | |
| 1029 output_dir.mkdir(parents=True, exist_ok=True) | |
| 1030 | |
| 1031 try: | |
| 1032 experiment_cli( | |
| 1033 dataset=str(dataset_path), | |
| 1034 config=str(config_path), | |
| 1035 output_directory=str(output_dir), | |
| 1036 random_seed=random_seed, | |
| 1037 skip_preprocessing=True, | |
| 1038 ) | |
| 1039 logger.info( | |
| 1040 f"LudwigDirectBackend: Experiment completed. Results in {output_dir}" | |
| 1041 ) | |
| 1042 except TypeError as e: | |
| 1043 logger.error( | |
| 1044 "LudwigDirectBackend: Argument mismatch in experiment_cli call.", | |
| 1045 exc_info=True, | |
| 1046 ) | |
| 1047 raise RuntimeError("Ludwig argument error.") from e | |
| 1048 except Exception: | |
| 1049 logger.error( | |
| 1050 "LudwigDirectBackend: Experiment execution error.", | |
| 1051 exc_info=True, | |
| 1052 ) | |
| 1053 raise | |
| 1054 | |
| 1055 def get_training_process(self, output_dir) -> Optional[Dict[str, Any]]: | |
| 1056 """Retrieve the learning rate used in the most recent Ludwig run.""" | |
| 1057 output_dir = Path(output_dir) | |
| 1058 exp_dirs = sorted( | |
| 1059 output_dir.glob("experiment_run*"), | |
| 1060 key=lambda p: p.stat().st_mtime, | |
| 1061 ) | |
| 1062 | |
| 1063 if not exp_dirs: | |
| 1064 logger.warning(f"No experiment run directories found in {output_dir}") | |
| 1065 return None | |
| 1066 | |
| 1067 progress_file = exp_dirs[-1] / "model" / "training_progress.json" | |
| 1068 if not progress_file.exists(): | |
| 1069 logger.warning(f"No training_progress.json found in {progress_file}") | |
| 1070 return None | |
| 1071 | |
| 1072 try: | |
| 1073 with progress_file.open("r", encoding="utf-8") as f: | |
| 1074 data = json.load(f) | |
| 1075 return { | |
| 1076 "learning_rate": data.get("learning_rate"), | |
| 1077 "batch_size": data.get("batch_size"), | |
| 1078 "epoch": data.get("epoch"), | |
| 1079 } | |
| 1080 except Exception as e: | |
| 1081 logger.warning(f"Failed to read training progress info: {e}") | |
| 1082 return {} | |
| 1083 | |
| 1084 def convert_parquet_to_csv(self, output_dir: Path): | |
| 1085 """Convert the predictions Parquet file to CSV.""" | |
| 1086 output_dir = Path(output_dir) | |
| 1087 exp_dirs = sorted( | |
| 1088 output_dir.glob("experiment_run*"), | |
| 1089 key=lambda p: p.stat().st_mtime, | |
| 1090 ) | |
| 1091 if not exp_dirs: | |
| 1092 logger.warning(f"No experiment run dirs found in {output_dir}") | |
| 1093 return | |
| 1094 exp_dir = exp_dirs[-1] | |
| 1095 parquet_path = exp_dir / PREDICTIONS_PARQUET_FILE_NAME | |
| 1096 csv_path = exp_dir / "predictions.csv" | |
| 1097 | |
| 1098 # Check if parquet file exists before trying to convert | |
| 1099 if not parquet_path.exists(): | |
| 1100 logger.info(f"Predictions parquet file not found at {parquet_path}, skipping conversion") | |
| 1101 return | |
| 1102 | |
| 1103 try: | |
| 1104 df = pd.read_parquet(parquet_path) | |
| 1105 df.to_csv(csv_path, index=False) | |
| 1106 logger.info(f"Converted Parquet to CSV: {csv_path}") | |
| 1107 except Exception as e: | |
| 1108 logger.error(f"Error converting Parquet to CSV: {e}") | |
| 1109 | |
| 1110 def generate_plots(self, output_dir: Path) -> None: | |
| 1111 """Generate all registered Ludwig visualizations for the latest experiment run.""" | |
| 1112 logger.info("Generating all Ludwig visualizations…") | |
| 1113 | |
| 1114 test_plots = { | |
| 1115 "compare_performance", | |
| 1116 "compare_classifiers_performance_from_prob", | |
| 1117 "compare_classifiers_performance_from_pred", | |
| 1118 "compare_classifiers_performance_changing_k", | |
| 1119 "compare_classifiers_multiclass_multimetric", | |
| 1120 "compare_classifiers_predictions", | |
| 1121 "confidence_thresholding_2thresholds_2d", | |
| 1122 "confidence_thresholding_2thresholds_3d", | |
| 1123 "confidence_thresholding", | |
| 1124 "confidence_thresholding_data_vs_acc", | |
| 1125 "binary_threshold_vs_metric", | |
| 1126 "roc_curves", | |
| 1127 "roc_curves_from_test_statistics", | |
| 1128 "calibration_1_vs_all", | |
| 1129 "calibration_multiclass", | |
| 1130 "confusion_matrix", | |
| 1131 "frequency_vs_f1", | |
| 1132 } | |
| 1133 train_plots = { | |
| 1134 "learning_curves", | |
| 1135 "compare_classifiers_performance_subset", | |
| 1136 } | |
| 1137 | |
| 1138 output_dir = Path(output_dir) | |
| 1139 exp_dirs = sorted( | |
| 1140 output_dir.glob("experiment_run*"), | |
| 1141 key=lambda p: p.stat().st_mtime, | |
| 1142 ) | |
| 1143 if not exp_dirs: | |
| 1144 logger.warning(f"No experiment run dirs found in {output_dir}") | |
| 1145 return | |
| 1146 exp_dir = exp_dirs[-1] | |
| 1147 | |
| 1148 viz_dir = exp_dir / "visualizations" | |
| 1149 viz_dir.mkdir(exist_ok=True) | |
| 1150 train_viz = viz_dir / "train" | |
| 1151 test_viz = viz_dir / "test" | |
| 1152 train_viz.mkdir(parents=True, exist_ok=True) | |
| 1153 test_viz.mkdir(parents=True, exist_ok=True) | |
| 1154 | |
| 1155 def _check(p: Path) -> Optional[str]: | |
| 1156 return str(p) if p.exists() else None | |
| 1157 | |
| 1158 training_stats = _check(exp_dir / "training_statistics.json") | |
| 1159 test_stats = _check(exp_dir / TEST_STATISTICS_FILE_NAME) | |
| 1160 probs_path = _check(exp_dir / PREDICTIONS_PARQUET_FILE_NAME) | |
| 1161 gt_metadata = _check(exp_dir / "model" / TRAIN_SET_METADATA_FILE_NAME) | |
| 1162 | |
| 1163 dataset_path = None | |
| 1164 split_file = None | |
| 1165 desc = exp_dir / DESCRIPTION_FILE_NAME | |
| 1166 if desc.exists(): | |
| 1167 with open(desc, "r") as f: | |
| 1168 cfg = json.load(f) | |
| 1169 dataset_path = _check(Path(cfg.get("dataset", ""))) | |
| 1170 split_file = _check(Path(get_split_path(cfg.get("dataset", "")))) | |
| 1171 | |
| 1172 output_feature = "" | |
| 1173 if desc.exists(): | |
| 1174 try: | |
| 1175 output_feature = cfg["config"]["output_features"][0]["name"] | |
| 1176 except Exception: | |
| 1177 pass | |
| 1178 if not output_feature and test_stats: | |
| 1179 with open(test_stats, "r") as f: | |
| 1180 stats = json.load(f) | |
| 1181 output_feature = next(iter(stats.keys()), "") | |
| 1182 | |
| 1183 viz_registry = get_visualizations_registry() | |
| 1184 for viz_name, viz_func in viz_registry.items(): | |
| 1185 if viz_name in train_plots: | |
| 1186 viz_dir_plot = train_viz | |
| 1187 elif viz_name in test_plots: | |
| 1188 viz_dir_plot = test_viz | |
| 1189 else: | |
| 1190 continue | |
| 1191 | |
| 1192 try: | |
| 1193 viz_func( | |
| 1194 training_statistics=[training_stats] if training_stats else [], | |
| 1195 test_statistics=[test_stats] if test_stats else [], | |
| 1196 probabilities=[probs_path] if probs_path else [], | |
| 1197 output_feature_name=output_feature, | |
| 1198 ground_truth_split=2, | |
| 1199 top_n_classes=[0], | |
| 1200 top_k=3, | |
| 1201 ground_truth_metadata=gt_metadata, | |
| 1202 ground_truth=dataset_path, | |
| 1203 split_file=split_file, | |
| 1204 output_directory=str(viz_dir_plot), | |
| 1205 normalize=False, | |
| 1206 file_format="png", | |
| 1207 ) | |
| 1208 logger.info(f"✔ Generated {viz_name}") | |
| 1209 except Exception as e: | |
| 1210 logger.warning(f"✘ Skipped {viz_name}: {e}") | |
| 1211 | |
| 1212 logger.info(f"All visualizations written to {viz_dir}") | |
| 1213 | |
| 1214 def generate_html_report( | |
| 1215 self, | |
| 1216 title: str, | |
| 1217 output_dir: str, | |
| 1218 config: dict, | |
| 1219 split_info: str, | |
| 1220 ) -> Path: | |
| 1221 """Assemble an HTML report from visualizations under train_val/ and test/ folders.""" | |
| 1222 cwd = Path.cwd() | |
| 1223 report_name = title.lower().replace(" ", "_") + "_report.html" | |
| 1224 report_path = cwd / report_name | |
| 1225 output_dir = Path(output_dir) | |
| 1226 output_type = None | |
| 1227 | |
| 1228 exp_dirs = sorted( | |
| 1229 output_dir.glob("experiment_run*"), | |
| 1230 key=lambda p: p.stat().st_mtime, | |
| 1231 ) | |
| 1232 if not exp_dirs: | |
| 1233 raise RuntimeError(f"No 'experiment*' dirs found in {output_dir}") | |
| 1234 exp_dir = exp_dirs[-1] | |
| 1235 | |
| 1236 base_viz_dir = exp_dir / "visualizations" | |
| 1237 train_viz_dir = base_viz_dir / "train" | |
| 1238 test_viz_dir = base_viz_dir / "test" | |
| 1239 | |
| 1240 html = get_html_template() | |
| 1241 | |
| 1242 # Extra CSS & JS: center Plotly and enable CSV download for predictions table | |
| 1243 html += """ | |
| 1244 <style> | |
| 1245 /* Center Plotly figures (both wrapper and native classes) */ | |
| 1246 .plotly-center { display: flex; justify-content: center; } | |
| 1247 .plotly-center .plotly-graph-div, .plotly-center .js-plotly-plot { margin: 0 auto !important; } | |
| 1248 .js-plotly-plot, .plotly-graph-div { margin-left: auto !important; margin-right: auto !important; } | |
| 1249 | |
| 1250 /* Download button for predictions table */ | |
| 1251 .download-btn { | |
| 1252 padding: 8px 12px; | |
| 1253 border: 1px solid #4CAF50; | |
| 1254 background: #4CAF50; | |
| 1255 color: white; | |
| 1256 border-radius: 6px; | |
| 1257 cursor: pointer; | |
| 1258 } | |
| 1259 .download-btn:hover { filter: brightness(0.95); } | |
| 1260 .preds-controls { | |
| 1261 display: flex; | |
| 1262 justify-content: flex-end; | |
| 1263 gap: 8px; | |
| 1264 margin: 8px 0; | |
| 1265 } | |
| 1266 </style> | |
| 1267 <script> | |
| 1268 function tableToCSV(table){ | |
| 1269 const rows = Array.from(table.querySelectorAll('tr')); | |
| 1270 return rows.map(row => | |
| 1271 Array.from(row.querySelectorAll('th,td')).map(cell => { | |
| 1272 let text = cell.innerText.replace(/\\r?\\n|\\r/g,' ').trim(); | |
| 1273 if (text.includes('"') || text.includes(',')) { | |
| 1274 text = '"' + text.replace(/"/g,'""') + '"'; | |
| 1275 } | |
| 1276 return text; | |
| 1277 }).join(',') | |
| 1278 ).join('\\n'); | |
| 1279 } | |
| 1280 document.addEventListener('DOMContentLoaded', function(){ | |
| 1281 const btn = document.getElementById('downloadPredsCsv'); | |
| 1282 if(btn){ | |
| 1283 btn.addEventListener('click', function(){ | |
| 1284 const tbl = document.querySelector('.predictions-table'); | |
| 1285 if(!tbl){ alert('Predictions table not found.'); return; } | |
| 1286 const csv = tableToCSV(tbl); | |
| 1287 const blob = new Blob([csv], {type: 'text/csv;charset=utf-8;'}); | |
| 1288 const url = URL.createObjectURL(blob); | |
| 1289 const a = document.createElement('a'); | |
| 1290 a.href = url; | |
| 1291 a.download = 'ground_truth_vs_predictions.csv'; | |
| 1292 document.body.appendChild(a); | |
| 1293 a.click(); | |
| 1294 document.body.removeChild(a); | |
| 1295 URL.revokeObjectURL(url); | |
| 1296 }); | |
| 1297 } | |
| 1298 }); | |
| 1299 </script> | |
| 1300 """ | |
| 1301 html += f"<h1>{title}</h1>" | |
| 1302 | |
| 1303 metrics_html = "" | |
| 1304 train_val_metrics_html = "" | |
| 1305 test_metrics_html = "" | |
| 1306 try: | |
| 1307 train_stats_path = exp_dir / "training_statistics.json" | |
| 1308 test_stats_path = exp_dir / TEST_STATISTICS_FILE_NAME | |
| 1309 if train_stats_path.exists() and test_stats_path.exists(): | |
| 1310 with open(train_stats_path) as f: | |
| 1311 train_stats = json.load(f) | |
| 1312 with open(test_stats_path) as f: | |
| 1313 test_stats = json.load(f) | |
| 1314 output_type = detect_output_type(test_stats) | |
| 1315 metrics_html = format_stats_table_html(train_stats, test_stats, output_type) | |
| 1316 train_val_metrics_html = format_train_val_stats_table_html( | |
| 1317 train_stats, test_stats | |
| 1318 ) | |
| 1319 test_metrics_html = format_test_merged_stats_table_html( | |
| 1320 extract_metrics_from_json(train_stats, test_stats, output_type)[ | |
| 1321 "test" | |
| 1322 ], output_type | |
| 1323 ) | |
| 1324 except Exception as e: | |
| 1325 logger.warning( | |
| 1326 f"Could not load stats for HTML report: {type(e).__name__}: {e}" | |
| 1327 ) | |
| 1328 | |
| 1329 config_html = "" | |
| 1330 training_progress = self.get_training_process(output_dir) | |
| 1331 try: | |
| 1332 config_html = format_config_table_html( | |
| 1333 config, split_info, training_progress, output_type | |
| 1334 ) | |
| 1335 except Exception as e: | |
| 1336 logger.warning(f"Could not load config for HTML report: {e}") | |
| 1337 | |
| 1338 # ---------- image rendering with exclusions ---------- | |
| 1339 def render_img_section( | |
| 1340 title: str, | |
| 1341 dir_path: Path, | |
| 1342 output_type: str = None, | |
| 1343 exclude_names: Optional[set] = None, | |
| 1344 ) -> str: | |
| 1345 if not dir_path.exists(): | |
| 1346 return f"<h2>{title}</h2><p><em>Directory not found.</em></p>" | |
| 1347 | |
| 1348 exclude_names = exclude_names or set() | |
| 1349 | |
| 1350 imgs = list(dir_path.glob("*.png")) | |
| 1351 | |
| 1352 # Exclude ROC curves and standard confusion matrices (keep only entropy version) | |
| 1353 default_exclude = { | |
| 1354 # "roc_curves.png", # Remove ROC curves from test tab | |
| 1355 "confusion_matrix__label_top5.png", # Remove standard confusion matrix | |
| 1356 "confusion_matrix__label_top10.png", # Remove duplicate | |
| 1357 "confusion_matrix__label_top6.png", # Remove duplicate | |
| 1358 "confusion_matrix_entropy__label_top10.png", # Keep only top5 | |
| 1359 "confusion_matrix_entropy__label_top6.png", # Keep only top5 | |
| 1360 } | |
| 1361 | |
| 1362 imgs = [ | |
| 1363 img | |
| 1364 for img in imgs | |
| 1365 if img.name not in default_exclude | |
| 1366 and img.name not in exclude_names | |
| 1367 ] | |
| 1368 | |
| 1369 if not imgs: | |
| 1370 return f"<h2>{title}</h2><p><em>No plots found.</em></p>" | |
| 1371 | |
| 1372 # Sort images by name for consistent ordering (works with string and numeric labels) | |
| 1373 imgs = sorted(imgs, key=lambda x: x.name) | |
| 1374 | |
| 1375 html_section = "" | |
| 1376 for img in imgs: | |
| 1377 b64 = encode_image_to_base64(str(img)) | |
| 1378 img_title = img.stem.replace("_", " ").title() | |
| 1379 html_section += ( | |
| 1380 f"<h2 style='text-align: center;'>{img_title}</h2>" | |
| 1381 f'<div class="plot" style="margin-bottom:20px;text-align:center;">' | |
| 1382 f'<img src="data:image/png;base64,{b64}" ' | |
| 1383 f'style="max-width:90%;max-height:600px;border:1px solid #ddd;" />' | |
| 1384 f"</div>" | |
| 1385 ) | |
| 1386 return html_section | |
| 1387 | |
| 1388 tab1_content = config_html + metrics_html | |
| 1389 | |
| 1390 tab2_content = train_val_metrics_html + render_img_section( | |
| 1391 "Training and Validation Visualizations", | |
| 1392 train_viz_dir, | |
| 1393 output_type, | |
| 1394 exclude_names={ | |
| 1395 "compare_classifiers_performance_from_prob.png", | |
| 1396 "roc_curves_from_prediction_statistics.png", | |
| 1397 "precision_recall_curves_from_prediction_statistics.png", | |
| 1398 "precision_recall_curve.png", | |
| 1399 }, | |
| 1400 ) | |
| 1401 | |
| 1402 # --- Predictions vs Ground Truth table (REGRESSION ONLY) --- | |
| 1403 preds_section = "" | |
| 1404 parquet_path = exp_dir / PREDICTIONS_PARQUET_FILE_NAME | |
| 1405 if output_type == "regression" and parquet_path.exists(): | |
| 1406 try: | |
| 1407 # 1) load predictions from Parquet | |
| 1408 df_preds = pd.read_parquet(parquet_path).reset_index(drop=True) | |
| 1409 # assume the column containing your model's prediction is named "prediction" | |
| 1410 # or contains that substring: | |
| 1411 pred_col = next( | |
| 1412 (c for c in df_preds.columns if "prediction" in c.lower()), | |
| 1413 None, | |
| 1414 ) | |
| 1415 if pred_col is None: | |
| 1416 raise ValueError("No prediction column found in Parquet output") | |
| 1417 df_pred = df_preds[[pred_col]].rename(columns={pred_col: "prediction"}) | |
| 1418 | |
| 1419 # 2) load ground truth for the test split from prepared CSV | |
| 1420 df_all = pd.read_csv(config["label_column_data_path"]) | |
| 1421 df_gt = df_all[df_all[SPLIT_COLUMN_NAME] == 2][ | |
| 1422 LABEL_COLUMN_NAME | |
| 1423 ].reset_index(drop=True) | |
| 1424 # 3) concatenate side-by-side | |
| 1425 df_table = pd.concat([df_gt, df_pred], axis=1) | |
| 1426 df_table.columns = [LABEL_COLUMN_NAME, "prediction"] | |
| 1427 | |
| 1428 # 4) render as HTML | |
| 1429 preds_html = df_table.to_html(index=False, classes="predictions-table") | |
| 1430 preds_section = ( | |
| 1431 "<h2 style='text-align: center;'>Ground Truth vs. Predictions</h2>" | |
| 1432 "<div class='preds-controls'>" | |
| 1433 "<button id='downloadPredsCsv' class='download-btn'>Download CSV</button>" | |
| 1434 "</div>" | |
| 1435 "<div class='scroll-rows-30' style='overflow-x:auto; overflow-y:auto; max-height:900px; margin-bottom:20px;'>" | |
| 1436 + preds_html | |
| 1437 + "</div>" | |
| 1438 ) | |
| 1439 except Exception as e: | |
| 1440 logger.warning(f"Could not build Predictions vs GT table: {e}") | |
| 1441 | |
| 1442 tab3_content = test_metrics_html + preds_section | |
| 1443 | |
| 1444 if output_type in ("binary", "category") and test_stats_path.exists(): | |
| 1445 try: | |
| 1446 interactive_plots = build_classification_plots( | |
| 1447 str(test_stats_path), | |
| 1448 str(train_stats_path) if train_stats_path.exists() else None, | |
| 1449 ) | |
| 1450 for plot in interactive_plots: | |
| 1451 tab3_content += ( | |
| 1452 f"<h2 style='text-align: center;'>{plot['title']}</h2>" | |
| 1453 f"<div class='plotly-center'>{plot['html']}</div>" | |
| 1454 ) | |
| 1455 logger.info(f"Generated {len(interactive_plots)} interactive Plotly plots") | |
| 1456 except Exception as e: | |
| 1457 logger.warning(f"Could not generate Plotly plots: {e}") | |
| 1458 | |
| 1459 # Add static TEST PNGs (with default dedupe/exclusions) | |
| 1460 tab3_content += render_img_section( | |
| 1461 "Test Visualizations", test_viz_dir, output_type | |
| 1462 ) | |
| 1463 | |
| 1464 tabbed_html = build_tabbed_html(tab1_content, tab2_content, tab3_content) | |
| 1465 modal_html = get_metrics_help_modal() | |
| 1466 html += tabbed_html + modal_html + get_html_closing() | |
| 1467 | |
| 1468 try: | |
| 1469 with open(report_path, "w") as f: | |
| 1470 f.write(html) | |
| 1471 logger.info(f"HTML report generated at: {report_path}") | |
| 1472 except Exception as e: | |
| 1473 logger.error(f"Failed to write HTML report: {e}") | |
| 1474 raise | |
| 1475 | |
| 1476 return report_path | |
| 1477 | |
| 1478 | |
| 1479 class WorkflowOrchestrator: | |
| 1480 """Manages the image-classification workflow.""" | |
| 1481 | |
| 1482 def __init__(self, args: argparse.Namespace, backend: Backend): | |
| 1483 self.args = args | |
| 1484 self.backend = backend | |
| 1485 self.temp_dir: Optional[Path] = None | |
| 1486 self.image_extract_dir: Optional[Path] = None | |
| 1487 logger.info(f"Orchestrator initialized with backend: {type(backend).__name__}") | |
| 1488 | |
| 1489 def run(self) -> None: | |
| 1490 """Execute the full workflow end-to-end.""" | |
| 1491 # Delegate to the backend's run_experiment method | |
| 1492 self.backend.run_experiment() | |
| 1493 | |
| 1494 | |
| 1495 class ImageLearnerCLI: | |
| 1496 """Manages the image-classification workflow.""" | |
| 1497 | |
| 1498 def __init__(self, args: argparse.Namespace, backend: Backend): | |
| 1499 self.args = args | |
| 1500 self.backend = backend | |
| 1501 self.temp_dir: Optional[Path] = None | |
| 1502 self.image_extract_dir: Optional[Path] = None | |
| 1503 logger.info(f"Orchestrator initialized with backend: {type(backend).__name__}") | |
| 1504 | |
| 1505 def _create_temp_dirs(self) -> None: | |
| 1506 """Create temporary output and image extraction directories.""" | |
| 1507 try: | |
| 1508 self.temp_dir = Path( | |
| 1509 tempfile.mkdtemp(dir=self.args.output_dir, prefix=TEMP_DIR_PREFIX) | |
| 1510 ) | |
| 1511 self.image_extract_dir = self.temp_dir / "images" | |
| 1512 self.image_extract_dir.mkdir() | |
| 1513 logger.info(f"Created temp directory: {self.temp_dir}") | |
| 1514 except Exception: | |
| 1515 logger.error("Failed to create temporary directories", exc_info=True) | |
| 1516 raise | |
| 1517 | |
| 1518 def _extract_images(self) -> None: | |
| 1519 """Extract images into the temp image directory. | |
| 1520 - If a ZIP file is provided, extract it | |
| 1521 - If a directory is provided, copy its contents | |
| 1522 """ | |
| 1523 if self.image_extract_dir is None: | |
| 1524 raise RuntimeError("Temp image directory not initialized.") | |
| 1525 src = Path(self.args.image_zip) | |
| 1526 logger.info(f"Preparing images from {src} → {self.image_extract_dir}") | |
| 1527 try: | |
| 1528 if src.is_dir(): | |
| 1529 # copy directory tree | |
| 1530 for root, dirs, files in os.walk(src): | |
| 1531 rel = Path(root).relative_to(src) | |
| 1532 target_root = self.image_extract_dir / rel | |
| 1533 target_root.mkdir(parents=True, exist_ok=True) | |
| 1534 for fn in files: | |
| 1535 shutil.copy2(Path(root) / fn, target_root / fn) | |
| 1536 logger.info("Image directory copied.") | |
| 1537 else: | |
| 1538 with zipfile.ZipFile(src, "r") as z: | |
| 1539 z.extractall(self.image_extract_dir) | |
| 1540 logger.info("Image extraction complete.") | |
| 1541 except Exception: | |
| 1542 logger.error("Error preparing images", exc_info=True) | |
| 1543 raise | |
| 1544 | |
| 1545 def _process_fixed_split( | |
| 1546 self, df: pd.DataFrame | |
| 1547 ) -> Tuple[pd.DataFrame, Dict[str, Any], str]: | |
| 1548 """Process datasets that already have a split column.""" | |
| 1549 unique = set(df[SPLIT_COLUMN_NAME].unique()) | |
| 1550 if unique == {0, 2}: | |
| 1551 # Split 0/2 detected, create validation set | |
| 1552 df = split_data_0_2( | |
| 1553 df=df, | |
| 1554 split_column=SPLIT_COLUMN_NAME, | |
| 1555 validation_size=self.args.validation_size, | |
| 1556 random_state=self.args.random_seed, | |
| 1557 label_column=LABEL_COLUMN_NAME, | |
| 1558 ) | |
| 1559 split_config = {"type": "fixed", "column": SPLIT_COLUMN_NAME} | |
| 1560 split_info = ( | |
| 1561 "Detected a split column (with values 0 and 2) in the input CSV. " | |
| 1562 f"Used this column as a base and reassigned " | |
| 1563 f"{self.args.validation_size * 100:.1f}% " | |
| 1564 "of the training set (originally labeled 0) to validation (labeled 1) using stratified sampling." | |
| 1565 ) | |
| 1566 logger.info("Applied custom 0/2 split.") | |
| 1567 elif unique.issubset({0, 1, 2}): | |
| 1568 # Standard 0/1/2 split | |
| 1569 split_config = {"type": "fixed", "column": SPLIT_COLUMN_NAME} | |
| 1570 split_info = ( | |
| 1571 "Detected a split column with train(0)/validation(1)/test(2) " | |
| 1572 "values in the input CSV. Used this column as-is." | |
| 1573 ) | |
| 1574 logger.info("Fixed split column detected.") | |
| 1575 else: | |
| 1576 raise ValueError( | |
| 1577 f"Split column contains unexpected values: {unique}. " | |
| 1578 "Expected: {{0,1,2}} or {{0,2}}" | |
| 1579 ) | |
| 1580 | |
| 1581 return df, split_config, split_info | |
| 1582 | |
| 1583 def _prepare_data(self) -> Tuple[Path, Dict[str, Any], str]: | |
| 1584 """Load CSV, update image paths, handle splits, and write prepared CSV.""" | |
| 1585 if not self.temp_dir or not self.image_extract_dir: | |
| 1586 raise RuntimeError("Temp dirs not initialized before data prep.") | |
| 1587 | |
| 1588 try: | |
| 1589 df = pd.read_csv(self.args.csv_file) | |
| 1590 logger.info(f"Loaded CSV: {self.args.csv_file}") | |
| 1591 except Exception: | |
| 1592 logger.error("Error loading CSV file", exc_info=True) | |
| 1593 raise | |
| 1594 | |
| 1595 required = {IMAGE_PATH_COLUMN_NAME, LABEL_COLUMN_NAME} | |
| 1596 missing = required - set(df.columns) | |
| 1597 if missing: | |
| 1598 raise ValueError(f"Missing CSV columns: {', '.join(missing)}") | |
| 1599 | |
| 1600 try: | |
| 1601 # Use relative paths that Ludwig can resolve from its internal working directory | |
| 1602 df[IMAGE_PATH_COLUMN_NAME] = df[IMAGE_PATH_COLUMN_NAME].apply( | |
| 1603 lambda p: str(Path("images") / p) | |
| 1604 ) | |
| 1605 except Exception: | |
| 1606 logger.error("Error updating image paths", exc_info=True) | |
| 1607 raise | |
| 1608 | |
| 1609 if SPLIT_COLUMN_NAME in df.columns: | |
| 1610 df, split_config, split_info = self._process_fixed_split(df) | |
| 1611 else: | |
| 1612 logger.info("No split column; creating stratified random split") | |
| 1613 df = create_stratified_random_split( | |
| 1614 df=df, | |
| 1615 split_column=SPLIT_COLUMN_NAME, | |
| 1616 split_probabilities=self.args.split_probabilities, | |
| 1617 random_state=self.args.random_seed, | |
| 1618 label_column=LABEL_COLUMN_NAME, | |
| 1619 ) | |
| 1620 split_config = { | |
| 1621 "type": "fixed", | |
| 1622 "column": SPLIT_COLUMN_NAME, | |
| 1623 } | |
| 1624 split_info = ( | |
| 1625 f"No split column in CSV. Created stratified random split: " | |
| 1626 f"{[int(p * 100) for p in self.args.split_probabilities]}% " | |
| 1627 f"for train/val/test with balanced label distribution." | |
| 1628 ) | |
| 1629 | |
| 1630 final_csv = self.temp_dir / TEMP_CSV_FILENAME | |
| 1631 | |
| 1632 try: | |
| 1633 | |
| 1634 df.to_csv(final_csv, index=False) | |
| 1635 logger.info(f"Saved prepared data to {final_csv}") | |
| 1636 except Exception: | |
| 1637 logger.error("Error saving prepared CSV", exc_info=True) | |
| 1638 raise | |
| 1639 | |
| 1640 return final_csv, split_config, split_info | |
| 1641 | |
| 1642 # Removed duplicate method | |
| 1643 | |
| 1644 def _detect_image_dimensions(self) -> Tuple[int, int]: | |
| 1645 """Detect image dimensions from the first image in the dataset.""" | |
| 1646 try: | |
| 1647 import zipfile | |
| 1648 from PIL import Image | |
| 1649 import io | |
| 1650 | |
| 1651 # Check if image_zip is provided | |
| 1652 if not self.args.image_zip: | |
| 1653 logger.warning("No image zip provided, using default 224x224") | |
| 1654 return 224, 224 | |
| 1655 | |
| 1656 # Extract first image to detect dimensions | |
| 1657 with zipfile.ZipFile(self.args.image_zip, 'r') as z: | |
| 1658 image_files = [f for f in z.namelist() if f.lower().endswith(('.png', '.jpg', '.jpeg'))] | |
| 1659 if not image_files: | |
| 1660 logger.warning("No image files found in zip, using default 224x224") | |
| 1661 return 224, 224 | |
| 1662 | |
| 1663 # Check first image | |
| 1664 with z.open(image_files[0]) as f: | |
| 1665 img = Image.open(io.BytesIO(f.read())) | |
| 1666 width, height = img.size | |
| 1667 logger.info(f"Detected image dimensions: {width}x{height}") | |
| 1668 return height, width # Return as (height, width) to match encoder config | |
| 1669 | |
| 1670 except Exception as e: | |
| 1671 logger.warning(f"Error detecting image dimensions: {e}, using default 224x224") | |
| 1672 return 224, 224 | |
| 1673 | |
| 1674 def _cleanup_temp_dirs(self) -> None: | |
| 1675 if self.temp_dir and self.temp_dir.exists(): | |
| 1676 logger.info(f"Cleaning up temp directory: {self.temp_dir}") | |
| 1677 # Don't clean up for debugging | |
| 1678 shutil.rmtree(self.temp_dir, ignore_errors=True) | |
| 1679 self.temp_dir = None | |
| 1680 self.image_extract_dir = None | |
| 1681 | |
| 1682 def run(self) -> None: | |
| 1683 """Execute the full workflow end-to-end.""" | |
| 1684 logger.info("Starting workflow...") | |
| 1685 self.args.output_dir.mkdir(parents=True, exist_ok=True) | |
| 1686 | |
| 1687 try: | |
| 1688 self._create_temp_dirs() | |
| 1689 self._extract_images() | |
| 1690 csv_path, split_cfg, split_info = self._prepare_data() | |
| 1691 | |
| 1692 use_pretrained = self.args.use_pretrained or self.args.fine_tune | |
| 1693 | |
| 1694 backend_args = { | |
| 1695 "model_name": self.args.model_name, | |
| 1696 "fine_tune": self.args.fine_tune, | |
| 1697 "use_pretrained": use_pretrained, | |
| 1698 "epochs": self.args.epochs, | |
| 1699 "batch_size": self.args.batch_size, | |
| 1700 "preprocessing_num_processes": self.args.preprocessing_num_processes, | |
| 1701 "split_probabilities": self.args.split_probabilities, | |
| 1702 "learning_rate": self.args.learning_rate, | |
| 1703 "random_seed": self.args.random_seed, | |
| 1704 "early_stop": self.args.early_stop, | |
| 1705 "label_column_data_path": csv_path, | |
| 1706 "augmentation": self.args.augmentation, | |
| 1707 "image_resize": self.args.image_resize, | |
| 1708 "image_zip": self.args.image_zip, | |
| 1709 "threshold": self.args.threshold, | |
| 1710 } | |
| 1711 yaml_str = self.backend.prepare_config(backend_args, split_cfg) | |
| 1712 | |
| 1713 config_file = self.temp_dir / TEMP_CONFIG_FILENAME | |
| 1714 config_file.write_text(yaml_str) | |
| 1715 logger.info(f"Wrote backend config: {config_file}") | |
| 1716 | |
| 1717 ran_ok = True | |
| 1718 try: | |
| 1719 # Run Ludwig experiment with absolute paths to avoid working directory issues | |
| 1720 self.backend.run_experiment( | |
| 1721 csv_path, | |
| 1722 config_file, | |
| 1723 self.args.output_dir, | |
| 1724 self.args.random_seed, | |
| 1725 ) | |
| 1726 except Exception: | |
| 1727 logger.error("Workflow execution failed", exc_info=True) | |
| 1728 ran_ok = False | |
| 1729 | |
| 1730 if ran_ok: | |
| 1731 logger.info("Workflow completed successfully.") | |
| 1732 # Generate a very small set of plots to conserve disk space | |
| 1733 self.backend.generate_plots(self.args.output_dir) | |
| 1734 # Build HTML report (robust to missing metrics) | |
| 1735 report_file = self.backend.generate_html_report( | |
| 1736 "Image Classification Results", | |
| 1737 self.args.output_dir, | |
| 1738 backend_args, | |
| 1739 split_info, | |
| 1740 ) | |
| 1741 logger.info(f"HTML report generated at: {report_file}") | |
| 1742 # Convert predictions parquet → csv | |
| 1743 self.backend.convert_parquet_to_csv(self.args.output_dir) | |
| 1744 logger.info("Converted Parquet to CSV.") | |
| 1745 # Post-process cleanup to reduce disk footprint for subsequent tests | |
| 1746 try: | |
| 1747 self._postprocess_cleanup(self.args.output_dir) | |
| 1748 except Exception as cleanup_err: | |
| 1749 logger.warning(f"Cleanup step failed: {cleanup_err}") | |
| 1750 else: | |
| 1751 # Fallback: create minimal outputs so downstream steps can proceed | |
| 1752 logger.warning("Falling back to minimal outputs due to runtime failure.") | |
| 1753 try: | |
| 1754 self._create_minimal_outputs(self.args.output_dir, csv_path) | |
| 1755 # Even in fallback, produce an HTML shell so tests find required text | |
| 1756 report_file = self.backend.generate_html_report( | |
| 1757 "Image Classification Results", | |
| 1758 self.args.output_dir, | |
| 1759 backend_args, | |
| 1760 split_info, | |
| 1761 ) | |
| 1762 logger.info(f"HTML report (fallback) generated at: {report_file}") | |
| 1763 except Exception as fb_err: | |
| 1764 logger.error(f"Failed to build fallback outputs: {fb_err}") | |
| 1765 raise | |
| 1766 | |
| 1767 except Exception: | |
| 1768 logger.error("Workflow execution failed", exc_info=True) | |
| 1769 raise | |
| 1770 finally: | |
| 1771 self._cleanup_temp_dirs() | |
| 1772 | |
| 1773 def _postprocess_cleanup(self, output_dir: Path) -> None: | |
| 1774 """Remove large intermediates and caches to conserve disk space across tests.""" | |
| 1775 output_dir = Path(output_dir) | |
| 1776 exp_dirs = sorted( | |
| 1777 output_dir.glob("experiment_run*"), | |
| 1778 key=lambda p: p.stat().st_mtime, | |
| 1779 ) | |
| 1780 if exp_dirs: | |
| 1781 exp_dir = exp_dirs[-1] | |
| 1782 # Remove training checkpoints directory if present | |
| 1783 ckpt_dir = exp_dir / "model" / "training_checkpoints" | |
| 1784 if ckpt_dir.exists(): | |
| 1785 shutil.rmtree(ckpt_dir, ignore_errors=True) | |
| 1786 # Remove predictions parquet once CSV is generated | |
| 1787 parquet_path = exp_dir / PREDICTIONS_PARQUET_FILE_NAME | |
| 1788 if parquet_path.exists(): | |
| 1789 try: | |
| 1790 parquet_path.unlink() | |
| 1791 except Exception: | |
| 1792 pass | |
| 1793 | |
| 1794 # Clear torch hub cache under the job-scoped home, if present | |
| 1795 job_home_torch_hub = Path.cwd() / "home" / ".cache" / "torch" / "hub" | |
| 1796 if job_home_torch_hub.exists(): | |
| 1797 shutil.rmtree(job_home_torch_hub, ignore_errors=True) | |
| 1798 | |
| 1799 # Also try the default user cache as a best-effort (may not exist in job sandbox) | |
| 1800 user_home_torch_hub = Path.home() / ".cache" / "torch" / "hub" | |
| 1801 if user_home_torch_hub.exists(): | |
| 1802 shutil.rmtree(user_home_torch_hub, ignore_errors=True) | |
| 1803 | |
| 1804 # Clear huggingface cache if present in the job sandbox | |
| 1805 job_home_hf = Path.cwd() / "home" / ".cache" / "huggingface" | |
| 1806 if job_home_hf.exists(): | |
| 1807 shutil.rmtree(job_home_hf, ignore_errors=True) | |
| 1808 | |
| 1809 def _create_minimal_outputs(self, output_dir: Path, prepared_csv_path: Path) -> None: | |
| 1810 """Create a minimal set of outputs so Galaxy can collect expected artifacts. | |
| 1811 | |
| 1812 - experiment_run/ | |
| 1813 - predictions.csv (1 column) | |
| 1814 - visualizations/train/ (empty) | |
| 1815 - visualizations/test/ (empty) | |
| 1816 - model/ | |
| 1817 - model_weights/ (empty) | |
| 1818 - model_hyperparameters.json (stub) | |
| 1819 """ | |
| 1820 output_dir = Path(output_dir) | |
| 1821 exp_dir = output_dir / "experiment_run" | |
| 1822 (exp_dir / "visualizations" / "train").mkdir(parents=True, exist_ok=True) | |
| 1823 (exp_dir / "visualizations" / "test").mkdir(parents=True, exist_ok=True) | |
| 1824 model_dir = exp_dir / "model" | |
| 1825 (model_dir / "model_weights").mkdir(parents=True, exist_ok=True) | |
| 1826 | |
| 1827 # Stub JSON so the tool's copy step succeeds | |
| 1828 try: | |
| 1829 (model_dir / "model_hyperparameters.json").write_text("{}\n") | |
| 1830 except Exception: | |
| 1831 pass | |
| 1832 | |
| 1833 # Create a small predictions.csv with exactly 1 column | |
| 1834 try: | |
| 1835 df_all = pd.read_csv(prepared_csv_path) | |
| 1836 from constants import SPLIT_COLUMN_NAME # local import to avoid cycle at top | |
| 1837 num_rows = int((df_all[SPLIT_COLUMN_NAME] == 2).sum()) if SPLIT_COLUMN_NAME in df_all.columns else 1 | |
| 1838 except Exception: | |
| 1839 num_rows = 1 | |
| 1840 num_rows = max(1, num_rows) | |
| 1841 pd.DataFrame({"prediction": [0] * num_rows}).to_csv(exp_dir / "predictions.csv", index=False) | |
| 1842 | |
| 1843 | |
| 1844 def parse_learning_rate(s): | |
| 1845 try: | |
| 1846 return float(s) | |
| 1847 except (TypeError, ValueError): | |
| 1848 return None | |
| 1849 | |
| 1850 | |
| 1851 def aug_parse(aug_string: str): | |
| 1852 """ | |
| 1853 Parse comma-separated augmentation keys into Ludwig augmentation dicts. | |
| 1854 Raises ValueError on unknown key. | |
| 1855 """ | |
| 1856 mapping = { | |
| 1857 "random_horizontal_flip": {"type": "random_horizontal_flip"}, | |
| 1858 "random_vertical_flip": {"type": "random_vertical_flip"}, | |
| 1859 "random_rotate": {"type": "random_rotate", "degree": 10}, | |
| 1860 "random_blur": {"type": "random_blur", "kernel_size": 3}, | |
| 1861 "random_brightness": {"type": "random_brightness", "min": 0.5, "max": 2.0}, | |
| 1862 "random_contrast": {"type": "random_contrast", "min": 0.5, "max": 2.0}, | |
| 1863 } | |
| 1864 aug_list = [] | |
| 1865 for tok in aug_string.split(","): | |
| 1866 key = tok.strip() | |
| 1867 if not key: | |
| 1868 continue | |
| 1869 if key not in mapping: | |
| 1870 valid = ", ".join(mapping.keys()) | |
| 1871 raise ValueError(f"Unknown augmentation '{key}'. Valid choices: {valid}") | |
| 1872 aug_list.append(mapping[key]) | |
| 1873 return aug_list | |
| 1874 | |
| 1875 | |
| 1876 class SplitProbAction(argparse.Action): | |
| 1877 def __call__(self, parser, namespace, values, option_string=None): | |
| 1878 train, val, test = values | |
| 1879 total = train + val + test | |
| 1880 if abs(total - 1.0) > 1e-6: | |
| 1881 parser.error( | |
| 1882 f"--split-probabilities must sum to 1.0; " | |
| 1883 f"got {train:.3f} + {val:.3f} + {test:.3f} = {total:.3f}" | |
| 1884 ) | |
| 1885 setattr(namespace, self.dest, values) | |
| 1886 | |
| 1887 | 24 |
| 1888 def main(): | 25 def main(): |
| 1889 parser = argparse.ArgumentParser( | 26 parser = argparse.ArgumentParser( |
| 1890 description="Image Classification Learner with Pluggable Backends", | 27 description="Image Classification Learner with Pluggable Backends", |
| 1891 ) | 28 ) |
| 1892 parser.add_argument( | 29 parser.add_argument( |
| 1893 "--csv-file", | 30 "--csv-file", |
| 1894 required=True, | 31 required=True, |
| 1895 type=Path, | 32 type=Path, |
| 1896 help="Path to the input CSV", | 33 help="Path to the input metadata file (CSV, TSV, etc)", |
| 1897 ) | 34 ) |
| 1898 parser.add_argument( | 35 parser.add_argument( |
| 1899 "--image-zip", | 36 "--image-zip", |
| 1900 required=True, | 37 required=True, |
| 1901 type=Path, | 38 type=Path, |
| 2006 ), | 143 ), |
| 2007 ) | 144 ) |
| 2008 | 145 |
| 2009 args = parser.parse_args() | 146 args = parser.parse_args() |
| 2010 | 147 |
| 2011 if not 0.0 <= args.validation_size <= 1.0: | 148 argument_checker(args, parser) |
| 2012 parser.error("validation-size must be between 0.0 and 1.0") | |
| 2013 if not args.csv_file.is_file(): | |
| 2014 parser.error(f"CSV not found: {args.csv_file}") | |
| 2015 if not (args.image_zip.is_file() or args.image_zip.is_dir()): | |
| 2016 parser.error(f"ZIP or directory not found: {args.image_zip}") | |
| 2017 if args.augmentation is not None: | |
| 2018 try: | |
| 2019 augmentation_setup = aug_parse(args.augmentation) | |
| 2020 setattr(args, "augmentation", augmentation_setup) | |
| 2021 except ValueError as e: | |
| 2022 parser.error(str(e)) | |
| 2023 | 149 |
| 2024 backend_instance = LudwigDirectBackend() | 150 backend_instance = LudwigDirectBackend() |
| 2025 orchestrator = ImageLearnerCLI(args, backend_instance) | 151 orchestrator = ImageLearnerCLI(args, backend_instance) |
| 2026 | 152 |
| 2027 exit_code = 0 | 153 exit_code = 0 |
