Mercurial > repos > goeckslab > image_learner
comparison image_learner_cli.py @ 2:186424a7eca7 draft
planemo upload for repository https://github.com/goeckslab/gleam.git commit 91fa4aba245520fc0680088a07cead66bcfd4ed2
author | goeckslab |
---|---|
date | Thu, 03 Jul 2025 20:43:24 +0000 |
parents | 39202fe5cf97 |
children | 09904b1f61f5 |
comparison
equal
deleted
inserted
replaced
1:39202fe5cf97 | 2:186424a7eca7 |
---|---|
1 #!/usr/bin/env python3 | |
2 import argparse | 1 import argparse |
3 import json | 2 import json |
4 import logging | 3 import logging |
5 import os | 4 import os |
6 import shutil | 5 import shutil |
9 import zipfile | 8 import zipfile |
10 from pathlib import Path | 9 from pathlib import Path |
11 from typing import Any, Dict, Optional, Protocol, Tuple | 10 from typing import Any, Dict, Optional, Protocol, Tuple |
12 | 11 |
13 import pandas as pd | 12 import pandas as pd |
13 import pandas.api.types as ptypes | |
14 import yaml | 14 import yaml |
15 from constants import ( | |
16 IMAGE_PATH_COLUMN_NAME, | |
17 LABEL_COLUMN_NAME, | |
18 METRIC_DISPLAY_NAMES, | |
19 MODEL_ENCODER_TEMPLATES, | |
20 SPLIT_COLUMN_NAME, | |
21 TEMP_CONFIG_FILENAME, | |
22 TEMP_CSV_FILENAME, | |
23 TEMP_DIR_PREFIX | |
24 ) | |
15 from ludwig.globals import ( | 25 from ludwig.globals import ( |
16 DESCRIPTION_FILE_NAME, | 26 DESCRIPTION_FILE_NAME, |
17 PREDICTIONS_PARQUET_FILE_NAME, | 27 PREDICTIONS_PARQUET_FILE_NAME, |
18 TEST_STATISTICS_FILE_NAME, | 28 TEST_STATISTICS_FILE_NAME, |
19 TRAIN_SET_METADATA_FILE_NAME, | 29 TRAIN_SET_METADATA_FILE_NAME, |
20 ) | 30 ) |
21 from ludwig.utils.data_utils import get_split_path | 31 from ludwig.utils.data_utils import get_split_path |
22 from ludwig.visualize import get_visualizations_registry | 32 from ludwig.visualize import get_visualizations_registry |
23 from sklearn.model_selection import train_test_split | 33 from sklearn.model_selection import train_test_split |
24 from utils import encode_image_to_base64, get_html_closing, get_html_template | 34 from utils import ( |
25 | 35 build_tabbed_html, |
26 # --- Constants --- | 36 encode_image_to_base64, |
27 SPLIT_COLUMN_NAME = "split" | 37 get_html_closing, |
28 LABEL_COLUMN_NAME = "label" | 38 get_html_template, |
29 IMAGE_PATH_COLUMN_NAME = "image_path" | 39 get_metrics_help_modal |
30 DEFAULT_SPLIT_PROBABILITIES = [0.7, 0.1, 0.2] | 40 ) |
31 TEMP_CSV_FILENAME = "processed_data_for_ludwig.csv" | |
32 TEMP_CONFIG_FILENAME = "ludwig_config.yaml" | |
33 TEMP_DIR_PREFIX = "ludwig_api_work_" | |
34 MODEL_ENCODER_TEMPLATES: Dict[str, Any] = { | |
35 "stacked_cnn": "stacked_cnn", | |
36 "resnet18": {"type": "resnet", "model_variant": 18}, | |
37 "resnet34": {"type": "resnet", "model_variant": 34}, | |
38 "resnet50": {"type": "resnet", "model_variant": 50}, | |
39 "resnet101": {"type": "resnet", "model_variant": 101}, | |
40 "resnet152": {"type": "resnet", "model_variant": 152}, | |
41 "resnext50_32x4d": {"type": "resnext", "model_variant": "50_32x4d"}, | |
42 "resnext101_32x8d": {"type": "resnext", "model_variant": "101_32x8d"}, | |
43 "resnext101_64x4d": {"type": "resnext", "model_variant": "101_64x4d"}, | |
44 "resnext152_32x8d": {"type": "resnext", "model_variant": "152_32x8d"}, | |
45 "wide_resnet50_2": {"type": "wide_resnet", "model_variant": "50_2"}, | |
46 "wide_resnet101_2": {"type": "wide_resnet", "model_variant": "101_2"}, | |
47 "wide_resnet103_2": {"type": "wide_resnet", "model_variant": "103_2"}, | |
48 "efficientnet_b0": {"type": "efficientnet", "model_variant": "b0"}, | |
49 "efficientnet_b1": {"type": "efficientnet", "model_variant": "b1"}, | |
50 "efficientnet_b2": {"type": "efficientnet", "model_variant": "b2"}, | |
51 "efficientnet_b3": {"type": "efficientnet", "model_variant": "b3"}, | |
52 "efficientnet_b4": {"type": "efficientnet", "model_variant": "b4"}, | |
53 "efficientnet_b5": {"type": "efficientnet", "model_variant": "b5"}, | |
54 "efficientnet_b6": {"type": "efficientnet", "model_variant": "b6"}, | |
55 "efficientnet_b7": {"type": "efficientnet", "model_variant": "b7"}, | |
56 "efficientnet_v2_s": {"type": "efficientnet", "model_variant": "v2_s"}, | |
57 "efficientnet_v2_m": {"type": "efficientnet", "model_variant": "v2_m"}, | |
58 "efficientnet_v2_l": {"type": "efficientnet", "model_variant": "v2_l"}, | |
59 "regnet_y_400mf": {"type": "regnet", "model_variant": "y_400mf"}, | |
60 "regnet_y_800mf": {"type": "regnet", "model_variant": "y_800mf"}, | |
61 "regnet_y_1_6gf": {"type": "regnet", "model_variant": "y_1_6gf"}, | |
62 "regnet_y_3_2gf": {"type": "regnet", "model_variant": "y_3_2gf"}, | |
63 "regnet_y_8gf": {"type": "regnet", "model_variant": "y_8gf"}, | |
64 "regnet_y_16gf": {"type": "regnet", "model_variant": "y_16gf"}, | |
65 "regnet_y_32gf": {"type": "regnet", "model_variant": "y_32gf"}, | |
66 "regnet_y_128gf": {"type": "regnet", "model_variant": "y_128gf"}, | |
67 "regnet_x_400mf": {"type": "regnet", "model_variant": "x_400mf"}, | |
68 "regnet_x_800mf": {"type": "regnet", "model_variant": "x_800mf"}, | |
69 "regnet_x_1_6gf": {"type": "regnet", "model_variant": "x_1_6gf"}, | |
70 "regnet_x_3_2gf": {"type": "regnet", "model_variant": "x_3_2gf"}, | |
71 "regnet_x_8gf": {"type": "regnet", "model_variant": "x_8gf"}, | |
72 "regnet_x_16gf": {"type": "regnet", "model_variant": "x_16gf"}, | |
73 "regnet_x_32gf": {"type": "regnet", "model_variant": "x_32gf"}, | |
74 "vgg11": {"type": "vgg", "model_variant": 11}, | |
75 "vgg11_bn": {"type": "vgg", "model_variant": "11_bn"}, | |
76 "vgg13": {"type": "vgg", "model_variant": 13}, | |
77 "vgg13_bn": {"type": "vgg", "model_variant": "13_bn"}, | |
78 "vgg16": {"type": "vgg", "model_variant": 16}, | |
79 "vgg16_bn": {"type": "vgg", "model_variant": "16_bn"}, | |
80 "vgg19": {"type": "vgg", "model_variant": 19}, | |
81 "vgg19_bn": {"type": "vgg", "model_variant": "19_bn"}, | |
82 "shufflenet_v2_x0_5": {"type": "shufflenet_v2", "model_variant": "x0_5"}, | |
83 "shufflenet_v2_x1_0": {"type": "shufflenet_v2", "model_variant": "x1_0"}, | |
84 "shufflenet_v2_x1_5": {"type": "shufflenet_v2", "model_variant": "x1_5"}, | |
85 "shufflenet_v2_x2_0": {"type": "shufflenet_v2", "model_variant": "x2_0"}, | |
86 "squeezenet1_0": {"type": "squeezenet", "model_variant": "1_0"}, | |
87 "squeezenet1_1": {"type": "squeezenet", "model_variant": "1_1"}, | |
88 "swin_t": {"type": "swin_transformer", "model_variant": "t"}, | |
89 "swin_s": {"type": "swin_transformer", "model_variant": "s"}, | |
90 "swin_b": {"type": "swin_transformer", "model_variant": "b"}, | |
91 "swin_v2_t": {"type": "swin_transformer", "model_variant": "v2_t"}, | |
92 "swin_v2_s": {"type": "swin_transformer", "model_variant": "v2_s"}, | |
93 "swin_v2_b": {"type": "swin_transformer", "model_variant": "v2_b"}, | |
94 "vit_b_16": {"type": "vision_transformer", "model_variant": "b_16"}, | |
95 "vit_b_32": {"type": "vision_transformer", "model_variant": "b_32"}, | |
96 "vit_l_16": {"type": "vision_transformer", "model_variant": "l_16"}, | |
97 "vit_l_32": {"type": "vision_transformer", "model_variant": "l_32"}, | |
98 "vit_h_14": {"type": "vision_transformer", "model_variant": "h_14"}, | |
99 "convnext_tiny": {"type": "convnext", "model_variant": "tiny"}, | |
100 "convnext_small": {"type": "convnext", "model_variant": "small"}, | |
101 "convnext_base": {"type": "convnext", "model_variant": "base"}, | |
102 "convnext_large": {"type": "convnext", "model_variant": "large"}, | |
103 "maxvit_t": {"type": "maxvit", "model_variant": "t"}, | |
104 "alexnet": {"type": "alexnet"}, | |
105 "googlenet": {"type": "googlenet"}, | |
106 "inception_v3": {"type": "inception_v3"}, | |
107 "mobilenet_v2": {"type": "mobilenet_v2"}, | |
108 "mobilenet_v3_large": {"type": "mobilenet_v3_large"}, | |
109 "mobilenet_v3_small": {"type": "mobilenet_v3_small"}, | |
110 } | |
111 METRIC_DISPLAY_NAMES = { | |
112 "accuracy": "Accuracy", | |
113 "accuracy_micro": "Accuracy-Micro", | |
114 "loss": "Loss", | |
115 "roc_auc": "ROC-AUC", | |
116 "roc_auc_macro": "ROC-AUC-Macro", | |
117 "roc_auc_micro": "ROC-AUC-Micro", | |
118 "hits_at_k": "Hits at K", | |
119 "precision": "Precision", | |
120 "recall": "Recall", | |
121 "specificity": "Specificity", | |
122 "kappa_score": "Cohen's Kappa", | |
123 "token_accuracy": "Token Accuracy", | |
124 "avg_precision_macro": "Precision-Macro", | |
125 "avg_recall_macro": "Recall-Macro", | |
126 "avg_f1_score_macro": "F1-score-Macro", | |
127 "avg_precision_micro": "Precision-Micro", | |
128 "avg_recall_micro": "Recall-Micro", | |
129 "avg_f1_score_micro": "F1-score-Micro", | |
130 "avg_precision_weighted": "Precision-Weighted", | |
131 "avg_recall_weighted": "Recall-Weighted", | |
132 "avg_f1_score_weighted": "F1-score-Weighted", | |
133 "average_precision_macro": " Precision-Average-Macro", | |
134 "average_precision_micro": "Precision-Average-Micro", | |
135 "average_precision_samples": "Precision-Average-Samples", | |
136 } | |
137 | 41 |
138 # --- Logging Setup --- | 42 # --- Logging Setup --- |
139 logging.basicConfig( | 43 logging.basicConfig( |
140 level=logging.INFO, | 44 level=logging.INFO, |
141 format="%(asctime)s %(levelname)s %(name)s: %(message)s", | 45 format='%(asctime)s %(levelname)s %(name)s: %(message)s', |
142 ) | 46 ) |
143 logger = logging.getLogger("ImageLearner") | 47 logger = logging.getLogger("ImageLearner") |
144 | |
145 | |
146 def get_metrics_help_modal() -> str: | |
147 modal_html = """ | |
148 <div id="metricsHelpModal" class="modal"> | |
149 <div class="modal-content"> | |
150 <span class="close">×</span> | |
151 <h2>Model Evaluation Metrics — Help Guide</h2> | |
152 <div class="metrics-guide"> | |
153 <h3>1) General Metrics</h3> | |
154 <p><strong>Loss:</strong> Measures the difference between predicted and actual values. Lower is better. Often used for optimization during training.</p> | |
155 <p><strong>Accuracy:</strong> Proportion of correct predictions among all predictions. Simple but can be misleading for imbalanced datasets.</p> | |
156 <p><strong>Micro Accuracy:</strong> Calculates accuracy by summing up all individual true positives and true negatives across all classes, making it suitable for multiclass or multilabel problems.</p> | |
157 <p><strong>Token Accuracy:</strong> Measures how often the predicted tokens (e.g., in sequences) match the true tokens. Useful in sequence prediction tasks like NLP.</p> | |
158 <h3>2) Precision, Recall & Specificity</h3> | |
159 <p><strong>Precision:</strong> Out of all positive predictions, how many were correct. Precision = TP / (TP + FP). Helps when false positives are costly.</p> | |
160 <p><strong>Recall (Sensitivity):</strong> Out of all actual positives, how many were predicted correctly. Recall = TP / (TP + FN). Important when missing positives is risky.</p> | |
161 <p><strong>Specificity:</strong> True negative rate. Measures how well the model identifies negatives. Specificity = TN / (TN + FP). Useful in medical testing to avoid false alarms.</p> | |
162 <h3>3) Macro, Micro, and Weighted Averages</h3> | |
163 <p><strong>Macro Precision / Recall / F1:</strong> Averages the metric across all classes, treating each class equally, regardless of class frequency. Best when class sizes are balanced.</p> | |
164 <p><strong>Micro Precision / Recall / F1:</strong> Aggregates TP, FP, FN across all classes before computing the metric. Gives a global view and is ideal for class-imbalanced problems.</p> | |
165 <p><strong>Weighted Precision / Recall / F1:</strong> Averages each metric across classes, weighted by the number of true instances per class. Balances importance of classes based on frequency.</p> | |
166 <h3>4) Average Precision (PR-AUC Variants)</h3> | |
167 <p><strong>Average Precision Macro:</strong> Precision-Recall AUC averaged across all classes equally. Useful for balanced multi-class problems.</p> | |
168 <p><strong>Average Precision Micro:</strong> Global Precision-Recall AUC using all instances. Best for imbalanced data or multi-label classification.</p> | |
169 <p><strong>Average Precision Samples:</strong> Precision-Recall AUC averaged across individual samples (not classes). Ideal for multi-label problems where each sample can belong to multiple classes.</p> | |
170 <h3>5) ROC-AUC Variants</h3> | |
171 <p><strong>ROC-AUC:</strong> Measures model's ability to distinguish between classes. AUC = 1 is perfect; 0.5 is random guessing. Use for binary classification.</p> | |
172 <p><strong>Macro ROC-AUC:</strong> Averages the AUC across all classes equally. Suitable when classes are balanced and of equal importance.</p> | |
173 <p><strong>Micro ROC-AUC:</strong> Computes AUC from aggregated predictions across all classes. Useful in multiclass or multilabel settings with imbalance.</p> | |
174 <h3>6) Ranking Metrics</h3> | |
175 <p><strong>Hits at K:</strong> Measures whether the true label is among the top-K predictions. Common in recommendation systems and retrieval tasks.</p> | |
176 <h3>7) Confusion Matrix Stats (Per Class)</h3> | |
177 <p><strong>True Positives / Negatives (TP / TN):</strong> Correct predictions for positives and negatives respectively.</p> | |
178 <p><strong>False Positives / Negatives (FP / FN):</strong> Incorrect predictions — false alarms and missed detections.</p> | |
179 <h3>8) Other Useful Metrics</h3> | |
180 <p><strong>Cohen's Kappa:</strong> Measures agreement between predicted and actual values adjusted for chance. Useful for multiclass classification with imbalanced labels.</p> | |
181 <p><strong>Matthews Correlation Coefficient (MCC):</strong> Balanced measure of prediction quality that takes into account TP, TN, FP, and FN. Particularly effective for imbalanced datasets.</p> | |
182 <h3>9) Metric Recommendations</h3> | |
183 <ul> | |
184 <li>Use <strong>Accuracy + F1</strong> for balanced data.</li> | |
185 <li>Use <strong>Precision, Recall, ROC-AUC</strong> for imbalanced datasets.</li> | |
186 <li>Use <strong>Average Precision Micro</strong> for multilabel or class-imbalanced problems.</li> | |
187 <li>Use <strong>Macro scores</strong> when all classes should be treated equally.</li> | |
188 <li>Use <strong>Weighted scores</strong> when class imbalance should be accounted for without ignoring small classes.</li> | |
189 <li>Use <strong>Confusion Matrix stats</strong> to analyze class-wise performance.</li> | |
190 <li>Use <strong>Hits at K</strong> for recommendation or ranking-based tasks.</li> | |
191 </ul> | |
192 </div> | |
193 </div> | |
194 </div> | |
195 """ | |
196 modal_css = """ | |
197 <style> | |
198 .modal { | |
199 display: none; | |
200 position: fixed; | |
201 z-index: 1; | |
202 left: 0; | |
203 top: 0; | |
204 width: 100%; | |
205 height: 100%; | |
206 overflow: auto; | |
207 background-color: rgba(0,0,0,0.4); | |
208 } | |
209 .modal-content { | |
210 background-color: #fefefe; | |
211 margin: 15% auto; | |
212 padding: 20px; | |
213 border: 1px solid #888; | |
214 width: 80%; | |
215 max-width: 800px; | |
216 } | |
217 .close { | |
218 color: #aaa; | |
219 float: right; | |
220 font-size: 28px; | |
221 font-weight: bold; | |
222 } | |
223 .close:hover, | |
224 .close:focus { | |
225 color: black; | |
226 text-decoration: none; | |
227 cursor: pointer; | |
228 } | |
229 .metrics-guide h3 { | |
230 margin-top: 20px; | |
231 } | |
232 .metrics-guide p { | |
233 margin: 5px 0; | |
234 } | |
235 .metrics-guide ul { | |
236 margin: 10px 0; | |
237 padding-left: 20px; | |
238 } | |
239 </style> | |
240 """ | |
241 modal_js = """ | |
242 <script> | |
243 document.addEventListener("DOMContentLoaded", function() { | |
244 var modal = document.getElementById("metricsHelpModal"); | |
245 var closeBtn = document.getElementsByClassName("close")[0]; | |
246 | |
247 document.querySelectorAll(".openMetricsHelp").forEach(btn => { | |
248 btn.onclick = function() { | |
249 modal.style.display = "block"; | |
250 }; | |
251 }); | |
252 | |
253 if (closeBtn) { | |
254 closeBtn.onclick = function() { | |
255 modal.style.display = "none"; | |
256 }; | |
257 } | |
258 | |
259 window.onclick = function(event) { | |
260 if (event.target == modal) { | |
261 modal.style.display = "none"; | |
262 } | |
263 } | |
264 }); | |
265 </script> | |
266 """ | |
267 return modal_css + modal_html + modal_js | |
268 | 48 |
269 | 49 |
270 def format_config_table_html( | 50 def format_config_table_html( |
271 config: dict, | 51 config: dict, |
272 split_info: Optional[str] = None, | 52 split_info: Optional[str] = None, |
273 training_progress: dict = None, | 53 training_progress: dict = None, |
274 ) -> str: | 54 ) -> str: |
275 display_keys = [ | 55 display_keys = [ |
56 "task_type", | |
276 "model_name", | 57 "model_name", |
277 "epochs", | 58 "epochs", |
278 "batch_size", | 59 "batch_size", |
279 "fine_tune", | 60 "fine_tune", |
280 "use_pretrained", | 61 "use_pretrained", |
285 | 66 |
286 rows = [] | 67 rows = [] |
287 | 68 |
288 for key in display_keys: | 69 for key in display_keys: |
289 val = config.get(key, "N/A") | 70 val = config.get(key, "N/A") |
71 if key == "task_type": | |
72 val = val.title() if isinstance(val, str) else val | |
290 if key == "batch_size": | 73 if key == "batch_size": |
291 if val is not None: | 74 if val is not None: |
292 val = int(val) | 75 val = int(val) |
293 else: | 76 else: |
294 if training_progress: | 77 if training_progress: |
346 f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center;'>" | 129 f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center;'>" |
347 f"{val}</td>" | 130 f"{val}</td>" |
348 f"</tr>" | 131 f"</tr>" |
349 ) | 132 ) |
350 | 133 |
134 aug_cfg = config.get("augmentation") | |
135 if aug_cfg: | |
136 types = [str(a.get("type", "")) for a in aug_cfg] | |
137 aug_val = ", ".join(types) | |
138 rows.append( | |
139 "<tr>" | |
140 "<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left;'>Augmentation</td>" | |
141 "<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center;'>" | |
142 f"{aug_val}</td>" | |
143 "</tr>" | |
144 ) | |
145 | |
351 if split_info: | 146 if split_info: |
352 rows.append( | 147 rows.append( |
353 f"<tr>" | 148 f"<tr>" |
354 f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left;'>" | 149 f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left;'>" |
355 f"Data Split</td>" | 150 f"Data Split</td>" |
369 "Value</th>" | 164 "Value</th>" |
370 "</tr></thead><tbody>" + "".join(rows) + "</tbody></table></div><br>" | 165 "</tr></thead><tbody>" + "".join(rows) + "</tbody></table></div><br>" |
371 "<p style='text-align: center; font-size: 0.9em;'>" | 166 "<p style='text-align: center; font-size: 0.9em;'>" |
372 "Model trained using Ludwig.<br>" | 167 "Model trained using Ludwig.<br>" |
373 "If want to learn more about Ludwig default settings," | 168 "If want to learn more about Ludwig default settings," |
374 "please check the their <a href='https://ludwig.ai' target='_blank'>" | 169 "please check their <a href='https://ludwig.ai' target='_blank'>" |
375 "website(ludwig.ai)</a>." | 170 "website(ludwig.ai)</a>." |
376 "</p><hr>" | 171 "</p><hr>" |
377 ) | 172 ) |
378 | 173 |
379 | 174 |
380 def detect_output_type(test_stats): | 175 def detect_output_type(test_stats): |
381 """Detects if the output type is 'binary' or 'category' based on test statistics.""" | 176 """Detects if the output type is 'binary' or 'category' based on test statistics.""" |
382 label_stats = test_stats.get("label", {}) | 177 label_stats = test_stats.get("label", {}) |
178 if "mean_squared_error" in label_stats: | |
179 return "regression" | |
383 per_class = label_stats.get("per_class_stats", {}) | 180 per_class = label_stats.get("per_class_stats", {}) |
384 if len(per_class) == 2: | 181 if len(per_class) == 2: |
385 return "binary" | 182 return "binary" |
386 return "category" | 183 return "category" |
387 | 184 |
418 "precision": get_last_value(label_stats, "precision"), | 215 "precision": get_last_value(label_stats, "precision"), |
419 "recall": get_last_value(label_stats, "recall"), | 216 "recall": get_last_value(label_stats, "recall"), |
420 "specificity": get_last_value(label_stats, "specificity"), | 217 "specificity": get_last_value(label_stats, "specificity"), |
421 "roc_auc": get_last_value(label_stats, "roc_auc"), | 218 "roc_auc": get_last_value(label_stats, "roc_auc"), |
422 } | 219 } |
220 elif output_type == "regression": | |
221 metrics[split] = { | |
222 "loss": get_last_value(label_stats, "loss"), | |
223 "mean_absolute_error": get_last_value( | |
224 label_stats, "mean_absolute_error" | |
225 ), | |
226 "mean_absolute_percentage_error": get_last_value( | |
227 label_stats, "mean_absolute_percentage_error" | |
228 ), | |
229 "mean_squared_error": get_last_value(label_stats, "mean_squared_error"), | |
230 "root_mean_squared_error": get_last_value( | |
231 label_stats, "root_mean_squared_error" | |
232 ), | |
233 "root_mean_squared_percentage_error": get_last_value( | |
234 label_stats, "root_mean_squared_percentage_error" | |
235 ), | |
236 "r2": get_last_value(label_stats, "r2"), | |
237 } | |
423 else: | 238 else: |
424 metrics[split] = { | 239 metrics[split] = { |
425 "accuracy": get_last_value(label_stats, "accuracy"), | 240 "accuracy": get_last_value(label_stats, "accuracy"), |
426 "accuracy_micro": get_last_value(label_stats, "accuracy_micro"), | 241 "accuracy_micro": get_last_value(label_stats, "accuracy_micro"), |
427 "loss": get_last_value(label_stats, "loss"), | 242 "loss": get_last_value(label_stats, "loss"), |
563 ) | 378 ) |
564 html += "</tbody></table></div><br>" | 379 html += "</tbody></table></div><br>" |
565 return html | 380 return html |
566 | 381 |
567 | 382 |
568 def format_test_merged_stats_table_html(test_metrics: Dict[str, Optional[float]]) -> str: | 383 def format_test_merged_stats_table_html( |
384 test_metrics: Dict[str, Optional[float]], | |
385 ) -> str: | |
569 """Formats an HTML table for test metrics.""" | 386 """Formats an HTML table for test metrics.""" |
570 rows = [] | 387 rows = [] |
571 for key in sorted(test_metrics.keys()): | 388 for key in sorted(test_metrics.keys()): |
572 display_name = METRIC_DISPLAY_NAMES.get(key, key.replace("_", " ").title()) | 389 display_name = METRIC_DISPLAY_NAMES.get(key, key.replace("_", " ").title()) |
573 value = test_metrics[key] | 390 value = test_metrics[key] |
594 "padding: 10px; border: 1px solid #ccc; text-align: center; " | 411 "padding: 10px; border: 1px solid #ccc; text-align: center; " |
595 "white-space: nowrap;", | 412 "white-space: nowrap;", |
596 ) | 413 ) |
597 html += "</tbody></table></div><br>" | 414 html += "</tbody></table></div><br>" |
598 return html | 415 return html |
599 | |
600 | |
601 def build_tabbed_html(metrics_html: str, train_val_html: str, test_html: str) -> str: | |
602 return f""" | |
603 <style> | |
604 .tabs {{ | |
605 display: flex; | |
606 border-bottom: 2px solid #ccc; | |
607 margin-bottom: 1rem; | |
608 }} | |
609 .tab {{ | |
610 padding: 10px 20px; | |
611 cursor: pointer; | |
612 border: 1px solid #ccc; | |
613 border-bottom: none; | |
614 background: #f9f9f9; | |
615 margin-right: 5px; | |
616 border-top-left-radius: 8px; | |
617 border-top-right-radius: 8px; | |
618 }} | |
619 .tab.active {{ | |
620 background: white; | |
621 font-weight: bold; | |
622 }} | |
623 .tab-content {{ | |
624 display: none; | |
625 padding: 20px; | |
626 border: 1px solid #ccc; | |
627 border-top: none; | |
628 }} | |
629 .tab-content.active {{ | |
630 display: block; | |
631 }} | |
632 </style> | |
633 <div class="tabs"> | |
634 <div class="tab active" onclick="showTab('metrics')"> Config & Results Summary</div> | |
635 <div class="tab" onclick="showTab('trainval')"> Train/Validation Results</div> | |
636 <div class="tab" onclick="showTab('test')"> Test Results</div> | |
637 </div> | |
638 <div id="metrics" class="tab-content active"> | |
639 {metrics_html} | |
640 </div> | |
641 <div id="trainval" class="tab-content"> | |
642 {train_val_html} | |
643 </div> | |
644 <div id="test" class="tab-content"> | |
645 {test_html} | |
646 </div> | |
647 <script> | |
648 function showTab(id) {{ | |
649 document.querySelectorAll('.tab-content').forEach(el => el.classList.remove('active')); | |
650 document.querySelectorAll('.tab').forEach(el => el.classList.remove('active')); | |
651 document.getElementById(id).classList.add('active'); | |
652 document.querySelector(`.tab[onclick*="${{id}}"]`).classList.add('active'); | |
653 }} | |
654 </script> | |
655 """ | |
656 | 416 |
657 | 417 |
658 def split_data_0_2( | 418 def split_data_0_2( |
659 df: pd.DataFrame, | 419 df: pd.DataFrame, |
660 split_column: str, | 420 split_column: str, |
725 output_dir: Path, | 485 output_dir: Path, |
726 random_seed: int, | 486 random_seed: int, |
727 ) -> None: | 487 ) -> None: |
728 ... | 488 ... |
729 | 489 |
730 def generate_plots( | 490 def generate_plots(self, output_dir: Path) -> None: |
731 self, | |
732 output_dir: Path | |
733 ) -> None: | |
734 ... | 491 ... |
735 | 492 |
736 def generate_html_report( | 493 def generate_html_report( |
737 self, | 494 self, |
738 title: str, | 495 title: str, |
739 output_dir: str | 496 output_dir: str, |
497 config: Dict[str, Any], | |
498 split_info: str, | |
740 ) -> Path: | 499 ) -> Path: |
741 ... | 500 ... |
742 | 501 |
743 | 502 |
744 class LudwigDirectBackend: | 503 class LudwigDirectBackend: |
747 def prepare_config( | 506 def prepare_config( |
748 self, | 507 self, |
749 config_params: Dict[str, Any], | 508 config_params: Dict[str, Any], |
750 split_config: Dict[str, Any], | 509 split_config: Dict[str, Any], |
751 ) -> str: | 510 ) -> str: |
752 """Build and serialize the Ludwig YAML configuration.""" | |
753 logger.info("LudwigDirectBackend: Preparing YAML configuration.") | 511 logger.info("LudwigDirectBackend: Preparing YAML configuration.") |
754 | 512 |
755 model_name = config_params.get("model_name", "resnet18") | 513 model_name = config_params.get("model_name", "resnet18") |
756 use_pretrained = config_params.get("use_pretrained", False) | 514 use_pretrained = config_params.get("use_pretrained", False) |
757 fine_tune = config_params.get("fine_tune", False) | 515 fine_tune = config_params.get("fine_tune", False) |
516 if use_pretrained: | |
517 trainable = bool(fine_tune) | |
518 else: | |
519 trainable = True | |
758 epochs = config_params.get("epochs", 10) | 520 epochs = config_params.get("epochs", 10) |
759 batch_size = config_params.get("batch_size") | 521 batch_size = config_params.get("batch_size") |
760 num_processes = config_params.get("preprocessing_num_processes", 1) | 522 num_processes = config_params.get("preprocessing_num_processes", 1) |
761 early_stop = config_params.get("early_stop", None) | 523 early_stop = config_params.get("early_stop", None) |
762 learning_rate = config_params.get("learning_rate") | 524 learning_rate = config_params.get("learning_rate") |
763 learning_rate = "auto" if learning_rate is None else float(learning_rate) | 525 learning_rate = "auto" if learning_rate is None else float(learning_rate) |
764 trainable = fine_tune or (not use_pretrained) | |
765 if not use_pretrained and not trainable: | |
766 logger.warning("trainable=False; use_pretrained=False is ignored.") | |
767 logger.warning("Setting trainable=True to train the model from scratch.") | |
768 trainable = True | |
769 raw_encoder = MODEL_ENCODER_TEMPLATES.get(model_name, model_name) | 526 raw_encoder = MODEL_ENCODER_TEMPLATES.get(model_name, model_name) |
770 if isinstance(raw_encoder, dict): | 527 if isinstance(raw_encoder, dict): |
771 encoder_config = { | 528 encoder_config = { |
772 **raw_encoder, | 529 **raw_encoder, |
773 "use_pretrained": use_pretrained, | 530 "use_pretrained": use_pretrained, |
777 encoder_config = {"type": raw_encoder} | 534 encoder_config = {"type": raw_encoder} |
778 | 535 |
779 batch_size_cfg = batch_size or "auto" | 536 batch_size_cfg = batch_size or "auto" |
780 | 537 |
781 label_column_path = config_params.get("label_column_data_path") | 538 label_column_path = config_params.get("label_column_data_path") |
539 label_series = None | |
782 if label_column_path is not None and Path(label_column_path).exists(): | 540 if label_column_path is not None and Path(label_column_path).exists(): |
783 try: | 541 try: |
784 label_series = pd.read_csv(label_column_path)[LABEL_COLUMN_NAME] | 542 label_series = pd.read_csv(label_column_path)[LABEL_COLUMN_NAME] |
785 num_unique_labels = label_series.nunique() | |
786 except Exception as e: | 543 except Exception as e: |
787 logger.warning( | 544 logger.warning(f"Could not read label column for task detection: {e}") |
788 f"Could not determine label cardinality, defaulting to 'binary': {e}" | 545 |
789 ) | 546 if ( |
790 num_unique_labels = 2 | 547 label_series is not None |
548 and ptypes.is_numeric_dtype(label_series.dtype) | |
549 and label_series.nunique() > 10 | |
550 ): | |
551 task_type = "regression" | |
791 else: | 552 else: |
792 logger.warning( | 553 task_type = "classification" |
793 "label_column_data_path not provided, defaulting to 'binary'" | 554 |
794 ) | 555 config_params["task_type"] = task_type |
795 num_unique_labels = 2 | 556 |
796 | 557 image_feat: Dict[str, Any] = { |
797 output_type = "binary" if num_unique_labels == 2 else "category" | 558 "name": IMAGE_PATH_COLUMN_NAME, |
559 "type": "image", | |
560 "encoder": encoder_config, | |
561 } | |
562 if config_params.get("augmentation") is not None: | |
563 image_feat["augmentation"] = config_params["augmentation"] | |
564 | |
565 if task_type == "regression": | |
566 output_feat = { | |
567 "name": LABEL_COLUMN_NAME, | |
568 "type": "number", | |
569 "decoder": {"type": "regressor"}, | |
570 "loss": {"type": "mean_squared_error"}, | |
571 "evaluation": { | |
572 "metrics": [ | |
573 "mean_squared_error", | |
574 "mean_absolute_error", | |
575 "r2", | |
576 ] | |
577 }, | |
578 } | |
579 val_metric = config_params.get("validation_metric", "mean_squared_error") | |
580 | |
581 else: | |
582 num_unique_labels = ( | |
583 label_series.nunique() if label_series is not None else 2 | |
584 ) | |
585 output_type = "binary" if num_unique_labels == 2 else "category" | |
586 output_feat = {"name": LABEL_COLUMN_NAME, "type": output_type} | |
587 val_metric = None | |
798 | 588 |
799 conf: Dict[str, Any] = { | 589 conf: Dict[str, Any] = { |
800 "model_type": "ecd", | 590 "model_type": "ecd", |
801 "input_features": [ | 591 "input_features": [image_feat], |
802 { | 592 "output_features": [output_feat], |
803 "name": IMAGE_PATH_COLUMN_NAME, | |
804 "type": "image", | |
805 "encoder": encoder_config, | |
806 } | |
807 ], | |
808 "output_features": [{"name": LABEL_COLUMN_NAME, "type": output_type}], | |
809 "combiner": {"type": "concat"}, | 593 "combiner": {"type": "concat"}, |
810 "trainer": { | 594 "trainer": { |
811 "epochs": epochs, | 595 "epochs": epochs, |
812 "early_stop": early_stop, | 596 "early_stop": early_stop, |
813 "batch_size": batch_size_cfg, | 597 "batch_size": batch_size_cfg, |
814 "learning_rate": learning_rate, | 598 "learning_rate": learning_rate, |
599 # only set validation_metric for regression | |
600 **({"validation_metric": val_metric} if val_metric else {}), | |
815 }, | 601 }, |
816 "preprocessing": { | 602 "preprocessing": { |
817 "split": split_config, | 603 "split": split_config, |
818 "num_processes": num_processes, | 604 "num_processes": num_processes, |
819 "in_memory": False, | 605 "in_memory": False, |
874 "LudwigDirectBackend: Experiment execution error.", | 660 "LudwigDirectBackend: Experiment execution error.", |
875 exc_info=True, | 661 exc_info=True, |
876 ) | 662 ) |
877 raise | 663 raise |
878 | 664 |
879 def get_training_process(self, output_dir) -> float: | 665 def get_training_process(self, output_dir) -> Optional[Dict[str, Any]]: |
880 """Retrieve the learning rate used in the most recent Ludwig run.""" | 666 """Retrieve the learning rate used in the most recent Ludwig run.""" |
881 output_dir = Path(output_dir) | 667 output_dir = Path(output_dir) |
882 exp_dirs = sorted( | 668 exp_dirs = sorted( |
883 output_dir.glob("experiment_run*"), | 669 output_dir.glob("experiment_run*"), |
884 key=lambda p: p.stat().st_mtime, | 670 key=lambda p: p.stat().st_mtime, |
998 stats = json.load(f) | 784 stats = json.load(f) |
999 output_feature = next(iter(stats.keys()), "") | 785 output_feature = next(iter(stats.keys()), "") |
1000 | 786 |
1001 viz_registry = get_visualizations_registry() | 787 viz_registry = get_visualizations_registry() |
1002 for viz_name, viz_func in viz_registry.items(): | 788 for viz_name, viz_func in viz_registry.items(): |
1003 viz_dir_plot = None | |
1004 if viz_name in train_plots: | 789 if viz_name in train_plots: |
1005 viz_dir_plot = train_viz | 790 viz_dir_plot = train_viz |
1006 elif viz_name in test_plots: | 791 elif viz_name in test_plots: |
1007 viz_dir_plot = test_viz | 792 viz_dir_plot = test_viz |
793 else: | |
794 continue | |
1008 | 795 |
1009 try: | 796 try: |
1010 viz_func( | 797 viz_func( |
1011 training_statistics=[training_stats] if training_stats else [], | 798 training_statistics=[training_stats] if training_stats else [], |
1012 test_statistics=[test_stats] if test_stats else [], | 799 test_statistics=[test_stats] if test_stats else [], |
1038 """Assemble an HTML report from visualizations under train_val/ and test/ folders.""" | 825 """Assemble an HTML report from visualizations under train_val/ and test/ folders.""" |
1039 cwd = Path.cwd() | 826 cwd = Path.cwd() |
1040 report_name = title.lower().replace(" ", "_") + "_report.html" | 827 report_name = title.lower().replace(" ", "_") + "_report.html" |
1041 report_path = cwd / report_name | 828 report_path = cwd / report_name |
1042 output_dir = Path(output_dir) | 829 output_dir = Path(output_dir) |
830 output_type = None | |
1043 | 831 |
1044 exp_dirs = sorted( | 832 exp_dirs = sorted( |
1045 output_dir.glob("experiment_run*"), | 833 output_dir.glob("experiment_run*"), |
1046 key=lambda p: p.stat().st_mtime, | 834 key=lambda p: p.stat().st_mtime, |
1047 ) | 835 ) |
1057 html += f"<h1>{title}</h1>" | 845 html += f"<h1>{title}</h1>" |
1058 | 846 |
1059 metrics_html = "" | 847 metrics_html = "" |
1060 train_val_metrics_html = "" | 848 train_val_metrics_html = "" |
1061 test_metrics_html = "" | 849 test_metrics_html = "" |
1062 | |
1063 try: | 850 try: |
1064 train_stats_path = exp_dir / "training_statistics.json" | 851 train_stats_path = exp_dir / "training_statistics.json" |
1065 test_stats_path = exp_dir / TEST_STATISTICS_FILE_NAME | 852 test_stats_path = exp_dir / TEST_STATISTICS_FILE_NAME |
1066 if train_stats_path.exists() and test_stats_path.exists(): | 853 if train_stats_path.exists() and test_stats_path.exists(): |
1067 with open(train_stats_path) as f: | 854 with open(train_stats_path) as f: |
1068 train_stats = json.load(f) | 855 train_stats = json.load(f) |
1069 with open(test_stats_path) as f: | 856 with open(test_stats_path) as f: |
1070 test_stats = json.load(f) | 857 test_stats = json.load(f) |
1071 output_type = detect_output_type(test_stats) | 858 output_type = detect_output_type(test_stats) |
1072 all_metrics = extract_metrics_from_json( | |
1073 train_stats, | |
1074 test_stats, | |
1075 output_type, | |
1076 ) | |
1077 metrics_html = format_stats_table_html(train_stats, test_stats) | 859 metrics_html = format_stats_table_html(train_stats, test_stats) |
1078 train_val_metrics_html = format_train_val_stats_table_html( | 860 train_val_metrics_html = format_train_val_stats_table_html( |
1079 train_stats, | 861 train_stats, test_stats |
1080 test_stats, | |
1081 ) | 862 ) |
1082 test_metrics_html = format_test_merged_stats_table_html( | 863 test_metrics_html = format_test_merged_stats_table_html( |
1083 all_metrics["test"], | 864 extract_metrics_from_json(train_stats, test_stats, output_type)[ |
865 "test" | |
866 ] | |
1084 ) | 867 ) |
1085 except Exception as e: | 868 except Exception as e: |
1086 logger.warning( | 869 logger.warning( |
1087 f"Could not load stats for HTML report: {type(e).__name__}: {e}" | 870 f"Could not load stats for HTML report: {type(e).__name__}: {e}" |
1088 ) | 871 ) |
1089 | 872 |
1090 config_html = "" | 873 config_html = "" |
1091 training_progress = self.get_training_process(output_dir) | 874 training_progress = self.get_training_process(output_dir) |
1092 try: | 875 try: |
1093 config_html = format_config_table_html(config, split_info, training_progress) | 876 config_html = format_config_table_html( |
877 config, split_info, training_progress | |
878 ) | |
1094 except Exception as e: | 879 except Exception as e: |
1095 logger.warning(f"Could not load config for HTML report: {e}") | 880 logger.warning(f"Could not load config for HTML report: {e}") |
1096 | 881 |
1097 def render_img_section(title: str, dir_path: Path, output_type: str = None) -> str: | 882 def render_img_section( |
883 title: str, dir_path: Path, output_type: str = None | |
884 ) -> str: | |
1098 if not dir_path.exists(): | 885 if not dir_path.exists(): |
1099 return f"<h2>{title}</h2><p><em>Directory not found.</em></p>" | 886 return f"<h2>{title}</h2><p><em>Directory not found.</em></p>" |
1100 | 887 |
1101 imgs = list(dir_path.glob("*.png")) | 888 imgs = list(dir_path.glob("*.png")) |
1102 if not imgs: | 889 if not imgs: |
1139 img_names = {img.name: img for img in imgs if img.name not in unwanted} | 926 img_names = {img.name: img for img in imgs if img.name not in unwanted} |
1140 ordered_imgs = [ | 927 ordered_imgs = [ |
1141 img_names[fname] for fname in display_order if fname in img_names | 928 img_names[fname] for fname in display_order if fname in img_names |
1142 ] | 929 ] |
1143 remaining = sorted( | 930 remaining = sorted( |
1144 [ | 931 [img for img in img_names.values() if img.name not in display_order] |
1145 img | |
1146 for img in img_names.values() | |
1147 if img.name not in display_order | |
1148 ] | |
1149 ) | 932 ) |
1150 imgs = ordered_imgs + remaining | 933 imgs = ordered_imgs + remaining |
1151 | 934 |
1152 else: | 935 else: |
1153 if output_type == "category": | 936 if output_type == "category": |
1171 f"</div>" | 954 f"</div>" |
1172 ) | 955 ) |
1173 section_html += "</div>" | 956 section_html += "</div>" |
1174 return section_html | 957 return section_html |
1175 | 958 |
1176 button_html = """ | 959 tab1_content = config_html + metrics_html |
1177 <button class="help-modal-btn openMetricsHelp">Model Evaluation Metrics — Help Guide</button> | 960 |
1178 <br><br> | 961 tab2_content = train_val_metrics_html + render_img_section( |
1179 <style> | 962 "Training & Validation Visualizations", train_viz_dir |
1180 .help-modal-btn { | 963 ) |
1181 background-color: #17623b; | 964 |
1182 color: #fff; | 965 # --- Predictions vs Ground Truth table --- |
1183 border: none; | 966 preds_section = "" |
1184 border-radius: 24px; | 967 parquet_path = exp_dir / PREDICTIONS_PARQUET_FILE_NAME |
1185 padding: 10px 28px; | 968 if parquet_path.exists(): |
1186 font-size: 1.1rem; | 969 try: |
1187 font-weight: bold; | 970 # 1) load predictions from Parquet |
1188 letter-spacing: 0.03em; | 971 df_preds = pd.read_parquet(parquet_path).reset_index(drop=True) |
1189 cursor: pointer; | 972 # assume the column containing your model's prediction is named "prediction" |
1190 transition: background 0.2s, box-shadow 0.2s; | 973 # or contains that substring: |
1191 box-shadow: 0 2px 8px rgba(23,98,59,0.07); | 974 pred_col = next( |
1192 } | 975 (c for c in df_preds.columns if "prediction" in c.lower()), |
1193 .help-modal-btn:hover, .help-modal-btn:focus { | 976 None, |
1194 background-color: #21895e; | 977 ) |
1195 outline: none; | 978 if pred_col is None: |
1196 box-shadow: 0 4px 16px rgba(23,98,59,0.14); | 979 raise ValueError("No prediction column found in Parquet output") |
1197 } | 980 df_pred = df_preds[[pred_col]].rename(columns={pred_col: "prediction"}) |
1198 </style> | 981 |
1199 """ | 982 # 2) load ground truth for the test split from prepared CSV |
1200 tab1_content = button_html + config_html + metrics_html | 983 df_all = pd.read_csv(config["label_column_data_path"]) |
1201 tab2_content = ( | 984 df_gt = df_all[df_all[SPLIT_COLUMN_NAME] == 2][ |
1202 button_html | 985 LABEL_COLUMN_NAME |
1203 + train_val_metrics_html | 986 ].reset_index(drop=True) |
1204 + render_img_section("Training & Validation Visualizations", train_viz_dir) | 987 |
1205 ) | 988 # 3) concatenate side‐by‐side |
989 df_table = pd.concat([df_gt, df_pred], axis=1) | |
990 df_table.columns = [LABEL_COLUMN_NAME, "prediction"] | |
991 | |
992 # 4) render as HTML | |
993 preds_html = df_table.to_html(index=False, classes="predictions-table") | |
994 preds_section = ( | |
995 "<h2 style='text-align: center;'>Predictions vs. Ground Truth</h2>" | |
996 "<div style='overflow-x:auto; margin-bottom:20px;'>" | |
997 + preds_html | |
998 + "</div>" | |
999 ) | |
1000 except Exception as e: | |
1001 logger.warning(f"Could not build Predictions vs GT table: {e}") | |
1002 # Test tab = Metrics + Preds table + Visualizations | |
1003 | |
1206 tab3_content = ( | 1004 tab3_content = ( |
1207 button_html | 1005 test_metrics_html |
1208 + test_metrics_html | 1006 + preds_section |
1209 + render_img_section("Test Visualizations", test_viz_dir, output_type) | 1007 + render_img_section("Test Visualizations", test_viz_dir, output_type) |
1210 ) | 1008 ) |
1211 | 1009 |
1010 # assemble the tabs and help modal | |
1212 tabbed_html = build_tabbed_html(tab1_content, tab2_content, tab3_content) | 1011 tabbed_html = build_tabbed_html(tab1_content, tab2_content, tab3_content) |
1213 modal_html = get_metrics_help_modal() | 1012 modal_html = get_metrics_help_modal() |
1214 html += tabbed_html + modal_html | 1013 html += tabbed_html + modal_html + get_html_closing() |
1215 html += get_html_closing() | |
1216 | 1014 |
1217 try: | 1015 try: |
1218 with open(report_path, "w") as f: | 1016 with open(report_path, "w") as f: |
1219 f.write(html) | 1017 f.write(html) |
1220 logger.info(f"HTML report generated at: {report_path}") | 1018 logger.info(f"HTML report generated at: {report_path}") |
1261 logger.info("Image extraction complete.") | 1059 logger.info("Image extraction complete.") |
1262 except Exception: | 1060 except Exception: |
1263 logger.error("Error extracting zip file", exc_info=True) | 1061 logger.error("Error extracting zip file", exc_info=True) |
1264 raise | 1062 raise |
1265 | 1063 |
1266 def _prepare_data(self) -> Tuple[Path, Dict[str, Any]]: | 1064 def _prepare_data(self) -> Tuple[Path, Dict[str, Any], str]: |
1267 """Load CSV, update image paths, handle splits, and write prepared CSV.""" | 1065 """Load CSV, update image paths, handle splits, and write prepared CSV.""" |
1268 if not self.temp_dir or not self.image_extract_dir: | 1066 if not self.temp_dir or not self.image_extract_dir: |
1269 raise RuntimeError("Temp dirs not initialized before data prep.") | 1067 raise RuntimeError("Temp dirs not initialized before data prep.") |
1270 | 1068 |
1271 try: | 1069 try: |
1300 f"No split column in CSV. Used random split: " | 1098 f"No split column in CSV. Used random split: " |
1301 f"{[int(p * 100) for p in self.args.split_probabilities]}% " | 1099 f"{[int(p * 100) for p in self.args.split_probabilities]}% " |
1302 f"for train/val/test." | 1100 f"for train/val/test." |
1303 ) | 1101 ) |
1304 | 1102 |
1305 final_csv = TEMP_CSV_FILENAME | 1103 final_csv = self.temp_dir / TEMP_CSV_FILENAME |
1306 try: | 1104 try: |
1105 | |
1307 df.to_csv(final_csv, index=False) | 1106 df.to_csv(final_csv, index=False) |
1308 logger.info(f"Saved prepared data to {final_csv}") | 1107 logger.info(f"Saved prepared data to {final_csv}") |
1309 except Exception: | 1108 except Exception: |
1310 logger.error("Error saving prepared CSV", exc_info=True) | 1109 logger.error("Error saving prepared CSV", exc_info=True) |
1311 raise | 1110 raise |
1312 | 1111 |
1313 return final_csv, split_config, split_info | 1112 return final_csv, split_config, split_info |
1314 | 1113 |
1315 def _process_fixed_split(self, df: pd.DataFrame) -> Dict[str, Any]: | 1114 def _process_fixed_split( |
1115 self, df: pd.DataFrame | |
1116 ) -> Tuple[pd.DataFrame, Dict[str, Any], str]: | |
1316 """Process a fixed split column (0=train,1=val,2=test).""" | 1117 """Process a fixed split column (0=train,1=val,2=test).""" |
1317 logger.info(f"Fixed split column '{SPLIT_COLUMN_NAME}' detected.") | 1118 logger.info(f"Fixed split column '{SPLIT_COLUMN_NAME}' detected.") |
1318 try: | 1119 try: |
1319 col = df[SPLIT_COLUMN_NAME] | 1120 col = df[SPLIT_COLUMN_NAME] |
1320 df[SPLIT_COLUMN_NAME] = pd.to_numeric(col, errors="coerce").astype( | 1121 df[SPLIT_COLUMN_NAME] = pd.to_numeric(col, errors="coerce").astype( |
1382 "split_probabilities": self.args.split_probabilities, | 1183 "split_probabilities": self.args.split_probabilities, |
1383 "learning_rate": self.args.learning_rate, | 1184 "learning_rate": self.args.learning_rate, |
1384 "random_seed": self.args.random_seed, | 1185 "random_seed": self.args.random_seed, |
1385 "early_stop": self.args.early_stop, | 1186 "early_stop": self.args.early_stop, |
1386 "label_column_data_path": csv_path, | 1187 "label_column_data_path": csv_path, |
1188 "augmentation": self.args.augmentation, | |
1387 } | 1189 } |
1388 yaml_str = self.backend.prepare_config(backend_args, split_cfg) | 1190 yaml_str = self.backend.prepare_config(backend_args, split_cfg) |
1389 | 1191 |
1390 config_file = self.temp_dir / TEMP_CONFIG_FILENAME | 1192 config_file = self.temp_dir / TEMP_CONFIG_FILENAME |
1391 config_file.write_text(yaml_str) | 1193 config_file.write_text(yaml_str) |
1420 return float(s) | 1222 return float(s) |
1421 except (TypeError, ValueError): | 1223 except (TypeError, ValueError): |
1422 return None | 1224 return None |
1423 | 1225 |
1424 | 1226 |
1227 def aug_parse(aug_string: str): | |
1228 """ | |
1229 Parse comma-separated augmentation keys into Ludwig augmentation dicts. | |
1230 Raises ValueError on unknown key. | |
1231 """ | |
1232 mapping = { | |
1233 "random_horizontal_flip": {"type": "random_horizontal_flip"}, | |
1234 "random_vertical_flip": {"type": "random_vertical_flip"}, | |
1235 "random_rotate": {"type": "random_rotate", "degree": 10}, | |
1236 "random_blur": {"type": "random_blur", "kernel_size": 3}, | |
1237 "random_brightness": {"type": "random_brightness", "min": 0.5, "max": 2.0}, | |
1238 "random_contrast": {"type": "random_contrast", "min": 0.5, "max": 2.0}, | |
1239 } | |
1240 aug_list = [] | |
1241 for tok in aug_string.split(","): | |
1242 key = tok.strip() | |
1243 if key not in mapping: | |
1244 valid = ", ".join(mapping.keys()) | |
1245 raise ValueError(f"Unknown augmentation '{key}'. Valid choices: {valid}") | |
1246 aug_list.append(mapping[key]) | |
1247 return aug_list | |
1248 | |
1249 | |
1425 class SplitProbAction(argparse.Action): | 1250 class SplitProbAction(argparse.Action): |
1426 def __call__(self, parser, namespace, values, option_string=None): | 1251 def __call__(self, parser, namespace, values, option_string=None): |
1427 train, val, test = values | 1252 train, val, test = values |
1428 total = train + val + test | 1253 total = train + val + test |
1429 if abs(total - 1.0) > 1e-6: | 1254 if abs(total - 1.0) > 1e-6: |
1506 type=float, | 1331 type=float, |
1507 nargs=3, | 1332 nargs=3, |
1508 metavar=("train", "val", "test"), | 1333 metavar=("train", "val", "test"), |
1509 action=SplitProbAction, | 1334 action=SplitProbAction, |
1510 default=[0.7, 0.1, 0.2], | 1335 default=[0.7, 0.1, 0.2], |
1511 help="Random split proportions (e.g., 0.7 0.1 0.2). Only used if no split column.", | 1336 help=( |
1337 "Random split proportions (e.g., 0.7 0.1 0.2)." | |
1338 "Only used if no split column." | |
1339 ), | |
1512 ) | 1340 ) |
1513 parser.add_argument( | 1341 parser.add_argument( |
1514 "--random-seed", | 1342 "--random-seed", |
1515 type=int, | 1343 type=int, |
1516 default=42, | 1344 default=42, |
1519 parser.add_argument( | 1347 parser.add_argument( |
1520 "--learning-rate", | 1348 "--learning-rate", |
1521 type=parse_learning_rate, | 1349 type=parse_learning_rate, |
1522 default=None, | 1350 default=None, |
1523 help="Learning rate. If not provided, Ludwig will auto-select it.", | 1351 help="Learning rate. If not provided, Ludwig will auto-select it.", |
1352 ) | |
1353 parser.add_argument( | |
1354 "--augmentation", | |
1355 type=str, | |
1356 default=None, | |
1357 help=( | |
1358 "Comma-separated list (in order) of any of: " | |
1359 "random_horizontal_flip, random_vertical_flip, random_rotate, " | |
1360 "random_blur, random_brightness, random_contrast. " | |
1361 "E.g. --augmentation random_horizontal_flip,random_rotate" | |
1362 ), | |
1524 ) | 1363 ) |
1525 | 1364 |
1526 args = parser.parse_args() | 1365 args = parser.parse_args() |
1527 | 1366 |
1528 if not 0.0 <= args.validation_size <= 1.0: | 1367 if not 0.0 <= args.validation_size <= 1.0: |
1529 parser.error("validation-size must be between 0.0 and 1.0") | 1368 parser.error("validation-size must be between 0.0 and 1.0") |
1530 if not args.csv_file.is_file(): | 1369 if not args.csv_file.is_file(): |
1531 parser.error(f"CSV not found: {args.csv_file}") | 1370 parser.error(f"CSV not found: {args.csv_file}") |
1532 if not args.image_zip.is_file(): | 1371 if not args.image_zip.is_file(): |
1533 parser.error(f"ZIP not found: {args.image_zip}") | 1372 parser.error(f"ZIP not found: {args.image_zip}") |
1373 if args.augmentation is not None: | |
1374 try: | |
1375 augmentation_setup = aug_parse(args.augmentation) | |
1376 setattr(args, "augmentation", augmentation_setup) | |
1377 except ValueError as e: | |
1378 parser.error(str(e)) | |
1534 | 1379 |
1535 backend_instance = LudwigDirectBackend() | 1380 backend_instance = LudwigDirectBackend() |
1536 orchestrator = WorkflowOrchestrator(args, backend_instance) | 1381 orchestrator = WorkflowOrchestrator(args, backend_instance) |
1537 | 1382 |
1538 exit_code = 0 | 1383 exit_code = 0 |