comparison image_learner_cli.py @ 2:186424a7eca7 draft

planemo upload for repository https://github.com/goeckslab/gleam.git commit 91fa4aba245520fc0680088a07cead66bcfd4ed2
author goeckslab
date Thu, 03 Jul 2025 20:43:24 +0000
parents 39202fe5cf97
children 09904b1f61f5
comparison
equal deleted inserted replaced
1:39202fe5cf97 2:186424a7eca7
1 #!/usr/bin/env python3
2 import argparse 1 import argparse
3 import json 2 import json
4 import logging 3 import logging
5 import os 4 import os
6 import shutil 5 import shutil
9 import zipfile 8 import zipfile
10 from pathlib import Path 9 from pathlib import Path
11 from typing import Any, Dict, Optional, Protocol, Tuple 10 from typing import Any, Dict, Optional, Protocol, Tuple
12 11
13 import pandas as pd 12 import pandas as pd
13 import pandas.api.types as ptypes
14 import yaml 14 import yaml
15 from constants import (
16 IMAGE_PATH_COLUMN_NAME,
17 LABEL_COLUMN_NAME,
18 METRIC_DISPLAY_NAMES,
19 MODEL_ENCODER_TEMPLATES,
20 SPLIT_COLUMN_NAME,
21 TEMP_CONFIG_FILENAME,
22 TEMP_CSV_FILENAME,
23 TEMP_DIR_PREFIX
24 )
15 from ludwig.globals import ( 25 from ludwig.globals import (
16 DESCRIPTION_FILE_NAME, 26 DESCRIPTION_FILE_NAME,
17 PREDICTIONS_PARQUET_FILE_NAME, 27 PREDICTIONS_PARQUET_FILE_NAME,
18 TEST_STATISTICS_FILE_NAME, 28 TEST_STATISTICS_FILE_NAME,
19 TRAIN_SET_METADATA_FILE_NAME, 29 TRAIN_SET_METADATA_FILE_NAME,
20 ) 30 )
21 from ludwig.utils.data_utils import get_split_path 31 from ludwig.utils.data_utils import get_split_path
22 from ludwig.visualize import get_visualizations_registry 32 from ludwig.visualize import get_visualizations_registry
23 from sklearn.model_selection import train_test_split 33 from sklearn.model_selection import train_test_split
24 from utils import encode_image_to_base64, get_html_closing, get_html_template 34 from utils import (
25 35 build_tabbed_html,
26 # --- Constants --- 36 encode_image_to_base64,
27 SPLIT_COLUMN_NAME = "split" 37 get_html_closing,
28 LABEL_COLUMN_NAME = "label" 38 get_html_template,
29 IMAGE_PATH_COLUMN_NAME = "image_path" 39 get_metrics_help_modal
30 DEFAULT_SPLIT_PROBABILITIES = [0.7, 0.1, 0.2] 40 )
31 TEMP_CSV_FILENAME = "processed_data_for_ludwig.csv"
32 TEMP_CONFIG_FILENAME = "ludwig_config.yaml"
33 TEMP_DIR_PREFIX = "ludwig_api_work_"
34 MODEL_ENCODER_TEMPLATES: Dict[str, Any] = {
35 "stacked_cnn": "stacked_cnn",
36 "resnet18": {"type": "resnet", "model_variant": 18},
37 "resnet34": {"type": "resnet", "model_variant": 34},
38 "resnet50": {"type": "resnet", "model_variant": 50},
39 "resnet101": {"type": "resnet", "model_variant": 101},
40 "resnet152": {"type": "resnet", "model_variant": 152},
41 "resnext50_32x4d": {"type": "resnext", "model_variant": "50_32x4d"},
42 "resnext101_32x8d": {"type": "resnext", "model_variant": "101_32x8d"},
43 "resnext101_64x4d": {"type": "resnext", "model_variant": "101_64x4d"},
44 "resnext152_32x8d": {"type": "resnext", "model_variant": "152_32x8d"},
45 "wide_resnet50_2": {"type": "wide_resnet", "model_variant": "50_2"},
46 "wide_resnet101_2": {"type": "wide_resnet", "model_variant": "101_2"},
47 "wide_resnet103_2": {"type": "wide_resnet", "model_variant": "103_2"},
48 "efficientnet_b0": {"type": "efficientnet", "model_variant": "b0"},
49 "efficientnet_b1": {"type": "efficientnet", "model_variant": "b1"},
50 "efficientnet_b2": {"type": "efficientnet", "model_variant": "b2"},
51 "efficientnet_b3": {"type": "efficientnet", "model_variant": "b3"},
52 "efficientnet_b4": {"type": "efficientnet", "model_variant": "b4"},
53 "efficientnet_b5": {"type": "efficientnet", "model_variant": "b5"},
54 "efficientnet_b6": {"type": "efficientnet", "model_variant": "b6"},
55 "efficientnet_b7": {"type": "efficientnet", "model_variant": "b7"},
56 "efficientnet_v2_s": {"type": "efficientnet", "model_variant": "v2_s"},
57 "efficientnet_v2_m": {"type": "efficientnet", "model_variant": "v2_m"},
58 "efficientnet_v2_l": {"type": "efficientnet", "model_variant": "v2_l"},
59 "regnet_y_400mf": {"type": "regnet", "model_variant": "y_400mf"},
60 "regnet_y_800mf": {"type": "regnet", "model_variant": "y_800mf"},
61 "regnet_y_1_6gf": {"type": "regnet", "model_variant": "y_1_6gf"},
62 "regnet_y_3_2gf": {"type": "regnet", "model_variant": "y_3_2gf"},
63 "regnet_y_8gf": {"type": "regnet", "model_variant": "y_8gf"},
64 "regnet_y_16gf": {"type": "regnet", "model_variant": "y_16gf"},
65 "regnet_y_32gf": {"type": "regnet", "model_variant": "y_32gf"},
66 "regnet_y_128gf": {"type": "regnet", "model_variant": "y_128gf"},
67 "regnet_x_400mf": {"type": "regnet", "model_variant": "x_400mf"},
68 "regnet_x_800mf": {"type": "regnet", "model_variant": "x_800mf"},
69 "regnet_x_1_6gf": {"type": "regnet", "model_variant": "x_1_6gf"},
70 "regnet_x_3_2gf": {"type": "regnet", "model_variant": "x_3_2gf"},
71 "regnet_x_8gf": {"type": "regnet", "model_variant": "x_8gf"},
72 "regnet_x_16gf": {"type": "regnet", "model_variant": "x_16gf"},
73 "regnet_x_32gf": {"type": "regnet", "model_variant": "x_32gf"},
74 "vgg11": {"type": "vgg", "model_variant": 11},
75 "vgg11_bn": {"type": "vgg", "model_variant": "11_bn"},
76 "vgg13": {"type": "vgg", "model_variant": 13},
77 "vgg13_bn": {"type": "vgg", "model_variant": "13_bn"},
78 "vgg16": {"type": "vgg", "model_variant": 16},
79 "vgg16_bn": {"type": "vgg", "model_variant": "16_bn"},
80 "vgg19": {"type": "vgg", "model_variant": 19},
81 "vgg19_bn": {"type": "vgg", "model_variant": "19_bn"},
82 "shufflenet_v2_x0_5": {"type": "shufflenet_v2", "model_variant": "x0_5"},
83 "shufflenet_v2_x1_0": {"type": "shufflenet_v2", "model_variant": "x1_0"},
84 "shufflenet_v2_x1_5": {"type": "shufflenet_v2", "model_variant": "x1_5"},
85 "shufflenet_v2_x2_0": {"type": "shufflenet_v2", "model_variant": "x2_0"},
86 "squeezenet1_0": {"type": "squeezenet", "model_variant": "1_0"},
87 "squeezenet1_1": {"type": "squeezenet", "model_variant": "1_1"},
88 "swin_t": {"type": "swin_transformer", "model_variant": "t"},
89 "swin_s": {"type": "swin_transformer", "model_variant": "s"},
90 "swin_b": {"type": "swin_transformer", "model_variant": "b"},
91 "swin_v2_t": {"type": "swin_transformer", "model_variant": "v2_t"},
92 "swin_v2_s": {"type": "swin_transformer", "model_variant": "v2_s"},
93 "swin_v2_b": {"type": "swin_transformer", "model_variant": "v2_b"},
94 "vit_b_16": {"type": "vision_transformer", "model_variant": "b_16"},
95 "vit_b_32": {"type": "vision_transformer", "model_variant": "b_32"},
96 "vit_l_16": {"type": "vision_transformer", "model_variant": "l_16"},
97 "vit_l_32": {"type": "vision_transformer", "model_variant": "l_32"},
98 "vit_h_14": {"type": "vision_transformer", "model_variant": "h_14"},
99 "convnext_tiny": {"type": "convnext", "model_variant": "tiny"},
100 "convnext_small": {"type": "convnext", "model_variant": "small"},
101 "convnext_base": {"type": "convnext", "model_variant": "base"},
102 "convnext_large": {"type": "convnext", "model_variant": "large"},
103 "maxvit_t": {"type": "maxvit", "model_variant": "t"},
104 "alexnet": {"type": "alexnet"},
105 "googlenet": {"type": "googlenet"},
106 "inception_v3": {"type": "inception_v3"},
107 "mobilenet_v2": {"type": "mobilenet_v2"},
108 "mobilenet_v3_large": {"type": "mobilenet_v3_large"},
109 "mobilenet_v3_small": {"type": "mobilenet_v3_small"},
110 }
111 METRIC_DISPLAY_NAMES = {
112 "accuracy": "Accuracy",
113 "accuracy_micro": "Accuracy-Micro",
114 "loss": "Loss",
115 "roc_auc": "ROC-AUC",
116 "roc_auc_macro": "ROC-AUC-Macro",
117 "roc_auc_micro": "ROC-AUC-Micro",
118 "hits_at_k": "Hits at K",
119 "precision": "Precision",
120 "recall": "Recall",
121 "specificity": "Specificity",
122 "kappa_score": "Cohen's Kappa",
123 "token_accuracy": "Token Accuracy",
124 "avg_precision_macro": "Precision-Macro",
125 "avg_recall_macro": "Recall-Macro",
126 "avg_f1_score_macro": "F1-score-Macro",
127 "avg_precision_micro": "Precision-Micro",
128 "avg_recall_micro": "Recall-Micro",
129 "avg_f1_score_micro": "F1-score-Micro",
130 "avg_precision_weighted": "Precision-Weighted",
131 "avg_recall_weighted": "Recall-Weighted",
132 "avg_f1_score_weighted": "F1-score-Weighted",
133 "average_precision_macro": " Precision-Average-Macro",
134 "average_precision_micro": "Precision-Average-Micro",
135 "average_precision_samples": "Precision-Average-Samples",
136 }
137 41
138 # --- Logging Setup --- 42 # --- Logging Setup ---
139 logging.basicConfig( 43 logging.basicConfig(
140 level=logging.INFO, 44 level=logging.INFO,
141 format="%(asctime)s %(levelname)s %(name)s: %(message)s", 45 format='%(asctime)s %(levelname)s %(name)s: %(message)s',
142 ) 46 )
143 logger = logging.getLogger("ImageLearner") 47 logger = logging.getLogger("ImageLearner")
144
145
146 def get_metrics_help_modal() -> str:
147 modal_html = """
148 <div id="metricsHelpModal" class="modal">
149 <div class="modal-content">
150 <span class="close">×</span>
151 <h2>Model Evaluation Metrics — Help Guide</h2>
152 <div class="metrics-guide">
153 <h3>1) General Metrics</h3>
154 <p><strong>Loss:</strong> Measures the difference between predicted and actual values. Lower is better. Often used for optimization during training.</p>
155 <p><strong>Accuracy:</strong> Proportion of correct predictions among all predictions. Simple but can be misleading for imbalanced datasets.</p>
156 <p><strong>Micro Accuracy:</strong> Calculates accuracy by summing up all individual true positives and true negatives across all classes, making it suitable for multiclass or multilabel problems.</p>
157 <p><strong>Token Accuracy:</strong> Measures how often the predicted tokens (e.g., in sequences) match the true tokens. Useful in sequence prediction tasks like NLP.</p>
158 <h3>2) Precision, Recall & Specificity</h3>
159 <p><strong>Precision:</strong> Out of all positive predictions, how many were correct. Precision = TP / (TP + FP). Helps when false positives are costly.</p>
160 <p><strong>Recall (Sensitivity):</strong> Out of all actual positives, how many were predicted correctly. Recall = TP / (TP + FN). Important when missing positives is risky.</p>
161 <p><strong>Specificity:</strong> True negative rate. Measures how well the model identifies negatives. Specificity = TN / (TN + FP). Useful in medical testing to avoid false alarms.</p>
162 <h3>3) Macro, Micro, and Weighted Averages</h3>
163 <p><strong>Macro Precision / Recall / F1:</strong> Averages the metric across all classes, treating each class equally, regardless of class frequency. Best when class sizes are balanced.</p>
164 <p><strong>Micro Precision / Recall / F1:</strong> Aggregates TP, FP, FN across all classes before computing the metric. Gives a global view and is ideal for class-imbalanced problems.</p>
165 <p><strong>Weighted Precision / Recall / F1:</strong> Averages each metric across classes, weighted by the number of true instances per class. Balances importance of classes based on frequency.</p>
166 <h3>4) Average Precision (PR-AUC Variants)</h3>
167 <p><strong>Average Precision Macro:</strong> Precision-Recall AUC averaged across all classes equally. Useful for balanced multi-class problems.</p>
168 <p><strong>Average Precision Micro:</strong> Global Precision-Recall AUC using all instances. Best for imbalanced data or multi-label classification.</p>
169 <p><strong>Average Precision Samples:</strong> Precision-Recall AUC averaged across individual samples (not classes). Ideal for multi-label problems where each sample can belong to multiple classes.</p>
170 <h3>5) ROC-AUC Variants</h3>
171 <p><strong>ROC-AUC:</strong> Measures model's ability to distinguish between classes. AUC = 1 is perfect; 0.5 is random guessing. Use for binary classification.</p>
172 <p><strong>Macro ROC-AUC:</strong> Averages the AUC across all classes equally. Suitable when classes are balanced and of equal importance.</p>
173 <p><strong>Micro ROC-AUC:</strong> Computes AUC from aggregated predictions across all classes. Useful in multiclass or multilabel settings with imbalance.</p>
174 <h3>6) Ranking Metrics</h3>
175 <p><strong>Hits at K:</strong> Measures whether the true label is among the top-K predictions. Common in recommendation systems and retrieval tasks.</p>
176 <h3>7) Confusion Matrix Stats (Per Class)</h3>
177 <p><strong>True Positives / Negatives (TP / TN):</strong> Correct predictions for positives and negatives respectively.</p>
178 <p><strong>False Positives / Negatives (FP / FN):</strong> Incorrect predictions — false alarms and missed detections.</p>
179 <h3>8) Other Useful Metrics</h3>
180 <p><strong>Cohen's Kappa:</strong> Measures agreement between predicted and actual values adjusted for chance. Useful for multiclass classification with imbalanced labels.</p>
181 <p><strong>Matthews Correlation Coefficient (MCC):</strong> Balanced measure of prediction quality that takes into account TP, TN, FP, and FN. Particularly effective for imbalanced datasets.</p>
182 <h3>9) Metric Recommendations</h3>
183 <ul>
184 <li>Use <strong>Accuracy + F1</strong> for balanced data.</li>
185 <li>Use <strong>Precision, Recall, ROC-AUC</strong> for imbalanced datasets.</li>
186 <li>Use <strong>Average Precision Micro</strong> for multilabel or class-imbalanced problems.</li>
187 <li>Use <strong>Macro scores</strong> when all classes should be treated equally.</li>
188 <li>Use <strong>Weighted scores</strong> when class imbalance should be accounted for without ignoring small classes.</li>
189 <li>Use <strong>Confusion Matrix stats</strong> to analyze class-wise performance.</li>
190 <li>Use <strong>Hits at K</strong> for recommendation or ranking-based tasks.</li>
191 </ul>
192 </div>
193 </div>
194 </div>
195 """
196 modal_css = """
197 <style>
198 .modal {
199 display: none;
200 position: fixed;
201 z-index: 1;
202 left: 0;
203 top: 0;
204 width: 100%;
205 height: 100%;
206 overflow: auto;
207 background-color: rgba(0,0,0,0.4);
208 }
209 .modal-content {
210 background-color: #fefefe;
211 margin: 15% auto;
212 padding: 20px;
213 border: 1px solid #888;
214 width: 80%;
215 max-width: 800px;
216 }
217 .close {
218 color: #aaa;
219 float: right;
220 font-size: 28px;
221 font-weight: bold;
222 }
223 .close:hover,
224 .close:focus {
225 color: black;
226 text-decoration: none;
227 cursor: pointer;
228 }
229 .metrics-guide h3 {
230 margin-top: 20px;
231 }
232 .metrics-guide p {
233 margin: 5px 0;
234 }
235 .metrics-guide ul {
236 margin: 10px 0;
237 padding-left: 20px;
238 }
239 </style>
240 """
241 modal_js = """
242 <script>
243 document.addEventListener("DOMContentLoaded", function() {
244 var modal = document.getElementById("metricsHelpModal");
245 var closeBtn = document.getElementsByClassName("close")[0];
246
247 document.querySelectorAll(".openMetricsHelp").forEach(btn => {
248 btn.onclick = function() {
249 modal.style.display = "block";
250 };
251 });
252
253 if (closeBtn) {
254 closeBtn.onclick = function() {
255 modal.style.display = "none";
256 };
257 }
258
259 window.onclick = function(event) {
260 if (event.target == modal) {
261 modal.style.display = "none";
262 }
263 }
264 });
265 </script>
266 """
267 return modal_css + modal_html + modal_js
268 48
269 49
270 def format_config_table_html( 50 def format_config_table_html(
271 config: dict, 51 config: dict,
272 split_info: Optional[str] = None, 52 split_info: Optional[str] = None,
273 training_progress: dict = None, 53 training_progress: dict = None,
274 ) -> str: 54 ) -> str:
275 display_keys = [ 55 display_keys = [
56 "task_type",
276 "model_name", 57 "model_name",
277 "epochs", 58 "epochs",
278 "batch_size", 59 "batch_size",
279 "fine_tune", 60 "fine_tune",
280 "use_pretrained", 61 "use_pretrained",
285 66
286 rows = [] 67 rows = []
287 68
288 for key in display_keys: 69 for key in display_keys:
289 val = config.get(key, "N/A") 70 val = config.get(key, "N/A")
71 if key == "task_type":
72 val = val.title() if isinstance(val, str) else val
290 if key == "batch_size": 73 if key == "batch_size":
291 if val is not None: 74 if val is not None:
292 val = int(val) 75 val = int(val)
293 else: 76 else:
294 if training_progress: 77 if training_progress:
346 f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center;'>" 129 f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center;'>"
347 f"{val}</td>" 130 f"{val}</td>"
348 f"</tr>" 131 f"</tr>"
349 ) 132 )
350 133
134 aug_cfg = config.get("augmentation")
135 if aug_cfg:
136 types = [str(a.get("type", "")) for a in aug_cfg]
137 aug_val = ", ".join(types)
138 rows.append(
139 "<tr>"
140 "<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left;'>Augmentation</td>"
141 "<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center;'>"
142 f"{aug_val}</td>"
143 "</tr>"
144 )
145
351 if split_info: 146 if split_info:
352 rows.append( 147 rows.append(
353 f"<tr>" 148 f"<tr>"
354 f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left;'>" 149 f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left;'>"
355 f"Data Split</td>" 150 f"Data Split</td>"
369 "Value</th>" 164 "Value</th>"
370 "</tr></thead><tbody>" + "".join(rows) + "</tbody></table></div><br>" 165 "</tr></thead><tbody>" + "".join(rows) + "</tbody></table></div><br>"
371 "<p style='text-align: center; font-size: 0.9em;'>" 166 "<p style='text-align: center; font-size: 0.9em;'>"
372 "Model trained using Ludwig.<br>" 167 "Model trained using Ludwig.<br>"
373 "If want to learn more about Ludwig default settings," 168 "If want to learn more about Ludwig default settings,"
374 "please check the their <a href='https://ludwig.ai' target='_blank'>" 169 "please check their <a href='https://ludwig.ai' target='_blank'>"
375 "website(ludwig.ai)</a>." 170 "website(ludwig.ai)</a>."
376 "</p><hr>" 171 "</p><hr>"
377 ) 172 )
378 173
379 174
380 def detect_output_type(test_stats): 175 def detect_output_type(test_stats):
381 """Detects if the output type is 'binary' or 'category' based on test statistics.""" 176 """Detects if the output type is 'binary' or 'category' based on test statistics."""
382 label_stats = test_stats.get("label", {}) 177 label_stats = test_stats.get("label", {})
178 if "mean_squared_error" in label_stats:
179 return "regression"
383 per_class = label_stats.get("per_class_stats", {}) 180 per_class = label_stats.get("per_class_stats", {})
384 if len(per_class) == 2: 181 if len(per_class) == 2:
385 return "binary" 182 return "binary"
386 return "category" 183 return "category"
387 184
418 "precision": get_last_value(label_stats, "precision"), 215 "precision": get_last_value(label_stats, "precision"),
419 "recall": get_last_value(label_stats, "recall"), 216 "recall": get_last_value(label_stats, "recall"),
420 "specificity": get_last_value(label_stats, "specificity"), 217 "specificity": get_last_value(label_stats, "specificity"),
421 "roc_auc": get_last_value(label_stats, "roc_auc"), 218 "roc_auc": get_last_value(label_stats, "roc_auc"),
422 } 219 }
220 elif output_type == "regression":
221 metrics[split] = {
222 "loss": get_last_value(label_stats, "loss"),
223 "mean_absolute_error": get_last_value(
224 label_stats, "mean_absolute_error"
225 ),
226 "mean_absolute_percentage_error": get_last_value(
227 label_stats, "mean_absolute_percentage_error"
228 ),
229 "mean_squared_error": get_last_value(label_stats, "mean_squared_error"),
230 "root_mean_squared_error": get_last_value(
231 label_stats, "root_mean_squared_error"
232 ),
233 "root_mean_squared_percentage_error": get_last_value(
234 label_stats, "root_mean_squared_percentage_error"
235 ),
236 "r2": get_last_value(label_stats, "r2"),
237 }
423 else: 238 else:
424 metrics[split] = { 239 metrics[split] = {
425 "accuracy": get_last_value(label_stats, "accuracy"), 240 "accuracy": get_last_value(label_stats, "accuracy"),
426 "accuracy_micro": get_last_value(label_stats, "accuracy_micro"), 241 "accuracy_micro": get_last_value(label_stats, "accuracy_micro"),
427 "loss": get_last_value(label_stats, "loss"), 242 "loss": get_last_value(label_stats, "loss"),
563 ) 378 )
564 html += "</tbody></table></div><br>" 379 html += "</tbody></table></div><br>"
565 return html 380 return html
566 381
567 382
568 def format_test_merged_stats_table_html(test_metrics: Dict[str, Optional[float]]) -> str: 383 def format_test_merged_stats_table_html(
384 test_metrics: Dict[str, Optional[float]],
385 ) -> str:
569 """Formats an HTML table for test metrics.""" 386 """Formats an HTML table for test metrics."""
570 rows = [] 387 rows = []
571 for key in sorted(test_metrics.keys()): 388 for key in sorted(test_metrics.keys()):
572 display_name = METRIC_DISPLAY_NAMES.get(key, key.replace("_", " ").title()) 389 display_name = METRIC_DISPLAY_NAMES.get(key, key.replace("_", " ").title())
573 value = test_metrics[key] 390 value = test_metrics[key]
594 "padding: 10px; border: 1px solid #ccc; text-align: center; " 411 "padding: 10px; border: 1px solid #ccc; text-align: center; "
595 "white-space: nowrap;", 412 "white-space: nowrap;",
596 ) 413 )
597 html += "</tbody></table></div><br>" 414 html += "</tbody></table></div><br>"
598 return html 415 return html
599
600
601 def build_tabbed_html(metrics_html: str, train_val_html: str, test_html: str) -> str:
602 return f"""
603 <style>
604 .tabs {{
605 display: flex;
606 border-bottom: 2px solid #ccc;
607 margin-bottom: 1rem;
608 }}
609 .tab {{
610 padding: 10px 20px;
611 cursor: pointer;
612 border: 1px solid #ccc;
613 border-bottom: none;
614 background: #f9f9f9;
615 margin-right: 5px;
616 border-top-left-radius: 8px;
617 border-top-right-radius: 8px;
618 }}
619 .tab.active {{
620 background: white;
621 font-weight: bold;
622 }}
623 .tab-content {{
624 display: none;
625 padding: 20px;
626 border: 1px solid #ccc;
627 border-top: none;
628 }}
629 .tab-content.active {{
630 display: block;
631 }}
632 </style>
633 <div class="tabs">
634 <div class="tab active" onclick="showTab('metrics')"> Config & Results Summary</div>
635 <div class="tab" onclick="showTab('trainval')"> Train/Validation Results</div>
636 <div class="tab" onclick="showTab('test')"> Test Results</div>
637 </div>
638 <div id="metrics" class="tab-content active">
639 {metrics_html}
640 </div>
641 <div id="trainval" class="tab-content">
642 {train_val_html}
643 </div>
644 <div id="test" class="tab-content">
645 {test_html}
646 </div>
647 <script>
648 function showTab(id) {{
649 document.querySelectorAll('.tab-content').forEach(el => el.classList.remove('active'));
650 document.querySelectorAll('.tab').forEach(el => el.classList.remove('active'));
651 document.getElementById(id).classList.add('active');
652 document.querySelector(`.tab[onclick*="${{id}}"]`).classList.add('active');
653 }}
654 </script>
655 """
656 416
657 417
658 def split_data_0_2( 418 def split_data_0_2(
659 df: pd.DataFrame, 419 df: pd.DataFrame,
660 split_column: str, 420 split_column: str,
725 output_dir: Path, 485 output_dir: Path,
726 random_seed: int, 486 random_seed: int,
727 ) -> None: 487 ) -> None:
728 ... 488 ...
729 489
730 def generate_plots( 490 def generate_plots(self, output_dir: Path) -> None:
731 self,
732 output_dir: Path
733 ) -> None:
734 ... 491 ...
735 492
736 def generate_html_report( 493 def generate_html_report(
737 self, 494 self,
738 title: str, 495 title: str,
739 output_dir: str 496 output_dir: str,
497 config: Dict[str, Any],
498 split_info: str,
740 ) -> Path: 499 ) -> Path:
741 ... 500 ...
742 501
743 502
744 class LudwigDirectBackend: 503 class LudwigDirectBackend:
747 def prepare_config( 506 def prepare_config(
748 self, 507 self,
749 config_params: Dict[str, Any], 508 config_params: Dict[str, Any],
750 split_config: Dict[str, Any], 509 split_config: Dict[str, Any],
751 ) -> str: 510 ) -> str:
752 """Build and serialize the Ludwig YAML configuration."""
753 logger.info("LudwigDirectBackend: Preparing YAML configuration.") 511 logger.info("LudwigDirectBackend: Preparing YAML configuration.")
754 512
755 model_name = config_params.get("model_name", "resnet18") 513 model_name = config_params.get("model_name", "resnet18")
756 use_pretrained = config_params.get("use_pretrained", False) 514 use_pretrained = config_params.get("use_pretrained", False)
757 fine_tune = config_params.get("fine_tune", False) 515 fine_tune = config_params.get("fine_tune", False)
516 if use_pretrained:
517 trainable = bool(fine_tune)
518 else:
519 trainable = True
758 epochs = config_params.get("epochs", 10) 520 epochs = config_params.get("epochs", 10)
759 batch_size = config_params.get("batch_size") 521 batch_size = config_params.get("batch_size")
760 num_processes = config_params.get("preprocessing_num_processes", 1) 522 num_processes = config_params.get("preprocessing_num_processes", 1)
761 early_stop = config_params.get("early_stop", None) 523 early_stop = config_params.get("early_stop", None)
762 learning_rate = config_params.get("learning_rate") 524 learning_rate = config_params.get("learning_rate")
763 learning_rate = "auto" if learning_rate is None else float(learning_rate) 525 learning_rate = "auto" if learning_rate is None else float(learning_rate)
764 trainable = fine_tune or (not use_pretrained)
765 if not use_pretrained and not trainable:
766 logger.warning("trainable=False; use_pretrained=False is ignored.")
767 logger.warning("Setting trainable=True to train the model from scratch.")
768 trainable = True
769 raw_encoder = MODEL_ENCODER_TEMPLATES.get(model_name, model_name) 526 raw_encoder = MODEL_ENCODER_TEMPLATES.get(model_name, model_name)
770 if isinstance(raw_encoder, dict): 527 if isinstance(raw_encoder, dict):
771 encoder_config = { 528 encoder_config = {
772 **raw_encoder, 529 **raw_encoder,
773 "use_pretrained": use_pretrained, 530 "use_pretrained": use_pretrained,
777 encoder_config = {"type": raw_encoder} 534 encoder_config = {"type": raw_encoder}
778 535
779 batch_size_cfg = batch_size or "auto" 536 batch_size_cfg = batch_size or "auto"
780 537
781 label_column_path = config_params.get("label_column_data_path") 538 label_column_path = config_params.get("label_column_data_path")
539 label_series = None
782 if label_column_path is not None and Path(label_column_path).exists(): 540 if label_column_path is not None and Path(label_column_path).exists():
783 try: 541 try:
784 label_series = pd.read_csv(label_column_path)[LABEL_COLUMN_NAME] 542 label_series = pd.read_csv(label_column_path)[LABEL_COLUMN_NAME]
785 num_unique_labels = label_series.nunique()
786 except Exception as e: 543 except Exception as e:
787 logger.warning( 544 logger.warning(f"Could not read label column for task detection: {e}")
788 f"Could not determine label cardinality, defaulting to 'binary': {e}" 545
789 ) 546 if (
790 num_unique_labels = 2 547 label_series is not None
548 and ptypes.is_numeric_dtype(label_series.dtype)
549 and label_series.nunique() > 10
550 ):
551 task_type = "regression"
791 else: 552 else:
792 logger.warning( 553 task_type = "classification"
793 "label_column_data_path not provided, defaulting to 'binary'" 554
794 ) 555 config_params["task_type"] = task_type
795 num_unique_labels = 2 556
796 557 image_feat: Dict[str, Any] = {
797 output_type = "binary" if num_unique_labels == 2 else "category" 558 "name": IMAGE_PATH_COLUMN_NAME,
559 "type": "image",
560 "encoder": encoder_config,
561 }
562 if config_params.get("augmentation") is not None:
563 image_feat["augmentation"] = config_params["augmentation"]
564
565 if task_type == "regression":
566 output_feat = {
567 "name": LABEL_COLUMN_NAME,
568 "type": "number",
569 "decoder": {"type": "regressor"},
570 "loss": {"type": "mean_squared_error"},
571 "evaluation": {
572 "metrics": [
573 "mean_squared_error",
574 "mean_absolute_error",
575 "r2",
576 ]
577 },
578 }
579 val_metric = config_params.get("validation_metric", "mean_squared_error")
580
581 else:
582 num_unique_labels = (
583 label_series.nunique() if label_series is not None else 2
584 )
585 output_type = "binary" if num_unique_labels == 2 else "category"
586 output_feat = {"name": LABEL_COLUMN_NAME, "type": output_type}
587 val_metric = None
798 588
799 conf: Dict[str, Any] = { 589 conf: Dict[str, Any] = {
800 "model_type": "ecd", 590 "model_type": "ecd",
801 "input_features": [ 591 "input_features": [image_feat],
802 { 592 "output_features": [output_feat],
803 "name": IMAGE_PATH_COLUMN_NAME,
804 "type": "image",
805 "encoder": encoder_config,
806 }
807 ],
808 "output_features": [{"name": LABEL_COLUMN_NAME, "type": output_type}],
809 "combiner": {"type": "concat"}, 593 "combiner": {"type": "concat"},
810 "trainer": { 594 "trainer": {
811 "epochs": epochs, 595 "epochs": epochs,
812 "early_stop": early_stop, 596 "early_stop": early_stop,
813 "batch_size": batch_size_cfg, 597 "batch_size": batch_size_cfg,
814 "learning_rate": learning_rate, 598 "learning_rate": learning_rate,
599 # only set validation_metric for regression
600 **({"validation_metric": val_metric} if val_metric else {}),
815 }, 601 },
816 "preprocessing": { 602 "preprocessing": {
817 "split": split_config, 603 "split": split_config,
818 "num_processes": num_processes, 604 "num_processes": num_processes,
819 "in_memory": False, 605 "in_memory": False,
874 "LudwigDirectBackend: Experiment execution error.", 660 "LudwigDirectBackend: Experiment execution error.",
875 exc_info=True, 661 exc_info=True,
876 ) 662 )
877 raise 663 raise
878 664
879 def get_training_process(self, output_dir) -> float: 665 def get_training_process(self, output_dir) -> Optional[Dict[str, Any]]:
880 """Retrieve the learning rate used in the most recent Ludwig run.""" 666 """Retrieve the learning rate used in the most recent Ludwig run."""
881 output_dir = Path(output_dir) 667 output_dir = Path(output_dir)
882 exp_dirs = sorted( 668 exp_dirs = sorted(
883 output_dir.glob("experiment_run*"), 669 output_dir.glob("experiment_run*"),
884 key=lambda p: p.stat().st_mtime, 670 key=lambda p: p.stat().st_mtime,
998 stats = json.load(f) 784 stats = json.load(f)
999 output_feature = next(iter(stats.keys()), "") 785 output_feature = next(iter(stats.keys()), "")
1000 786
1001 viz_registry = get_visualizations_registry() 787 viz_registry = get_visualizations_registry()
1002 for viz_name, viz_func in viz_registry.items(): 788 for viz_name, viz_func in viz_registry.items():
1003 viz_dir_plot = None
1004 if viz_name in train_plots: 789 if viz_name in train_plots:
1005 viz_dir_plot = train_viz 790 viz_dir_plot = train_viz
1006 elif viz_name in test_plots: 791 elif viz_name in test_plots:
1007 viz_dir_plot = test_viz 792 viz_dir_plot = test_viz
793 else:
794 continue
1008 795
1009 try: 796 try:
1010 viz_func( 797 viz_func(
1011 training_statistics=[training_stats] if training_stats else [], 798 training_statistics=[training_stats] if training_stats else [],
1012 test_statistics=[test_stats] if test_stats else [], 799 test_statistics=[test_stats] if test_stats else [],
1038 """Assemble an HTML report from visualizations under train_val/ and test/ folders.""" 825 """Assemble an HTML report from visualizations under train_val/ and test/ folders."""
1039 cwd = Path.cwd() 826 cwd = Path.cwd()
1040 report_name = title.lower().replace(" ", "_") + "_report.html" 827 report_name = title.lower().replace(" ", "_") + "_report.html"
1041 report_path = cwd / report_name 828 report_path = cwd / report_name
1042 output_dir = Path(output_dir) 829 output_dir = Path(output_dir)
830 output_type = None
1043 831
1044 exp_dirs = sorted( 832 exp_dirs = sorted(
1045 output_dir.glob("experiment_run*"), 833 output_dir.glob("experiment_run*"),
1046 key=lambda p: p.stat().st_mtime, 834 key=lambda p: p.stat().st_mtime,
1047 ) 835 )
1057 html += f"<h1>{title}</h1>" 845 html += f"<h1>{title}</h1>"
1058 846
1059 metrics_html = "" 847 metrics_html = ""
1060 train_val_metrics_html = "" 848 train_val_metrics_html = ""
1061 test_metrics_html = "" 849 test_metrics_html = ""
1062
1063 try: 850 try:
1064 train_stats_path = exp_dir / "training_statistics.json" 851 train_stats_path = exp_dir / "training_statistics.json"
1065 test_stats_path = exp_dir / TEST_STATISTICS_FILE_NAME 852 test_stats_path = exp_dir / TEST_STATISTICS_FILE_NAME
1066 if train_stats_path.exists() and test_stats_path.exists(): 853 if train_stats_path.exists() and test_stats_path.exists():
1067 with open(train_stats_path) as f: 854 with open(train_stats_path) as f:
1068 train_stats = json.load(f) 855 train_stats = json.load(f)
1069 with open(test_stats_path) as f: 856 with open(test_stats_path) as f:
1070 test_stats = json.load(f) 857 test_stats = json.load(f)
1071 output_type = detect_output_type(test_stats) 858 output_type = detect_output_type(test_stats)
1072 all_metrics = extract_metrics_from_json(
1073 train_stats,
1074 test_stats,
1075 output_type,
1076 )
1077 metrics_html = format_stats_table_html(train_stats, test_stats) 859 metrics_html = format_stats_table_html(train_stats, test_stats)
1078 train_val_metrics_html = format_train_val_stats_table_html( 860 train_val_metrics_html = format_train_val_stats_table_html(
1079 train_stats, 861 train_stats, test_stats
1080 test_stats,
1081 ) 862 )
1082 test_metrics_html = format_test_merged_stats_table_html( 863 test_metrics_html = format_test_merged_stats_table_html(
1083 all_metrics["test"], 864 extract_metrics_from_json(train_stats, test_stats, output_type)[
865 "test"
866 ]
1084 ) 867 )
1085 except Exception as e: 868 except Exception as e:
1086 logger.warning( 869 logger.warning(
1087 f"Could not load stats for HTML report: {type(e).__name__}: {e}" 870 f"Could not load stats for HTML report: {type(e).__name__}: {e}"
1088 ) 871 )
1089 872
1090 config_html = "" 873 config_html = ""
1091 training_progress = self.get_training_process(output_dir) 874 training_progress = self.get_training_process(output_dir)
1092 try: 875 try:
1093 config_html = format_config_table_html(config, split_info, training_progress) 876 config_html = format_config_table_html(
877 config, split_info, training_progress
878 )
1094 except Exception as e: 879 except Exception as e:
1095 logger.warning(f"Could not load config for HTML report: {e}") 880 logger.warning(f"Could not load config for HTML report: {e}")
1096 881
1097 def render_img_section(title: str, dir_path: Path, output_type: str = None) -> str: 882 def render_img_section(
883 title: str, dir_path: Path, output_type: str = None
884 ) -> str:
1098 if not dir_path.exists(): 885 if not dir_path.exists():
1099 return f"<h2>{title}</h2><p><em>Directory not found.</em></p>" 886 return f"<h2>{title}</h2><p><em>Directory not found.</em></p>"
1100 887
1101 imgs = list(dir_path.glob("*.png")) 888 imgs = list(dir_path.glob("*.png"))
1102 if not imgs: 889 if not imgs:
1139 img_names = {img.name: img for img in imgs if img.name not in unwanted} 926 img_names = {img.name: img for img in imgs if img.name not in unwanted}
1140 ordered_imgs = [ 927 ordered_imgs = [
1141 img_names[fname] for fname in display_order if fname in img_names 928 img_names[fname] for fname in display_order if fname in img_names
1142 ] 929 ]
1143 remaining = sorted( 930 remaining = sorted(
1144 [ 931 [img for img in img_names.values() if img.name not in display_order]
1145 img
1146 for img in img_names.values()
1147 if img.name not in display_order
1148 ]
1149 ) 932 )
1150 imgs = ordered_imgs + remaining 933 imgs = ordered_imgs + remaining
1151 934
1152 else: 935 else:
1153 if output_type == "category": 936 if output_type == "category":
1171 f"</div>" 954 f"</div>"
1172 ) 955 )
1173 section_html += "</div>" 956 section_html += "</div>"
1174 return section_html 957 return section_html
1175 958
1176 button_html = """ 959 tab1_content = config_html + metrics_html
1177 <button class="help-modal-btn openMetricsHelp">Model Evaluation Metrics — Help Guide</button> 960
1178 <br><br> 961 tab2_content = train_val_metrics_html + render_img_section(
1179 <style> 962 "Training & Validation Visualizations", train_viz_dir
1180 .help-modal-btn { 963 )
1181 background-color: #17623b; 964
1182 color: #fff; 965 # --- Predictions vs Ground Truth table ---
1183 border: none; 966 preds_section = ""
1184 border-radius: 24px; 967 parquet_path = exp_dir / PREDICTIONS_PARQUET_FILE_NAME
1185 padding: 10px 28px; 968 if parquet_path.exists():
1186 font-size: 1.1rem; 969 try:
1187 font-weight: bold; 970 # 1) load predictions from Parquet
1188 letter-spacing: 0.03em; 971 df_preds = pd.read_parquet(parquet_path).reset_index(drop=True)
1189 cursor: pointer; 972 # assume the column containing your model's prediction is named "prediction"
1190 transition: background 0.2s, box-shadow 0.2s; 973 # or contains that substring:
1191 box-shadow: 0 2px 8px rgba(23,98,59,0.07); 974 pred_col = next(
1192 } 975 (c for c in df_preds.columns if "prediction" in c.lower()),
1193 .help-modal-btn:hover, .help-modal-btn:focus { 976 None,
1194 background-color: #21895e; 977 )
1195 outline: none; 978 if pred_col is None:
1196 box-shadow: 0 4px 16px rgba(23,98,59,0.14); 979 raise ValueError("No prediction column found in Parquet output")
1197 } 980 df_pred = df_preds[[pred_col]].rename(columns={pred_col: "prediction"})
1198 </style> 981
1199 """ 982 # 2) load ground truth for the test split from prepared CSV
1200 tab1_content = button_html + config_html + metrics_html 983 df_all = pd.read_csv(config["label_column_data_path"])
1201 tab2_content = ( 984 df_gt = df_all[df_all[SPLIT_COLUMN_NAME] == 2][
1202 button_html 985 LABEL_COLUMN_NAME
1203 + train_val_metrics_html 986 ].reset_index(drop=True)
1204 + render_img_section("Training & Validation Visualizations", train_viz_dir) 987
1205 ) 988 # 3) concatenate side‐by‐side
989 df_table = pd.concat([df_gt, df_pred], axis=1)
990 df_table.columns = [LABEL_COLUMN_NAME, "prediction"]
991
992 # 4) render as HTML
993 preds_html = df_table.to_html(index=False, classes="predictions-table")
994 preds_section = (
995 "<h2 style='text-align: center;'>Predictions vs. Ground Truth</h2>"
996 "<div style='overflow-x:auto; margin-bottom:20px;'>"
997 + preds_html
998 + "</div>"
999 )
1000 except Exception as e:
1001 logger.warning(f"Could not build Predictions vs GT table: {e}")
1002 # Test tab = Metrics + Preds table + Visualizations
1003
1206 tab3_content = ( 1004 tab3_content = (
1207 button_html 1005 test_metrics_html
1208 + test_metrics_html 1006 + preds_section
1209 + render_img_section("Test Visualizations", test_viz_dir, output_type) 1007 + render_img_section("Test Visualizations", test_viz_dir, output_type)
1210 ) 1008 )
1211 1009
1010 # assemble the tabs and help modal
1212 tabbed_html = build_tabbed_html(tab1_content, tab2_content, tab3_content) 1011 tabbed_html = build_tabbed_html(tab1_content, tab2_content, tab3_content)
1213 modal_html = get_metrics_help_modal() 1012 modal_html = get_metrics_help_modal()
1214 html += tabbed_html + modal_html 1013 html += tabbed_html + modal_html + get_html_closing()
1215 html += get_html_closing()
1216 1014
1217 try: 1015 try:
1218 with open(report_path, "w") as f: 1016 with open(report_path, "w") as f:
1219 f.write(html) 1017 f.write(html)
1220 logger.info(f"HTML report generated at: {report_path}") 1018 logger.info(f"HTML report generated at: {report_path}")
1261 logger.info("Image extraction complete.") 1059 logger.info("Image extraction complete.")
1262 except Exception: 1060 except Exception:
1263 logger.error("Error extracting zip file", exc_info=True) 1061 logger.error("Error extracting zip file", exc_info=True)
1264 raise 1062 raise
1265 1063
1266 def _prepare_data(self) -> Tuple[Path, Dict[str, Any]]: 1064 def _prepare_data(self) -> Tuple[Path, Dict[str, Any], str]:
1267 """Load CSV, update image paths, handle splits, and write prepared CSV.""" 1065 """Load CSV, update image paths, handle splits, and write prepared CSV."""
1268 if not self.temp_dir or not self.image_extract_dir: 1066 if not self.temp_dir or not self.image_extract_dir:
1269 raise RuntimeError("Temp dirs not initialized before data prep.") 1067 raise RuntimeError("Temp dirs not initialized before data prep.")
1270 1068
1271 try: 1069 try:
1300 f"No split column in CSV. Used random split: " 1098 f"No split column in CSV. Used random split: "
1301 f"{[int(p * 100) for p in self.args.split_probabilities]}% " 1099 f"{[int(p * 100) for p in self.args.split_probabilities]}% "
1302 f"for train/val/test." 1100 f"for train/val/test."
1303 ) 1101 )
1304 1102
1305 final_csv = TEMP_CSV_FILENAME 1103 final_csv = self.temp_dir / TEMP_CSV_FILENAME
1306 try: 1104 try:
1105
1307 df.to_csv(final_csv, index=False) 1106 df.to_csv(final_csv, index=False)
1308 logger.info(f"Saved prepared data to {final_csv}") 1107 logger.info(f"Saved prepared data to {final_csv}")
1309 except Exception: 1108 except Exception:
1310 logger.error("Error saving prepared CSV", exc_info=True) 1109 logger.error("Error saving prepared CSV", exc_info=True)
1311 raise 1110 raise
1312 1111
1313 return final_csv, split_config, split_info 1112 return final_csv, split_config, split_info
1314 1113
1315 def _process_fixed_split(self, df: pd.DataFrame) -> Dict[str, Any]: 1114 def _process_fixed_split(
1115 self, df: pd.DataFrame
1116 ) -> Tuple[pd.DataFrame, Dict[str, Any], str]:
1316 """Process a fixed split column (0=train,1=val,2=test).""" 1117 """Process a fixed split column (0=train,1=val,2=test)."""
1317 logger.info(f"Fixed split column '{SPLIT_COLUMN_NAME}' detected.") 1118 logger.info(f"Fixed split column '{SPLIT_COLUMN_NAME}' detected.")
1318 try: 1119 try:
1319 col = df[SPLIT_COLUMN_NAME] 1120 col = df[SPLIT_COLUMN_NAME]
1320 df[SPLIT_COLUMN_NAME] = pd.to_numeric(col, errors="coerce").astype( 1121 df[SPLIT_COLUMN_NAME] = pd.to_numeric(col, errors="coerce").astype(
1382 "split_probabilities": self.args.split_probabilities, 1183 "split_probabilities": self.args.split_probabilities,
1383 "learning_rate": self.args.learning_rate, 1184 "learning_rate": self.args.learning_rate,
1384 "random_seed": self.args.random_seed, 1185 "random_seed": self.args.random_seed,
1385 "early_stop": self.args.early_stop, 1186 "early_stop": self.args.early_stop,
1386 "label_column_data_path": csv_path, 1187 "label_column_data_path": csv_path,
1188 "augmentation": self.args.augmentation,
1387 } 1189 }
1388 yaml_str = self.backend.prepare_config(backend_args, split_cfg) 1190 yaml_str = self.backend.prepare_config(backend_args, split_cfg)
1389 1191
1390 config_file = self.temp_dir / TEMP_CONFIG_FILENAME 1192 config_file = self.temp_dir / TEMP_CONFIG_FILENAME
1391 config_file.write_text(yaml_str) 1193 config_file.write_text(yaml_str)
1420 return float(s) 1222 return float(s)
1421 except (TypeError, ValueError): 1223 except (TypeError, ValueError):
1422 return None 1224 return None
1423 1225
1424 1226
1227 def aug_parse(aug_string: str):
1228 """
1229 Parse comma-separated augmentation keys into Ludwig augmentation dicts.
1230 Raises ValueError on unknown key.
1231 """
1232 mapping = {
1233 "random_horizontal_flip": {"type": "random_horizontal_flip"},
1234 "random_vertical_flip": {"type": "random_vertical_flip"},
1235 "random_rotate": {"type": "random_rotate", "degree": 10},
1236 "random_blur": {"type": "random_blur", "kernel_size": 3},
1237 "random_brightness": {"type": "random_brightness", "min": 0.5, "max": 2.0},
1238 "random_contrast": {"type": "random_contrast", "min": 0.5, "max": 2.0},
1239 }
1240 aug_list = []
1241 for tok in aug_string.split(","):
1242 key = tok.strip()
1243 if key not in mapping:
1244 valid = ", ".join(mapping.keys())
1245 raise ValueError(f"Unknown augmentation '{key}'. Valid choices: {valid}")
1246 aug_list.append(mapping[key])
1247 return aug_list
1248
1249
1425 class SplitProbAction(argparse.Action): 1250 class SplitProbAction(argparse.Action):
1426 def __call__(self, parser, namespace, values, option_string=None): 1251 def __call__(self, parser, namespace, values, option_string=None):
1427 train, val, test = values 1252 train, val, test = values
1428 total = train + val + test 1253 total = train + val + test
1429 if abs(total - 1.0) > 1e-6: 1254 if abs(total - 1.0) > 1e-6:
1506 type=float, 1331 type=float,
1507 nargs=3, 1332 nargs=3,
1508 metavar=("train", "val", "test"), 1333 metavar=("train", "val", "test"),
1509 action=SplitProbAction, 1334 action=SplitProbAction,
1510 default=[0.7, 0.1, 0.2], 1335 default=[0.7, 0.1, 0.2],
1511 help="Random split proportions (e.g., 0.7 0.1 0.2). Only used if no split column.", 1336 help=(
1337 "Random split proportions (e.g., 0.7 0.1 0.2)."
1338 "Only used if no split column."
1339 ),
1512 ) 1340 )
1513 parser.add_argument( 1341 parser.add_argument(
1514 "--random-seed", 1342 "--random-seed",
1515 type=int, 1343 type=int,
1516 default=42, 1344 default=42,
1519 parser.add_argument( 1347 parser.add_argument(
1520 "--learning-rate", 1348 "--learning-rate",
1521 type=parse_learning_rate, 1349 type=parse_learning_rate,
1522 default=None, 1350 default=None,
1523 help="Learning rate. If not provided, Ludwig will auto-select it.", 1351 help="Learning rate. If not provided, Ludwig will auto-select it.",
1352 )
1353 parser.add_argument(
1354 "--augmentation",
1355 type=str,
1356 default=None,
1357 help=(
1358 "Comma-separated list (in order) of any of: "
1359 "random_horizontal_flip, random_vertical_flip, random_rotate, "
1360 "random_blur, random_brightness, random_contrast. "
1361 "E.g. --augmentation random_horizontal_flip,random_rotate"
1362 ),
1524 ) 1363 )
1525 1364
1526 args = parser.parse_args() 1365 args = parser.parse_args()
1527 1366
1528 if not 0.0 <= args.validation_size <= 1.0: 1367 if not 0.0 <= args.validation_size <= 1.0:
1529 parser.error("validation-size must be between 0.0 and 1.0") 1368 parser.error("validation-size must be between 0.0 and 1.0")
1530 if not args.csv_file.is_file(): 1369 if not args.csv_file.is_file():
1531 parser.error(f"CSV not found: {args.csv_file}") 1370 parser.error(f"CSV not found: {args.csv_file}")
1532 if not args.image_zip.is_file(): 1371 if not args.image_zip.is_file():
1533 parser.error(f"ZIP not found: {args.image_zip}") 1372 parser.error(f"ZIP not found: {args.image_zip}")
1373 if args.augmentation is not None:
1374 try:
1375 augmentation_setup = aug_parse(args.augmentation)
1376 setattr(args, "augmentation", augmentation_setup)
1377 except ValueError as e:
1378 parser.error(str(e))
1534 1379
1535 backend_instance = LudwigDirectBackend() 1380 backend_instance = LudwigDirectBackend()
1536 orchestrator = WorkflowOrchestrator(args, backend_instance) 1381 orchestrator = WorkflowOrchestrator(args, backend_instance)
1537 1382
1538 exit_code = 0 1383 exit_code = 0