Mercurial > repos > goeckslab > image_learner
diff image_learner_cli.py @ 2:186424a7eca7 draft
planemo upload for repository https://github.com/goeckslab/gleam.git commit 91fa4aba245520fc0680088a07cead66bcfd4ed2
author | goeckslab |
---|---|
date | Thu, 03 Jul 2025 20:43:24 +0000 |
parents | 39202fe5cf97 |
children | 09904b1f61f5 |
line wrap: on
line diff
--- a/image_learner_cli.py Wed Jul 02 18:59:10 2025 +0000 +++ b/image_learner_cli.py Thu Jul 03 20:43:24 2025 +0000 @@ -1,4 +1,3 @@ -#!/usr/bin/env python3 import argparse import json import logging @@ -11,7 +10,18 @@ from typing import Any, Dict, Optional, Protocol, Tuple import pandas as pd +import pandas.api.types as ptypes import yaml +from constants import ( + IMAGE_PATH_COLUMN_NAME, + LABEL_COLUMN_NAME, + METRIC_DISPLAY_NAMES, + MODEL_ENCODER_TEMPLATES, + SPLIT_COLUMN_NAME, + TEMP_CONFIG_FILENAME, + TEMP_CSV_FILENAME, + TEMP_DIR_PREFIX +) from ludwig.globals import ( DESCRIPTION_FILE_NAME, PREDICTIONS_PARQUET_FILE_NAME, @@ -21,258 +31,29 @@ from ludwig.utils.data_utils import get_split_path from ludwig.visualize import get_visualizations_registry from sklearn.model_selection import train_test_split -from utils import encode_image_to_base64, get_html_closing, get_html_template - -# --- Constants --- -SPLIT_COLUMN_NAME = "split" -LABEL_COLUMN_NAME = "label" -IMAGE_PATH_COLUMN_NAME = "image_path" -DEFAULT_SPLIT_PROBABILITIES = [0.7, 0.1, 0.2] -TEMP_CSV_FILENAME = "processed_data_for_ludwig.csv" -TEMP_CONFIG_FILENAME = "ludwig_config.yaml" -TEMP_DIR_PREFIX = "ludwig_api_work_" -MODEL_ENCODER_TEMPLATES: Dict[str, Any] = { - "stacked_cnn": "stacked_cnn", - "resnet18": {"type": "resnet", "model_variant": 18}, - "resnet34": {"type": "resnet", "model_variant": 34}, - "resnet50": {"type": "resnet", "model_variant": 50}, - "resnet101": {"type": "resnet", "model_variant": 101}, - "resnet152": {"type": "resnet", "model_variant": 152}, - "resnext50_32x4d": {"type": "resnext", "model_variant": "50_32x4d"}, - "resnext101_32x8d": {"type": "resnext", "model_variant": "101_32x8d"}, - "resnext101_64x4d": {"type": "resnext", "model_variant": "101_64x4d"}, - "resnext152_32x8d": {"type": "resnext", "model_variant": "152_32x8d"}, - "wide_resnet50_2": {"type": "wide_resnet", "model_variant": "50_2"}, - "wide_resnet101_2": {"type": "wide_resnet", "model_variant": "101_2"}, - "wide_resnet103_2": {"type": "wide_resnet", "model_variant": "103_2"}, - "efficientnet_b0": {"type": "efficientnet", "model_variant": "b0"}, - "efficientnet_b1": {"type": "efficientnet", "model_variant": "b1"}, - "efficientnet_b2": {"type": "efficientnet", "model_variant": "b2"}, - "efficientnet_b3": {"type": "efficientnet", "model_variant": "b3"}, - "efficientnet_b4": {"type": "efficientnet", "model_variant": "b4"}, - "efficientnet_b5": {"type": "efficientnet", "model_variant": "b5"}, - "efficientnet_b6": {"type": "efficientnet", "model_variant": "b6"}, - "efficientnet_b7": {"type": "efficientnet", "model_variant": "b7"}, - "efficientnet_v2_s": {"type": "efficientnet", "model_variant": "v2_s"}, - "efficientnet_v2_m": {"type": "efficientnet", "model_variant": "v2_m"}, - "efficientnet_v2_l": {"type": "efficientnet", "model_variant": "v2_l"}, - "regnet_y_400mf": {"type": "regnet", "model_variant": "y_400mf"}, - "regnet_y_800mf": {"type": "regnet", "model_variant": "y_800mf"}, - "regnet_y_1_6gf": {"type": "regnet", "model_variant": "y_1_6gf"}, - "regnet_y_3_2gf": {"type": "regnet", "model_variant": "y_3_2gf"}, - "regnet_y_8gf": {"type": "regnet", "model_variant": "y_8gf"}, - "regnet_y_16gf": {"type": "regnet", "model_variant": "y_16gf"}, - "regnet_y_32gf": {"type": "regnet", "model_variant": "y_32gf"}, - "regnet_y_128gf": {"type": "regnet", "model_variant": "y_128gf"}, - "regnet_x_400mf": {"type": "regnet", "model_variant": "x_400mf"}, - "regnet_x_800mf": {"type": "regnet", "model_variant": "x_800mf"}, - "regnet_x_1_6gf": {"type": "regnet", "model_variant": "x_1_6gf"}, - "regnet_x_3_2gf": {"type": "regnet", "model_variant": "x_3_2gf"}, - "regnet_x_8gf": {"type": "regnet", "model_variant": "x_8gf"}, - "regnet_x_16gf": {"type": "regnet", "model_variant": "x_16gf"}, - "regnet_x_32gf": {"type": "regnet", "model_variant": "x_32gf"}, - "vgg11": {"type": "vgg", "model_variant": 11}, - "vgg11_bn": {"type": "vgg", "model_variant": "11_bn"}, - "vgg13": {"type": "vgg", "model_variant": 13}, - "vgg13_bn": {"type": "vgg", "model_variant": "13_bn"}, - "vgg16": {"type": "vgg", "model_variant": 16}, - "vgg16_bn": {"type": "vgg", "model_variant": "16_bn"}, - "vgg19": {"type": "vgg", "model_variant": 19}, - "vgg19_bn": {"type": "vgg", "model_variant": "19_bn"}, - "shufflenet_v2_x0_5": {"type": "shufflenet_v2", "model_variant": "x0_5"}, - "shufflenet_v2_x1_0": {"type": "shufflenet_v2", "model_variant": "x1_0"}, - "shufflenet_v2_x1_5": {"type": "shufflenet_v2", "model_variant": "x1_5"}, - "shufflenet_v2_x2_0": {"type": "shufflenet_v2", "model_variant": "x2_0"}, - "squeezenet1_0": {"type": "squeezenet", "model_variant": "1_0"}, - "squeezenet1_1": {"type": "squeezenet", "model_variant": "1_1"}, - "swin_t": {"type": "swin_transformer", "model_variant": "t"}, - "swin_s": {"type": "swin_transformer", "model_variant": "s"}, - "swin_b": {"type": "swin_transformer", "model_variant": "b"}, - "swin_v2_t": {"type": "swin_transformer", "model_variant": "v2_t"}, - "swin_v2_s": {"type": "swin_transformer", "model_variant": "v2_s"}, - "swin_v2_b": {"type": "swin_transformer", "model_variant": "v2_b"}, - "vit_b_16": {"type": "vision_transformer", "model_variant": "b_16"}, - "vit_b_32": {"type": "vision_transformer", "model_variant": "b_32"}, - "vit_l_16": {"type": "vision_transformer", "model_variant": "l_16"}, - "vit_l_32": {"type": "vision_transformer", "model_variant": "l_32"}, - "vit_h_14": {"type": "vision_transformer", "model_variant": "h_14"}, - "convnext_tiny": {"type": "convnext", "model_variant": "tiny"}, - "convnext_small": {"type": "convnext", "model_variant": "small"}, - "convnext_base": {"type": "convnext", "model_variant": "base"}, - "convnext_large": {"type": "convnext", "model_variant": "large"}, - "maxvit_t": {"type": "maxvit", "model_variant": "t"}, - "alexnet": {"type": "alexnet"}, - "googlenet": {"type": "googlenet"}, - "inception_v3": {"type": "inception_v3"}, - "mobilenet_v2": {"type": "mobilenet_v2"}, - "mobilenet_v3_large": {"type": "mobilenet_v3_large"}, - "mobilenet_v3_small": {"type": "mobilenet_v3_small"}, -} -METRIC_DISPLAY_NAMES = { - "accuracy": "Accuracy", - "accuracy_micro": "Accuracy-Micro", - "loss": "Loss", - "roc_auc": "ROC-AUC", - "roc_auc_macro": "ROC-AUC-Macro", - "roc_auc_micro": "ROC-AUC-Micro", - "hits_at_k": "Hits at K", - "precision": "Precision", - "recall": "Recall", - "specificity": "Specificity", - "kappa_score": "Cohen's Kappa", - "token_accuracy": "Token Accuracy", - "avg_precision_macro": "Precision-Macro", - "avg_recall_macro": "Recall-Macro", - "avg_f1_score_macro": "F1-score-Macro", - "avg_precision_micro": "Precision-Micro", - "avg_recall_micro": "Recall-Micro", - "avg_f1_score_micro": "F1-score-Micro", - "avg_precision_weighted": "Precision-Weighted", - "avg_recall_weighted": "Recall-Weighted", - "avg_f1_score_weighted": "F1-score-Weighted", - "average_precision_macro": " Precision-Average-Macro", - "average_precision_micro": "Precision-Average-Micro", - "average_precision_samples": "Precision-Average-Samples", -} +from utils import ( + build_tabbed_html, + encode_image_to_base64, + get_html_closing, + get_html_template, + get_metrics_help_modal +) # --- Logging Setup --- logging.basicConfig( level=logging.INFO, - format="%(asctime)s %(levelname)s %(name)s: %(message)s", + format='%(asctime)s %(levelname)s %(name)s: %(message)s', ) logger = logging.getLogger("ImageLearner") -def get_metrics_help_modal() -> str: - modal_html = """ -<div id="metricsHelpModal" class="modal"> - <div class="modal-content"> - <span class="close">×</span> - <h2>Model Evaluation Metrics — Help Guide</h2> - <div class="metrics-guide"> - <h3>1) General Metrics</h3> - <p><strong>Loss:</strong> Measures the difference between predicted and actual values. Lower is better. Often used for optimization during training.</p> - <p><strong>Accuracy:</strong> Proportion of correct predictions among all predictions. Simple but can be misleading for imbalanced datasets.</p> - <p><strong>Micro Accuracy:</strong> Calculates accuracy by summing up all individual true positives and true negatives across all classes, making it suitable for multiclass or multilabel problems.</p> - <p><strong>Token Accuracy:</strong> Measures how often the predicted tokens (e.g., in sequences) match the true tokens. Useful in sequence prediction tasks like NLP.</p> - <h3>2) Precision, Recall & Specificity</h3> - <p><strong>Precision:</strong> Out of all positive predictions, how many were correct. Precision = TP / (TP + FP). Helps when false positives are costly.</p> - <p><strong>Recall (Sensitivity):</strong> Out of all actual positives, how many were predicted correctly. Recall = TP / (TP + FN). Important when missing positives is risky.</p> - <p><strong>Specificity:</strong> True negative rate. Measures how well the model identifies negatives. Specificity = TN / (TN + FP). Useful in medical testing to avoid false alarms.</p> - <h3>3) Macro, Micro, and Weighted Averages</h3> - <p><strong>Macro Precision / Recall / F1:</strong> Averages the metric across all classes, treating each class equally, regardless of class frequency. Best when class sizes are balanced.</p> - <p><strong>Micro Precision / Recall / F1:</strong> Aggregates TP, FP, FN across all classes before computing the metric. Gives a global view and is ideal for class-imbalanced problems.</p> - <p><strong>Weighted Precision / Recall / F1:</strong> Averages each metric across classes, weighted by the number of true instances per class. Balances importance of classes based on frequency.</p> - <h3>4) Average Precision (PR-AUC Variants)</h3> - <p><strong>Average Precision Macro:</strong> Precision-Recall AUC averaged across all classes equally. Useful for balanced multi-class problems.</p> - <p><strong>Average Precision Micro:</strong> Global Precision-Recall AUC using all instances. Best for imbalanced data or multi-label classification.</p> - <p><strong>Average Precision Samples:</strong> Precision-Recall AUC averaged across individual samples (not classes). Ideal for multi-label problems where each sample can belong to multiple classes.</p> - <h3>5) ROC-AUC Variants</h3> - <p><strong>ROC-AUC:</strong> Measures model's ability to distinguish between classes. AUC = 1 is perfect; 0.5 is random guessing. Use for binary classification.</p> - <p><strong>Macro ROC-AUC:</strong> Averages the AUC across all classes equally. Suitable when classes are balanced and of equal importance.</p> - <p><strong>Micro ROC-AUC:</strong> Computes AUC from aggregated predictions across all classes. Useful in multiclass or multilabel settings with imbalance.</p> - <h3>6) Ranking Metrics</h3> - <p><strong>Hits at K:</strong> Measures whether the true label is among the top-K predictions. Common in recommendation systems and retrieval tasks.</p> - <h3>7) Confusion Matrix Stats (Per Class)</h3> - <p><strong>True Positives / Negatives (TP / TN):</strong> Correct predictions for positives and negatives respectively.</p> - <p><strong>False Positives / Negatives (FP / FN):</strong> Incorrect predictions — false alarms and missed detections.</p> - <h3>8) Other Useful Metrics</h3> - <p><strong>Cohen's Kappa:</strong> Measures agreement between predicted and actual values adjusted for chance. Useful for multiclass classification with imbalanced labels.</p> - <p><strong>Matthews Correlation Coefficient (MCC):</strong> Balanced measure of prediction quality that takes into account TP, TN, FP, and FN. Particularly effective for imbalanced datasets.</p> - <h3>9) Metric Recommendations</h3> - <ul> - <li>Use <strong>Accuracy + F1</strong> for balanced data.</li> - <li>Use <strong>Precision, Recall, ROC-AUC</strong> for imbalanced datasets.</li> - <li>Use <strong>Average Precision Micro</strong> for multilabel or class-imbalanced problems.</li> - <li>Use <strong>Macro scores</strong> when all classes should be treated equally.</li> - <li>Use <strong>Weighted scores</strong> when class imbalance should be accounted for without ignoring small classes.</li> - <li>Use <strong>Confusion Matrix stats</strong> to analyze class-wise performance.</li> - <li>Use <strong>Hits at K</strong> for recommendation or ranking-based tasks.</li> - </ul> - </div> - </div> -</div> -""" - modal_css = """ -<style> -.modal { - display: none; - position: fixed; - z-index: 1; - left: 0; - top: 0; - width: 100%; - height: 100%; - overflow: auto; - background-color: rgba(0,0,0,0.4); -} -.modal-content { - background-color: #fefefe; - margin: 15% auto; - padding: 20px; - border: 1px solid #888; - width: 80%; - max-width: 800px; -} -.close { - color: #aaa; - float: right; - font-size: 28px; - font-weight: bold; -} -.close:hover, -.close:focus { - color: black; - text-decoration: none; - cursor: pointer; -} -.metrics-guide h3 { - margin-top: 20px; -} -.metrics-guide p { - margin: 5px 0; -} -.metrics-guide ul { - margin: 10px 0; - padding-left: 20px; -} -</style> -""" - modal_js = """ -<script> -document.addEventListener("DOMContentLoaded", function() { - var modal = document.getElementById("metricsHelpModal"); - var closeBtn = document.getElementsByClassName("close")[0]; - - document.querySelectorAll(".openMetricsHelp").forEach(btn => { - btn.onclick = function() { - modal.style.display = "block"; - }; - }); - - if (closeBtn) { - closeBtn.onclick = function() { - modal.style.display = "none"; - }; - } - - window.onclick = function(event) { - if (event.target == modal) { - modal.style.display = "none"; - } - } -}); -</script> -""" - return modal_css + modal_html + modal_js - - def format_config_table_html( config: dict, split_info: Optional[str] = None, training_progress: dict = None, ) -> str: display_keys = [ + "task_type", "model_name", "epochs", "batch_size", @@ -287,6 +68,8 @@ for key in display_keys: val = config.get(key, "N/A") + if key == "task_type": + val = val.title() if isinstance(val, str) else val if key == "batch_size": if val is not None: val = int(val) @@ -348,6 +131,18 @@ f"</tr>" ) + aug_cfg = config.get("augmentation") + if aug_cfg: + types = [str(a.get("type", "")) for a in aug_cfg] + aug_val = ", ".join(types) + rows.append( + "<tr>" + "<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left;'>Augmentation</td>" + "<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center;'>" + f"{aug_val}</td>" + "</tr>" + ) + if split_info: rows.append( f"<tr>" @@ -371,7 +166,7 @@ "<p style='text-align: center; font-size: 0.9em;'>" "Model trained using Ludwig.<br>" "If want to learn more about Ludwig default settings," - "please check the their <a href='https://ludwig.ai' target='_blank'>" + "please check their <a href='https://ludwig.ai' target='_blank'>" "website(ludwig.ai)</a>." "</p><hr>" ) @@ -380,6 +175,8 @@ def detect_output_type(test_stats): """Detects if the output type is 'binary' or 'category' based on test statistics.""" label_stats = test_stats.get("label", {}) + if "mean_squared_error" in label_stats: + return "regression" per_class = label_stats.get("per_class_stats", {}) if len(per_class) == 2: return "binary" @@ -420,6 +217,24 @@ "specificity": get_last_value(label_stats, "specificity"), "roc_auc": get_last_value(label_stats, "roc_auc"), } + elif output_type == "regression": + metrics[split] = { + "loss": get_last_value(label_stats, "loss"), + "mean_absolute_error": get_last_value( + label_stats, "mean_absolute_error" + ), + "mean_absolute_percentage_error": get_last_value( + label_stats, "mean_absolute_percentage_error" + ), + "mean_squared_error": get_last_value(label_stats, "mean_squared_error"), + "root_mean_squared_error": get_last_value( + label_stats, "root_mean_squared_error" + ), + "root_mean_squared_percentage_error": get_last_value( + label_stats, "root_mean_squared_percentage_error" + ), + "r2": get_last_value(label_stats, "r2"), + } else: metrics[split] = { "accuracy": get_last_value(label_stats, "accuracy"), @@ -565,7 +380,9 @@ return html -def format_test_merged_stats_table_html(test_metrics: Dict[str, Optional[float]]) -> str: +def format_test_merged_stats_table_html( + test_metrics: Dict[str, Optional[float]], +) -> str: """Formats an HTML table for test metrics.""" rows = [] for key in sorted(test_metrics.keys()): @@ -598,63 +415,6 @@ return html -def build_tabbed_html(metrics_html: str, train_val_html: str, test_html: str) -> str: - return f""" -<style> -.tabs {{ - display: flex; - border-bottom: 2px solid #ccc; - margin-bottom: 1rem; -}} -.tab {{ - padding: 10px 20px; - cursor: pointer; - border: 1px solid #ccc; - border-bottom: none; - background: #f9f9f9; - margin-right: 5px; - border-top-left-radius: 8px; - border-top-right-radius: 8px; -}} -.tab.active {{ - background: white; - font-weight: bold; -}} -.tab-content {{ - display: none; - padding: 20px; - border: 1px solid #ccc; - border-top: none; -}} -.tab-content.active {{ - display: block; -}} -</style> -<div class="tabs"> - <div class="tab active" onclick="showTab('metrics')"> Config & Results Summary</div> - <div class="tab" onclick="showTab('trainval')"> Train/Validation Results</div> - <div class="tab" onclick="showTab('test')"> Test Results</div> -</div> -<div id="metrics" class="tab-content active"> - {metrics_html} -</div> -<div id="trainval" class="tab-content"> - {train_val_html} -</div> -<div id="test" class="tab-content"> - {test_html} -</div> -<script> -function showTab(id) {{ - document.querySelectorAll('.tab-content').forEach(el => el.classList.remove('active')); - document.querySelectorAll('.tab').forEach(el => el.classList.remove('active')); - document.getElementById(id).classList.add('active'); - document.querySelector(`.tab[onclick*="${{id}}"]`).classList.add('active'); -}} -</script> -""" - - def split_data_0_2( df: pd.DataFrame, split_column: str, @@ -727,16 +487,15 @@ ) -> None: ... - def generate_plots( - self, - output_dir: Path - ) -> None: + def generate_plots(self, output_dir: Path) -> None: ... def generate_html_report( self, title: str, - output_dir: str + output_dir: str, + config: Dict[str, Any], + split_info: str, ) -> Path: ... @@ -749,23 +508,21 @@ config_params: Dict[str, Any], split_config: Dict[str, Any], ) -> str: - """Build and serialize the Ludwig YAML configuration.""" logger.info("LudwigDirectBackend: Preparing YAML configuration.") model_name = config_params.get("model_name", "resnet18") use_pretrained = config_params.get("use_pretrained", False) fine_tune = config_params.get("fine_tune", False) + if use_pretrained: + trainable = bool(fine_tune) + else: + trainable = True epochs = config_params.get("epochs", 10) batch_size = config_params.get("batch_size") num_processes = config_params.get("preprocessing_num_processes", 1) early_stop = config_params.get("early_stop", None) learning_rate = config_params.get("learning_rate") learning_rate = "auto" if learning_rate is None else float(learning_rate) - trainable = fine_tune or (not use_pretrained) - if not use_pretrained and not trainable: - logger.warning("trainable=False; use_pretrained=False is ignored.") - logger.warning("Setting trainable=True to train the model from scratch.") - trainable = True raw_encoder = MODEL_ENCODER_TEMPLATES.get(model_name, model_name) if isinstance(raw_encoder, dict): encoder_config = { @@ -779,39 +536,68 @@ batch_size_cfg = batch_size or "auto" label_column_path = config_params.get("label_column_data_path") + label_series = None if label_column_path is not None and Path(label_column_path).exists(): try: label_series = pd.read_csv(label_column_path)[LABEL_COLUMN_NAME] - num_unique_labels = label_series.nunique() except Exception as e: - logger.warning( - f"Could not determine label cardinality, defaulting to 'binary': {e}" - ) - num_unique_labels = 2 + logger.warning(f"Could not read label column for task detection: {e}") + + if ( + label_series is not None + and ptypes.is_numeric_dtype(label_series.dtype) + and label_series.nunique() > 10 + ): + task_type = "regression" else: - logger.warning( - "label_column_data_path not provided, defaulting to 'binary'" + task_type = "classification" + + config_params["task_type"] = task_type + + image_feat: Dict[str, Any] = { + "name": IMAGE_PATH_COLUMN_NAME, + "type": "image", + "encoder": encoder_config, + } + if config_params.get("augmentation") is not None: + image_feat["augmentation"] = config_params["augmentation"] + + if task_type == "regression": + output_feat = { + "name": LABEL_COLUMN_NAME, + "type": "number", + "decoder": {"type": "regressor"}, + "loss": {"type": "mean_squared_error"}, + "evaluation": { + "metrics": [ + "mean_squared_error", + "mean_absolute_error", + "r2", + ] + }, + } + val_metric = config_params.get("validation_metric", "mean_squared_error") + + else: + num_unique_labels = ( + label_series.nunique() if label_series is not None else 2 ) - num_unique_labels = 2 - - output_type = "binary" if num_unique_labels == 2 else "category" + output_type = "binary" if num_unique_labels == 2 else "category" + output_feat = {"name": LABEL_COLUMN_NAME, "type": output_type} + val_metric = None conf: Dict[str, Any] = { "model_type": "ecd", - "input_features": [ - { - "name": IMAGE_PATH_COLUMN_NAME, - "type": "image", - "encoder": encoder_config, - } - ], - "output_features": [{"name": LABEL_COLUMN_NAME, "type": output_type}], + "input_features": [image_feat], + "output_features": [output_feat], "combiner": {"type": "concat"}, "trainer": { "epochs": epochs, "early_stop": early_stop, "batch_size": batch_size_cfg, "learning_rate": learning_rate, + # only set validation_metric for regression + **({"validation_metric": val_metric} if val_metric else {}), }, "preprocessing": { "split": split_config, @@ -876,7 +662,7 @@ ) raise - def get_training_process(self, output_dir) -> float: + def get_training_process(self, output_dir) -> Optional[Dict[str, Any]]: """Retrieve the learning rate used in the most recent Ludwig run.""" output_dir = Path(output_dir) exp_dirs = sorted( @@ -1000,11 +786,12 @@ viz_registry = get_visualizations_registry() for viz_name, viz_func in viz_registry.items(): - viz_dir_plot = None if viz_name in train_plots: viz_dir_plot = train_viz elif viz_name in test_plots: viz_dir_plot = test_viz + else: + continue try: viz_func( @@ -1040,6 +827,7 @@ report_name = title.lower().replace(" ", "_") + "_report.html" report_path = cwd / report_name output_dir = Path(output_dir) + output_type = None exp_dirs = sorted( output_dir.glob("experiment_run*"), @@ -1059,7 +847,6 @@ metrics_html = "" train_val_metrics_html = "" test_metrics_html = "" - try: train_stats_path = exp_dir / "training_statistics.json" test_stats_path = exp_dir / TEST_STATISTICS_FILE_NAME @@ -1069,18 +856,14 @@ with open(test_stats_path) as f: test_stats = json.load(f) output_type = detect_output_type(test_stats) - all_metrics = extract_metrics_from_json( - train_stats, - test_stats, - output_type, - ) metrics_html = format_stats_table_html(train_stats, test_stats) train_val_metrics_html = format_train_val_stats_table_html( - train_stats, - test_stats, + train_stats, test_stats ) test_metrics_html = format_test_merged_stats_table_html( - all_metrics["test"], + extract_metrics_from_json(train_stats, test_stats, output_type)[ + "test" + ] ) except Exception as e: logger.warning( @@ -1090,11 +873,15 @@ config_html = "" training_progress = self.get_training_process(output_dir) try: - config_html = format_config_table_html(config, split_info, training_progress) + config_html = format_config_table_html( + config, split_info, training_progress + ) except Exception as e: logger.warning(f"Could not load config for HTML report: {e}") - def render_img_section(title: str, dir_path: Path, output_type: str = None) -> str: + def render_img_section( + title: str, dir_path: Path, output_type: str = None + ) -> str: if not dir_path.exists(): return f"<h2>{title}</h2><p><em>Directory not found.</em></p>" @@ -1141,11 +928,7 @@ img_names[fname] for fname in display_order if fname in img_names ] remaining = sorted( - [ - img - for img in img_names.values() - if img.name not in display_order - ] + [img for img in img_names.values() if img.name not in display_order] ) imgs = ordered_imgs + remaining @@ -1173,46 +956,61 @@ section_html += "</div>" return section_html - button_html = """ - <button class="help-modal-btn openMetricsHelp">Model Evaluation Metrics — Help Guide</button> - <br><br> - <style> - .help-modal-btn { - background-color: #17623b; - color: #fff; - border: none; - border-radius: 24px; - padding: 10px 28px; - font-size: 1.1rem; - font-weight: bold; - letter-spacing: 0.03em; - cursor: pointer; - transition: background 0.2s, box-shadow 0.2s; - box-shadow: 0 2px 8px rgba(23,98,59,0.07); - } - .help-modal-btn:hover, .help-modal-btn:focus { - background-color: #21895e; - outline: none; - box-shadow: 0 4px 16px rgba(23,98,59,0.14); - } - </style> - """ - tab1_content = button_html + config_html + metrics_html - tab2_content = ( - button_html - + train_val_metrics_html - + render_img_section("Training & Validation Visualizations", train_viz_dir) + tab1_content = config_html + metrics_html + + tab2_content = train_val_metrics_html + render_img_section( + "Training & Validation Visualizations", train_viz_dir ) + + # --- Predictions vs Ground Truth table --- + preds_section = "" + parquet_path = exp_dir / PREDICTIONS_PARQUET_FILE_NAME + if parquet_path.exists(): + try: + # 1) load predictions from Parquet + df_preds = pd.read_parquet(parquet_path).reset_index(drop=True) + # assume the column containing your model's prediction is named "prediction" + # or contains that substring: + pred_col = next( + (c for c in df_preds.columns if "prediction" in c.lower()), + None, + ) + if pred_col is None: + raise ValueError("No prediction column found in Parquet output") + df_pred = df_preds[[pred_col]].rename(columns={pred_col: "prediction"}) + + # 2) load ground truth for the test split from prepared CSV + df_all = pd.read_csv(config["label_column_data_path"]) + df_gt = df_all[df_all[SPLIT_COLUMN_NAME] == 2][ + LABEL_COLUMN_NAME + ].reset_index(drop=True) + + # 3) concatenate side‐by‐side + df_table = pd.concat([df_gt, df_pred], axis=1) + df_table.columns = [LABEL_COLUMN_NAME, "prediction"] + + # 4) render as HTML + preds_html = df_table.to_html(index=False, classes="predictions-table") + preds_section = ( + "<h2 style='text-align: center;'>Predictions vs. Ground Truth</h2>" + "<div style='overflow-x:auto; margin-bottom:20px;'>" + + preds_html + + "</div>" + ) + except Exception as e: + logger.warning(f"Could not build Predictions vs GT table: {e}") + # Test tab = Metrics + Preds table + Visualizations + tab3_content = ( - button_html - + test_metrics_html + test_metrics_html + + preds_section + render_img_section("Test Visualizations", test_viz_dir, output_type) ) + # assemble the tabs and help modal tabbed_html = build_tabbed_html(tab1_content, tab2_content, tab3_content) modal_html = get_metrics_help_modal() - html += tabbed_html + modal_html - html += get_html_closing() + html += tabbed_html + modal_html + get_html_closing() try: with open(report_path, "w") as f: @@ -1263,7 +1061,7 @@ logger.error("Error extracting zip file", exc_info=True) raise - def _prepare_data(self) -> Tuple[Path, Dict[str, Any]]: + def _prepare_data(self) -> Tuple[Path, Dict[str, Any], str]: """Load CSV, update image paths, handle splits, and write prepared CSV.""" if not self.temp_dir or not self.image_extract_dir: raise RuntimeError("Temp dirs not initialized before data prep.") @@ -1302,8 +1100,9 @@ f"for train/val/test." ) - final_csv = TEMP_CSV_FILENAME + final_csv = self.temp_dir / TEMP_CSV_FILENAME try: + df.to_csv(final_csv, index=False) logger.info(f"Saved prepared data to {final_csv}") except Exception: @@ -1312,7 +1111,9 @@ return final_csv, split_config, split_info - def _process_fixed_split(self, df: pd.DataFrame) -> Dict[str, Any]: + def _process_fixed_split( + self, df: pd.DataFrame + ) -> Tuple[pd.DataFrame, Dict[str, Any], str]: """Process a fixed split column (0=train,1=val,2=test).""" logger.info(f"Fixed split column '{SPLIT_COLUMN_NAME}' detected.") try: @@ -1384,6 +1185,7 @@ "random_seed": self.args.random_seed, "early_stop": self.args.early_stop, "label_column_data_path": csv_path, + "augmentation": self.args.augmentation, } yaml_str = self.backend.prepare_config(backend_args, split_cfg) @@ -1422,6 +1224,29 @@ return None +def aug_parse(aug_string: str): + """ + Parse comma-separated augmentation keys into Ludwig augmentation dicts. + Raises ValueError on unknown key. + """ + mapping = { + "random_horizontal_flip": {"type": "random_horizontal_flip"}, + "random_vertical_flip": {"type": "random_vertical_flip"}, + "random_rotate": {"type": "random_rotate", "degree": 10}, + "random_blur": {"type": "random_blur", "kernel_size": 3}, + "random_brightness": {"type": "random_brightness", "min": 0.5, "max": 2.0}, + "random_contrast": {"type": "random_contrast", "min": 0.5, "max": 2.0}, + } + aug_list = [] + for tok in aug_string.split(","): + key = tok.strip() + if key not in mapping: + valid = ", ".join(mapping.keys()) + raise ValueError(f"Unknown augmentation '{key}'. Valid choices: {valid}") + aug_list.append(mapping[key]) + return aug_list + + class SplitProbAction(argparse.Action): def __call__(self, parser, namespace, values, option_string=None): train, val, test = values @@ -1508,7 +1333,10 @@ metavar=("train", "val", "test"), action=SplitProbAction, default=[0.7, 0.1, 0.2], - help="Random split proportions (e.g., 0.7 0.1 0.2). Only used if no split column.", + help=( + "Random split proportions (e.g., 0.7 0.1 0.2)." + "Only used if no split column." + ), ) parser.add_argument( "--random-seed", @@ -1522,6 +1350,17 @@ default=None, help="Learning rate. If not provided, Ludwig will auto-select it.", ) + parser.add_argument( + "--augmentation", + type=str, + default=None, + help=( + "Comma-separated list (in order) of any of: " + "random_horizontal_flip, random_vertical_flip, random_rotate, " + "random_blur, random_brightness, random_contrast. " + "E.g. --augmentation random_horizontal_flip,random_rotate" + ), + ) args = parser.parse_args() @@ -1531,6 +1370,12 @@ parser.error(f"CSV not found: {args.csv_file}") if not args.image_zip.is_file(): parser.error(f"ZIP not found: {args.image_zip}") + if args.augmentation is not None: + try: + augmentation_setup = aug_parse(args.augmentation) + setattr(args, "augmentation", augmentation_setup) + except ValueError as e: + parser.error(str(e)) backend_instance = LudwigDirectBackend() orchestrator = WorkflowOrchestrator(args, backend_instance)