Mercurial > repos > goeckslab > image_learner
comparison image_learner_cli.py @ 0:54b871dfc51e draft
planemo upload for repository https://github.com/goeckslab/gleam.git commit b7411ff35b6228ccdfd36cd4ebd946c03ac7f7e9
| author | goeckslab |
|---|---|
| date | Tue, 03 Jun 2025 21:22:11 +0000 |
| parents | |
| children | 39202fe5cf97 |
comparison
equal
deleted
inserted
replaced
| -1:000000000000 | 0:54b871dfc51e |
|---|---|
| 1 #!/usr/bin/env python3 | |
| 2 import argparse | |
| 3 import json | |
| 4 import logging | |
| 5 import os | |
| 6 import shutil | |
| 7 import sys | |
| 8 import tempfile | |
| 9 import zipfile | |
| 10 from pathlib import Path | |
| 11 from typing import Any, Dict, Optional, Protocol, Tuple | |
| 12 | |
| 13 import pandas as pd | |
| 14 import yaml | |
| 15 from ludwig.globals import ( | |
| 16 DESCRIPTION_FILE_NAME, | |
| 17 PREDICTIONS_PARQUET_FILE_NAME, | |
| 18 TEST_STATISTICS_FILE_NAME, | |
| 19 TRAIN_SET_METADATA_FILE_NAME, | |
| 20 ) | |
| 21 from ludwig.utils.data_utils import get_split_path | |
| 22 from ludwig.visualize import get_visualizations_registry | |
| 23 from sklearn.model_selection import train_test_split | |
| 24 from utils import encode_image_to_base64, get_html_closing, get_html_template | |
| 25 | |
| 26 # --- Constants --- | |
| 27 SPLIT_COLUMN_NAME = 'split' | |
| 28 LABEL_COLUMN_NAME = 'label' | |
| 29 IMAGE_PATH_COLUMN_NAME = 'image_path' | |
| 30 DEFAULT_SPLIT_PROBABILITIES = [0.7, 0.1, 0.2] | |
| 31 TEMP_CSV_FILENAME = "processed_data_for_ludwig.csv" | |
| 32 TEMP_CONFIG_FILENAME = "ludwig_config.yaml" | |
| 33 TEMP_DIR_PREFIX = "ludwig_api_work_" | |
| 34 MODEL_ENCODER_TEMPLATES: Dict[str, Any] = { | |
| 35 'stacked_cnn': 'stacked_cnn', | |
| 36 'resnet18': {'type': 'resnet', 'model_variant': 18}, | |
| 37 'resnet34': {'type': 'resnet', 'model_variant': 34}, | |
| 38 'resnet50': {'type': 'resnet', 'model_variant': 50}, | |
| 39 'resnet101': {'type': 'resnet', 'model_variant': 101}, | |
| 40 'resnet152': {'type': 'resnet', 'model_variant': 152}, | |
| 41 'resnext50_32x4d': {'type': 'resnext', 'model_variant': '50_32x4d'}, | |
| 42 'resnext101_32x8d': {'type': 'resnext', 'model_variant': '101_32x8d'}, | |
| 43 'resnext101_64x4d': {'type': 'resnext', 'model_variant': '101_64x4d'}, | |
| 44 'resnext152_32x8d': {'type': 'resnext', 'model_variant': '152_32x8d'}, | |
| 45 'wide_resnet50_2': {'type': 'wide_resnet', 'model_variant': '50_2'}, | |
| 46 'wide_resnet101_2': {'type': 'wide_resnet', 'model_variant': '101_2'}, | |
| 47 'wide_resnet103_2': {'type': 'wide_resnet', 'model_variant': '103_2'}, | |
| 48 'efficientnet_b0': {'type': 'efficientnet', 'model_variant': 'b0'}, | |
| 49 'efficientnet_b1': {'type': 'efficientnet', 'model_variant': 'b1'}, | |
| 50 'efficientnet_b2': {'type': 'efficientnet', 'model_variant': 'b2'}, | |
| 51 'efficientnet_b3': {'type': 'efficientnet', 'model_variant': 'b3'}, | |
| 52 'efficientnet_b4': {'type': 'efficientnet', 'model_variant': 'b4'}, | |
| 53 'efficientnet_b5': {'type': 'efficientnet', 'model_variant': 'b5'}, | |
| 54 'efficientnet_b6': {'type': 'efficientnet', 'model_variant': 'b6'}, | |
| 55 'efficientnet_b7': {'type': 'efficientnet', 'model_variant': 'b7'}, | |
| 56 'efficientnet_v2_s': {'type': 'efficientnet', 'model_variant': 'v2_s'}, | |
| 57 'efficientnet_v2_m': {'type': 'efficientnet', 'model_variant': 'v2_m'}, | |
| 58 'efficientnet_v2_l': {'type': 'efficientnet', 'model_variant': 'v2_l'}, | |
| 59 'regnet_y_400mf': {'type': 'regnet', 'model_variant': 'y_400mf'}, | |
| 60 'regnet_y_800mf': {'type': 'regnet', 'model_variant': 'y_800mf'}, | |
| 61 'regnet_y_1_6gf': {'type': 'regnet', 'model_variant': 'y_1_6gf'}, | |
| 62 'regnet_y_3_2gf': {'type': 'regnet', 'model_variant': 'y_3_2gf'}, | |
| 63 'regnet_y_8gf': {'type': 'regnet', 'model_variant': 'y_8gf'}, | |
| 64 'regnet_y_16gf': {'type': 'regnet', 'model_variant': 'y_16gf'}, | |
| 65 'regnet_y_32gf': {'type': 'regnet', 'model_variant': 'y_32gf'}, | |
| 66 'regnet_y_128gf': {'type': 'regnet', 'model_variant': 'y_128gf'}, | |
| 67 'regnet_x_400mf': {'type': 'regnet', 'model_variant': 'x_400mf'}, | |
| 68 'regnet_x_800mf': {'type': 'regnet', 'model_variant': 'x_800mf'}, | |
| 69 'regnet_x_1_6gf': {'type': 'regnet', 'model_variant': 'x_1_6gf'}, | |
| 70 'regnet_x_3_2gf': {'type': 'regnet', 'model_variant': 'x_3_2gf'}, | |
| 71 'regnet_x_8gf': {'type': 'regnet', 'model_variant': 'x_8gf'}, | |
| 72 'regnet_x_16gf': {'type': 'regnet', 'model_variant': 'x_16gf'}, | |
| 73 'regnet_x_32gf': {'type': 'regnet', 'model_variant': 'x_32gf'}, | |
| 74 'vgg11': {'type': 'vgg', 'model_variant': 11}, | |
| 75 'vgg11_bn': {'type': 'vgg', 'model_variant': '11_bn'}, | |
| 76 'vgg13': {'type': 'vgg', 'model_variant': 13}, | |
| 77 'vgg13_bn': {'type': 'vgg', 'model_variant': '13_bn'}, | |
| 78 'vgg16': {'type': 'vgg', 'model_variant': 16}, | |
| 79 'vgg16_bn': {'type': 'vgg', 'model_variant': '16_bn'}, | |
| 80 'vgg19': {'type': 'vgg', 'model_variant': 19}, | |
| 81 'vgg19_bn': {'type': 'vgg', 'model_variant': '19_bn'}, | |
| 82 'shufflenet_v2_x0_5': {'type': 'shufflenet_v2', 'model_variant': 'x0_5'}, | |
| 83 'shufflenet_v2_x1_0': {'type': 'shufflenet_v2', 'model_variant': 'x1_0'}, | |
| 84 'shufflenet_v2_x1_5': {'type': 'shufflenet_v2', 'model_variant': 'x1_5'}, | |
| 85 'shufflenet_v2_x2_0': {'type': 'shufflenet_v2', 'model_variant': 'x2_0'}, | |
| 86 'squeezenet1_0': {'type': 'squeezenet', 'model_variant': '1_0'}, | |
| 87 'squeezenet1_1': {'type': 'squeezenet', 'model_variant': '1_1'}, | |
| 88 'swin_t': {'type': 'swin_transformer', 'model_variant': 't'}, | |
| 89 'swin_s': {'type': 'swin_transformer', 'model_variant': 's'}, | |
| 90 'swin_b': {'type': 'swin_transformer', 'model_variant': 'b'}, | |
| 91 'swin_v2_t': {'type': 'swin_transformer', 'model_variant': 'v2_t'}, | |
| 92 'swin_v2_s': {'type': 'swin_transformer', 'model_variant': 'v2_s'}, | |
| 93 'swin_v2_b': {'type': 'swin_transformer', 'model_variant': 'v2_b'}, | |
| 94 'vit_b_16': {'type': 'vision_transformer', 'model_variant': 'b_16'}, | |
| 95 'vit_b_32': {'type': 'vision_transformer', 'model_variant': 'b_32'}, | |
| 96 'vit_l_16': {'type': 'vision_transformer', 'model_variant': 'l_16'}, | |
| 97 'vit_l_32': {'type': 'vision_transformer', 'model_variant': 'l_32'}, | |
| 98 'vit_h_14': {'type': 'vision_transformer', 'model_variant': 'h_14'}, | |
| 99 'convnext_tiny': {'type': 'convnext', 'model_variant': 'tiny'}, | |
| 100 'convnext_small': {'type': 'convnext', 'model_variant': 'small'}, | |
| 101 'convnext_base': {'type': 'convnext', 'model_variant': 'base'}, | |
| 102 'convnext_large': {'type': 'convnext', 'model_variant': 'large'}, | |
| 103 'maxvit_t': {'type': 'maxvit', 'model_variant': 't'}, | |
| 104 'alexnet': {'type': 'alexnet'}, | |
| 105 'googlenet': {'type': 'googlenet'}, | |
| 106 'inception_v3': {'type': 'inception_v3'}, | |
| 107 'mobilenet_v2': {'type': 'mobilenet_v2'}, | |
| 108 'mobilenet_v3_large': {'type': 'mobilenet_v3_large'}, | |
| 109 'mobilenet_v3_small': {'type': 'mobilenet_v3_small'}, | |
| 110 } | |
| 111 | |
| 112 # --- Logging Setup --- | |
| 113 logging.basicConfig( | |
| 114 level=logging.INFO, | |
| 115 format='%(asctime)s %(levelname)s %(name)s: %(message)s' | |
| 116 ) | |
| 117 logger = logging.getLogger("ImageLearner") | |
| 118 | |
| 119 | |
| 120 def format_config_table_html( | |
| 121 config: dict, | |
| 122 split_info: Optional[str] = None, | |
| 123 training_progress: dict = None) -> str: | |
| 124 display_keys = [ | |
| 125 "model_name", | |
| 126 "epochs", | |
| 127 "batch_size", | |
| 128 "fine_tune", | |
| 129 "use_pretrained", | |
| 130 "learning_rate", | |
| 131 "random_seed", | |
| 132 "early_stop", | |
| 133 ] | |
| 134 | |
| 135 rows = [] | |
| 136 | |
| 137 for key in display_keys: | |
| 138 val = config.get(key, "N/A") | |
| 139 if key == "batch_size": | |
| 140 if val is not None: | |
| 141 val = int(val) | |
| 142 else: | |
| 143 if training_progress: | |
| 144 val = "Auto-selected batch size by Ludwig:<br>" | |
| 145 resolved_val = training_progress.get("batch_size") | |
| 146 val += ( | |
| 147 f"<span style='font-size: 0.85em;'>{resolved_val}</span><br>" | |
| 148 ) | |
| 149 else: | |
| 150 val = "auto" | |
| 151 if key == "learning_rate": | |
| 152 resolved_val = None | |
| 153 if val is None or val == "auto": | |
| 154 if training_progress: | |
| 155 resolved_val = training_progress.get("learning_rate") | |
| 156 val = ( | |
| 157 "Auto-selected learning rate by Ludwig:<br>" | |
| 158 f"<span style='font-size: 0.85em;'>{resolved_val if resolved_val else val}</span><br>" | |
| 159 "<span style='font-size: 0.85em;'>" | |
| 160 "Based on model architecture and training setup (e.g., fine-tuning).<br>" | |
| 161 "See <a href='https://ludwig.ai/latest/configuration/trainer/#trainer-parameters' " | |
| 162 "target='_blank'>Ludwig Trainer Parameters</a> for details." | |
| 163 "</span>" | |
| 164 ) | |
| 165 else: | |
| 166 val = ( | |
| 167 "Auto-selected by Ludwig<br>" | |
| 168 "<span style='font-size: 0.85em;'>" | |
| 169 "Automatically tuned based on architecture and dataset.<br>" | |
| 170 "See <a href='https://ludwig.ai/latest/configuration/trainer/#trainer-parameters' " | |
| 171 "target='_blank'>Ludwig Trainer Parameters</a> for details." | |
| 172 "</span>" | |
| 173 ) | |
| 174 else: | |
| 175 val = f"{val:.6f}" | |
| 176 if key == "epochs": | |
| 177 if training_progress and "epoch" in training_progress and val > training_progress["epoch"]: | |
| 178 val = ( | |
| 179 f"Because of early stopping: the training" | |
| 180 f"stopped at epoch {training_progress['epoch']}" | |
| 181 ) | |
| 182 | |
| 183 if val is None: | |
| 184 continue | |
| 185 rows.append( | |
| 186 f"<tr>" | |
| 187 f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left;'>" | |
| 188 f"{key.replace('_', ' ').title()}</td>" | |
| 189 f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center;'>{val}</td>" | |
| 190 f"</tr>" | |
| 191 ) | |
| 192 | |
| 193 if split_info: | |
| 194 rows.append( | |
| 195 f"<tr>" | |
| 196 f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left;'>Data Split</td>" | |
| 197 f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center;'>{split_info}</td>" | |
| 198 f"</tr>" | |
| 199 ) | |
| 200 | |
| 201 return ( | |
| 202 "<h2 style='text-align: center;'>Training Setup</h2>" | |
| 203 "<div style='display: flex; justify-content: center;'>" | |
| 204 "<table style='border-collapse: collapse; width: 60%; table-layout: auto;'>" | |
| 205 "<thead><tr>" | |
| 206 "<th style='padding: 10px; border: 1px solid #ccc; text-align: left;'>Parameter</th>" | |
| 207 "<th style='padding: 10px; border: 1px solid #ccc; text-align: center;'>Value</th>" | |
| 208 "</tr></thead><tbody>" + "".join(rows) + "</tbody></table></div><br>" | |
| 209 "<p style='text-align: center; font-size: 0.9em;'>" | |
| 210 "Model trained using Ludwig.<br>" | |
| 211 "If want to learn more about Ludwig default settings," | |
| 212 "please check the their <a href='https://ludwig.ai' target='_blank'>website(ludwig.ai)</a>." | |
| 213 "</p><hr>" | |
| 214 ) | |
| 215 | |
| 216 | |
| 217 def format_stats_table_html(training_stats: dict, test_stats: dict) -> str: | |
| 218 train_metrics = training_stats.get("training", {}).get("label", {}) | |
| 219 val_metrics = training_stats.get("validation", {}).get("label", {}) | |
| 220 test_metrics = test_stats.get("label", {}) | |
| 221 | |
| 222 all_metrics = set(train_metrics) | set(val_metrics) | set(test_metrics) | |
| 223 | |
| 224 def get_last_value(stats, key): | |
| 225 val = stats.get(key) | |
| 226 if isinstance(val, list) and val: | |
| 227 return val[-1] | |
| 228 elif isinstance(val, (int, float)): | |
| 229 return val | |
| 230 return None | |
| 231 | |
| 232 rows = [] | |
| 233 for metric in sorted(all_metrics): | |
| 234 t = get_last_value(train_metrics, metric) | |
| 235 v = get_last_value(val_metrics, metric) | |
| 236 te = get_last_value(test_metrics, metric) | |
| 237 if all(x is not None for x in [t, v, te]): | |
| 238 row = ( | |
| 239 f"<tr>" | |
| 240 f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left;'>{metric}</td>" | |
| 241 f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center;'>{t:.4f}</td>" | |
| 242 f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center;'>{v:.4f}</td>" | |
| 243 f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center;'>{te:.4f}</td>" | |
| 244 f"</tr>" | |
| 245 ) | |
| 246 rows.append(row) | |
| 247 | |
| 248 if not rows: | |
| 249 return "<p><em>No metric values found.</em></p>" | |
| 250 | |
| 251 return ( | |
| 252 "<h2 style='text-align: center;'>Model Performance Summary</h2>" | |
| 253 "<div style='display: flex; justify-content: center;'>" | |
| 254 "<table style='border-collapse: collapse; width: 80%; table-layout: fixed;'>" | |
| 255 "<colgroup>" | |
| 256 "<col style='width: 40%;'>" | |
| 257 "<col style='width: 20%;'>" | |
| 258 "<col style='width: 20%;'>" | |
| 259 "<col style='width: 20%;'>" | |
| 260 "</colgroup>" | |
| 261 "<thead><tr>" | |
| 262 "<th style='padding: 10px; border: 1px solid #ccc; text-align: left;'>Metric</th>" | |
| 263 "<th style='padding: 10px; border: 1px solid #ccc; text-align: center;'>Train</th>" | |
| 264 "<th style='padding: 10px; border: 1px solid #ccc; text-align: center;'>Validation</th>" | |
| 265 "<th style='padding: 10px; border: 1px solid #ccc; text-align: center;'>Test</th>" | |
| 266 "</tr></thead><tbody>" + "".join(rows) + "</tbody></table></div><br>" | |
| 267 ) | |
| 268 | |
| 269 | |
| 270 def build_tabbed_html( | |
| 271 metrics_html: str, | |
| 272 train_viz_html: str, | |
| 273 test_viz_html: str) -> str: | |
| 274 return f""" | |
| 275 <style> | |
| 276 .tabs {{ | |
| 277 display: flex; | |
| 278 border-bottom: 2px solid #ccc; | |
| 279 margin-bottom: 1rem; | |
| 280 }} | |
| 281 .tab {{ | |
| 282 padding: 10px 20px; | |
| 283 cursor: pointer; | |
| 284 border: 1px solid #ccc; | |
| 285 border-bottom: none; | |
| 286 background: #f9f9f9; | |
| 287 margin-right: 5px; | |
| 288 border-top-left-radius: 8px; | |
| 289 border-top-right-radius: 8px; | |
| 290 }} | |
| 291 .tab.active {{ | |
| 292 background: white; | |
| 293 font-weight: bold; | |
| 294 }} | |
| 295 .tab-content {{ | |
| 296 display: none; | |
| 297 padding: 20px; | |
| 298 border: 1px solid #ccc; | |
| 299 border-top: none; | |
| 300 }} | |
| 301 .tab-content.active {{ | |
| 302 display: block; | |
| 303 }} | |
| 304 </style> | |
| 305 | |
| 306 <div class="tabs"> | |
| 307 <div class="tab active" onclick="showTab('metrics')"> Config & Metrics</div> | |
| 308 <div class="tab" onclick="showTab('trainval')"> Train/Validation Plots</div> | |
| 309 <div class="tab" onclick="showTab('test')"> Test Plots</div> | |
| 310 </div> | |
| 311 | |
| 312 <div id="metrics" class="tab-content active"> | |
| 313 {metrics_html} | |
| 314 </div> | |
| 315 <div id="trainval" class="tab-content"> | |
| 316 {train_viz_html} | |
| 317 </div> | |
| 318 <div id="test" class="tab-content"> | |
| 319 {test_viz_html} | |
| 320 </div> | |
| 321 | |
| 322 <script> | |
| 323 function showTab(id) {{ | |
| 324 document.querySelectorAll('.tab-content').forEach(el => el.classList.remove('active')); | |
| 325 document.querySelectorAll('.tab').forEach(el => el.classList.remove('active')); | |
| 326 document.getElementById(id).classList.add('active'); | |
| 327 document.querySelector(`.tab[onclick*="${{id}}"]`).classList.add('active'); | |
| 328 }} | |
| 329 </script> | |
| 330 """ | |
| 331 | |
| 332 | |
| 333 def split_data_0_2( | |
| 334 df: pd.DataFrame, | |
| 335 split_column: str, | |
| 336 validation_size: float = 0.15, | |
| 337 random_state: int = 42, | |
| 338 label_column: Optional[str] = None, | |
| 339 ) -> pd.DataFrame: | |
| 340 """ | |
| 341 Given a DataFrame whose split_column only contains {0,2}, re-assign | |
| 342 a portion of the 0s to become 1s (validation). Returns a fresh DataFrame. | |
| 343 """ | |
| 344 # Work on a copy | |
| 345 out = df.copy() | |
| 346 # Ensure split col is integer dtype | |
| 347 out[split_column] = pd.to_numeric(out[split_column], errors="coerce").astype(int) | |
| 348 | |
| 349 idx_train = out.index[out[split_column] == 0].tolist() | |
| 350 | |
| 351 if not idx_train: | |
| 352 logger.info("No rows with split=0; nothing to do.") | |
| 353 return out | |
| 354 | |
| 355 # Determine stratify array if possible | |
| 356 stratify_arr = None | |
| 357 if label_column and label_column in out.columns: | |
| 358 # Only stratify if at least two classes and enough samples | |
| 359 label_counts = out.loc[idx_train, label_column].value_counts() | |
| 360 if label_counts.size > 1 and (label_counts.min() * validation_size) >= 1: | |
| 361 stratify_arr = out.loc[idx_train, label_column] | |
| 362 else: | |
| 363 logger.warning("Cannot stratify (too few labels); splitting without stratify.") | |
| 364 | |
| 365 # Edge cases | |
| 366 if validation_size <= 0: | |
| 367 logger.info("validation_size <= 0; keeping all as train.") | |
| 368 return out | |
| 369 if validation_size >= 1: | |
| 370 logger.info("validation_size >= 1; moving all train → validation.") | |
| 371 out.loc[idx_train, split_column] = 1 | |
| 372 return out | |
| 373 | |
| 374 # Do the split | |
| 375 try: | |
| 376 train_idx, val_idx = train_test_split( | |
| 377 idx_train, | |
| 378 test_size=validation_size, | |
| 379 random_state=random_state, | |
| 380 stratify=stratify_arr | |
| 381 ) | |
| 382 except ValueError as e: | |
| 383 logger.warning(f"Stratified split failed ({e}); retrying without stratify.") | |
| 384 train_idx, val_idx = train_test_split( | |
| 385 idx_train, | |
| 386 test_size=validation_size, | |
| 387 random_state=random_state, | |
| 388 stratify=None | |
| 389 ) | |
| 390 | |
| 391 # Assign new splits | |
| 392 out.loc[train_idx, split_column] = 0 | |
| 393 out.loc[val_idx, split_column] = 1 | |
| 394 # idx_test stays at 2 | |
| 395 | |
| 396 # Cast back to a clean integer type | |
| 397 out[split_column] = out[split_column].astype(int) | |
| 398 # print(out) | |
| 399 return out | |
| 400 | |
| 401 | |
| 402 class Backend(Protocol): | |
| 403 """Interface for a machine learning backend.""" | |
| 404 def prepare_config( | |
| 405 self, | |
| 406 config_params: Dict[str, Any], | |
| 407 split_config: Dict[str, Any] | |
| 408 ) -> str: | |
| 409 ... | |
| 410 | |
| 411 def run_experiment( | |
| 412 self, | |
| 413 dataset_path: Path, | |
| 414 config_path: Path, | |
| 415 output_dir: Path, | |
| 416 random_seed: int, | |
| 417 ) -> None: | |
| 418 ... | |
| 419 | |
| 420 def generate_plots( | |
| 421 self, | |
| 422 output_dir: Path | |
| 423 ) -> None: | |
| 424 ... | |
| 425 | |
| 426 def generate_html_report( | |
| 427 self, | |
| 428 title: str, | |
| 429 output_dir: str | |
| 430 ) -> Path: | |
| 431 ... | |
| 432 | |
| 433 | |
| 434 class LudwigDirectBackend: | |
| 435 """ | |
| 436 Backend for running Ludwig experiments directly via the internal experiment_cli function. | |
| 437 """ | |
| 438 | |
| 439 def prepare_config( | |
| 440 self, | |
| 441 config_params: Dict[str, Any], | |
| 442 split_config: Dict[str, Any], | |
| 443 ) -> str: | |
| 444 """ | |
| 445 Build and serialize the Ludwig YAML configuration. | |
| 446 """ | |
| 447 logger.info("LudwigDirectBackend: Preparing YAML configuration.") | |
| 448 | |
| 449 model_name = config_params.get("model_name", "resnet18") | |
| 450 use_pretrained = config_params.get("use_pretrained", False) | |
| 451 fine_tune = config_params.get("fine_tune", False) | |
| 452 epochs = config_params.get("epochs", 10) | |
| 453 batch_size = config_params.get("batch_size") | |
| 454 num_processes = config_params.get("preprocessing_num_processes", 1) | |
| 455 early_stop = config_params.get("early_stop", None) | |
| 456 learning_rate = config_params.get("learning_rate") | |
| 457 learning_rate = "auto" if learning_rate is None else float(learning_rate) | |
| 458 trainable = fine_tune or (not use_pretrained) | |
| 459 if not use_pretrained and not trainable: | |
| 460 logger.warning("trainable=False; use_pretrained=False is ignored.") | |
| 461 logger.warning("Setting trainable=True to train the model from scratch.") | |
| 462 trainable = True | |
| 463 | |
| 464 # Encoder setup | |
| 465 raw_encoder = MODEL_ENCODER_TEMPLATES.get(model_name, model_name) | |
| 466 if isinstance(raw_encoder, dict): | |
| 467 encoder_config = { | |
| 468 **raw_encoder, | |
| 469 "use_pretrained": use_pretrained, | |
| 470 "trainable": trainable, | |
| 471 } | |
| 472 else: | |
| 473 encoder_config = {"type": raw_encoder} | |
| 474 | |
| 475 # Trainer & optimizer | |
| 476 # optimizer = {"type": "adam", "learning_rate": 5e-5} if fine_tune else {"type": "adam"} | |
| 477 batch_size_cfg = batch_size or "auto" | |
| 478 | |
| 479 conf: Dict[str, Any] = { | |
| 480 "model_type": "ecd", | |
| 481 "input_features": [ | |
| 482 { | |
| 483 "name": IMAGE_PATH_COLUMN_NAME, | |
| 484 "type": "image", | |
| 485 "encoder": encoder_config, | |
| 486 } | |
| 487 ], | |
| 488 "output_features": [ | |
| 489 {"name": LABEL_COLUMN_NAME, "type": "category"} | |
| 490 ], | |
| 491 "combiner": {"type": "concat"}, | |
| 492 "trainer": { | |
| 493 "epochs": epochs, | |
| 494 "early_stop": early_stop, | |
| 495 "batch_size": batch_size_cfg, | |
| 496 "learning_rate": learning_rate, | |
| 497 }, | |
| 498 "preprocessing": { | |
| 499 "split": split_config, | |
| 500 "num_processes": num_processes, | |
| 501 "in_memory": False, | |
| 502 }, | |
| 503 } | |
| 504 | |
| 505 logger.debug("LudwigDirectBackend: Config dict built.") | |
| 506 try: | |
| 507 yaml_str = yaml.dump(conf, sort_keys=False, indent=2) | |
| 508 logger.info("LudwigDirectBackend: YAML config generated.") | |
| 509 return yaml_str | |
| 510 except Exception: | |
| 511 logger.error("LudwigDirectBackend: Failed to serialize YAML.", exc_info=True) | |
| 512 raise | |
| 513 | |
| 514 def run_experiment( | |
| 515 self, | |
| 516 dataset_path: Path, | |
| 517 config_path: Path, | |
| 518 output_dir: Path, | |
| 519 random_seed: int = 42, | |
| 520 ) -> None: | |
| 521 """ | |
| 522 Invoke Ludwig's internal experiment_cli function to run the experiment. | |
| 523 """ | |
| 524 logger.info("LudwigDirectBackend: Starting experiment execution.") | |
| 525 | |
| 526 try: | |
| 527 from ludwig.experiment import experiment_cli | |
| 528 except ImportError as e: | |
| 529 logger.error( | |
| 530 "LudwigDirectBackend: Could not import experiment_cli.", | |
| 531 exc_info=True | |
| 532 ) | |
| 533 raise RuntimeError("Ludwig import failed.") from e | |
| 534 | |
| 535 output_dir.mkdir(parents=True, exist_ok=True) | |
| 536 | |
| 537 try: | |
| 538 experiment_cli( | |
| 539 dataset=str(dataset_path), | |
| 540 config=str(config_path), | |
| 541 output_directory=str(output_dir), | |
| 542 random_seed=random_seed, | |
| 543 ) | |
| 544 logger.info(f"LudwigDirectBackend: Experiment completed. Results in {output_dir}") | |
| 545 except TypeError as e: | |
| 546 logger.error( | |
| 547 "LudwigDirectBackend: Argument mismatch in experiment_cli call.", | |
| 548 exc_info=True | |
| 549 ) | |
| 550 raise RuntimeError("Ludwig argument error.") from e | |
| 551 except Exception: | |
| 552 logger.error( | |
| 553 "LudwigDirectBackend: Experiment execution error.", | |
| 554 exc_info=True | |
| 555 ) | |
| 556 raise | |
| 557 | |
| 558 def get_training_process(self, output_dir) -> float: | |
| 559 """ | |
| 560 Retrieve the learning rate used in the most recent Ludwig run. | |
| 561 Returns: | |
| 562 float: learning rate (or None if not found) | |
| 563 """ | |
| 564 output_dir = Path(output_dir) | |
| 565 exp_dirs = sorted( | |
| 566 output_dir.glob("experiment_run*"), | |
| 567 key=lambda p: p.stat().st_mtime | |
| 568 ) | |
| 569 | |
| 570 if not exp_dirs: | |
| 571 logger.warning(f"No experiment run directories found in {output_dir}") | |
| 572 return None | |
| 573 | |
| 574 progress_file = exp_dirs[-1] / "model" / "training_progress.json" | |
| 575 if not progress_file.exists(): | |
| 576 logger.warning(f"No training_progress.json found in {progress_file}") | |
| 577 return None | |
| 578 | |
| 579 try: | |
| 580 with progress_file.open("r", encoding="utf-8") as f: | |
| 581 data = json.load(f) | |
| 582 return { | |
| 583 "learning_rate": data.get("learning_rate"), | |
| 584 "batch_size": data.get("batch_size"), | |
| 585 "epoch": data.get("epoch"), | |
| 586 } | |
| 587 except Exception as e: | |
| 588 self.logger.warning(f"Failed to read training progress info: {e}") | |
| 589 return {} | |
| 590 | |
| 591 def convert_parquet_to_csv(self, output_dir: Path): | |
| 592 """Convert the predictions Parquet file to CSV.""" | |
| 593 output_dir = Path(output_dir) | |
| 594 exp_dirs = sorted( | |
| 595 output_dir.glob("experiment_run*"), | |
| 596 key=lambda p: p.stat().st_mtime | |
| 597 ) | |
| 598 if not exp_dirs: | |
| 599 logger.warning(f"No experiment run dirs found in {output_dir}") | |
| 600 return | |
| 601 exp_dir = exp_dirs[-1] | |
| 602 parquet_path = exp_dir / PREDICTIONS_PARQUET_FILE_NAME | |
| 603 csv_path = exp_dir / "predictions.csv" | |
| 604 try: | |
| 605 df = pd.read_parquet(parquet_path) | |
| 606 df.to_csv(csv_path, index=False) | |
| 607 logger.info(f"Converted Parquet to CSV: {csv_path}") | |
| 608 except Exception as e: | |
| 609 logger.error(f"Error converting Parquet to CSV: {e}") | |
| 610 | |
| 611 def generate_plots(self, output_dir: Path) -> None: | |
| 612 """ | |
| 613 Generate _all_ registered Ludwig visualizations for the latest experiment run. | |
| 614 """ | |
| 615 logger.info("Generating all Ludwig visualizations…") | |
| 616 | |
| 617 test_plots = { | |
| 618 'compare_performance', | |
| 619 'compare_classifiers_performance_from_prob', | |
| 620 'compare_classifiers_performance_from_pred', | |
| 621 'compare_classifiers_performance_changing_k', | |
| 622 'compare_classifiers_multiclass_multimetric', | |
| 623 'compare_classifiers_predictions', | |
| 624 'confidence_thresholding_2thresholds_2d', | |
| 625 'confidence_thresholding_2thresholds_3d', | |
| 626 'confidence_thresholding', | |
| 627 'confidence_thresholding_data_vs_acc', | |
| 628 'binary_threshold_vs_metric', | |
| 629 'roc_curves', | |
| 630 'roc_curves_from_test_statistics', | |
| 631 'calibration_1_vs_all', | |
| 632 'calibration_multiclass', | |
| 633 'confusion_matrix', | |
| 634 'frequency_vs_f1', | |
| 635 } | |
| 636 train_plots = { | |
| 637 'learning_curves', | |
| 638 'compare_classifiers_performance_subset', | |
| 639 } | |
| 640 | |
| 641 # 1) find the most recent experiment directory | |
| 642 output_dir = Path(output_dir) | |
| 643 exp_dirs = sorted( | |
| 644 output_dir.glob("experiment_run*"), | |
| 645 key=lambda p: p.stat().st_mtime | |
| 646 ) | |
| 647 if not exp_dirs: | |
| 648 logger.warning(f"No experiment run dirs found in {output_dir}") | |
| 649 return | |
| 650 exp_dir = exp_dirs[-1] | |
| 651 | |
| 652 # 2) ensure viz output subfolder exists | |
| 653 viz_dir = exp_dir / "visualizations" | |
| 654 viz_dir.mkdir(exist_ok=True) | |
| 655 train_viz = viz_dir / "train" | |
| 656 test_viz = viz_dir / "test" | |
| 657 train_viz.mkdir(parents=True, exist_ok=True) | |
| 658 test_viz.mkdir(parents=True, exist_ok=True) | |
| 659 | |
| 660 # 3) helper to check file existence | |
| 661 def _check(p: Path) -> Optional[str]: | |
| 662 return str(p) if p.exists() else None | |
| 663 | |
| 664 # 4) gather standard Ludwig output files | |
| 665 training_stats = _check(exp_dir / "training_statistics.json") | |
| 666 test_stats = _check(exp_dir / TEST_STATISTICS_FILE_NAME) | |
| 667 probs_path = _check(exp_dir / PREDICTIONS_PARQUET_FILE_NAME) | |
| 668 gt_metadata = _check(exp_dir / "model" / TRAIN_SET_METADATA_FILE_NAME) | |
| 669 | |
| 670 # 5) try to read original dataset & split file from description.json | |
| 671 dataset_path = None | |
| 672 split_file = None | |
| 673 desc = exp_dir / DESCRIPTION_FILE_NAME | |
| 674 if desc.exists(): | |
| 675 with open(desc, "r") as f: | |
| 676 cfg = json.load(f) | |
| 677 dataset_path = _check(Path(cfg.get("dataset", ""))) | |
| 678 split_file = _check(Path(get_split_path(cfg.get("dataset", "")))) | |
| 679 | |
| 680 # 6) infer output feature name | |
| 681 output_feature = "" | |
| 682 if desc.exists(): | |
| 683 try: | |
| 684 output_feature = cfg["config"]["output_features"][0]["name"] | |
| 685 except Exception: | |
| 686 pass | |
| 687 if not output_feature and test_stats: | |
| 688 with open(test_stats, "r") as f: | |
| 689 stats = json.load(f) | |
| 690 output_feature = next(iter(stats.keys()), "") | |
| 691 | |
| 692 # 7) loop through every registered viz | |
| 693 viz_registry = get_visualizations_registry() | |
| 694 for viz_name, viz_func in viz_registry.items(): | |
| 695 viz_dir_plot = None | |
| 696 if viz_name in train_plots: | |
| 697 viz_dir_plot = train_viz | |
| 698 elif viz_name in test_plots: | |
| 699 viz_dir_plot = test_viz | |
| 700 | |
| 701 try: | |
| 702 viz_func( | |
| 703 training_statistics=[training_stats] if training_stats else [], | |
| 704 test_statistics=[test_stats] if test_stats else [], | |
| 705 probabilities=[probs_path] if probs_path else [], | |
| 706 output_feature_name=output_feature, | |
| 707 ground_truth_split=2, | |
| 708 top_n_classes=[0], | |
| 709 top_k=3, | |
| 710 ground_truth_metadata=gt_metadata, | |
| 711 ground_truth=dataset_path, | |
| 712 split_file=split_file, | |
| 713 output_directory=str(viz_dir_plot), | |
| 714 normalize=False, | |
| 715 file_format="png", | |
| 716 ) | |
| 717 logger.info(f"✔ Generated {viz_name}") | |
| 718 except Exception as e: | |
| 719 logger.warning(f"✘ Skipped {viz_name}: {e}") | |
| 720 | |
| 721 logger.info(f"All visualizations written to {viz_dir}") | |
| 722 | |
| 723 def generate_html_report( | |
| 724 self, | |
| 725 title: str, | |
| 726 output_dir: str, | |
| 727 config: dict, | |
| 728 split_info: str) -> Path: | |
| 729 """ | |
| 730 Assemble an HTML report from visualizations under train_val/ and test/ folders. | |
| 731 """ | |
| 732 cwd = Path.cwd() | |
| 733 report_name = title.lower().replace(" ", "_") + "_report.html" | |
| 734 report_path = cwd / report_name | |
| 735 output_dir = Path(output_dir) | |
| 736 | |
| 737 # Find latest experiment dir | |
| 738 exp_dirs = sorted(output_dir.glob("experiment_run*"), key=lambda p: p.stat().st_mtime) | |
| 739 if not exp_dirs: | |
| 740 raise RuntimeError(f"No 'experiment*' dirs found in {output_dir}") | |
| 741 exp_dir = exp_dirs[-1] | |
| 742 | |
| 743 base_viz_dir = exp_dir / "visualizations" | |
| 744 train_viz_dir = base_viz_dir / "train" | |
| 745 test_viz_dir = base_viz_dir / "test" | |
| 746 | |
| 747 html = get_html_template() | |
| 748 html += f"<h1>{title}</h1>" | |
| 749 | |
| 750 metrics_html = "" | |
| 751 | |
| 752 # Load and embed metrics table (training/val/test stats) | |
| 753 try: | |
| 754 train_stats_path = exp_dir / "training_statistics.json" | |
| 755 test_stats_path = exp_dir / TEST_STATISTICS_FILE_NAME | |
| 756 if train_stats_path.exists() and test_stats_path.exists(): | |
| 757 with open(train_stats_path) as f: | |
| 758 train_stats = json.load(f) | |
| 759 with open(test_stats_path) as f: | |
| 760 test_stats = json.load(f) | |
| 761 output_feature = next(iter(train_stats.keys()), "") | |
| 762 if output_feature: | |
| 763 metrics_html += format_stats_table_html(train_stats, test_stats) | |
| 764 except Exception as e: | |
| 765 logger.warning(f"Could not load stats for HTML report: {e}") | |
| 766 | |
| 767 config_html = "" | |
| 768 training_progress = self.get_training_process(output_dir) | |
| 769 try: | |
| 770 config_html = format_config_table_html(config, split_info, training_progress) | |
| 771 except Exception as e: | |
| 772 logger.warning(f"Could not load config for HTML report: {e}") | |
| 773 | |
| 774 def render_img_section(title: str, dir_path: Path) -> str: | |
| 775 if not dir_path.exists(): | |
| 776 return f"<h2>{title}</h2><p><em>Directory not found.</em></p>" | |
| 777 imgs = sorted(dir_path.glob("*.png")) | |
| 778 if not imgs: | |
| 779 return f"<h2>{title}</h2><p><em>No plots found.</em></p>" | |
| 780 | |
| 781 section_html = f"<h2 style='text-align: center;'>{title}</h2><div>" | |
| 782 for img in imgs: | |
| 783 b64 = encode_image_to_base64(str(img)) | |
| 784 section_html += ( | |
| 785 f'<div class="plot" style="margin-bottom:20px;text-align:center;">' | |
| 786 f"<h3>{img.stem.replace('_',' ').title()}</h3>" | |
| 787 f'<img src="data:image/png;base64,{b64}" ' | |
| 788 'style="max-width:90%;max-height:600px;border:1px solid #ddd;" />' | |
| 789 "</div>" | |
| 790 ) | |
| 791 section_html += "</div>" | |
| 792 return section_html | |
| 793 | |
| 794 train_plots_html = render_img_section("Training & Validation Visualizations", train_viz_dir) | |
| 795 test_plots_html = render_img_section("Test Visualizations", test_viz_dir) | |
| 796 html += build_tabbed_html(config_html + metrics_html, train_plots_html, test_plots_html) | |
| 797 html += get_html_closing() | |
| 798 | |
| 799 try: | |
| 800 with open(report_path, "w") as f: | |
| 801 f.write(html) | |
| 802 logger.info(f"HTML report generated at: {report_path}") | |
| 803 except Exception as e: | |
| 804 logger.error(f"Failed to write HTML report: {e}") | |
| 805 raise | |
| 806 | |
| 807 return report_path | |
| 808 | |
| 809 | |
| 810 class WorkflowOrchestrator: | |
| 811 """ | |
| 812 Manages the image-classification workflow: | |
| 813 1. Creates temp dirs | |
| 814 2. Extracts images | |
| 815 3. Prepares data (CSV + splits) | |
| 816 4. Renders a backend config | |
| 817 5. Runs the experiment | |
| 818 6. Cleans up | |
| 819 """ | |
| 820 | |
| 821 def __init__(self, args: argparse.Namespace, backend: Backend): | |
| 822 self.args = args | |
| 823 self.backend = backend | |
| 824 self.temp_dir: Optional[Path] = None | |
| 825 self.image_extract_dir: Optional[Path] = None | |
| 826 logger.info(f"Orchestrator initialized with backend: {type(backend).__name__}") | |
| 827 | |
| 828 def _create_temp_dirs(self) -> None: | |
| 829 """Create temporary output and image extraction directories.""" | |
| 830 try: | |
| 831 self.temp_dir = Path(tempfile.mkdtemp( | |
| 832 dir=self.args.output_dir, | |
| 833 prefix=TEMP_DIR_PREFIX | |
| 834 )) | |
| 835 self.image_extract_dir = self.temp_dir / "images" | |
| 836 self.image_extract_dir.mkdir() | |
| 837 logger.info(f"Created temp directory: {self.temp_dir}") | |
| 838 except Exception: | |
| 839 logger.error("Failed to create temporary directories", exc_info=True) | |
| 840 raise | |
| 841 | |
| 842 def _extract_images(self) -> None: | |
| 843 """Extract images from ZIP into the temp image directory.""" | |
| 844 if self.image_extract_dir is None: | |
| 845 raise RuntimeError("Temp image directory not initialized.") | |
| 846 logger.info(f"Extracting images from {self.args.image_zip} → {self.image_extract_dir}") | |
| 847 try: | |
| 848 with zipfile.ZipFile(self.args.image_zip, "r") as z: | |
| 849 z.extractall(self.image_extract_dir) | |
| 850 logger.info("Image extraction complete.") | |
| 851 except Exception: | |
| 852 logger.error("Error extracting zip file", exc_info=True) | |
| 853 raise | |
| 854 | |
| 855 def _prepare_data(self) -> Tuple[Path, Dict[str, Any]]: | |
| 856 """ | |
| 857 Load CSV, update image paths, handle splits, and write prepared CSV. | |
| 858 Returns: | |
| 859 final_csv_path: Path to the prepared CSV | |
| 860 split_config: Dict for backend split settings | |
| 861 """ | |
| 862 if not self.temp_dir or not self.image_extract_dir: | |
| 863 raise RuntimeError("Temp dirs not initialized before data prep.") | |
| 864 | |
| 865 # 1) Load | |
| 866 try: | |
| 867 df = pd.read_csv(self.args.csv_file) | |
| 868 logger.info(f"Loaded CSV: {self.args.csv_file}") | |
| 869 except Exception: | |
| 870 logger.error("Error loading CSV file", exc_info=True) | |
| 871 raise | |
| 872 | |
| 873 # 2) Validate columns | |
| 874 required = {IMAGE_PATH_COLUMN_NAME, LABEL_COLUMN_NAME} | |
| 875 missing = required - set(df.columns) | |
| 876 if missing: | |
| 877 raise ValueError(f"Missing CSV columns: {', '.join(missing)}") | |
| 878 | |
| 879 # 3) Update image paths | |
| 880 try: | |
| 881 df[IMAGE_PATH_COLUMN_NAME] = df[IMAGE_PATH_COLUMN_NAME].apply( | |
| 882 lambda p: str((self.image_extract_dir / p).resolve()) | |
| 883 ) | |
| 884 except Exception: | |
| 885 logger.error("Error updating image paths", exc_info=True) | |
| 886 raise | |
| 887 | |
| 888 # 4) Handle splits | |
| 889 if SPLIT_COLUMN_NAME in df.columns: | |
| 890 df, split_config, split_info = self._process_fixed_split(df) | |
| 891 else: | |
| 892 logger.info("No split column; using random split") | |
| 893 split_config = { | |
| 894 "type": "random", | |
| 895 "probabilities": self.args.split_probabilities | |
| 896 } | |
| 897 split_info = ( | |
| 898 f"No split column in CSV. Used random split: " | |
| 899 f"{[int(p*100) for p in self.args.split_probabilities]}% for train/val/test." | |
| 900 ) | |
| 901 | |
| 902 # 5) Write out prepared CSV | |
| 903 final_csv = TEMP_CSV_FILENAME | |
| 904 try: | |
| 905 df.to_csv(final_csv, index=False) | |
| 906 logger.info(f"Saved prepared data to {final_csv}") | |
| 907 except Exception: | |
| 908 logger.error("Error saving prepared CSV", exc_info=True) | |
| 909 raise | |
| 910 | |
| 911 return final_csv, split_config, split_info | |
| 912 | |
| 913 def _process_fixed_split(self, df: pd.DataFrame) -> Dict[str, Any]: | |
| 914 """Process a fixed split column (0=train,1=val,2=test).""" | |
| 915 logger.info(f"Fixed split column '{SPLIT_COLUMN_NAME}' detected.") | |
| 916 try: | |
| 917 col = df[SPLIT_COLUMN_NAME] | |
| 918 df[SPLIT_COLUMN_NAME] = pd.to_numeric(col, errors="coerce").astype(pd.Int64Dtype()) | |
| 919 if df[SPLIT_COLUMN_NAME].isna().any(): | |
| 920 logger.warning("Split column contains non-numeric/missing values.") | |
| 921 | |
| 922 unique = set(df[SPLIT_COLUMN_NAME].dropna().unique()) | |
| 923 logger.info(f"Unique split values: {unique}") | |
| 924 | |
| 925 if unique == {0, 2}: | |
| 926 df = split_data_0_2( | |
| 927 df, SPLIT_COLUMN_NAME, | |
| 928 validation_size=self.args.validation_size, | |
| 929 label_column=LABEL_COLUMN_NAME, | |
| 930 random_state=self.args.random_seed | |
| 931 ) | |
| 932 split_info = ( | |
| 933 "Detected a split column (with values 0 and 2) in the input CSV. " | |
| 934 f"Used this column as a base and" | |
| 935 f"reassigned {self.args.validation_size * 100:.1f}% " | |
| 936 "of the training set (originally labeled 0) to validation (labeled 1)." | |
| 937 ) | |
| 938 | |
| 939 logger.info("Applied custom 0/2 split.") | |
| 940 elif unique.issubset({0, 1, 2}): | |
| 941 split_info = "Used user-defined split column from CSV." | |
| 942 logger.info("Using fixed split as-is.") | |
| 943 else: | |
| 944 raise ValueError(f"Unexpected split values: {unique}") | |
| 945 | |
| 946 return df, {"type": "fixed", "column": SPLIT_COLUMN_NAME}, split_info | |
| 947 | |
| 948 except Exception: | |
| 949 logger.error("Error processing fixed split", exc_info=True) | |
| 950 raise | |
| 951 | |
| 952 def _cleanup_temp_dirs(self) -> None: | |
| 953 """Remove any temporary directories.""" | |
| 954 if self.temp_dir and self.temp_dir.exists(): | |
| 955 logger.info(f"Cleaning up temp directory: {self.temp_dir}") | |
| 956 shutil.rmtree(self.temp_dir, ignore_errors=True) | |
| 957 self.temp_dir = None | |
| 958 self.image_extract_dir = None | |
| 959 | |
| 960 def run(self) -> None: | |
| 961 """Execute the full workflow end-to-end.""" | |
| 962 logger.info("Starting workflow...") | |
| 963 self.args.output_dir.mkdir(parents=True, exist_ok=True) | |
| 964 | |
| 965 try: | |
| 966 self._create_temp_dirs() | |
| 967 self._extract_images() | |
| 968 csv_path, split_cfg, split_info = self._prepare_data() | |
| 969 | |
| 970 use_pretrained = self.args.use_pretrained or self.args.fine_tune | |
| 971 | |
| 972 backend_args = { | |
| 973 "model_name": self.args.model_name, | |
| 974 "fine_tune": self.args.fine_tune, | |
| 975 "use_pretrained": use_pretrained, | |
| 976 "epochs": self.args.epochs, | |
| 977 "batch_size": self.args.batch_size, | |
| 978 "preprocessing_num_processes": self.args.preprocessing_num_processes, | |
| 979 "split_probabilities": self.args.split_probabilities, | |
| 980 "learning_rate": self.args.learning_rate, | |
| 981 "random_seed": self.args.random_seed, | |
| 982 "early_stop": self.args.early_stop, | |
| 983 } | |
| 984 yaml_str = self.backend.prepare_config(backend_args, split_cfg) | |
| 985 | |
| 986 config_file = self.temp_dir / TEMP_CONFIG_FILENAME | |
| 987 config_file.write_text(yaml_str) | |
| 988 logger.info(f"Wrote backend config: {config_file}") | |
| 989 | |
| 990 self.backend.run_experiment( | |
| 991 csv_path, | |
| 992 config_file, | |
| 993 self.args.output_dir, | |
| 994 self.args.random_seed | |
| 995 ) | |
| 996 logger.info("Workflow completed successfully.") | |
| 997 self.backend.generate_plots(self.args.output_dir) | |
| 998 report_file = self.backend.generate_html_report( | |
| 999 "Image Classification Results", | |
| 1000 self.args.output_dir, | |
| 1001 backend_args, | |
| 1002 split_info | |
| 1003 ) | |
| 1004 logger.info(f"HTML report generated at: {report_file}") | |
| 1005 self.backend.convert_parquet_to_csv(self.args.output_dir) | |
| 1006 logger.info("Converted Parquet to CSV.") | |
| 1007 except Exception: | |
| 1008 logger.error("Workflow execution failed", exc_info=True) | |
| 1009 raise | |
| 1010 | |
| 1011 finally: | |
| 1012 self._cleanup_temp_dirs() | |
| 1013 | |
| 1014 | |
| 1015 def parse_learning_rate(s): | |
| 1016 try: | |
| 1017 return float(s) | |
| 1018 except (TypeError, ValueError): | |
| 1019 return None | |
| 1020 | |
| 1021 | |
| 1022 class SplitProbAction(argparse.Action): | |
| 1023 def __call__(self, parser, namespace, values, option_string=None): | |
| 1024 # values is a list of three floats | |
| 1025 train, val, test = values | |
| 1026 total = train + val + test | |
| 1027 if abs(total - 1.0) > 1e-6: | |
| 1028 parser.error( | |
| 1029 f"--split-probabilities must sum to 1.0; " | |
| 1030 f"got {train:.3f} + {val:.3f} + {test:.3f} = {total:.3f}" | |
| 1031 ) | |
| 1032 setattr(namespace, self.dest, values) | |
| 1033 | |
| 1034 | |
| 1035 def main(): | |
| 1036 | |
| 1037 parser = argparse.ArgumentParser( | |
| 1038 description="Image Classification Learner with Pluggable Backends" | |
| 1039 ) | |
| 1040 parser.add_argument( | |
| 1041 "--csv-file", required=True, type=Path, | |
| 1042 help="Path to the input CSV" | |
| 1043 ) | |
| 1044 parser.add_argument( | |
| 1045 "--image-zip", required=True, type=Path, | |
| 1046 help="Path to the images ZIP" | |
| 1047 ) | |
| 1048 parser.add_argument( | |
| 1049 "--model-name", required=True, | |
| 1050 choices=MODEL_ENCODER_TEMPLATES.keys(), | |
| 1051 help="Which model template to use" | |
| 1052 ) | |
| 1053 parser.add_argument( | |
| 1054 "--use-pretrained", action="store_true", | |
| 1055 help="Use pretrained weights for the model" | |
| 1056 ) | |
| 1057 parser.add_argument( | |
| 1058 "--fine-tune", action="store_true", | |
| 1059 help="Enable fine-tuning" | |
| 1060 ) | |
| 1061 parser.add_argument( | |
| 1062 "--epochs", type=int, default=10, | |
| 1063 help="Number of training epochs" | |
| 1064 ) | |
| 1065 parser.add_argument( | |
| 1066 "--early-stop", type=int, default=5, | |
| 1067 help="Early stopping patience" | |
| 1068 ) | |
| 1069 parser.add_argument( | |
| 1070 "--batch-size", type=int, | |
| 1071 help="Batch size (None = auto)" | |
| 1072 ) | |
| 1073 parser.add_argument( | |
| 1074 "--output-dir", type=Path, default=Path("learner_output"), | |
| 1075 help="Where to write outputs" | |
| 1076 ) | |
| 1077 parser.add_argument( | |
| 1078 "--validation-size", type=float, default=0.15, | |
| 1079 help="Fraction for validation (0.0–1.0)" | |
| 1080 ) | |
| 1081 parser.add_argument( | |
| 1082 "--preprocessing-num-processes", type=int, | |
| 1083 default=max(1, os.cpu_count() // 2), | |
| 1084 help="CPU processes for data prep" | |
| 1085 ) | |
| 1086 parser.add_argument( | |
| 1087 "--split-probabilities", type=float, nargs=3, | |
| 1088 metavar=("train", "val", "test"), | |
| 1089 action=SplitProbAction, | |
| 1090 default=[0.7, 0.1, 0.2], | |
| 1091 help="Random split proportions (e.g., 0.7 0.1 0.2). Only used if no split column is present." | |
| 1092 ) | |
| 1093 parser.add_argument( | |
| 1094 "--random-seed", type=int, default=42, | |
| 1095 help="Random seed used for dataset splitting (default: 42)" | |
| 1096 ) | |
| 1097 parser.add_argument( | |
| 1098 "--learning-rate", type=parse_learning_rate, default=None, | |
| 1099 help="Learning rate. If not provided, Ludwig will auto-select it." | |
| 1100 ) | |
| 1101 | |
| 1102 args = parser.parse_args() | |
| 1103 | |
| 1104 # -- Validation -- | |
| 1105 if not 0.0 <= args.validation_size <= 1.0: | |
| 1106 parser.error("validation-size must be between 0.0 and 1.0") | |
| 1107 if not args.csv_file.is_file(): | |
| 1108 parser.error(f"CSV not found: {args.csv_file}") | |
| 1109 if not args.image_zip.is_file(): | |
| 1110 parser.error(f"ZIP not found: {args.image_zip}") | |
| 1111 | |
| 1112 # --- Instantiate Backend and Orchestrator --- | |
| 1113 # Use the new LudwigDirectBackend | |
| 1114 backend_instance = LudwigDirectBackend() | |
| 1115 orchestrator = WorkflowOrchestrator(args, backend_instance) | |
| 1116 | |
| 1117 # --- Run Workflow --- | |
| 1118 exit_code = 0 | |
| 1119 try: | |
| 1120 orchestrator.run() | |
| 1121 logger.info("Main script finished successfully.") | |
| 1122 except Exception as e: | |
| 1123 logger.error(f"Main script failed.{e}") | |
| 1124 exit_code = 1 | |
| 1125 finally: | |
| 1126 sys.exit(exit_code) | |
| 1127 | |
| 1128 | |
| 1129 if __name__ == '__main__': | |
| 1130 try: | |
| 1131 import ludwig | |
| 1132 logger.debug(f"Found Ludwig version: {ludwig.globals.LUDWIG_VERSION}") | |
| 1133 except ImportError: | |
| 1134 logger.error("Ludwig library not found. Please ensure Ludwig is installed ('pip install ludwig[image]')") | |
| 1135 sys.exit(1) | |
| 1136 | |
| 1137 main() |
