Mercurial > repos > goeckslab > image_learner
changeset 1:39202fe5cf97 draft
planemo upload for repository https://github.com/goeckslab/gleam.git commit 06c0da44ac93256dfb616a6b40276b5485a71e8e
author | goeckslab |
---|---|
date | Wed, 02 Jul 2025 18:59:10 +0000 |
parents | 54b871dfc51e |
children | 186424a7eca7 |
files | image_learner.xml image_learner_cli.py utils.py |
diffstat | 3 files changed, 756 insertions(+), 329 deletions(-) [+] |
line wrap: on
line diff
--- a/image_learner.xml Tue Jun 03 21:22:11 2025 +0000 +++ b/image_learner.xml Wed Jul 02 18:59:10 2025 +0000 @@ -222,9 +222,11 @@ <param name="input_csv" value="mnist_subset.csv" ftype="csv" /> <param name="image_zip" value="mnist_subset.zip" ftype="zip" /> <param name="model_name" value="resnet18" /> - <output name="output_report" file="image_classification_results_report_mnist.html" compare="sim_size" delta="20000" > + <output name="output_report"> <assert_contents> - <has_text text="Epochs" /> + <has_text text="Results Summary" /> + <has_text text="Train/Validation Results" /> + <has_text text="Test Results" /> </assert_contents> </output>
--- a/image_learner_cli.py Tue Jun 03 21:22:11 2025 +0000 +++ b/image_learner_cli.py Wed Jul 02 18:59:10 2025 +0000 @@ -24,103 +24,254 @@ from utils import encode_image_to_base64, get_html_closing, get_html_template # --- Constants --- -SPLIT_COLUMN_NAME = 'split' -LABEL_COLUMN_NAME = 'label' -IMAGE_PATH_COLUMN_NAME = 'image_path' +SPLIT_COLUMN_NAME = "split" +LABEL_COLUMN_NAME = "label" +IMAGE_PATH_COLUMN_NAME = "image_path" DEFAULT_SPLIT_PROBABILITIES = [0.7, 0.1, 0.2] TEMP_CSV_FILENAME = "processed_data_for_ludwig.csv" TEMP_CONFIG_FILENAME = "ludwig_config.yaml" TEMP_DIR_PREFIX = "ludwig_api_work_" MODEL_ENCODER_TEMPLATES: Dict[str, Any] = { - 'stacked_cnn': 'stacked_cnn', - 'resnet18': {'type': 'resnet', 'model_variant': 18}, - 'resnet34': {'type': 'resnet', 'model_variant': 34}, - 'resnet50': {'type': 'resnet', 'model_variant': 50}, - 'resnet101': {'type': 'resnet', 'model_variant': 101}, - 'resnet152': {'type': 'resnet', 'model_variant': 152}, - 'resnext50_32x4d': {'type': 'resnext', 'model_variant': '50_32x4d'}, - 'resnext101_32x8d': {'type': 'resnext', 'model_variant': '101_32x8d'}, - 'resnext101_64x4d': {'type': 'resnext', 'model_variant': '101_64x4d'}, - 'resnext152_32x8d': {'type': 'resnext', 'model_variant': '152_32x8d'}, - 'wide_resnet50_2': {'type': 'wide_resnet', 'model_variant': '50_2'}, - 'wide_resnet101_2': {'type': 'wide_resnet', 'model_variant': '101_2'}, - 'wide_resnet103_2': {'type': 'wide_resnet', 'model_variant': '103_2'}, - 'efficientnet_b0': {'type': 'efficientnet', 'model_variant': 'b0'}, - 'efficientnet_b1': {'type': 'efficientnet', 'model_variant': 'b1'}, - 'efficientnet_b2': {'type': 'efficientnet', 'model_variant': 'b2'}, - 'efficientnet_b3': {'type': 'efficientnet', 'model_variant': 'b3'}, - 'efficientnet_b4': {'type': 'efficientnet', 'model_variant': 'b4'}, - 'efficientnet_b5': {'type': 'efficientnet', 'model_variant': 'b5'}, - 'efficientnet_b6': {'type': 'efficientnet', 'model_variant': 'b6'}, - 'efficientnet_b7': {'type': 'efficientnet', 'model_variant': 'b7'}, - 'efficientnet_v2_s': {'type': 'efficientnet', 'model_variant': 'v2_s'}, - 'efficientnet_v2_m': {'type': 'efficientnet', 'model_variant': 'v2_m'}, - 'efficientnet_v2_l': {'type': 'efficientnet', 'model_variant': 'v2_l'}, - 'regnet_y_400mf': {'type': 'regnet', 'model_variant': 'y_400mf'}, - 'regnet_y_800mf': {'type': 'regnet', 'model_variant': 'y_800mf'}, - 'regnet_y_1_6gf': {'type': 'regnet', 'model_variant': 'y_1_6gf'}, - 'regnet_y_3_2gf': {'type': 'regnet', 'model_variant': 'y_3_2gf'}, - 'regnet_y_8gf': {'type': 'regnet', 'model_variant': 'y_8gf'}, - 'regnet_y_16gf': {'type': 'regnet', 'model_variant': 'y_16gf'}, - 'regnet_y_32gf': {'type': 'regnet', 'model_variant': 'y_32gf'}, - 'regnet_y_128gf': {'type': 'regnet', 'model_variant': 'y_128gf'}, - 'regnet_x_400mf': {'type': 'regnet', 'model_variant': 'x_400mf'}, - 'regnet_x_800mf': {'type': 'regnet', 'model_variant': 'x_800mf'}, - 'regnet_x_1_6gf': {'type': 'regnet', 'model_variant': 'x_1_6gf'}, - 'regnet_x_3_2gf': {'type': 'regnet', 'model_variant': 'x_3_2gf'}, - 'regnet_x_8gf': {'type': 'regnet', 'model_variant': 'x_8gf'}, - 'regnet_x_16gf': {'type': 'regnet', 'model_variant': 'x_16gf'}, - 'regnet_x_32gf': {'type': 'regnet', 'model_variant': 'x_32gf'}, - 'vgg11': {'type': 'vgg', 'model_variant': 11}, - 'vgg11_bn': {'type': 'vgg', 'model_variant': '11_bn'}, - 'vgg13': {'type': 'vgg', 'model_variant': 13}, - 'vgg13_bn': {'type': 'vgg', 'model_variant': '13_bn'}, - 'vgg16': {'type': 'vgg', 'model_variant': 16}, - 'vgg16_bn': {'type': 'vgg', 'model_variant': '16_bn'}, - 'vgg19': {'type': 'vgg', 'model_variant': 19}, - 'vgg19_bn': {'type': 'vgg', 'model_variant': '19_bn'}, - 'shufflenet_v2_x0_5': {'type': 'shufflenet_v2', 'model_variant': 'x0_5'}, - 'shufflenet_v2_x1_0': {'type': 'shufflenet_v2', 'model_variant': 'x1_0'}, - 'shufflenet_v2_x1_5': {'type': 'shufflenet_v2', 'model_variant': 'x1_5'}, - 'shufflenet_v2_x2_0': {'type': 'shufflenet_v2', 'model_variant': 'x2_0'}, - 'squeezenet1_0': {'type': 'squeezenet', 'model_variant': '1_0'}, - 'squeezenet1_1': {'type': 'squeezenet', 'model_variant': '1_1'}, - 'swin_t': {'type': 'swin_transformer', 'model_variant': 't'}, - 'swin_s': {'type': 'swin_transformer', 'model_variant': 's'}, - 'swin_b': {'type': 'swin_transformer', 'model_variant': 'b'}, - 'swin_v2_t': {'type': 'swin_transformer', 'model_variant': 'v2_t'}, - 'swin_v2_s': {'type': 'swin_transformer', 'model_variant': 'v2_s'}, - 'swin_v2_b': {'type': 'swin_transformer', 'model_variant': 'v2_b'}, - 'vit_b_16': {'type': 'vision_transformer', 'model_variant': 'b_16'}, - 'vit_b_32': {'type': 'vision_transformer', 'model_variant': 'b_32'}, - 'vit_l_16': {'type': 'vision_transformer', 'model_variant': 'l_16'}, - 'vit_l_32': {'type': 'vision_transformer', 'model_variant': 'l_32'}, - 'vit_h_14': {'type': 'vision_transformer', 'model_variant': 'h_14'}, - 'convnext_tiny': {'type': 'convnext', 'model_variant': 'tiny'}, - 'convnext_small': {'type': 'convnext', 'model_variant': 'small'}, - 'convnext_base': {'type': 'convnext', 'model_variant': 'base'}, - 'convnext_large': {'type': 'convnext', 'model_variant': 'large'}, - 'maxvit_t': {'type': 'maxvit', 'model_variant': 't'}, - 'alexnet': {'type': 'alexnet'}, - 'googlenet': {'type': 'googlenet'}, - 'inception_v3': {'type': 'inception_v3'}, - 'mobilenet_v2': {'type': 'mobilenet_v2'}, - 'mobilenet_v3_large': {'type': 'mobilenet_v3_large'}, - 'mobilenet_v3_small': {'type': 'mobilenet_v3_small'}, + "stacked_cnn": "stacked_cnn", + "resnet18": {"type": "resnet", "model_variant": 18}, + "resnet34": {"type": "resnet", "model_variant": 34}, + "resnet50": {"type": "resnet", "model_variant": 50}, + "resnet101": {"type": "resnet", "model_variant": 101}, + "resnet152": {"type": "resnet", "model_variant": 152}, + "resnext50_32x4d": {"type": "resnext", "model_variant": "50_32x4d"}, + "resnext101_32x8d": {"type": "resnext", "model_variant": "101_32x8d"}, + "resnext101_64x4d": {"type": "resnext", "model_variant": "101_64x4d"}, + "resnext152_32x8d": {"type": "resnext", "model_variant": "152_32x8d"}, + "wide_resnet50_2": {"type": "wide_resnet", "model_variant": "50_2"}, + "wide_resnet101_2": {"type": "wide_resnet", "model_variant": "101_2"}, + "wide_resnet103_2": {"type": "wide_resnet", "model_variant": "103_2"}, + "efficientnet_b0": {"type": "efficientnet", "model_variant": "b0"}, + "efficientnet_b1": {"type": "efficientnet", "model_variant": "b1"}, + "efficientnet_b2": {"type": "efficientnet", "model_variant": "b2"}, + "efficientnet_b3": {"type": "efficientnet", "model_variant": "b3"}, + "efficientnet_b4": {"type": "efficientnet", "model_variant": "b4"}, + "efficientnet_b5": {"type": "efficientnet", "model_variant": "b5"}, + "efficientnet_b6": {"type": "efficientnet", "model_variant": "b6"}, + "efficientnet_b7": {"type": "efficientnet", "model_variant": "b7"}, + "efficientnet_v2_s": {"type": "efficientnet", "model_variant": "v2_s"}, + "efficientnet_v2_m": {"type": "efficientnet", "model_variant": "v2_m"}, + "efficientnet_v2_l": {"type": "efficientnet", "model_variant": "v2_l"}, + "regnet_y_400mf": {"type": "regnet", "model_variant": "y_400mf"}, + "regnet_y_800mf": {"type": "regnet", "model_variant": "y_800mf"}, + "regnet_y_1_6gf": {"type": "regnet", "model_variant": "y_1_6gf"}, + "regnet_y_3_2gf": {"type": "regnet", "model_variant": "y_3_2gf"}, + "regnet_y_8gf": {"type": "regnet", "model_variant": "y_8gf"}, + "regnet_y_16gf": {"type": "regnet", "model_variant": "y_16gf"}, + "regnet_y_32gf": {"type": "regnet", "model_variant": "y_32gf"}, + "regnet_y_128gf": {"type": "regnet", "model_variant": "y_128gf"}, + "regnet_x_400mf": {"type": "regnet", "model_variant": "x_400mf"}, + "regnet_x_800mf": {"type": "regnet", "model_variant": "x_800mf"}, + "regnet_x_1_6gf": {"type": "regnet", "model_variant": "x_1_6gf"}, + "regnet_x_3_2gf": {"type": "regnet", "model_variant": "x_3_2gf"}, + "regnet_x_8gf": {"type": "regnet", "model_variant": "x_8gf"}, + "regnet_x_16gf": {"type": "regnet", "model_variant": "x_16gf"}, + "regnet_x_32gf": {"type": "regnet", "model_variant": "x_32gf"}, + "vgg11": {"type": "vgg", "model_variant": 11}, + "vgg11_bn": {"type": "vgg", "model_variant": "11_bn"}, + "vgg13": {"type": "vgg", "model_variant": 13}, + "vgg13_bn": {"type": "vgg", "model_variant": "13_bn"}, + "vgg16": {"type": "vgg", "model_variant": 16}, + "vgg16_bn": {"type": "vgg", "model_variant": "16_bn"}, + "vgg19": {"type": "vgg", "model_variant": 19}, + "vgg19_bn": {"type": "vgg", "model_variant": "19_bn"}, + "shufflenet_v2_x0_5": {"type": "shufflenet_v2", "model_variant": "x0_5"}, + "shufflenet_v2_x1_0": {"type": "shufflenet_v2", "model_variant": "x1_0"}, + "shufflenet_v2_x1_5": {"type": "shufflenet_v2", "model_variant": "x1_5"}, + "shufflenet_v2_x2_0": {"type": "shufflenet_v2", "model_variant": "x2_0"}, + "squeezenet1_0": {"type": "squeezenet", "model_variant": "1_0"}, + "squeezenet1_1": {"type": "squeezenet", "model_variant": "1_1"}, + "swin_t": {"type": "swin_transformer", "model_variant": "t"}, + "swin_s": {"type": "swin_transformer", "model_variant": "s"}, + "swin_b": {"type": "swin_transformer", "model_variant": "b"}, + "swin_v2_t": {"type": "swin_transformer", "model_variant": "v2_t"}, + "swin_v2_s": {"type": "swin_transformer", "model_variant": "v2_s"}, + "swin_v2_b": {"type": "swin_transformer", "model_variant": "v2_b"}, + "vit_b_16": {"type": "vision_transformer", "model_variant": "b_16"}, + "vit_b_32": {"type": "vision_transformer", "model_variant": "b_32"}, + "vit_l_16": {"type": "vision_transformer", "model_variant": "l_16"}, + "vit_l_32": {"type": "vision_transformer", "model_variant": "l_32"}, + "vit_h_14": {"type": "vision_transformer", "model_variant": "h_14"}, + "convnext_tiny": {"type": "convnext", "model_variant": "tiny"}, + "convnext_small": {"type": "convnext", "model_variant": "small"}, + "convnext_base": {"type": "convnext", "model_variant": "base"}, + "convnext_large": {"type": "convnext", "model_variant": "large"}, + "maxvit_t": {"type": "maxvit", "model_variant": "t"}, + "alexnet": {"type": "alexnet"}, + "googlenet": {"type": "googlenet"}, + "inception_v3": {"type": "inception_v3"}, + "mobilenet_v2": {"type": "mobilenet_v2"}, + "mobilenet_v3_large": {"type": "mobilenet_v3_large"}, + "mobilenet_v3_small": {"type": "mobilenet_v3_small"}, +} +METRIC_DISPLAY_NAMES = { + "accuracy": "Accuracy", + "accuracy_micro": "Accuracy-Micro", + "loss": "Loss", + "roc_auc": "ROC-AUC", + "roc_auc_macro": "ROC-AUC-Macro", + "roc_auc_micro": "ROC-AUC-Micro", + "hits_at_k": "Hits at K", + "precision": "Precision", + "recall": "Recall", + "specificity": "Specificity", + "kappa_score": "Cohen's Kappa", + "token_accuracy": "Token Accuracy", + "avg_precision_macro": "Precision-Macro", + "avg_recall_macro": "Recall-Macro", + "avg_f1_score_macro": "F1-score-Macro", + "avg_precision_micro": "Precision-Micro", + "avg_recall_micro": "Recall-Micro", + "avg_f1_score_micro": "F1-score-Micro", + "avg_precision_weighted": "Precision-Weighted", + "avg_recall_weighted": "Recall-Weighted", + "avg_f1_score_weighted": "F1-score-Weighted", + "average_precision_macro": " Precision-Average-Macro", + "average_precision_micro": "Precision-Average-Micro", + "average_precision_samples": "Precision-Average-Samples", } # --- Logging Setup --- logging.basicConfig( level=logging.INFO, - format='%(asctime)s %(levelname)s %(name)s: %(message)s' + format="%(asctime)s %(levelname)s %(name)s: %(message)s", ) logger = logging.getLogger("ImageLearner") +def get_metrics_help_modal() -> str: + modal_html = """ +<div id="metricsHelpModal" class="modal"> + <div class="modal-content"> + <span class="close">×</span> + <h2>Model Evaluation Metrics — Help Guide</h2> + <div class="metrics-guide"> + <h3>1) General Metrics</h3> + <p><strong>Loss:</strong> Measures the difference between predicted and actual values. Lower is better. Often used for optimization during training.</p> + <p><strong>Accuracy:</strong> Proportion of correct predictions among all predictions. Simple but can be misleading for imbalanced datasets.</p> + <p><strong>Micro Accuracy:</strong> Calculates accuracy by summing up all individual true positives and true negatives across all classes, making it suitable for multiclass or multilabel problems.</p> + <p><strong>Token Accuracy:</strong> Measures how often the predicted tokens (e.g., in sequences) match the true tokens. Useful in sequence prediction tasks like NLP.</p> + <h3>2) Precision, Recall & Specificity</h3> + <p><strong>Precision:</strong> Out of all positive predictions, how many were correct. Precision = TP / (TP + FP). Helps when false positives are costly.</p> + <p><strong>Recall (Sensitivity):</strong> Out of all actual positives, how many were predicted correctly. Recall = TP / (TP + FN). Important when missing positives is risky.</p> + <p><strong>Specificity:</strong> True negative rate. Measures how well the model identifies negatives. Specificity = TN / (TN + FP). Useful in medical testing to avoid false alarms.</p> + <h3>3) Macro, Micro, and Weighted Averages</h3> + <p><strong>Macro Precision / Recall / F1:</strong> Averages the metric across all classes, treating each class equally, regardless of class frequency. Best when class sizes are balanced.</p> + <p><strong>Micro Precision / Recall / F1:</strong> Aggregates TP, FP, FN across all classes before computing the metric. Gives a global view and is ideal for class-imbalanced problems.</p> + <p><strong>Weighted Precision / Recall / F1:</strong> Averages each metric across classes, weighted by the number of true instances per class. Balances importance of classes based on frequency.</p> + <h3>4) Average Precision (PR-AUC Variants)</h3> + <p><strong>Average Precision Macro:</strong> Precision-Recall AUC averaged across all classes equally. Useful for balanced multi-class problems.</p> + <p><strong>Average Precision Micro:</strong> Global Precision-Recall AUC using all instances. Best for imbalanced data or multi-label classification.</p> + <p><strong>Average Precision Samples:</strong> Precision-Recall AUC averaged across individual samples (not classes). Ideal for multi-label problems where each sample can belong to multiple classes.</p> + <h3>5) ROC-AUC Variants</h3> + <p><strong>ROC-AUC:</strong> Measures model's ability to distinguish between classes. AUC = 1 is perfect; 0.5 is random guessing. Use for binary classification.</p> + <p><strong>Macro ROC-AUC:</strong> Averages the AUC across all classes equally. Suitable when classes are balanced and of equal importance.</p> + <p><strong>Micro ROC-AUC:</strong> Computes AUC from aggregated predictions across all classes. Useful in multiclass or multilabel settings with imbalance.</p> + <h3>6) Ranking Metrics</h3> + <p><strong>Hits at K:</strong> Measures whether the true label is among the top-K predictions. Common in recommendation systems and retrieval tasks.</p> + <h3>7) Confusion Matrix Stats (Per Class)</h3> + <p><strong>True Positives / Negatives (TP / TN):</strong> Correct predictions for positives and negatives respectively.</p> + <p><strong>False Positives / Negatives (FP / FN):</strong> Incorrect predictions — false alarms and missed detections.</p> + <h3>8) Other Useful Metrics</h3> + <p><strong>Cohen's Kappa:</strong> Measures agreement between predicted and actual values adjusted for chance. Useful for multiclass classification with imbalanced labels.</p> + <p><strong>Matthews Correlation Coefficient (MCC):</strong> Balanced measure of prediction quality that takes into account TP, TN, FP, and FN. Particularly effective for imbalanced datasets.</p> + <h3>9) Metric Recommendations</h3> + <ul> + <li>Use <strong>Accuracy + F1</strong> for balanced data.</li> + <li>Use <strong>Precision, Recall, ROC-AUC</strong> for imbalanced datasets.</li> + <li>Use <strong>Average Precision Micro</strong> for multilabel or class-imbalanced problems.</li> + <li>Use <strong>Macro scores</strong> when all classes should be treated equally.</li> + <li>Use <strong>Weighted scores</strong> when class imbalance should be accounted for without ignoring small classes.</li> + <li>Use <strong>Confusion Matrix stats</strong> to analyze class-wise performance.</li> + <li>Use <strong>Hits at K</strong> for recommendation or ranking-based tasks.</li> + </ul> + </div> + </div> +</div> +""" + modal_css = """ +<style> +.modal { + display: none; + position: fixed; + z-index: 1; + left: 0; + top: 0; + width: 100%; + height: 100%; + overflow: auto; + background-color: rgba(0,0,0,0.4); +} +.modal-content { + background-color: #fefefe; + margin: 15% auto; + padding: 20px; + border: 1px solid #888; + width: 80%; + max-width: 800px; +} +.close { + color: #aaa; + float: right; + font-size: 28px; + font-weight: bold; +} +.close:hover, +.close:focus { + color: black; + text-decoration: none; + cursor: pointer; +} +.metrics-guide h3 { + margin-top: 20px; +} +.metrics-guide p { + margin: 5px 0; +} +.metrics-guide ul { + margin: 10px 0; + padding-left: 20px; +} +</style> +""" + modal_js = """ +<script> +document.addEventListener("DOMContentLoaded", function() { + var modal = document.getElementById("metricsHelpModal"); + var closeBtn = document.getElementsByClassName("close")[0]; + + document.querySelectorAll(".openMetricsHelp").forEach(btn => { + btn.onclick = function() { + modal.style.display = "block"; + }; + }); + + if (closeBtn) { + closeBtn.onclick = function() { + modal.style.display = "none"; + }; + } + + window.onclick = function(event) { + if (event.target == modal) { + modal.style.display = "none"; + } + } +}); +</script> +""" + return modal_css + modal_html + modal_js + + def format_config_table_html( - config: dict, - split_info: Optional[str] = None, - training_progress: dict = None) -> str: + config: dict, + split_info: Optional[str] = None, + training_progress: dict = None, +) -> str: display_keys = [ "model_name", "epochs", @@ -143,9 +294,7 @@ if training_progress: val = "Auto-selected batch size by Ludwig:<br>" resolved_val = training_progress.get("batch_size") - val += ( - f"<span style='font-size: 0.85em;'>{resolved_val}</span><br>" - ) + val += f"<span style='font-size: 0.85em;'>{resolved_val}</span><br>" else: val = "auto" if key == "learning_rate": @@ -155,11 +304,14 @@ resolved_val = training_progress.get("learning_rate") val = ( "Auto-selected learning rate by Ludwig:<br>" - f"<span style='font-size: 0.85em;'>{resolved_val if resolved_val else val}</span><br>" + f"<span style='font-size: 0.85em;'>" + f"{resolved_val if resolved_val else val}</span><br>" "<span style='font-size: 0.85em;'>" - "Based on model architecture and training setup (e.g., fine-tuning).<br>" - "See <a href='https://ludwig.ai/latest/configuration/trainer/#trainer-parameters' " - "target='_blank'>Ludwig Trainer Parameters</a> for details." + "Based on model architecture and training setup " + "(e.g., fine-tuning).<br>" + "See <a href='https://ludwig.ai/latest/configuration/trainer/" + "#trainer-parameters' target='_blank'>" + "Ludwig Trainer Parameters</a> for details." "</span>" ) else: @@ -167,16 +319,21 @@ "Auto-selected by Ludwig<br>" "<span style='font-size: 0.85em;'>" "Automatically tuned based on architecture and dataset.<br>" - "See <a href='https://ludwig.ai/latest/configuration/trainer/#trainer-parameters' " - "target='_blank'>Ludwig Trainer Parameters</a> for details." + "See <a href='https://ludwig.ai/latest/configuration/trainer/" + "#trainer-parameters' target='_blank'>" + "Ludwig Trainer Parameters</a> for details." "</span>" ) else: val = f"{val:.6f}" if key == "epochs": - if training_progress and "epoch" in training_progress and val > training_progress["epoch"]: + if ( + training_progress + and "epoch" in training_progress + and val > training_progress["epoch"] + ): val = ( - f"Because of early stopping: the training" + f"Because of early stopping: the training " f"stopped at epoch {training_progress['epoch']}" ) @@ -186,15 +343,18 @@ f"<tr>" f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left;'>" f"{key.replace('_', ' ').title()}</td>" - f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center;'>{val}</td>" + f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center;'>" + f"{val}</td>" f"</tr>" ) if split_info: rows.append( f"<tr>" - f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left;'>Data Split</td>" - f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center;'>{split_info}</td>" + f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left;'>" + f"Data Split</td>" + f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center;'>" + f"{split_info}</td>" f"</tr>" ) @@ -203,23 +363,36 @@ "<div style='display: flex; justify-content: center;'>" "<table style='border-collapse: collapse; width: 60%; table-layout: auto;'>" "<thead><tr>" - "<th style='padding: 10px; border: 1px solid #ccc; text-align: left;'>Parameter</th>" - "<th style='padding: 10px; border: 1px solid #ccc; text-align: center;'>Value</th>" + "<th style='padding: 10px; border: 1px solid #ccc; text-align: left;'>" + "Parameter</th>" + "<th style='padding: 10px; border: 1px solid #ccc; text-align: center;'>" + "Value</th>" "</tr></thead><tbody>" + "".join(rows) + "</tbody></table></div><br>" "<p style='text-align: center; font-size: 0.9em;'>" "Model trained using Ludwig.<br>" "If want to learn more about Ludwig default settings," - "please check the their <a href='https://ludwig.ai' target='_blank'>website(ludwig.ai)</a>." + "please check the their <a href='https://ludwig.ai' target='_blank'>" + "website(ludwig.ai)</a>." "</p><hr>" ) -def format_stats_table_html(training_stats: dict, test_stats: dict) -> str: - train_metrics = training_stats.get("training", {}).get("label", {}) - val_metrics = training_stats.get("validation", {}).get("label", {}) - test_metrics = test_stats.get("label", {}) +def detect_output_type(test_stats): + """Detects if the output type is 'binary' or 'category' based on test statistics.""" + label_stats = test_stats.get("label", {}) + per_class = label_stats.get("per_class_stats", {}) + if len(per_class) == 2: + return "binary" + return "category" - all_metrics = set(train_metrics) | set(val_metrics) | set(test_metrics) + +def extract_metrics_from_json( + train_stats: dict, + test_stats: dict, + output_type: str, +) -> dict: + """Extracts relevant metrics from training and test statistics based on the output type.""" + metrics = {"training": {}, "validation": {}, "test": {}} def get_last_value(stats, key): val = stats.get(key) @@ -229,48 +402,203 @@ return val return None - rows = [] - for metric in sorted(all_metrics): - t = get_last_value(train_metrics, metric) - v = get_last_value(val_metrics, metric) - te = get_last_value(test_metrics, metric) - if all(x is not None for x in [t, v, te]): - row = ( - f"<tr>" - f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left;'>{metric}</td>" - f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center;'>{t:.4f}</td>" - f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center;'>{v:.4f}</td>" - f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center;'>{te:.4f}</td>" - f"</tr>" - ) - rows.append(row) + for split in ["training", "validation"]: + split_stats = train_stats.get(split, {}) + if not split_stats: + logging.warning(f"No statistics found for {split} split") + continue + label_stats = split_stats.get("label", {}) + if not label_stats: + logging.warning(f"No label statistics found for {split} split") + continue + if output_type == "binary": + metrics[split] = { + "accuracy": get_last_value(label_stats, "accuracy"), + "loss": get_last_value(label_stats, "loss"), + "precision": get_last_value(label_stats, "precision"), + "recall": get_last_value(label_stats, "recall"), + "specificity": get_last_value(label_stats, "specificity"), + "roc_auc": get_last_value(label_stats, "roc_auc"), + } + else: + metrics[split] = { + "accuracy": get_last_value(label_stats, "accuracy"), + "accuracy_micro": get_last_value(label_stats, "accuracy_micro"), + "loss": get_last_value(label_stats, "loss"), + "roc_auc": get_last_value(label_stats, "roc_auc"), + "hits_at_k": get_last_value(label_stats, "hits_at_k"), + } + + # Test metrics: dynamic extraction according to exclusions + test_label_stats = test_stats.get("label", {}) + if not test_label_stats: + logging.warning("No label statistics found for test split") + else: + combined_stats = test_stats.get("combined", {}) + overall_stats = test_label_stats.get("overall_stats", {}) - if not rows: - return "<p><em>No metric values found.</em></p>" + # Define exclusions + if output_type == "binary": + exclude = {"per_class_stats", "precision_recall_curve", "roc_curve"} + else: + exclude = {"per_class_stats", "confusion_matrix"} + + # 1. Get all scalar test_label_stats not excluded + test_metrics = {} + for k, v in test_label_stats.items(): + if k in exclude: + continue + if k == "overall_stats": + continue + if isinstance(v, (int, float, str, bool)): + test_metrics[k] = v + # 2. Add overall_stats (flattened) + for k, v in overall_stats.items(): + test_metrics[k] = v + + # 3. Optionally include combined/loss if present and not already + if "loss" in combined_stats and "loss" not in test_metrics: + test_metrics["loss"] = combined_stats["loss"] + + metrics["test"] = test_metrics + + return metrics + + +def generate_table_row(cells, styles): + """Helper function to generate an HTML table row.""" return ( - "<h2 style='text-align: center;'>Model Performance Summary</h2>" - "<div style='display: flex; justify-content: center;'>" - "<table style='border-collapse: collapse; width: 80%; table-layout: fixed;'>" - "<colgroup>" - "<col style='width: 40%;'>" - "<col style='width: 20%;'>" - "<col style='width: 20%;'>" - "<col style='width: 20%;'>" - "</colgroup>" - "<thead><tr>" - "<th style='padding: 10px; border: 1px solid #ccc; text-align: left;'>Metric</th>" - "<th style='padding: 10px; border: 1px solid #ccc; text-align: center;'>Train</th>" - "<th style='padding: 10px; border: 1px solid #ccc; text-align: center;'>Validation</th>" - "<th style='padding: 10px; border: 1px solid #ccc; text-align: center;'>Test</th>" - "</tr></thead><tbody>" + "".join(rows) + "</tbody></table></div><br>" + "<tr>" + + "".join(f"<td style='{styles}'>{cell}</td>" for cell in cells) + + "</tr>" ) -def build_tabbed_html( - metrics_html: str, - train_viz_html: str, - test_viz_html: str) -> str: +def format_stats_table_html(train_stats: dict, test_stats: dict) -> str: + """Formats a combined HTML table for training, validation, and test metrics.""" + output_type = detect_output_type(test_stats) + all_metrics = extract_metrics_from_json(train_stats, test_stats, output_type) + rows = [] + for metric_key in sorted(all_metrics["training"].keys()): + if ( + metric_key in all_metrics["validation"] + and metric_key in all_metrics["test"] + ): + display_name = METRIC_DISPLAY_NAMES.get( + metric_key, + metric_key.replace("_", " ").title(), + ) + t = all_metrics["training"].get(metric_key) + v = all_metrics["validation"].get(metric_key) + te = all_metrics["test"].get(metric_key) + if all(x is not None for x in [t, v, te]): + rows.append([display_name, f"{t:.4f}", f"{v:.4f}", f"{te:.4f}"]) + + if not rows: + return "<table><tr><td>No metric values found.</td></tr></table>" + + html = ( + "<h2 style='text-align: center;'>Model Performance Summary</h2>" + "<div style='display: flex; justify-content: center;'>" + "<table style='border-collapse: collapse; table-layout: auto;'>" + "<thead><tr>" + "<th style='padding: 10px; border: 1px solid #ccc; text-align: left; " + "white-space: nowrap;'>Metric</th>" + "<th style='padding: 10px; border: 1px solid #ccc; text-align: center; " + "white-space: nowrap;'>Train</th>" + "<th style='padding: 10px; border: 1px solid #ccc; text-align: center; " + "white-space: nowrap;'>Validation</th>" + "<th style='padding: 10px; border: 1px solid #ccc; text-align: center; " + "white-space: nowrap;'>Test</th>" + "</tr></thead><tbody>" + ) + for row in rows: + html += generate_table_row( + row, + "padding: 10px; border: 1px solid #ccc; text-align: center; " + "white-space: nowrap;", + ) + html += "</tbody></table></div><br>" + return html + + +def format_train_val_stats_table_html(train_stats: dict, test_stats: dict) -> str: + """Formats an HTML table for training and validation metrics.""" + output_type = detect_output_type(test_stats) + all_metrics = extract_metrics_from_json(train_stats, test_stats, output_type) + rows = [] + for metric_key in sorted(all_metrics["training"].keys()): + if metric_key in all_metrics["validation"]: + display_name = METRIC_DISPLAY_NAMES.get( + metric_key, + metric_key.replace("_", " ").title(), + ) + t = all_metrics["training"].get(metric_key) + v = all_metrics["validation"].get(metric_key) + if t is not None and v is not None: + rows.append([display_name, f"{t:.4f}", f"{v:.4f}"]) + + if not rows: + return "<table><tr><td>No metric values found for Train/Validation.</td></tr></table>" + + html = ( + "<h2 style='text-align: center;'>Train/Validation Performance Summary</h2>" + "<div style='display: flex; justify-content: center;'>" + "<table style='border-collapse: collapse; table-layout: auto;'>" + "<thead><tr>" + "<th style='padding: 10px; border: 1px solid #ccc; text-align: left; " + "white-space: nowrap;'>Metric</th>" + "<th style='padding: 10px; border: 1px solid #ccc; text-align: center; " + "white-space: nowrap;'>Train</th>" + "<th style='padding: 10px; border: 1px solid #ccc; text-align: center; " + "white-space: nowrap;'>Validation</th>" + "</tr></thead><tbody>" + ) + for row in rows: + html += generate_table_row( + row, + "padding: 10px; border: 1px solid #ccc; text-align: center; " + "white-space: nowrap;", + ) + html += "</tbody></table></div><br>" + return html + + +def format_test_merged_stats_table_html(test_metrics: Dict[str, Optional[float]]) -> str: + """Formats an HTML table for test metrics.""" + rows = [] + for key in sorted(test_metrics.keys()): + display_name = METRIC_DISPLAY_NAMES.get(key, key.replace("_", " ").title()) + value = test_metrics[key] + if value is not None: + rows.append([display_name, f"{value:.4f}"]) + + if not rows: + return "<table><tr><td>No test metric values found.</td></tr></table>" + + html = ( + "<h2 style='text-align: center;'>Test Performance Summary</h2>" + "<div style='display: flex; justify-content: center;'>" + "<table style='border-collapse: collapse; table-layout: auto;'>" + "<thead><tr>" + "<th style='padding: 10px; border: 1px solid #ccc; text-align: left; " + "white-space: nowrap;'>Metric</th>" + "<th style='padding: 10px; border: 1px solid #ccc; text-align: center; " + "white-space: nowrap;'>Test</th>" + "</tr></thead><tbody>" + ) + for row in rows: + html += generate_table_row( + row, + "padding: 10px; border: 1px solid #ccc; text-align: center; " + "white-space: nowrap;", + ) + html += "</tbody></table></div><br>" + return html + + +def build_tabbed_html(metrics_html: str, train_val_html: str, test_html: str) -> str: return f""" <style> .tabs {{ @@ -302,23 +630,20 @@ display: block; }} </style> - <div class="tabs"> - <div class="tab active" onclick="showTab('metrics')"> Config & Metrics</div> - <div class="tab" onclick="showTab('trainval')"> Train/Validation Plots</div> - <div class="tab" onclick="showTab('test')"> Test Plots</div> + <div class="tab active" onclick="showTab('metrics')"> Config & Results Summary</div> + <div class="tab" onclick="showTab('trainval')"> Train/Validation Results</div> + <div class="tab" onclick="showTab('test')"> Test Results</div> </div> - <div id="metrics" class="tab-content active"> {metrics_html} </div> <div id="trainval" class="tab-content"> - {train_viz_html} + {train_val_html} </div> <div id="test" class="tab-content"> - {test_viz_html} + {test_html} </div> - <script> function showTab(id) {{ document.querySelectorAll('.tab-content').forEach(el => el.classList.remove('active')); @@ -337,13 +662,8 @@ random_state: int = 42, label_column: Optional[str] = None, ) -> pd.DataFrame: - """ - Given a DataFrame whose split_column only contains {0,2}, re-assign - a portion of the 0s to become 1s (validation). Returns a fresh DataFrame. - """ - # Work on a copy + """Given a DataFrame whose split_column only contains {0,2}, re-assign a portion of the 0s to become 1s (validation).""" out = df.copy() - # Ensure split col is integer dtype out[split_column] = pd.to_numeric(out[split_column], errors="coerce").astype(int) idx_train = out.index[out[split_column] == 0].tolist() @@ -351,18 +671,15 @@ if not idx_train: logger.info("No rows with split=0; nothing to do.") return out - - # Determine stratify array if possible stratify_arr = None if label_column and label_column in out.columns: - # Only stratify if at least two classes and enough samples label_counts = out.loc[idx_train, label_column].value_counts() if label_counts.size > 1 and (label_counts.min() * validation_size) >= 1: stratify_arr = out.loc[idx_train, label_column] else: - logger.warning("Cannot stratify (too few labels); splitting without stratify.") - - # Edge cases + logger.warning( + "Cannot stratify (too few labels); splitting without stratify." + ) if validation_size <= 0: logger.info("validation_size <= 0; keeping all as train.") return out @@ -370,14 +687,12 @@ logger.info("validation_size >= 1; moving all train → validation.") out.loc[idx_train, split_column] = 1 return out - - # Do the split try: train_idx, val_idx = train_test_split( idx_train, test_size=validation_size, random_state=random_state, - stratify=stratify_arr + stratify=stratify_arr, ) except ValueError as e: logger.warning(f"Stratified split failed ({e}); retrying without stratify.") @@ -385,26 +700,21 @@ idx_train, test_size=validation_size, random_state=random_state, - stratify=None + stratify=None, ) - - # Assign new splits out.loc[train_idx, split_column] = 0 out.loc[val_idx, split_column] = 1 - # idx_test stays at 2 - - # Cast back to a clean integer type out[split_column] = out[split_column].astype(int) - # print(out) return out class Backend(Protocol): """Interface for a machine learning backend.""" + def prepare_config( self, config_params: Dict[str, Any], - split_config: Dict[str, Any] + split_config: Dict[str, Any], ) -> str: ... @@ -432,18 +742,14 @@ class LudwigDirectBackend: - """ - Backend for running Ludwig experiments directly via the internal experiment_cli function. - """ + """Backend for running Ludwig experiments directly via the internal experiment_cli function.""" def prepare_config( self, config_params: Dict[str, Any], split_config: Dict[str, Any], ) -> str: - """ - Build and serialize the Ludwig YAML configuration. - """ + """Build and serialize the Ludwig YAML configuration.""" logger.info("LudwigDirectBackend: Preparing YAML configuration.") model_name = config_params.get("model_name", "resnet18") @@ -460,8 +766,6 @@ logger.warning("trainable=False; use_pretrained=False is ignored.") logger.warning("Setting trainable=True to train the model from scratch.") trainable = True - - # Encoder setup raw_encoder = MODEL_ENCODER_TEMPLATES.get(model_name, model_name) if isinstance(raw_encoder, dict): encoder_config = { @@ -472,10 +776,26 @@ else: encoder_config = {"type": raw_encoder} - # Trainer & optimizer - # optimizer = {"type": "adam", "learning_rate": 5e-5} if fine_tune else {"type": "adam"} batch_size_cfg = batch_size or "auto" + label_column_path = config_params.get("label_column_data_path") + if label_column_path is not None and Path(label_column_path).exists(): + try: + label_series = pd.read_csv(label_column_path)[LABEL_COLUMN_NAME] + num_unique_labels = label_series.nunique() + except Exception as e: + logger.warning( + f"Could not determine label cardinality, defaulting to 'binary': {e}" + ) + num_unique_labels = 2 + else: + logger.warning( + "label_column_data_path not provided, defaulting to 'binary'" + ) + num_unique_labels = 2 + + output_type = "binary" if num_unique_labels == 2 else "category" + conf: Dict[str, Any] = { "model_type": "ecd", "input_features": [ @@ -485,9 +805,7 @@ "encoder": encoder_config, } ], - "output_features": [ - {"name": LABEL_COLUMN_NAME, "type": "category"} - ], + "output_features": [{"name": LABEL_COLUMN_NAME, "type": output_type}], "combiner": {"type": "concat"}, "trainer": { "epochs": epochs, @@ -508,7 +826,10 @@ logger.info("LudwigDirectBackend: YAML config generated.") return yaml_str except Exception: - logger.error("LudwigDirectBackend: Failed to serialize YAML.", exc_info=True) + logger.error( + "LudwigDirectBackend: Failed to serialize YAML.", + exc_info=True, + ) raise def run_experiment( @@ -518,9 +839,7 @@ output_dir: Path, random_seed: int = 42, ) -> None: - """ - Invoke Ludwig's internal experiment_cli function to run the experiment. - """ + """Invoke Ludwig's internal experiment_cli function to run the experiment.""" logger.info("LudwigDirectBackend: Starting experiment execution.") try: @@ -528,7 +847,7 @@ except ImportError as e: logger.error( "LudwigDirectBackend: Could not import experiment_cli.", - exc_info=True + exc_info=True, ) raise RuntimeError("Ludwig import failed.") from e @@ -541,30 +860,28 @@ output_directory=str(output_dir), random_seed=random_seed, ) - logger.info(f"LudwigDirectBackend: Experiment completed. Results in {output_dir}") + logger.info( + f"LudwigDirectBackend: Experiment completed. Results in {output_dir}" + ) except TypeError as e: logger.error( "LudwigDirectBackend: Argument mismatch in experiment_cli call.", - exc_info=True + exc_info=True, ) raise RuntimeError("Ludwig argument error.") from e except Exception: logger.error( "LudwigDirectBackend: Experiment execution error.", - exc_info=True + exc_info=True, ) raise def get_training_process(self, output_dir) -> float: - """ - Retrieve the learning rate used in the most recent Ludwig run. - Returns: - float: learning rate (or None if not found) - """ + """Retrieve the learning rate used in the most recent Ludwig run.""" output_dir = Path(output_dir) exp_dirs = sorted( output_dir.glob("experiment_run*"), - key=lambda p: p.stat().st_mtime + key=lambda p: p.stat().st_mtime, ) if not exp_dirs: @@ -585,7 +902,7 @@ "epoch": data.get("epoch"), } except Exception as e: - self.logger.warning(f"Failed to read training progress info: {e}") + logger.warning(f"Failed to read training progress info: {e}") return {} def convert_parquet_to_csv(self, output_dir: Path): @@ -593,7 +910,7 @@ output_dir = Path(output_dir) exp_dirs = sorted( output_dir.glob("experiment_run*"), - key=lambda p: p.stat().st_mtime + key=lambda p: p.stat().st_mtime, ) if not exp_dirs: logger.warning(f"No experiment run dirs found in {output_dir}") @@ -609,47 +926,43 @@ logger.error(f"Error converting Parquet to CSV: {e}") def generate_plots(self, output_dir: Path) -> None: - """ - Generate _all_ registered Ludwig visualizations for the latest experiment run. - """ + """Generate all registered Ludwig visualizations for the latest experiment run.""" logger.info("Generating all Ludwig visualizations…") test_plots = { - 'compare_performance', - 'compare_classifiers_performance_from_prob', - 'compare_classifiers_performance_from_pred', - 'compare_classifiers_performance_changing_k', - 'compare_classifiers_multiclass_multimetric', - 'compare_classifiers_predictions', - 'confidence_thresholding_2thresholds_2d', - 'confidence_thresholding_2thresholds_3d', - 'confidence_thresholding', - 'confidence_thresholding_data_vs_acc', - 'binary_threshold_vs_metric', - 'roc_curves', - 'roc_curves_from_test_statistics', - 'calibration_1_vs_all', - 'calibration_multiclass', - 'confusion_matrix', - 'frequency_vs_f1', + "compare_performance", + "compare_classifiers_performance_from_prob", + "compare_classifiers_performance_from_pred", + "compare_classifiers_performance_changing_k", + "compare_classifiers_multiclass_multimetric", + "compare_classifiers_predictions", + "confidence_thresholding_2thresholds_2d", + "confidence_thresholding_2thresholds_3d", + "confidence_thresholding", + "confidence_thresholding_data_vs_acc", + "binary_threshold_vs_metric", + "roc_curves", + "roc_curves_from_test_statistics", + "calibration_1_vs_all", + "calibration_multiclass", + "confusion_matrix", + "frequency_vs_f1", } train_plots = { - 'learning_curves', - 'compare_classifiers_performance_subset', + "learning_curves", + "compare_classifiers_performance_subset", } - # 1) find the most recent experiment directory output_dir = Path(output_dir) exp_dirs = sorted( output_dir.glob("experiment_run*"), - key=lambda p: p.stat().st_mtime + key=lambda p: p.stat().st_mtime, ) if not exp_dirs: logger.warning(f"No experiment run dirs found in {output_dir}") return exp_dir = exp_dirs[-1] - # 2) ensure viz output subfolder exists viz_dir = exp_dir / "visualizations" viz_dir.mkdir(exist_ok=True) train_viz = viz_dir / "train" @@ -657,17 +970,14 @@ train_viz.mkdir(parents=True, exist_ok=True) test_viz.mkdir(parents=True, exist_ok=True) - # 3) helper to check file existence def _check(p: Path) -> Optional[str]: return str(p) if p.exists() else None - # 4) gather standard Ludwig output files training_stats = _check(exp_dir / "training_statistics.json") test_stats = _check(exp_dir / TEST_STATISTICS_FILE_NAME) probs_path = _check(exp_dir / PREDICTIONS_PARQUET_FILE_NAME) gt_metadata = _check(exp_dir / "model" / TRAIN_SET_METADATA_FILE_NAME) - # 5) try to read original dataset & split file from description.json dataset_path = None split_file = None desc = exp_dir / DESCRIPTION_FILE_NAME @@ -677,7 +987,6 @@ dataset_path = _check(Path(cfg.get("dataset", ""))) split_file = _check(Path(get_split_path(cfg.get("dataset", "")))) - # 6) infer output feature name output_feature = "" if desc.exists(): try: @@ -689,7 +998,6 @@ stats = json.load(f) output_feature = next(iter(stats.keys()), "") - # 7) loop through every registered viz viz_registry = get_visualizations_registry() for viz_name, viz_func in viz_registry.items(): viz_dir_plot = None @@ -721,21 +1029,22 @@ logger.info(f"All visualizations written to {viz_dir}") def generate_html_report( - self, - title: str, - output_dir: str, - config: dict, - split_info: str) -> Path: - """ - Assemble an HTML report from visualizations under train_val/ and test/ folders. - """ + self, + title: str, + output_dir: str, + config: dict, + split_info: str, + ) -> Path: + """Assemble an HTML report from visualizations under train_val/ and test/ folders.""" cwd = Path.cwd() report_name = title.lower().replace(" ", "_") + "_report.html" report_path = cwd / report_name output_dir = Path(output_dir) - # Find latest experiment dir - exp_dirs = sorted(output_dir.glob("experiment_run*"), key=lambda p: p.stat().st_mtime) + exp_dirs = sorted( + output_dir.glob("experiment_run*"), + key=lambda p: p.stat().st_mtime, + ) if not exp_dirs: raise RuntimeError(f"No 'experiment*' dirs found in {output_dir}") exp_dir = exp_dirs[-1] @@ -748,8 +1057,9 @@ html += f"<h1>{title}</h1>" metrics_html = "" + train_val_metrics_html = "" + test_metrics_html = "" - # Load and embed metrics table (training/val/test stats) try: train_stats_path = exp_dir / "training_statistics.json" test_stats_path = exp_dir / TEST_STATISTICS_FILE_NAME @@ -758,11 +1068,24 @@ train_stats = json.load(f) with open(test_stats_path) as f: test_stats = json.load(f) - output_feature = next(iter(train_stats.keys()), "") - if output_feature: - metrics_html += format_stats_table_html(train_stats, test_stats) + output_type = detect_output_type(test_stats) + all_metrics = extract_metrics_from_json( + train_stats, + test_stats, + output_type, + ) + metrics_html = format_stats_table_html(train_stats, test_stats) + train_val_metrics_html = format_train_val_stats_table_html( + train_stats, + test_stats, + ) + test_metrics_html = format_test_merged_stats_table_html( + all_metrics["test"], + ) except Exception as e: - logger.warning(f"Could not load stats for HTML report: {e}") + logger.warning( + f"Could not load stats for HTML report: {type(e).__name__}: {e}" + ) config_html = "" training_progress = self.get_training_process(output_dir) @@ -771,29 +1094,124 @@ except Exception as e: logger.warning(f"Could not load config for HTML report: {e}") - def render_img_section(title: str, dir_path: Path) -> str: + def render_img_section(title: str, dir_path: Path, output_type: str = None) -> str: if not dir_path.exists(): return f"<h2>{title}</h2><p><em>Directory not found.</em></p>" - imgs = sorted(dir_path.glob("*.png")) + + imgs = list(dir_path.glob("*.png")) if not imgs: return f"<h2>{title}</h2><p><em>No plots found.</em></p>" + if title == "Test Visualizations" and output_type == "binary": + order = [ + "confusion_matrix__label_top2.png", + "roc_curves_from_prediction_statistics.png", + "compare_performance_label.png", + "confusion_matrix_entropy__label_top2.png", + ] + img_names = {img.name: img for img in imgs} + ordered_imgs = [ + img_names[fname] for fname in order if fname in img_names + ] + remaining = sorted( + [ + img + for img in imgs + if img.name not in order and img.name != "roc_curves.png" + ] + ) + imgs = ordered_imgs + remaining + + elif title == "Test Visualizations" and output_type == "category": + unwanted = { + "compare_classifiers_multiclass_multimetric__label_best10.png", + "compare_classifiers_multiclass_multimetric__label_top10.png", + "compare_classifiers_multiclass_multimetric__label_worst10.png", + } + display_order = [ + "confusion_matrix__label_top10.png", + "roc_curves.png", + "compare_performance_label.png", + "compare_classifiers_performance_from_prob.png", + "compare_classifiers_multiclass_multimetric__label_sorted.png", + "confusion_matrix_entropy__label_top10.png", + ] + img_names = {img.name: img for img in imgs if img.name not in unwanted} + ordered_imgs = [ + img_names[fname] for fname in display_order if fname in img_names + ] + remaining = sorted( + [ + img + for img in img_names.values() + if img.name not in display_order + ] + ) + imgs = ordered_imgs + remaining + + else: + if output_type == "category": + unwanted = { + "compare_classifiers_multiclass_multimetric__label_best10.png", + "compare_classifiers_multiclass_multimetric__label_top10.png", + "compare_classifiers_multiclass_multimetric__label_worst10.png", + } + imgs = sorted([img for img in imgs if img.name not in unwanted]) + else: + imgs = sorted(imgs) + section_html = f"<h2 style='text-align: center;'>{title}</h2><div>" for img in imgs: b64 = encode_image_to_base64(str(img)) section_html += ( f'<div class="plot" style="margin-bottom:20px;text-align:center;">' - f"<h3>{img.stem.replace('_',' ').title()}</h3>" + f"<h3>{img.stem.replace('_', ' ').title()}</h3>" f'<img src="data:image/png;base64,{b64}" ' - 'style="max-width:90%;max-height:600px;border:1px solid #ddd;" />' - "</div>" + f'style="max-width:90%;max-height:600px;border:1px solid #ddd;" />' + f"</div>" ) section_html += "</div>" return section_html - train_plots_html = render_img_section("Training & Validation Visualizations", train_viz_dir) - test_plots_html = render_img_section("Test Visualizations", test_viz_dir) - html += build_tabbed_html(config_html + metrics_html, train_plots_html, test_plots_html) + button_html = """ + <button class="help-modal-btn openMetricsHelp">Model Evaluation Metrics — Help Guide</button> + <br><br> + <style> + .help-modal-btn { + background-color: #17623b; + color: #fff; + border: none; + border-radius: 24px; + padding: 10px 28px; + font-size: 1.1rem; + font-weight: bold; + letter-spacing: 0.03em; + cursor: pointer; + transition: background 0.2s, box-shadow 0.2s; + box-shadow: 0 2px 8px rgba(23,98,59,0.07); + } + .help-modal-btn:hover, .help-modal-btn:focus { + background-color: #21895e; + outline: none; + box-shadow: 0 4px 16px rgba(23,98,59,0.14); + } + </style> + """ + tab1_content = button_html + config_html + metrics_html + tab2_content = ( + button_html + + train_val_metrics_html + + render_img_section("Training & Validation Visualizations", train_viz_dir) + ) + tab3_content = ( + button_html + + test_metrics_html + + render_img_section("Test Visualizations", test_viz_dir, output_type) + ) + + tabbed_html = build_tabbed_html(tab1_content, tab2_content, tab3_content) + modal_html = get_metrics_help_modal() + html += tabbed_html + modal_html html += get_html_closing() try: @@ -808,15 +1226,7 @@ class WorkflowOrchestrator: - """ - Manages the image-classification workflow: - 1. Creates temp dirs - 2. Extracts images - 3. Prepares data (CSV + splits) - 4. Renders a backend config - 5. Runs the experiment - 6. Cleans up - """ + """Manages the image-classification workflow.""" def __init__(self, args: argparse.Namespace, backend: Backend): self.args = args @@ -828,10 +1238,9 @@ def _create_temp_dirs(self) -> None: """Create temporary output and image extraction directories.""" try: - self.temp_dir = Path(tempfile.mkdtemp( - dir=self.args.output_dir, - prefix=TEMP_DIR_PREFIX - )) + self.temp_dir = Path( + tempfile.mkdtemp(dir=self.args.output_dir, prefix=TEMP_DIR_PREFIX) + ) self.image_extract_dir = self.temp_dir / "images" self.image_extract_dir.mkdir() logger.info(f"Created temp directory: {self.temp_dir}") @@ -843,7 +1252,9 @@ """Extract images from ZIP into the temp image directory.""" if self.image_extract_dir is None: raise RuntimeError("Temp image directory not initialized.") - logger.info(f"Extracting images from {self.args.image_zip} → {self.image_extract_dir}") + logger.info( + f"Extracting images from {self.args.image_zip} → {self.image_extract_dir}" + ) try: with zipfile.ZipFile(self.args.image_zip, "r") as z: z.extractall(self.image_extract_dir) @@ -853,16 +1264,10 @@ raise def _prepare_data(self) -> Tuple[Path, Dict[str, Any]]: - """ - Load CSV, update image paths, handle splits, and write prepared CSV. - Returns: - final_csv_path: Path to the prepared CSV - split_config: Dict for backend split settings - """ + """Load CSV, update image paths, handle splits, and write prepared CSV.""" if not self.temp_dir or not self.image_extract_dir: raise RuntimeError("Temp dirs not initialized before data prep.") - # 1) Load try: df = pd.read_csv(self.args.csv_file) logger.info(f"Loaded CSV: {self.args.csv_file}") @@ -870,13 +1275,11 @@ logger.error("Error loading CSV file", exc_info=True) raise - # 2) Validate columns required = {IMAGE_PATH_COLUMN_NAME, LABEL_COLUMN_NAME} missing = required - set(df.columns) if missing: raise ValueError(f"Missing CSV columns: {', '.join(missing)}") - # 3) Update image paths try: df[IMAGE_PATH_COLUMN_NAME] = df[IMAGE_PATH_COLUMN_NAME].apply( lambda p: str((self.image_extract_dir / p).resolve()) @@ -885,21 +1288,20 @@ logger.error("Error updating image paths", exc_info=True) raise - # 4) Handle splits if SPLIT_COLUMN_NAME in df.columns: df, split_config, split_info = self._process_fixed_split(df) else: logger.info("No split column; using random split") split_config = { "type": "random", - "probabilities": self.args.split_probabilities + "probabilities": self.args.split_probabilities, } split_info = ( f"No split column in CSV. Used random split: " - f"{[int(p*100) for p in self.args.split_probabilities]}% for train/val/test." + f"{[int(p * 100) for p in self.args.split_probabilities]}% " + f"for train/val/test." ) - # 5) Write out prepared CSV final_csv = TEMP_CSV_FILENAME try: df.to_csv(final_csv, index=False) @@ -915,7 +1317,9 @@ logger.info(f"Fixed split column '{SPLIT_COLUMN_NAME}' detected.") try: col = df[SPLIT_COLUMN_NAME] - df[SPLIT_COLUMN_NAME] = pd.to_numeric(col, errors="coerce").astype(pd.Int64Dtype()) + df[SPLIT_COLUMN_NAME] = pd.to_numeric(col, errors="coerce").astype( + pd.Int64Dtype() + ) if df[SPLIT_COLUMN_NAME].isna().any(): logger.warning("Split column contains non-numeric/missing values.") @@ -924,18 +1328,18 @@ if unique == {0, 2}: df = split_data_0_2( - df, SPLIT_COLUMN_NAME, + df, + SPLIT_COLUMN_NAME, validation_size=self.args.validation_size, label_column=LABEL_COLUMN_NAME, - random_state=self.args.random_seed + random_state=self.args.random_seed, ) split_info = ( "Detected a split column (with values 0 and 2) in the input CSV. " - f"Used this column as a base and" - f"reassigned {self.args.validation_size * 100:.1f}% " + f"Used this column as a base and reassigned " + f"{self.args.validation_size * 100:.1f}% " "of the training set (originally labeled 0) to validation (labeled 1)." ) - logger.info("Applied custom 0/2 split.") elif unique.issubset({0, 1, 2}): split_info = "Used user-defined split column from CSV." @@ -950,7 +1354,6 @@ raise def _cleanup_temp_dirs(self) -> None: - """Remove any temporary directories.""" if self.temp_dir and self.temp_dir.exists(): logger.info(f"Cleaning up temp directory: {self.temp_dir}") shutil.rmtree(self.temp_dir, ignore_errors=True) @@ -980,6 +1383,7 @@ "learning_rate": self.args.learning_rate, "random_seed": self.args.random_seed, "early_stop": self.args.early_stop, + "label_column_data_path": csv_path, } yaml_str = self.backend.prepare_config(backend_args, split_cfg) @@ -991,7 +1395,7 @@ csv_path, config_file, self.args.output_dir, - self.args.random_seed + self.args.random_seed, ) logger.info("Workflow completed successfully.") self.backend.generate_plots(self.args.output_dir) @@ -999,7 +1403,7 @@ "Image Classification Results", self.args.output_dir, backend_args, - split_info + split_info, ) logger.info(f"HTML report generated at: {report_file}") self.backend.convert_parquet_to_csv(self.args.output_dir) @@ -1007,7 +1411,6 @@ except Exception: logger.error("Workflow execution failed", exc_info=True) raise - finally: self._cleanup_temp_dirs() @@ -1021,7 +1424,6 @@ class SplitProbAction(argparse.Action): def __call__(self, parser, namespace, values, option_string=None): - # values is a list of three floats train, val, test = values total = train + val + test if abs(total - 1.0) > 1e-6: @@ -1033,75 +1435,96 @@ def main(): - parser = argparse.ArgumentParser( - description="Image Classification Learner with Pluggable Backends" + description="Image Classification Learner with Pluggable Backends", ) parser.add_argument( - "--csv-file", required=True, type=Path, - help="Path to the input CSV" + "--csv-file", + required=True, + type=Path, + help="Path to the input CSV", ) parser.add_argument( - "--image-zip", required=True, type=Path, - help="Path to the images ZIP" + "--image-zip", + required=True, + type=Path, + help="Path to the images ZIP", ) parser.add_argument( - "--model-name", required=True, + "--model-name", + required=True, choices=MODEL_ENCODER_TEMPLATES.keys(), - help="Which model template to use" + help="Which model template to use", ) parser.add_argument( - "--use-pretrained", action="store_true", - help="Use pretrained weights for the model" + "--use-pretrained", + action="store_true", + help="Use pretrained weights for the model", ) parser.add_argument( - "--fine-tune", action="store_true", - help="Enable fine-tuning" + "--fine-tune", + action="store_true", + help="Enable fine-tuning", ) parser.add_argument( - "--epochs", type=int, default=10, - help="Number of training epochs" + "--epochs", + type=int, + default=10, + help="Number of training epochs", ) parser.add_argument( - "--early-stop", type=int, default=5, - help="Early stopping patience" + "--early-stop", + type=int, + default=5, + help="Early stopping patience", ) parser.add_argument( - "--batch-size", type=int, - help="Batch size (None = auto)" + "--batch-size", + type=int, + help="Batch size (None = auto)", ) parser.add_argument( - "--output-dir", type=Path, default=Path("learner_output"), - help="Where to write outputs" + "--output-dir", + type=Path, + default=Path("learner_output"), + help="Where to write outputs", ) parser.add_argument( - "--validation-size", type=float, default=0.15, - help="Fraction for validation (0.0–1.0)" + "--validation-size", + type=float, + default=0.15, + help="Fraction for validation (0.0–1.0)", ) parser.add_argument( - "--preprocessing-num-processes", type=int, + "--preprocessing-num-processes", + type=int, default=max(1, os.cpu_count() // 2), - help="CPU processes for data prep" + help="CPU processes for data prep", ) parser.add_argument( - "--split-probabilities", type=float, nargs=3, + "--split-probabilities", + type=float, + nargs=3, metavar=("train", "val", "test"), action=SplitProbAction, default=[0.7, 0.1, 0.2], - help="Random split proportions (e.g., 0.7 0.1 0.2). Only used if no split column is present." + help="Random split proportions (e.g., 0.7 0.1 0.2). Only used if no split column.", ) parser.add_argument( - "--random-seed", type=int, default=42, - help="Random seed used for dataset splitting (default: 42)" + "--random-seed", + type=int, + default=42, + help="Random seed used for dataset splitting (default: 42)", ) parser.add_argument( - "--learning-rate", type=parse_learning_rate, default=None, - help="Learning rate. If not provided, Ludwig will auto-select it." + "--learning-rate", + type=parse_learning_rate, + default=None, + help="Learning rate. If not provided, Ludwig will auto-select it.", ) args = parser.parse_args() - # -- Validation -- if not 0.0 <= args.validation_size <= 1.0: parser.error("validation-size must be between 0.0 and 1.0") if not args.csv_file.is_file(): @@ -1109,12 +1532,9 @@ if not args.image_zip.is_file(): parser.error(f"ZIP not found: {args.image_zip}") - # --- Instantiate Backend and Orchestrator --- - # Use the new LudwigDirectBackend backend_instance = LudwigDirectBackend() orchestrator = WorkflowOrchestrator(args, backend_instance) - # --- Run Workflow --- exit_code = 0 try: orchestrator.run() @@ -1126,12 +1546,16 @@ sys.exit(exit_code) -if __name__ == '__main__': +if __name__ == "__main__": try: import ludwig + logger.debug(f"Found Ludwig version: {ludwig.globals.LUDWIG_VERSION}") except ImportError: - logger.error("Ludwig library not found. Please ensure Ludwig is installed ('pip install ludwig[image]')") + logger.error( + "Ludwig library not found. Please ensure Ludwig is installed " + "('pip install ludwig[image]')" + ) sys.exit(1) main()