Mercurial > repos > goeckslab > image_learner
comparison image_learner_cli.py @ 2:186424a7eca7 draft
planemo upload for repository https://github.com/goeckslab/gleam.git commit 91fa4aba245520fc0680088a07cead66bcfd4ed2
| author | goeckslab |
|---|---|
| date | Thu, 03 Jul 2025 20:43:24 +0000 |
| parents | 39202fe5cf97 |
| children | 09904b1f61f5 |
comparison
equal
deleted
inserted
replaced
| 1:39202fe5cf97 | 2:186424a7eca7 |
|---|---|
| 1 #!/usr/bin/env python3 | |
| 2 import argparse | 1 import argparse |
| 3 import json | 2 import json |
| 4 import logging | 3 import logging |
| 5 import os | 4 import os |
| 6 import shutil | 5 import shutil |
| 9 import zipfile | 8 import zipfile |
| 10 from pathlib import Path | 9 from pathlib import Path |
| 11 from typing import Any, Dict, Optional, Protocol, Tuple | 10 from typing import Any, Dict, Optional, Protocol, Tuple |
| 12 | 11 |
| 13 import pandas as pd | 12 import pandas as pd |
| 13 import pandas.api.types as ptypes | |
| 14 import yaml | 14 import yaml |
| 15 from constants import ( | |
| 16 IMAGE_PATH_COLUMN_NAME, | |
| 17 LABEL_COLUMN_NAME, | |
| 18 METRIC_DISPLAY_NAMES, | |
| 19 MODEL_ENCODER_TEMPLATES, | |
| 20 SPLIT_COLUMN_NAME, | |
| 21 TEMP_CONFIG_FILENAME, | |
| 22 TEMP_CSV_FILENAME, | |
| 23 TEMP_DIR_PREFIX | |
| 24 ) | |
| 15 from ludwig.globals import ( | 25 from ludwig.globals import ( |
| 16 DESCRIPTION_FILE_NAME, | 26 DESCRIPTION_FILE_NAME, |
| 17 PREDICTIONS_PARQUET_FILE_NAME, | 27 PREDICTIONS_PARQUET_FILE_NAME, |
| 18 TEST_STATISTICS_FILE_NAME, | 28 TEST_STATISTICS_FILE_NAME, |
| 19 TRAIN_SET_METADATA_FILE_NAME, | 29 TRAIN_SET_METADATA_FILE_NAME, |
| 20 ) | 30 ) |
| 21 from ludwig.utils.data_utils import get_split_path | 31 from ludwig.utils.data_utils import get_split_path |
| 22 from ludwig.visualize import get_visualizations_registry | 32 from ludwig.visualize import get_visualizations_registry |
| 23 from sklearn.model_selection import train_test_split | 33 from sklearn.model_selection import train_test_split |
| 24 from utils import encode_image_to_base64, get_html_closing, get_html_template | 34 from utils import ( |
| 25 | 35 build_tabbed_html, |
| 26 # --- Constants --- | 36 encode_image_to_base64, |
| 27 SPLIT_COLUMN_NAME = "split" | 37 get_html_closing, |
| 28 LABEL_COLUMN_NAME = "label" | 38 get_html_template, |
| 29 IMAGE_PATH_COLUMN_NAME = "image_path" | 39 get_metrics_help_modal |
| 30 DEFAULT_SPLIT_PROBABILITIES = [0.7, 0.1, 0.2] | 40 ) |
| 31 TEMP_CSV_FILENAME = "processed_data_for_ludwig.csv" | |
| 32 TEMP_CONFIG_FILENAME = "ludwig_config.yaml" | |
| 33 TEMP_DIR_PREFIX = "ludwig_api_work_" | |
| 34 MODEL_ENCODER_TEMPLATES: Dict[str, Any] = { | |
| 35 "stacked_cnn": "stacked_cnn", | |
| 36 "resnet18": {"type": "resnet", "model_variant": 18}, | |
| 37 "resnet34": {"type": "resnet", "model_variant": 34}, | |
| 38 "resnet50": {"type": "resnet", "model_variant": 50}, | |
| 39 "resnet101": {"type": "resnet", "model_variant": 101}, | |
| 40 "resnet152": {"type": "resnet", "model_variant": 152}, | |
| 41 "resnext50_32x4d": {"type": "resnext", "model_variant": "50_32x4d"}, | |
| 42 "resnext101_32x8d": {"type": "resnext", "model_variant": "101_32x8d"}, | |
| 43 "resnext101_64x4d": {"type": "resnext", "model_variant": "101_64x4d"}, | |
| 44 "resnext152_32x8d": {"type": "resnext", "model_variant": "152_32x8d"}, | |
| 45 "wide_resnet50_2": {"type": "wide_resnet", "model_variant": "50_2"}, | |
| 46 "wide_resnet101_2": {"type": "wide_resnet", "model_variant": "101_2"}, | |
| 47 "wide_resnet103_2": {"type": "wide_resnet", "model_variant": "103_2"}, | |
| 48 "efficientnet_b0": {"type": "efficientnet", "model_variant": "b0"}, | |
| 49 "efficientnet_b1": {"type": "efficientnet", "model_variant": "b1"}, | |
| 50 "efficientnet_b2": {"type": "efficientnet", "model_variant": "b2"}, | |
| 51 "efficientnet_b3": {"type": "efficientnet", "model_variant": "b3"}, | |
| 52 "efficientnet_b4": {"type": "efficientnet", "model_variant": "b4"}, | |
| 53 "efficientnet_b5": {"type": "efficientnet", "model_variant": "b5"}, | |
| 54 "efficientnet_b6": {"type": "efficientnet", "model_variant": "b6"}, | |
| 55 "efficientnet_b7": {"type": "efficientnet", "model_variant": "b7"}, | |
| 56 "efficientnet_v2_s": {"type": "efficientnet", "model_variant": "v2_s"}, | |
| 57 "efficientnet_v2_m": {"type": "efficientnet", "model_variant": "v2_m"}, | |
| 58 "efficientnet_v2_l": {"type": "efficientnet", "model_variant": "v2_l"}, | |
| 59 "regnet_y_400mf": {"type": "regnet", "model_variant": "y_400mf"}, | |
| 60 "regnet_y_800mf": {"type": "regnet", "model_variant": "y_800mf"}, | |
| 61 "regnet_y_1_6gf": {"type": "regnet", "model_variant": "y_1_6gf"}, | |
| 62 "regnet_y_3_2gf": {"type": "regnet", "model_variant": "y_3_2gf"}, | |
| 63 "regnet_y_8gf": {"type": "regnet", "model_variant": "y_8gf"}, | |
| 64 "regnet_y_16gf": {"type": "regnet", "model_variant": "y_16gf"}, | |
| 65 "regnet_y_32gf": {"type": "regnet", "model_variant": "y_32gf"}, | |
| 66 "regnet_y_128gf": {"type": "regnet", "model_variant": "y_128gf"}, | |
| 67 "regnet_x_400mf": {"type": "regnet", "model_variant": "x_400mf"}, | |
| 68 "regnet_x_800mf": {"type": "regnet", "model_variant": "x_800mf"}, | |
| 69 "regnet_x_1_6gf": {"type": "regnet", "model_variant": "x_1_6gf"}, | |
| 70 "regnet_x_3_2gf": {"type": "regnet", "model_variant": "x_3_2gf"}, | |
| 71 "regnet_x_8gf": {"type": "regnet", "model_variant": "x_8gf"}, | |
| 72 "regnet_x_16gf": {"type": "regnet", "model_variant": "x_16gf"}, | |
| 73 "regnet_x_32gf": {"type": "regnet", "model_variant": "x_32gf"}, | |
| 74 "vgg11": {"type": "vgg", "model_variant": 11}, | |
| 75 "vgg11_bn": {"type": "vgg", "model_variant": "11_bn"}, | |
| 76 "vgg13": {"type": "vgg", "model_variant": 13}, | |
| 77 "vgg13_bn": {"type": "vgg", "model_variant": "13_bn"}, | |
| 78 "vgg16": {"type": "vgg", "model_variant": 16}, | |
| 79 "vgg16_bn": {"type": "vgg", "model_variant": "16_bn"}, | |
| 80 "vgg19": {"type": "vgg", "model_variant": 19}, | |
| 81 "vgg19_bn": {"type": "vgg", "model_variant": "19_bn"}, | |
| 82 "shufflenet_v2_x0_5": {"type": "shufflenet_v2", "model_variant": "x0_5"}, | |
| 83 "shufflenet_v2_x1_0": {"type": "shufflenet_v2", "model_variant": "x1_0"}, | |
| 84 "shufflenet_v2_x1_5": {"type": "shufflenet_v2", "model_variant": "x1_5"}, | |
| 85 "shufflenet_v2_x2_0": {"type": "shufflenet_v2", "model_variant": "x2_0"}, | |
| 86 "squeezenet1_0": {"type": "squeezenet", "model_variant": "1_0"}, | |
| 87 "squeezenet1_1": {"type": "squeezenet", "model_variant": "1_1"}, | |
| 88 "swin_t": {"type": "swin_transformer", "model_variant": "t"}, | |
| 89 "swin_s": {"type": "swin_transformer", "model_variant": "s"}, | |
| 90 "swin_b": {"type": "swin_transformer", "model_variant": "b"}, | |
| 91 "swin_v2_t": {"type": "swin_transformer", "model_variant": "v2_t"}, | |
| 92 "swin_v2_s": {"type": "swin_transformer", "model_variant": "v2_s"}, | |
| 93 "swin_v2_b": {"type": "swin_transformer", "model_variant": "v2_b"}, | |
| 94 "vit_b_16": {"type": "vision_transformer", "model_variant": "b_16"}, | |
| 95 "vit_b_32": {"type": "vision_transformer", "model_variant": "b_32"}, | |
| 96 "vit_l_16": {"type": "vision_transformer", "model_variant": "l_16"}, | |
| 97 "vit_l_32": {"type": "vision_transformer", "model_variant": "l_32"}, | |
| 98 "vit_h_14": {"type": "vision_transformer", "model_variant": "h_14"}, | |
| 99 "convnext_tiny": {"type": "convnext", "model_variant": "tiny"}, | |
| 100 "convnext_small": {"type": "convnext", "model_variant": "small"}, | |
| 101 "convnext_base": {"type": "convnext", "model_variant": "base"}, | |
| 102 "convnext_large": {"type": "convnext", "model_variant": "large"}, | |
| 103 "maxvit_t": {"type": "maxvit", "model_variant": "t"}, | |
| 104 "alexnet": {"type": "alexnet"}, | |
| 105 "googlenet": {"type": "googlenet"}, | |
| 106 "inception_v3": {"type": "inception_v3"}, | |
| 107 "mobilenet_v2": {"type": "mobilenet_v2"}, | |
| 108 "mobilenet_v3_large": {"type": "mobilenet_v3_large"}, | |
| 109 "mobilenet_v3_small": {"type": "mobilenet_v3_small"}, | |
| 110 } | |
| 111 METRIC_DISPLAY_NAMES = { | |
| 112 "accuracy": "Accuracy", | |
| 113 "accuracy_micro": "Accuracy-Micro", | |
| 114 "loss": "Loss", | |
| 115 "roc_auc": "ROC-AUC", | |
| 116 "roc_auc_macro": "ROC-AUC-Macro", | |
| 117 "roc_auc_micro": "ROC-AUC-Micro", | |
| 118 "hits_at_k": "Hits at K", | |
| 119 "precision": "Precision", | |
| 120 "recall": "Recall", | |
| 121 "specificity": "Specificity", | |
| 122 "kappa_score": "Cohen's Kappa", | |
| 123 "token_accuracy": "Token Accuracy", | |
| 124 "avg_precision_macro": "Precision-Macro", | |
| 125 "avg_recall_macro": "Recall-Macro", | |
| 126 "avg_f1_score_macro": "F1-score-Macro", | |
| 127 "avg_precision_micro": "Precision-Micro", | |
| 128 "avg_recall_micro": "Recall-Micro", | |
| 129 "avg_f1_score_micro": "F1-score-Micro", | |
| 130 "avg_precision_weighted": "Precision-Weighted", | |
| 131 "avg_recall_weighted": "Recall-Weighted", | |
| 132 "avg_f1_score_weighted": "F1-score-Weighted", | |
| 133 "average_precision_macro": " Precision-Average-Macro", | |
| 134 "average_precision_micro": "Precision-Average-Micro", | |
| 135 "average_precision_samples": "Precision-Average-Samples", | |
| 136 } | |
| 137 | 41 |
| 138 # --- Logging Setup --- | 42 # --- Logging Setup --- |
| 139 logging.basicConfig( | 43 logging.basicConfig( |
| 140 level=logging.INFO, | 44 level=logging.INFO, |
| 141 format="%(asctime)s %(levelname)s %(name)s: %(message)s", | 45 format='%(asctime)s %(levelname)s %(name)s: %(message)s', |
| 142 ) | 46 ) |
| 143 logger = logging.getLogger("ImageLearner") | 47 logger = logging.getLogger("ImageLearner") |
| 144 | |
| 145 | |
| 146 def get_metrics_help_modal() -> str: | |
| 147 modal_html = """ | |
| 148 <div id="metricsHelpModal" class="modal"> | |
| 149 <div class="modal-content"> | |
| 150 <span class="close">×</span> | |
| 151 <h2>Model Evaluation Metrics — Help Guide</h2> | |
| 152 <div class="metrics-guide"> | |
| 153 <h3>1) General Metrics</h3> | |
| 154 <p><strong>Loss:</strong> Measures the difference between predicted and actual values. Lower is better. Often used for optimization during training.</p> | |
| 155 <p><strong>Accuracy:</strong> Proportion of correct predictions among all predictions. Simple but can be misleading for imbalanced datasets.</p> | |
| 156 <p><strong>Micro Accuracy:</strong> Calculates accuracy by summing up all individual true positives and true negatives across all classes, making it suitable for multiclass or multilabel problems.</p> | |
| 157 <p><strong>Token Accuracy:</strong> Measures how often the predicted tokens (e.g., in sequences) match the true tokens. Useful in sequence prediction tasks like NLP.</p> | |
| 158 <h3>2) Precision, Recall & Specificity</h3> | |
| 159 <p><strong>Precision:</strong> Out of all positive predictions, how many were correct. Precision = TP / (TP + FP). Helps when false positives are costly.</p> | |
| 160 <p><strong>Recall (Sensitivity):</strong> Out of all actual positives, how many were predicted correctly. Recall = TP / (TP + FN). Important when missing positives is risky.</p> | |
| 161 <p><strong>Specificity:</strong> True negative rate. Measures how well the model identifies negatives. Specificity = TN / (TN + FP). Useful in medical testing to avoid false alarms.</p> | |
| 162 <h3>3) Macro, Micro, and Weighted Averages</h3> | |
| 163 <p><strong>Macro Precision / Recall / F1:</strong> Averages the metric across all classes, treating each class equally, regardless of class frequency. Best when class sizes are balanced.</p> | |
| 164 <p><strong>Micro Precision / Recall / F1:</strong> Aggregates TP, FP, FN across all classes before computing the metric. Gives a global view and is ideal for class-imbalanced problems.</p> | |
| 165 <p><strong>Weighted Precision / Recall / F1:</strong> Averages each metric across classes, weighted by the number of true instances per class. Balances importance of classes based on frequency.</p> | |
| 166 <h3>4) Average Precision (PR-AUC Variants)</h3> | |
| 167 <p><strong>Average Precision Macro:</strong> Precision-Recall AUC averaged across all classes equally. Useful for balanced multi-class problems.</p> | |
| 168 <p><strong>Average Precision Micro:</strong> Global Precision-Recall AUC using all instances. Best for imbalanced data or multi-label classification.</p> | |
| 169 <p><strong>Average Precision Samples:</strong> Precision-Recall AUC averaged across individual samples (not classes). Ideal for multi-label problems where each sample can belong to multiple classes.</p> | |
| 170 <h3>5) ROC-AUC Variants</h3> | |
| 171 <p><strong>ROC-AUC:</strong> Measures model's ability to distinguish between classes. AUC = 1 is perfect; 0.5 is random guessing. Use for binary classification.</p> | |
| 172 <p><strong>Macro ROC-AUC:</strong> Averages the AUC across all classes equally. Suitable when classes are balanced and of equal importance.</p> | |
| 173 <p><strong>Micro ROC-AUC:</strong> Computes AUC from aggregated predictions across all classes. Useful in multiclass or multilabel settings with imbalance.</p> | |
| 174 <h3>6) Ranking Metrics</h3> | |
| 175 <p><strong>Hits at K:</strong> Measures whether the true label is among the top-K predictions. Common in recommendation systems and retrieval tasks.</p> | |
| 176 <h3>7) Confusion Matrix Stats (Per Class)</h3> | |
| 177 <p><strong>True Positives / Negatives (TP / TN):</strong> Correct predictions for positives and negatives respectively.</p> | |
| 178 <p><strong>False Positives / Negatives (FP / FN):</strong> Incorrect predictions — false alarms and missed detections.</p> | |
| 179 <h3>8) Other Useful Metrics</h3> | |
| 180 <p><strong>Cohen's Kappa:</strong> Measures agreement between predicted and actual values adjusted for chance. Useful for multiclass classification with imbalanced labels.</p> | |
| 181 <p><strong>Matthews Correlation Coefficient (MCC):</strong> Balanced measure of prediction quality that takes into account TP, TN, FP, and FN. Particularly effective for imbalanced datasets.</p> | |
| 182 <h3>9) Metric Recommendations</h3> | |
| 183 <ul> | |
| 184 <li>Use <strong>Accuracy + F1</strong> for balanced data.</li> | |
| 185 <li>Use <strong>Precision, Recall, ROC-AUC</strong> for imbalanced datasets.</li> | |
| 186 <li>Use <strong>Average Precision Micro</strong> for multilabel or class-imbalanced problems.</li> | |
| 187 <li>Use <strong>Macro scores</strong> when all classes should be treated equally.</li> | |
| 188 <li>Use <strong>Weighted scores</strong> when class imbalance should be accounted for without ignoring small classes.</li> | |
| 189 <li>Use <strong>Confusion Matrix stats</strong> to analyze class-wise performance.</li> | |
| 190 <li>Use <strong>Hits at K</strong> for recommendation or ranking-based tasks.</li> | |
| 191 </ul> | |
| 192 </div> | |
| 193 </div> | |
| 194 </div> | |
| 195 """ | |
| 196 modal_css = """ | |
| 197 <style> | |
| 198 .modal { | |
| 199 display: none; | |
| 200 position: fixed; | |
| 201 z-index: 1; | |
| 202 left: 0; | |
| 203 top: 0; | |
| 204 width: 100%; | |
| 205 height: 100%; | |
| 206 overflow: auto; | |
| 207 background-color: rgba(0,0,0,0.4); | |
| 208 } | |
| 209 .modal-content { | |
| 210 background-color: #fefefe; | |
| 211 margin: 15% auto; | |
| 212 padding: 20px; | |
| 213 border: 1px solid #888; | |
| 214 width: 80%; | |
| 215 max-width: 800px; | |
| 216 } | |
| 217 .close { | |
| 218 color: #aaa; | |
| 219 float: right; | |
| 220 font-size: 28px; | |
| 221 font-weight: bold; | |
| 222 } | |
| 223 .close:hover, | |
| 224 .close:focus { | |
| 225 color: black; | |
| 226 text-decoration: none; | |
| 227 cursor: pointer; | |
| 228 } | |
| 229 .metrics-guide h3 { | |
| 230 margin-top: 20px; | |
| 231 } | |
| 232 .metrics-guide p { | |
| 233 margin: 5px 0; | |
| 234 } | |
| 235 .metrics-guide ul { | |
| 236 margin: 10px 0; | |
| 237 padding-left: 20px; | |
| 238 } | |
| 239 </style> | |
| 240 """ | |
| 241 modal_js = """ | |
| 242 <script> | |
| 243 document.addEventListener("DOMContentLoaded", function() { | |
| 244 var modal = document.getElementById("metricsHelpModal"); | |
| 245 var closeBtn = document.getElementsByClassName("close")[0]; | |
| 246 | |
| 247 document.querySelectorAll(".openMetricsHelp").forEach(btn => { | |
| 248 btn.onclick = function() { | |
| 249 modal.style.display = "block"; | |
| 250 }; | |
| 251 }); | |
| 252 | |
| 253 if (closeBtn) { | |
| 254 closeBtn.onclick = function() { | |
| 255 modal.style.display = "none"; | |
| 256 }; | |
| 257 } | |
| 258 | |
| 259 window.onclick = function(event) { | |
| 260 if (event.target == modal) { | |
| 261 modal.style.display = "none"; | |
| 262 } | |
| 263 } | |
| 264 }); | |
| 265 </script> | |
| 266 """ | |
| 267 return modal_css + modal_html + modal_js | |
| 268 | 48 |
| 269 | 49 |
| 270 def format_config_table_html( | 50 def format_config_table_html( |
| 271 config: dict, | 51 config: dict, |
| 272 split_info: Optional[str] = None, | 52 split_info: Optional[str] = None, |
| 273 training_progress: dict = None, | 53 training_progress: dict = None, |
| 274 ) -> str: | 54 ) -> str: |
| 275 display_keys = [ | 55 display_keys = [ |
| 56 "task_type", | |
| 276 "model_name", | 57 "model_name", |
| 277 "epochs", | 58 "epochs", |
| 278 "batch_size", | 59 "batch_size", |
| 279 "fine_tune", | 60 "fine_tune", |
| 280 "use_pretrained", | 61 "use_pretrained", |
| 285 | 66 |
| 286 rows = [] | 67 rows = [] |
| 287 | 68 |
| 288 for key in display_keys: | 69 for key in display_keys: |
| 289 val = config.get(key, "N/A") | 70 val = config.get(key, "N/A") |
| 71 if key == "task_type": | |
| 72 val = val.title() if isinstance(val, str) else val | |
| 290 if key == "batch_size": | 73 if key == "batch_size": |
| 291 if val is not None: | 74 if val is not None: |
| 292 val = int(val) | 75 val = int(val) |
| 293 else: | 76 else: |
| 294 if training_progress: | 77 if training_progress: |
| 346 f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center;'>" | 129 f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center;'>" |
| 347 f"{val}</td>" | 130 f"{val}</td>" |
| 348 f"</tr>" | 131 f"</tr>" |
| 349 ) | 132 ) |
| 350 | 133 |
| 134 aug_cfg = config.get("augmentation") | |
| 135 if aug_cfg: | |
| 136 types = [str(a.get("type", "")) for a in aug_cfg] | |
| 137 aug_val = ", ".join(types) | |
| 138 rows.append( | |
| 139 "<tr>" | |
| 140 "<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left;'>Augmentation</td>" | |
| 141 "<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center;'>" | |
| 142 f"{aug_val}</td>" | |
| 143 "</tr>" | |
| 144 ) | |
| 145 | |
| 351 if split_info: | 146 if split_info: |
| 352 rows.append( | 147 rows.append( |
| 353 f"<tr>" | 148 f"<tr>" |
| 354 f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left;'>" | 149 f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left;'>" |
| 355 f"Data Split</td>" | 150 f"Data Split</td>" |
| 369 "Value</th>" | 164 "Value</th>" |
| 370 "</tr></thead><tbody>" + "".join(rows) + "</tbody></table></div><br>" | 165 "</tr></thead><tbody>" + "".join(rows) + "</tbody></table></div><br>" |
| 371 "<p style='text-align: center; font-size: 0.9em;'>" | 166 "<p style='text-align: center; font-size: 0.9em;'>" |
| 372 "Model trained using Ludwig.<br>" | 167 "Model trained using Ludwig.<br>" |
| 373 "If want to learn more about Ludwig default settings," | 168 "If want to learn more about Ludwig default settings," |
| 374 "please check the their <a href='https://ludwig.ai' target='_blank'>" | 169 "please check their <a href='https://ludwig.ai' target='_blank'>" |
| 375 "website(ludwig.ai)</a>." | 170 "website(ludwig.ai)</a>." |
| 376 "</p><hr>" | 171 "</p><hr>" |
| 377 ) | 172 ) |
| 378 | 173 |
| 379 | 174 |
| 380 def detect_output_type(test_stats): | 175 def detect_output_type(test_stats): |
| 381 """Detects if the output type is 'binary' or 'category' based on test statistics.""" | 176 """Detects if the output type is 'binary' or 'category' based on test statistics.""" |
| 382 label_stats = test_stats.get("label", {}) | 177 label_stats = test_stats.get("label", {}) |
| 178 if "mean_squared_error" in label_stats: | |
| 179 return "regression" | |
| 383 per_class = label_stats.get("per_class_stats", {}) | 180 per_class = label_stats.get("per_class_stats", {}) |
| 384 if len(per_class) == 2: | 181 if len(per_class) == 2: |
| 385 return "binary" | 182 return "binary" |
| 386 return "category" | 183 return "category" |
| 387 | 184 |
| 418 "precision": get_last_value(label_stats, "precision"), | 215 "precision": get_last_value(label_stats, "precision"), |
| 419 "recall": get_last_value(label_stats, "recall"), | 216 "recall": get_last_value(label_stats, "recall"), |
| 420 "specificity": get_last_value(label_stats, "specificity"), | 217 "specificity": get_last_value(label_stats, "specificity"), |
| 421 "roc_auc": get_last_value(label_stats, "roc_auc"), | 218 "roc_auc": get_last_value(label_stats, "roc_auc"), |
| 422 } | 219 } |
| 220 elif output_type == "regression": | |
| 221 metrics[split] = { | |
| 222 "loss": get_last_value(label_stats, "loss"), | |
| 223 "mean_absolute_error": get_last_value( | |
| 224 label_stats, "mean_absolute_error" | |
| 225 ), | |
| 226 "mean_absolute_percentage_error": get_last_value( | |
| 227 label_stats, "mean_absolute_percentage_error" | |
| 228 ), | |
| 229 "mean_squared_error": get_last_value(label_stats, "mean_squared_error"), | |
| 230 "root_mean_squared_error": get_last_value( | |
| 231 label_stats, "root_mean_squared_error" | |
| 232 ), | |
| 233 "root_mean_squared_percentage_error": get_last_value( | |
| 234 label_stats, "root_mean_squared_percentage_error" | |
| 235 ), | |
| 236 "r2": get_last_value(label_stats, "r2"), | |
| 237 } | |
| 423 else: | 238 else: |
| 424 metrics[split] = { | 239 metrics[split] = { |
| 425 "accuracy": get_last_value(label_stats, "accuracy"), | 240 "accuracy": get_last_value(label_stats, "accuracy"), |
| 426 "accuracy_micro": get_last_value(label_stats, "accuracy_micro"), | 241 "accuracy_micro": get_last_value(label_stats, "accuracy_micro"), |
| 427 "loss": get_last_value(label_stats, "loss"), | 242 "loss": get_last_value(label_stats, "loss"), |
| 563 ) | 378 ) |
| 564 html += "</tbody></table></div><br>" | 379 html += "</tbody></table></div><br>" |
| 565 return html | 380 return html |
| 566 | 381 |
| 567 | 382 |
| 568 def format_test_merged_stats_table_html(test_metrics: Dict[str, Optional[float]]) -> str: | 383 def format_test_merged_stats_table_html( |
| 384 test_metrics: Dict[str, Optional[float]], | |
| 385 ) -> str: | |
| 569 """Formats an HTML table for test metrics.""" | 386 """Formats an HTML table for test metrics.""" |
| 570 rows = [] | 387 rows = [] |
| 571 for key in sorted(test_metrics.keys()): | 388 for key in sorted(test_metrics.keys()): |
| 572 display_name = METRIC_DISPLAY_NAMES.get(key, key.replace("_", " ").title()) | 389 display_name = METRIC_DISPLAY_NAMES.get(key, key.replace("_", " ").title()) |
| 573 value = test_metrics[key] | 390 value = test_metrics[key] |
| 594 "padding: 10px; border: 1px solid #ccc; text-align: center; " | 411 "padding: 10px; border: 1px solid #ccc; text-align: center; " |
| 595 "white-space: nowrap;", | 412 "white-space: nowrap;", |
| 596 ) | 413 ) |
| 597 html += "</tbody></table></div><br>" | 414 html += "</tbody></table></div><br>" |
| 598 return html | 415 return html |
| 599 | |
| 600 | |
| 601 def build_tabbed_html(metrics_html: str, train_val_html: str, test_html: str) -> str: | |
| 602 return f""" | |
| 603 <style> | |
| 604 .tabs {{ | |
| 605 display: flex; | |
| 606 border-bottom: 2px solid #ccc; | |
| 607 margin-bottom: 1rem; | |
| 608 }} | |
| 609 .tab {{ | |
| 610 padding: 10px 20px; | |
| 611 cursor: pointer; | |
| 612 border: 1px solid #ccc; | |
| 613 border-bottom: none; | |
| 614 background: #f9f9f9; | |
| 615 margin-right: 5px; | |
| 616 border-top-left-radius: 8px; | |
| 617 border-top-right-radius: 8px; | |
| 618 }} | |
| 619 .tab.active {{ | |
| 620 background: white; | |
| 621 font-weight: bold; | |
| 622 }} | |
| 623 .tab-content {{ | |
| 624 display: none; | |
| 625 padding: 20px; | |
| 626 border: 1px solid #ccc; | |
| 627 border-top: none; | |
| 628 }} | |
| 629 .tab-content.active {{ | |
| 630 display: block; | |
| 631 }} | |
| 632 </style> | |
| 633 <div class="tabs"> | |
| 634 <div class="tab active" onclick="showTab('metrics')"> Config & Results Summary</div> | |
| 635 <div class="tab" onclick="showTab('trainval')"> Train/Validation Results</div> | |
| 636 <div class="tab" onclick="showTab('test')"> Test Results</div> | |
| 637 </div> | |
| 638 <div id="metrics" class="tab-content active"> | |
| 639 {metrics_html} | |
| 640 </div> | |
| 641 <div id="trainval" class="tab-content"> | |
| 642 {train_val_html} | |
| 643 </div> | |
| 644 <div id="test" class="tab-content"> | |
| 645 {test_html} | |
| 646 </div> | |
| 647 <script> | |
| 648 function showTab(id) {{ | |
| 649 document.querySelectorAll('.tab-content').forEach(el => el.classList.remove('active')); | |
| 650 document.querySelectorAll('.tab').forEach(el => el.classList.remove('active')); | |
| 651 document.getElementById(id).classList.add('active'); | |
| 652 document.querySelector(`.tab[onclick*="${{id}}"]`).classList.add('active'); | |
| 653 }} | |
| 654 </script> | |
| 655 """ | |
| 656 | 416 |
| 657 | 417 |
| 658 def split_data_0_2( | 418 def split_data_0_2( |
| 659 df: pd.DataFrame, | 419 df: pd.DataFrame, |
| 660 split_column: str, | 420 split_column: str, |
| 725 output_dir: Path, | 485 output_dir: Path, |
| 726 random_seed: int, | 486 random_seed: int, |
| 727 ) -> None: | 487 ) -> None: |
| 728 ... | 488 ... |
| 729 | 489 |
| 730 def generate_plots( | 490 def generate_plots(self, output_dir: Path) -> None: |
| 731 self, | |
| 732 output_dir: Path | |
| 733 ) -> None: | |
| 734 ... | 491 ... |
| 735 | 492 |
| 736 def generate_html_report( | 493 def generate_html_report( |
| 737 self, | 494 self, |
| 738 title: str, | 495 title: str, |
| 739 output_dir: str | 496 output_dir: str, |
| 497 config: Dict[str, Any], | |
| 498 split_info: str, | |
| 740 ) -> Path: | 499 ) -> Path: |
| 741 ... | 500 ... |
| 742 | 501 |
| 743 | 502 |
| 744 class LudwigDirectBackend: | 503 class LudwigDirectBackend: |
| 747 def prepare_config( | 506 def prepare_config( |
| 748 self, | 507 self, |
| 749 config_params: Dict[str, Any], | 508 config_params: Dict[str, Any], |
| 750 split_config: Dict[str, Any], | 509 split_config: Dict[str, Any], |
| 751 ) -> str: | 510 ) -> str: |
| 752 """Build and serialize the Ludwig YAML configuration.""" | |
| 753 logger.info("LudwigDirectBackend: Preparing YAML configuration.") | 511 logger.info("LudwigDirectBackend: Preparing YAML configuration.") |
| 754 | 512 |
| 755 model_name = config_params.get("model_name", "resnet18") | 513 model_name = config_params.get("model_name", "resnet18") |
| 756 use_pretrained = config_params.get("use_pretrained", False) | 514 use_pretrained = config_params.get("use_pretrained", False) |
| 757 fine_tune = config_params.get("fine_tune", False) | 515 fine_tune = config_params.get("fine_tune", False) |
| 516 if use_pretrained: | |
| 517 trainable = bool(fine_tune) | |
| 518 else: | |
| 519 trainable = True | |
| 758 epochs = config_params.get("epochs", 10) | 520 epochs = config_params.get("epochs", 10) |
| 759 batch_size = config_params.get("batch_size") | 521 batch_size = config_params.get("batch_size") |
| 760 num_processes = config_params.get("preprocessing_num_processes", 1) | 522 num_processes = config_params.get("preprocessing_num_processes", 1) |
| 761 early_stop = config_params.get("early_stop", None) | 523 early_stop = config_params.get("early_stop", None) |
| 762 learning_rate = config_params.get("learning_rate") | 524 learning_rate = config_params.get("learning_rate") |
| 763 learning_rate = "auto" if learning_rate is None else float(learning_rate) | 525 learning_rate = "auto" if learning_rate is None else float(learning_rate) |
| 764 trainable = fine_tune or (not use_pretrained) | |
| 765 if not use_pretrained and not trainable: | |
| 766 logger.warning("trainable=False; use_pretrained=False is ignored.") | |
| 767 logger.warning("Setting trainable=True to train the model from scratch.") | |
| 768 trainable = True | |
| 769 raw_encoder = MODEL_ENCODER_TEMPLATES.get(model_name, model_name) | 526 raw_encoder = MODEL_ENCODER_TEMPLATES.get(model_name, model_name) |
| 770 if isinstance(raw_encoder, dict): | 527 if isinstance(raw_encoder, dict): |
| 771 encoder_config = { | 528 encoder_config = { |
| 772 **raw_encoder, | 529 **raw_encoder, |
| 773 "use_pretrained": use_pretrained, | 530 "use_pretrained": use_pretrained, |
| 777 encoder_config = {"type": raw_encoder} | 534 encoder_config = {"type": raw_encoder} |
| 778 | 535 |
| 779 batch_size_cfg = batch_size or "auto" | 536 batch_size_cfg = batch_size or "auto" |
| 780 | 537 |
| 781 label_column_path = config_params.get("label_column_data_path") | 538 label_column_path = config_params.get("label_column_data_path") |
| 539 label_series = None | |
| 782 if label_column_path is not None and Path(label_column_path).exists(): | 540 if label_column_path is not None and Path(label_column_path).exists(): |
| 783 try: | 541 try: |
| 784 label_series = pd.read_csv(label_column_path)[LABEL_COLUMN_NAME] | 542 label_series = pd.read_csv(label_column_path)[LABEL_COLUMN_NAME] |
| 785 num_unique_labels = label_series.nunique() | |
| 786 except Exception as e: | 543 except Exception as e: |
| 787 logger.warning( | 544 logger.warning(f"Could not read label column for task detection: {e}") |
| 788 f"Could not determine label cardinality, defaulting to 'binary': {e}" | 545 |
| 789 ) | 546 if ( |
| 790 num_unique_labels = 2 | 547 label_series is not None |
| 548 and ptypes.is_numeric_dtype(label_series.dtype) | |
| 549 and label_series.nunique() > 10 | |
| 550 ): | |
| 551 task_type = "regression" | |
| 791 else: | 552 else: |
| 792 logger.warning( | 553 task_type = "classification" |
| 793 "label_column_data_path not provided, defaulting to 'binary'" | 554 |
| 794 ) | 555 config_params["task_type"] = task_type |
| 795 num_unique_labels = 2 | 556 |
| 796 | 557 image_feat: Dict[str, Any] = { |
| 797 output_type = "binary" if num_unique_labels == 2 else "category" | 558 "name": IMAGE_PATH_COLUMN_NAME, |
| 559 "type": "image", | |
| 560 "encoder": encoder_config, | |
| 561 } | |
| 562 if config_params.get("augmentation") is not None: | |
| 563 image_feat["augmentation"] = config_params["augmentation"] | |
| 564 | |
| 565 if task_type == "regression": | |
| 566 output_feat = { | |
| 567 "name": LABEL_COLUMN_NAME, | |
| 568 "type": "number", | |
| 569 "decoder": {"type": "regressor"}, | |
| 570 "loss": {"type": "mean_squared_error"}, | |
| 571 "evaluation": { | |
| 572 "metrics": [ | |
| 573 "mean_squared_error", | |
| 574 "mean_absolute_error", | |
| 575 "r2", | |
| 576 ] | |
| 577 }, | |
| 578 } | |
| 579 val_metric = config_params.get("validation_metric", "mean_squared_error") | |
| 580 | |
| 581 else: | |
| 582 num_unique_labels = ( | |
| 583 label_series.nunique() if label_series is not None else 2 | |
| 584 ) | |
| 585 output_type = "binary" if num_unique_labels == 2 else "category" | |
| 586 output_feat = {"name": LABEL_COLUMN_NAME, "type": output_type} | |
| 587 val_metric = None | |
| 798 | 588 |
| 799 conf: Dict[str, Any] = { | 589 conf: Dict[str, Any] = { |
| 800 "model_type": "ecd", | 590 "model_type": "ecd", |
| 801 "input_features": [ | 591 "input_features": [image_feat], |
| 802 { | 592 "output_features": [output_feat], |
| 803 "name": IMAGE_PATH_COLUMN_NAME, | |
| 804 "type": "image", | |
| 805 "encoder": encoder_config, | |
| 806 } | |
| 807 ], | |
| 808 "output_features": [{"name": LABEL_COLUMN_NAME, "type": output_type}], | |
| 809 "combiner": {"type": "concat"}, | 593 "combiner": {"type": "concat"}, |
| 810 "trainer": { | 594 "trainer": { |
| 811 "epochs": epochs, | 595 "epochs": epochs, |
| 812 "early_stop": early_stop, | 596 "early_stop": early_stop, |
| 813 "batch_size": batch_size_cfg, | 597 "batch_size": batch_size_cfg, |
| 814 "learning_rate": learning_rate, | 598 "learning_rate": learning_rate, |
| 599 # only set validation_metric for regression | |
| 600 **({"validation_metric": val_metric} if val_metric else {}), | |
| 815 }, | 601 }, |
| 816 "preprocessing": { | 602 "preprocessing": { |
| 817 "split": split_config, | 603 "split": split_config, |
| 818 "num_processes": num_processes, | 604 "num_processes": num_processes, |
| 819 "in_memory": False, | 605 "in_memory": False, |
| 874 "LudwigDirectBackend: Experiment execution error.", | 660 "LudwigDirectBackend: Experiment execution error.", |
| 875 exc_info=True, | 661 exc_info=True, |
| 876 ) | 662 ) |
| 877 raise | 663 raise |
| 878 | 664 |
| 879 def get_training_process(self, output_dir) -> float: | 665 def get_training_process(self, output_dir) -> Optional[Dict[str, Any]]: |
| 880 """Retrieve the learning rate used in the most recent Ludwig run.""" | 666 """Retrieve the learning rate used in the most recent Ludwig run.""" |
| 881 output_dir = Path(output_dir) | 667 output_dir = Path(output_dir) |
| 882 exp_dirs = sorted( | 668 exp_dirs = sorted( |
| 883 output_dir.glob("experiment_run*"), | 669 output_dir.glob("experiment_run*"), |
| 884 key=lambda p: p.stat().st_mtime, | 670 key=lambda p: p.stat().st_mtime, |
| 998 stats = json.load(f) | 784 stats = json.load(f) |
| 999 output_feature = next(iter(stats.keys()), "") | 785 output_feature = next(iter(stats.keys()), "") |
| 1000 | 786 |
| 1001 viz_registry = get_visualizations_registry() | 787 viz_registry = get_visualizations_registry() |
| 1002 for viz_name, viz_func in viz_registry.items(): | 788 for viz_name, viz_func in viz_registry.items(): |
| 1003 viz_dir_plot = None | |
| 1004 if viz_name in train_plots: | 789 if viz_name in train_plots: |
| 1005 viz_dir_plot = train_viz | 790 viz_dir_plot = train_viz |
| 1006 elif viz_name in test_plots: | 791 elif viz_name in test_plots: |
| 1007 viz_dir_plot = test_viz | 792 viz_dir_plot = test_viz |
| 793 else: | |
| 794 continue | |
| 1008 | 795 |
| 1009 try: | 796 try: |
| 1010 viz_func( | 797 viz_func( |
| 1011 training_statistics=[training_stats] if training_stats else [], | 798 training_statistics=[training_stats] if training_stats else [], |
| 1012 test_statistics=[test_stats] if test_stats else [], | 799 test_statistics=[test_stats] if test_stats else [], |
| 1038 """Assemble an HTML report from visualizations under train_val/ and test/ folders.""" | 825 """Assemble an HTML report from visualizations under train_val/ and test/ folders.""" |
| 1039 cwd = Path.cwd() | 826 cwd = Path.cwd() |
| 1040 report_name = title.lower().replace(" ", "_") + "_report.html" | 827 report_name = title.lower().replace(" ", "_") + "_report.html" |
| 1041 report_path = cwd / report_name | 828 report_path = cwd / report_name |
| 1042 output_dir = Path(output_dir) | 829 output_dir = Path(output_dir) |
| 830 output_type = None | |
| 1043 | 831 |
| 1044 exp_dirs = sorted( | 832 exp_dirs = sorted( |
| 1045 output_dir.glob("experiment_run*"), | 833 output_dir.glob("experiment_run*"), |
| 1046 key=lambda p: p.stat().st_mtime, | 834 key=lambda p: p.stat().st_mtime, |
| 1047 ) | 835 ) |
| 1057 html += f"<h1>{title}</h1>" | 845 html += f"<h1>{title}</h1>" |
| 1058 | 846 |
| 1059 metrics_html = "" | 847 metrics_html = "" |
| 1060 train_val_metrics_html = "" | 848 train_val_metrics_html = "" |
| 1061 test_metrics_html = "" | 849 test_metrics_html = "" |
| 1062 | |
| 1063 try: | 850 try: |
| 1064 train_stats_path = exp_dir / "training_statistics.json" | 851 train_stats_path = exp_dir / "training_statistics.json" |
| 1065 test_stats_path = exp_dir / TEST_STATISTICS_FILE_NAME | 852 test_stats_path = exp_dir / TEST_STATISTICS_FILE_NAME |
| 1066 if train_stats_path.exists() and test_stats_path.exists(): | 853 if train_stats_path.exists() and test_stats_path.exists(): |
| 1067 with open(train_stats_path) as f: | 854 with open(train_stats_path) as f: |
| 1068 train_stats = json.load(f) | 855 train_stats = json.load(f) |
| 1069 with open(test_stats_path) as f: | 856 with open(test_stats_path) as f: |
| 1070 test_stats = json.load(f) | 857 test_stats = json.load(f) |
| 1071 output_type = detect_output_type(test_stats) | 858 output_type = detect_output_type(test_stats) |
| 1072 all_metrics = extract_metrics_from_json( | |
| 1073 train_stats, | |
| 1074 test_stats, | |
| 1075 output_type, | |
| 1076 ) | |
| 1077 metrics_html = format_stats_table_html(train_stats, test_stats) | 859 metrics_html = format_stats_table_html(train_stats, test_stats) |
| 1078 train_val_metrics_html = format_train_val_stats_table_html( | 860 train_val_metrics_html = format_train_val_stats_table_html( |
| 1079 train_stats, | 861 train_stats, test_stats |
| 1080 test_stats, | |
| 1081 ) | 862 ) |
| 1082 test_metrics_html = format_test_merged_stats_table_html( | 863 test_metrics_html = format_test_merged_stats_table_html( |
| 1083 all_metrics["test"], | 864 extract_metrics_from_json(train_stats, test_stats, output_type)[ |
| 865 "test" | |
| 866 ] | |
| 1084 ) | 867 ) |
| 1085 except Exception as e: | 868 except Exception as e: |
| 1086 logger.warning( | 869 logger.warning( |
| 1087 f"Could not load stats for HTML report: {type(e).__name__}: {e}" | 870 f"Could not load stats for HTML report: {type(e).__name__}: {e}" |
| 1088 ) | 871 ) |
| 1089 | 872 |
| 1090 config_html = "" | 873 config_html = "" |
| 1091 training_progress = self.get_training_process(output_dir) | 874 training_progress = self.get_training_process(output_dir) |
| 1092 try: | 875 try: |
| 1093 config_html = format_config_table_html(config, split_info, training_progress) | 876 config_html = format_config_table_html( |
| 877 config, split_info, training_progress | |
| 878 ) | |
| 1094 except Exception as e: | 879 except Exception as e: |
| 1095 logger.warning(f"Could not load config for HTML report: {e}") | 880 logger.warning(f"Could not load config for HTML report: {e}") |
| 1096 | 881 |
| 1097 def render_img_section(title: str, dir_path: Path, output_type: str = None) -> str: | 882 def render_img_section( |
| 883 title: str, dir_path: Path, output_type: str = None | |
| 884 ) -> str: | |
| 1098 if not dir_path.exists(): | 885 if not dir_path.exists(): |
| 1099 return f"<h2>{title}</h2><p><em>Directory not found.</em></p>" | 886 return f"<h2>{title}</h2><p><em>Directory not found.</em></p>" |
| 1100 | 887 |
| 1101 imgs = list(dir_path.glob("*.png")) | 888 imgs = list(dir_path.glob("*.png")) |
| 1102 if not imgs: | 889 if not imgs: |
| 1139 img_names = {img.name: img for img in imgs if img.name not in unwanted} | 926 img_names = {img.name: img for img in imgs if img.name not in unwanted} |
| 1140 ordered_imgs = [ | 927 ordered_imgs = [ |
| 1141 img_names[fname] for fname in display_order if fname in img_names | 928 img_names[fname] for fname in display_order if fname in img_names |
| 1142 ] | 929 ] |
| 1143 remaining = sorted( | 930 remaining = sorted( |
| 1144 [ | 931 [img for img in img_names.values() if img.name not in display_order] |
| 1145 img | |
| 1146 for img in img_names.values() | |
| 1147 if img.name not in display_order | |
| 1148 ] | |
| 1149 ) | 932 ) |
| 1150 imgs = ordered_imgs + remaining | 933 imgs = ordered_imgs + remaining |
| 1151 | 934 |
| 1152 else: | 935 else: |
| 1153 if output_type == "category": | 936 if output_type == "category": |
| 1171 f"</div>" | 954 f"</div>" |
| 1172 ) | 955 ) |
| 1173 section_html += "</div>" | 956 section_html += "</div>" |
| 1174 return section_html | 957 return section_html |
| 1175 | 958 |
| 1176 button_html = """ | 959 tab1_content = config_html + metrics_html |
| 1177 <button class="help-modal-btn openMetricsHelp">Model Evaluation Metrics — Help Guide</button> | 960 |
| 1178 <br><br> | 961 tab2_content = train_val_metrics_html + render_img_section( |
| 1179 <style> | 962 "Training & Validation Visualizations", train_viz_dir |
| 1180 .help-modal-btn { | 963 ) |
| 1181 background-color: #17623b; | 964 |
| 1182 color: #fff; | 965 # --- Predictions vs Ground Truth table --- |
| 1183 border: none; | 966 preds_section = "" |
| 1184 border-radius: 24px; | 967 parquet_path = exp_dir / PREDICTIONS_PARQUET_FILE_NAME |
| 1185 padding: 10px 28px; | 968 if parquet_path.exists(): |
| 1186 font-size: 1.1rem; | 969 try: |
| 1187 font-weight: bold; | 970 # 1) load predictions from Parquet |
| 1188 letter-spacing: 0.03em; | 971 df_preds = pd.read_parquet(parquet_path).reset_index(drop=True) |
| 1189 cursor: pointer; | 972 # assume the column containing your model's prediction is named "prediction" |
| 1190 transition: background 0.2s, box-shadow 0.2s; | 973 # or contains that substring: |
| 1191 box-shadow: 0 2px 8px rgba(23,98,59,0.07); | 974 pred_col = next( |
| 1192 } | 975 (c for c in df_preds.columns if "prediction" in c.lower()), |
| 1193 .help-modal-btn:hover, .help-modal-btn:focus { | 976 None, |
| 1194 background-color: #21895e; | 977 ) |
| 1195 outline: none; | 978 if pred_col is None: |
| 1196 box-shadow: 0 4px 16px rgba(23,98,59,0.14); | 979 raise ValueError("No prediction column found in Parquet output") |
| 1197 } | 980 df_pred = df_preds[[pred_col]].rename(columns={pred_col: "prediction"}) |
| 1198 </style> | 981 |
| 1199 """ | 982 # 2) load ground truth for the test split from prepared CSV |
| 1200 tab1_content = button_html + config_html + metrics_html | 983 df_all = pd.read_csv(config["label_column_data_path"]) |
| 1201 tab2_content = ( | 984 df_gt = df_all[df_all[SPLIT_COLUMN_NAME] == 2][ |
| 1202 button_html | 985 LABEL_COLUMN_NAME |
| 1203 + train_val_metrics_html | 986 ].reset_index(drop=True) |
| 1204 + render_img_section("Training & Validation Visualizations", train_viz_dir) | 987 |
| 1205 ) | 988 # 3) concatenate side‐by‐side |
| 989 df_table = pd.concat([df_gt, df_pred], axis=1) | |
| 990 df_table.columns = [LABEL_COLUMN_NAME, "prediction"] | |
| 991 | |
| 992 # 4) render as HTML | |
| 993 preds_html = df_table.to_html(index=False, classes="predictions-table") | |
| 994 preds_section = ( | |
| 995 "<h2 style='text-align: center;'>Predictions vs. Ground Truth</h2>" | |
| 996 "<div style='overflow-x:auto; margin-bottom:20px;'>" | |
| 997 + preds_html | |
| 998 + "</div>" | |
| 999 ) | |
| 1000 except Exception as e: | |
| 1001 logger.warning(f"Could not build Predictions vs GT table: {e}") | |
| 1002 # Test tab = Metrics + Preds table + Visualizations | |
| 1003 | |
| 1206 tab3_content = ( | 1004 tab3_content = ( |
| 1207 button_html | 1005 test_metrics_html |
| 1208 + test_metrics_html | 1006 + preds_section |
| 1209 + render_img_section("Test Visualizations", test_viz_dir, output_type) | 1007 + render_img_section("Test Visualizations", test_viz_dir, output_type) |
| 1210 ) | 1008 ) |
| 1211 | 1009 |
| 1010 # assemble the tabs and help modal | |
| 1212 tabbed_html = build_tabbed_html(tab1_content, tab2_content, tab3_content) | 1011 tabbed_html = build_tabbed_html(tab1_content, tab2_content, tab3_content) |
| 1213 modal_html = get_metrics_help_modal() | 1012 modal_html = get_metrics_help_modal() |
| 1214 html += tabbed_html + modal_html | 1013 html += tabbed_html + modal_html + get_html_closing() |
| 1215 html += get_html_closing() | |
| 1216 | 1014 |
| 1217 try: | 1015 try: |
| 1218 with open(report_path, "w") as f: | 1016 with open(report_path, "w") as f: |
| 1219 f.write(html) | 1017 f.write(html) |
| 1220 logger.info(f"HTML report generated at: {report_path}") | 1018 logger.info(f"HTML report generated at: {report_path}") |
| 1261 logger.info("Image extraction complete.") | 1059 logger.info("Image extraction complete.") |
| 1262 except Exception: | 1060 except Exception: |
| 1263 logger.error("Error extracting zip file", exc_info=True) | 1061 logger.error("Error extracting zip file", exc_info=True) |
| 1264 raise | 1062 raise |
| 1265 | 1063 |
| 1266 def _prepare_data(self) -> Tuple[Path, Dict[str, Any]]: | 1064 def _prepare_data(self) -> Tuple[Path, Dict[str, Any], str]: |
| 1267 """Load CSV, update image paths, handle splits, and write prepared CSV.""" | 1065 """Load CSV, update image paths, handle splits, and write prepared CSV.""" |
| 1268 if not self.temp_dir or not self.image_extract_dir: | 1066 if not self.temp_dir or not self.image_extract_dir: |
| 1269 raise RuntimeError("Temp dirs not initialized before data prep.") | 1067 raise RuntimeError("Temp dirs not initialized before data prep.") |
| 1270 | 1068 |
| 1271 try: | 1069 try: |
| 1300 f"No split column in CSV. Used random split: " | 1098 f"No split column in CSV. Used random split: " |
| 1301 f"{[int(p * 100) for p in self.args.split_probabilities]}% " | 1099 f"{[int(p * 100) for p in self.args.split_probabilities]}% " |
| 1302 f"for train/val/test." | 1100 f"for train/val/test." |
| 1303 ) | 1101 ) |
| 1304 | 1102 |
| 1305 final_csv = TEMP_CSV_FILENAME | 1103 final_csv = self.temp_dir / TEMP_CSV_FILENAME |
| 1306 try: | 1104 try: |
| 1105 | |
| 1307 df.to_csv(final_csv, index=False) | 1106 df.to_csv(final_csv, index=False) |
| 1308 logger.info(f"Saved prepared data to {final_csv}") | 1107 logger.info(f"Saved prepared data to {final_csv}") |
| 1309 except Exception: | 1108 except Exception: |
| 1310 logger.error("Error saving prepared CSV", exc_info=True) | 1109 logger.error("Error saving prepared CSV", exc_info=True) |
| 1311 raise | 1110 raise |
| 1312 | 1111 |
| 1313 return final_csv, split_config, split_info | 1112 return final_csv, split_config, split_info |
| 1314 | 1113 |
| 1315 def _process_fixed_split(self, df: pd.DataFrame) -> Dict[str, Any]: | 1114 def _process_fixed_split( |
| 1115 self, df: pd.DataFrame | |
| 1116 ) -> Tuple[pd.DataFrame, Dict[str, Any], str]: | |
| 1316 """Process a fixed split column (0=train,1=val,2=test).""" | 1117 """Process a fixed split column (0=train,1=val,2=test).""" |
| 1317 logger.info(f"Fixed split column '{SPLIT_COLUMN_NAME}' detected.") | 1118 logger.info(f"Fixed split column '{SPLIT_COLUMN_NAME}' detected.") |
| 1318 try: | 1119 try: |
| 1319 col = df[SPLIT_COLUMN_NAME] | 1120 col = df[SPLIT_COLUMN_NAME] |
| 1320 df[SPLIT_COLUMN_NAME] = pd.to_numeric(col, errors="coerce").astype( | 1121 df[SPLIT_COLUMN_NAME] = pd.to_numeric(col, errors="coerce").astype( |
| 1382 "split_probabilities": self.args.split_probabilities, | 1183 "split_probabilities": self.args.split_probabilities, |
| 1383 "learning_rate": self.args.learning_rate, | 1184 "learning_rate": self.args.learning_rate, |
| 1384 "random_seed": self.args.random_seed, | 1185 "random_seed": self.args.random_seed, |
| 1385 "early_stop": self.args.early_stop, | 1186 "early_stop": self.args.early_stop, |
| 1386 "label_column_data_path": csv_path, | 1187 "label_column_data_path": csv_path, |
| 1188 "augmentation": self.args.augmentation, | |
| 1387 } | 1189 } |
| 1388 yaml_str = self.backend.prepare_config(backend_args, split_cfg) | 1190 yaml_str = self.backend.prepare_config(backend_args, split_cfg) |
| 1389 | 1191 |
| 1390 config_file = self.temp_dir / TEMP_CONFIG_FILENAME | 1192 config_file = self.temp_dir / TEMP_CONFIG_FILENAME |
| 1391 config_file.write_text(yaml_str) | 1193 config_file.write_text(yaml_str) |
| 1420 return float(s) | 1222 return float(s) |
| 1421 except (TypeError, ValueError): | 1223 except (TypeError, ValueError): |
| 1422 return None | 1224 return None |
| 1423 | 1225 |
| 1424 | 1226 |
| 1227 def aug_parse(aug_string: str): | |
| 1228 """ | |
| 1229 Parse comma-separated augmentation keys into Ludwig augmentation dicts. | |
| 1230 Raises ValueError on unknown key. | |
| 1231 """ | |
| 1232 mapping = { | |
| 1233 "random_horizontal_flip": {"type": "random_horizontal_flip"}, | |
| 1234 "random_vertical_flip": {"type": "random_vertical_flip"}, | |
| 1235 "random_rotate": {"type": "random_rotate", "degree": 10}, | |
| 1236 "random_blur": {"type": "random_blur", "kernel_size": 3}, | |
| 1237 "random_brightness": {"type": "random_brightness", "min": 0.5, "max": 2.0}, | |
| 1238 "random_contrast": {"type": "random_contrast", "min": 0.5, "max": 2.0}, | |
| 1239 } | |
| 1240 aug_list = [] | |
| 1241 for tok in aug_string.split(","): | |
| 1242 key = tok.strip() | |
| 1243 if key not in mapping: | |
| 1244 valid = ", ".join(mapping.keys()) | |
| 1245 raise ValueError(f"Unknown augmentation '{key}'. Valid choices: {valid}") | |
| 1246 aug_list.append(mapping[key]) | |
| 1247 return aug_list | |
| 1248 | |
| 1249 | |
| 1425 class SplitProbAction(argparse.Action): | 1250 class SplitProbAction(argparse.Action): |
| 1426 def __call__(self, parser, namespace, values, option_string=None): | 1251 def __call__(self, parser, namespace, values, option_string=None): |
| 1427 train, val, test = values | 1252 train, val, test = values |
| 1428 total = train + val + test | 1253 total = train + val + test |
| 1429 if abs(total - 1.0) > 1e-6: | 1254 if abs(total - 1.0) > 1e-6: |
| 1506 type=float, | 1331 type=float, |
| 1507 nargs=3, | 1332 nargs=3, |
| 1508 metavar=("train", "val", "test"), | 1333 metavar=("train", "val", "test"), |
| 1509 action=SplitProbAction, | 1334 action=SplitProbAction, |
| 1510 default=[0.7, 0.1, 0.2], | 1335 default=[0.7, 0.1, 0.2], |
| 1511 help="Random split proportions (e.g., 0.7 0.1 0.2). Only used if no split column.", | 1336 help=( |
| 1337 "Random split proportions (e.g., 0.7 0.1 0.2)." | |
| 1338 "Only used if no split column." | |
| 1339 ), | |
| 1512 ) | 1340 ) |
| 1513 parser.add_argument( | 1341 parser.add_argument( |
| 1514 "--random-seed", | 1342 "--random-seed", |
| 1515 type=int, | 1343 type=int, |
| 1516 default=42, | 1344 default=42, |
| 1519 parser.add_argument( | 1347 parser.add_argument( |
| 1520 "--learning-rate", | 1348 "--learning-rate", |
| 1521 type=parse_learning_rate, | 1349 type=parse_learning_rate, |
| 1522 default=None, | 1350 default=None, |
| 1523 help="Learning rate. If not provided, Ludwig will auto-select it.", | 1351 help="Learning rate. If not provided, Ludwig will auto-select it.", |
| 1352 ) | |
| 1353 parser.add_argument( | |
| 1354 "--augmentation", | |
| 1355 type=str, | |
| 1356 default=None, | |
| 1357 help=( | |
| 1358 "Comma-separated list (in order) of any of: " | |
| 1359 "random_horizontal_flip, random_vertical_flip, random_rotate, " | |
| 1360 "random_blur, random_brightness, random_contrast. " | |
| 1361 "E.g. --augmentation random_horizontal_flip,random_rotate" | |
| 1362 ), | |
| 1524 ) | 1363 ) |
| 1525 | 1364 |
| 1526 args = parser.parse_args() | 1365 args = parser.parse_args() |
| 1527 | 1366 |
| 1528 if not 0.0 <= args.validation_size <= 1.0: | 1367 if not 0.0 <= args.validation_size <= 1.0: |
| 1529 parser.error("validation-size must be between 0.0 and 1.0") | 1368 parser.error("validation-size must be between 0.0 and 1.0") |
| 1530 if not args.csv_file.is_file(): | 1369 if not args.csv_file.is_file(): |
| 1531 parser.error(f"CSV not found: {args.csv_file}") | 1370 parser.error(f"CSV not found: {args.csv_file}") |
| 1532 if not args.image_zip.is_file(): | 1371 if not args.image_zip.is_file(): |
| 1533 parser.error(f"ZIP not found: {args.image_zip}") | 1372 parser.error(f"ZIP not found: {args.image_zip}") |
| 1373 if args.augmentation is not None: | |
| 1374 try: | |
| 1375 augmentation_setup = aug_parse(args.augmentation) | |
| 1376 setattr(args, "augmentation", augmentation_setup) | |
| 1377 except ValueError as e: | |
| 1378 parser.error(str(e)) | |
| 1534 | 1379 |
| 1535 backend_instance = LudwigDirectBackend() | 1380 backend_instance = LudwigDirectBackend() |
| 1536 orchestrator = WorkflowOrchestrator(args, backend_instance) | 1381 orchestrator = WorkflowOrchestrator(args, backend_instance) |
| 1537 | 1382 |
| 1538 exit_code = 0 | 1383 exit_code = 0 |
