Mercurial > repos > goeckslab > image_learner
comparison image_learner_cli.py @ 0:54b871dfc51e draft default tip
planemo upload for repository https://github.com/goeckslab/gleam.git commit b7411ff35b6228ccdfd36cd4ebd946c03ac7f7e9
author | goeckslab |
---|---|
date | Tue, 03 Jun 2025 21:22:11 +0000 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:54b871dfc51e |
---|---|
1 #!/usr/bin/env python3 | |
2 import argparse | |
3 import json | |
4 import logging | |
5 import os | |
6 import shutil | |
7 import sys | |
8 import tempfile | |
9 import zipfile | |
10 from pathlib import Path | |
11 from typing import Any, Dict, Optional, Protocol, Tuple | |
12 | |
13 import pandas as pd | |
14 import yaml | |
15 from ludwig.globals import ( | |
16 DESCRIPTION_FILE_NAME, | |
17 PREDICTIONS_PARQUET_FILE_NAME, | |
18 TEST_STATISTICS_FILE_NAME, | |
19 TRAIN_SET_METADATA_FILE_NAME, | |
20 ) | |
21 from ludwig.utils.data_utils import get_split_path | |
22 from ludwig.visualize import get_visualizations_registry | |
23 from sklearn.model_selection import train_test_split | |
24 from utils import encode_image_to_base64, get_html_closing, get_html_template | |
25 | |
26 # --- Constants --- | |
27 SPLIT_COLUMN_NAME = 'split' | |
28 LABEL_COLUMN_NAME = 'label' | |
29 IMAGE_PATH_COLUMN_NAME = 'image_path' | |
30 DEFAULT_SPLIT_PROBABILITIES = [0.7, 0.1, 0.2] | |
31 TEMP_CSV_FILENAME = "processed_data_for_ludwig.csv" | |
32 TEMP_CONFIG_FILENAME = "ludwig_config.yaml" | |
33 TEMP_DIR_PREFIX = "ludwig_api_work_" | |
34 MODEL_ENCODER_TEMPLATES: Dict[str, Any] = { | |
35 'stacked_cnn': 'stacked_cnn', | |
36 'resnet18': {'type': 'resnet', 'model_variant': 18}, | |
37 'resnet34': {'type': 'resnet', 'model_variant': 34}, | |
38 'resnet50': {'type': 'resnet', 'model_variant': 50}, | |
39 'resnet101': {'type': 'resnet', 'model_variant': 101}, | |
40 'resnet152': {'type': 'resnet', 'model_variant': 152}, | |
41 'resnext50_32x4d': {'type': 'resnext', 'model_variant': '50_32x4d'}, | |
42 'resnext101_32x8d': {'type': 'resnext', 'model_variant': '101_32x8d'}, | |
43 'resnext101_64x4d': {'type': 'resnext', 'model_variant': '101_64x4d'}, | |
44 'resnext152_32x8d': {'type': 'resnext', 'model_variant': '152_32x8d'}, | |
45 'wide_resnet50_2': {'type': 'wide_resnet', 'model_variant': '50_2'}, | |
46 'wide_resnet101_2': {'type': 'wide_resnet', 'model_variant': '101_2'}, | |
47 'wide_resnet103_2': {'type': 'wide_resnet', 'model_variant': '103_2'}, | |
48 'efficientnet_b0': {'type': 'efficientnet', 'model_variant': 'b0'}, | |
49 'efficientnet_b1': {'type': 'efficientnet', 'model_variant': 'b1'}, | |
50 'efficientnet_b2': {'type': 'efficientnet', 'model_variant': 'b2'}, | |
51 'efficientnet_b3': {'type': 'efficientnet', 'model_variant': 'b3'}, | |
52 'efficientnet_b4': {'type': 'efficientnet', 'model_variant': 'b4'}, | |
53 'efficientnet_b5': {'type': 'efficientnet', 'model_variant': 'b5'}, | |
54 'efficientnet_b6': {'type': 'efficientnet', 'model_variant': 'b6'}, | |
55 'efficientnet_b7': {'type': 'efficientnet', 'model_variant': 'b7'}, | |
56 'efficientnet_v2_s': {'type': 'efficientnet', 'model_variant': 'v2_s'}, | |
57 'efficientnet_v2_m': {'type': 'efficientnet', 'model_variant': 'v2_m'}, | |
58 'efficientnet_v2_l': {'type': 'efficientnet', 'model_variant': 'v2_l'}, | |
59 'regnet_y_400mf': {'type': 'regnet', 'model_variant': 'y_400mf'}, | |
60 'regnet_y_800mf': {'type': 'regnet', 'model_variant': 'y_800mf'}, | |
61 'regnet_y_1_6gf': {'type': 'regnet', 'model_variant': 'y_1_6gf'}, | |
62 'regnet_y_3_2gf': {'type': 'regnet', 'model_variant': 'y_3_2gf'}, | |
63 'regnet_y_8gf': {'type': 'regnet', 'model_variant': 'y_8gf'}, | |
64 'regnet_y_16gf': {'type': 'regnet', 'model_variant': 'y_16gf'}, | |
65 'regnet_y_32gf': {'type': 'regnet', 'model_variant': 'y_32gf'}, | |
66 'regnet_y_128gf': {'type': 'regnet', 'model_variant': 'y_128gf'}, | |
67 'regnet_x_400mf': {'type': 'regnet', 'model_variant': 'x_400mf'}, | |
68 'regnet_x_800mf': {'type': 'regnet', 'model_variant': 'x_800mf'}, | |
69 'regnet_x_1_6gf': {'type': 'regnet', 'model_variant': 'x_1_6gf'}, | |
70 'regnet_x_3_2gf': {'type': 'regnet', 'model_variant': 'x_3_2gf'}, | |
71 'regnet_x_8gf': {'type': 'regnet', 'model_variant': 'x_8gf'}, | |
72 'regnet_x_16gf': {'type': 'regnet', 'model_variant': 'x_16gf'}, | |
73 'regnet_x_32gf': {'type': 'regnet', 'model_variant': 'x_32gf'}, | |
74 'vgg11': {'type': 'vgg', 'model_variant': 11}, | |
75 'vgg11_bn': {'type': 'vgg', 'model_variant': '11_bn'}, | |
76 'vgg13': {'type': 'vgg', 'model_variant': 13}, | |
77 'vgg13_bn': {'type': 'vgg', 'model_variant': '13_bn'}, | |
78 'vgg16': {'type': 'vgg', 'model_variant': 16}, | |
79 'vgg16_bn': {'type': 'vgg', 'model_variant': '16_bn'}, | |
80 'vgg19': {'type': 'vgg', 'model_variant': 19}, | |
81 'vgg19_bn': {'type': 'vgg', 'model_variant': '19_bn'}, | |
82 'shufflenet_v2_x0_5': {'type': 'shufflenet_v2', 'model_variant': 'x0_5'}, | |
83 'shufflenet_v2_x1_0': {'type': 'shufflenet_v2', 'model_variant': 'x1_0'}, | |
84 'shufflenet_v2_x1_5': {'type': 'shufflenet_v2', 'model_variant': 'x1_5'}, | |
85 'shufflenet_v2_x2_0': {'type': 'shufflenet_v2', 'model_variant': 'x2_0'}, | |
86 'squeezenet1_0': {'type': 'squeezenet', 'model_variant': '1_0'}, | |
87 'squeezenet1_1': {'type': 'squeezenet', 'model_variant': '1_1'}, | |
88 'swin_t': {'type': 'swin_transformer', 'model_variant': 't'}, | |
89 'swin_s': {'type': 'swin_transformer', 'model_variant': 's'}, | |
90 'swin_b': {'type': 'swin_transformer', 'model_variant': 'b'}, | |
91 'swin_v2_t': {'type': 'swin_transformer', 'model_variant': 'v2_t'}, | |
92 'swin_v2_s': {'type': 'swin_transformer', 'model_variant': 'v2_s'}, | |
93 'swin_v2_b': {'type': 'swin_transformer', 'model_variant': 'v2_b'}, | |
94 'vit_b_16': {'type': 'vision_transformer', 'model_variant': 'b_16'}, | |
95 'vit_b_32': {'type': 'vision_transformer', 'model_variant': 'b_32'}, | |
96 'vit_l_16': {'type': 'vision_transformer', 'model_variant': 'l_16'}, | |
97 'vit_l_32': {'type': 'vision_transformer', 'model_variant': 'l_32'}, | |
98 'vit_h_14': {'type': 'vision_transformer', 'model_variant': 'h_14'}, | |
99 'convnext_tiny': {'type': 'convnext', 'model_variant': 'tiny'}, | |
100 'convnext_small': {'type': 'convnext', 'model_variant': 'small'}, | |
101 'convnext_base': {'type': 'convnext', 'model_variant': 'base'}, | |
102 'convnext_large': {'type': 'convnext', 'model_variant': 'large'}, | |
103 'maxvit_t': {'type': 'maxvit', 'model_variant': 't'}, | |
104 'alexnet': {'type': 'alexnet'}, | |
105 'googlenet': {'type': 'googlenet'}, | |
106 'inception_v3': {'type': 'inception_v3'}, | |
107 'mobilenet_v2': {'type': 'mobilenet_v2'}, | |
108 'mobilenet_v3_large': {'type': 'mobilenet_v3_large'}, | |
109 'mobilenet_v3_small': {'type': 'mobilenet_v3_small'}, | |
110 } | |
111 | |
112 # --- Logging Setup --- | |
113 logging.basicConfig( | |
114 level=logging.INFO, | |
115 format='%(asctime)s %(levelname)s %(name)s: %(message)s' | |
116 ) | |
117 logger = logging.getLogger("ImageLearner") | |
118 | |
119 | |
120 def format_config_table_html( | |
121 config: dict, | |
122 split_info: Optional[str] = None, | |
123 training_progress: dict = None) -> str: | |
124 display_keys = [ | |
125 "model_name", | |
126 "epochs", | |
127 "batch_size", | |
128 "fine_tune", | |
129 "use_pretrained", | |
130 "learning_rate", | |
131 "random_seed", | |
132 "early_stop", | |
133 ] | |
134 | |
135 rows = [] | |
136 | |
137 for key in display_keys: | |
138 val = config.get(key, "N/A") | |
139 if key == "batch_size": | |
140 if val is not None: | |
141 val = int(val) | |
142 else: | |
143 if training_progress: | |
144 val = "Auto-selected batch size by Ludwig:<br>" | |
145 resolved_val = training_progress.get("batch_size") | |
146 val += ( | |
147 f"<span style='font-size: 0.85em;'>{resolved_val}</span><br>" | |
148 ) | |
149 else: | |
150 val = "auto" | |
151 if key == "learning_rate": | |
152 resolved_val = None | |
153 if val is None or val == "auto": | |
154 if training_progress: | |
155 resolved_val = training_progress.get("learning_rate") | |
156 val = ( | |
157 "Auto-selected learning rate by Ludwig:<br>" | |
158 f"<span style='font-size: 0.85em;'>{resolved_val if resolved_val else val}</span><br>" | |
159 "<span style='font-size: 0.85em;'>" | |
160 "Based on model architecture and training setup (e.g., fine-tuning).<br>" | |
161 "See <a href='https://ludwig.ai/latest/configuration/trainer/#trainer-parameters' " | |
162 "target='_blank'>Ludwig Trainer Parameters</a> for details." | |
163 "</span>" | |
164 ) | |
165 else: | |
166 val = ( | |
167 "Auto-selected by Ludwig<br>" | |
168 "<span style='font-size: 0.85em;'>" | |
169 "Automatically tuned based on architecture and dataset.<br>" | |
170 "See <a href='https://ludwig.ai/latest/configuration/trainer/#trainer-parameters' " | |
171 "target='_blank'>Ludwig Trainer Parameters</a> for details." | |
172 "</span>" | |
173 ) | |
174 else: | |
175 val = f"{val:.6f}" | |
176 if key == "epochs": | |
177 if training_progress and "epoch" in training_progress and val > training_progress["epoch"]: | |
178 val = ( | |
179 f"Because of early stopping: the training" | |
180 f"stopped at epoch {training_progress['epoch']}" | |
181 ) | |
182 | |
183 if val is None: | |
184 continue | |
185 rows.append( | |
186 f"<tr>" | |
187 f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left;'>" | |
188 f"{key.replace('_', ' ').title()}</td>" | |
189 f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center;'>{val}</td>" | |
190 f"</tr>" | |
191 ) | |
192 | |
193 if split_info: | |
194 rows.append( | |
195 f"<tr>" | |
196 f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left;'>Data Split</td>" | |
197 f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center;'>{split_info}</td>" | |
198 f"</tr>" | |
199 ) | |
200 | |
201 return ( | |
202 "<h2 style='text-align: center;'>Training Setup</h2>" | |
203 "<div style='display: flex; justify-content: center;'>" | |
204 "<table style='border-collapse: collapse; width: 60%; table-layout: auto;'>" | |
205 "<thead><tr>" | |
206 "<th style='padding: 10px; border: 1px solid #ccc; text-align: left;'>Parameter</th>" | |
207 "<th style='padding: 10px; border: 1px solid #ccc; text-align: center;'>Value</th>" | |
208 "</tr></thead><tbody>" + "".join(rows) + "</tbody></table></div><br>" | |
209 "<p style='text-align: center; font-size: 0.9em;'>" | |
210 "Model trained using Ludwig.<br>" | |
211 "If want to learn more about Ludwig default settings," | |
212 "please check the their <a href='https://ludwig.ai' target='_blank'>website(ludwig.ai)</a>." | |
213 "</p><hr>" | |
214 ) | |
215 | |
216 | |
217 def format_stats_table_html(training_stats: dict, test_stats: dict) -> str: | |
218 train_metrics = training_stats.get("training", {}).get("label", {}) | |
219 val_metrics = training_stats.get("validation", {}).get("label", {}) | |
220 test_metrics = test_stats.get("label", {}) | |
221 | |
222 all_metrics = set(train_metrics) | set(val_metrics) | set(test_metrics) | |
223 | |
224 def get_last_value(stats, key): | |
225 val = stats.get(key) | |
226 if isinstance(val, list) and val: | |
227 return val[-1] | |
228 elif isinstance(val, (int, float)): | |
229 return val | |
230 return None | |
231 | |
232 rows = [] | |
233 for metric in sorted(all_metrics): | |
234 t = get_last_value(train_metrics, metric) | |
235 v = get_last_value(val_metrics, metric) | |
236 te = get_last_value(test_metrics, metric) | |
237 if all(x is not None for x in [t, v, te]): | |
238 row = ( | |
239 f"<tr>" | |
240 f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left;'>{metric}</td>" | |
241 f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center;'>{t:.4f}</td>" | |
242 f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center;'>{v:.4f}</td>" | |
243 f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center;'>{te:.4f}</td>" | |
244 f"</tr>" | |
245 ) | |
246 rows.append(row) | |
247 | |
248 if not rows: | |
249 return "<p><em>No metric values found.</em></p>" | |
250 | |
251 return ( | |
252 "<h2 style='text-align: center;'>Model Performance Summary</h2>" | |
253 "<div style='display: flex; justify-content: center;'>" | |
254 "<table style='border-collapse: collapse; width: 80%; table-layout: fixed;'>" | |
255 "<colgroup>" | |
256 "<col style='width: 40%;'>" | |
257 "<col style='width: 20%;'>" | |
258 "<col style='width: 20%;'>" | |
259 "<col style='width: 20%;'>" | |
260 "</colgroup>" | |
261 "<thead><tr>" | |
262 "<th style='padding: 10px; border: 1px solid #ccc; text-align: left;'>Metric</th>" | |
263 "<th style='padding: 10px; border: 1px solid #ccc; text-align: center;'>Train</th>" | |
264 "<th style='padding: 10px; border: 1px solid #ccc; text-align: center;'>Validation</th>" | |
265 "<th style='padding: 10px; border: 1px solid #ccc; text-align: center;'>Test</th>" | |
266 "</tr></thead><tbody>" + "".join(rows) + "</tbody></table></div><br>" | |
267 ) | |
268 | |
269 | |
270 def build_tabbed_html( | |
271 metrics_html: str, | |
272 train_viz_html: str, | |
273 test_viz_html: str) -> str: | |
274 return f""" | |
275 <style> | |
276 .tabs {{ | |
277 display: flex; | |
278 border-bottom: 2px solid #ccc; | |
279 margin-bottom: 1rem; | |
280 }} | |
281 .tab {{ | |
282 padding: 10px 20px; | |
283 cursor: pointer; | |
284 border: 1px solid #ccc; | |
285 border-bottom: none; | |
286 background: #f9f9f9; | |
287 margin-right: 5px; | |
288 border-top-left-radius: 8px; | |
289 border-top-right-radius: 8px; | |
290 }} | |
291 .tab.active {{ | |
292 background: white; | |
293 font-weight: bold; | |
294 }} | |
295 .tab-content {{ | |
296 display: none; | |
297 padding: 20px; | |
298 border: 1px solid #ccc; | |
299 border-top: none; | |
300 }} | |
301 .tab-content.active {{ | |
302 display: block; | |
303 }} | |
304 </style> | |
305 | |
306 <div class="tabs"> | |
307 <div class="tab active" onclick="showTab('metrics')"> Config & Metrics</div> | |
308 <div class="tab" onclick="showTab('trainval')"> Train/Validation Plots</div> | |
309 <div class="tab" onclick="showTab('test')"> Test Plots</div> | |
310 </div> | |
311 | |
312 <div id="metrics" class="tab-content active"> | |
313 {metrics_html} | |
314 </div> | |
315 <div id="trainval" class="tab-content"> | |
316 {train_viz_html} | |
317 </div> | |
318 <div id="test" class="tab-content"> | |
319 {test_viz_html} | |
320 </div> | |
321 | |
322 <script> | |
323 function showTab(id) {{ | |
324 document.querySelectorAll('.tab-content').forEach(el => el.classList.remove('active')); | |
325 document.querySelectorAll('.tab').forEach(el => el.classList.remove('active')); | |
326 document.getElementById(id).classList.add('active'); | |
327 document.querySelector(`.tab[onclick*="${{id}}"]`).classList.add('active'); | |
328 }} | |
329 </script> | |
330 """ | |
331 | |
332 | |
333 def split_data_0_2( | |
334 df: pd.DataFrame, | |
335 split_column: str, | |
336 validation_size: float = 0.15, | |
337 random_state: int = 42, | |
338 label_column: Optional[str] = None, | |
339 ) -> pd.DataFrame: | |
340 """ | |
341 Given a DataFrame whose split_column only contains {0,2}, re-assign | |
342 a portion of the 0s to become 1s (validation). Returns a fresh DataFrame. | |
343 """ | |
344 # Work on a copy | |
345 out = df.copy() | |
346 # Ensure split col is integer dtype | |
347 out[split_column] = pd.to_numeric(out[split_column], errors="coerce").astype(int) | |
348 | |
349 idx_train = out.index[out[split_column] == 0].tolist() | |
350 | |
351 if not idx_train: | |
352 logger.info("No rows with split=0; nothing to do.") | |
353 return out | |
354 | |
355 # Determine stratify array if possible | |
356 stratify_arr = None | |
357 if label_column and label_column in out.columns: | |
358 # Only stratify if at least two classes and enough samples | |
359 label_counts = out.loc[idx_train, label_column].value_counts() | |
360 if label_counts.size > 1 and (label_counts.min() * validation_size) >= 1: | |
361 stratify_arr = out.loc[idx_train, label_column] | |
362 else: | |
363 logger.warning("Cannot stratify (too few labels); splitting without stratify.") | |
364 | |
365 # Edge cases | |
366 if validation_size <= 0: | |
367 logger.info("validation_size <= 0; keeping all as train.") | |
368 return out | |
369 if validation_size >= 1: | |
370 logger.info("validation_size >= 1; moving all train → validation.") | |
371 out.loc[idx_train, split_column] = 1 | |
372 return out | |
373 | |
374 # Do the split | |
375 try: | |
376 train_idx, val_idx = train_test_split( | |
377 idx_train, | |
378 test_size=validation_size, | |
379 random_state=random_state, | |
380 stratify=stratify_arr | |
381 ) | |
382 except ValueError as e: | |
383 logger.warning(f"Stratified split failed ({e}); retrying without stratify.") | |
384 train_idx, val_idx = train_test_split( | |
385 idx_train, | |
386 test_size=validation_size, | |
387 random_state=random_state, | |
388 stratify=None | |
389 ) | |
390 | |
391 # Assign new splits | |
392 out.loc[train_idx, split_column] = 0 | |
393 out.loc[val_idx, split_column] = 1 | |
394 # idx_test stays at 2 | |
395 | |
396 # Cast back to a clean integer type | |
397 out[split_column] = out[split_column].astype(int) | |
398 # print(out) | |
399 return out | |
400 | |
401 | |
402 class Backend(Protocol): | |
403 """Interface for a machine learning backend.""" | |
404 def prepare_config( | |
405 self, | |
406 config_params: Dict[str, Any], | |
407 split_config: Dict[str, Any] | |
408 ) -> str: | |
409 ... | |
410 | |
411 def run_experiment( | |
412 self, | |
413 dataset_path: Path, | |
414 config_path: Path, | |
415 output_dir: Path, | |
416 random_seed: int, | |
417 ) -> None: | |
418 ... | |
419 | |
420 def generate_plots( | |
421 self, | |
422 output_dir: Path | |
423 ) -> None: | |
424 ... | |
425 | |
426 def generate_html_report( | |
427 self, | |
428 title: str, | |
429 output_dir: str | |
430 ) -> Path: | |
431 ... | |
432 | |
433 | |
434 class LudwigDirectBackend: | |
435 """ | |
436 Backend for running Ludwig experiments directly via the internal experiment_cli function. | |
437 """ | |
438 | |
439 def prepare_config( | |
440 self, | |
441 config_params: Dict[str, Any], | |
442 split_config: Dict[str, Any], | |
443 ) -> str: | |
444 """ | |
445 Build and serialize the Ludwig YAML configuration. | |
446 """ | |
447 logger.info("LudwigDirectBackend: Preparing YAML configuration.") | |
448 | |
449 model_name = config_params.get("model_name", "resnet18") | |
450 use_pretrained = config_params.get("use_pretrained", False) | |
451 fine_tune = config_params.get("fine_tune", False) | |
452 epochs = config_params.get("epochs", 10) | |
453 batch_size = config_params.get("batch_size") | |
454 num_processes = config_params.get("preprocessing_num_processes", 1) | |
455 early_stop = config_params.get("early_stop", None) | |
456 learning_rate = config_params.get("learning_rate") | |
457 learning_rate = "auto" if learning_rate is None else float(learning_rate) | |
458 trainable = fine_tune or (not use_pretrained) | |
459 if not use_pretrained and not trainable: | |
460 logger.warning("trainable=False; use_pretrained=False is ignored.") | |
461 logger.warning("Setting trainable=True to train the model from scratch.") | |
462 trainable = True | |
463 | |
464 # Encoder setup | |
465 raw_encoder = MODEL_ENCODER_TEMPLATES.get(model_name, model_name) | |
466 if isinstance(raw_encoder, dict): | |
467 encoder_config = { | |
468 **raw_encoder, | |
469 "use_pretrained": use_pretrained, | |
470 "trainable": trainable, | |
471 } | |
472 else: | |
473 encoder_config = {"type": raw_encoder} | |
474 | |
475 # Trainer & optimizer | |
476 # optimizer = {"type": "adam", "learning_rate": 5e-5} if fine_tune else {"type": "adam"} | |
477 batch_size_cfg = batch_size or "auto" | |
478 | |
479 conf: Dict[str, Any] = { | |
480 "model_type": "ecd", | |
481 "input_features": [ | |
482 { | |
483 "name": IMAGE_PATH_COLUMN_NAME, | |
484 "type": "image", | |
485 "encoder": encoder_config, | |
486 } | |
487 ], | |
488 "output_features": [ | |
489 {"name": LABEL_COLUMN_NAME, "type": "category"} | |
490 ], | |
491 "combiner": {"type": "concat"}, | |
492 "trainer": { | |
493 "epochs": epochs, | |
494 "early_stop": early_stop, | |
495 "batch_size": batch_size_cfg, | |
496 "learning_rate": learning_rate, | |
497 }, | |
498 "preprocessing": { | |
499 "split": split_config, | |
500 "num_processes": num_processes, | |
501 "in_memory": False, | |
502 }, | |
503 } | |
504 | |
505 logger.debug("LudwigDirectBackend: Config dict built.") | |
506 try: | |
507 yaml_str = yaml.dump(conf, sort_keys=False, indent=2) | |
508 logger.info("LudwigDirectBackend: YAML config generated.") | |
509 return yaml_str | |
510 except Exception: | |
511 logger.error("LudwigDirectBackend: Failed to serialize YAML.", exc_info=True) | |
512 raise | |
513 | |
514 def run_experiment( | |
515 self, | |
516 dataset_path: Path, | |
517 config_path: Path, | |
518 output_dir: Path, | |
519 random_seed: int = 42, | |
520 ) -> None: | |
521 """ | |
522 Invoke Ludwig's internal experiment_cli function to run the experiment. | |
523 """ | |
524 logger.info("LudwigDirectBackend: Starting experiment execution.") | |
525 | |
526 try: | |
527 from ludwig.experiment import experiment_cli | |
528 except ImportError as e: | |
529 logger.error( | |
530 "LudwigDirectBackend: Could not import experiment_cli.", | |
531 exc_info=True | |
532 ) | |
533 raise RuntimeError("Ludwig import failed.") from e | |
534 | |
535 output_dir.mkdir(parents=True, exist_ok=True) | |
536 | |
537 try: | |
538 experiment_cli( | |
539 dataset=str(dataset_path), | |
540 config=str(config_path), | |
541 output_directory=str(output_dir), | |
542 random_seed=random_seed, | |
543 ) | |
544 logger.info(f"LudwigDirectBackend: Experiment completed. Results in {output_dir}") | |
545 except TypeError as e: | |
546 logger.error( | |
547 "LudwigDirectBackend: Argument mismatch in experiment_cli call.", | |
548 exc_info=True | |
549 ) | |
550 raise RuntimeError("Ludwig argument error.") from e | |
551 except Exception: | |
552 logger.error( | |
553 "LudwigDirectBackend: Experiment execution error.", | |
554 exc_info=True | |
555 ) | |
556 raise | |
557 | |
558 def get_training_process(self, output_dir) -> float: | |
559 """ | |
560 Retrieve the learning rate used in the most recent Ludwig run. | |
561 Returns: | |
562 float: learning rate (or None if not found) | |
563 """ | |
564 output_dir = Path(output_dir) | |
565 exp_dirs = sorted( | |
566 output_dir.glob("experiment_run*"), | |
567 key=lambda p: p.stat().st_mtime | |
568 ) | |
569 | |
570 if not exp_dirs: | |
571 logger.warning(f"No experiment run directories found in {output_dir}") | |
572 return None | |
573 | |
574 progress_file = exp_dirs[-1] / "model" / "training_progress.json" | |
575 if not progress_file.exists(): | |
576 logger.warning(f"No training_progress.json found in {progress_file}") | |
577 return None | |
578 | |
579 try: | |
580 with progress_file.open("r", encoding="utf-8") as f: | |
581 data = json.load(f) | |
582 return { | |
583 "learning_rate": data.get("learning_rate"), | |
584 "batch_size": data.get("batch_size"), | |
585 "epoch": data.get("epoch"), | |
586 } | |
587 except Exception as e: | |
588 self.logger.warning(f"Failed to read training progress info: {e}") | |
589 return {} | |
590 | |
591 def convert_parquet_to_csv(self, output_dir: Path): | |
592 """Convert the predictions Parquet file to CSV.""" | |
593 output_dir = Path(output_dir) | |
594 exp_dirs = sorted( | |
595 output_dir.glob("experiment_run*"), | |
596 key=lambda p: p.stat().st_mtime | |
597 ) | |
598 if not exp_dirs: | |
599 logger.warning(f"No experiment run dirs found in {output_dir}") | |
600 return | |
601 exp_dir = exp_dirs[-1] | |
602 parquet_path = exp_dir / PREDICTIONS_PARQUET_FILE_NAME | |
603 csv_path = exp_dir / "predictions.csv" | |
604 try: | |
605 df = pd.read_parquet(parquet_path) | |
606 df.to_csv(csv_path, index=False) | |
607 logger.info(f"Converted Parquet to CSV: {csv_path}") | |
608 except Exception as e: | |
609 logger.error(f"Error converting Parquet to CSV: {e}") | |
610 | |
611 def generate_plots(self, output_dir: Path) -> None: | |
612 """ | |
613 Generate _all_ registered Ludwig visualizations for the latest experiment run. | |
614 """ | |
615 logger.info("Generating all Ludwig visualizations…") | |
616 | |
617 test_plots = { | |
618 'compare_performance', | |
619 'compare_classifiers_performance_from_prob', | |
620 'compare_classifiers_performance_from_pred', | |
621 'compare_classifiers_performance_changing_k', | |
622 'compare_classifiers_multiclass_multimetric', | |
623 'compare_classifiers_predictions', | |
624 'confidence_thresholding_2thresholds_2d', | |
625 'confidence_thresholding_2thresholds_3d', | |
626 'confidence_thresholding', | |
627 'confidence_thresholding_data_vs_acc', | |
628 'binary_threshold_vs_metric', | |
629 'roc_curves', | |
630 'roc_curves_from_test_statistics', | |
631 'calibration_1_vs_all', | |
632 'calibration_multiclass', | |
633 'confusion_matrix', | |
634 'frequency_vs_f1', | |
635 } | |
636 train_plots = { | |
637 'learning_curves', | |
638 'compare_classifiers_performance_subset', | |
639 } | |
640 | |
641 # 1) find the most recent experiment directory | |
642 output_dir = Path(output_dir) | |
643 exp_dirs = sorted( | |
644 output_dir.glob("experiment_run*"), | |
645 key=lambda p: p.stat().st_mtime | |
646 ) | |
647 if not exp_dirs: | |
648 logger.warning(f"No experiment run dirs found in {output_dir}") | |
649 return | |
650 exp_dir = exp_dirs[-1] | |
651 | |
652 # 2) ensure viz output subfolder exists | |
653 viz_dir = exp_dir / "visualizations" | |
654 viz_dir.mkdir(exist_ok=True) | |
655 train_viz = viz_dir / "train" | |
656 test_viz = viz_dir / "test" | |
657 train_viz.mkdir(parents=True, exist_ok=True) | |
658 test_viz.mkdir(parents=True, exist_ok=True) | |
659 | |
660 # 3) helper to check file existence | |
661 def _check(p: Path) -> Optional[str]: | |
662 return str(p) if p.exists() else None | |
663 | |
664 # 4) gather standard Ludwig output files | |
665 training_stats = _check(exp_dir / "training_statistics.json") | |
666 test_stats = _check(exp_dir / TEST_STATISTICS_FILE_NAME) | |
667 probs_path = _check(exp_dir / PREDICTIONS_PARQUET_FILE_NAME) | |
668 gt_metadata = _check(exp_dir / "model" / TRAIN_SET_METADATA_FILE_NAME) | |
669 | |
670 # 5) try to read original dataset & split file from description.json | |
671 dataset_path = None | |
672 split_file = None | |
673 desc = exp_dir / DESCRIPTION_FILE_NAME | |
674 if desc.exists(): | |
675 with open(desc, "r") as f: | |
676 cfg = json.load(f) | |
677 dataset_path = _check(Path(cfg.get("dataset", ""))) | |
678 split_file = _check(Path(get_split_path(cfg.get("dataset", "")))) | |
679 | |
680 # 6) infer output feature name | |
681 output_feature = "" | |
682 if desc.exists(): | |
683 try: | |
684 output_feature = cfg["config"]["output_features"][0]["name"] | |
685 except Exception: | |
686 pass | |
687 if not output_feature and test_stats: | |
688 with open(test_stats, "r") as f: | |
689 stats = json.load(f) | |
690 output_feature = next(iter(stats.keys()), "") | |
691 | |
692 # 7) loop through every registered viz | |
693 viz_registry = get_visualizations_registry() | |
694 for viz_name, viz_func in viz_registry.items(): | |
695 viz_dir_plot = None | |
696 if viz_name in train_plots: | |
697 viz_dir_plot = train_viz | |
698 elif viz_name in test_plots: | |
699 viz_dir_plot = test_viz | |
700 | |
701 try: | |
702 viz_func( | |
703 training_statistics=[training_stats] if training_stats else [], | |
704 test_statistics=[test_stats] if test_stats else [], | |
705 probabilities=[probs_path] if probs_path else [], | |
706 output_feature_name=output_feature, | |
707 ground_truth_split=2, | |
708 top_n_classes=[0], | |
709 top_k=3, | |
710 ground_truth_metadata=gt_metadata, | |
711 ground_truth=dataset_path, | |
712 split_file=split_file, | |
713 output_directory=str(viz_dir_plot), | |
714 normalize=False, | |
715 file_format="png", | |
716 ) | |
717 logger.info(f"✔ Generated {viz_name}") | |
718 except Exception as e: | |
719 logger.warning(f"✘ Skipped {viz_name}: {e}") | |
720 | |
721 logger.info(f"All visualizations written to {viz_dir}") | |
722 | |
723 def generate_html_report( | |
724 self, | |
725 title: str, | |
726 output_dir: str, | |
727 config: dict, | |
728 split_info: str) -> Path: | |
729 """ | |
730 Assemble an HTML report from visualizations under train_val/ and test/ folders. | |
731 """ | |
732 cwd = Path.cwd() | |
733 report_name = title.lower().replace(" ", "_") + "_report.html" | |
734 report_path = cwd / report_name | |
735 output_dir = Path(output_dir) | |
736 | |
737 # Find latest experiment dir | |
738 exp_dirs = sorted(output_dir.glob("experiment_run*"), key=lambda p: p.stat().st_mtime) | |
739 if not exp_dirs: | |
740 raise RuntimeError(f"No 'experiment*' dirs found in {output_dir}") | |
741 exp_dir = exp_dirs[-1] | |
742 | |
743 base_viz_dir = exp_dir / "visualizations" | |
744 train_viz_dir = base_viz_dir / "train" | |
745 test_viz_dir = base_viz_dir / "test" | |
746 | |
747 html = get_html_template() | |
748 html += f"<h1>{title}</h1>" | |
749 | |
750 metrics_html = "" | |
751 | |
752 # Load and embed metrics table (training/val/test stats) | |
753 try: | |
754 train_stats_path = exp_dir / "training_statistics.json" | |
755 test_stats_path = exp_dir / TEST_STATISTICS_FILE_NAME | |
756 if train_stats_path.exists() and test_stats_path.exists(): | |
757 with open(train_stats_path) as f: | |
758 train_stats = json.load(f) | |
759 with open(test_stats_path) as f: | |
760 test_stats = json.load(f) | |
761 output_feature = next(iter(train_stats.keys()), "") | |
762 if output_feature: | |
763 metrics_html += format_stats_table_html(train_stats, test_stats) | |
764 except Exception as e: | |
765 logger.warning(f"Could not load stats for HTML report: {e}") | |
766 | |
767 config_html = "" | |
768 training_progress = self.get_training_process(output_dir) | |
769 try: | |
770 config_html = format_config_table_html(config, split_info, training_progress) | |
771 except Exception as e: | |
772 logger.warning(f"Could not load config for HTML report: {e}") | |
773 | |
774 def render_img_section(title: str, dir_path: Path) -> str: | |
775 if not dir_path.exists(): | |
776 return f"<h2>{title}</h2><p><em>Directory not found.</em></p>" | |
777 imgs = sorted(dir_path.glob("*.png")) | |
778 if not imgs: | |
779 return f"<h2>{title}</h2><p><em>No plots found.</em></p>" | |
780 | |
781 section_html = f"<h2 style='text-align: center;'>{title}</h2><div>" | |
782 for img in imgs: | |
783 b64 = encode_image_to_base64(str(img)) | |
784 section_html += ( | |
785 f'<div class="plot" style="margin-bottom:20px;text-align:center;">' | |
786 f"<h3>{img.stem.replace('_',' ').title()}</h3>" | |
787 f'<img src="data:image/png;base64,{b64}" ' | |
788 'style="max-width:90%;max-height:600px;border:1px solid #ddd;" />' | |
789 "</div>" | |
790 ) | |
791 section_html += "</div>" | |
792 return section_html | |
793 | |
794 train_plots_html = render_img_section("Training & Validation Visualizations", train_viz_dir) | |
795 test_plots_html = render_img_section("Test Visualizations", test_viz_dir) | |
796 html += build_tabbed_html(config_html + metrics_html, train_plots_html, test_plots_html) | |
797 html += get_html_closing() | |
798 | |
799 try: | |
800 with open(report_path, "w") as f: | |
801 f.write(html) | |
802 logger.info(f"HTML report generated at: {report_path}") | |
803 except Exception as e: | |
804 logger.error(f"Failed to write HTML report: {e}") | |
805 raise | |
806 | |
807 return report_path | |
808 | |
809 | |
810 class WorkflowOrchestrator: | |
811 """ | |
812 Manages the image-classification workflow: | |
813 1. Creates temp dirs | |
814 2. Extracts images | |
815 3. Prepares data (CSV + splits) | |
816 4. Renders a backend config | |
817 5. Runs the experiment | |
818 6. Cleans up | |
819 """ | |
820 | |
821 def __init__(self, args: argparse.Namespace, backend: Backend): | |
822 self.args = args | |
823 self.backend = backend | |
824 self.temp_dir: Optional[Path] = None | |
825 self.image_extract_dir: Optional[Path] = None | |
826 logger.info(f"Orchestrator initialized with backend: {type(backend).__name__}") | |
827 | |
828 def _create_temp_dirs(self) -> None: | |
829 """Create temporary output and image extraction directories.""" | |
830 try: | |
831 self.temp_dir = Path(tempfile.mkdtemp( | |
832 dir=self.args.output_dir, | |
833 prefix=TEMP_DIR_PREFIX | |
834 )) | |
835 self.image_extract_dir = self.temp_dir / "images" | |
836 self.image_extract_dir.mkdir() | |
837 logger.info(f"Created temp directory: {self.temp_dir}") | |
838 except Exception: | |
839 logger.error("Failed to create temporary directories", exc_info=True) | |
840 raise | |
841 | |
842 def _extract_images(self) -> None: | |
843 """Extract images from ZIP into the temp image directory.""" | |
844 if self.image_extract_dir is None: | |
845 raise RuntimeError("Temp image directory not initialized.") | |
846 logger.info(f"Extracting images from {self.args.image_zip} → {self.image_extract_dir}") | |
847 try: | |
848 with zipfile.ZipFile(self.args.image_zip, "r") as z: | |
849 z.extractall(self.image_extract_dir) | |
850 logger.info("Image extraction complete.") | |
851 except Exception: | |
852 logger.error("Error extracting zip file", exc_info=True) | |
853 raise | |
854 | |
855 def _prepare_data(self) -> Tuple[Path, Dict[str, Any]]: | |
856 """ | |
857 Load CSV, update image paths, handle splits, and write prepared CSV. | |
858 Returns: | |
859 final_csv_path: Path to the prepared CSV | |
860 split_config: Dict for backend split settings | |
861 """ | |
862 if not self.temp_dir or not self.image_extract_dir: | |
863 raise RuntimeError("Temp dirs not initialized before data prep.") | |
864 | |
865 # 1) Load | |
866 try: | |
867 df = pd.read_csv(self.args.csv_file) | |
868 logger.info(f"Loaded CSV: {self.args.csv_file}") | |
869 except Exception: | |
870 logger.error("Error loading CSV file", exc_info=True) | |
871 raise | |
872 | |
873 # 2) Validate columns | |
874 required = {IMAGE_PATH_COLUMN_NAME, LABEL_COLUMN_NAME} | |
875 missing = required - set(df.columns) | |
876 if missing: | |
877 raise ValueError(f"Missing CSV columns: {', '.join(missing)}") | |
878 | |
879 # 3) Update image paths | |
880 try: | |
881 df[IMAGE_PATH_COLUMN_NAME] = df[IMAGE_PATH_COLUMN_NAME].apply( | |
882 lambda p: str((self.image_extract_dir / p).resolve()) | |
883 ) | |
884 except Exception: | |
885 logger.error("Error updating image paths", exc_info=True) | |
886 raise | |
887 | |
888 # 4) Handle splits | |
889 if SPLIT_COLUMN_NAME in df.columns: | |
890 df, split_config, split_info = self._process_fixed_split(df) | |
891 else: | |
892 logger.info("No split column; using random split") | |
893 split_config = { | |
894 "type": "random", | |
895 "probabilities": self.args.split_probabilities | |
896 } | |
897 split_info = ( | |
898 f"No split column in CSV. Used random split: " | |
899 f"{[int(p*100) for p in self.args.split_probabilities]}% for train/val/test." | |
900 ) | |
901 | |
902 # 5) Write out prepared CSV | |
903 final_csv = TEMP_CSV_FILENAME | |
904 try: | |
905 df.to_csv(final_csv, index=False) | |
906 logger.info(f"Saved prepared data to {final_csv}") | |
907 except Exception: | |
908 logger.error("Error saving prepared CSV", exc_info=True) | |
909 raise | |
910 | |
911 return final_csv, split_config, split_info | |
912 | |
913 def _process_fixed_split(self, df: pd.DataFrame) -> Dict[str, Any]: | |
914 """Process a fixed split column (0=train,1=val,2=test).""" | |
915 logger.info(f"Fixed split column '{SPLIT_COLUMN_NAME}' detected.") | |
916 try: | |
917 col = df[SPLIT_COLUMN_NAME] | |
918 df[SPLIT_COLUMN_NAME] = pd.to_numeric(col, errors="coerce").astype(pd.Int64Dtype()) | |
919 if df[SPLIT_COLUMN_NAME].isna().any(): | |
920 logger.warning("Split column contains non-numeric/missing values.") | |
921 | |
922 unique = set(df[SPLIT_COLUMN_NAME].dropna().unique()) | |
923 logger.info(f"Unique split values: {unique}") | |
924 | |
925 if unique == {0, 2}: | |
926 df = split_data_0_2( | |
927 df, SPLIT_COLUMN_NAME, | |
928 validation_size=self.args.validation_size, | |
929 label_column=LABEL_COLUMN_NAME, | |
930 random_state=self.args.random_seed | |
931 ) | |
932 split_info = ( | |
933 "Detected a split column (with values 0 and 2) in the input CSV. " | |
934 f"Used this column as a base and" | |
935 f"reassigned {self.args.validation_size * 100:.1f}% " | |
936 "of the training set (originally labeled 0) to validation (labeled 1)." | |
937 ) | |
938 | |
939 logger.info("Applied custom 0/2 split.") | |
940 elif unique.issubset({0, 1, 2}): | |
941 split_info = "Used user-defined split column from CSV." | |
942 logger.info("Using fixed split as-is.") | |
943 else: | |
944 raise ValueError(f"Unexpected split values: {unique}") | |
945 | |
946 return df, {"type": "fixed", "column": SPLIT_COLUMN_NAME}, split_info | |
947 | |
948 except Exception: | |
949 logger.error("Error processing fixed split", exc_info=True) | |
950 raise | |
951 | |
952 def _cleanup_temp_dirs(self) -> None: | |
953 """Remove any temporary directories.""" | |
954 if self.temp_dir and self.temp_dir.exists(): | |
955 logger.info(f"Cleaning up temp directory: {self.temp_dir}") | |
956 shutil.rmtree(self.temp_dir, ignore_errors=True) | |
957 self.temp_dir = None | |
958 self.image_extract_dir = None | |
959 | |
960 def run(self) -> None: | |
961 """Execute the full workflow end-to-end.""" | |
962 logger.info("Starting workflow...") | |
963 self.args.output_dir.mkdir(parents=True, exist_ok=True) | |
964 | |
965 try: | |
966 self._create_temp_dirs() | |
967 self._extract_images() | |
968 csv_path, split_cfg, split_info = self._prepare_data() | |
969 | |
970 use_pretrained = self.args.use_pretrained or self.args.fine_tune | |
971 | |
972 backend_args = { | |
973 "model_name": self.args.model_name, | |
974 "fine_tune": self.args.fine_tune, | |
975 "use_pretrained": use_pretrained, | |
976 "epochs": self.args.epochs, | |
977 "batch_size": self.args.batch_size, | |
978 "preprocessing_num_processes": self.args.preprocessing_num_processes, | |
979 "split_probabilities": self.args.split_probabilities, | |
980 "learning_rate": self.args.learning_rate, | |
981 "random_seed": self.args.random_seed, | |
982 "early_stop": self.args.early_stop, | |
983 } | |
984 yaml_str = self.backend.prepare_config(backend_args, split_cfg) | |
985 | |
986 config_file = self.temp_dir / TEMP_CONFIG_FILENAME | |
987 config_file.write_text(yaml_str) | |
988 logger.info(f"Wrote backend config: {config_file}") | |
989 | |
990 self.backend.run_experiment( | |
991 csv_path, | |
992 config_file, | |
993 self.args.output_dir, | |
994 self.args.random_seed | |
995 ) | |
996 logger.info("Workflow completed successfully.") | |
997 self.backend.generate_plots(self.args.output_dir) | |
998 report_file = self.backend.generate_html_report( | |
999 "Image Classification Results", | |
1000 self.args.output_dir, | |
1001 backend_args, | |
1002 split_info | |
1003 ) | |
1004 logger.info(f"HTML report generated at: {report_file}") | |
1005 self.backend.convert_parquet_to_csv(self.args.output_dir) | |
1006 logger.info("Converted Parquet to CSV.") | |
1007 except Exception: | |
1008 logger.error("Workflow execution failed", exc_info=True) | |
1009 raise | |
1010 | |
1011 finally: | |
1012 self._cleanup_temp_dirs() | |
1013 | |
1014 | |
1015 def parse_learning_rate(s): | |
1016 try: | |
1017 return float(s) | |
1018 except (TypeError, ValueError): | |
1019 return None | |
1020 | |
1021 | |
1022 class SplitProbAction(argparse.Action): | |
1023 def __call__(self, parser, namespace, values, option_string=None): | |
1024 # values is a list of three floats | |
1025 train, val, test = values | |
1026 total = train + val + test | |
1027 if abs(total - 1.0) > 1e-6: | |
1028 parser.error( | |
1029 f"--split-probabilities must sum to 1.0; " | |
1030 f"got {train:.3f} + {val:.3f} + {test:.3f} = {total:.3f}" | |
1031 ) | |
1032 setattr(namespace, self.dest, values) | |
1033 | |
1034 | |
1035 def main(): | |
1036 | |
1037 parser = argparse.ArgumentParser( | |
1038 description="Image Classification Learner with Pluggable Backends" | |
1039 ) | |
1040 parser.add_argument( | |
1041 "--csv-file", required=True, type=Path, | |
1042 help="Path to the input CSV" | |
1043 ) | |
1044 parser.add_argument( | |
1045 "--image-zip", required=True, type=Path, | |
1046 help="Path to the images ZIP" | |
1047 ) | |
1048 parser.add_argument( | |
1049 "--model-name", required=True, | |
1050 choices=MODEL_ENCODER_TEMPLATES.keys(), | |
1051 help="Which model template to use" | |
1052 ) | |
1053 parser.add_argument( | |
1054 "--use-pretrained", action="store_true", | |
1055 help="Use pretrained weights for the model" | |
1056 ) | |
1057 parser.add_argument( | |
1058 "--fine-tune", action="store_true", | |
1059 help="Enable fine-tuning" | |
1060 ) | |
1061 parser.add_argument( | |
1062 "--epochs", type=int, default=10, | |
1063 help="Number of training epochs" | |
1064 ) | |
1065 parser.add_argument( | |
1066 "--early-stop", type=int, default=5, | |
1067 help="Early stopping patience" | |
1068 ) | |
1069 parser.add_argument( | |
1070 "--batch-size", type=int, | |
1071 help="Batch size (None = auto)" | |
1072 ) | |
1073 parser.add_argument( | |
1074 "--output-dir", type=Path, default=Path("learner_output"), | |
1075 help="Where to write outputs" | |
1076 ) | |
1077 parser.add_argument( | |
1078 "--validation-size", type=float, default=0.15, | |
1079 help="Fraction for validation (0.0–1.0)" | |
1080 ) | |
1081 parser.add_argument( | |
1082 "--preprocessing-num-processes", type=int, | |
1083 default=max(1, os.cpu_count() // 2), | |
1084 help="CPU processes for data prep" | |
1085 ) | |
1086 parser.add_argument( | |
1087 "--split-probabilities", type=float, nargs=3, | |
1088 metavar=("train", "val", "test"), | |
1089 action=SplitProbAction, | |
1090 default=[0.7, 0.1, 0.2], | |
1091 help="Random split proportions (e.g., 0.7 0.1 0.2). Only used if no split column is present." | |
1092 ) | |
1093 parser.add_argument( | |
1094 "--random-seed", type=int, default=42, | |
1095 help="Random seed used for dataset splitting (default: 42)" | |
1096 ) | |
1097 parser.add_argument( | |
1098 "--learning-rate", type=parse_learning_rate, default=None, | |
1099 help="Learning rate. If not provided, Ludwig will auto-select it." | |
1100 ) | |
1101 | |
1102 args = parser.parse_args() | |
1103 | |
1104 # -- Validation -- | |
1105 if not 0.0 <= args.validation_size <= 1.0: | |
1106 parser.error("validation-size must be between 0.0 and 1.0") | |
1107 if not args.csv_file.is_file(): | |
1108 parser.error(f"CSV not found: {args.csv_file}") | |
1109 if not args.image_zip.is_file(): | |
1110 parser.error(f"ZIP not found: {args.image_zip}") | |
1111 | |
1112 # --- Instantiate Backend and Orchestrator --- | |
1113 # Use the new LudwigDirectBackend | |
1114 backend_instance = LudwigDirectBackend() | |
1115 orchestrator = WorkflowOrchestrator(args, backend_instance) | |
1116 | |
1117 # --- Run Workflow --- | |
1118 exit_code = 0 | |
1119 try: | |
1120 orchestrator.run() | |
1121 logger.info("Main script finished successfully.") | |
1122 except Exception as e: | |
1123 logger.error(f"Main script failed.{e}") | |
1124 exit_code = 1 | |
1125 finally: | |
1126 sys.exit(exit_code) | |
1127 | |
1128 | |
1129 if __name__ == '__main__': | |
1130 try: | |
1131 import ludwig | |
1132 logger.debug(f"Found Ludwig version: {ludwig.globals.LUDWIG_VERSION}") | |
1133 except ImportError: | |
1134 logger.error("Ludwig library not found. Please ensure Ludwig is installed ('pip install ludwig[image]')") | |
1135 sys.exit(1) | |
1136 | |
1137 main() |