comparison image_learner_cli.py @ 0:54b871dfc51e draft default tip

planemo upload for repository https://github.com/goeckslab/gleam.git commit b7411ff35b6228ccdfd36cd4ebd946c03ac7f7e9
author goeckslab
date Tue, 03 Jun 2025 21:22:11 +0000
parents
children
comparison
equal deleted inserted replaced
-1:000000000000 0:54b871dfc51e
1 #!/usr/bin/env python3
2 import argparse
3 import json
4 import logging
5 import os
6 import shutil
7 import sys
8 import tempfile
9 import zipfile
10 from pathlib import Path
11 from typing import Any, Dict, Optional, Protocol, Tuple
12
13 import pandas as pd
14 import yaml
15 from ludwig.globals import (
16 DESCRIPTION_FILE_NAME,
17 PREDICTIONS_PARQUET_FILE_NAME,
18 TEST_STATISTICS_FILE_NAME,
19 TRAIN_SET_METADATA_FILE_NAME,
20 )
21 from ludwig.utils.data_utils import get_split_path
22 from ludwig.visualize import get_visualizations_registry
23 from sklearn.model_selection import train_test_split
24 from utils import encode_image_to_base64, get_html_closing, get_html_template
25
26 # --- Constants ---
27 SPLIT_COLUMN_NAME = 'split'
28 LABEL_COLUMN_NAME = 'label'
29 IMAGE_PATH_COLUMN_NAME = 'image_path'
30 DEFAULT_SPLIT_PROBABILITIES = [0.7, 0.1, 0.2]
31 TEMP_CSV_FILENAME = "processed_data_for_ludwig.csv"
32 TEMP_CONFIG_FILENAME = "ludwig_config.yaml"
33 TEMP_DIR_PREFIX = "ludwig_api_work_"
34 MODEL_ENCODER_TEMPLATES: Dict[str, Any] = {
35 'stacked_cnn': 'stacked_cnn',
36 'resnet18': {'type': 'resnet', 'model_variant': 18},
37 'resnet34': {'type': 'resnet', 'model_variant': 34},
38 'resnet50': {'type': 'resnet', 'model_variant': 50},
39 'resnet101': {'type': 'resnet', 'model_variant': 101},
40 'resnet152': {'type': 'resnet', 'model_variant': 152},
41 'resnext50_32x4d': {'type': 'resnext', 'model_variant': '50_32x4d'},
42 'resnext101_32x8d': {'type': 'resnext', 'model_variant': '101_32x8d'},
43 'resnext101_64x4d': {'type': 'resnext', 'model_variant': '101_64x4d'},
44 'resnext152_32x8d': {'type': 'resnext', 'model_variant': '152_32x8d'},
45 'wide_resnet50_2': {'type': 'wide_resnet', 'model_variant': '50_2'},
46 'wide_resnet101_2': {'type': 'wide_resnet', 'model_variant': '101_2'},
47 'wide_resnet103_2': {'type': 'wide_resnet', 'model_variant': '103_2'},
48 'efficientnet_b0': {'type': 'efficientnet', 'model_variant': 'b0'},
49 'efficientnet_b1': {'type': 'efficientnet', 'model_variant': 'b1'},
50 'efficientnet_b2': {'type': 'efficientnet', 'model_variant': 'b2'},
51 'efficientnet_b3': {'type': 'efficientnet', 'model_variant': 'b3'},
52 'efficientnet_b4': {'type': 'efficientnet', 'model_variant': 'b4'},
53 'efficientnet_b5': {'type': 'efficientnet', 'model_variant': 'b5'},
54 'efficientnet_b6': {'type': 'efficientnet', 'model_variant': 'b6'},
55 'efficientnet_b7': {'type': 'efficientnet', 'model_variant': 'b7'},
56 'efficientnet_v2_s': {'type': 'efficientnet', 'model_variant': 'v2_s'},
57 'efficientnet_v2_m': {'type': 'efficientnet', 'model_variant': 'v2_m'},
58 'efficientnet_v2_l': {'type': 'efficientnet', 'model_variant': 'v2_l'},
59 'regnet_y_400mf': {'type': 'regnet', 'model_variant': 'y_400mf'},
60 'regnet_y_800mf': {'type': 'regnet', 'model_variant': 'y_800mf'},
61 'regnet_y_1_6gf': {'type': 'regnet', 'model_variant': 'y_1_6gf'},
62 'regnet_y_3_2gf': {'type': 'regnet', 'model_variant': 'y_3_2gf'},
63 'regnet_y_8gf': {'type': 'regnet', 'model_variant': 'y_8gf'},
64 'regnet_y_16gf': {'type': 'regnet', 'model_variant': 'y_16gf'},
65 'regnet_y_32gf': {'type': 'regnet', 'model_variant': 'y_32gf'},
66 'regnet_y_128gf': {'type': 'regnet', 'model_variant': 'y_128gf'},
67 'regnet_x_400mf': {'type': 'regnet', 'model_variant': 'x_400mf'},
68 'regnet_x_800mf': {'type': 'regnet', 'model_variant': 'x_800mf'},
69 'regnet_x_1_6gf': {'type': 'regnet', 'model_variant': 'x_1_6gf'},
70 'regnet_x_3_2gf': {'type': 'regnet', 'model_variant': 'x_3_2gf'},
71 'regnet_x_8gf': {'type': 'regnet', 'model_variant': 'x_8gf'},
72 'regnet_x_16gf': {'type': 'regnet', 'model_variant': 'x_16gf'},
73 'regnet_x_32gf': {'type': 'regnet', 'model_variant': 'x_32gf'},
74 'vgg11': {'type': 'vgg', 'model_variant': 11},
75 'vgg11_bn': {'type': 'vgg', 'model_variant': '11_bn'},
76 'vgg13': {'type': 'vgg', 'model_variant': 13},
77 'vgg13_bn': {'type': 'vgg', 'model_variant': '13_bn'},
78 'vgg16': {'type': 'vgg', 'model_variant': 16},
79 'vgg16_bn': {'type': 'vgg', 'model_variant': '16_bn'},
80 'vgg19': {'type': 'vgg', 'model_variant': 19},
81 'vgg19_bn': {'type': 'vgg', 'model_variant': '19_bn'},
82 'shufflenet_v2_x0_5': {'type': 'shufflenet_v2', 'model_variant': 'x0_5'},
83 'shufflenet_v2_x1_0': {'type': 'shufflenet_v2', 'model_variant': 'x1_0'},
84 'shufflenet_v2_x1_5': {'type': 'shufflenet_v2', 'model_variant': 'x1_5'},
85 'shufflenet_v2_x2_0': {'type': 'shufflenet_v2', 'model_variant': 'x2_0'},
86 'squeezenet1_0': {'type': 'squeezenet', 'model_variant': '1_0'},
87 'squeezenet1_1': {'type': 'squeezenet', 'model_variant': '1_1'},
88 'swin_t': {'type': 'swin_transformer', 'model_variant': 't'},
89 'swin_s': {'type': 'swin_transformer', 'model_variant': 's'},
90 'swin_b': {'type': 'swin_transformer', 'model_variant': 'b'},
91 'swin_v2_t': {'type': 'swin_transformer', 'model_variant': 'v2_t'},
92 'swin_v2_s': {'type': 'swin_transformer', 'model_variant': 'v2_s'},
93 'swin_v2_b': {'type': 'swin_transformer', 'model_variant': 'v2_b'},
94 'vit_b_16': {'type': 'vision_transformer', 'model_variant': 'b_16'},
95 'vit_b_32': {'type': 'vision_transformer', 'model_variant': 'b_32'},
96 'vit_l_16': {'type': 'vision_transformer', 'model_variant': 'l_16'},
97 'vit_l_32': {'type': 'vision_transformer', 'model_variant': 'l_32'},
98 'vit_h_14': {'type': 'vision_transformer', 'model_variant': 'h_14'},
99 'convnext_tiny': {'type': 'convnext', 'model_variant': 'tiny'},
100 'convnext_small': {'type': 'convnext', 'model_variant': 'small'},
101 'convnext_base': {'type': 'convnext', 'model_variant': 'base'},
102 'convnext_large': {'type': 'convnext', 'model_variant': 'large'},
103 'maxvit_t': {'type': 'maxvit', 'model_variant': 't'},
104 'alexnet': {'type': 'alexnet'},
105 'googlenet': {'type': 'googlenet'},
106 'inception_v3': {'type': 'inception_v3'},
107 'mobilenet_v2': {'type': 'mobilenet_v2'},
108 'mobilenet_v3_large': {'type': 'mobilenet_v3_large'},
109 'mobilenet_v3_small': {'type': 'mobilenet_v3_small'},
110 }
111
112 # --- Logging Setup ---
113 logging.basicConfig(
114 level=logging.INFO,
115 format='%(asctime)s %(levelname)s %(name)s: %(message)s'
116 )
117 logger = logging.getLogger("ImageLearner")
118
119
120 def format_config_table_html(
121 config: dict,
122 split_info: Optional[str] = None,
123 training_progress: dict = None) -> str:
124 display_keys = [
125 "model_name",
126 "epochs",
127 "batch_size",
128 "fine_tune",
129 "use_pretrained",
130 "learning_rate",
131 "random_seed",
132 "early_stop",
133 ]
134
135 rows = []
136
137 for key in display_keys:
138 val = config.get(key, "N/A")
139 if key == "batch_size":
140 if val is not None:
141 val = int(val)
142 else:
143 if training_progress:
144 val = "Auto-selected batch size by Ludwig:<br>"
145 resolved_val = training_progress.get("batch_size")
146 val += (
147 f"<span style='font-size: 0.85em;'>{resolved_val}</span><br>"
148 )
149 else:
150 val = "auto"
151 if key == "learning_rate":
152 resolved_val = None
153 if val is None or val == "auto":
154 if training_progress:
155 resolved_val = training_progress.get("learning_rate")
156 val = (
157 "Auto-selected learning rate by Ludwig:<br>"
158 f"<span style='font-size: 0.85em;'>{resolved_val if resolved_val else val}</span><br>"
159 "<span style='font-size: 0.85em;'>"
160 "Based on model architecture and training setup (e.g., fine-tuning).<br>"
161 "See <a href='https://ludwig.ai/latest/configuration/trainer/#trainer-parameters' "
162 "target='_blank'>Ludwig Trainer Parameters</a> for details."
163 "</span>"
164 )
165 else:
166 val = (
167 "Auto-selected by Ludwig<br>"
168 "<span style='font-size: 0.85em;'>"
169 "Automatically tuned based on architecture and dataset.<br>"
170 "See <a href='https://ludwig.ai/latest/configuration/trainer/#trainer-parameters' "
171 "target='_blank'>Ludwig Trainer Parameters</a> for details."
172 "</span>"
173 )
174 else:
175 val = f"{val:.6f}"
176 if key == "epochs":
177 if training_progress and "epoch" in training_progress and val > training_progress["epoch"]:
178 val = (
179 f"Because of early stopping: the training"
180 f"stopped at epoch {training_progress['epoch']}"
181 )
182
183 if val is None:
184 continue
185 rows.append(
186 f"<tr>"
187 f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left;'>"
188 f"{key.replace('_', ' ').title()}</td>"
189 f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center;'>{val}</td>"
190 f"</tr>"
191 )
192
193 if split_info:
194 rows.append(
195 f"<tr>"
196 f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left;'>Data Split</td>"
197 f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center;'>{split_info}</td>"
198 f"</tr>"
199 )
200
201 return (
202 "<h2 style='text-align: center;'>Training Setup</h2>"
203 "<div style='display: flex; justify-content: center;'>"
204 "<table style='border-collapse: collapse; width: 60%; table-layout: auto;'>"
205 "<thead><tr>"
206 "<th style='padding: 10px; border: 1px solid #ccc; text-align: left;'>Parameter</th>"
207 "<th style='padding: 10px; border: 1px solid #ccc; text-align: center;'>Value</th>"
208 "</tr></thead><tbody>" + "".join(rows) + "</tbody></table></div><br>"
209 "<p style='text-align: center; font-size: 0.9em;'>"
210 "Model trained using Ludwig.<br>"
211 "If want to learn more about Ludwig default settings,"
212 "please check the their <a href='https://ludwig.ai' target='_blank'>website(ludwig.ai)</a>."
213 "</p><hr>"
214 )
215
216
217 def format_stats_table_html(training_stats: dict, test_stats: dict) -> str:
218 train_metrics = training_stats.get("training", {}).get("label", {})
219 val_metrics = training_stats.get("validation", {}).get("label", {})
220 test_metrics = test_stats.get("label", {})
221
222 all_metrics = set(train_metrics) | set(val_metrics) | set(test_metrics)
223
224 def get_last_value(stats, key):
225 val = stats.get(key)
226 if isinstance(val, list) and val:
227 return val[-1]
228 elif isinstance(val, (int, float)):
229 return val
230 return None
231
232 rows = []
233 for metric in sorted(all_metrics):
234 t = get_last_value(train_metrics, metric)
235 v = get_last_value(val_metrics, metric)
236 te = get_last_value(test_metrics, metric)
237 if all(x is not None for x in [t, v, te]):
238 row = (
239 f"<tr>"
240 f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left;'>{metric}</td>"
241 f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center;'>{t:.4f}</td>"
242 f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center;'>{v:.4f}</td>"
243 f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center;'>{te:.4f}</td>"
244 f"</tr>"
245 )
246 rows.append(row)
247
248 if not rows:
249 return "<p><em>No metric values found.</em></p>"
250
251 return (
252 "<h2 style='text-align: center;'>Model Performance Summary</h2>"
253 "<div style='display: flex; justify-content: center;'>"
254 "<table style='border-collapse: collapse; width: 80%; table-layout: fixed;'>"
255 "<colgroup>"
256 "<col style='width: 40%;'>"
257 "<col style='width: 20%;'>"
258 "<col style='width: 20%;'>"
259 "<col style='width: 20%;'>"
260 "</colgroup>"
261 "<thead><tr>"
262 "<th style='padding: 10px; border: 1px solid #ccc; text-align: left;'>Metric</th>"
263 "<th style='padding: 10px; border: 1px solid #ccc; text-align: center;'>Train</th>"
264 "<th style='padding: 10px; border: 1px solid #ccc; text-align: center;'>Validation</th>"
265 "<th style='padding: 10px; border: 1px solid #ccc; text-align: center;'>Test</th>"
266 "</tr></thead><tbody>" + "".join(rows) + "</tbody></table></div><br>"
267 )
268
269
270 def build_tabbed_html(
271 metrics_html: str,
272 train_viz_html: str,
273 test_viz_html: str) -> str:
274 return f"""
275 <style>
276 .tabs {{
277 display: flex;
278 border-bottom: 2px solid #ccc;
279 margin-bottom: 1rem;
280 }}
281 .tab {{
282 padding: 10px 20px;
283 cursor: pointer;
284 border: 1px solid #ccc;
285 border-bottom: none;
286 background: #f9f9f9;
287 margin-right: 5px;
288 border-top-left-radius: 8px;
289 border-top-right-radius: 8px;
290 }}
291 .tab.active {{
292 background: white;
293 font-weight: bold;
294 }}
295 .tab-content {{
296 display: none;
297 padding: 20px;
298 border: 1px solid #ccc;
299 border-top: none;
300 }}
301 .tab-content.active {{
302 display: block;
303 }}
304 </style>
305
306 <div class="tabs">
307 <div class="tab active" onclick="showTab('metrics')"> Config & Metrics</div>
308 <div class="tab" onclick="showTab('trainval')"> Train/Validation Plots</div>
309 <div class="tab" onclick="showTab('test')"> Test Plots</div>
310 </div>
311
312 <div id="metrics" class="tab-content active">
313 {metrics_html}
314 </div>
315 <div id="trainval" class="tab-content">
316 {train_viz_html}
317 </div>
318 <div id="test" class="tab-content">
319 {test_viz_html}
320 </div>
321
322 <script>
323 function showTab(id) {{
324 document.querySelectorAll('.tab-content').forEach(el => el.classList.remove('active'));
325 document.querySelectorAll('.tab').forEach(el => el.classList.remove('active'));
326 document.getElementById(id).classList.add('active');
327 document.querySelector(`.tab[onclick*="${{id}}"]`).classList.add('active');
328 }}
329 </script>
330 """
331
332
333 def split_data_0_2(
334 df: pd.DataFrame,
335 split_column: str,
336 validation_size: float = 0.15,
337 random_state: int = 42,
338 label_column: Optional[str] = None,
339 ) -> pd.DataFrame:
340 """
341 Given a DataFrame whose split_column only contains {0,2}, re-assign
342 a portion of the 0s to become 1s (validation). Returns a fresh DataFrame.
343 """
344 # Work on a copy
345 out = df.copy()
346 # Ensure split col is integer dtype
347 out[split_column] = pd.to_numeric(out[split_column], errors="coerce").astype(int)
348
349 idx_train = out.index[out[split_column] == 0].tolist()
350
351 if not idx_train:
352 logger.info("No rows with split=0; nothing to do.")
353 return out
354
355 # Determine stratify array if possible
356 stratify_arr = None
357 if label_column and label_column in out.columns:
358 # Only stratify if at least two classes and enough samples
359 label_counts = out.loc[idx_train, label_column].value_counts()
360 if label_counts.size > 1 and (label_counts.min() * validation_size) >= 1:
361 stratify_arr = out.loc[idx_train, label_column]
362 else:
363 logger.warning("Cannot stratify (too few labels); splitting without stratify.")
364
365 # Edge cases
366 if validation_size <= 0:
367 logger.info("validation_size <= 0; keeping all as train.")
368 return out
369 if validation_size >= 1:
370 logger.info("validation_size >= 1; moving all train → validation.")
371 out.loc[idx_train, split_column] = 1
372 return out
373
374 # Do the split
375 try:
376 train_idx, val_idx = train_test_split(
377 idx_train,
378 test_size=validation_size,
379 random_state=random_state,
380 stratify=stratify_arr
381 )
382 except ValueError as e:
383 logger.warning(f"Stratified split failed ({e}); retrying without stratify.")
384 train_idx, val_idx = train_test_split(
385 idx_train,
386 test_size=validation_size,
387 random_state=random_state,
388 stratify=None
389 )
390
391 # Assign new splits
392 out.loc[train_idx, split_column] = 0
393 out.loc[val_idx, split_column] = 1
394 # idx_test stays at 2
395
396 # Cast back to a clean integer type
397 out[split_column] = out[split_column].astype(int)
398 # print(out)
399 return out
400
401
402 class Backend(Protocol):
403 """Interface for a machine learning backend."""
404 def prepare_config(
405 self,
406 config_params: Dict[str, Any],
407 split_config: Dict[str, Any]
408 ) -> str:
409 ...
410
411 def run_experiment(
412 self,
413 dataset_path: Path,
414 config_path: Path,
415 output_dir: Path,
416 random_seed: int,
417 ) -> None:
418 ...
419
420 def generate_plots(
421 self,
422 output_dir: Path
423 ) -> None:
424 ...
425
426 def generate_html_report(
427 self,
428 title: str,
429 output_dir: str
430 ) -> Path:
431 ...
432
433
434 class LudwigDirectBackend:
435 """
436 Backend for running Ludwig experiments directly via the internal experiment_cli function.
437 """
438
439 def prepare_config(
440 self,
441 config_params: Dict[str, Any],
442 split_config: Dict[str, Any],
443 ) -> str:
444 """
445 Build and serialize the Ludwig YAML configuration.
446 """
447 logger.info("LudwigDirectBackend: Preparing YAML configuration.")
448
449 model_name = config_params.get("model_name", "resnet18")
450 use_pretrained = config_params.get("use_pretrained", False)
451 fine_tune = config_params.get("fine_tune", False)
452 epochs = config_params.get("epochs", 10)
453 batch_size = config_params.get("batch_size")
454 num_processes = config_params.get("preprocessing_num_processes", 1)
455 early_stop = config_params.get("early_stop", None)
456 learning_rate = config_params.get("learning_rate")
457 learning_rate = "auto" if learning_rate is None else float(learning_rate)
458 trainable = fine_tune or (not use_pretrained)
459 if not use_pretrained and not trainable:
460 logger.warning("trainable=False; use_pretrained=False is ignored.")
461 logger.warning("Setting trainable=True to train the model from scratch.")
462 trainable = True
463
464 # Encoder setup
465 raw_encoder = MODEL_ENCODER_TEMPLATES.get(model_name, model_name)
466 if isinstance(raw_encoder, dict):
467 encoder_config = {
468 **raw_encoder,
469 "use_pretrained": use_pretrained,
470 "trainable": trainable,
471 }
472 else:
473 encoder_config = {"type": raw_encoder}
474
475 # Trainer & optimizer
476 # optimizer = {"type": "adam", "learning_rate": 5e-5} if fine_tune else {"type": "adam"}
477 batch_size_cfg = batch_size or "auto"
478
479 conf: Dict[str, Any] = {
480 "model_type": "ecd",
481 "input_features": [
482 {
483 "name": IMAGE_PATH_COLUMN_NAME,
484 "type": "image",
485 "encoder": encoder_config,
486 }
487 ],
488 "output_features": [
489 {"name": LABEL_COLUMN_NAME, "type": "category"}
490 ],
491 "combiner": {"type": "concat"},
492 "trainer": {
493 "epochs": epochs,
494 "early_stop": early_stop,
495 "batch_size": batch_size_cfg,
496 "learning_rate": learning_rate,
497 },
498 "preprocessing": {
499 "split": split_config,
500 "num_processes": num_processes,
501 "in_memory": False,
502 },
503 }
504
505 logger.debug("LudwigDirectBackend: Config dict built.")
506 try:
507 yaml_str = yaml.dump(conf, sort_keys=False, indent=2)
508 logger.info("LudwigDirectBackend: YAML config generated.")
509 return yaml_str
510 except Exception:
511 logger.error("LudwigDirectBackend: Failed to serialize YAML.", exc_info=True)
512 raise
513
514 def run_experiment(
515 self,
516 dataset_path: Path,
517 config_path: Path,
518 output_dir: Path,
519 random_seed: int = 42,
520 ) -> None:
521 """
522 Invoke Ludwig's internal experiment_cli function to run the experiment.
523 """
524 logger.info("LudwigDirectBackend: Starting experiment execution.")
525
526 try:
527 from ludwig.experiment import experiment_cli
528 except ImportError as e:
529 logger.error(
530 "LudwigDirectBackend: Could not import experiment_cli.",
531 exc_info=True
532 )
533 raise RuntimeError("Ludwig import failed.") from e
534
535 output_dir.mkdir(parents=True, exist_ok=True)
536
537 try:
538 experiment_cli(
539 dataset=str(dataset_path),
540 config=str(config_path),
541 output_directory=str(output_dir),
542 random_seed=random_seed,
543 )
544 logger.info(f"LudwigDirectBackend: Experiment completed. Results in {output_dir}")
545 except TypeError as e:
546 logger.error(
547 "LudwigDirectBackend: Argument mismatch in experiment_cli call.",
548 exc_info=True
549 )
550 raise RuntimeError("Ludwig argument error.") from e
551 except Exception:
552 logger.error(
553 "LudwigDirectBackend: Experiment execution error.",
554 exc_info=True
555 )
556 raise
557
558 def get_training_process(self, output_dir) -> float:
559 """
560 Retrieve the learning rate used in the most recent Ludwig run.
561 Returns:
562 float: learning rate (or None if not found)
563 """
564 output_dir = Path(output_dir)
565 exp_dirs = sorted(
566 output_dir.glob("experiment_run*"),
567 key=lambda p: p.stat().st_mtime
568 )
569
570 if not exp_dirs:
571 logger.warning(f"No experiment run directories found in {output_dir}")
572 return None
573
574 progress_file = exp_dirs[-1] / "model" / "training_progress.json"
575 if not progress_file.exists():
576 logger.warning(f"No training_progress.json found in {progress_file}")
577 return None
578
579 try:
580 with progress_file.open("r", encoding="utf-8") as f:
581 data = json.load(f)
582 return {
583 "learning_rate": data.get("learning_rate"),
584 "batch_size": data.get("batch_size"),
585 "epoch": data.get("epoch"),
586 }
587 except Exception as e:
588 self.logger.warning(f"Failed to read training progress info: {e}")
589 return {}
590
591 def convert_parquet_to_csv(self, output_dir: Path):
592 """Convert the predictions Parquet file to CSV."""
593 output_dir = Path(output_dir)
594 exp_dirs = sorted(
595 output_dir.glob("experiment_run*"),
596 key=lambda p: p.stat().st_mtime
597 )
598 if not exp_dirs:
599 logger.warning(f"No experiment run dirs found in {output_dir}")
600 return
601 exp_dir = exp_dirs[-1]
602 parquet_path = exp_dir / PREDICTIONS_PARQUET_FILE_NAME
603 csv_path = exp_dir / "predictions.csv"
604 try:
605 df = pd.read_parquet(parquet_path)
606 df.to_csv(csv_path, index=False)
607 logger.info(f"Converted Parquet to CSV: {csv_path}")
608 except Exception as e:
609 logger.error(f"Error converting Parquet to CSV: {e}")
610
611 def generate_plots(self, output_dir: Path) -> None:
612 """
613 Generate _all_ registered Ludwig visualizations for the latest experiment run.
614 """
615 logger.info("Generating all Ludwig visualizations…")
616
617 test_plots = {
618 'compare_performance',
619 'compare_classifiers_performance_from_prob',
620 'compare_classifiers_performance_from_pred',
621 'compare_classifiers_performance_changing_k',
622 'compare_classifiers_multiclass_multimetric',
623 'compare_classifiers_predictions',
624 'confidence_thresholding_2thresholds_2d',
625 'confidence_thresholding_2thresholds_3d',
626 'confidence_thresholding',
627 'confidence_thresholding_data_vs_acc',
628 'binary_threshold_vs_metric',
629 'roc_curves',
630 'roc_curves_from_test_statistics',
631 'calibration_1_vs_all',
632 'calibration_multiclass',
633 'confusion_matrix',
634 'frequency_vs_f1',
635 }
636 train_plots = {
637 'learning_curves',
638 'compare_classifiers_performance_subset',
639 }
640
641 # 1) find the most recent experiment directory
642 output_dir = Path(output_dir)
643 exp_dirs = sorted(
644 output_dir.glob("experiment_run*"),
645 key=lambda p: p.stat().st_mtime
646 )
647 if not exp_dirs:
648 logger.warning(f"No experiment run dirs found in {output_dir}")
649 return
650 exp_dir = exp_dirs[-1]
651
652 # 2) ensure viz output subfolder exists
653 viz_dir = exp_dir / "visualizations"
654 viz_dir.mkdir(exist_ok=True)
655 train_viz = viz_dir / "train"
656 test_viz = viz_dir / "test"
657 train_viz.mkdir(parents=True, exist_ok=True)
658 test_viz.mkdir(parents=True, exist_ok=True)
659
660 # 3) helper to check file existence
661 def _check(p: Path) -> Optional[str]:
662 return str(p) if p.exists() else None
663
664 # 4) gather standard Ludwig output files
665 training_stats = _check(exp_dir / "training_statistics.json")
666 test_stats = _check(exp_dir / TEST_STATISTICS_FILE_NAME)
667 probs_path = _check(exp_dir / PREDICTIONS_PARQUET_FILE_NAME)
668 gt_metadata = _check(exp_dir / "model" / TRAIN_SET_METADATA_FILE_NAME)
669
670 # 5) try to read original dataset & split file from description.json
671 dataset_path = None
672 split_file = None
673 desc = exp_dir / DESCRIPTION_FILE_NAME
674 if desc.exists():
675 with open(desc, "r") as f:
676 cfg = json.load(f)
677 dataset_path = _check(Path(cfg.get("dataset", "")))
678 split_file = _check(Path(get_split_path(cfg.get("dataset", ""))))
679
680 # 6) infer output feature name
681 output_feature = ""
682 if desc.exists():
683 try:
684 output_feature = cfg["config"]["output_features"][0]["name"]
685 except Exception:
686 pass
687 if not output_feature and test_stats:
688 with open(test_stats, "r") as f:
689 stats = json.load(f)
690 output_feature = next(iter(stats.keys()), "")
691
692 # 7) loop through every registered viz
693 viz_registry = get_visualizations_registry()
694 for viz_name, viz_func in viz_registry.items():
695 viz_dir_plot = None
696 if viz_name in train_plots:
697 viz_dir_plot = train_viz
698 elif viz_name in test_plots:
699 viz_dir_plot = test_viz
700
701 try:
702 viz_func(
703 training_statistics=[training_stats] if training_stats else [],
704 test_statistics=[test_stats] if test_stats else [],
705 probabilities=[probs_path] if probs_path else [],
706 output_feature_name=output_feature,
707 ground_truth_split=2,
708 top_n_classes=[0],
709 top_k=3,
710 ground_truth_metadata=gt_metadata,
711 ground_truth=dataset_path,
712 split_file=split_file,
713 output_directory=str(viz_dir_plot),
714 normalize=False,
715 file_format="png",
716 )
717 logger.info(f"✔ Generated {viz_name}")
718 except Exception as e:
719 logger.warning(f"✘ Skipped {viz_name}: {e}")
720
721 logger.info(f"All visualizations written to {viz_dir}")
722
723 def generate_html_report(
724 self,
725 title: str,
726 output_dir: str,
727 config: dict,
728 split_info: str) -> Path:
729 """
730 Assemble an HTML report from visualizations under train_val/ and test/ folders.
731 """
732 cwd = Path.cwd()
733 report_name = title.lower().replace(" ", "_") + "_report.html"
734 report_path = cwd / report_name
735 output_dir = Path(output_dir)
736
737 # Find latest experiment dir
738 exp_dirs = sorted(output_dir.glob("experiment_run*"), key=lambda p: p.stat().st_mtime)
739 if not exp_dirs:
740 raise RuntimeError(f"No 'experiment*' dirs found in {output_dir}")
741 exp_dir = exp_dirs[-1]
742
743 base_viz_dir = exp_dir / "visualizations"
744 train_viz_dir = base_viz_dir / "train"
745 test_viz_dir = base_viz_dir / "test"
746
747 html = get_html_template()
748 html += f"<h1>{title}</h1>"
749
750 metrics_html = ""
751
752 # Load and embed metrics table (training/val/test stats)
753 try:
754 train_stats_path = exp_dir / "training_statistics.json"
755 test_stats_path = exp_dir / TEST_STATISTICS_FILE_NAME
756 if train_stats_path.exists() and test_stats_path.exists():
757 with open(train_stats_path) as f:
758 train_stats = json.load(f)
759 with open(test_stats_path) as f:
760 test_stats = json.load(f)
761 output_feature = next(iter(train_stats.keys()), "")
762 if output_feature:
763 metrics_html += format_stats_table_html(train_stats, test_stats)
764 except Exception as e:
765 logger.warning(f"Could not load stats for HTML report: {e}")
766
767 config_html = ""
768 training_progress = self.get_training_process(output_dir)
769 try:
770 config_html = format_config_table_html(config, split_info, training_progress)
771 except Exception as e:
772 logger.warning(f"Could not load config for HTML report: {e}")
773
774 def render_img_section(title: str, dir_path: Path) -> str:
775 if not dir_path.exists():
776 return f"<h2>{title}</h2><p><em>Directory not found.</em></p>"
777 imgs = sorted(dir_path.glob("*.png"))
778 if not imgs:
779 return f"<h2>{title}</h2><p><em>No plots found.</em></p>"
780
781 section_html = f"<h2 style='text-align: center;'>{title}</h2><div>"
782 for img in imgs:
783 b64 = encode_image_to_base64(str(img))
784 section_html += (
785 f'<div class="plot" style="margin-bottom:20px;text-align:center;">'
786 f"<h3>{img.stem.replace('_',' ').title()}</h3>"
787 f'<img src="data:image/png;base64,{b64}" '
788 'style="max-width:90%;max-height:600px;border:1px solid #ddd;" />'
789 "</div>"
790 )
791 section_html += "</div>"
792 return section_html
793
794 train_plots_html = render_img_section("Training & Validation Visualizations", train_viz_dir)
795 test_plots_html = render_img_section("Test Visualizations", test_viz_dir)
796 html += build_tabbed_html(config_html + metrics_html, train_plots_html, test_plots_html)
797 html += get_html_closing()
798
799 try:
800 with open(report_path, "w") as f:
801 f.write(html)
802 logger.info(f"HTML report generated at: {report_path}")
803 except Exception as e:
804 logger.error(f"Failed to write HTML report: {e}")
805 raise
806
807 return report_path
808
809
810 class WorkflowOrchestrator:
811 """
812 Manages the image-classification workflow:
813 1. Creates temp dirs
814 2. Extracts images
815 3. Prepares data (CSV + splits)
816 4. Renders a backend config
817 5. Runs the experiment
818 6. Cleans up
819 """
820
821 def __init__(self, args: argparse.Namespace, backend: Backend):
822 self.args = args
823 self.backend = backend
824 self.temp_dir: Optional[Path] = None
825 self.image_extract_dir: Optional[Path] = None
826 logger.info(f"Orchestrator initialized with backend: {type(backend).__name__}")
827
828 def _create_temp_dirs(self) -> None:
829 """Create temporary output and image extraction directories."""
830 try:
831 self.temp_dir = Path(tempfile.mkdtemp(
832 dir=self.args.output_dir,
833 prefix=TEMP_DIR_PREFIX
834 ))
835 self.image_extract_dir = self.temp_dir / "images"
836 self.image_extract_dir.mkdir()
837 logger.info(f"Created temp directory: {self.temp_dir}")
838 except Exception:
839 logger.error("Failed to create temporary directories", exc_info=True)
840 raise
841
842 def _extract_images(self) -> None:
843 """Extract images from ZIP into the temp image directory."""
844 if self.image_extract_dir is None:
845 raise RuntimeError("Temp image directory not initialized.")
846 logger.info(f"Extracting images from {self.args.image_zip} → {self.image_extract_dir}")
847 try:
848 with zipfile.ZipFile(self.args.image_zip, "r") as z:
849 z.extractall(self.image_extract_dir)
850 logger.info("Image extraction complete.")
851 except Exception:
852 logger.error("Error extracting zip file", exc_info=True)
853 raise
854
855 def _prepare_data(self) -> Tuple[Path, Dict[str, Any]]:
856 """
857 Load CSV, update image paths, handle splits, and write prepared CSV.
858 Returns:
859 final_csv_path: Path to the prepared CSV
860 split_config: Dict for backend split settings
861 """
862 if not self.temp_dir or not self.image_extract_dir:
863 raise RuntimeError("Temp dirs not initialized before data prep.")
864
865 # 1) Load
866 try:
867 df = pd.read_csv(self.args.csv_file)
868 logger.info(f"Loaded CSV: {self.args.csv_file}")
869 except Exception:
870 logger.error("Error loading CSV file", exc_info=True)
871 raise
872
873 # 2) Validate columns
874 required = {IMAGE_PATH_COLUMN_NAME, LABEL_COLUMN_NAME}
875 missing = required - set(df.columns)
876 if missing:
877 raise ValueError(f"Missing CSV columns: {', '.join(missing)}")
878
879 # 3) Update image paths
880 try:
881 df[IMAGE_PATH_COLUMN_NAME] = df[IMAGE_PATH_COLUMN_NAME].apply(
882 lambda p: str((self.image_extract_dir / p).resolve())
883 )
884 except Exception:
885 logger.error("Error updating image paths", exc_info=True)
886 raise
887
888 # 4) Handle splits
889 if SPLIT_COLUMN_NAME in df.columns:
890 df, split_config, split_info = self._process_fixed_split(df)
891 else:
892 logger.info("No split column; using random split")
893 split_config = {
894 "type": "random",
895 "probabilities": self.args.split_probabilities
896 }
897 split_info = (
898 f"No split column in CSV. Used random split: "
899 f"{[int(p*100) for p in self.args.split_probabilities]}% for train/val/test."
900 )
901
902 # 5) Write out prepared CSV
903 final_csv = TEMP_CSV_FILENAME
904 try:
905 df.to_csv(final_csv, index=False)
906 logger.info(f"Saved prepared data to {final_csv}")
907 except Exception:
908 logger.error("Error saving prepared CSV", exc_info=True)
909 raise
910
911 return final_csv, split_config, split_info
912
913 def _process_fixed_split(self, df: pd.DataFrame) -> Dict[str, Any]:
914 """Process a fixed split column (0=train,1=val,2=test)."""
915 logger.info(f"Fixed split column '{SPLIT_COLUMN_NAME}' detected.")
916 try:
917 col = df[SPLIT_COLUMN_NAME]
918 df[SPLIT_COLUMN_NAME] = pd.to_numeric(col, errors="coerce").astype(pd.Int64Dtype())
919 if df[SPLIT_COLUMN_NAME].isna().any():
920 logger.warning("Split column contains non-numeric/missing values.")
921
922 unique = set(df[SPLIT_COLUMN_NAME].dropna().unique())
923 logger.info(f"Unique split values: {unique}")
924
925 if unique == {0, 2}:
926 df = split_data_0_2(
927 df, SPLIT_COLUMN_NAME,
928 validation_size=self.args.validation_size,
929 label_column=LABEL_COLUMN_NAME,
930 random_state=self.args.random_seed
931 )
932 split_info = (
933 "Detected a split column (with values 0 and 2) in the input CSV. "
934 f"Used this column as a base and"
935 f"reassigned {self.args.validation_size * 100:.1f}% "
936 "of the training set (originally labeled 0) to validation (labeled 1)."
937 )
938
939 logger.info("Applied custom 0/2 split.")
940 elif unique.issubset({0, 1, 2}):
941 split_info = "Used user-defined split column from CSV."
942 logger.info("Using fixed split as-is.")
943 else:
944 raise ValueError(f"Unexpected split values: {unique}")
945
946 return df, {"type": "fixed", "column": SPLIT_COLUMN_NAME}, split_info
947
948 except Exception:
949 logger.error("Error processing fixed split", exc_info=True)
950 raise
951
952 def _cleanup_temp_dirs(self) -> None:
953 """Remove any temporary directories."""
954 if self.temp_dir and self.temp_dir.exists():
955 logger.info(f"Cleaning up temp directory: {self.temp_dir}")
956 shutil.rmtree(self.temp_dir, ignore_errors=True)
957 self.temp_dir = None
958 self.image_extract_dir = None
959
960 def run(self) -> None:
961 """Execute the full workflow end-to-end."""
962 logger.info("Starting workflow...")
963 self.args.output_dir.mkdir(parents=True, exist_ok=True)
964
965 try:
966 self._create_temp_dirs()
967 self._extract_images()
968 csv_path, split_cfg, split_info = self._prepare_data()
969
970 use_pretrained = self.args.use_pretrained or self.args.fine_tune
971
972 backend_args = {
973 "model_name": self.args.model_name,
974 "fine_tune": self.args.fine_tune,
975 "use_pretrained": use_pretrained,
976 "epochs": self.args.epochs,
977 "batch_size": self.args.batch_size,
978 "preprocessing_num_processes": self.args.preprocessing_num_processes,
979 "split_probabilities": self.args.split_probabilities,
980 "learning_rate": self.args.learning_rate,
981 "random_seed": self.args.random_seed,
982 "early_stop": self.args.early_stop,
983 }
984 yaml_str = self.backend.prepare_config(backend_args, split_cfg)
985
986 config_file = self.temp_dir / TEMP_CONFIG_FILENAME
987 config_file.write_text(yaml_str)
988 logger.info(f"Wrote backend config: {config_file}")
989
990 self.backend.run_experiment(
991 csv_path,
992 config_file,
993 self.args.output_dir,
994 self.args.random_seed
995 )
996 logger.info("Workflow completed successfully.")
997 self.backend.generate_plots(self.args.output_dir)
998 report_file = self.backend.generate_html_report(
999 "Image Classification Results",
1000 self.args.output_dir,
1001 backend_args,
1002 split_info
1003 )
1004 logger.info(f"HTML report generated at: {report_file}")
1005 self.backend.convert_parquet_to_csv(self.args.output_dir)
1006 logger.info("Converted Parquet to CSV.")
1007 except Exception:
1008 logger.error("Workflow execution failed", exc_info=True)
1009 raise
1010
1011 finally:
1012 self._cleanup_temp_dirs()
1013
1014
1015 def parse_learning_rate(s):
1016 try:
1017 return float(s)
1018 except (TypeError, ValueError):
1019 return None
1020
1021
1022 class SplitProbAction(argparse.Action):
1023 def __call__(self, parser, namespace, values, option_string=None):
1024 # values is a list of three floats
1025 train, val, test = values
1026 total = train + val + test
1027 if abs(total - 1.0) > 1e-6:
1028 parser.error(
1029 f"--split-probabilities must sum to 1.0; "
1030 f"got {train:.3f} + {val:.3f} + {test:.3f} = {total:.3f}"
1031 )
1032 setattr(namespace, self.dest, values)
1033
1034
1035 def main():
1036
1037 parser = argparse.ArgumentParser(
1038 description="Image Classification Learner with Pluggable Backends"
1039 )
1040 parser.add_argument(
1041 "--csv-file", required=True, type=Path,
1042 help="Path to the input CSV"
1043 )
1044 parser.add_argument(
1045 "--image-zip", required=True, type=Path,
1046 help="Path to the images ZIP"
1047 )
1048 parser.add_argument(
1049 "--model-name", required=True,
1050 choices=MODEL_ENCODER_TEMPLATES.keys(),
1051 help="Which model template to use"
1052 )
1053 parser.add_argument(
1054 "--use-pretrained", action="store_true",
1055 help="Use pretrained weights for the model"
1056 )
1057 parser.add_argument(
1058 "--fine-tune", action="store_true",
1059 help="Enable fine-tuning"
1060 )
1061 parser.add_argument(
1062 "--epochs", type=int, default=10,
1063 help="Number of training epochs"
1064 )
1065 parser.add_argument(
1066 "--early-stop", type=int, default=5,
1067 help="Early stopping patience"
1068 )
1069 parser.add_argument(
1070 "--batch-size", type=int,
1071 help="Batch size (None = auto)"
1072 )
1073 parser.add_argument(
1074 "--output-dir", type=Path, default=Path("learner_output"),
1075 help="Where to write outputs"
1076 )
1077 parser.add_argument(
1078 "--validation-size", type=float, default=0.15,
1079 help="Fraction for validation (0.0–1.0)"
1080 )
1081 parser.add_argument(
1082 "--preprocessing-num-processes", type=int,
1083 default=max(1, os.cpu_count() // 2),
1084 help="CPU processes for data prep"
1085 )
1086 parser.add_argument(
1087 "--split-probabilities", type=float, nargs=3,
1088 metavar=("train", "val", "test"),
1089 action=SplitProbAction,
1090 default=[0.7, 0.1, 0.2],
1091 help="Random split proportions (e.g., 0.7 0.1 0.2). Only used if no split column is present."
1092 )
1093 parser.add_argument(
1094 "--random-seed", type=int, default=42,
1095 help="Random seed used for dataset splitting (default: 42)"
1096 )
1097 parser.add_argument(
1098 "--learning-rate", type=parse_learning_rate, default=None,
1099 help="Learning rate. If not provided, Ludwig will auto-select it."
1100 )
1101
1102 args = parser.parse_args()
1103
1104 # -- Validation --
1105 if not 0.0 <= args.validation_size <= 1.0:
1106 parser.error("validation-size must be between 0.0 and 1.0")
1107 if not args.csv_file.is_file():
1108 parser.error(f"CSV not found: {args.csv_file}")
1109 if not args.image_zip.is_file():
1110 parser.error(f"ZIP not found: {args.image_zip}")
1111
1112 # --- Instantiate Backend and Orchestrator ---
1113 # Use the new LudwigDirectBackend
1114 backend_instance = LudwigDirectBackend()
1115 orchestrator = WorkflowOrchestrator(args, backend_instance)
1116
1117 # --- Run Workflow ---
1118 exit_code = 0
1119 try:
1120 orchestrator.run()
1121 logger.info("Main script finished successfully.")
1122 except Exception as e:
1123 logger.error(f"Main script failed.{e}")
1124 exit_code = 1
1125 finally:
1126 sys.exit(exit_code)
1127
1128
1129 if __name__ == '__main__':
1130 try:
1131 import ludwig
1132 logger.debug(f"Found Ludwig version: {ludwig.globals.LUDWIG_VERSION}")
1133 except ImportError:
1134 logger.error("Ludwig library not found. Please ensure Ludwig is installed ('pip install ludwig[image]')")
1135 sys.exit(1)
1136
1137 main()