Mercurial > repos > goeckslab > image_learner
comparison image_learner_cli.py @ 9:9e912fce264c draft default tip
planemo upload for repository https://github.com/goeckslab/gleam.git commit eace0d7c2b2939029c052991d238a54947d2e191
| author | goeckslab |
|---|---|
| date | Wed, 27 Aug 2025 21:02:48 +0000 |
| parents | 85e6f4b2ad18 |
| children |
comparison
equal
deleted
inserted
replaced
| 8:85e6f4b2ad18 | 9:9e912fce264c |
|---|---|
| 19 METRIC_DISPLAY_NAMES, | 19 METRIC_DISPLAY_NAMES, |
| 20 MODEL_ENCODER_TEMPLATES, | 20 MODEL_ENCODER_TEMPLATES, |
| 21 SPLIT_COLUMN_NAME, | 21 SPLIT_COLUMN_NAME, |
| 22 TEMP_CONFIG_FILENAME, | 22 TEMP_CONFIG_FILENAME, |
| 23 TEMP_CSV_FILENAME, | 23 TEMP_CSV_FILENAME, |
| 24 TEMP_DIR_PREFIX | 24 TEMP_DIR_PREFIX, |
| 25 ) | 25 ) |
| 26 from ludwig.globals import ( | 26 from ludwig.globals import ( |
| 27 DESCRIPTION_FILE_NAME, | 27 DESCRIPTION_FILE_NAME, |
| 28 PREDICTIONS_PARQUET_FILE_NAME, | 28 PREDICTIONS_PARQUET_FILE_NAME, |
| 29 TEST_STATISTICS_FILE_NAME, | 29 TEST_STATISTICS_FILE_NAME, |
| 36 from utils import ( | 36 from utils import ( |
| 37 build_tabbed_html, | 37 build_tabbed_html, |
| 38 encode_image_to_base64, | 38 encode_image_to_base64, |
| 39 get_html_closing, | 39 get_html_closing, |
| 40 get_html_template, | 40 get_html_template, |
| 41 get_metrics_help_modal | 41 get_metrics_help_modal, |
| 42 ) | 42 ) |
| 43 | 43 |
| 44 # --- Logging Setup --- | 44 # --- Logging Setup --- |
| 45 logging.basicConfig( | 45 logging.basicConfig( |
| 46 level=logging.INFO, | 46 level=logging.INFO, |
| 47 format='%(asctime)s %(levelname)s %(name)s: %(message)s', | 47 format="%(asctime)s %(levelname)s %(name)s: %(message)s", |
| 48 ) | 48 ) |
| 49 logger = logging.getLogger("ImageLearner") | 49 logger = logging.getLogger("ImageLearner") |
| 50 | 50 |
| 51 | 51 |
| 52 def format_config_table_html( | 52 def format_config_table_html( |
| 65 "learning_rate", | 65 "learning_rate", |
| 66 "random_seed", | 66 "random_seed", |
| 67 "early_stop", | 67 "early_stop", |
| 68 "threshold", | 68 "threshold", |
| 69 ] | 69 ] |
| 70 | |
| 70 rows = [] | 71 rows = [] |
| 72 | |
| 71 for key in display_keys: | 73 for key in display_keys: |
| 72 val = config.get(key, None) | 74 val = config.get(key, None) |
| 73 if key == "threshold": | 75 if key == "threshold": |
| 74 if output_type != "binary": | 76 if output_type != "binary": |
| 75 continue | 77 continue |
| 132 ) | 134 ) |
| 133 else: | 135 else: |
| 134 val_str = val | 136 val_str = val |
| 135 else: | 137 else: |
| 136 val_str = val if val is not None else "N/A" | 138 val_str = val if val is not None else "N/A" |
| 137 if val_str == "N/A" and key not in ["task_type"]: # Skip if N/A for non-essential | 139 if val_str == "N/A" and key not in [ |
| 140 "task_type" | |
| 141 ]: # Skip if N/A for non-essential | |
| 138 continue | 142 continue |
| 139 rows.append( | 143 rows.append( |
| 140 f"<tr>" | 144 f"<tr>" |
| 141 f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left;'>" | 145 f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left;'>" |
| 142 f"{key.replace('_', ' ').title()}</td>" | 146 f"{key.replace('_', ' ').title()}</td>" |
| 164 <thead><tr> | 168 <thead><tr> |
| 165 <th style="padding: 10px; border: 1px solid #ccc; text-align: left; white-space: nowrap; overflow: hidden; text-overflow: ellipsis;">Parameter</th> | 169 <th style="padding: 10px; border: 1px solid #ccc; text-align: left; white-space: nowrap; overflow: hidden; text-overflow: ellipsis;">Parameter</th> |
| 166 <th style="padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap; overflow: hidden; text-overflow: ellipsis;">Value</th> | 170 <th style="padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap; overflow: hidden; text-overflow: ellipsis;">Value</th> |
| 167 </tr></thead> | 171 </tr></thead> |
| 168 <tbody> | 172 <tbody> |
| 169 {''.join(rows)} | 173 {"".join(rows)} |
| 170 </tbody> | 174 </tbody> |
| 171 </table> | 175 </table> |
| 172 </div><br> | 176 </div><br> |
| 173 <p style="text-align: center; font-size: 0.9em;"> | 177 <p style="text-align: center; font-size: 0.9em;"> |
| 174 Model trained using <a href="https://ludwig.ai/" target="_blank" rel="noopener noreferrer">Ludwig</a>. | 178 Model trained using <a href="https://ludwig.ai/" target="_blank" rel="noopener noreferrer">Ludwig</a>. |
| 249 "accuracy_micro": get_last_value(label_stats, "accuracy_micro"), | 253 "accuracy_micro": get_last_value(label_stats, "accuracy_micro"), |
| 250 "loss": get_last_value(label_stats, "loss"), | 254 "loss": get_last_value(label_stats, "loss"), |
| 251 "roc_auc": get_last_value(label_stats, "roc_auc"), | 255 "roc_auc": get_last_value(label_stats, "roc_auc"), |
| 252 "hits_at_k": get_last_value(label_stats, "hits_at_k"), | 256 "hits_at_k": get_last_value(label_stats, "hits_at_k"), |
| 253 } | 257 } |
| 258 | |
| 254 # Test metrics: dynamic extraction according to exclusions | 259 # Test metrics: dynamic extraction according to exclusions |
| 255 test_label_stats = test_stats.get("label", {}) | 260 test_label_stats = test_stats.get("label", {}) |
| 256 if not test_label_stats: | 261 if not test_label_stats: |
| 257 logging.warning("No label statistics found for test split") | 262 logging.warning("No label statistics found for test split") |
| 258 else: | 263 else: |
| 259 combined_stats = test_stats.get("combined", {}) | 264 combined_stats = test_stats.get("combined", {}) |
| 260 overall_stats = test_label_stats.get("overall_stats", {}) | 265 overall_stats = test_label_stats.get("overall_stats", {}) |
| 266 | |
| 261 # Define exclusions | 267 # Define exclusions |
| 262 if output_type == "binary": | 268 if output_type == "binary": |
| 263 exclude = {"per_class_stats", "precision_recall_curve", "roc_curve"} | 269 exclude = {"per_class_stats", "precision_recall_curve", "roc_curve"} |
| 264 else: | 270 else: |
| 265 exclude = {"per_class_stats", "confusion_matrix"} | 271 exclude = {"per_class_stats", "confusion_matrix"} |
| 272 | |
| 266 # 1. Get all scalar test_label_stats not excluded | 273 # 1. Get all scalar test_label_stats not excluded |
| 267 test_metrics = {} | 274 test_metrics = {} |
| 268 for k, v in test_label_stats.items(): | 275 for k, v in test_label_stats.items(): |
| 269 if k in exclude: | 276 if k in exclude: |
| 270 continue | 277 continue |
| 271 if k == "overall_stats": | 278 if k == "overall_stats": |
| 272 continue | 279 continue |
| 273 if isinstance(v, (int, float, str, bool)): | 280 if isinstance(v, (int, float, str, bool)): |
| 274 test_metrics[k] = v | 281 test_metrics[k] = v |
| 282 | |
| 275 # 2. Add overall_stats (flattened) | 283 # 2. Add overall_stats (flattened) |
| 276 for k, v in overall_stats.items(): | 284 for k, v in overall_stats.items(): |
| 277 test_metrics[k] = v | 285 test_metrics[k] = v |
| 286 | |
| 278 # 3. Optionally include combined/loss if present and not already | 287 # 3. Optionally include combined/loss if present and not already |
| 279 if "loss" in combined_stats and "loss" not in test_metrics: | 288 if "loss" in combined_stats and "loss" not in test_metrics: |
| 280 test_metrics["loss"] = combined_stats["loss"] | 289 test_metrics["loss"] = combined_stats["loss"] |
| 281 metrics["test"] = test_metrics | 290 metrics["test"] = test_metrics |
| 282 return metrics | 291 return metrics |
| 313 t = all_metrics["training"].get(metric_key) | 322 t = all_metrics["training"].get(metric_key) |
| 314 v = all_metrics["validation"].get(metric_key) | 323 v = all_metrics["validation"].get(metric_key) |
| 315 te = all_metrics["test"].get(metric_key) | 324 te = all_metrics["test"].get(metric_key) |
| 316 if all(x is not None for x in [t, v, te]): | 325 if all(x is not None for x in [t, v, te]): |
| 317 rows.append([display_name, f"{t:.4f}", f"{v:.4f}", f"{te:.4f}"]) | 326 rows.append([display_name, f"{t:.4f}", f"{v:.4f}", f"{te:.4f}"]) |
| 327 | |
| 318 if not rows: | 328 if not rows: |
| 319 return "<table><tr><td>No metric values found.</td></tr></table>" | 329 return "<table><tr><td>No metric values found.</td></tr></table>" |
| 330 | |
| 320 html = ( | 331 html = ( |
| 321 "<h2 style='text-align: center;'>Model Performance Summary</h2>" | 332 "<h2 style='text-align: center;'>Model Performance Summary</h2>" |
| 322 "<div style='display: flex; justify-content: center;'>" | 333 "<div style='display: flex; justify-content: center;'>" |
| 323 "<table class='performance-summary' style='border-collapse: collapse;'>" | 334 "<table class='performance-summary' style='border-collapse: collapse;'>" |
| 324 "<thead><tr>" | 335 "<thead><tr>" |
| 329 "</tr></thead><tbody>" | 340 "</tr></thead><tbody>" |
| 330 ) | 341 ) |
| 331 for row in rows: | 342 for row in rows: |
| 332 html += generate_table_row( | 343 html += generate_table_row( |
| 333 row, | 344 row, |
| 334 "padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;" | 345 "padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;", |
| 335 ) | 346 ) |
| 336 html += "</tbody></table></div><br>" | 347 html += "</tbody></table></div><br>" |
| 337 return html | 348 return html |
| 338 | 349 |
| 339 | 350 |
| 355 ) | 366 ) |
| 356 t = all_metrics["training"].get(metric_key) | 367 t = all_metrics["training"].get(metric_key) |
| 357 v = all_metrics["validation"].get(metric_key) | 368 v = all_metrics["validation"].get(metric_key) |
| 358 if t is not None and v is not None: | 369 if t is not None and v is not None: |
| 359 rows.append([display_name, f"{t:.4f}", f"{v:.4f}"]) | 370 rows.append([display_name, f"{t:.4f}", f"{v:.4f}"]) |
| 371 | |
| 360 if not rows: | 372 if not rows: |
| 361 return "<table><tr><td>No metric values found for Train/Validation.</td></tr></table>" | 373 return "<table><tr><td>No metric values found for Train/Validation.</td></tr></table>" |
| 374 | |
| 362 html = ( | 375 html = ( |
| 363 "<h2 style='text-align: center;'>Train/Validation Performance Summary</h2>" | 376 "<h2 style='text-align: center;'>Train/Validation Performance Summary</h2>" |
| 364 "<div style='display: flex; justify-content: center;'>" | 377 "<div style='display: flex; justify-content: center;'>" |
| 365 "<table class='performance-summary' style='border-collapse: collapse;'>" | 378 "<table class='performance-summary' style='border-collapse: collapse;'>" |
| 366 "<thead><tr>" | 379 "<thead><tr>" |
| 370 "</tr></thead><tbody>" | 383 "</tr></thead><tbody>" |
| 371 ) | 384 ) |
| 372 for row in rows: | 385 for row in rows: |
| 373 html += generate_table_row( | 386 html += generate_table_row( |
| 374 row, | 387 row, |
| 375 "padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;" | 388 "padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;", |
| 376 ) | 389 ) |
| 377 html += "</tbody></table></div><br>" | 390 html += "</tbody></table></div><br>" |
| 378 return html | 391 return html |
| 379 | 392 |
| 380 | 393 |
| 391 for key in sorted(test_metrics.keys()): | 404 for key in sorted(test_metrics.keys()): |
| 392 display_name = METRIC_DISPLAY_NAMES.get(key, key.replace("_", " ").title()) | 405 display_name = METRIC_DISPLAY_NAMES.get(key, key.replace("_", " ").title()) |
| 393 value = test_metrics[key] | 406 value = test_metrics[key] |
| 394 if value is not None: | 407 if value is not None: |
| 395 rows.append([display_name, f"{value:.4f}"]) | 408 rows.append([display_name, f"{value:.4f}"]) |
| 409 | |
| 396 if not rows: | 410 if not rows: |
| 397 return "<table><tr><td>No test metric values found.</td></tr></table>" | 411 return "<table><tr><td>No test metric values found.</td></tr></table>" |
| 412 | |
| 398 html = ( | 413 html = ( |
| 399 "<h2 style='text-align: center;'>Test Performance Summary</h2>" | 414 "<h2 style='text-align: center;'>Test Performance Summary</h2>" |
| 400 "<div style='display: flex; justify-content: center;'>" | 415 "<div style='display: flex; justify-content: center;'>" |
| 401 "<table class='performance-summary' style='border-collapse: collapse;'>" | 416 "<table class='performance-summary' style='border-collapse: collapse;'>" |
| 402 "<thead><tr>" | 417 "<thead><tr>" |
| 405 "</tr></thead><tbody>" | 420 "</tr></thead><tbody>" |
| 406 ) | 421 ) |
| 407 for row in rows: | 422 for row in rows: |
| 408 html += generate_table_row( | 423 html += generate_table_row( |
| 409 row, | 424 row, |
| 410 "padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;" | 425 "padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;", |
| 411 ) | 426 ) |
| 412 html += "</tbody></table></div><br>" | 427 html += "</tbody></table></div><br>" |
| 413 return html | 428 return html |
| 414 | 429 |
| 415 | 430 |
| 434 if label_counts.size > 1: | 449 if label_counts.size > 1: |
| 435 # Force stratify even with fewer samples - adjust validation_size if needed | 450 # Force stratify even with fewer samples - adjust validation_size if needed |
| 436 min_samples_per_class = label_counts.min() | 451 min_samples_per_class = label_counts.min() |
| 437 if min_samples_per_class * validation_size < 1: | 452 if min_samples_per_class * validation_size < 1: |
| 438 # Adjust validation_size to ensure at least 1 sample per class, but do not exceed original validation_size | 453 # Adjust validation_size to ensure at least 1 sample per class, but do not exceed original validation_size |
| 439 adjusted_validation_size = min(validation_size, 1.0 / min_samples_per_class) | 454 adjusted_validation_size = min( |
| 455 validation_size, 1.0 / min_samples_per_class | |
| 456 ) | |
| 440 if adjusted_validation_size != validation_size: | 457 if adjusted_validation_size != validation_size: |
| 441 validation_size = adjusted_validation_size | 458 validation_size = adjusted_validation_size |
| 442 logger.info(f"Adjusted validation_size to {validation_size:.3f} to ensure at least one sample per class in validation") | 459 logger.info( |
| 460 f"Adjusted validation_size to {validation_size:.3f} to ensure at least one sample per class in validation" | |
| 461 ) | |
| 443 stratify_arr = out.loc[idx_train, label_column] | 462 stratify_arr = out.loc[idx_train, label_column] |
| 444 logger.info("Using stratified split for validation set") | 463 logger.info("Using stratified split for validation set") |
| 445 else: | 464 else: |
| 446 logger.warning("Only one label class found; cannot stratify") | 465 logger.warning("Only one label class found; cannot stratify") |
| 447 if validation_size <= 0: | 466 if validation_size <= 0: |
| 484 """Create a stratified random split when no split column exists.""" | 503 """Create a stratified random split when no split column exists.""" |
| 485 out = df.copy() | 504 out = df.copy() |
| 486 # initialize split column | 505 # initialize split column |
| 487 out[split_column] = 0 | 506 out[split_column] = 0 |
| 488 if not label_column or label_column not in out.columns: | 507 if not label_column or label_column not in out.columns: |
| 489 logger.warning("No label column found; using random split without stratification") | 508 logger.warning( |
| 509 "No label column found; using random split without stratification" | |
| 510 ) | |
| 490 # fall back to simple random assignment | 511 # fall back to simple random assignment |
| 491 indices = out.index.tolist() | 512 indices = out.index.tolist() |
| 492 np.random.seed(random_state) | 513 np.random.seed(random_state) |
| 493 np.random.shuffle(indices) | 514 np.random.shuffle(indices) |
| 494 n_total = len(indices) | 515 n_total = len(indices) |
| 527 test_size=split_probabilities[2], | 548 test_size=split_probabilities[2], |
| 528 random_state=random_state, | 549 random_state=random_state, |
| 529 stratify=out[label_column], | 550 stratify=out[label_column], |
| 530 ) | 551 ) |
| 531 # second split: separate training and validation from remaining data | 552 # second split: separate training and validation from remaining data |
| 532 val_size_adjusted = split_probabilities[1] / (split_probabilities[0] + split_probabilities[1]) | 553 val_size_adjusted = split_probabilities[1] / ( |
| 554 split_probabilities[0] + split_probabilities[1] | |
| 555 ) | |
| 533 train_idx, val_idx = train_test_split( | 556 train_idx, val_idx = train_test_split( |
| 534 train_val_idx, | 557 train_val_idx, |
| 535 test_size=val_size_adjusted, | 558 test_size=val_size_adjusted, |
| 536 random_state=random_state, | 559 random_state=random_state, |
| 537 stratify=out.loc[train_val_idx, label_column], | 560 stratify=out.loc[train_val_idx, label_column], |
| 539 # assign split values | 562 # assign split values |
| 540 out.loc[train_idx, split_column] = 0 | 563 out.loc[train_idx, split_column] = 0 |
| 541 out.loc[val_idx, split_column] = 1 | 564 out.loc[val_idx, split_column] = 1 |
| 542 out.loc[test_idx, split_column] = 2 | 565 out.loc[test_idx, split_column] = 2 |
| 543 logger.info("Successfully applied stratified random split") | 566 logger.info("Successfully applied stratified random split") |
| 544 logger.info(f"Split counts: Train={len(train_idx)}, Val={len(val_idx)}, Test={len(test_idx)}") | 567 logger.info( |
| 568 f"Split counts: Train={len(train_idx)}, Val={len(val_idx)}, Test={len(test_idx)}" | |
| 569 ) | |
| 545 return out.astype({split_column: int}) | 570 return out.astype({split_column: int}) |
| 546 | 571 |
| 547 | 572 |
| 548 class Backend(Protocol): | 573 class Backend(Protocol): |
| 549 """Interface for a machine learning backend.""" | 574 """Interface for a machine learning backend.""" |
| 575 | |
| 550 def prepare_config( | 576 def prepare_config( |
| 551 self, | 577 self, |
| 552 config_params: Dict[str, Any], | 578 config_params: Dict[str, Any], |
| 553 split_config: Dict[str, Any], | 579 split_config: Dict[str, Any], |
| 554 ) -> str: | 580 ) -> str: |
| 576 ... | 602 ... |
| 577 | 603 |
| 578 | 604 |
| 579 class LudwigDirectBackend: | 605 class LudwigDirectBackend: |
| 580 """Backend for running Ludwig experiments directly via the internal experiment_cli function.""" | 606 """Backend for running Ludwig experiments directly via the internal experiment_cli function.""" |
| 607 | |
| 581 def prepare_config( | 608 def prepare_config( |
| 582 self, | 609 self, |
| 583 config_params: Dict[str, Any], | 610 config_params: Dict[str, Any], |
| 584 split_config: Dict[str, Any], | 611 split_config: Dict[str, Any], |
| 585 ) -> str: | 612 ) -> str: |
| 586 logger.info("LudwigDirectBackend: Preparing YAML configuration.") | 613 logger.info("LudwigDirectBackend: Preparing YAML configuration.") |
| 614 | |
| 587 model_name = config_params.get("model_name", "resnet18") | 615 model_name = config_params.get("model_name", "resnet18") |
| 588 use_pretrained = config_params.get("use_pretrained", False) | 616 use_pretrained = config_params.get("use_pretrained", False) |
| 589 fine_tune = config_params.get("fine_tune", False) | 617 fine_tune = config_params.get("fine_tune", False) |
| 590 if use_pretrained: | 618 if use_pretrained: |
| 591 trainable = bool(fine_tune) | 619 trainable = bool(fine_tune) |
| 604 "use_pretrained": use_pretrained, | 632 "use_pretrained": use_pretrained, |
| 605 "trainable": trainable, | 633 "trainable": trainable, |
| 606 } | 634 } |
| 607 else: | 635 else: |
| 608 encoder_config = {"type": raw_encoder} | 636 encoder_config = {"type": raw_encoder} |
| 637 | |
| 609 batch_size_cfg = batch_size or "auto" | 638 batch_size_cfg = batch_size or "auto" |
| 639 | |
| 610 label_column_path = config_params.get("label_column_data_path") | 640 label_column_path = config_params.get("label_column_data_path") |
| 611 label_series = None | 641 label_series = None |
| 612 if label_column_path is not None and Path(label_column_path).exists(): | 642 if label_column_path is not None and Path(label_column_path).exists(): |
| 613 try: | 643 try: |
| 614 label_series = pd.read_csv(label_column_path)[LABEL_COLUMN_NAME] | 644 label_series = pd.read_csv(label_column_path)[LABEL_COLUMN_NAME] |
| 615 except Exception as e: | 645 except Exception as e: |
| 616 logger.warning(f"Could not read label column for task detection: {e}") | 646 logger.warning(f"Could not read label column for task detection: {e}") |
| 647 | |
| 617 if ( | 648 if ( |
| 618 label_series is not None | 649 label_series is not None |
| 619 and ptypes.is_numeric_dtype(label_series.dtype) | 650 and ptypes.is_numeric_dtype(label_series.dtype) |
| 620 and label_series.nunique() > 10 | 651 and label_series.nunique() > 10 |
| 621 ): | 652 ): |
| 622 task_type = "regression" | 653 task_type = "regression" |
| 623 else: | 654 else: |
| 624 task_type = "classification" | 655 task_type = "classification" |
| 656 | |
| 625 config_params["task_type"] = task_type | 657 config_params["task_type"] = task_type |
| 658 | |
| 626 image_feat: Dict[str, Any] = { | 659 image_feat: Dict[str, Any] = { |
| 627 "name": IMAGE_PATH_COLUMN_NAME, | 660 "name": IMAGE_PATH_COLUMN_NAME, |
| 628 "type": "image", | 661 "type": "image", |
| 629 "encoder": encoder_config, | 662 "encoder": encoder_config, |
| 630 } | 663 } |
| 631 if config_params.get("augmentation") is not None: | 664 if config_params.get("augmentation") is not None: |
| 632 image_feat["augmentation"] = config_params["augmentation"] | 665 image_feat["augmentation"] = config_params["augmentation"] |
| 666 | |
| 633 if task_type == "regression": | 667 if task_type == "regression": |
| 634 output_feat = { | 668 output_feat = { |
| 635 "name": LABEL_COLUMN_NAME, | 669 "name": LABEL_COLUMN_NAME, |
| 636 "type": "number", | 670 "type": "number", |
| 637 "decoder": {"type": "regressor"}, | 671 "decoder": {"type": "regressor"}, |
| 643 "r2", | 677 "r2", |
| 644 ] | 678 ] |
| 645 }, | 679 }, |
| 646 } | 680 } |
| 647 val_metric = config_params.get("validation_metric", "mean_squared_error") | 681 val_metric = config_params.get("validation_metric", "mean_squared_error") |
| 682 | |
| 648 else: | 683 else: |
| 649 num_unique_labels = ( | 684 num_unique_labels = ( |
| 650 label_series.nunique() if label_series is not None else 2 | 685 label_series.nunique() if label_series is not None else 2 |
| 651 ) | 686 ) |
| 652 output_type = "binary" if num_unique_labels == 2 else "category" | 687 output_type = "binary" if num_unique_labels == 2 else "category" |
| 653 output_feat = {"name": LABEL_COLUMN_NAME, "type": output_type} | 688 output_feat = {"name": LABEL_COLUMN_NAME, "type": output_type} |
| 654 if output_type == "binary" and config_params.get("threshold") is not None: | 689 if output_type == "binary" and config_params.get("threshold") is not None: |
| 655 output_feat["threshold"] = float(config_params["threshold"]) | 690 output_feat["threshold"] = float(config_params["threshold"]) |
| 656 val_metric = None | 691 val_metric = None |
| 692 | |
| 657 conf: Dict[str, Any] = { | 693 conf: Dict[str, Any] = { |
| 658 "model_type": "ecd", | 694 "model_type": "ecd", |
| 659 "input_features": [image_feat], | 695 "input_features": [image_feat], |
| 660 "output_features": [output_feat], | 696 "output_features": [output_feat], |
| 661 "combiner": {"type": "concat"}, | 697 "combiner": {"type": "concat"}, |
| 671 "split": split_config, | 707 "split": split_config, |
| 672 "num_processes": num_processes, | 708 "num_processes": num_processes, |
| 673 "in_memory": False, | 709 "in_memory": False, |
| 674 }, | 710 }, |
| 675 } | 711 } |
| 712 | |
| 676 logger.debug("LudwigDirectBackend: Config dict built.") | 713 logger.debug("LudwigDirectBackend: Config dict built.") |
| 677 try: | 714 try: |
| 678 yaml_str = yaml.dump(conf, sort_keys=False, indent=2) | 715 yaml_str = yaml.dump(conf, sort_keys=False, indent=2) |
| 679 logger.info("LudwigDirectBackend: YAML config generated.") | 716 logger.info("LudwigDirectBackend: YAML config generated.") |
| 680 return yaml_str | 717 return yaml_str |
| 692 output_dir: Path, | 729 output_dir: Path, |
| 693 random_seed: int = 42, | 730 random_seed: int = 42, |
| 694 ) -> None: | 731 ) -> None: |
| 695 """Invoke Ludwig's internal experiment_cli function to run the experiment.""" | 732 """Invoke Ludwig's internal experiment_cli function to run the experiment.""" |
| 696 logger.info("LudwigDirectBackend: Starting experiment execution.") | 733 logger.info("LudwigDirectBackend: Starting experiment execution.") |
| 734 | |
| 697 try: | 735 try: |
| 698 from ludwig.experiment import experiment_cli | 736 from ludwig.experiment import experiment_cli |
| 699 except ImportError as e: | 737 except ImportError as e: |
| 700 logger.error( | 738 logger.error( |
| 701 "LudwigDirectBackend: Could not import experiment_cli.", | 739 "LudwigDirectBackend: Could not import experiment_cli.", |
| 702 exc_info=True, | 740 exc_info=True, |
| 703 ) | 741 ) |
| 704 raise RuntimeError("Ludwig import failed.") from e | 742 raise RuntimeError("Ludwig import failed.") from e |
| 743 | |
| 705 output_dir.mkdir(parents=True, exist_ok=True) | 744 output_dir.mkdir(parents=True, exist_ok=True) |
| 745 | |
| 706 try: | 746 try: |
| 707 experiment_cli( | 747 experiment_cli( |
| 708 dataset=str(dataset_path), | 748 dataset=str(dataset_path), |
| 709 config=str(config_path), | 749 config=str(config_path), |
| 710 output_directory=str(output_dir), | 750 output_directory=str(output_dir), |
| 731 output_dir = Path(output_dir) | 771 output_dir = Path(output_dir) |
| 732 exp_dirs = sorted( | 772 exp_dirs = sorted( |
| 733 output_dir.glob("experiment_run*"), | 773 output_dir.glob("experiment_run*"), |
| 734 key=lambda p: p.stat().st_mtime, | 774 key=lambda p: p.stat().st_mtime, |
| 735 ) | 775 ) |
| 776 | |
| 736 if not exp_dirs: | 777 if not exp_dirs: |
| 737 logger.warning(f"No experiment run directories found in {output_dir}") | 778 logger.warning(f"No experiment run directories found in {output_dir}") |
| 738 return None | 779 return None |
| 780 | |
| 739 progress_file = exp_dirs[-1] / "model" / "training_progress.json" | 781 progress_file = exp_dirs[-1] / "model" / "training_progress.json" |
| 740 if not progress_file.exists(): | 782 if not progress_file.exists(): |
| 741 logger.warning(f"No training_progress.json found in {progress_file}") | 783 logger.warning(f"No training_progress.json found in {progress_file}") |
| 742 return None | 784 return None |
| 785 | |
| 743 try: | 786 try: |
| 744 with progress_file.open("r", encoding="utf-8") as f: | 787 with progress_file.open("r", encoding="utf-8") as f: |
| 745 data = json.load(f) | 788 data = json.load(f) |
| 746 return { | 789 return { |
| 747 "learning_rate": data.get("learning_rate"), | 790 "learning_rate": data.get("learning_rate"), |
| 773 logger.error(f"Error converting Parquet to CSV: {e}") | 816 logger.error(f"Error converting Parquet to CSV: {e}") |
| 774 | 817 |
| 775 def generate_plots(self, output_dir: Path) -> None: | 818 def generate_plots(self, output_dir: Path) -> None: |
| 776 """Generate all registered Ludwig visualizations for the latest experiment run.""" | 819 """Generate all registered Ludwig visualizations for the latest experiment run.""" |
| 777 logger.info("Generating all Ludwig visualizations…") | 820 logger.info("Generating all Ludwig visualizations…") |
| 821 | |
| 778 test_plots = { | 822 test_plots = { |
| 779 "compare_performance", | 823 "compare_performance", |
| 780 "compare_classifiers_performance_from_prob", | 824 "compare_classifiers_performance_from_prob", |
| 781 "compare_classifiers_performance_from_pred", | 825 "compare_classifiers_performance_from_pred", |
| 782 "compare_classifiers_performance_changing_k", | 826 "compare_classifiers_performance_changing_k", |
| 796 } | 840 } |
| 797 train_plots = { | 841 train_plots = { |
| 798 "learning_curves", | 842 "learning_curves", |
| 799 "compare_classifiers_performance_subset", | 843 "compare_classifiers_performance_subset", |
| 800 } | 844 } |
| 845 | |
| 801 output_dir = Path(output_dir) | 846 output_dir = Path(output_dir) |
| 802 exp_dirs = sorted( | 847 exp_dirs = sorted( |
| 803 output_dir.glob("experiment_run*"), | 848 output_dir.glob("experiment_run*"), |
| 804 key=lambda p: p.stat().st_mtime, | 849 key=lambda p: p.stat().st_mtime, |
| 805 ) | 850 ) |
| 806 if not exp_dirs: | 851 if not exp_dirs: |
| 807 logger.warning(f"No experiment run dirs found in {output_dir}") | 852 logger.warning(f"No experiment run dirs found in {output_dir}") |
| 808 return | 853 return |
| 809 exp_dir = exp_dirs[-1] | 854 exp_dir = exp_dirs[-1] |
| 855 | |
| 810 viz_dir = exp_dir / "visualizations" | 856 viz_dir = exp_dir / "visualizations" |
| 811 viz_dir.mkdir(exist_ok=True) | 857 viz_dir.mkdir(exist_ok=True) |
| 812 train_viz = viz_dir / "train" | 858 train_viz = viz_dir / "train" |
| 813 test_viz = viz_dir / "test" | 859 test_viz = viz_dir / "test" |
| 814 train_viz.mkdir(parents=True, exist_ok=True) | 860 train_viz.mkdir(parents=True, exist_ok=True) |
| 819 | 865 |
| 820 training_stats = _check(exp_dir / "training_statistics.json") | 866 training_stats = _check(exp_dir / "training_statistics.json") |
| 821 test_stats = _check(exp_dir / TEST_STATISTICS_FILE_NAME) | 867 test_stats = _check(exp_dir / TEST_STATISTICS_FILE_NAME) |
| 822 probs_path = _check(exp_dir / PREDICTIONS_PARQUET_FILE_NAME) | 868 probs_path = _check(exp_dir / PREDICTIONS_PARQUET_FILE_NAME) |
| 823 gt_metadata = _check(exp_dir / "model" / TRAIN_SET_METADATA_FILE_NAME) | 869 gt_metadata = _check(exp_dir / "model" / TRAIN_SET_METADATA_FILE_NAME) |
| 870 | |
| 824 dataset_path = None | 871 dataset_path = None |
| 825 split_file = None | 872 split_file = None |
| 826 desc = exp_dir / DESCRIPTION_FILE_NAME | 873 desc = exp_dir / DESCRIPTION_FILE_NAME |
| 827 if desc.exists(): | 874 if desc.exists(): |
| 828 with open(desc, "r") as f: | 875 with open(desc, "r") as f: |
| 829 cfg = json.load(f) | 876 cfg = json.load(f) |
| 830 dataset_path = _check(Path(cfg.get("dataset", ""))) | 877 dataset_path = _check(Path(cfg.get("dataset", ""))) |
| 831 split_file = _check(Path(get_split_path(cfg.get("dataset", "")))) | 878 split_file = _check(Path(get_split_path(cfg.get("dataset", "")))) |
| 879 | |
| 832 output_feature = "" | 880 output_feature = "" |
| 833 if desc.exists(): | 881 if desc.exists(): |
| 834 try: | 882 try: |
| 835 output_feature = cfg["config"]["output_features"][0]["name"] | 883 output_feature = cfg["config"]["output_features"][0]["name"] |
| 836 except Exception: | 884 except Exception: |
| 837 pass | 885 pass |
| 838 if not output_feature and test_stats: | 886 if not output_feature and test_stats: |
| 839 with open(test_stats, "r") as f: | 887 with open(test_stats, "r") as f: |
| 840 stats = json.load(f) | 888 stats = json.load(f) |
| 841 output_feature = next(iter(stats.keys()), "") | 889 output_feature = next(iter(stats.keys()), "") |
| 890 | |
| 842 viz_registry = get_visualizations_registry() | 891 viz_registry = get_visualizations_registry() |
| 843 for viz_name, viz_func in viz_registry.items(): | 892 for viz_name, viz_func in viz_registry.items(): |
| 844 if viz_name in train_plots: | 893 if viz_name in train_plots: |
| 845 viz_dir_plot = train_viz | 894 viz_dir_plot = train_viz |
| 846 elif viz_name in test_plots: | 895 elif viz_name in test_plots: |
| 847 viz_dir_plot = test_viz | 896 viz_dir_plot = test_viz |
| 848 else: | 897 else: |
| 849 continue | 898 continue |
| 899 | |
| 850 try: | 900 try: |
| 851 viz_func( | 901 viz_func( |
| 852 training_statistics=[training_stats] if training_stats else [], | 902 training_statistics=[training_stats] if training_stats else [], |
| 853 test_statistics=[test_stats] if test_stats else [], | 903 test_statistics=[test_stats] if test_stats else [], |
| 854 probabilities=[probs_path] if probs_path else [], | 904 probabilities=[probs_path] if probs_path else [], |
| 864 file_format="png", | 914 file_format="png", |
| 865 ) | 915 ) |
| 866 logger.info(f"✔ Generated {viz_name}") | 916 logger.info(f"✔ Generated {viz_name}") |
| 867 except Exception as e: | 917 except Exception as e: |
| 868 logger.warning(f"✘ Skipped {viz_name}: {e}") | 918 logger.warning(f"✘ Skipped {viz_name}: {e}") |
| 919 | |
| 869 logger.info(f"All visualizations written to {viz_dir}") | 920 logger.info(f"All visualizations written to {viz_dir}") |
| 870 | 921 |
| 871 def generate_html_report( | 922 def generate_html_report( |
| 872 self, | 923 self, |
| 873 title: str, | 924 title: str, |
| 879 cwd = Path.cwd() | 930 cwd = Path.cwd() |
| 880 report_name = title.lower().replace(" ", "_") + "_report.html" | 931 report_name = title.lower().replace(" ", "_") + "_report.html" |
| 881 report_path = cwd / report_name | 932 report_path = cwd / report_name |
| 882 output_dir = Path(output_dir) | 933 output_dir = Path(output_dir) |
| 883 output_type = None | 934 output_type = None |
| 935 | |
| 884 exp_dirs = sorted( | 936 exp_dirs = sorted( |
| 885 output_dir.glob("experiment_run*"), | 937 output_dir.glob("experiment_run*"), |
| 886 key=lambda p: p.stat().st_mtime, | 938 key=lambda p: p.stat().st_mtime, |
| 887 ) | 939 ) |
| 888 if not exp_dirs: | 940 if not exp_dirs: |
| 889 raise RuntimeError(f"No 'experiment*' dirs found in {output_dir}") | 941 raise RuntimeError(f"No 'experiment*' dirs found in {output_dir}") |
| 890 exp_dir = exp_dirs[-1] | 942 exp_dir = exp_dirs[-1] |
| 943 | |
| 891 base_viz_dir = exp_dir / "visualizations" | 944 base_viz_dir = exp_dir / "visualizations" |
| 892 train_viz_dir = base_viz_dir / "train" | 945 train_viz_dir = base_viz_dir / "train" |
| 893 test_viz_dir = base_viz_dir / "test" | 946 test_viz_dir = base_viz_dir / "test" |
| 947 | |
| 894 html = get_html_template() | 948 html = get_html_template() |
| 895 html += f"<h1>{title}</h1>" | 949 html += f"<h1>{title}</h1>" |
| 950 | |
| 896 metrics_html = "" | 951 metrics_html = "" |
| 897 train_val_metrics_html = "" | 952 train_val_metrics_html = "" |
| 898 test_metrics_html = "" | 953 test_metrics_html = "" |
| 899 try: | 954 try: |
| 900 train_stats_path = exp_dir / "training_statistics.json" | 955 train_stats_path = exp_dir / "training_statistics.json" |
| 916 ) | 971 ) |
| 917 except Exception as e: | 972 except Exception as e: |
| 918 logger.warning( | 973 logger.warning( |
| 919 f"Could not load stats for HTML report: {type(e).__name__}: {e}" | 974 f"Could not load stats for HTML report: {type(e).__name__}: {e}" |
| 920 ) | 975 ) |
| 976 | |
| 921 config_html = "" | 977 config_html = "" |
| 922 training_progress = self.get_training_process(output_dir) | 978 training_progress = self.get_training_process(output_dir) |
| 923 try: | 979 try: |
| 924 config_html = format_config_table_html( | 980 config_html = format_config_table_html( |
| 925 config, split_info, training_progress | 981 config, split_info, training_progress, output_type |
| 926 ) | 982 ) |
| 927 except Exception as e: | 983 except Exception as e: |
| 928 logger.warning(f"Could not load config for HTML report: {e}") | 984 logger.warning(f"Could not load config for HTML report: {e}") |
| 929 | 985 |
| 930 def render_img_section( | 986 def render_img_section( |
| 934 return f"<h2>{title}</h2><p><em>Directory not found.</em></p>" | 990 return f"<h2>{title}</h2><p><em>Directory not found.</em></p>" |
| 935 # collect every PNG | 991 # collect every PNG |
| 936 imgs = list(dir_path.glob("*.png")) | 992 imgs = list(dir_path.glob("*.png")) |
| 937 # --- EXCLUDE Ludwig's base confusion matrix and any top-N confusion_matrix files --- | 993 # --- EXCLUDE Ludwig's base confusion matrix and any top-N confusion_matrix files --- |
| 938 imgs = [ | 994 imgs = [ |
| 939 img for img in imgs | 995 img |
| 996 for img in imgs | |
| 940 if not ( | 997 if not ( |
| 941 img.name == "confusion_matrix.png" | 998 img.name == "confusion_matrix.png" |
| 942 or img.name.startswith("confusion_matrix__label_top") | 999 or img.name.startswith("confusion_matrix__label_top") |
| 943 or img.name == "roc_curves.png" | 1000 or img.name == "roc_curves.png" |
| 944 ) | 1001 ) |
| 970 ] | 1027 ] |
| 971 # filter and order | 1028 # filter and order |
| 972 valid_imgs = [img for img in imgs if img.name not in unwanted] | 1029 valid_imgs = [img for img in imgs if img.name not in unwanted] |
| 973 img_map = {img.name: img for img in valid_imgs} | 1030 img_map = {img.name: img for img in valid_imgs} |
| 974 ordered = [img_map[n] for n in display_order if n in img_map] | 1031 ordered = [img_map[n] for n in display_order if n in img_map] |
| 975 others = sorted(img for img in valid_imgs if img.name not in display_order) | 1032 others = sorted( |
| 1033 img for img in valid_imgs if img.name not in display_order | |
| 1034 ) | |
| 976 imgs = ordered + others | 1035 imgs = ordered + others |
| 977 else: | 1036 else: |
| 978 # regression: just sort whatever's left | 1037 # regression: just sort whatever's left |
| 979 imgs = sorted(imgs) | 1038 imgs = sorted(imgs) |
| 980 # render each remaining PNG | 1039 # render each remaining PNG |
| 1010 if pred_col is None: | 1069 if pred_col is None: |
| 1011 raise ValueError("No prediction column found in Parquet output") | 1070 raise ValueError("No prediction column found in Parquet output") |
| 1012 df_pred = df_preds[[pred_col]].rename(columns={pred_col: "prediction"}) | 1071 df_pred = df_preds[[pred_col]].rename(columns={pred_col: "prediction"}) |
| 1013 # 2) load ground truth for the test split from prepared CSV | 1072 # 2) load ground truth for the test split from prepared CSV |
| 1014 df_all = pd.read_csv(config["label_column_data_path"]) | 1073 df_all = pd.read_csv(config["label_column_data_path"]) |
| 1015 df_gt = df_all[df_all[SPLIT_COLUMN_NAME] == 2][LABEL_COLUMN_NAME].reset_index(drop=True) | 1074 df_gt = df_all[df_all[SPLIT_COLUMN_NAME] == 2][ |
| 1075 LABEL_COLUMN_NAME | |
| 1076 ].reset_index(drop=True) | |
| 1016 # 3) concatenate side-by-side | 1077 # 3) concatenate side-by-side |
| 1017 df_table = pd.concat([df_gt, df_pred], axis=1) | 1078 df_table = pd.concat([df_gt, df_pred], axis=1) |
| 1018 df_table.columns = [LABEL_COLUMN_NAME, "prediction"] | 1079 df_table.columns = [LABEL_COLUMN_NAME, "prediction"] |
| 1019 # 4) render as HTML | 1080 # 4) render as HTML |
| 1020 preds_html = df_table.to_html(index=False, classes="predictions-table") | 1081 preds_html = df_table.to_html(index=False, classes="predictions-table") |
| 1034 str(training_stats_path), | 1095 str(training_stats_path), |
| 1035 ) | 1096 ) |
| 1036 for plot in interactive_plots: | 1097 for plot in interactive_plots: |
| 1037 # 2) inject the static "roc_curves_from_prediction_statistics.png" | 1098 # 2) inject the static "roc_curves_from_prediction_statistics.png" |
| 1038 if plot["title"] == "ROC-AUC": | 1099 if plot["title"] == "ROC-AUC": |
| 1039 static_img = test_viz_dir / "roc_curves_from_prediction_statistics.png" | 1100 static_img = ( |
| 1101 test_viz_dir / "roc_curves_from_prediction_statistics.png" | |
| 1102 ) | |
| 1040 if static_img.exists(): | 1103 if static_img.exists(): |
| 1041 b64 = encode_image_to_base64(str(static_img)) | 1104 b64 = encode_image_to_base64(str(static_img)) |
| 1042 tab3_content += ( | 1105 tab3_content += ( |
| 1043 "<h2 style='text-align: center;'>" | 1106 "<h2 style='text-align: center;'>" |
| 1044 "Roc Curves From Prediction Statistics" | 1107 "Roc Curves From Prediction Statistics" |
| 1052 tab3_content += ( | 1115 tab3_content += ( |
| 1053 f"<h2 style='text-align: center;'>{plot['title']}</h2>" | 1116 f"<h2 style='text-align: center;'>{plot['title']}</h2>" |
| 1054 + plot["html"] | 1117 + plot["html"] |
| 1055 ) | 1118 ) |
| 1056 tab3_content += render_img_section( | 1119 tab3_content += render_img_section( |
| 1057 "Test Visualizations", | 1120 "Test Visualizations", test_viz_dir, output_type |
| 1058 test_viz_dir, | |
| 1059 output_type | |
| 1060 ) | 1121 ) |
| 1061 # assemble the tabs and help modal | 1122 # assemble the tabs and help modal |
| 1062 tabbed_html = build_tabbed_html(tab1_content, tab2_content, tab3_content) | 1123 tabbed_html = build_tabbed_html(tab1_content, tab2_content, tab3_content) |
| 1063 modal_html = get_metrics_help_modal() | 1124 modal_html = get_metrics_help_modal() |
| 1064 html += tabbed_html + modal_html + get_html_closing() | 1125 html += tabbed_html + modal_html + get_html_closing() |
| 1126 | |
| 1065 try: | 1127 try: |
| 1066 with open(report_path, "w") as f: | 1128 with open(report_path, "w") as f: |
| 1067 f.write(html) | 1129 f.write(html) |
| 1068 logger.info(f"HTML report generated at: {report_path}") | 1130 logger.info(f"HTML report generated at: {report_path}") |
| 1069 except Exception as e: | 1131 except Exception as e: |
| 1070 logger.error(f"Failed to write HTML report: {e}") | 1132 logger.error(f"Failed to write HTML report: {e}") |
| 1071 raise | 1133 raise |
| 1134 | |
| 1072 return report_path | 1135 return report_path |
| 1073 | 1136 |
| 1074 | 1137 |
| 1075 class WorkflowOrchestrator: | 1138 class WorkflowOrchestrator: |
| 1076 """Manages the image-classification workflow.""" | 1139 """Manages the image-classification workflow.""" |
| 1140 | |
| 1077 def __init__(self, args: argparse.Namespace, backend: Backend): | 1141 def __init__(self, args: argparse.Namespace, backend: Backend): |
| 1078 self.args = args | 1142 self.args = args |
| 1079 self.backend = backend | 1143 self.backend = backend |
| 1080 self.temp_dir: Optional[Path] = None | 1144 self.temp_dir: Optional[Path] = None |
| 1081 self.image_extract_dir: Optional[Path] = None | 1145 self.image_extract_dir: Optional[Path] = None |
| 1111 | 1175 |
| 1112 def _prepare_data(self) -> Tuple[Path, Dict[str, Any], str]: | 1176 def _prepare_data(self) -> Tuple[Path, Dict[str, Any], str]: |
| 1113 """Load CSV, update image paths, handle splits, and write prepared CSV.""" | 1177 """Load CSV, update image paths, handle splits, and write prepared CSV.""" |
| 1114 if not self.temp_dir or not self.image_extract_dir: | 1178 if not self.temp_dir or not self.image_extract_dir: |
| 1115 raise RuntimeError("Temp dirs not initialized before data prep.") | 1179 raise RuntimeError("Temp dirs not initialized before data prep.") |
| 1180 | |
| 1116 try: | 1181 try: |
| 1117 df = pd.read_csv(self.args.csv_file) | 1182 df = pd.read_csv(self.args.csv_file) |
| 1118 logger.info(f"Loaded CSV: {self.args.csv_file}") | 1183 logger.info(f"Loaded CSV: {self.args.csv_file}") |
| 1119 except Exception: | 1184 except Exception: |
| 1120 logger.error("Error loading CSV file", exc_info=True) | 1185 logger.error("Error loading CSV file", exc_info=True) |
| 1121 raise | 1186 raise |
| 1187 | |
| 1122 required = {IMAGE_PATH_COLUMN_NAME, LABEL_COLUMN_NAME} | 1188 required = {IMAGE_PATH_COLUMN_NAME, LABEL_COLUMN_NAME} |
| 1123 missing = required - set(df.columns) | 1189 missing = required - set(df.columns) |
| 1124 if missing: | 1190 if missing: |
| 1125 raise ValueError(f"Missing CSV columns: {', '.join(missing)}") | 1191 raise ValueError(f"Missing CSV columns: {', '.join(missing)}") |
| 1192 | |
| 1126 try: | 1193 try: |
| 1127 df[IMAGE_PATH_COLUMN_NAME] = df[IMAGE_PATH_COLUMN_NAME].apply( | 1194 df[IMAGE_PATH_COLUMN_NAME] = df[IMAGE_PATH_COLUMN_NAME].apply( |
| 1128 lambda p: str((self.image_extract_dir / p).resolve()) | 1195 lambda p: str((self.image_extract_dir / p).resolve()) |
| 1129 ) | 1196 ) |
| 1130 except Exception: | 1197 except Exception: |
| 1148 split_info = ( | 1215 split_info = ( |
| 1149 f"No split column in CSV. Created stratified random split: " | 1216 f"No split column in CSV. Created stratified random split: " |
| 1150 f"{[int(p * 100) for p in self.args.split_probabilities]}% " | 1217 f"{[int(p * 100) for p in self.args.split_probabilities]}% " |
| 1151 f"for train/val/test with balanced label distribution." | 1218 f"for train/val/test with balanced label distribution." |
| 1152 ) | 1219 ) |
| 1220 | |
| 1153 final_csv = self.temp_dir / TEMP_CSV_FILENAME | 1221 final_csv = self.temp_dir / TEMP_CSV_FILENAME |
| 1222 | |
| 1154 try: | 1223 try: |
| 1155 df.to_csv(final_csv, index=False) | 1224 df.to_csv(final_csv, index=False) |
| 1156 logger.info(f"Saved prepared data to {final_csv}") | 1225 logger.info(f"Saved prepared data to {final_csv}") |
| 1157 except Exception: | 1226 except Exception: |
| 1158 logger.error("Error saving prepared CSV", exc_info=True) | 1227 logger.error("Error saving prepared CSV", exc_info=True) |
| 1159 raise | 1228 raise |
| 1229 | |
| 1160 return final_csv, split_config, split_info | 1230 return final_csv, split_config, split_info |
| 1161 | 1231 |
| 1162 def _process_fixed_split( | 1232 def _process_fixed_split( |
| 1163 self, df: pd.DataFrame | 1233 self, df: pd.DataFrame |
| 1164 ) -> Tuple[pd.DataFrame, Dict[str, Any], str]: | 1234 ) -> Tuple[pd.DataFrame, Dict[str, Any], str]: |
| 1169 df[SPLIT_COLUMN_NAME] = pd.to_numeric(col, errors="coerce").astype( | 1239 df[SPLIT_COLUMN_NAME] = pd.to_numeric(col, errors="coerce").astype( |
| 1170 pd.Int64Dtype() | 1240 pd.Int64Dtype() |
| 1171 ) | 1241 ) |
| 1172 if df[SPLIT_COLUMN_NAME].isna().any(): | 1242 if df[SPLIT_COLUMN_NAME].isna().any(): |
| 1173 logger.warning("Split column contains non-numeric/missing values.") | 1243 logger.warning("Split column contains non-numeric/missing values.") |
| 1244 | |
| 1174 unique = set(df[SPLIT_COLUMN_NAME].dropna().unique()) | 1245 unique = set(df[SPLIT_COLUMN_NAME].dropna().unique()) |
| 1175 logger.info(f"Unique split values: {unique}") | 1246 logger.info(f"Unique split values: {unique}") |
| 1176 if unique == {0, 2}: | 1247 if unique == {0, 2}: |
| 1177 df = split_data_0_2( | 1248 df = split_data_0_2( |
| 1178 df, | 1249 df, |
| 1191 elif unique.issubset({0, 1, 2}): | 1262 elif unique.issubset({0, 1, 2}): |
| 1192 split_info = "Used user-defined split column from CSV." | 1263 split_info = "Used user-defined split column from CSV." |
| 1193 logger.info("Using fixed split as-is.") | 1264 logger.info("Using fixed split as-is.") |
| 1194 else: | 1265 else: |
| 1195 raise ValueError(f"Unexpected split values: {unique}") | 1266 raise ValueError(f"Unexpected split values: {unique}") |
| 1267 | |
| 1196 return df, {"type": "fixed", "column": SPLIT_COLUMN_NAME}, split_info | 1268 return df, {"type": "fixed", "column": SPLIT_COLUMN_NAME}, split_info |
| 1269 | |
| 1197 except Exception: | 1270 except Exception: |
| 1198 logger.error("Error processing fixed split", exc_info=True) | 1271 logger.error("Error processing fixed split", exc_info=True) |
| 1199 raise | 1272 raise |
| 1200 | 1273 |
| 1201 def _cleanup_temp_dirs(self) -> None: | 1274 def _cleanup_temp_dirs(self) -> None: |
| 1207 | 1280 |
| 1208 def run(self) -> None: | 1281 def run(self) -> None: |
| 1209 """Execute the full workflow end-to-end.""" | 1282 """Execute the full workflow end-to-end.""" |
| 1210 logger.info("Starting workflow...") | 1283 logger.info("Starting workflow...") |
| 1211 self.args.output_dir.mkdir(parents=True, exist_ok=True) | 1284 self.args.output_dir.mkdir(parents=True, exist_ok=True) |
| 1285 | |
| 1212 try: | 1286 try: |
| 1213 self._create_temp_dirs() | 1287 self._create_temp_dirs() |
| 1214 self._extract_images() | 1288 self._extract_images() |
| 1215 csv_path, split_cfg, split_info = self._prepare_data() | 1289 csv_path, split_cfg, split_info = self._prepare_data() |
| 1290 | |
| 1216 use_pretrained = self.args.use_pretrained or self.args.fine_tune | 1291 use_pretrained = self.args.use_pretrained or self.args.fine_tune |
| 1292 | |
| 1217 backend_args = { | 1293 backend_args = { |
| 1218 "model_name": self.args.model_name, | 1294 "model_name": self.args.model_name, |
| 1219 "fine_tune": self.args.fine_tune, | 1295 "fine_tune": self.args.fine_tune, |
| 1220 "use_pretrained": use_pretrained, | 1296 "use_pretrained": use_pretrained, |
| 1221 "epochs": self.args.epochs, | 1297 "epochs": self.args.epochs, |
| 1228 "label_column_data_path": csv_path, | 1304 "label_column_data_path": csv_path, |
| 1229 "augmentation": self.args.augmentation, | 1305 "augmentation": self.args.augmentation, |
| 1230 "threshold": self.args.threshold, | 1306 "threshold": self.args.threshold, |
| 1231 } | 1307 } |
| 1232 yaml_str = self.backend.prepare_config(backend_args, split_cfg) | 1308 yaml_str = self.backend.prepare_config(backend_args, split_cfg) |
| 1309 | |
| 1233 config_file = self.temp_dir / TEMP_CONFIG_FILENAME | 1310 config_file = self.temp_dir / TEMP_CONFIG_FILENAME |
| 1234 config_file.write_text(yaml_str) | 1311 config_file.write_text(yaml_str) |
| 1235 logger.info(f"Wrote backend config: {config_file}") | 1312 logger.info(f"Wrote backend config: {config_file}") |
| 1313 | |
| 1236 self.backend.run_experiment( | 1314 self.backend.run_experiment( |
| 1237 csv_path, | 1315 csv_path, |
| 1238 config_file, | 1316 config_file, |
| 1239 self.args.output_dir, | 1317 self.args.output_dir, |
| 1240 self.args.random_seed, | 1318 self.args.random_seed, |
| 1372 nargs=3, | 1450 nargs=3, |
| 1373 metavar=("train", "val", "test"), | 1451 metavar=("train", "val", "test"), |
| 1374 action=SplitProbAction, | 1452 action=SplitProbAction, |
| 1375 default=[0.7, 0.1, 0.2], | 1453 default=[0.7, 0.1, 0.2], |
| 1376 help=( | 1454 help=( |
| 1377 "Random split proportions (e.g., 0.7 0.1 0.2)." | 1455 "Random split proportions (e.g., 0.7 0.1 0.2).Only used if no split column." |
| 1378 "Only used if no split column." | |
| 1379 ), | 1456 ), |
| 1380 ) | 1457 ) |
| 1381 parser.add_argument( | 1458 parser.add_argument( |
| 1382 "--random-seed", | 1459 "--random-seed", |
| 1383 type=int, | 1460 type=int, |
| 1406 type=float, | 1483 type=float, |
| 1407 default=None, | 1484 default=None, |
| 1408 help=( | 1485 help=( |
| 1409 "Decision threshold for binary classification (0.0–1.0)." | 1486 "Decision threshold for binary classification (0.0–1.0)." |
| 1410 "Overrides default 0.5." | 1487 "Overrides default 0.5." |
| 1411 ) | 1488 ), |
| 1412 ) | 1489 ) |
| 1413 args = parser.parse_args() | 1490 args = parser.parse_args() |
| 1491 | |
| 1414 if not 0.0 <= args.validation_size <= 1.0: | 1492 if not 0.0 <= args.validation_size <= 1.0: |
| 1415 parser.error("validation-size must be between 0.0 and 1.0") | 1493 parser.error("validation-size must be between 0.0 and 1.0") |
| 1416 if not args.csv_file.is_file(): | 1494 if not args.csv_file.is_file(): |
| 1417 parser.error(f"CSV not found: {args.csv_file}") | 1495 parser.error(f"CSV not found: {args.csv_file}") |
| 1418 if not args.image_zip.is_file(): | 1496 if not args.image_zip.is_file(): |
| 1421 try: | 1499 try: |
| 1422 augmentation_setup = aug_parse(args.augmentation) | 1500 augmentation_setup = aug_parse(args.augmentation) |
| 1423 setattr(args, "augmentation", augmentation_setup) | 1501 setattr(args, "augmentation", augmentation_setup) |
| 1424 except ValueError as e: | 1502 except ValueError as e: |
| 1425 parser.error(str(e)) | 1503 parser.error(str(e)) |
| 1504 | |
| 1426 backend_instance = LudwigDirectBackend() | 1505 backend_instance = LudwigDirectBackend() |
| 1427 orchestrator = WorkflowOrchestrator(args, backend_instance) | 1506 orchestrator = WorkflowOrchestrator(args, backend_instance) |
| 1507 | |
| 1428 exit_code = 0 | 1508 exit_code = 0 |
| 1429 try: | 1509 try: |
| 1430 orchestrator.run() | 1510 orchestrator.run() |
| 1431 logger.info("Main script finished successfully.") | 1511 logger.info("Main script finished successfully.") |
| 1432 except Exception as e: | 1512 except Exception as e: |
| 1437 | 1517 |
| 1438 | 1518 |
| 1439 if __name__ == "__main__": | 1519 if __name__ == "__main__": |
| 1440 try: | 1520 try: |
| 1441 import ludwig | 1521 import ludwig |
| 1522 | |
| 1442 logger.debug(f"Found Ludwig version: {ludwig.globals.LUDWIG_VERSION}") | 1523 logger.debug(f"Found Ludwig version: {ludwig.globals.LUDWIG_VERSION}") |
| 1443 except ImportError: | 1524 except ImportError: |
| 1444 logger.error( | 1525 logger.error( |
| 1445 "Ludwig library not found. Please ensure Ludwig is installed " | 1526 "Ludwig library not found. Please ensure Ludwig is installed " |
| 1446 "('pip install ludwig[image]')" | 1527 "('pip install ludwig[image]')" |
| 1447 ) | 1528 ) |
| 1448 sys.exit(1) | 1529 sys.exit(1) |
| 1530 | |
| 1449 main() | 1531 main() |
