comparison image_learner_cli.py @ 9:9e912fce264c draft default tip

planemo upload for repository https://github.com/goeckslab/gleam.git commit eace0d7c2b2939029c052991d238a54947d2e191
author goeckslab
date Wed, 27 Aug 2025 21:02:48 +0000
parents 85e6f4b2ad18
children
comparison
equal deleted inserted replaced
8:85e6f4b2ad18 9:9e912fce264c
19 METRIC_DISPLAY_NAMES, 19 METRIC_DISPLAY_NAMES,
20 MODEL_ENCODER_TEMPLATES, 20 MODEL_ENCODER_TEMPLATES,
21 SPLIT_COLUMN_NAME, 21 SPLIT_COLUMN_NAME,
22 TEMP_CONFIG_FILENAME, 22 TEMP_CONFIG_FILENAME,
23 TEMP_CSV_FILENAME, 23 TEMP_CSV_FILENAME,
24 TEMP_DIR_PREFIX 24 TEMP_DIR_PREFIX,
25 ) 25 )
26 from ludwig.globals import ( 26 from ludwig.globals import (
27 DESCRIPTION_FILE_NAME, 27 DESCRIPTION_FILE_NAME,
28 PREDICTIONS_PARQUET_FILE_NAME, 28 PREDICTIONS_PARQUET_FILE_NAME,
29 TEST_STATISTICS_FILE_NAME, 29 TEST_STATISTICS_FILE_NAME,
36 from utils import ( 36 from utils import (
37 build_tabbed_html, 37 build_tabbed_html,
38 encode_image_to_base64, 38 encode_image_to_base64,
39 get_html_closing, 39 get_html_closing,
40 get_html_template, 40 get_html_template,
41 get_metrics_help_modal 41 get_metrics_help_modal,
42 ) 42 )
43 43
44 # --- Logging Setup --- 44 # --- Logging Setup ---
45 logging.basicConfig( 45 logging.basicConfig(
46 level=logging.INFO, 46 level=logging.INFO,
47 format='%(asctime)s %(levelname)s %(name)s: %(message)s', 47 format="%(asctime)s %(levelname)s %(name)s: %(message)s",
48 ) 48 )
49 logger = logging.getLogger("ImageLearner") 49 logger = logging.getLogger("ImageLearner")
50 50
51 51
52 def format_config_table_html( 52 def format_config_table_html(
65 "learning_rate", 65 "learning_rate",
66 "random_seed", 66 "random_seed",
67 "early_stop", 67 "early_stop",
68 "threshold", 68 "threshold",
69 ] 69 ]
70
70 rows = [] 71 rows = []
72
71 for key in display_keys: 73 for key in display_keys:
72 val = config.get(key, None) 74 val = config.get(key, None)
73 if key == "threshold": 75 if key == "threshold":
74 if output_type != "binary": 76 if output_type != "binary":
75 continue 77 continue
132 ) 134 )
133 else: 135 else:
134 val_str = val 136 val_str = val
135 else: 137 else:
136 val_str = val if val is not None else "N/A" 138 val_str = val if val is not None else "N/A"
137 if val_str == "N/A" and key not in ["task_type"]: # Skip if N/A for non-essential 139 if val_str == "N/A" and key not in [
140 "task_type"
141 ]: # Skip if N/A for non-essential
138 continue 142 continue
139 rows.append( 143 rows.append(
140 f"<tr>" 144 f"<tr>"
141 f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left;'>" 145 f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left;'>"
142 f"{key.replace('_', ' ').title()}</td>" 146 f"{key.replace('_', ' ').title()}</td>"
164 <thead><tr> 168 <thead><tr>
165 <th style="padding: 10px; border: 1px solid #ccc; text-align: left; white-space: nowrap; overflow: hidden; text-overflow: ellipsis;">Parameter</th> 169 <th style="padding: 10px; border: 1px solid #ccc; text-align: left; white-space: nowrap; overflow: hidden; text-overflow: ellipsis;">Parameter</th>
166 <th style="padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap; overflow: hidden; text-overflow: ellipsis;">Value</th> 170 <th style="padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap; overflow: hidden; text-overflow: ellipsis;">Value</th>
167 </tr></thead> 171 </tr></thead>
168 <tbody> 172 <tbody>
169 {''.join(rows)} 173 {"".join(rows)}
170 </tbody> 174 </tbody>
171 </table> 175 </table>
172 </div><br> 176 </div><br>
173 <p style="text-align: center; font-size: 0.9em;"> 177 <p style="text-align: center; font-size: 0.9em;">
174 Model trained using <a href="https://ludwig.ai/" target="_blank" rel="noopener noreferrer">Ludwig</a>. 178 Model trained using <a href="https://ludwig.ai/" target="_blank" rel="noopener noreferrer">Ludwig</a>.
249 "accuracy_micro": get_last_value(label_stats, "accuracy_micro"), 253 "accuracy_micro": get_last_value(label_stats, "accuracy_micro"),
250 "loss": get_last_value(label_stats, "loss"), 254 "loss": get_last_value(label_stats, "loss"),
251 "roc_auc": get_last_value(label_stats, "roc_auc"), 255 "roc_auc": get_last_value(label_stats, "roc_auc"),
252 "hits_at_k": get_last_value(label_stats, "hits_at_k"), 256 "hits_at_k": get_last_value(label_stats, "hits_at_k"),
253 } 257 }
258
254 # Test metrics: dynamic extraction according to exclusions 259 # Test metrics: dynamic extraction according to exclusions
255 test_label_stats = test_stats.get("label", {}) 260 test_label_stats = test_stats.get("label", {})
256 if not test_label_stats: 261 if not test_label_stats:
257 logging.warning("No label statistics found for test split") 262 logging.warning("No label statistics found for test split")
258 else: 263 else:
259 combined_stats = test_stats.get("combined", {}) 264 combined_stats = test_stats.get("combined", {})
260 overall_stats = test_label_stats.get("overall_stats", {}) 265 overall_stats = test_label_stats.get("overall_stats", {})
266
261 # Define exclusions 267 # Define exclusions
262 if output_type == "binary": 268 if output_type == "binary":
263 exclude = {"per_class_stats", "precision_recall_curve", "roc_curve"} 269 exclude = {"per_class_stats", "precision_recall_curve", "roc_curve"}
264 else: 270 else:
265 exclude = {"per_class_stats", "confusion_matrix"} 271 exclude = {"per_class_stats", "confusion_matrix"}
272
266 # 1. Get all scalar test_label_stats not excluded 273 # 1. Get all scalar test_label_stats not excluded
267 test_metrics = {} 274 test_metrics = {}
268 for k, v in test_label_stats.items(): 275 for k, v in test_label_stats.items():
269 if k in exclude: 276 if k in exclude:
270 continue 277 continue
271 if k == "overall_stats": 278 if k == "overall_stats":
272 continue 279 continue
273 if isinstance(v, (int, float, str, bool)): 280 if isinstance(v, (int, float, str, bool)):
274 test_metrics[k] = v 281 test_metrics[k] = v
282
275 # 2. Add overall_stats (flattened) 283 # 2. Add overall_stats (flattened)
276 for k, v in overall_stats.items(): 284 for k, v in overall_stats.items():
277 test_metrics[k] = v 285 test_metrics[k] = v
286
278 # 3. Optionally include combined/loss if present and not already 287 # 3. Optionally include combined/loss if present and not already
279 if "loss" in combined_stats and "loss" not in test_metrics: 288 if "loss" in combined_stats and "loss" not in test_metrics:
280 test_metrics["loss"] = combined_stats["loss"] 289 test_metrics["loss"] = combined_stats["loss"]
281 metrics["test"] = test_metrics 290 metrics["test"] = test_metrics
282 return metrics 291 return metrics
313 t = all_metrics["training"].get(metric_key) 322 t = all_metrics["training"].get(metric_key)
314 v = all_metrics["validation"].get(metric_key) 323 v = all_metrics["validation"].get(metric_key)
315 te = all_metrics["test"].get(metric_key) 324 te = all_metrics["test"].get(metric_key)
316 if all(x is not None for x in [t, v, te]): 325 if all(x is not None for x in [t, v, te]):
317 rows.append([display_name, f"{t:.4f}", f"{v:.4f}", f"{te:.4f}"]) 326 rows.append([display_name, f"{t:.4f}", f"{v:.4f}", f"{te:.4f}"])
327
318 if not rows: 328 if not rows:
319 return "<table><tr><td>No metric values found.</td></tr></table>" 329 return "<table><tr><td>No metric values found.</td></tr></table>"
330
320 html = ( 331 html = (
321 "<h2 style='text-align: center;'>Model Performance Summary</h2>" 332 "<h2 style='text-align: center;'>Model Performance Summary</h2>"
322 "<div style='display: flex; justify-content: center;'>" 333 "<div style='display: flex; justify-content: center;'>"
323 "<table class='performance-summary' style='border-collapse: collapse;'>" 334 "<table class='performance-summary' style='border-collapse: collapse;'>"
324 "<thead><tr>" 335 "<thead><tr>"
329 "</tr></thead><tbody>" 340 "</tr></thead><tbody>"
330 ) 341 )
331 for row in rows: 342 for row in rows:
332 html += generate_table_row( 343 html += generate_table_row(
333 row, 344 row,
334 "padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;" 345 "padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;",
335 ) 346 )
336 html += "</tbody></table></div><br>" 347 html += "</tbody></table></div><br>"
337 return html 348 return html
338 349
339 350
355 ) 366 )
356 t = all_metrics["training"].get(metric_key) 367 t = all_metrics["training"].get(metric_key)
357 v = all_metrics["validation"].get(metric_key) 368 v = all_metrics["validation"].get(metric_key)
358 if t is not None and v is not None: 369 if t is not None and v is not None:
359 rows.append([display_name, f"{t:.4f}", f"{v:.4f}"]) 370 rows.append([display_name, f"{t:.4f}", f"{v:.4f}"])
371
360 if not rows: 372 if not rows:
361 return "<table><tr><td>No metric values found for Train/Validation.</td></tr></table>" 373 return "<table><tr><td>No metric values found for Train/Validation.</td></tr></table>"
374
362 html = ( 375 html = (
363 "<h2 style='text-align: center;'>Train/Validation Performance Summary</h2>" 376 "<h2 style='text-align: center;'>Train/Validation Performance Summary</h2>"
364 "<div style='display: flex; justify-content: center;'>" 377 "<div style='display: flex; justify-content: center;'>"
365 "<table class='performance-summary' style='border-collapse: collapse;'>" 378 "<table class='performance-summary' style='border-collapse: collapse;'>"
366 "<thead><tr>" 379 "<thead><tr>"
370 "</tr></thead><tbody>" 383 "</tr></thead><tbody>"
371 ) 384 )
372 for row in rows: 385 for row in rows:
373 html += generate_table_row( 386 html += generate_table_row(
374 row, 387 row,
375 "padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;" 388 "padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;",
376 ) 389 )
377 html += "</tbody></table></div><br>" 390 html += "</tbody></table></div><br>"
378 return html 391 return html
379 392
380 393
391 for key in sorted(test_metrics.keys()): 404 for key in sorted(test_metrics.keys()):
392 display_name = METRIC_DISPLAY_NAMES.get(key, key.replace("_", " ").title()) 405 display_name = METRIC_DISPLAY_NAMES.get(key, key.replace("_", " ").title())
393 value = test_metrics[key] 406 value = test_metrics[key]
394 if value is not None: 407 if value is not None:
395 rows.append([display_name, f"{value:.4f}"]) 408 rows.append([display_name, f"{value:.4f}"])
409
396 if not rows: 410 if not rows:
397 return "<table><tr><td>No test metric values found.</td></tr></table>" 411 return "<table><tr><td>No test metric values found.</td></tr></table>"
412
398 html = ( 413 html = (
399 "<h2 style='text-align: center;'>Test Performance Summary</h2>" 414 "<h2 style='text-align: center;'>Test Performance Summary</h2>"
400 "<div style='display: flex; justify-content: center;'>" 415 "<div style='display: flex; justify-content: center;'>"
401 "<table class='performance-summary' style='border-collapse: collapse;'>" 416 "<table class='performance-summary' style='border-collapse: collapse;'>"
402 "<thead><tr>" 417 "<thead><tr>"
405 "</tr></thead><tbody>" 420 "</tr></thead><tbody>"
406 ) 421 )
407 for row in rows: 422 for row in rows:
408 html += generate_table_row( 423 html += generate_table_row(
409 row, 424 row,
410 "padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;" 425 "padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;",
411 ) 426 )
412 html += "</tbody></table></div><br>" 427 html += "</tbody></table></div><br>"
413 return html 428 return html
414 429
415 430
434 if label_counts.size > 1: 449 if label_counts.size > 1:
435 # Force stratify even with fewer samples - adjust validation_size if needed 450 # Force stratify even with fewer samples - adjust validation_size if needed
436 min_samples_per_class = label_counts.min() 451 min_samples_per_class = label_counts.min()
437 if min_samples_per_class * validation_size < 1: 452 if min_samples_per_class * validation_size < 1:
438 # Adjust validation_size to ensure at least 1 sample per class, but do not exceed original validation_size 453 # Adjust validation_size to ensure at least 1 sample per class, but do not exceed original validation_size
439 adjusted_validation_size = min(validation_size, 1.0 / min_samples_per_class) 454 adjusted_validation_size = min(
455 validation_size, 1.0 / min_samples_per_class
456 )
440 if adjusted_validation_size != validation_size: 457 if adjusted_validation_size != validation_size:
441 validation_size = adjusted_validation_size 458 validation_size = adjusted_validation_size
442 logger.info(f"Adjusted validation_size to {validation_size:.3f} to ensure at least one sample per class in validation") 459 logger.info(
460 f"Adjusted validation_size to {validation_size:.3f} to ensure at least one sample per class in validation"
461 )
443 stratify_arr = out.loc[idx_train, label_column] 462 stratify_arr = out.loc[idx_train, label_column]
444 logger.info("Using stratified split for validation set") 463 logger.info("Using stratified split for validation set")
445 else: 464 else:
446 logger.warning("Only one label class found; cannot stratify") 465 logger.warning("Only one label class found; cannot stratify")
447 if validation_size <= 0: 466 if validation_size <= 0:
484 """Create a stratified random split when no split column exists.""" 503 """Create a stratified random split when no split column exists."""
485 out = df.copy() 504 out = df.copy()
486 # initialize split column 505 # initialize split column
487 out[split_column] = 0 506 out[split_column] = 0
488 if not label_column or label_column not in out.columns: 507 if not label_column or label_column not in out.columns:
489 logger.warning("No label column found; using random split without stratification") 508 logger.warning(
509 "No label column found; using random split without stratification"
510 )
490 # fall back to simple random assignment 511 # fall back to simple random assignment
491 indices = out.index.tolist() 512 indices = out.index.tolist()
492 np.random.seed(random_state) 513 np.random.seed(random_state)
493 np.random.shuffle(indices) 514 np.random.shuffle(indices)
494 n_total = len(indices) 515 n_total = len(indices)
527 test_size=split_probabilities[2], 548 test_size=split_probabilities[2],
528 random_state=random_state, 549 random_state=random_state,
529 stratify=out[label_column], 550 stratify=out[label_column],
530 ) 551 )
531 # second split: separate training and validation from remaining data 552 # second split: separate training and validation from remaining data
532 val_size_adjusted = split_probabilities[1] / (split_probabilities[0] + split_probabilities[1]) 553 val_size_adjusted = split_probabilities[1] / (
554 split_probabilities[0] + split_probabilities[1]
555 )
533 train_idx, val_idx = train_test_split( 556 train_idx, val_idx = train_test_split(
534 train_val_idx, 557 train_val_idx,
535 test_size=val_size_adjusted, 558 test_size=val_size_adjusted,
536 random_state=random_state, 559 random_state=random_state,
537 stratify=out.loc[train_val_idx, label_column], 560 stratify=out.loc[train_val_idx, label_column],
539 # assign split values 562 # assign split values
540 out.loc[train_idx, split_column] = 0 563 out.loc[train_idx, split_column] = 0
541 out.loc[val_idx, split_column] = 1 564 out.loc[val_idx, split_column] = 1
542 out.loc[test_idx, split_column] = 2 565 out.loc[test_idx, split_column] = 2
543 logger.info("Successfully applied stratified random split") 566 logger.info("Successfully applied stratified random split")
544 logger.info(f"Split counts: Train={len(train_idx)}, Val={len(val_idx)}, Test={len(test_idx)}") 567 logger.info(
568 f"Split counts: Train={len(train_idx)}, Val={len(val_idx)}, Test={len(test_idx)}"
569 )
545 return out.astype({split_column: int}) 570 return out.astype({split_column: int})
546 571
547 572
548 class Backend(Protocol): 573 class Backend(Protocol):
549 """Interface for a machine learning backend.""" 574 """Interface for a machine learning backend."""
575
550 def prepare_config( 576 def prepare_config(
551 self, 577 self,
552 config_params: Dict[str, Any], 578 config_params: Dict[str, Any],
553 split_config: Dict[str, Any], 579 split_config: Dict[str, Any],
554 ) -> str: 580 ) -> str:
576 ... 602 ...
577 603
578 604
579 class LudwigDirectBackend: 605 class LudwigDirectBackend:
580 """Backend for running Ludwig experiments directly via the internal experiment_cli function.""" 606 """Backend for running Ludwig experiments directly via the internal experiment_cli function."""
607
581 def prepare_config( 608 def prepare_config(
582 self, 609 self,
583 config_params: Dict[str, Any], 610 config_params: Dict[str, Any],
584 split_config: Dict[str, Any], 611 split_config: Dict[str, Any],
585 ) -> str: 612 ) -> str:
586 logger.info("LudwigDirectBackend: Preparing YAML configuration.") 613 logger.info("LudwigDirectBackend: Preparing YAML configuration.")
614
587 model_name = config_params.get("model_name", "resnet18") 615 model_name = config_params.get("model_name", "resnet18")
588 use_pretrained = config_params.get("use_pretrained", False) 616 use_pretrained = config_params.get("use_pretrained", False)
589 fine_tune = config_params.get("fine_tune", False) 617 fine_tune = config_params.get("fine_tune", False)
590 if use_pretrained: 618 if use_pretrained:
591 trainable = bool(fine_tune) 619 trainable = bool(fine_tune)
604 "use_pretrained": use_pretrained, 632 "use_pretrained": use_pretrained,
605 "trainable": trainable, 633 "trainable": trainable,
606 } 634 }
607 else: 635 else:
608 encoder_config = {"type": raw_encoder} 636 encoder_config = {"type": raw_encoder}
637
609 batch_size_cfg = batch_size or "auto" 638 batch_size_cfg = batch_size or "auto"
639
610 label_column_path = config_params.get("label_column_data_path") 640 label_column_path = config_params.get("label_column_data_path")
611 label_series = None 641 label_series = None
612 if label_column_path is not None and Path(label_column_path).exists(): 642 if label_column_path is not None and Path(label_column_path).exists():
613 try: 643 try:
614 label_series = pd.read_csv(label_column_path)[LABEL_COLUMN_NAME] 644 label_series = pd.read_csv(label_column_path)[LABEL_COLUMN_NAME]
615 except Exception as e: 645 except Exception as e:
616 logger.warning(f"Could not read label column for task detection: {e}") 646 logger.warning(f"Could not read label column for task detection: {e}")
647
617 if ( 648 if (
618 label_series is not None 649 label_series is not None
619 and ptypes.is_numeric_dtype(label_series.dtype) 650 and ptypes.is_numeric_dtype(label_series.dtype)
620 and label_series.nunique() > 10 651 and label_series.nunique() > 10
621 ): 652 ):
622 task_type = "regression" 653 task_type = "regression"
623 else: 654 else:
624 task_type = "classification" 655 task_type = "classification"
656
625 config_params["task_type"] = task_type 657 config_params["task_type"] = task_type
658
626 image_feat: Dict[str, Any] = { 659 image_feat: Dict[str, Any] = {
627 "name": IMAGE_PATH_COLUMN_NAME, 660 "name": IMAGE_PATH_COLUMN_NAME,
628 "type": "image", 661 "type": "image",
629 "encoder": encoder_config, 662 "encoder": encoder_config,
630 } 663 }
631 if config_params.get("augmentation") is not None: 664 if config_params.get("augmentation") is not None:
632 image_feat["augmentation"] = config_params["augmentation"] 665 image_feat["augmentation"] = config_params["augmentation"]
666
633 if task_type == "regression": 667 if task_type == "regression":
634 output_feat = { 668 output_feat = {
635 "name": LABEL_COLUMN_NAME, 669 "name": LABEL_COLUMN_NAME,
636 "type": "number", 670 "type": "number",
637 "decoder": {"type": "regressor"}, 671 "decoder": {"type": "regressor"},
643 "r2", 677 "r2",
644 ] 678 ]
645 }, 679 },
646 } 680 }
647 val_metric = config_params.get("validation_metric", "mean_squared_error") 681 val_metric = config_params.get("validation_metric", "mean_squared_error")
682
648 else: 683 else:
649 num_unique_labels = ( 684 num_unique_labels = (
650 label_series.nunique() if label_series is not None else 2 685 label_series.nunique() if label_series is not None else 2
651 ) 686 )
652 output_type = "binary" if num_unique_labels == 2 else "category" 687 output_type = "binary" if num_unique_labels == 2 else "category"
653 output_feat = {"name": LABEL_COLUMN_NAME, "type": output_type} 688 output_feat = {"name": LABEL_COLUMN_NAME, "type": output_type}
654 if output_type == "binary" and config_params.get("threshold") is not None: 689 if output_type == "binary" and config_params.get("threshold") is not None:
655 output_feat["threshold"] = float(config_params["threshold"]) 690 output_feat["threshold"] = float(config_params["threshold"])
656 val_metric = None 691 val_metric = None
692
657 conf: Dict[str, Any] = { 693 conf: Dict[str, Any] = {
658 "model_type": "ecd", 694 "model_type": "ecd",
659 "input_features": [image_feat], 695 "input_features": [image_feat],
660 "output_features": [output_feat], 696 "output_features": [output_feat],
661 "combiner": {"type": "concat"}, 697 "combiner": {"type": "concat"},
671 "split": split_config, 707 "split": split_config,
672 "num_processes": num_processes, 708 "num_processes": num_processes,
673 "in_memory": False, 709 "in_memory": False,
674 }, 710 },
675 } 711 }
712
676 logger.debug("LudwigDirectBackend: Config dict built.") 713 logger.debug("LudwigDirectBackend: Config dict built.")
677 try: 714 try:
678 yaml_str = yaml.dump(conf, sort_keys=False, indent=2) 715 yaml_str = yaml.dump(conf, sort_keys=False, indent=2)
679 logger.info("LudwigDirectBackend: YAML config generated.") 716 logger.info("LudwigDirectBackend: YAML config generated.")
680 return yaml_str 717 return yaml_str
692 output_dir: Path, 729 output_dir: Path,
693 random_seed: int = 42, 730 random_seed: int = 42,
694 ) -> None: 731 ) -> None:
695 """Invoke Ludwig's internal experiment_cli function to run the experiment.""" 732 """Invoke Ludwig's internal experiment_cli function to run the experiment."""
696 logger.info("LudwigDirectBackend: Starting experiment execution.") 733 logger.info("LudwigDirectBackend: Starting experiment execution.")
734
697 try: 735 try:
698 from ludwig.experiment import experiment_cli 736 from ludwig.experiment import experiment_cli
699 except ImportError as e: 737 except ImportError as e:
700 logger.error( 738 logger.error(
701 "LudwigDirectBackend: Could not import experiment_cli.", 739 "LudwigDirectBackend: Could not import experiment_cli.",
702 exc_info=True, 740 exc_info=True,
703 ) 741 )
704 raise RuntimeError("Ludwig import failed.") from e 742 raise RuntimeError("Ludwig import failed.") from e
743
705 output_dir.mkdir(parents=True, exist_ok=True) 744 output_dir.mkdir(parents=True, exist_ok=True)
745
706 try: 746 try:
707 experiment_cli( 747 experiment_cli(
708 dataset=str(dataset_path), 748 dataset=str(dataset_path),
709 config=str(config_path), 749 config=str(config_path),
710 output_directory=str(output_dir), 750 output_directory=str(output_dir),
731 output_dir = Path(output_dir) 771 output_dir = Path(output_dir)
732 exp_dirs = sorted( 772 exp_dirs = sorted(
733 output_dir.glob("experiment_run*"), 773 output_dir.glob("experiment_run*"),
734 key=lambda p: p.stat().st_mtime, 774 key=lambda p: p.stat().st_mtime,
735 ) 775 )
776
736 if not exp_dirs: 777 if not exp_dirs:
737 logger.warning(f"No experiment run directories found in {output_dir}") 778 logger.warning(f"No experiment run directories found in {output_dir}")
738 return None 779 return None
780
739 progress_file = exp_dirs[-1] / "model" / "training_progress.json" 781 progress_file = exp_dirs[-1] / "model" / "training_progress.json"
740 if not progress_file.exists(): 782 if not progress_file.exists():
741 logger.warning(f"No training_progress.json found in {progress_file}") 783 logger.warning(f"No training_progress.json found in {progress_file}")
742 return None 784 return None
785
743 try: 786 try:
744 with progress_file.open("r", encoding="utf-8") as f: 787 with progress_file.open("r", encoding="utf-8") as f:
745 data = json.load(f) 788 data = json.load(f)
746 return { 789 return {
747 "learning_rate": data.get("learning_rate"), 790 "learning_rate": data.get("learning_rate"),
773 logger.error(f"Error converting Parquet to CSV: {e}") 816 logger.error(f"Error converting Parquet to CSV: {e}")
774 817
775 def generate_plots(self, output_dir: Path) -> None: 818 def generate_plots(self, output_dir: Path) -> None:
776 """Generate all registered Ludwig visualizations for the latest experiment run.""" 819 """Generate all registered Ludwig visualizations for the latest experiment run."""
777 logger.info("Generating all Ludwig visualizations…") 820 logger.info("Generating all Ludwig visualizations…")
821
778 test_plots = { 822 test_plots = {
779 "compare_performance", 823 "compare_performance",
780 "compare_classifiers_performance_from_prob", 824 "compare_classifiers_performance_from_prob",
781 "compare_classifiers_performance_from_pred", 825 "compare_classifiers_performance_from_pred",
782 "compare_classifiers_performance_changing_k", 826 "compare_classifiers_performance_changing_k",
796 } 840 }
797 train_plots = { 841 train_plots = {
798 "learning_curves", 842 "learning_curves",
799 "compare_classifiers_performance_subset", 843 "compare_classifiers_performance_subset",
800 } 844 }
845
801 output_dir = Path(output_dir) 846 output_dir = Path(output_dir)
802 exp_dirs = sorted( 847 exp_dirs = sorted(
803 output_dir.glob("experiment_run*"), 848 output_dir.glob("experiment_run*"),
804 key=lambda p: p.stat().st_mtime, 849 key=lambda p: p.stat().st_mtime,
805 ) 850 )
806 if not exp_dirs: 851 if not exp_dirs:
807 logger.warning(f"No experiment run dirs found in {output_dir}") 852 logger.warning(f"No experiment run dirs found in {output_dir}")
808 return 853 return
809 exp_dir = exp_dirs[-1] 854 exp_dir = exp_dirs[-1]
855
810 viz_dir = exp_dir / "visualizations" 856 viz_dir = exp_dir / "visualizations"
811 viz_dir.mkdir(exist_ok=True) 857 viz_dir.mkdir(exist_ok=True)
812 train_viz = viz_dir / "train" 858 train_viz = viz_dir / "train"
813 test_viz = viz_dir / "test" 859 test_viz = viz_dir / "test"
814 train_viz.mkdir(parents=True, exist_ok=True) 860 train_viz.mkdir(parents=True, exist_ok=True)
819 865
820 training_stats = _check(exp_dir / "training_statistics.json") 866 training_stats = _check(exp_dir / "training_statistics.json")
821 test_stats = _check(exp_dir / TEST_STATISTICS_FILE_NAME) 867 test_stats = _check(exp_dir / TEST_STATISTICS_FILE_NAME)
822 probs_path = _check(exp_dir / PREDICTIONS_PARQUET_FILE_NAME) 868 probs_path = _check(exp_dir / PREDICTIONS_PARQUET_FILE_NAME)
823 gt_metadata = _check(exp_dir / "model" / TRAIN_SET_METADATA_FILE_NAME) 869 gt_metadata = _check(exp_dir / "model" / TRAIN_SET_METADATA_FILE_NAME)
870
824 dataset_path = None 871 dataset_path = None
825 split_file = None 872 split_file = None
826 desc = exp_dir / DESCRIPTION_FILE_NAME 873 desc = exp_dir / DESCRIPTION_FILE_NAME
827 if desc.exists(): 874 if desc.exists():
828 with open(desc, "r") as f: 875 with open(desc, "r") as f:
829 cfg = json.load(f) 876 cfg = json.load(f)
830 dataset_path = _check(Path(cfg.get("dataset", ""))) 877 dataset_path = _check(Path(cfg.get("dataset", "")))
831 split_file = _check(Path(get_split_path(cfg.get("dataset", "")))) 878 split_file = _check(Path(get_split_path(cfg.get("dataset", ""))))
879
832 output_feature = "" 880 output_feature = ""
833 if desc.exists(): 881 if desc.exists():
834 try: 882 try:
835 output_feature = cfg["config"]["output_features"][0]["name"] 883 output_feature = cfg["config"]["output_features"][0]["name"]
836 except Exception: 884 except Exception:
837 pass 885 pass
838 if not output_feature and test_stats: 886 if not output_feature and test_stats:
839 with open(test_stats, "r") as f: 887 with open(test_stats, "r") as f:
840 stats = json.load(f) 888 stats = json.load(f)
841 output_feature = next(iter(stats.keys()), "") 889 output_feature = next(iter(stats.keys()), "")
890
842 viz_registry = get_visualizations_registry() 891 viz_registry = get_visualizations_registry()
843 for viz_name, viz_func in viz_registry.items(): 892 for viz_name, viz_func in viz_registry.items():
844 if viz_name in train_plots: 893 if viz_name in train_plots:
845 viz_dir_plot = train_viz 894 viz_dir_plot = train_viz
846 elif viz_name in test_plots: 895 elif viz_name in test_plots:
847 viz_dir_plot = test_viz 896 viz_dir_plot = test_viz
848 else: 897 else:
849 continue 898 continue
899
850 try: 900 try:
851 viz_func( 901 viz_func(
852 training_statistics=[training_stats] if training_stats else [], 902 training_statistics=[training_stats] if training_stats else [],
853 test_statistics=[test_stats] if test_stats else [], 903 test_statistics=[test_stats] if test_stats else [],
854 probabilities=[probs_path] if probs_path else [], 904 probabilities=[probs_path] if probs_path else [],
864 file_format="png", 914 file_format="png",
865 ) 915 )
866 logger.info(f"✔ Generated {viz_name}") 916 logger.info(f"✔ Generated {viz_name}")
867 except Exception as e: 917 except Exception as e:
868 logger.warning(f"✘ Skipped {viz_name}: {e}") 918 logger.warning(f"✘ Skipped {viz_name}: {e}")
919
869 logger.info(f"All visualizations written to {viz_dir}") 920 logger.info(f"All visualizations written to {viz_dir}")
870 921
871 def generate_html_report( 922 def generate_html_report(
872 self, 923 self,
873 title: str, 924 title: str,
879 cwd = Path.cwd() 930 cwd = Path.cwd()
880 report_name = title.lower().replace(" ", "_") + "_report.html" 931 report_name = title.lower().replace(" ", "_") + "_report.html"
881 report_path = cwd / report_name 932 report_path = cwd / report_name
882 output_dir = Path(output_dir) 933 output_dir = Path(output_dir)
883 output_type = None 934 output_type = None
935
884 exp_dirs = sorted( 936 exp_dirs = sorted(
885 output_dir.glob("experiment_run*"), 937 output_dir.glob("experiment_run*"),
886 key=lambda p: p.stat().st_mtime, 938 key=lambda p: p.stat().st_mtime,
887 ) 939 )
888 if not exp_dirs: 940 if not exp_dirs:
889 raise RuntimeError(f"No 'experiment*' dirs found in {output_dir}") 941 raise RuntimeError(f"No 'experiment*' dirs found in {output_dir}")
890 exp_dir = exp_dirs[-1] 942 exp_dir = exp_dirs[-1]
943
891 base_viz_dir = exp_dir / "visualizations" 944 base_viz_dir = exp_dir / "visualizations"
892 train_viz_dir = base_viz_dir / "train" 945 train_viz_dir = base_viz_dir / "train"
893 test_viz_dir = base_viz_dir / "test" 946 test_viz_dir = base_viz_dir / "test"
947
894 html = get_html_template() 948 html = get_html_template()
895 html += f"<h1>{title}</h1>" 949 html += f"<h1>{title}</h1>"
950
896 metrics_html = "" 951 metrics_html = ""
897 train_val_metrics_html = "" 952 train_val_metrics_html = ""
898 test_metrics_html = "" 953 test_metrics_html = ""
899 try: 954 try:
900 train_stats_path = exp_dir / "training_statistics.json" 955 train_stats_path = exp_dir / "training_statistics.json"
916 ) 971 )
917 except Exception as e: 972 except Exception as e:
918 logger.warning( 973 logger.warning(
919 f"Could not load stats for HTML report: {type(e).__name__}: {e}" 974 f"Could not load stats for HTML report: {type(e).__name__}: {e}"
920 ) 975 )
976
921 config_html = "" 977 config_html = ""
922 training_progress = self.get_training_process(output_dir) 978 training_progress = self.get_training_process(output_dir)
923 try: 979 try:
924 config_html = format_config_table_html( 980 config_html = format_config_table_html(
925 config, split_info, training_progress 981 config, split_info, training_progress, output_type
926 ) 982 )
927 except Exception as e: 983 except Exception as e:
928 logger.warning(f"Could not load config for HTML report: {e}") 984 logger.warning(f"Could not load config for HTML report: {e}")
929 985
930 def render_img_section( 986 def render_img_section(
934 return f"<h2>{title}</h2><p><em>Directory not found.</em></p>" 990 return f"<h2>{title}</h2><p><em>Directory not found.</em></p>"
935 # collect every PNG 991 # collect every PNG
936 imgs = list(dir_path.glob("*.png")) 992 imgs = list(dir_path.glob("*.png"))
937 # --- EXCLUDE Ludwig's base confusion matrix and any top-N confusion_matrix files --- 993 # --- EXCLUDE Ludwig's base confusion matrix and any top-N confusion_matrix files ---
938 imgs = [ 994 imgs = [
939 img for img in imgs 995 img
996 for img in imgs
940 if not ( 997 if not (
941 img.name == "confusion_matrix.png" 998 img.name == "confusion_matrix.png"
942 or img.name.startswith("confusion_matrix__label_top") 999 or img.name.startswith("confusion_matrix__label_top")
943 or img.name == "roc_curves.png" 1000 or img.name == "roc_curves.png"
944 ) 1001 )
970 ] 1027 ]
971 # filter and order 1028 # filter and order
972 valid_imgs = [img for img in imgs if img.name not in unwanted] 1029 valid_imgs = [img for img in imgs if img.name not in unwanted]
973 img_map = {img.name: img for img in valid_imgs} 1030 img_map = {img.name: img for img in valid_imgs}
974 ordered = [img_map[n] for n in display_order if n in img_map] 1031 ordered = [img_map[n] for n in display_order if n in img_map]
975 others = sorted(img for img in valid_imgs if img.name not in display_order) 1032 others = sorted(
1033 img for img in valid_imgs if img.name not in display_order
1034 )
976 imgs = ordered + others 1035 imgs = ordered + others
977 else: 1036 else:
978 # regression: just sort whatever's left 1037 # regression: just sort whatever's left
979 imgs = sorted(imgs) 1038 imgs = sorted(imgs)
980 # render each remaining PNG 1039 # render each remaining PNG
1010 if pred_col is None: 1069 if pred_col is None:
1011 raise ValueError("No prediction column found in Parquet output") 1070 raise ValueError("No prediction column found in Parquet output")
1012 df_pred = df_preds[[pred_col]].rename(columns={pred_col: "prediction"}) 1071 df_pred = df_preds[[pred_col]].rename(columns={pred_col: "prediction"})
1013 # 2) load ground truth for the test split from prepared CSV 1072 # 2) load ground truth for the test split from prepared CSV
1014 df_all = pd.read_csv(config["label_column_data_path"]) 1073 df_all = pd.read_csv(config["label_column_data_path"])
1015 df_gt = df_all[df_all[SPLIT_COLUMN_NAME] == 2][LABEL_COLUMN_NAME].reset_index(drop=True) 1074 df_gt = df_all[df_all[SPLIT_COLUMN_NAME] == 2][
1075 LABEL_COLUMN_NAME
1076 ].reset_index(drop=True)
1016 # 3) concatenate side-by-side 1077 # 3) concatenate side-by-side
1017 df_table = pd.concat([df_gt, df_pred], axis=1) 1078 df_table = pd.concat([df_gt, df_pred], axis=1)
1018 df_table.columns = [LABEL_COLUMN_NAME, "prediction"] 1079 df_table.columns = [LABEL_COLUMN_NAME, "prediction"]
1019 # 4) render as HTML 1080 # 4) render as HTML
1020 preds_html = df_table.to_html(index=False, classes="predictions-table") 1081 preds_html = df_table.to_html(index=False, classes="predictions-table")
1034 str(training_stats_path), 1095 str(training_stats_path),
1035 ) 1096 )
1036 for plot in interactive_plots: 1097 for plot in interactive_plots:
1037 # 2) inject the static "roc_curves_from_prediction_statistics.png" 1098 # 2) inject the static "roc_curves_from_prediction_statistics.png"
1038 if plot["title"] == "ROC-AUC": 1099 if plot["title"] == "ROC-AUC":
1039 static_img = test_viz_dir / "roc_curves_from_prediction_statistics.png" 1100 static_img = (
1101 test_viz_dir / "roc_curves_from_prediction_statistics.png"
1102 )
1040 if static_img.exists(): 1103 if static_img.exists():
1041 b64 = encode_image_to_base64(str(static_img)) 1104 b64 = encode_image_to_base64(str(static_img))
1042 tab3_content += ( 1105 tab3_content += (
1043 "<h2 style='text-align: center;'>" 1106 "<h2 style='text-align: center;'>"
1044 "Roc Curves From Prediction Statistics" 1107 "Roc Curves From Prediction Statistics"
1052 tab3_content += ( 1115 tab3_content += (
1053 f"<h2 style='text-align: center;'>{plot['title']}</h2>" 1116 f"<h2 style='text-align: center;'>{plot['title']}</h2>"
1054 + plot["html"] 1117 + plot["html"]
1055 ) 1118 )
1056 tab3_content += render_img_section( 1119 tab3_content += render_img_section(
1057 "Test Visualizations", 1120 "Test Visualizations", test_viz_dir, output_type
1058 test_viz_dir,
1059 output_type
1060 ) 1121 )
1061 # assemble the tabs and help modal 1122 # assemble the tabs and help modal
1062 tabbed_html = build_tabbed_html(tab1_content, tab2_content, tab3_content) 1123 tabbed_html = build_tabbed_html(tab1_content, tab2_content, tab3_content)
1063 modal_html = get_metrics_help_modal() 1124 modal_html = get_metrics_help_modal()
1064 html += tabbed_html + modal_html + get_html_closing() 1125 html += tabbed_html + modal_html + get_html_closing()
1126
1065 try: 1127 try:
1066 with open(report_path, "w") as f: 1128 with open(report_path, "w") as f:
1067 f.write(html) 1129 f.write(html)
1068 logger.info(f"HTML report generated at: {report_path}") 1130 logger.info(f"HTML report generated at: {report_path}")
1069 except Exception as e: 1131 except Exception as e:
1070 logger.error(f"Failed to write HTML report: {e}") 1132 logger.error(f"Failed to write HTML report: {e}")
1071 raise 1133 raise
1134
1072 return report_path 1135 return report_path
1073 1136
1074 1137
1075 class WorkflowOrchestrator: 1138 class WorkflowOrchestrator:
1076 """Manages the image-classification workflow.""" 1139 """Manages the image-classification workflow."""
1140
1077 def __init__(self, args: argparse.Namespace, backend: Backend): 1141 def __init__(self, args: argparse.Namespace, backend: Backend):
1078 self.args = args 1142 self.args = args
1079 self.backend = backend 1143 self.backend = backend
1080 self.temp_dir: Optional[Path] = None 1144 self.temp_dir: Optional[Path] = None
1081 self.image_extract_dir: Optional[Path] = None 1145 self.image_extract_dir: Optional[Path] = None
1111 1175
1112 def _prepare_data(self) -> Tuple[Path, Dict[str, Any], str]: 1176 def _prepare_data(self) -> Tuple[Path, Dict[str, Any], str]:
1113 """Load CSV, update image paths, handle splits, and write prepared CSV.""" 1177 """Load CSV, update image paths, handle splits, and write prepared CSV."""
1114 if not self.temp_dir or not self.image_extract_dir: 1178 if not self.temp_dir or not self.image_extract_dir:
1115 raise RuntimeError("Temp dirs not initialized before data prep.") 1179 raise RuntimeError("Temp dirs not initialized before data prep.")
1180
1116 try: 1181 try:
1117 df = pd.read_csv(self.args.csv_file) 1182 df = pd.read_csv(self.args.csv_file)
1118 logger.info(f"Loaded CSV: {self.args.csv_file}") 1183 logger.info(f"Loaded CSV: {self.args.csv_file}")
1119 except Exception: 1184 except Exception:
1120 logger.error("Error loading CSV file", exc_info=True) 1185 logger.error("Error loading CSV file", exc_info=True)
1121 raise 1186 raise
1187
1122 required = {IMAGE_PATH_COLUMN_NAME, LABEL_COLUMN_NAME} 1188 required = {IMAGE_PATH_COLUMN_NAME, LABEL_COLUMN_NAME}
1123 missing = required - set(df.columns) 1189 missing = required - set(df.columns)
1124 if missing: 1190 if missing:
1125 raise ValueError(f"Missing CSV columns: {', '.join(missing)}") 1191 raise ValueError(f"Missing CSV columns: {', '.join(missing)}")
1192
1126 try: 1193 try:
1127 df[IMAGE_PATH_COLUMN_NAME] = df[IMAGE_PATH_COLUMN_NAME].apply( 1194 df[IMAGE_PATH_COLUMN_NAME] = df[IMAGE_PATH_COLUMN_NAME].apply(
1128 lambda p: str((self.image_extract_dir / p).resolve()) 1195 lambda p: str((self.image_extract_dir / p).resolve())
1129 ) 1196 )
1130 except Exception: 1197 except Exception:
1148 split_info = ( 1215 split_info = (
1149 f"No split column in CSV. Created stratified random split: " 1216 f"No split column in CSV. Created stratified random split: "
1150 f"{[int(p * 100) for p in self.args.split_probabilities]}% " 1217 f"{[int(p * 100) for p in self.args.split_probabilities]}% "
1151 f"for train/val/test with balanced label distribution." 1218 f"for train/val/test with balanced label distribution."
1152 ) 1219 )
1220
1153 final_csv = self.temp_dir / TEMP_CSV_FILENAME 1221 final_csv = self.temp_dir / TEMP_CSV_FILENAME
1222
1154 try: 1223 try:
1155 df.to_csv(final_csv, index=False) 1224 df.to_csv(final_csv, index=False)
1156 logger.info(f"Saved prepared data to {final_csv}") 1225 logger.info(f"Saved prepared data to {final_csv}")
1157 except Exception: 1226 except Exception:
1158 logger.error("Error saving prepared CSV", exc_info=True) 1227 logger.error("Error saving prepared CSV", exc_info=True)
1159 raise 1228 raise
1229
1160 return final_csv, split_config, split_info 1230 return final_csv, split_config, split_info
1161 1231
1162 def _process_fixed_split( 1232 def _process_fixed_split(
1163 self, df: pd.DataFrame 1233 self, df: pd.DataFrame
1164 ) -> Tuple[pd.DataFrame, Dict[str, Any], str]: 1234 ) -> Tuple[pd.DataFrame, Dict[str, Any], str]:
1169 df[SPLIT_COLUMN_NAME] = pd.to_numeric(col, errors="coerce").astype( 1239 df[SPLIT_COLUMN_NAME] = pd.to_numeric(col, errors="coerce").astype(
1170 pd.Int64Dtype() 1240 pd.Int64Dtype()
1171 ) 1241 )
1172 if df[SPLIT_COLUMN_NAME].isna().any(): 1242 if df[SPLIT_COLUMN_NAME].isna().any():
1173 logger.warning("Split column contains non-numeric/missing values.") 1243 logger.warning("Split column contains non-numeric/missing values.")
1244
1174 unique = set(df[SPLIT_COLUMN_NAME].dropna().unique()) 1245 unique = set(df[SPLIT_COLUMN_NAME].dropna().unique())
1175 logger.info(f"Unique split values: {unique}") 1246 logger.info(f"Unique split values: {unique}")
1176 if unique == {0, 2}: 1247 if unique == {0, 2}:
1177 df = split_data_0_2( 1248 df = split_data_0_2(
1178 df, 1249 df,
1191 elif unique.issubset({0, 1, 2}): 1262 elif unique.issubset({0, 1, 2}):
1192 split_info = "Used user-defined split column from CSV." 1263 split_info = "Used user-defined split column from CSV."
1193 logger.info("Using fixed split as-is.") 1264 logger.info("Using fixed split as-is.")
1194 else: 1265 else:
1195 raise ValueError(f"Unexpected split values: {unique}") 1266 raise ValueError(f"Unexpected split values: {unique}")
1267
1196 return df, {"type": "fixed", "column": SPLIT_COLUMN_NAME}, split_info 1268 return df, {"type": "fixed", "column": SPLIT_COLUMN_NAME}, split_info
1269
1197 except Exception: 1270 except Exception:
1198 logger.error("Error processing fixed split", exc_info=True) 1271 logger.error("Error processing fixed split", exc_info=True)
1199 raise 1272 raise
1200 1273
1201 def _cleanup_temp_dirs(self) -> None: 1274 def _cleanup_temp_dirs(self) -> None:
1207 1280
1208 def run(self) -> None: 1281 def run(self) -> None:
1209 """Execute the full workflow end-to-end.""" 1282 """Execute the full workflow end-to-end."""
1210 logger.info("Starting workflow...") 1283 logger.info("Starting workflow...")
1211 self.args.output_dir.mkdir(parents=True, exist_ok=True) 1284 self.args.output_dir.mkdir(parents=True, exist_ok=True)
1285
1212 try: 1286 try:
1213 self._create_temp_dirs() 1287 self._create_temp_dirs()
1214 self._extract_images() 1288 self._extract_images()
1215 csv_path, split_cfg, split_info = self._prepare_data() 1289 csv_path, split_cfg, split_info = self._prepare_data()
1290
1216 use_pretrained = self.args.use_pretrained or self.args.fine_tune 1291 use_pretrained = self.args.use_pretrained or self.args.fine_tune
1292
1217 backend_args = { 1293 backend_args = {
1218 "model_name": self.args.model_name, 1294 "model_name": self.args.model_name,
1219 "fine_tune": self.args.fine_tune, 1295 "fine_tune": self.args.fine_tune,
1220 "use_pretrained": use_pretrained, 1296 "use_pretrained": use_pretrained,
1221 "epochs": self.args.epochs, 1297 "epochs": self.args.epochs,
1228 "label_column_data_path": csv_path, 1304 "label_column_data_path": csv_path,
1229 "augmentation": self.args.augmentation, 1305 "augmentation": self.args.augmentation,
1230 "threshold": self.args.threshold, 1306 "threshold": self.args.threshold,
1231 } 1307 }
1232 yaml_str = self.backend.prepare_config(backend_args, split_cfg) 1308 yaml_str = self.backend.prepare_config(backend_args, split_cfg)
1309
1233 config_file = self.temp_dir / TEMP_CONFIG_FILENAME 1310 config_file = self.temp_dir / TEMP_CONFIG_FILENAME
1234 config_file.write_text(yaml_str) 1311 config_file.write_text(yaml_str)
1235 logger.info(f"Wrote backend config: {config_file}") 1312 logger.info(f"Wrote backend config: {config_file}")
1313
1236 self.backend.run_experiment( 1314 self.backend.run_experiment(
1237 csv_path, 1315 csv_path,
1238 config_file, 1316 config_file,
1239 self.args.output_dir, 1317 self.args.output_dir,
1240 self.args.random_seed, 1318 self.args.random_seed,
1372 nargs=3, 1450 nargs=3,
1373 metavar=("train", "val", "test"), 1451 metavar=("train", "val", "test"),
1374 action=SplitProbAction, 1452 action=SplitProbAction,
1375 default=[0.7, 0.1, 0.2], 1453 default=[0.7, 0.1, 0.2],
1376 help=( 1454 help=(
1377 "Random split proportions (e.g., 0.7 0.1 0.2)." 1455 "Random split proportions (e.g., 0.7 0.1 0.2).Only used if no split column."
1378 "Only used if no split column."
1379 ), 1456 ),
1380 ) 1457 )
1381 parser.add_argument( 1458 parser.add_argument(
1382 "--random-seed", 1459 "--random-seed",
1383 type=int, 1460 type=int,
1406 type=float, 1483 type=float,
1407 default=None, 1484 default=None,
1408 help=( 1485 help=(
1409 "Decision threshold for binary classification (0.0–1.0)." 1486 "Decision threshold for binary classification (0.0–1.0)."
1410 "Overrides default 0.5." 1487 "Overrides default 0.5."
1411 ) 1488 ),
1412 ) 1489 )
1413 args = parser.parse_args() 1490 args = parser.parse_args()
1491
1414 if not 0.0 <= args.validation_size <= 1.0: 1492 if not 0.0 <= args.validation_size <= 1.0:
1415 parser.error("validation-size must be between 0.0 and 1.0") 1493 parser.error("validation-size must be between 0.0 and 1.0")
1416 if not args.csv_file.is_file(): 1494 if not args.csv_file.is_file():
1417 parser.error(f"CSV not found: {args.csv_file}") 1495 parser.error(f"CSV not found: {args.csv_file}")
1418 if not args.image_zip.is_file(): 1496 if not args.image_zip.is_file():
1421 try: 1499 try:
1422 augmentation_setup = aug_parse(args.augmentation) 1500 augmentation_setup = aug_parse(args.augmentation)
1423 setattr(args, "augmentation", augmentation_setup) 1501 setattr(args, "augmentation", augmentation_setup)
1424 except ValueError as e: 1502 except ValueError as e:
1425 parser.error(str(e)) 1503 parser.error(str(e))
1504
1426 backend_instance = LudwigDirectBackend() 1505 backend_instance = LudwigDirectBackend()
1427 orchestrator = WorkflowOrchestrator(args, backend_instance) 1506 orchestrator = WorkflowOrchestrator(args, backend_instance)
1507
1428 exit_code = 0 1508 exit_code = 0
1429 try: 1509 try:
1430 orchestrator.run() 1510 orchestrator.run()
1431 logger.info("Main script finished successfully.") 1511 logger.info("Main script finished successfully.")
1432 except Exception as e: 1512 except Exception as e:
1437 1517
1438 1518
1439 if __name__ == "__main__": 1519 if __name__ == "__main__":
1440 try: 1520 try:
1441 import ludwig 1521 import ludwig
1522
1442 logger.debug(f"Found Ludwig version: {ludwig.globals.LUDWIG_VERSION}") 1523 logger.debug(f"Found Ludwig version: {ludwig.globals.LUDWIG_VERSION}")
1443 except ImportError: 1524 except ImportError:
1444 logger.error( 1525 logger.error(
1445 "Ludwig library not found. Please ensure Ludwig is installed " 1526 "Ludwig library not found. Please ensure Ludwig is installed "
1446 "('pip install ludwig[image]')" 1527 "('pip install ludwig[image]')"
1447 ) 1528 )
1448 sys.exit(1) 1529 sys.exit(1)
1530
1449 main() 1531 main()