Mercurial > repos > goeckslab > image_learner
comparison image_learner_cli.py @ 9:9e912fce264c draft default tip
planemo upload for repository https://github.com/goeckslab/gleam.git commit eace0d7c2b2939029c052991d238a54947d2e191
author | goeckslab |
---|---|
date | Wed, 27 Aug 2025 21:02:48 +0000 |
parents | 85e6f4b2ad18 |
children |
comparison
equal
deleted
inserted
replaced
8:85e6f4b2ad18 | 9:9e912fce264c |
---|---|
19 METRIC_DISPLAY_NAMES, | 19 METRIC_DISPLAY_NAMES, |
20 MODEL_ENCODER_TEMPLATES, | 20 MODEL_ENCODER_TEMPLATES, |
21 SPLIT_COLUMN_NAME, | 21 SPLIT_COLUMN_NAME, |
22 TEMP_CONFIG_FILENAME, | 22 TEMP_CONFIG_FILENAME, |
23 TEMP_CSV_FILENAME, | 23 TEMP_CSV_FILENAME, |
24 TEMP_DIR_PREFIX | 24 TEMP_DIR_PREFIX, |
25 ) | 25 ) |
26 from ludwig.globals import ( | 26 from ludwig.globals import ( |
27 DESCRIPTION_FILE_NAME, | 27 DESCRIPTION_FILE_NAME, |
28 PREDICTIONS_PARQUET_FILE_NAME, | 28 PREDICTIONS_PARQUET_FILE_NAME, |
29 TEST_STATISTICS_FILE_NAME, | 29 TEST_STATISTICS_FILE_NAME, |
36 from utils import ( | 36 from utils import ( |
37 build_tabbed_html, | 37 build_tabbed_html, |
38 encode_image_to_base64, | 38 encode_image_to_base64, |
39 get_html_closing, | 39 get_html_closing, |
40 get_html_template, | 40 get_html_template, |
41 get_metrics_help_modal | 41 get_metrics_help_modal, |
42 ) | 42 ) |
43 | 43 |
44 # --- Logging Setup --- | 44 # --- Logging Setup --- |
45 logging.basicConfig( | 45 logging.basicConfig( |
46 level=logging.INFO, | 46 level=logging.INFO, |
47 format='%(asctime)s %(levelname)s %(name)s: %(message)s', | 47 format="%(asctime)s %(levelname)s %(name)s: %(message)s", |
48 ) | 48 ) |
49 logger = logging.getLogger("ImageLearner") | 49 logger = logging.getLogger("ImageLearner") |
50 | 50 |
51 | 51 |
52 def format_config_table_html( | 52 def format_config_table_html( |
65 "learning_rate", | 65 "learning_rate", |
66 "random_seed", | 66 "random_seed", |
67 "early_stop", | 67 "early_stop", |
68 "threshold", | 68 "threshold", |
69 ] | 69 ] |
70 | |
70 rows = [] | 71 rows = [] |
72 | |
71 for key in display_keys: | 73 for key in display_keys: |
72 val = config.get(key, None) | 74 val = config.get(key, None) |
73 if key == "threshold": | 75 if key == "threshold": |
74 if output_type != "binary": | 76 if output_type != "binary": |
75 continue | 77 continue |
132 ) | 134 ) |
133 else: | 135 else: |
134 val_str = val | 136 val_str = val |
135 else: | 137 else: |
136 val_str = val if val is not None else "N/A" | 138 val_str = val if val is not None else "N/A" |
137 if val_str == "N/A" and key not in ["task_type"]: # Skip if N/A for non-essential | 139 if val_str == "N/A" and key not in [ |
140 "task_type" | |
141 ]: # Skip if N/A for non-essential | |
138 continue | 142 continue |
139 rows.append( | 143 rows.append( |
140 f"<tr>" | 144 f"<tr>" |
141 f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left;'>" | 145 f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left;'>" |
142 f"{key.replace('_', ' ').title()}</td>" | 146 f"{key.replace('_', ' ').title()}</td>" |
164 <thead><tr> | 168 <thead><tr> |
165 <th style="padding: 10px; border: 1px solid #ccc; text-align: left; white-space: nowrap; overflow: hidden; text-overflow: ellipsis;">Parameter</th> | 169 <th style="padding: 10px; border: 1px solid #ccc; text-align: left; white-space: nowrap; overflow: hidden; text-overflow: ellipsis;">Parameter</th> |
166 <th style="padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap; overflow: hidden; text-overflow: ellipsis;">Value</th> | 170 <th style="padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap; overflow: hidden; text-overflow: ellipsis;">Value</th> |
167 </tr></thead> | 171 </tr></thead> |
168 <tbody> | 172 <tbody> |
169 {''.join(rows)} | 173 {"".join(rows)} |
170 </tbody> | 174 </tbody> |
171 </table> | 175 </table> |
172 </div><br> | 176 </div><br> |
173 <p style="text-align: center; font-size: 0.9em;"> | 177 <p style="text-align: center; font-size: 0.9em;"> |
174 Model trained using <a href="https://ludwig.ai/" target="_blank" rel="noopener noreferrer">Ludwig</a>. | 178 Model trained using <a href="https://ludwig.ai/" target="_blank" rel="noopener noreferrer">Ludwig</a>. |
249 "accuracy_micro": get_last_value(label_stats, "accuracy_micro"), | 253 "accuracy_micro": get_last_value(label_stats, "accuracy_micro"), |
250 "loss": get_last_value(label_stats, "loss"), | 254 "loss": get_last_value(label_stats, "loss"), |
251 "roc_auc": get_last_value(label_stats, "roc_auc"), | 255 "roc_auc": get_last_value(label_stats, "roc_auc"), |
252 "hits_at_k": get_last_value(label_stats, "hits_at_k"), | 256 "hits_at_k": get_last_value(label_stats, "hits_at_k"), |
253 } | 257 } |
258 | |
254 # Test metrics: dynamic extraction according to exclusions | 259 # Test metrics: dynamic extraction according to exclusions |
255 test_label_stats = test_stats.get("label", {}) | 260 test_label_stats = test_stats.get("label", {}) |
256 if not test_label_stats: | 261 if not test_label_stats: |
257 logging.warning("No label statistics found for test split") | 262 logging.warning("No label statistics found for test split") |
258 else: | 263 else: |
259 combined_stats = test_stats.get("combined", {}) | 264 combined_stats = test_stats.get("combined", {}) |
260 overall_stats = test_label_stats.get("overall_stats", {}) | 265 overall_stats = test_label_stats.get("overall_stats", {}) |
266 | |
261 # Define exclusions | 267 # Define exclusions |
262 if output_type == "binary": | 268 if output_type == "binary": |
263 exclude = {"per_class_stats", "precision_recall_curve", "roc_curve"} | 269 exclude = {"per_class_stats", "precision_recall_curve", "roc_curve"} |
264 else: | 270 else: |
265 exclude = {"per_class_stats", "confusion_matrix"} | 271 exclude = {"per_class_stats", "confusion_matrix"} |
272 | |
266 # 1. Get all scalar test_label_stats not excluded | 273 # 1. Get all scalar test_label_stats not excluded |
267 test_metrics = {} | 274 test_metrics = {} |
268 for k, v in test_label_stats.items(): | 275 for k, v in test_label_stats.items(): |
269 if k in exclude: | 276 if k in exclude: |
270 continue | 277 continue |
271 if k == "overall_stats": | 278 if k == "overall_stats": |
272 continue | 279 continue |
273 if isinstance(v, (int, float, str, bool)): | 280 if isinstance(v, (int, float, str, bool)): |
274 test_metrics[k] = v | 281 test_metrics[k] = v |
282 | |
275 # 2. Add overall_stats (flattened) | 283 # 2. Add overall_stats (flattened) |
276 for k, v in overall_stats.items(): | 284 for k, v in overall_stats.items(): |
277 test_metrics[k] = v | 285 test_metrics[k] = v |
286 | |
278 # 3. Optionally include combined/loss if present and not already | 287 # 3. Optionally include combined/loss if present and not already |
279 if "loss" in combined_stats and "loss" not in test_metrics: | 288 if "loss" in combined_stats and "loss" not in test_metrics: |
280 test_metrics["loss"] = combined_stats["loss"] | 289 test_metrics["loss"] = combined_stats["loss"] |
281 metrics["test"] = test_metrics | 290 metrics["test"] = test_metrics |
282 return metrics | 291 return metrics |
313 t = all_metrics["training"].get(metric_key) | 322 t = all_metrics["training"].get(metric_key) |
314 v = all_metrics["validation"].get(metric_key) | 323 v = all_metrics["validation"].get(metric_key) |
315 te = all_metrics["test"].get(metric_key) | 324 te = all_metrics["test"].get(metric_key) |
316 if all(x is not None for x in [t, v, te]): | 325 if all(x is not None for x in [t, v, te]): |
317 rows.append([display_name, f"{t:.4f}", f"{v:.4f}", f"{te:.4f}"]) | 326 rows.append([display_name, f"{t:.4f}", f"{v:.4f}", f"{te:.4f}"]) |
327 | |
318 if not rows: | 328 if not rows: |
319 return "<table><tr><td>No metric values found.</td></tr></table>" | 329 return "<table><tr><td>No metric values found.</td></tr></table>" |
330 | |
320 html = ( | 331 html = ( |
321 "<h2 style='text-align: center;'>Model Performance Summary</h2>" | 332 "<h2 style='text-align: center;'>Model Performance Summary</h2>" |
322 "<div style='display: flex; justify-content: center;'>" | 333 "<div style='display: flex; justify-content: center;'>" |
323 "<table class='performance-summary' style='border-collapse: collapse;'>" | 334 "<table class='performance-summary' style='border-collapse: collapse;'>" |
324 "<thead><tr>" | 335 "<thead><tr>" |
329 "</tr></thead><tbody>" | 340 "</tr></thead><tbody>" |
330 ) | 341 ) |
331 for row in rows: | 342 for row in rows: |
332 html += generate_table_row( | 343 html += generate_table_row( |
333 row, | 344 row, |
334 "padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;" | 345 "padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;", |
335 ) | 346 ) |
336 html += "</tbody></table></div><br>" | 347 html += "</tbody></table></div><br>" |
337 return html | 348 return html |
338 | 349 |
339 | 350 |
355 ) | 366 ) |
356 t = all_metrics["training"].get(metric_key) | 367 t = all_metrics["training"].get(metric_key) |
357 v = all_metrics["validation"].get(metric_key) | 368 v = all_metrics["validation"].get(metric_key) |
358 if t is not None and v is not None: | 369 if t is not None and v is not None: |
359 rows.append([display_name, f"{t:.4f}", f"{v:.4f}"]) | 370 rows.append([display_name, f"{t:.4f}", f"{v:.4f}"]) |
371 | |
360 if not rows: | 372 if not rows: |
361 return "<table><tr><td>No metric values found for Train/Validation.</td></tr></table>" | 373 return "<table><tr><td>No metric values found for Train/Validation.</td></tr></table>" |
374 | |
362 html = ( | 375 html = ( |
363 "<h2 style='text-align: center;'>Train/Validation Performance Summary</h2>" | 376 "<h2 style='text-align: center;'>Train/Validation Performance Summary</h2>" |
364 "<div style='display: flex; justify-content: center;'>" | 377 "<div style='display: flex; justify-content: center;'>" |
365 "<table class='performance-summary' style='border-collapse: collapse;'>" | 378 "<table class='performance-summary' style='border-collapse: collapse;'>" |
366 "<thead><tr>" | 379 "<thead><tr>" |
370 "</tr></thead><tbody>" | 383 "</tr></thead><tbody>" |
371 ) | 384 ) |
372 for row in rows: | 385 for row in rows: |
373 html += generate_table_row( | 386 html += generate_table_row( |
374 row, | 387 row, |
375 "padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;" | 388 "padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;", |
376 ) | 389 ) |
377 html += "</tbody></table></div><br>" | 390 html += "</tbody></table></div><br>" |
378 return html | 391 return html |
379 | 392 |
380 | 393 |
391 for key in sorted(test_metrics.keys()): | 404 for key in sorted(test_metrics.keys()): |
392 display_name = METRIC_DISPLAY_NAMES.get(key, key.replace("_", " ").title()) | 405 display_name = METRIC_DISPLAY_NAMES.get(key, key.replace("_", " ").title()) |
393 value = test_metrics[key] | 406 value = test_metrics[key] |
394 if value is not None: | 407 if value is not None: |
395 rows.append([display_name, f"{value:.4f}"]) | 408 rows.append([display_name, f"{value:.4f}"]) |
409 | |
396 if not rows: | 410 if not rows: |
397 return "<table><tr><td>No test metric values found.</td></tr></table>" | 411 return "<table><tr><td>No test metric values found.</td></tr></table>" |
412 | |
398 html = ( | 413 html = ( |
399 "<h2 style='text-align: center;'>Test Performance Summary</h2>" | 414 "<h2 style='text-align: center;'>Test Performance Summary</h2>" |
400 "<div style='display: flex; justify-content: center;'>" | 415 "<div style='display: flex; justify-content: center;'>" |
401 "<table class='performance-summary' style='border-collapse: collapse;'>" | 416 "<table class='performance-summary' style='border-collapse: collapse;'>" |
402 "<thead><tr>" | 417 "<thead><tr>" |
405 "</tr></thead><tbody>" | 420 "</tr></thead><tbody>" |
406 ) | 421 ) |
407 for row in rows: | 422 for row in rows: |
408 html += generate_table_row( | 423 html += generate_table_row( |
409 row, | 424 row, |
410 "padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;" | 425 "padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;", |
411 ) | 426 ) |
412 html += "</tbody></table></div><br>" | 427 html += "</tbody></table></div><br>" |
413 return html | 428 return html |
414 | 429 |
415 | 430 |
434 if label_counts.size > 1: | 449 if label_counts.size > 1: |
435 # Force stratify even with fewer samples - adjust validation_size if needed | 450 # Force stratify even with fewer samples - adjust validation_size if needed |
436 min_samples_per_class = label_counts.min() | 451 min_samples_per_class = label_counts.min() |
437 if min_samples_per_class * validation_size < 1: | 452 if min_samples_per_class * validation_size < 1: |
438 # Adjust validation_size to ensure at least 1 sample per class, but do not exceed original validation_size | 453 # Adjust validation_size to ensure at least 1 sample per class, but do not exceed original validation_size |
439 adjusted_validation_size = min(validation_size, 1.0 / min_samples_per_class) | 454 adjusted_validation_size = min( |
455 validation_size, 1.0 / min_samples_per_class | |
456 ) | |
440 if adjusted_validation_size != validation_size: | 457 if adjusted_validation_size != validation_size: |
441 validation_size = adjusted_validation_size | 458 validation_size = adjusted_validation_size |
442 logger.info(f"Adjusted validation_size to {validation_size:.3f} to ensure at least one sample per class in validation") | 459 logger.info( |
460 f"Adjusted validation_size to {validation_size:.3f} to ensure at least one sample per class in validation" | |
461 ) | |
443 stratify_arr = out.loc[idx_train, label_column] | 462 stratify_arr = out.loc[idx_train, label_column] |
444 logger.info("Using stratified split for validation set") | 463 logger.info("Using stratified split for validation set") |
445 else: | 464 else: |
446 logger.warning("Only one label class found; cannot stratify") | 465 logger.warning("Only one label class found; cannot stratify") |
447 if validation_size <= 0: | 466 if validation_size <= 0: |
484 """Create a stratified random split when no split column exists.""" | 503 """Create a stratified random split when no split column exists.""" |
485 out = df.copy() | 504 out = df.copy() |
486 # initialize split column | 505 # initialize split column |
487 out[split_column] = 0 | 506 out[split_column] = 0 |
488 if not label_column or label_column not in out.columns: | 507 if not label_column or label_column not in out.columns: |
489 logger.warning("No label column found; using random split without stratification") | 508 logger.warning( |
509 "No label column found; using random split without stratification" | |
510 ) | |
490 # fall back to simple random assignment | 511 # fall back to simple random assignment |
491 indices = out.index.tolist() | 512 indices = out.index.tolist() |
492 np.random.seed(random_state) | 513 np.random.seed(random_state) |
493 np.random.shuffle(indices) | 514 np.random.shuffle(indices) |
494 n_total = len(indices) | 515 n_total = len(indices) |
527 test_size=split_probabilities[2], | 548 test_size=split_probabilities[2], |
528 random_state=random_state, | 549 random_state=random_state, |
529 stratify=out[label_column], | 550 stratify=out[label_column], |
530 ) | 551 ) |
531 # second split: separate training and validation from remaining data | 552 # second split: separate training and validation from remaining data |
532 val_size_adjusted = split_probabilities[1] / (split_probabilities[0] + split_probabilities[1]) | 553 val_size_adjusted = split_probabilities[1] / ( |
554 split_probabilities[0] + split_probabilities[1] | |
555 ) | |
533 train_idx, val_idx = train_test_split( | 556 train_idx, val_idx = train_test_split( |
534 train_val_idx, | 557 train_val_idx, |
535 test_size=val_size_adjusted, | 558 test_size=val_size_adjusted, |
536 random_state=random_state, | 559 random_state=random_state, |
537 stratify=out.loc[train_val_idx, label_column], | 560 stratify=out.loc[train_val_idx, label_column], |
539 # assign split values | 562 # assign split values |
540 out.loc[train_idx, split_column] = 0 | 563 out.loc[train_idx, split_column] = 0 |
541 out.loc[val_idx, split_column] = 1 | 564 out.loc[val_idx, split_column] = 1 |
542 out.loc[test_idx, split_column] = 2 | 565 out.loc[test_idx, split_column] = 2 |
543 logger.info("Successfully applied stratified random split") | 566 logger.info("Successfully applied stratified random split") |
544 logger.info(f"Split counts: Train={len(train_idx)}, Val={len(val_idx)}, Test={len(test_idx)}") | 567 logger.info( |
568 f"Split counts: Train={len(train_idx)}, Val={len(val_idx)}, Test={len(test_idx)}" | |
569 ) | |
545 return out.astype({split_column: int}) | 570 return out.astype({split_column: int}) |
546 | 571 |
547 | 572 |
548 class Backend(Protocol): | 573 class Backend(Protocol): |
549 """Interface for a machine learning backend.""" | 574 """Interface for a machine learning backend.""" |
575 | |
550 def prepare_config( | 576 def prepare_config( |
551 self, | 577 self, |
552 config_params: Dict[str, Any], | 578 config_params: Dict[str, Any], |
553 split_config: Dict[str, Any], | 579 split_config: Dict[str, Any], |
554 ) -> str: | 580 ) -> str: |
576 ... | 602 ... |
577 | 603 |
578 | 604 |
579 class LudwigDirectBackend: | 605 class LudwigDirectBackend: |
580 """Backend for running Ludwig experiments directly via the internal experiment_cli function.""" | 606 """Backend for running Ludwig experiments directly via the internal experiment_cli function.""" |
607 | |
581 def prepare_config( | 608 def prepare_config( |
582 self, | 609 self, |
583 config_params: Dict[str, Any], | 610 config_params: Dict[str, Any], |
584 split_config: Dict[str, Any], | 611 split_config: Dict[str, Any], |
585 ) -> str: | 612 ) -> str: |
586 logger.info("LudwigDirectBackend: Preparing YAML configuration.") | 613 logger.info("LudwigDirectBackend: Preparing YAML configuration.") |
614 | |
587 model_name = config_params.get("model_name", "resnet18") | 615 model_name = config_params.get("model_name", "resnet18") |
588 use_pretrained = config_params.get("use_pretrained", False) | 616 use_pretrained = config_params.get("use_pretrained", False) |
589 fine_tune = config_params.get("fine_tune", False) | 617 fine_tune = config_params.get("fine_tune", False) |
590 if use_pretrained: | 618 if use_pretrained: |
591 trainable = bool(fine_tune) | 619 trainable = bool(fine_tune) |
604 "use_pretrained": use_pretrained, | 632 "use_pretrained": use_pretrained, |
605 "trainable": trainable, | 633 "trainable": trainable, |
606 } | 634 } |
607 else: | 635 else: |
608 encoder_config = {"type": raw_encoder} | 636 encoder_config = {"type": raw_encoder} |
637 | |
609 batch_size_cfg = batch_size or "auto" | 638 batch_size_cfg = batch_size or "auto" |
639 | |
610 label_column_path = config_params.get("label_column_data_path") | 640 label_column_path = config_params.get("label_column_data_path") |
611 label_series = None | 641 label_series = None |
612 if label_column_path is not None and Path(label_column_path).exists(): | 642 if label_column_path is not None and Path(label_column_path).exists(): |
613 try: | 643 try: |
614 label_series = pd.read_csv(label_column_path)[LABEL_COLUMN_NAME] | 644 label_series = pd.read_csv(label_column_path)[LABEL_COLUMN_NAME] |
615 except Exception as e: | 645 except Exception as e: |
616 logger.warning(f"Could not read label column for task detection: {e}") | 646 logger.warning(f"Could not read label column for task detection: {e}") |
647 | |
617 if ( | 648 if ( |
618 label_series is not None | 649 label_series is not None |
619 and ptypes.is_numeric_dtype(label_series.dtype) | 650 and ptypes.is_numeric_dtype(label_series.dtype) |
620 and label_series.nunique() > 10 | 651 and label_series.nunique() > 10 |
621 ): | 652 ): |
622 task_type = "regression" | 653 task_type = "regression" |
623 else: | 654 else: |
624 task_type = "classification" | 655 task_type = "classification" |
656 | |
625 config_params["task_type"] = task_type | 657 config_params["task_type"] = task_type |
658 | |
626 image_feat: Dict[str, Any] = { | 659 image_feat: Dict[str, Any] = { |
627 "name": IMAGE_PATH_COLUMN_NAME, | 660 "name": IMAGE_PATH_COLUMN_NAME, |
628 "type": "image", | 661 "type": "image", |
629 "encoder": encoder_config, | 662 "encoder": encoder_config, |
630 } | 663 } |
631 if config_params.get("augmentation") is not None: | 664 if config_params.get("augmentation") is not None: |
632 image_feat["augmentation"] = config_params["augmentation"] | 665 image_feat["augmentation"] = config_params["augmentation"] |
666 | |
633 if task_type == "regression": | 667 if task_type == "regression": |
634 output_feat = { | 668 output_feat = { |
635 "name": LABEL_COLUMN_NAME, | 669 "name": LABEL_COLUMN_NAME, |
636 "type": "number", | 670 "type": "number", |
637 "decoder": {"type": "regressor"}, | 671 "decoder": {"type": "regressor"}, |
643 "r2", | 677 "r2", |
644 ] | 678 ] |
645 }, | 679 }, |
646 } | 680 } |
647 val_metric = config_params.get("validation_metric", "mean_squared_error") | 681 val_metric = config_params.get("validation_metric", "mean_squared_error") |
682 | |
648 else: | 683 else: |
649 num_unique_labels = ( | 684 num_unique_labels = ( |
650 label_series.nunique() if label_series is not None else 2 | 685 label_series.nunique() if label_series is not None else 2 |
651 ) | 686 ) |
652 output_type = "binary" if num_unique_labels == 2 else "category" | 687 output_type = "binary" if num_unique_labels == 2 else "category" |
653 output_feat = {"name": LABEL_COLUMN_NAME, "type": output_type} | 688 output_feat = {"name": LABEL_COLUMN_NAME, "type": output_type} |
654 if output_type == "binary" and config_params.get("threshold") is not None: | 689 if output_type == "binary" and config_params.get("threshold") is not None: |
655 output_feat["threshold"] = float(config_params["threshold"]) | 690 output_feat["threshold"] = float(config_params["threshold"]) |
656 val_metric = None | 691 val_metric = None |
692 | |
657 conf: Dict[str, Any] = { | 693 conf: Dict[str, Any] = { |
658 "model_type": "ecd", | 694 "model_type": "ecd", |
659 "input_features": [image_feat], | 695 "input_features": [image_feat], |
660 "output_features": [output_feat], | 696 "output_features": [output_feat], |
661 "combiner": {"type": "concat"}, | 697 "combiner": {"type": "concat"}, |
671 "split": split_config, | 707 "split": split_config, |
672 "num_processes": num_processes, | 708 "num_processes": num_processes, |
673 "in_memory": False, | 709 "in_memory": False, |
674 }, | 710 }, |
675 } | 711 } |
712 | |
676 logger.debug("LudwigDirectBackend: Config dict built.") | 713 logger.debug("LudwigDirectBackend: Config dict built.") |
677 try: | 714 try: |
678 yaml_str = yaml.dump(conf, sort_keys=False, indent=2) | 715 yaml_str = yaml.dump(conf, sort_keys=False, indent=2) |
679 logger.info("LudwigDirectBackend: YAML config generated.") | 716 logger.info("LudwigDirectBackend: YAML config generated.") |
680 return yaml_str | 717 return yaml_str |
692 output_dir: Path, | 729 output_dir: Path, |
693 random_seed: int = 42, | 730 random_seed: int = 42, |
694 ) -> None: | 731 ) -> None: |
695 """Invoke Ludwig's internal experiment_cli function to run the experiment.""" | 732 """Invoke Ludwig's internal experiment_cli function to run the experiment.""" |
696 logger.info("LudwigDirectBackend: Starting experiment execution.") | 733 logger.info("LudwigDirectBackend: Starting experiment execution.") |
734 | |
697 try: | 735 try: |
698 from ludwig.experiment import experiment_cli | 736 from ludwig.experiment import experiment_cli |
699 except ImportError as e: | 737 except ImportError as e: |
700 logger.error( | 738 logger.error( |
701 "LudwigDirectBackend: Could not import experiment_cli.", | 739 "LudwigDirectBackend: Could not import experiment_cli.", |
702 exc_info=True, | 740 exc_info=True, |
703 ) | 741 ) |
704 raise RuntimeError("Ludwig import failed.") from e | 742 raise RuntimeError("Ludwig import failed.") from e |
743 | |
705 output_dir.mkdir(parents=True, exist_ok=True) | 744 output_dir.mkdir(parents=True, exist_ok=True) |
745 | |
706 try: | 746 try: |
707 experiment_cli( | 747 experiment_cli( |
708 dataset=str(dataset_path), | 748 dataset=str(dataset_path), |
709 config=str(config_path), | 749 config=str(config_path), |
710 output_directory=str(output_dir), | 750 output_directory=str(output_dir), |
731 output_dir = Path(output_dir) | 771 output_dir = Path(output_dir) |
732 exp_dirs = sorted( | 772 exp_dirs = sorted( |
733 output_dir.glob("experiment_run*"), | 773 output_dir.glob("experiment_run*"), |
734 key=lambda p: p.stat().st_mtime, | 774 key=lambda p: p.stat().st_mtime, |
735 ) | 775 ) |
776 | |
736 if not exp_dirs: | 777 if not exp_dirs: |
737 logger.warning(f"No experiment run directories found in {output_dir}") | 778 logger.warning(f"No experiment run directories found in {output_dir}") |
738 return None | 779 return None |
780 | |
739 progress_file = exp_dirs[-1] / "model" / "training_progress.json" | 781 progress_file = exp_dirs[-1] / "model" / "training_progress.json" |
740 if not progress_file.exists(): | 782 if not progress_file.exists(): |
741 logger.warning(f"No training_progress.json found in {progress_file}") | 783 logger.warning(f"No training_progress.json found in {progress_file}") |
742 return None | 784 return None |
785 | |
743 try: | 786 try: |
744 with progress_file.open("r", encoding="utf-8") as f: | 787 with progress_file.open("r", encoding="utf-8") as f: |
745 data = json.load(f) | 788 data = json.load(f) |
746 return { | 789 return { |
747 "learning_rate": data.get("learning_rate"), | 790 "learning_rate": data.get("learning_rate"), |
773 logger.error(f"Error converting Parquet to CSV: {e}") | 816 logger.error(f"Error converting Parquet to CSV: {e}") |
774 | 817 |
775 def generate_plots(self, output_dir: Path) -> None: | 818 def generate_plots(self, output_dir: Path) -> None: |
776 """Generate all registered Ludwig visualizations for the latest experiment run.""" | 819 """Generate all registered Ludwig visualizations for the latest experiment run.""" |
777 logger.info("Generating all Ludwig visualizations…") | 820 logger.info("Generating all Ludwig visualizations…") |
821 | |
778 test_plots = { | 822 test_plots = { |
779 "compare_performance", | 823 "compare_performance", |
780 "compare_classifiers_performance_from_prob", | 824 "compare_classifiers_performance_from_prob", |
781 "compare_classifiers_performance_from_pred", | 825 "compare_classifiers_performance_from_pred", |
782 "compare_classifiers_performance_changing_k", | 826 "compare_classifiers_performance_changing_k", |
796 } | 840 } |
797 train_plots = { | 841 train_plots = { |
798 "learning_curves", | 842 "learning_curves", |
799 "compare_classifiers_performance_subset", | 843 "compare_classifiers_performance_subset", |
800 } | 844 } |
845 | |
801 output_dir = Path(output_dir) | 846 output_dir = Path(output_dir) |
802 exp_dirs = sorted( | 847 exp_dirs = sorted( |
803 output_dir.glob("experiment_run*"), | 848 output_dir.glob("experiment_run*"), |
804 key=lambda p: p.stat().st_mtime, | 849 key=lambda p: p.stat().st_mtime, |
805 ) | 850 ) |
806 if not exp_dirs: | 851 if not exp_dirs: |
807 logger.warning(f"No experiment run dirs found in {output_dir}") | 852 logger.warning(f"No experiment run dirs found in {output_dir}") |
808 return | 853 return |
809 exp_dir = exp_dirs[-1] | 854 exp_dir = exp_dirs[-1] |
855 | |
810 viz_dir = exp_dir / "visualizations" | 856 viz_dir = exp_dir / "visualizations" |
811 viz_dir.mkdir(exist_ok=True) | 857 viz_dir.mkdir(exist_ok=True) |
812 train_viz = viz_dir / "train" | 858 train_viz = viz_dir / "train" |
813 test_viz = viz_dir / "test" | 859 test_viz = viz_dir / "test" |
814 train_viz.mkdir(parents=True, exist_ok=True) | 860 train_viz.mkdir(parents=True, exist_ok=True) |
819 | 865 |
820 training_stats = _check(exp_dir / "training_statistics.json") | 866 training_stats = _check(exp_dir / "training_statistics.json") |
821 test_stats = _check(exp_dir / TEST_STATISTICS_FILE_NAME) | 867 test_stats = _check(exp_dir / TEST_STATISTICS_FILE_NAME) |
822 probs_path = _check(exp_dir / PREDICTIONS_PARQUET_FILE_NAME) | 868 probs_path = _check(exp_dir / PREDICTIONS_PARQUET_FILE_NAME) |
823 gt_metadata = _check(exp_dir / "model" / TRAIN_SET_METADATA_FILE_NAME) | 869 gt_metadata = _check(exp_dir / "model" / TRAIN_SET_METADATA_FILE_NAME) |
870 | |
824 dataset_path = None | 871 dataset_path = None |
825 split_file = None | 872 split_file = None |
826 desc = exp_dir / DESCRIPTION_FILE_NAME | 873 desc = exp_dir / DESCRIPTION_FILE_NAME |
827 if desc.exists(): | 874 if desc.exists(): |
828 with open(desc, "r") as f: | 875 with open(desc, "r") as f: |
829 cfg = json.load(f) | 876 cfg = json.load(f) |
830 dataset_path = _check(Path(cfg.get("dataset", ""))) | 877 dataset_path = _check(Path(cfg.get("dataset", ""))) |
831 split_file = _check(Path(get_split_path(cfg.get("dataset", "")))) | 878 split_file = _check(Path(get_split_path(cfg.get("dataset", "")))) |
879 | |
832 output_feature = "" | 880 output_feature = "" |
833 if desc.exists(): | 881 if desc.exists(): |
834 try: | 882 try: |
835 output_feature = cfg["config"]["output_features"][0]["name"] | 883 output_feature = cfg["config"]["output_features"][0]["name"] |
836 except Exception: | 884 except Exception: |
837 pass | 885 pass |
838 if not output_feature and test_stats: | 886 if not output_feature and test_stats: |
839 with open(test_stats, "r") as f: | 887 with open(test_stats, "r") as f: |
840 stats = json.load(f) | 888 stats = json.load(f) |
841 output_feature = next(iter(stats.keys()), "") | 889 output_feature = next(iter(stats.keys()), "") |
890 | |
842 viz_registry = get_visualizations_registry() | 891 viz_registry = get_visualizations_registry() |
843 for viz_name, viz_func in viz_registry.items(): | 892 for viz_name, viz_func in viz_registry.items(): |
844 if viz_name in train_plots: | 893 if viz_name in train_plots: |
845 viz_dir_plot = train_viz | 894 viz_dir_plot = train_viz |
846 elif viz_name in test_plots: | 895 elif viz_name in test_plots: |
847 viz_dir_plot = test_viz | 896 viz_dir_plot = test_viz |
848 else: | 897 else: |
849 continue | 898 continue |
899 | |
850 try: | 900 try: |
851 viz_func( | 901 viz_func( |
852 training_statistics=[training_stats] if training_stats else [], | 902 training_statistics=[training_stats] if training_stats else [], |
853 test_statistics=[test_stats] if test_stats else [], | 903 test_statistics=[test_stats] if test_stats else [], |
854 probabilities=[probs_path] if probs_path else [], | 904 probabilities=[probs_path] if probs_path else [], |
864 file_format="png", | 914 file_format="png", |
865 ) | 915 ) |
866 logger.info(f"✔ Generated {viz_name}") | 916 logger.info(f"✔ Generated {viz_name}") |
867 except Exception as e: | 917 except Exception as e: |
868 logger.warning(f"✘ Skipped {viz_name}: {e}") | 918 logger.warning(f"✘ Skipped {viz_name}: {e}") |
919 | |
869 logger.info(f"All visualizations written to {viz_dir}") | 920 logger.info(f"All visualizations written to {viz_dir}") |
870 | 921 |
871 def generate_html_report( | 922 def generate_html_report( |
872 self, | 923 self, |
873 title: str, | 924 title: str, |
879 cwd = Path.cwd() | 930 cwd = Path.cwd() |
880 report_name = title.lower().replace(" ", "_") + "_report.html" | 931 report_name = title.lower().replace(" ", "_") + "_report.html" |
881 report_path = cwd / report_name | 932 report_path = cwd / report_name |
882 output_dir = Path(output_dir) | 933 output_dir = Path(output_dir) |
883 output_type = None | 934 output_type = None |
935 | |
884 exp_dirs = sorted( | 936 exp_dirs = sorted( |
885 output_dir.glob("experiment_run*"), | 937 output_dir.glob("experiment_run*"), |
886 key=lambda p: p.stat().st_mtime, | 938 key=lambda p: p.stat().st_mtime, |
887 ) | 939 ) |
888 if not exp_dirs: | 940 if not exp_dirs: |
889 raise RuntimeError(f"No 'experiment*' dirs found in {output_dir}") | 941 raise RuntimeError(f"No 'experiment*' dirs found in {output_dir}") |
890 exp_dir = exp_dirs[-1] | 942 exp_dir = exp_dirs[-1] |
943 | |
891 base_viz_dir = exp_dir / "visualizations" | 944 base_viz_dir = exp_dir / "visualizations" |
892 train_viz_dir = base_viz_dir / "train" | 945 train_viz_dir = base_viz_dir / "train" |
893 test_viz_dir = base_viz_dir / "test" | 946 test_viz_dir = base_viz_dir / "test" |
947 | |
894 html = get_html_template() | 948 html = get_html_template() |
895 html += f"<h1>{title}</h1>" | 949 html += f"<h1>{title}</h1>" |
950 | |
896 metrics_html = "" | 951 metrics_html = "" |
897 train_val_metrics_html = "" | 952 train_val_metrics_html = "" |
898 test_metrics_html = "" | 953 test_metrics_html = "" |
899 try: | 954 try: |
900 train_stats_path = exp_dir / "training_statistics.json" | 955 train_stats_path = exp_dir / "training_statistics.json" |
916 ) | 971 ) |
917 except Exception as e: | 972 except Exception as e: |
918 logger.warning( | 973 logger.warning( |
919 f"Could not load stats for HTML report: {type(e).__name__}: {e}" | 974 f"Could not load stats for HTML report: {type(e).__name__}: {e}" |
920 ) | 975 ) |
976 | |
921 config_html = "" | 977 config_html = "" |
922 training_progress = self.get_training_process(output_dir) | 978 training_progress = self.get_training_process(output_dir) |
923 try: | 979 try: |
924 config_html = format_config_table_html( | 980 config_html = format_config_table_html( |
925 config, split_info, training_progress | 981 config, split_info, training_progress, output_type |
926 ) | 982 ) |
927 except Exception as e: | 983 except Exception as e: |
928 logger.warning(f"Could not load config for HTML report: {e}") | 984 logger.warning(f"Could not load config for HTML report: {e}") |
929 | 985 |
930 def render_img_section( | 986 def render_img_section( |
934 return f"<h2>{title}</h2><p><em>Directory not found.</em></p>" | 990 return f"<h2>{title}</h2><p><em>Directory not found.</em></p>" |
935 # collect every PNG | 991 # collect every PNG |
936 imgs = list(dir_path.glob("*.png")) | 992 imgs = list(dir_path.glob("*.png")) |
937 # --- EXCLUDE Ludwig's base confusion matrix and any top-N confusion_matrix files --- | 993 # --- EXCLUDE Ludwig's base confusion matrix and any top-N confusion_matrix files --- |
938 imgs = [ | 994 imgs = [ |
939 img for img in imgs | 995 img |
996 for img in imgs | |
940 if not ( | 997 if not ( |
941 img.name == "confusion_matrix.png" | 998 img.name == "confusion_matrix.png" |
942 or img.name.startswith("confusion_matrix__label_top") | 999 or img.name.startswith("confusion_matrix__label_top") |
943 or img.name == "roc_curves.png" | 1000 or img.name == "roc_curves.png" |
944 ) | 1001 ) |
970 ] | 1027 ] |
971 # filter and order | 1028 # filter and order |
972 valid_imgs = [img for img in imgs if img.name not in unwanted] | 1029 valid_imgs = [img for img in imgs if img.name not in unwanted] |
973 img_map = {img.name: img for img in valid_imgs} | 1030 img_map = {img.name: img for img in valid_imgs} |
974 ordered = [img_map[n] for n in display_order if n in img_map] | 1031 ordered = [img_map[n] for n in display_order if n in img_map] |
975 others = sorted(img for img in valid_imgs if img.name not in display_order) | 1032 others = sorted( |
1033 img for img in valid_imgs if img.name not in display_order | |
1034 ) | |
976 imgs = ordered + others | 1035 imgs = ordered + others |
977 else: | 1036 else: |
978 # regression: just sort whatever's left | 1037 # regression: just sort whatever's left |
979 imgs = sorted(imgs) | 1038 imgs = sorted(imgs) |
980 # render each remaining PNG | 1039 # render each remaining PNG |
1010 if pred_col is None: | 1069 if pred_col is None: |
1011 raise ValueError("No prediction column found in Parquet output") | 1070 raise ValueError("No prediction column found in Parquet output") |
1012 df_pred = df_preds[[pred_col]].rename(columns={pred_col: "prediction"}) | 1071 df_pred = df_preds[[pred_col]].rename(columns={pred_col: "prediction"}) |
1013 # 2) load ground truth for the test split from prepared CSV | 1072 # 2) load ground truth for the test split from prepared CSV |
1014 df_all = pd.read_csv(config["label_column_data_path"]) | 1073 df_all = pd.read_csv(config["label_column_data_path"]) |
1015 df_gt = df_all[df_all[SPLIT_COLUMN_NAME] == 2][LABEL_COLUMN_NAME].reset_index(drop=True) | 1074 df_gt = df_all[df_all[SPLIT_COLUMN_NAME] == 2][ |
1075 LABEL_COLUMN_NAME | |
1076 ].reset_index(drop=True) | |
1016 # 3) concatenate side-by-side | 1077 # 3) concatenate side-by-side |
1017 df_table = pd.concat([df_gt, df_pred], axis=1) | 1078 df_table = pd.concat([df_gt, df_pred], axis=1) |
1018 df_table.columns = [LABEL_COLUMN_NAME, "prediction"] | 1079 df_table.columns = [LABEL_COLUMN_NAME, "prediction"] |
1019 # 4) render as HTML | 1080 # 4) render as HTML |
1020 preds_html = df_table.to_html(index=False, classes="predictions-table") | 1081 preds_html = df_table.to_html(index=False, classes="predictions-table") |
1034 str(training_stats_path), | 1095 str(training_stats_path), |
1035 ) | 1096 ) |
1036 for plot in interactive_plots: | 1097 for plot in interactive_plots: |
1037 # 2) inject the static "roc_curves_from_prediction_statistics.png" | 1098 # 2) inject the static "roc_curves_from_prediction_statistics.png" |
1038 if plot["title"] == "ROC-AUC": | 1099 if plot["title"] == "ROC-AUC": |
1039 static_img = test_viz_dir / "roc_curves_from_prediction_statistics.png" | 1100 static_img = ( |
1101 test_viz_dir / "roc_curves_from_prediction_statistics.png" | |
1102 ) | |
1040 if static_img.exists(): | 1103 if static_img.exists(): |
1041 b64 = encode_image_to_base64(str(static_img)) | 1104 b64 = encode_image_to_base64(str(static_img)) |
1042 tab3_content += ( | 1105 tab3_content += ( |
1043 "<h2 style='text-align: center;'>" | 1106 "<h2 style='text-align: center;'>" |
1044 "Roc Curves From Prediction Statistics" | 1107 "Roc Curves From Prediction Statistics" |
1052 tab3_content += ( | 1115 tab3_content += ( |
1053 f"<h2 style='text-align: center;'>{plot['title']}</h2>" | 1116 f"<h2 style='text-align: center;'>{plot['title']}</h2>" |
1054 + plot["html"] | 1117 + plot["html"] |
1055 ) | 1118 ) |
1056 tab3_content += render_img_section( | 1119 tab3_content += render_img_section( |
1057 "Test Visualizations", | 1120 "Test Visualizations", test_viz_dir, output_type |
1058 test_viz_dir, | |
1059 output_type | |
1060 ) | 1121 ) |
1061 # assemble the tabs and help modal | 1122 # assemble the tabs and help modal |
1062 tabbed_html = build_tabbed_html(tab1_content, tab2_content, tab3_content) | 1123 tabbed_html = build_tabbed_html(tab1_content, tab2_content, tab3_content) |
1063 modal_html = get_metrics_help_modal() | 1124 modal_html = get_metrics_help_modal() |
1064 html += tabbed_html + modal_html + get_html_closing() | 1125 html += tabbed_html + modal_html + get_html_closing() |
1126 | |
1065 try: | 1127 try: |
1066 with open(report_path, "w") as f: | 1128 with open(report_path, "w") as f: |
1067 f.write(html) | 1129 f.write(html) |
1068 logger.info(f"HTML report generated at: {report_path}") | 1130 logger.info(f"HTML report generated at: {report_path}") |
1069 except Exception as e: | 1131 except Exception as e: |
1070 logger.error(f"Failed to write HTML report: {e}") | 1132 logger.error(f"Failed to write HTML report: {e}") |
1071 raise | 1133 raise |
1134 | |
1072 return report_path | 1135 return report_path |
1073 | 1136 |
1074 | 1137 |
1075 class WorkflowOrchestrator: | 1138 class WorkflowOrchestrator: |
1076 """Manages the image-classification workflow.""" | 1139 """Manages the image-classification workflow.""" |
1140 | |
1077 def __init__(self, args: argparse.Namespace, backend: Backend): | 1141 def __init__(self, args: argparse.Namespace, backend: Backend): |
1078 self.args = args | 1142 self.args = args |
1079 self.backend = backend | 1143 self.backend = backend |
1080 self.temp_dir: Optional[Path] = None | 1144 self.temp_dir: Optional[Path] = None |
1081 self.image_extract_dir: Optional[Path] = None | 1145 self.image_extract_dir: Optional[Path] = None |
1111 | 1175 |
1112 def _prepare_data(self) -> Tuple[Path, Dict[str, Any], str]: | 1176 def _prepare_data(self) -> Tuple[Path, Dict[str, Any], str]: |
1113 """Load CSV, update image paths, handle splits, and write prepared CSV.""" | 1177 """Load CSV, update image paths, handle splits, and write prepared CSV.""" |
1114 if not self.temp_dir or not self.image_extract_dir: | 1178 if not self.temp_dir or not self.image_extract_dir: |
1115 raise RuntimeError("Temp dirs not initialized before data prep.") | 1179 raise RuntimeError("Temp dirs not initialized before data prep.") |
1180 | |
1116 try: | 1181 try: |
1117 df = pd.read_csv(self.args.csv_file) | 1182 df = pd.read_csv(self.args.csv_file) |
1118 logger.info(f"Loaded CSV: {self.args.csv_file}") | 1183 logger.info(f"Loaded CSV: {self.args.csv_file}") |
1119 except Exception: | 1184 except Exception: |
1120 logger.error("Error loading CSV file", exc_info=True) | 1185 logger.error("Error loading CSV file", exc_info=True) |
1121 raise | 1186 raise |
1187 | |
1122 required = {IMAGE_PATH_COLUMN_NAME, LABEL_COLUMN_NAME} | 1188 required = {IMAGE_PATH_COLUMN_NAME, LABEL_COLUMN_NAME} |
1123 missing = required - set(df.columns) | 1189 missing = required - set(df.columns) |
1124 if missing: | 1190 if missing: |
1125 raise ValueError(f"Missing CSV columns: {', '.join(missing)}") | 1191 raise ValueError(f"Missing CSV columns: {', '.join(missing)}") |
1192 | |
1126 try: | 1193 try: |
1127 df[IMAGE_PATH_COLUMN_NAME] = df[IMAGE_PATH_COLUMN_NAME].apply( | 1194 df[IMAGE_PATH_COLUMN_NAME] = df[IMAGE_PATH_COLUMN_NAME].apply( |
1128 lambda p: str((self.image_extract_dir / p).resolve()) | 1195 lambda p: str((self.image_extract_dir / p).resolve()) |
1129 ) | 1196 ) |
1130 except Exception: | 1197 except Exception: |
1148 split_info = ( | 1215 split_info = ( |
1149 f"No split column in CSV. Created stratified random split: " | 1216 f"No split column in CSV. Created stratified random split: " |
1150 f"{[int(p * 100) for p in self.args.split_probabilities]}% " | 1217 f"{[int(p * 100) for p in self.args.split_probabilities]}% " |
1151 f"for train/val/test with balanced label distribution." | 1218 f"for train/val/test with balanced label distribution." |
1152 ) | 1219 ) |
1220 | |
1153 final_csv = self.temp_dir / TEMP_CSV_FILENAME | 1221 final_csv = self.temp_dir / TEMP_CSV_FILENAME |
1222 | |
1154 try: | 1223 try: |
1155 df.to_csv(final_csv, index=False) | 1224 df.to_csv(final_csv, index=False) |
1156 logger.info(f"Saved prepared data to {final_csv}") | 1225 logger.info(f"Saved prepared data to {final_csv}") |
1157 except Exception: | 1226 except Exception: |
1158 logger.error("Error saving prepared CSV", exc_info=True) | 1227 logger.error("Error saving prepared CSV", exc_info=True) |
1159 raise | 1228 raise |
1229 | |
1160 return final_csv, split_config, split_info | 1230 return final_csv, split_config, split_info |
1161 | 1231 |
1162 def _process_fixed_split( | 1232 def _process_fixed_split( |
1163 self, df: pd.DataFrame | 1233 self, df: pd.DataFrame |
1164 ) -> Tuple[pd.DataFrame, Dict[str, Any], str]: | 1234 ) -> Tuple[pd.DataFrame, Dict[str, Any], str]: |
1169 df[SPLIT_COLUMN_NAME] = pd.to_numeric(col, errors="coerce").astype( | 1239 df[SPLIT_COLUMN_NAME] = pd.to_numeric(col, errors="coerce").astype( |
1170 pd.Int64Dtype() | 1240 pd.Int64Dtype() |
1171 ) | 1241 ) |
1172 if df[SPLIT_COLUMN_NAME].isna().any(): | 1242 if df[SPLIT_COLUMN_NAME].isna().any(): |
1173 logger.warning("Split column contains non-numeric/missing values.") | 1243 logger.warning("Split column contains non-numeric/missing values.") |
1244 | |
1174 unique = set(df[SPLIT_COLUMN_NAME].dropna().unique()) | 1245 unique = set(df[SPLIT_COLUMN_NAME].dropna().unique()) |
1175 logger.info(f"Unique split values: {unique}") | 1246 logger.info(f"Unique split values: {unique}") |
1176 if unique == {0, 2}: | 1247 if unique == {0, 2}: |
1177 df = split_data_0_2( | 1248 df = split_data_0_2( |
1178 df, | 1249 df, |
1191 elif unique.issubset({0, 1, 2}): | 1262 elif unique.issubset({0, 1, 2}): |
1192 split_info = "Used user-defined split column from CSV." | 1263 split_info = "Used user-defined split column from CSV." |
1193 logger.info("Using fixed split as-is.") | 1264 logger.info("Using fixed split as-is.") |
1194 else: | 1265 else: |
1195 raise ValueError(f"Unexpected split values: {unique}") | 1266 raise ValueError(f"Unexpected split values: {unique}") |
1267 | |
1196 return df, {"type": "fixed", "column": SPLIT_COLUMN_NAME}, split_info | 1268 return df, {"type": "fixed", "column": SPLIT_COLUMN_NAME}, split_info |
1269 | |
1197 except Exception: | 1270 except Exception: |
1198 logger.error("Error processing fixed split", exc_info=True) | 1271 logger.error("Error processing fixed split", exc_info=True) |
1199 raise | 1272 raise |
1200 | 1273 |
1201 def _cleanup_temp_dirs(self) -> None: | 1274 def _cleanup_temp_dirs(self) -> None: |
1207 | 1280 |
1208 def run(self) -> None: | 1281 def run(self) -> None: |
1209 """Execute the full workflow end-to-end.""" | 1282 """Execute the full workflow end-to-end.""" |
1210 logger.info("Starting workflow...") | 1283 logger.info("Starting workflow...") |
1211 self.args.output_dir.mkdir(parents=True, exist_ok=True) | 1284 self.args.output_dir.mkdir(parents=True, exist_ok=True) |
1285 | |
1212 try: | 1286 try: |
1213 self._create_temp_dirs() | 1287 self._create_temp_dirs() |
1214 self._extract_images() | 1288 self._extract_images() |
1215 csv_path, split_cfg, split_info = self._prepare_data() | 1289 csv_path, split_cfg, split_info = self._prepare_data() |
1290 | |
1216 use_pretrained = self.args.use_pretrained or self.args.fine_tune | 1291 use_pretrained = self.args.use_pretrained or self.args.fine_tune |
1292 | |
1217 backend_args = { | 1293 backend_args = { |
1218 "model_name": self.args.model_name, | 1294 "model_name": self.args.model_name, |
1219 "fine_tune": self.args.fine_tune, | 1295 "fine_tune": self.args.fine_tune, |
1220 "use_pretrained": use_pretrained, | 1296 "use_pretrained": use_pretrained, |
1221 "epochs": self.args.epochs, | 1297 "epochs": self.args.epochs, |
1228 "label_column_data_path": csv_path, | 1304 "label_column_data_path": csv_path, |
1229 "augmentation": self.args.augmentation, | 1305 "augmentation": self.args.augmentation, |
1230 "threshold": self.args.threshold, | 1306 "threshold": self.args.threshold, |
1231 } | 1307 } |
1232 yaml_str = self.backend.prepare_config(backend_args, split_cfg) | 1308 yaml_str = self.backend.prepare_config(backend_args, split_cfg) |
1309 | |
1233 config_file = self.temp_dir / TEMP_CONFIG_FILENAME | 1310 config_file = self.temp_dir / TEMP_CONFIG_FILENAME |
1234 config_file.write_text(yaml_str) | 1311 config_file.write_text(yaml_str) |
1235 logger.info(f"Wrote backend config: {config_file}") | 1312 logger.info(f"Wrote backend config: {config_file}") |
1313 | |
1236 self.backend.run_experiment( | 1314 self.backend.run_experiment( |
1237 csv_path, | 1315 csv_path, |
1238 config_file, | 1316 config_file, |
1239 self.args.output_dir, | 1317 self.args.output_dir, |
1240 self.args.random_seed, | 1318 self.args.random_seed, |
1372 nargs=3, | 1450 nargs=3, |
1373 metavar=("train", "val", "test"), | 1451 metavar=("train", "val", "test"), |
1374 action=SplitProbAction, | 1452 action=SplitProbAction, |
1375 default=[0.7, 0.1, 0.2], | 1453 default=[0.7, 0.1, 0.2], |
1376 help=( | 1454 help=( |
1377 "Random split proportions (e.g., 0.7 0.1 0.2)." | 1455 "Random split proportions (e.g., 0.7 0.1 0.2).Only used if no split column." |
1378 "Only used if no split column." | |
1379 ), | 1456 ), |
1380 ) | 1457 ) |
1381 parser.add_argument( | 1458 parser.add_argument( |
1382 "--random-seed", | 1459 "--random-seed", |
1383 type=int, | 1460 type=int, |
1406 type=float, | 1483 type=float, |
1407 default=None, | 1484 default=None, |
1408 help=( | 1485 help=( |
1409 "Decision threshold for binary classification (0.0–1.0)." | 1486 "Decision threshold for binary classification (0.0–1.0)." |
1410 "Overrides default 0.5." | 1487 "Overrides default 0.5." |
1411 ) | 1488 ), |
1412 ) | 1489 ) |
1413 args = parser.parse_args() | 1490 args = parser.parse_args() |
1491 | |
1414 if not 0.0 <= args.validation_size <= 1.0: | 1492 if not 0.0 <= args.validation_size <= 1.0: |
1415 parser.error("validation-size must be between 0.0 and 1.0") | 1493 parser.error("validation-size must be between 0.0 and 1.0") |
1416 if not args.csv_file.is_file(): | 1494 if not args.csv_file.is_file(): |
1417 parser.error(f"CSV not found: {args.csv_file}") | 1495 parser.error(f"CSV not found: {args.csv_file}") |
1418 if not args.image_zip.is_file(): | 1496 if not args.image_zip.is_file(): |
1421 try: | 1499 try: |
1422 augmentation_setup = aug_parse(args.augmentation) | 1500 augmentation_setup = aug_parse(args.augmentation) |
1423 setattr(args, "augmentation", augmentation_setup) | 1501 setattr(args, "augmentation", augmentation_setup) |
1424 except ValueError as e: | 1502 except ValueError as e: |
1425 parser.error(str(e)) | 1503 parser.error(str(e)) |
1504 | |
1426 backend_instance = LudwigDirectBackend() | 1505 backend_instance = LudwigDirectBackend() |
1427 orchestrator = WorkflowOrchestrator(args, backend_instance) | 1506 orchestrator = WorkflowOrchestrator(args, backend_instance) |
1507 | |
1428 exit_code = 0 | 1508 exit_code = 0 |
1429 try: | 1509 try: |
1430 orchestrator.run() | 1510 orchestrator.run() |
1431 logger.info("Main script finished successfully.") | 1511 logger.info("Main script finished successfully.") |
1432 except Exception as e: | 1512 except Exception as e: |
1437 | 1517 |
1438 | 1518 |
1439 if __name__ == "__main__": | 1519 if __name__ == "__main__": |
1440 try: | 1520 try: |
1441 import ludwig | 1521 import ludwig |
1522 | |
1442 logger.debug(f"Found Ludwig version: {ludwig.globals.LUDWIG_VERSION}") | 1523 logger.debug(f"Found Ludwig version: {ludwig.globals.LUDWIG_VERSION}") |
1443 except ImportError: | 1524 except ImportError: |
1444 logger.error( | 1525 logger.error( |
1445 "Ludwig library not found. Please ensure Ludwig is installed " | 1526 "Ludwig library not found. Please ensure Ludwig is installed " |
1446 "('pip install ludwig[image]')" | 1527 "('pip install ludwig[image]')" |
1447 ) | 1528 ) |
1448 sys.exit(1) | 1529 sys.exit(1) |
1530 | |
1449 main() | 1531 main() |