comparison image_learner_cli.py @ 1:39202fe5cf97 draft

planemo upload for repository https://github.com/goeckslab/gleam.git commit 06c0da44ac93256dfb616a6b40276b5485a71e8e
author goeckslab
date Wed, 02 Jul 2025 18:59:10 +0000
parents 54b871dfc51e
children 186424a7eca7
comparison
equal deleted inserted replaced
0:54b871dfc51e 1:39202fe5cf97
22 from ludwig.visualize import get_visualizations_registry 22 from ludwig.visualize import get_visualizations_registry
23 from sklearn.model_selection import train_test_split 23 from sklearn.model_selection import train_test_split
24 from utils import encode_image_to_base64, get_html_closing, get_html_template 24 from utils import encode_image_to_base64, get_html_closing, get_html_template
25 25
26 # --- Constants --- 26 # --- Constants ---
27 SPLIT_COLUMN_NAME = 'split' 27 SPLIT_COLUMN_NAME = "split"
28 LABEL_COLUMN_NAME = 'label' 28 LABEL_COLUMN_NAME = "label"
29 IMAGE_PATH_COLUMN_NAME = 'image_path' 29 IMAGE_PATH_COLUMN_NAME = "image_path"
30 DEFAULT_SPLIT_PROBABILITIES = [0.7, 0.1, 0.2] 30 DEFAULT_SPLIT_PROBABILITIES = [0.7, 0.1, 0.2]
31 TEMP_CSV_FILENAME = "processed_data_for_ludwig.csv" 31 TEMP_CSV_FILENAME = "processed_data_for_ludwig.csv"
32 TEMP_CONFIG_FILENAME = "ludwig_config.yaml" 32 TEMP_CONFIG_FILENAME = "ludwig_config.yaml"
33 TEMP_DIR_PREFIX = "ludwig_api_work_" 33 TEMP_DIR_PREFIX = "ludwig_api_work_"
34 MODEL_ENCODER_TEMPLATES: Dict[str, Any] = { 34 MODEL_ENCODER_TEMPLATES: Dict[str, Any] = {
35 'stacked_cnn': 'stacked_cnn', 35 "stacked_cnn": "stacked_cnn",
36 'resnet18': {'type': 'resnet', 'model_variant': 18}, 36 "resnet18": {"type": "resnet", "model_variant": 18},
37 'resnet34': {'type': 'resnet', 'model_variant': 34}, 37 "resnet34": {"type": "resnet", "model_variant": 34},
38 'resnet50': {'type': 'resnet', 'model_variant': 50}, 38 "resnet50": {"type": "resnet", "model_variant": 50},
39 'resnet101': {'type': 'resnet', 'model_variant': 101}, 39 "resnet101": {"type": "resnet", "model_variant": 101},
40 'resnet152': {'type': 'resnet', 'model_variant': 152}, 40 "resnet152": {"type": "resnet", "model_variant": 152},
41 'resnext50_32x4d': {'type': 'resnext', 'model_variant': '50_32x4d'}, 41 "resnext50_32x4d": {"type": "resnext", "model_variant": "50_32x4d"},
42 'resnext101_32x8d': {'type': 'resnext', 'model_variant': '101_32x8d'}, 42 "resnext101_32x8d": {"type": "resnext", "model_variant": "101_32x8d"},
43 'resnext101_64x4d': {'type': 'resnext', 'model_variant': '101_64x4d'}, 43 "resnext101_64x4d": {"type": "resnext", "model_variant": "101_64x4d"},
44 'resnext152_32x8d': {'type': 'resnext', 'model_variant': '152_32x8d'}, 44 "resnext152_32x8d": {"type": "resnext", "model_variant": "152_32x8d"},
45 'wide_resnet50_2': {'type': 'wide_resnet', 'model_variant': '50_2'}, 45 "wide_resnet50_2": {"type": "wide_resnet", "model_variant": "50_2"},
46 'wide_resnet101_2': {'type': 'wide_resnet', 'model_variant': '101_2'}, 46 "wide_resnet101_2": {"type": "wide_resnet", "model_variant": "101_2"},
47 'wide_resnet103_2': {'type': 'wide_resnet', 'model_variant': '103_2'}, 47 "wide_resnet103_2": {"type": "wide_resnet", "model_variant": "103_2"},
48 'efficientnet_b0': {'type': 'efficientnet', 'model_variant': 'b0'}, 48 "efficientnet_b0": {"type": "efficientnet", "model_variant": "b0"},
49 'efficientnet_b1': {'type': 'efficientnet', 'model_variant': 'b1'}, 49 "efficientnet_b1": {"type": "efficientnet", "model_variant": "b1"},
50 'efficientnet_b2': {'type': 'efficientnet', 'model_variant': 'b2'}, 50 "efficientnet_b2": {"type": "efficientnet", "model_variant": "b2"},
51 'efficientnet_b3': {'type': 'efficientnet', 'model_variant': 'b3'}, 51 "efficientnet_b3": {"type": "efficientnet", "model_variant": "b3"},
52 'efficientnet_b4': {'type': 'efficientnet', 'model_variant': 'b4'}, 52 "efficientnet_b4": {"type": "efficientnet", "model_variant": "b4"},
53 'efficientnet_b5': {'type': 'efficientnet', 'model_variant': 'b5'}, 53 "efficientnet_b5": {"type": "efficientnet", "model_variant": "b5"},
54 'efficientnet_b6': {'type': 'efficientnet', 'model_variant': 'b6'}, 54 "efficientnet_b6": {"type": "efficientnet", "model_variant": "b6"},
55 'efficientnet_b7': {'type': 'efficientnet', 'model_variant': 'b7'}, 55 "efficientnet_b7": {"type": "efficientnet", "model_variant": "b7"},
56 'efficientnet_v2_s': {'type': 'efficientnet', 'model_variant': 'v2_s'}, 56 "efficientnet_v2_s": {"type": "efficientnet", "model_variant": "v2_s"},
57 'efficientnet_v2_m': {'type': 'efficientnet', 'model_variant': 'v2_m'}, 57 "efficientnet_v2_m": {"type": "efficientnet", "model_variant": "v2_m"},
58 'efficientnet_v2_l': {'type': 'efficientnet', 'model_variant': 'v2_l'}, 58 "efficientnet_v2_l": {"type": "efficientnet", "model_variant": "v2_l"},
59 'regnet_y_400mf': {'type': 'regnet', 'model_variant': 'y_400mf'}, 59 "regnet_y_400mf": {"type": "regnet", "model_variant": "y_400mf"},
60 'regnet_y_800mf': {'type': 'regnet', 'model_variant': 'y_800mf'}, 60 "regnet_y_800mf": {"type": "regnet", "model_variant": "y_800mf"},
61 'regnet_y_1_6gf': {'type': 'regnet', 'model_variant': 'y_1_6gf'}, 61 "regnet_y_1_6gf": {"type": "regnet", "model_variant": "y_1_6gf"},
62 'regnet_y_3_2gf': {'type': 'regnet', 'model_variant': 'y_3_2gf'}, 62 "regnet_y_3_2gf": {"type": "regnet", "model_variant": "y_3_2gf"},
63 'regnet_y_8gf': {'type': 'regnet', 'model_variant': 'y_8gf'}, 63 "regnet_y_8gf": {"type": "regnet", "model_variant": "y_8gf"},
64 'regnet_y_16gf': {'type': 'regnet', 'model_variant': 'y_16gf'}, 64 "regnet_y_16gf": {"type": "regnet", "model_variant": "y_16gf"},
65 'regnet_y_32gf': {'type': 'regnet', 'model_variant': 'y_32gf'}, 65 "regnet_y_32gf": {"type": "regnet", "model_variant": "y_32gf"},
66 'regnet_y_128gf': {'type': 'regnet', 'model_variant': 'y_128gf'}, 66 "regnet_y_128gf": {"type": "regnet", "model_variant": "y_128gf"},
67 'regnet_x_400mf': {'type': 'regnet', 'model_variant': 'x_400mf'}, 67 "regnet_x_400mf": {"type": "regnet", "model_variant": "x_400mf"},
68 'regnet_x_800mf': {'type': 'regnet', 'model_variant': 'x_800mf'}, 68 "regnet_x_800mf": {"type": "regnet", "model_variant": "x_800mf"},
69 'regnet_x_1_6gf': {'type': 'regnet', 'model_variant': 'x_1_6gf'}, 69 "regnet_x_1_6gf": {"type": "regnet", "model_variant": "x_1_6gf"},
70 'regnet_x_3_2gf': {'type': 'regnet', 'model_variant': 'x_3_2gf'}, 70 "regnet_x_3_2gf": {"type": "regnet", "model_variant": "x_3_2gf"},
71 'regnet_x_8gf': {'type': 'regnet', 'model_variant': 'x_8gf'}, 71 "regnet_x_8gf": {"type": "regnet", "model_variant": "x_8gf"},
72 'regnet_x_16gf': {'type': 'regnet', 'model_variant': 'x_16gf'}, 72 "regnet_x_16gf": {"type": "regnet", "model_variant": "x_16gf"},
73 'regnet_x_32gf': {'type': 'regnet', 'model_variant': 'x_32gf'}, 73 "regnet_x_32gf": {"type": "regnet", "model_variant": "x_32gf"},
74 'vgg11': {'type': 'vgg', 'model_variant': 11}, 74 "vgg11": {"type": "vgg", "model_variant": 11},
75 'vgg11_bn': {'type': 'vgg', 'model_variant': '11_bn'}, 75 "vgg11_bn": {"type": "vgg", "model_variant": "11_bn"},
76 'vgg13': {'type': 'vgg', 'model_variant': 13}, 76 "vgg13": {"type": "vgg", "model_variant": 13},
77 'vgg13_bn': {'type': 'vgg', 'model_variant': '13_bn'}, 77 "vgg13_bn": {"type": "vgg", "model_variant": "13_bn"},
78 'vgg16': {'type': 'vgg', 'model_variant': 16}, 78 "vgg16": {"type": "vgg", "model_variant": 16},
79 'vgg16_bn': {'type': 'vgg', 'model_variant': '16_bn'}, 79 "vgg16_bn": {"type": "vgg", "model_variant": "16_bn"},
80 'vgg19': {'type': 'vgg', 'model_variant': 19}, 80 "vgg19": {"type": "vgg", "model_variant": 19},
81 'vgg19_bn': {'type': 'vgg', 'model_variant': '19_bn'}, 81 "vgg19_bn": {"type": "vgg", "model_variant": "19_bn"},
82 'shufflenet_v2_x0_5': {'type': 'shufflenet_v2', 'model_variant': 'x0_5'}, 82 "shufflenet_v2_x0_5": {"type": "shufflenet_v2", "model_variant": "x0_5"},
83 'shufflenet_v2_x1_0': {'type': 'shufflenet_v2', 'model_variant': 'x1_0'}, 83 "shufflenet_v2_x1_0": {"type": "shufflenet_v2", "model_variant": "x1_0"},
84 'shufflenet_v2_x1_5': {'type': 'shufflenet_v2', 'model_variant': 'x1_5'}, 84 "shufflenet_v2_x1_5": {"type": "shufflenet_v2", "model_variant": "x1_5"},
85 'shufflenet_v2_x2_0': {'type': 'shufflenet_v2', 'model_variant': 'x2_0'}, 85 "shufflenet_v2_x2_0": {"type": "shufflenet_v2", "model_variant": "x2_0"},
86 'squeezenet1_0': {'type': 'squeezenet', 'model_variant': '1_0'}, 86 "squeezenet1_0": {"type": "squeezenet", "model_variant": "1_0"},
87 'squeezenet1_1': {'type': 'squeezenet', 'model_variant': '1_1'}, 87 "squeezenet1_1": {"type": "squeezenet", "model_variant": "1_1"},
88 'swin_t': {'type': 'swin_transformer', 'model_variant': 't'}, 88 "swin_t": {"type": "swin_transformer", "model_variant": "t"},
89 'swin_s': {'type': 'swin_transformer', 'model_variant': 's'}, 89 "swin_s": {"type": "swin_transformer", "model_variant": "s"},
90 'swin_b': {'type': 'swin_transformer', 'model_variant': 'b'}, 90 "swin_b": {"type": "swin_transformer", "model_variant": "b"},
91 'swin_v2_t': {'type': 'swin_transformer', 'model_variant': 'v2_t'}, 91 "swin_v2_t": {"type": "swin_transformer", "model_variant": "v2_t"},
92 'swin_v2_s': {'type': 'swin_transformer', 'model_variant': 'v2_s'}, 92 "swin_v2_s": {"type": "swin_transformer", "model_variant": "v2_s"},
93 'swin_v2_b': {'type': 'swin_transformer', 'model_variant': 'v2_b'}, 93 "swin_v2_b": {"type": "swin_transformer", "model_variant": "v2_b"},
94 'vit_b_16': {'type': 'vision_transformer', 'model_variant': 'b_16'}, 94 "vit_b_16": {"type": "vision_transformer", "model_variant": "b_16"},
95 'vit_b_32': {'type': 'vision_transformer', 'model_variant': 'b_32'}, 95 "vit_b_32": {"type": "vision_transformer", "model_variant": "b_32"},
96 'vit_l_16': {'type': 'vision_transformer', 'model_variant': 'l_16'}, 96 "vit_l_16": {"type": "vision_transformer", "model_variant": "l_16"},
97 'vit_l_32': {'type': 'vision_transformer', 'model_variant': 'l_32'}, 97 "vit_l_32": {"type": "vision_transformer", "model_variant": "l_32"},
98 'vit_h_14': {'type': 'vision_transformer', 'model_variant': 'h_14'}, 98 "vit_h_14": {"type": "vision_transformer", "model_variant": "h_14"},
99 'convnext_tiny': {'type': 'convnext', 'model_variant': 'tiny'}, 99 "convnext_tiny": {"type": "convnext", "model_variant": "tiny"},
100 'convnext_small': {'type': 'convnext', 'model_variant': 'small'}, 100 "convnext_small": {"type": "convnext", "model_variant": "small"},
101 'convnext_base': {'type': 'convnext', 'model_variant': 'base'}, 101 "convnext_base": {"type": "convnext", "model_variant": "base"},
102 'convnext_large': {'type': 'convnext', 'model_variant': 'large'}, 102 "convnext_large": {"type": "convnext", "model_variant": "large"},
103 'maxvit_t': {'type': 'maxvit', 'model_variant': 't'}, 103 "maxvit_t": {"type": "maxvit", "model_variant": "t"},
104 'alexnet': {'type': 'alexnet'}, 104 "alexnet": {"type": "alexnet"},
105 'googlenet': {'type': 'googlenet'}, 105 "googlenet": {"type": "googlenet"},
106 'inception_v3': {'type': 'inception_v3'}, 106 "inception_v3": {"type": "inception_v3"},
107 'mobilenet_v2': {'type': 'mobilenet_v2'}, 107 "mobilenet_v2": {"type": "mobilenet_v2"},
108 'mobilenet_v3_large': {'type': 'mobilenet_v3_large'}, 108 "mobilenet_v3_large": {"type": "mobilenet_v3_large"},
109 'mobilenet_v3_small': {'type': 'mobilenet_v3_small'}, 109 "mobilenet_v3_small": {"type": "mobilenet_v3_small"},
110 }
111 METRIC_DISPLAY_NAMES = {
112 "accuracy": "Accuracy",
113 "accuracy_micro": "Accuracy-Micro",
114 "loss": "Loss",
115 "roc_auc": "ROC-AUC",
116 "roc_auc_macro": "ROC-AUC-Macro",
117 "roc_auc_micro": "ROC-AUC-Micro",
118 "hits_at_k": "Hits at K",
119 "precision": "Precision",
120 "recall": "Recall",
121 "specificity": "Specificity",
122 "kappa_score": "Cohen's Kappa",
123 "token_accuracy": "Token Accuracy",
124 "avg_precision_macro": "Precision-Macro",
125 "avg_recall_macro": "Recall-Macro",
126 "avg_f1_score_macro": "F1-score-Macro",
127 "avg_precision_micro": "Precision-Micro",
128 "avg_recall_micro": "Recall-Micro",
129 "avg_f1_score_micro": "F1-score-Micro",
130 "avg_precision_weighted": "Precision-Weighted",
131 "avg_recall_weighted": "Recall-Weighted",
132 "avg_f1_score_weighted": "F1-score-Weighted",
133 "average_precision_macro": " Precision-Average-Macro",
134 "average_precision_micro": "Precision-Average-Micro",
135 "average_precision_samples": "Precision-Average-Samples",
110 } 136 }
111 137
112 # --- Logging Setup --- 138 # --- Logging Setup ---
113 logging.basicConfig( 139 logging.basicConfig(
114 level=logging.INFO, 140 level=logging.INFO,
115 format='%(asctime)s %(levelname)s %(name)s: %(message)s' 141 format="%(asctime)s %(levelname)s %(name)s: %(message)s",
116 ) 142 )
117 logger = logging.getLogger("ImageLearner") 143 logger = logging.getLogger("ImageLearner")
118 144
119 145
146 def get_metrics_help_modal() -> str:
147 modal_html = """
148 <div id="metricsHelpModal" class="modal">
149 <div class="modal-content">
150 <span class="close">×</span>
151 <h2>Model Evaluation Metrics — Help Guide</h2>
152 <div class="metrics-guide">
153 <h3>1) General Metrics</h3>
154 <p><strong>Loss:</strong> Measures the difference between predicted and actual values. Lower is better. Often used for optimization during training.</p>
155 <p><strong>Accuracy:</strong> Proportion of correct predictions among all predictions. Simple but can be misleading for imbalanced datasets.</p>
156 <p><strong>Micro Accuracy:</strong> Calculates accuracy by summing up all individual true positives and true negatives across all classes, making it suitable for multiclass or multilabel problems.</p>
157 <p><strong>Token Accuracy:</strong> Measures how often the predicted tokens (e.g., in sequences) match the true tokens. Useful in sequence prediction tasks like NLP.</p>
158 <h3>2) Precision, Recall & Specificity</h3>
159 <p><strong>Precision:</strong> Out of all positive predictions, how many were correct. Precision = TP / (TP + FP). Helps when false positives are costly.</p>
160 <p><strong>Recall (Sensitivity):</strong> Out of all actual positives, how many were predicted correctly. Recall = TP / (TP + FN). Important when missing positives is risky.</p>
161 <p><strong>Specificity:</strong> True negative rate. Measures how well the model identifies negatives. Specificity = TN / (TN + FP). Useful in medical testing to avoid false alarms.</p>
162 <h3>3) Macro, Micro, and Weighted Averages</h3>
163 <p><strong>Macro Precision / Recall / F1:</strong> Averages the metric across all classes, treating each class equally, regardless of class frequency. Best when class sizes are balanced.</p>
164 <p><strong>Micro Precision / Recall / F1:</strong> Aggregates TP, FP, FN across all classes before computing the metric. Gives a global view and is ideal for class-imbalanced problems.</p>
165 <p><strong>Weighted Precision / Recall / F1:</strong> Averages each metric across classes, weighted by the number of true instances per class. Balances importance of classes based on frequency.</p>
166 <h3>4) Average Precision (PR-AUC Variants)</h3>
167 <p><strong>Average Precision Macro:</strong> Precision-Recall AUC averaged across all classes equally. Useful for balanced multi-class problems.</p>
168 <p><strong>Average Precision Micro:</strong> Global Precision-Recall AUC using all instances. Best for imbalanced data or multi-label classification.</p>
169 <p><strong>Average Precision Samples:</strong> Precision-Recall AUC averaged across individual samples (not classes). Ideal for multi-label problems where each sample can belong to multiple classes.</p>
170 <h3>5) ROC-AUC Variants</h3>
171 <p><strong>ROC-AUC:</strong> Measures model's ability to distinguish between classes. AUC = 1 is perfect; 0.5 is random guessing. Use for binary classification.</p>
172 <p><strong>Macro ROC-AUC:</strong> Averages the AUC across all classes equally. Suitable when classes are balanced and of equal importance.</p>
173 <p><strong>Micro ROC-AUC:</strong> Computes AUC from aggregated predictions across all classes. Useful in multiclass or multilabel settings with imbalance.</p>
174 <h3>6) Ranking Metrics</h3>
175 <p><strong>Hits at K:</strong> Measures whether the true label is among the top-K predictions. Common in recommendation systems and retrieval tasks.</p>
176 <h3>7) Confusion Matrix Stats (Per Class)</h3>
177 <p><strong>True Positives / Negatives (TP / TN):</strong> Correct predictions for positives and negatives respectively.</p>
178 <p><strong>False Positives / Negatives (FP / FN):</strong> Incorrect predictions — false alarms and missed detections.</p>
179 <h3>8) Other Useful Metrics</h3>
180 <p><strong>Cohen's Kappa:</strong> Measures agreement between predicted and actual values adjusted for chance. Useful for multiclass classification with imbalanced labels.</p>
181 <p><strong>Matthews Correlation Coefficient (MCC):</strong> Balanced measure of prediction quality that takes into account TP, TN, FP, and FN. Particularly effective for imbalanced datasets.</p>
182 <h3>9) Metric Recommendations</h3>
183 <ul>
184 <li>Use <strong>Accuracy + F1</strong> for balanced data.</li>
185 <li>Use <strong>Precision, Recall, ROC-AUC</strong> for imbalanced datasets.</li>
186 <li>Use <strong>Average Precision Micro</strong> for multilabel or class-imbalanced problems.</li>
187 <li>Use <strong>Macro scores</strong> when all classes should be treated equally.</li>
188 <li>Use <strong>Weighted scores</strong> when class imbalance should be accounted for without ignoring small classes.</li>
189 <li>Use <strong>Confusion Matrix stats</strong> to analyze class-wise performance.</li>
190 <li>Use <strong>Hits at K</strong> for recommendation or ranking-based tasks.</li>
191 </ul>
192 </div>
193 </div>
194 </div>
195 """
196 modal_css = """
197 <style>
198 .modal {
199 display: none;
200 position: fixed;
201 z-index: 1;
202 left: 0;
203 top: 0;
204 width: 100%;
205 height: 100%;
206 overflow: auto;
207 background-color: rgba(0,0,0,0.4);
208 }
209 .modal-content {
210 background-color: #fefefe;
211 margin: 15% auto;
212 padding: 20px;
213 border: 1px solid #888;
214 width: 80%;
215 max-width: 800px;
216 }
217 .close {
218 color: #aaa;
219 float: right;
220 font-size: 28px;
221 font-weight: bold;
222 }
223 .close:hover,
224 .close:focus {
225 color: black;
226 text-decoration: none;
227 cursor: pointer;
228 }
229 .metrics-guide h3 {
230 margin-top: 20px;
231 }
232 .metrics-guide p {
233 margin: 5px 0;
234 }
235 .metrics-guide ul {
236 margin: 10px 0;
237 padding-left: 20px;
238 }
239 </style>
240 """
241 modal_js = """
242 <script>
243 document.addEventListener("DOMContentLoaded", function() {
244 var modal = document.getElementById("metricsHelpModal");
245 var closeBtn = document.getElementsByClassName("close")[0];
246
247 document.querySelectorAll(".openMetricsHelp").forEach(btn => {
248 btn.onclick = function() {
249 modal.style.display = "block";
250 };
251 });
252
253 if (closeBtn) {
254 closeBtn.onclick = function() {
255 modal.style.display = "none";
256 };
257 }
258
259 window.onclick = function(event) {
260 if (event.target == modal) {
261 modal.style.display = "none";
262 }
263 }
264 });
265 </script>
266 """
267 return modal_css + modal_html + modal_js
268
269
120 def format_config_table_html( 270 def format_config_table_html(
121 config: dict, 271 config: dict,
122 split_info: Optional[str] = None, 272 split_info: Optional[str] = None,
123 training_progress: dict = None) -> str: 273 training_progress: dict = None,
274 ) -> str:
124 display_keys = [ 275 display_keys = [
125 "model_name", 276 "model_name",
126 "epochs", 277 "epochs",
127 "batch_size", 278 "batch_size",
128 "fine_tune", 279 "fine_tune",
141 val = int(val) 292 val = int(val)
142 else: 293 else:
143 if training_progress: 294 if training_progress:
144 val = "Auto-selected batch size by Ludwig:<br>" 295 val = "Auto-selected batch size by Ludwig:<br>"
145 resolved_val = training_progress.get("batch_size") 296 resolved_val = training_progress.get("batch_size")
146 val += ( 297 val += f"<span style='font-size: 0.85em;'>{resolved_val}</span><br>"
147 f"<span style='font-size: 0.85em;'>{resolved_val}</span><br>"
148 )
149 else: 298 else:
150 val = "auto" 299 val = "auto"
151 if key == "learning_rate": 300 if key == "learning_rate":
152 resolved_val = None 301 resolved_val = None
153 if val is None or val == "auto": 302 if val is None or val == "auto":
154 if training_progress: 303 if training_progress:
155 resolved_val = training_progress.get("learning_rate") 304 resolved_val = training_progress.get("learning_rate")
156 val = ( 305 val = (
157 "Auto-selected learning rate by Ludwig:<br>" 306 "Auto-selected learning rate by Ludwig:<br>"
158 f"<span style='font-size: 0.85em;'>{resolved_val if resolved_val else val}</span><br>" 307 f"<span style='font-size: 0.85em;'>"
308 f"{resolved_val if resolved_val else val}</span><br>"
159 "<span style='font-size: 0.85em;'>" 309 "<span style='font-size: 0.85em;'>"
160 "Based on model architecture and training setup (e.g., fine-tuning).<br>" 310 "Based on model architecture and training setup "
161 "See <a href='https://ludwig.ai/latest/configuration/trainer/#trainer-parameters' " 311 "(e.g., fine-tuning).<br>"
162 "target='_blank'>Ludwig Trainer Parameters</a> for details." 312 "See <a href='https://ludwig.ai/latest/configuration/trainer/"
313 "#trainer-parameters' target='_blank'>"
314 "Ludwig Trainer Parameters</a> for details."
163 "</span>" 315 "</span>"
164 ) 316 )
165 else: 317 else:
166 val = ( 318 val = (
167 "Auto-selected by Ludwig<br>" 319 "Auto-selected by Ludwig<br>"
168 "<span style='font-size: 0.85em;'>" 320 "<span style='font-size: 0.85em;'>"
169 "Automatically tuned based on architecture and dataset.<br>" 321 "Automatically tuned based on architecture and dataset.<br>"
170 "See <a href='https://ludwig.ai/latest/configuration/trainer/#trainer-parameters' " 322 "See <a href='https://ludwig.ai/latest/configuration/trainer/"
171 "target='_blank'>Ludwig Trainer Parameters</a> for details." 323 "#trainer-parameters' target='_blank'>"
324 "Ludwig Trainer Parameters</a> for details."
172 "</span>" 325 "</span>"
173 ) 326 )
174 else: 327 else:
175 val = f"{val:.6f}" 328 val = f"{val:.6f}"
176 if key == "epochs": 329 if key == "epochs":
177 if training_progress and "epoch" in training_progress and val > training_progress["epoch"]: 330 if (
331 training_progress
332 and "epoch" in training_progress
333 and val > training_progress["epoch"]
334 ):
178 val = ( 335 val = (
179 f"Because of early stopping: the training" 336 f"Because of early stopping: the training "
180 f"stopped at epoch {training_progress['epoch']}" 337 f"stopped at epoch {training_progress['epoch']}"
181 ) 338 )
182 339
183 if val is None: 340 if val is None:
184 continue 341 continue
185 rows.append( 342 rows.append(
186 f"<tr>" 343 f"<tr>"
187 f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left;'>" 344 f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left;'>"
188 f"{key.replace('_', ' ').title()}</td>" 345 f"{key.replace('_', ' ').title()}</td>"
189 f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center;'>{val}</td>" 346 f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center;'>"
347 f"{val}</td>"
190 f"</tr>" 348 f"</tr>"
191 ) 349 )
192 350
193 if split_info: 351 if split_info:
194 rows.append( 352 rows.append(
195 f"<tr>" 353 f"<tr>"
196 f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left;'>Data Split</td>" 354 f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left;'>"
197 f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center;'>{split_info}</td>" 355 f"Data Split</td>"
356 f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center;'>"
357 f"{split_info}</td>"
198 f"</tr>" 358 f"</tr>"
199 ) 359 )
200 360
201 return ( 361 return (
202 "<h2 style='text-align: center;'>Training Setup</h2>" 362 "<h2 style='text-align: center;'>Training Setup</h2>"
203 "<div style='display: flex; justify-content: center;'>" 363 "<div style='display: flex; justify-content: center;'>"
204 "<table style='border-collapse: collapse; width: 60%; table-layout: auto;'>" 364 "<table style='border-collapse: collapse; width: 60%; table-layout: auto;'>"
205 "<thead><tr>" 365 "<thead><tr>"
206 "<th style='padding: 10px; border: 1px solid #ccc; text-align: left;'>Parameter</th>" 366 "<th style='padding: 10px; border: 1px solid #ccc; text-align: left;'>"
207 "<th style='padding: 10px; border: 1px solid #ccc; text-align: center;'>Value</th>" 367 "Parameter</th>"
368 "<th style='padding: 10px; border: 1px solid #ccc; text-align: center;'>"
369 "Value</th>"
208 "</tr></thead><tbody>" + "".join(rows) + "</tbody></table></div><br>" 370 "</tr></thead><tbody>" + "".join(rows) + "</tbody></table></div><br>"
209 "<p style='text-align: center; font-size: 0.9em;'>" 371 "<p style='text-align: center; font-size: 0.9em;'>"
210 "Model trained using Ludwig.<br>" 372 "Model trained using Ludwig.<br>"
211 "If want to learn more about Ludwig default settings," 373 "If want to learn more about Ludwig default settings,"
212 "please check the their <a href='https://ludwig.ai' target='_blank'>website(ludwig.ai)</a>." 374 "please check the their <a href='https://ludwig.ai' target='_blank'>"
375 "website(ludwig.ai)</a>."
213 "</p><hr>" 376 "</p><hr>"
214 ) 377 )
215 378
216 379
217 def format_stats_table_html(training_stats: dict, test_stats: dict) -> str: 380 def detect_output_type(test_stats):
218 train_metrics = training_stats.get("training", {}).get("label", {}) 381 """Detects if the output type is 'binary' or 'category' based on test statistics."""
219 val_metrics = training_stats.get("validation", {}).get("label", {}) 382 label_stats = test_stats.get("label", {})
220 test_metrics = test_stats.get("label", {}) 383 per_class = label_stats.get("per_class_stats", {})
221 384 if len(per_class) == 2:
222 all_metrics = set(train_metrics) | set(val_metrics) | set(test_metrics) 385 return "binary"
386 return "category"
387
388
389 def extract_metrics_from_json(
390 train_stats: dict,
391 test_stats: dict,
392 output_type: str,
393 ) -> dict:
394 """Extracts relevant metrics from training and test statistics based on the output type."""
395 metrics = {"training": {}, "validation": {}, "test": {}}
223 396
224 def get_last_value(stats, key): 397 def get_last_value(stats, key):
225 val = stats.get(key) 398 val = stats.get(key)
226 if isinstance(val, list) and val: 399 if isinstance(val, list) and val:
227 return val[-1] 400 return val[-1]
228 elif isinstance(val, (int, float)): 401 elif isinstance(val, (int, float)):
229 return val 402 return val
230 return None 403 return None
231 404
405 for split in ["training", "validation"]:
406 split_stats = train_stats.get(split, {})
407 if not split_stats:
408 logging.warning(f"No statistics found for {split} split")
409 continue
410 label_stats = split_stats.get("label", {})
411 if not label_stats:
412 logging.warning(f"No label statistics found for {split} split")
413 continue
414 if output_type == "binary":
415 metrics[split] = {
416 "accuracy": get_last_value(label_stats, "accuracy"),
417 "loss": get_last_value(label_stats, "loss"),
418 "precision": get_last_value(label_stats, "precision"),
419 "recall": get_last_value(label_stats, "recall"),
420 "specificity": get_last_value(label_stats, "specificity"),
421 "roc_auc": get_last_value(label_stats, "roc_auc"),
422 }
423 else:
424 metrics[split] = {
425 "accuracy": get_last_value(label_stats, "accuracy"),
426 "accuracy_micro": get_last_value(label_stats, "accuracy_micro"),
427 "loss": get_last_value(label_stats, "loss"),
428 "roc_auc": get_last_value(label_stats, "roc_auc"),
429 "hits_at_k": get_last_value(label_stats, "hits_at_k"),
430 }
431
432 # Test metrics: dynamic extraction according to exclusions
433 test_label_stats = test_stats.get("label", {})
434 if not test_label_stats:
435 logging.warning("No label statistics found for test split")
436 else:
437 combined_stats = test_stats.get("combined", {})
438 overall_stats = test_label_stats.get("overall_stats", {})
439
440 # Define exclusions
441 if output_type == "binary":
442 exclude = {"per_class_stats", "precision_recall_curve", "roc_curve"}
443 else:
444 exclude = {"per_class_stats", "confusion_matrix"}
445
446 # 1. Get all scalar test_label_stats not excluded
447 test_metrics = {}
448 for k, v in test_label_stats.items():
449 if k in exclude:
450 continue
451 if k == "overall_stats":
452 continue
453 if isinstance(v, (int, float, str, bool)):
454 test_metrics[k] = v
455
456 # 2. Add overall_stats (flattened)
457 for k, v in overall_stats.items():
458 test_metrics[k] = v
459
460 # 3. Optionally include combined/loss if present and not already
461 if "loss" in combined_stats and "loss" not in test_metrics:
462 test_metrics["loss"] = combined_stats["loss"]
463
464 metrics["test"] = test_metrics
465
466 return metrics
467
468
469 def generate_table_row(cells, styles):
470 """Helper function to generate an HTML table row."""
471 return (
472 "<tr>"
473 + "".join(f"<td style='{styles}'>{cell}</td>" for cell in cells)
474 + "</tr>"
475 )
476
477
478 def format_stats_table_html(train_stats: dict, test_stats: dict) -> str:
479 """Formats a combined HTML table for training, validation, and test metrics."""
480 output_type = detect_output_type(test_stats)
481 all_metrics = extract_metrics_from_json(train_stats, test_stats, output_type)
232 rows = [] 482 rows = []
233 for metric in sorted(all_metrics): 483 for metric_key in sorted(all_metrics["training"].keys()):
234 t = get_last_value(train_metrics, metric) 484 if (
235 v = get_last_value(val_metrics, metric) 485 metric_key in all_metrics["validation"]
236 te = get_last_value(test_metrics, metric) 486 and metric_key in all_metrics["test"]
237 if all(x is not None for x in [t, v, te]): 487 ):
238 row = ( 488 display_name = METRIC_DISPLAY_NAMES.get(
239 f"<tr>" 489 metric_key,
240 f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left;'>{metric}</td>" 490 metric_key.replace("_", " ").title(),
241 f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center;'>{t:.4f}</td>" 491 )
242 f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center;'>{v:.4f}</td>" 492 t = all_metrics["training"].get(metric_key)
243 f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center;'>{te:.4f}</td>" 493 v = all_metrics["validation"].get(metric_key)
244 f"</tr>" 494 te = all_metrics["test"].get(metric_key)
245 ) 495 if all(x is not None for x in [t, v, te]):
246 rows.append(row) 496 rows.append([display_name, f"{t:.4f}", f"{v:.4f}", f"{te:.4f}"])
247 497
248 if not rows: 498 if not rows:
249 return "<p><em>No metric values found.</em></p>" 499 return "<table><tr><td>No metric values found.</td></tr></table>"
250 500
251 return ( 501 html = (
252 "<h2 style='text-align: center;'>Model Performance Summary</h2>" 502 "<h2 style='text-align: center;'>Model Performance Summary</h2>"
253 "<div style='display: flex; justify-content: center;'>" 503 "<div style='display: flex; justify-content: center;'>"
254 "<table style='border-collapse: collapse; width: 80%; table-layout: fixed;'>" 504 "<table style='border-collapse: collapse; table-layout: auto;'>"
255 "<colgroup>"
256 "<col style='width: 40%;'>"
257 "<col style='width: 20%;'>"
258 "<col style='width: 20%;'>"
259 "<col style='width: 20%;'>"
260 "</colgroup>"
261 "<thead><tr>" 505 "<thead><tr>"
262 "<th style='padding: 10px; border: 1px solid #ccc; text-align: left;'>Metric</th>" 506 "<th style='padding: 10px; border: 1px solid #ccc; text-align: left; "
263 "<th style='padding: 10px; border: 1px solid #ccc; text-align: center;'>Train</th>" 507 "white-space: nowrap;'>Metric</th>"
264 "<th style='padding: 10px; border: 1px solid #ccc; text-align: center;'>Validation</th>" 508 "<th style='padding: 10px; border: 1px solid #ccc; text-align: center; "
265 "<th style='padding: 10px; border: 1px solid #ccc; text-align: center;'>Test</th>" 509 "white-space: nowrap;'>Train</th>"
266 "</tr></thead><tbody>" + "".join(rows) + "</tbody></table></div><br>" 510 "<th style='padding: 10px; border: 1px solid #ccc; text-align: center; "
267 ) 511 "white-space: nowrap;'>Validation</th>"
268 512 "<th style='padding: 10px; border: 1px solid #ccc; text-align: center; "
269 513 "white-space: nowrap;'>Test</th>"
270 def build_tabbed_html( 514 "</tr></thead><tbody>"
271 metrics_html: str, 515 )
272 train_viz_html: str, 516 for row in rows:
273 test_viz_html: str) -> str: 517 html += generate_table_row(
518 row,
519 "padding: 10px; border: 1px solid #ccc; text-align: center; "
520 "white-space: nowrap;",
521 )
522 html += "</tbody></table></div><br>"
523 return html
524
525
526 def format_train_val_stats_table_html(train_stats: dict, test_stats: dict) -> str:
527 """Formats an HTML table for training and validation metrics."""
528 output_type = detect_output_type(test_stats)
529 all_metrics = extract_metrics_from_json(train_stats, test_stats, output_type)
530 rows = []
531 for metric_key in sorted(all_metrics["training"].keys()):
532 if metric_key in all_metrics["validation"]:
533 display_name = METRIC_DISPLAY_NAMES.get(
534 metric_key,
535 metric_key.replace("_", " ").title(),
536 )
537 t = all_metrics["training"].get(metric_key)
538 v = all_metrics["validation"].get(metric_key)
539 if t is not None and v is not None:
540 rows.append([display_name, f"{t:.4f}", f"{v:.4f}"])
541
542 if not rows:
543 return "<table><tr><td>No metric values found for Train/Validation.</td></tr></table>"
544
545 html = (
546 "<h2 style='text-align: center;'>Train/Validation Performance Summary</h2>"
547 "<div style='display: flex; justify-content: center;'>"
548 "<table style='border-collapse: collapse; table-layout: auto;'>"
549 "<thead><tr>"
550 "<th style='padding: 10px; border: 1px solid #ccc; text-align: left; "
551 "white-space: nowrap;'>Metric</th>"
552 "<th style='padding: 10px; border: 1px solid #ccc; text-align: center; "
553 "white-space: nowrap;'>Train</th>"
554 "<th style='padding: 10px; border: 1px solid #ccc; text-align: center; "
555 "white-space: nowrap;'>Validation</th>"
556 "</tr></thead><tbody>"
557 )
558 for row in rows:
559 html += generate_table_row(
560 row,
561 "padding: 10px; border: 1px solid #ccc; text-align: center; "
562 "white-space: nowrap;",
563 )
564 html += "</tbody></table></div><br>"
565 return html
566
567
568 def format_test_merged_stats_table_html(test_metrics: Dict[str, Optional[float]]) -> str:
569 """Formats an HTML table for test metrics."""
570 rows = []
571 for key in sorted(test_metrics.keys()):
572 display_name = METRIC_DISPLAY_NAMES.get(key, key.replace("_", " ").title())
573 value = test_metrics[key]
574 if value is not None:
575 rows.append([display_name, f"{value:.4f}"])
576
577 if not rows:
578 return "<table><tr><td>No test metric values found.</td></tr></table>"
579
580 html = (
581 "<h2 style='text-align: center;'>Test Performance Summary</h2>"
582 "<div style='display: flex; justify-content: center;'>"
583 "<table style='border-collapse: collapse; table-layout: auto;'>"
584 "<thead><tr>"
585 "<th style='padding: 10px; border: 1px solid #ccc; text-align: left; "
586 "white-space: nowrap;'>Metric</th>"
587 "<th style='padding: 10px; border: 1px solid #ccc; text-align: center; "
588 "white-space: nowrap;'>Test</th>"
589 "</tr></thead><tbody>"
590 )
591 for row in rows:
592 html += generate_table_row(
593 row,
594 "padding: 10px; border: 1px solid #ccc; text-align: center; "
595 "white-space: nowrap;",
596 )
597 html += "</tbody></table></div><br>"
598 return html
599
600
601 def build_tabbed_html(metrics_html: str, train_val_html: str, test_html: str) -> str:
274 return f""" 602 return f"""
275 <style> 603 <style>
276 .tabs {{ 604 .tabs {{
277 display: flex; 605 display: flex;
278 border-bottom: 2px solid #ccc; 606 border-bottom: 2px solid #ccc;
300 }} 628 }}
301 .tab-content.active {{ 629 .tab-content.active {{
302 display: block; 630 display: block;
303 }} 631 }}
304 </style> 632 </style>
305
306 <div class="tabs"> 633 <div class="tabs">
307 <div class="tab active" onclick="showTab('metrics')"> Config & Metrics</div> 634 <div class="tab active" onclick="showTab('metrics')"> Config & Results Summary</div>
308 <div class="tab" onclick="showTab('trainval')"> Train/Validation Plots</div> 635 <div class="tab" onclick="showTab('trainval')"> Train/Validation Results</div>
309 <div class="tab" onclick="showTab('test')"> Test Plots</div> 636 <div class="tab" onclick="showTab('test')"> Test Results</div>
310 </div> 637 </div>
311
312 <div id="metrics" class="tab-content active"> 638 <div id="metrics" class="tab-content active">
313 {metrics_html} 639 {metrics_html}
314 </div> 640 </div>
315 <div id="trainval" class="tab-content"> 641 <div id="trainval" class="tab-content">
316 {train_viz_html} 642 {train_val_html}
317 </div> 643 </div>
318 <div id="test" class="tab-content"> 644 <div id="test" class="tab-content">
319 {test_viz_html} 645 {test_html}
320 </div> 646 </div>
321
322 <script> 647 <script>
323 function showTab(id) {{ 648 function showTab(id) {{
324 document.querySelectorAll('.tab-content').forEach(el => el.classList.remove('active')); 649 document.querySelectorAll('.tab-content').forEach(el => el.classList.remove('active'));
325 document.querySelectorAll('.tab').forEach(el => el.classList.remove('active')); 650 document.querySelectorAll('.tab').forEach(el => el.classList.remove('active'));
326 document.getElementById(id).classList.add('active'); 651 document.getElementById(id).classList.add('active');
335 split_column: str, 660 split_column: str,
336 validation_size: float = 0.15, 661 validation_size: float = 0.15,
337 random_state: int = 42, 662 random_state: int = 42,
338 label_column: Optional[str] = None, 663 label_column: Optional[str] = None,
339 ) -> pd.DataFrame: 664 ) -> pd.DataFrame:
340 """ 665 """Given a DataFrame whose split_column only contains {0,2}, re-assign a portion of the 0s to become 1s (validation)."""
341 Given a DataFrame whose split_column only contains {0,2}, re-assign
342 a portion of the 0s to become 1s (validation). Returns a fresh DataFrame.
343 """
344 # Work on a copy
345 out = df.copy() 666 out = df.copy()
346 # Ensure split col is integer dtype
347 out[split_column] = pd.to_numeric(out[split_column], errors="coerce").astype(int) 667 out[split_column] = pd.to_numeric(out[split_column], errors="coerce").astype(int)
348 668
349 idx_train = out.index[out[split_column] == 0].tolist() 669 idx_train = out.index[out[split_column] == 0].tolist()
350 670
351 if not idx_train: 671 if not idx_train:
352 logger.info("No rows with split=0; nothing to do.") 672 logger.info("No rows with split=0; nothing to do.")
353 return out 673 return out
354
355 # Determine stratify array if possible
356 stratify_arr = None 674 stratify_arr = None
357 if label_column and label_column in out.columns: 675 if label_column and label_column in out.columns:
358 # Only stratify if at least two classes and enough samples
359 label_counts = out.loc[idx_train, label_column].value_counts() 676 label_counts = out.loc[idx_train, label_column].value_counts()
360 if label_counts.size > 1 and (label_counts.min() * validation_size) >= 1: 677 if label_counts.size > 1 and (label_counts.min() * validation_size) >= 1:
361 stratify_arr = out.loc[idx_train, label_column] 678 stratify_arr = out.loc[idx_train, label_column]
362 else: 679 else:
363 logger.warning("Cannot stratify (too few labels); splitting without stratify.") 680 logger.warning(
364 681 "Cannot stratify (too few labels); splitting without stratify."
365 # Edge cases 682 )
366 if validation_size <= 0: 683 if validation_size <= 0:
367 logger.info("validation_size <= 0; keeping all as train.") 684 logger.info("validation_size <= 0; keeping all as train.")
368 return out 685 return out
369 if validation_size >= 1: 686 if validation_size >= 1:
370 logger.info("validation_size >= 1; moving all train → validation.") 687 logger.info("validation_size >= 1; moving all train → validation.")
371 out.loc[idx_train, split_column] = 1 688 out.loc[idx_train, split_column] = 1
372 return out 689 return out
373
374 # Do the split
375 try: 690 try:
376 train_idx, val_idx = train_test_split( 691 train_idx, val_idx = train_test_split(
377 idx_train, 692 idx_train,
378 test_size=validation_size, 693 test_size=validation_size,
379 random_state=random_state, 694 random_state=random_state,
380 stratify=stratify_arr 695 stratify=stratify_arr,
381 ) 696 )
382 except ValueError as e: 697 except ValueError as e:
383 logger.warning(f"Stratified split failed ({e}); retrying without stratify.") 698 logger.warning(f"Stratified split failed ({e}); retrying without stratify.")
384 train_idx, val_idx = train_test_split( 699 train_idx, val_idx = train_test_split(
385 idx_train, 700 idx_train,
386 test_size=validation_size, 701 test_size=validation_size,
387 random_state=random_state, 702 random_state=random_state,
388 stratify=None 703 stratify=None,
389 ) 704 )
390
391 # Assign new splits
392 out.loc[train_idx, split_column] = 0 705 out.loc[train_idx, split_column] = 0
393 out.loc[val_idx, split_column] = 1 706 out.loc[val_idx, split_column] = 1
394 # idx_test stays at 2
395
396 # Cast back to a clean integer type
397 out[split_column] = out[split_column].astype(int) 707 out[split_column] = out[split_column].astype(int)
398 # print(out)
399 return out 708 return out
400 709
401 710
402 class Backend(Protocol): 711 class Backend(Protocol):
403 """Interface for a machine learning backend.""" 712 """Interface for a machine learning backend."""
713
404 def prepare_config( 714 def prepare_config(
405 self, 715 self,
406 config_params: Dict[str, Any], 716 config_params: Dict[str, Any],
407 split_config: Dict[str, Any] 717 split_config: Dict[str, Any],
408 ) -> str: 718 ) -> str:
409 ... 719 ...
410 720
411 def run_experiment( 721 def run_experiment(
412 self, 722 self,
430 ) -> Path: 740 ) -> Path:
431 ... 741 ...
432 742
433 743
434 class LudwigDirectBackend: 744 class LudwigDirectBackend:
435 """ 745 """Backend for running Ludwig experiments directly via the internal experiment_cli function."""
436 Backend for running Ludwig experiments directly via the internal experiment_cli function.
437 """
438 746
439 def prepare_config( 747 def prepare_config(
440 self, 748 self,
441 config_params: Dict[str, Any], 749 config_params: Dict[str, Any],
442 split_config: Dict[str, Any], 750 split_config: Dict[str, Any],
443 ) -> str: 751 ) -> str:
444 """ 752 """Build and serialize the Ludwig YAML configuration."""
445 Build and serialize the Ludwig YAML configuration.
446 """
447 logger.info("LudwigDirectBackend: Preparing YAML configuration.") 753 logger.info("LudwigDirectBackend: Preparing YAML configuration.")
448 754
449 model_name = config_params.get("model_name", "resnet18") 755 model_name = config_params.get("model_name", "resnet18")
450 use_pretrained = config_params.get("use_pretrained", False) 756 use_pretrained = config_params.get("use_pretrained", False)
451 fine_tune = config_params.get("fine_tune", False) 757 fine_tune = config_params.get("fine_tune", False)
458 trainable = fine_tune or (not use_pretrained) 764 trainable = fine_tune or (not use_pretrained)
459 if not use_pretrained and not trainable: 765 if not use_pretrained and not trainable:
460 logger.warning("trainable=False; use_pretrained=False is ignored.") 766 logger.warning("trainable=False; use_pretrained=False is ignored.")
461 logger.warning("Setting trainable=True to train the model from scratch.") 767 logger.warning("Setting trainable=True to train the model from scratch.")
462 trainable = True 768 trainable = True
463
464 # Encoder setup
465 raw_encoder = MODEL_ENCODER_TEMPLATES.get(model_name, model_name) 769 raw_encoder = MODEL_ENCODER_TEMPLATES.get(model_name, model_name)
466 if isinstance(raw_encoder, dict): 770 if isinstance(raw_encoder, dict):
467 encoder_config = { 771 encoder_config = {
468 **raw_encoder, 772 **raw_encoder,
469 "use_pretrained": use_pretrained, 773 "use_pretrained": use_pretrained,
470 "trainable": trainable, 774 "trainable": trainable,
471 } 775 }
472 else: 776 else:
473 encoder_config = {"type": raw_encoder} 777 encoder_config = {"type": raw_encoder}
474 778
475 # Trainer & optimizer
476 # optimizer = {"type": "adam", "learning_rate": 5e-5} if fine_tune else {"type": "adam"}
477 batch_size_cfg = batch_size or "auto" 779 batch_size_cfg = batch_size or "auto"
780
781 label_column_path = config_params.get("label_column_data_path")
782 if label_column_path is not None and Path(label_column_path).exists():
783 try:
784 label_series = pd.read_csv(label_column_path)[LABEL_COLUMN_NAME]
785 num_unique_labels = label_series.nunique()
786 except Exception as e:
787 logger.warning(
788 f"Could not determine label cardinality, defaulting to 'binary': {e}"
789 )
790 num_unique_labels = 2
791 else:
792 logger.warning(
793 "label_column_data_path not provided, defaulting to 'binary'"
794 )
795 num_unique_labels = 2
796
797 output_type = "binary" if num_unique_labels == 2 else "category"
478 798
479 conf: Dict[str, Any] = { 799 conf: Dict[str, Any] = {
480 "model_type": "ecd", 800 "model_type": "ecd",
481 "input_features": [ 801 "input_features": [
482 { 802 {
483 "name": IMAGE_PATH_COLUMN_NAME, 803 "name": IMAGE_PATH_COLUMN_NAME,
484 "type": "image", 804 "type": "image",
485 "encoder": encoder_config, 805 "encoder": encoder_config,
486 } 806 }
487 ], 807 ],
488 "output_features": [ 808 "output_features": [{"name": LABEL_COLUMN_NAME, "type": output_type}],
489 {"name": LABEL_COLUMN_NAME, "type": "category"}
490 ],
491 "combiner": {"type": "concat"}, 809 "combiner": {"type": "concat"},
492 "trainer": { 810 "trainer": {
493 "epochs": epochs, 811 "epochs": epochs,
494 "early_stop": early_stop, 812 "early_stop": early_stop,
495 "batch_size": batch_size_cfg, 813 "batch_size": batch_size_cfg,
506 try: 824 try:
507 yaml_str = yaml.dump(conf, sort_keys=False, indent=2) 825 yaml_str = yaml.dump(conf, sort_keys=False, indent=2)
508 logger.info("LudwigDirectBackend: YAML config generated.") 826 logger.info("LudwigDirectBackend: YAML config generated.")
509 return yaml_str 827 return yaml_str
510 except Exception: 828 except Exception:
511 logger.error("LudwigDirectBackend: Failed to serialize YAML.", exc_info=True) 829 logger.error(
830 "LudwigDirectBackend: Failed to serialize YAML.",
831 exc_info=True,
832 )
512 raise 833 raise
513 834
514 def run_experiment( 835 def run_experiment(
515 self, 836 self,
516 dataset_path: Path, 837 dataset_path: Path,
517 config_path: Path, 838 config_path: Path,
518 output_dir: Path, 839 output_dir: Path,
519 random_seed: int = 42, 840 random_seed: int = 42,
520 ) -> None: 841 ) -> None:
521 """ 842 """Invoke Ludwig's internal experiment_cli function to run the experiment."""
522 Invoke Ludwig's internal experiment_cli function to run the experiment.
523 """
524 logger.info("LudwigDirectBackend: Starting experiment execution.") 843 logger.info("LudwigDirectBackend: Starting experiment execution.")
525 844
526 try: 845 try:
527 from ludwig.experiment import experiment_cli 846 from ludwig.experiment import experiment_cli
528 except ImportError as e: 847 except ImportError as e:
529 logger.error( 848 logger.error(
530 "LudwigDirectBackend: Could not import experiment_cli.", 849 "LudwigDirectBackend: Could not import experiment_cli.",
531 exc_info=True 850 exc_info=True,
532 ) 851 )
533 raise RuntimeError("Ludwig import failed.") from e 852 raise RuntimeError("Ludwig import failed.") from e
534 853
535 output_dir.mkdir(parents=True, exist_ok=True) 854 output_dir.mkdir(parents=True, exist_ok=True)
536 855
539 dataset=str(dataset_path), 858 dataset=str(dataset_path),
540 config=str(config_path), 859 config=str(config_path),
541 output_directory=str(output_dir), 860 output_directory=str(output_dir),
542 random_seed=random_seed, 861 random_seed=random_seed,
543 ) 862 )
544 logger.info(f"LudwigDirectBackend: Experiment completed. Results in {output_dir}") 863 logger.info(
864 f"LudwigDirectBackend: Experiment completed. Results in {output_dir}"
865 )
545 except TypeError as e: 866 except TypeError as e:
546 logger.error( 867 logger.error(
547 "LudwigDirectBackend: Argument mismatch in experiment_cli call.", 868 "LudwigDirectBackend: Argument mismatch in experiment_cli call.",
548 exc_info=True 869 exc_info=True,
549 ) 870 )
550 raise RuntimeError("Ludwig argument error.") from e 871 raise RuntimeError("Ludwig argument error.") from e
551 except Exception: 872 except Exception:
552 logger.error( 873 logger.error(
553 "LudwigDirectBackend: Experiment execution error.", 874 "LudwigDirectBackend: Experiment execution error.",
554 exc_info=True 875 exc_info=True,
555 ) 876 )
556 raise 877 raise
557 878
558 def get_training_process(self, output_dir) -> float: 879 def get_training_process(self, output_dir) -> float:
559 """ 880 """Retrieve the learning rate used in the most recent Ludwig run."""
560 Retrieve the learning rate used in the most recent Ludwig run.
561 Returns:
562 float: learning rate (or None if not found)
563 """
564 output_dir = Path(output_dir) 881 output_dir = Path(output_dir)
565 exp_dirs = sorted( 882 exp_dirs = sorted(
566 output_dir.glob("experiment_run*"), 883 output_dir.glob("experiment_run*"),
567 key=lambda p: p.stat().st_mtime 884 key=lambda p: p.stat().st_mtime,
568 ) 885 )
569 886
570 if not exp_dirs: 887 if not exp_dirs:
571 logger.warning(f"No experiment run directories found in {output_dir}") 888 logger.warning(f"No experiment run directories found in {output_dir}")
572 return None 889 return None
583 "learning_rate": data.get("learning_rate"), 900 "learning_rate": data.get("learning_rate"),
584 "batch_size": data.get("batch_size"), 901 "batch_size": data.get("batch_size"),
585 "epoch": data.get("epoch"), 902 "epoch": data.get("epoch"),
586 } 903 }
587 except Exception as e: 904 except Exception as e:
588 self.logger.warning(f"Failed to read training progress info: {e}") 905 logger.warning(f"Failed to read training progress info: {e}")
589 return {} 906 return {}
590 907
591 def convert_parquet_to_csv(self, output_dir: Path): 908 def convert_parquet_to_csv(self, output_dir: Path):
592 """Convert the predictions Parquet file to CSV.""" 909 """Convert the predictions Parquet file to CSV."""
593 output_dir = Path(output_dir) 910 output_dir = Path(output_dir)
594 exp_dirs = sorted( 911 exp_dirs = sorted(
595 output_dir.glob("experiment_run*"), 912 output_dir.glob("experiment_run*"),
596 key=lambda p: p.stat().st_mtime 913 key=lambda p: p.stat().st_mtime,
597 ) 914 )
598 if not exp_dirs: 915 if not exp_dirs:
599 logger.warning(f"No experiment run dirs found in {output_dir}") 916 logger.warning(f"No experiment run dirs found in {output_dir}")
600 return 917 return
601 exp_dir = exp_dirs[-1] 918 exp_dir = exp_dirs[-1]
607 logger.info(f"Converted Parquet to CSV: {csv_path}") 924 logger.info(f"Converted Parquet to CSV: {csv_path}")
608 except Exception as e: 925 except Exception as e:
609 logger.error(f"Error converting Parquet to CSV: {e}") 926 logger.error(f"Error converting Parquet to CSV: {e}")
610 927
611 def generate_plots(self, output_dir: Path) -> None: 928 def generate_plots(self, output_dir: Path) -> None:
612 """ 929 """Generate all registered Ludwig visualizations for the latest experiment run."""
613 Generate _all_ registered Ludwig visualizations for the latest experiment run.
614 """
615 logger.info("Generating all Ludwig visualizations…") 930 logger.info("Generating all Ludwig visualizations…")
616 931
617 test_plots = { 932 test_plots = {
618 'compare_performance', 933 "compare_performance",
619 'compare_classifiers_performance_from_prob', 934 "compare_classifiers_performance_from_prob",
620 'compare_classifiers_performance_from_pred', 935 "compare_classifiers_performance_from_pred",
621 'compare_classifiers_performance_changing_k', 936 "compare_classifiers_performance_changing_k",
622 'compare_classifiers_multiclass_multimetric', 937 "compare_classifiers_multiclass_multimetric",
623 'compare_classifiers_predictions', 938 "compare_classifiers_predictions",
624 'confidence_thresholding_2thresholds_2d', 939 "confidence_thresholding_2thresholds_2d",
625 'confidence_thresholding_2thresholds_3d', 940 "confidence_thresholding_2thresholds_3d",
626 'confidence_thresholding', 941 "confidence_thresholding",
627 'confidence_thresholding_data_vs_acc', 942 "confidence_thresholding_data_vs_acc",
628 'binary_threshold_vs_metric', 943 "binary_threshold_vs_metric",
629 'roc_curves', 944 "roc_curves",
630 'roc_curves_from_test_statistics', 945 "roc_curves_from_test_statistics",
631 'calibration_1_vs_all', 946 "calibration_1_vs_all",
632 'calibration_multiclass', 947 "calibration_multiclass",
633 'confusion_matrix', 948 "confusion_matrix",
634 'frequency_vs_f1', 949 "frequency_vs_f1",
635 } 950 }
636 train_plots = { 951 train_plots = {
637 'learning_curves', 952 "learning_curves",
638 'compare_classifiers_performance_subset', 953 "compare_classifiers_performance_subset",
639 } 954 }
640 955
641 # 1) find the most recent experiment directory
642 output_dir = Path(output_dir) 956 output_dir = Path(output_dir)
643 exp_dirs = sorted( 957 exp_dirs = sorted(
644 output_dir.glob("experiment_run*"), 958 output_dir.glob("experiment_run*"),
645 key=lambda p: p.stat().st_mtime 959 key=lambda p: p.stat().st_mtime,
646 ) 960 )
647 if not exp_dirs: 961 if not exp_dirs:
648 logger.warning(f"No experiment run dirs found in {output_dir}") 962 logger.warning(f"No experiment run dirs found in {output_dir}")
649 return 963 return
650 exp_dir = exp_dirs[-1] 964 exp_dir = exp_dirs[-1]
651 965
652 # 2) ensure viz output subfolder exists
653 viz_dir = exp_dir / "visualizations" 966 viz_dir = exp_dir / "visualizations"
654 viz_dir.mkdir(exist_ok=True) 967 viz_dir.mkdir(exist_ok=True)
655 train_viz = viz_dir / "train" 968 train_viz = viz_dir / "train"
656 test_viz = viz_dir / "test" 969 test_viz = viz_dir / "test"
657 train_viz.mkdir(parents=True, exist_ok=True) 970 train_viz.mkdir(parents=True, exist_ok=True)
658 test_viz.mkdir(parents=True, exist_ok=True) 971 test_viz.mkdir(parents=True, exist_ok=True)
659 972
660 # 3) helper to check file existence
661 def _check(p: Path) -> Optional[str]: 973 def _check(p: Path) -> Optional[str]:
662 return str(p) if p.exists() else None 974 return str(p) if p.exists() else None
663 975
664 # 4) gather standard Ludwig output files
665 training_stats = _check(exp_dir / "training_statistics.json") 976 training_stats = _check(exp_dir / "training_statistics.json")
666 test_stats = _check(exp_dir / TEST_STATISTICS_FILE_NAME) 977 test_stats = _check(exp_dir / TEST_STATISTICS_FILE_NAME)
667 probs_path = _check(exp_dir / PREDICTIONS_PARQUET_FILE_NAME) 978 probs_path = _check(exp_dir / PREDICTIONS_PARQUET_FILE_NAME)
668 gt_metadata = _check(exp_dir / "model" / TRAIN_SET_METADATA_FILE_NAME) 979 gt_metadata = _check(exp_dir / "model" / TRAIN_SET_METADATA_FILE_NAME)
669 980
670 # 5) try to read original dataset & split file from description.json
671 dataset_path = None 981 dataset_path = None
672 split_file = None 982 split_file = None
673 desc = exp_dir / DESCRIPTION_FILE_NAME 983 desc = exp_dir / DESCRIPTION_FILE_NAME
674 if desc.exists(): 984 if desc.exists():
675 with open(desc, "r") as f: 985 with open(desc, "r") as f:
676 cfg = json.load(f) 986 cfg = json.load(f)
677 dataset_path = _check(Path(cfg.get("dataset", ""))) 987 dataset_path = _check(Path(cfg.get("dataset", "")))
678 split_file = _check(Path(get_split_path(cfg.get("dataset", "")))) 988 split_file = _check(Path(get_split_path(cfg.get("dataset", ""))))
679 989
680 # 6) infer output feature name
681 output_feature = "" 990 output_feature = ""
682 if desc.exists(): 991 if desc.exists():
683 try: 992 try:
684 output_feature = cfg["config"]["output_features"][0]["name"] 993 output_feature = cfg["config"]["output_features"][0]["name"]
685 except Exception: 994 except Exception:
687 if not output_feature and test_stats: 996 if not output_feature and test_stats:
688 with open(test_stats, "r") as f: 997 with open(test_stats, "r") as f:
689 stats = json.load(f) 998 stats = json.load(f)
690 output_feature = next(iter(stats.keys()), "") 999 output_feature = next(iter(stats.keys()), "")
691 1000
692 # 7) loop through every registered viz
693 viz_registry = get_visualizations_registry() 1001 viz_registry = get_visualizations_registry()
694 for viz_name, viz_func in viz_registry.items(): 1002 for viz_name, viz_func in viz_registry.items():
695 viz_dir_plot = None 1003 viz_dir_plot = None
696 if viz_name in train_plots: 1004 if viz_name in train_plots:
697 viz_dir_plot = train_viz 1005 viz_dir_plot = train_viz
719 logger.warning(f"✘ Skipped {viz_name}: {e}") 1027 logger.warning(f"✘ Skipped {viz_name}: {e}")
720 1028
721 logger.info(f"All visualizations written to {viz_dir}") 1029 logger.info(f"All visualizations written to {viz_dir}")
722 1030
723 def generate_html_report( 1031 def generate_html_report(
724 self, 1032 self,
725 title: str, 1033 title: str,
726 output_dir: str, 1034 output_dir: str,
727 config: dict, 1035 config: dict,
728 split_info: str) -> Path: 1036 split_info: str,
729 """ 1037 ) -> Path:
730 Assemble an HTML report from visualizations under train_val/ and test/ folders. 1038 """Assemble an HTML report from visualizations under train_val/ and test/ folders."""
731 """
732 cwd = Path.cwd() 1039 cwd = Path.cwd()
733 report_name = title.lower().replace(" ", "_") + "_report.html" 1040 report_name = title.lower().replace(" ", "_") + "_report.html"
734 report_path = cwd / report_name 1041 report_path = cwd / report_name
735 output_dir = Path(output_dir) 1042 output_dir = Path(output_dir)
736 1043
737 # Find latest experiment dir 1044 exp_dirs = sorted(
738 exp_dirs = sorted(output_dir.glob("experiment_run*"), key=lambda p: p.stat().st_mtime) 1045 output_dir.glob("experiment_run*"),
1046 key=lambda p: p.stat().st_mtime,
1047 )
739 if not exp_dirs: 1048 if not exp_dirs:
740 raise RuntimeError(f"No 'experiment*' dirs found in {output_dir}") 1049 raise RuntimeError(f"No 'experiment*' dirs found in {output_dir}")
741 exp_dir = exp_dirs[-1] 1050 exp_dir = exp_dirs[-1]
742 1051
743 base_viz_dir = exp_dir / "visualizations" 1052 base_viz_dir = exp_dir / "visualizations"
746 1055
747 html = get_html_template() 1056 html = get_html_template()
748 html += f"<h1>{title}</h1>" 1057 html += f"<h1>{title}</h1>"
749 1058
750 metrics_html = "" 1059 metrics_html = ""
751 1060 train_val_metrics_html = ""
752 # Load and embed metrics table (training/val/test stats) 1061 test_metrics_html = ""
1062
753 try: 1063 try:
754 train_stats_path = exp_dir / "training_statistics.json" 1064 train_stats_path = exp_dir / "training_statistics.json"
755 test_stats_path = exp_dir / TEST_STATISTICS_FILE_NAME 1065 test_stats_path = exp_dir / TEST_STATISTICS_FILE_NAME
756 if train_stats_path.exists() and test_stats_path.exists(): 1066 if train_stats_path.exists() and test_stats_path.exists():
757 with open(train_stats_path) as f: 1067 with open(train_stats_path) as f:
758 train_stats = json.load(f) 1068 train_stats = json.load(f)
759 with open(test_stats_path) as f: 1069 with open(test_stats_path) as f:
760 test_stats = json.load(f) 1070 test_stats = json.load(f)
761 output_feature = next(iter(train_stats.keys()), "") 1071 output_type = detect_output_type(test_stats)
762 if output_feature: 1072 all_metrics = extract_metrics_from_json(
763 metrics_html += format_stats_table_html(train_stats, test_stats) 1073 train_stats,
1074 test_stats,
1075 output_type,
1076 )
1077 metrics_html = format_stats_table_html(train_stats, test_stats)
1078 train_val_metrics_html = format_train_val_stats_table_html(
1079 train_stats,
1080 test_stats,
1081 )
1082 test_metrics_html = format_test_merged_stats_table_html(
1083 all_metrics["test"],
1084 )
764 except Exception as e: 1085 except Exception as e:
765 logger.warning(f"Could not load stats for HTML report: {e}") 1086 logger.warning(
1087 f"Could not load stats for HTML report: {type(e).__name__}: {e}"
1088 )
766 1089
767 config_html = "" 1090 config_html = ""
768 training_progress = self.get_training_process(output_dir) 1091 training_progress = self.get_training_process(output_dir)
769 try: 1092 try:
770 config_html = format_config_table_html(config, split_info, training_progress) 1093 config_html = format_config_table_html(config, split_info, training_progress)
771 except Exception as e: 1094 except Exception as e:
772 logger.warning(f"Could not load config for HTML report: {e}") 1095 logger.warning(f"Could not load config for HTML report: {e}")
773 1096
774 def render_img_section(title: str, dir_path: Path) -> str: 1097 def render_img_section(title: str, dir_path: Path, output_type: str = None) -> str:
775 if not dir_path.exists(): 1098 if not dir_path.exists():
776 return f"<h2>{title}</h2><p><em>Directory not found.</em></p>" 1099 return f"<h2>{title}</h2><p><em>Directory not found.</em></p>"
777 imgs = sorted(dir_path.glob("*.png")) 1100
1101 imgs = list(dir_path.glob("*.png"))
778 if not imgs: 1102 if not imgs:
779 return f"<h2>{title}</h2><p><em>No plots found.</em></p>" 1103 return f"<h2>{title}</h2><p><em>No plots found.</em></p>"
1104
1105 if title == "Test Visualizations" and output_type == "binary":
1106 order = [
1107 "confusion_matrix__label_top2.png",
1108 "roc_curves_from_prediction_statistics.png",
1109 "compare_performance_label.png",
1110 "confusion_matrix_entropy__label_top2.png",
1111 ]
1112 img_names = {img.name: img for img in imgs}
1113 ordered_imgs = [
1114 img_names[fname] for fname in order if fname in img_names
1115 ]
1116 remaining = sorted(
1117 [
1118 img
1119 for img in imgs
1120 if img.name not in order and img.name != "roc_curves.png"
1121 ]
1122 )
1123 imgs = ordered_imgs + remaining
1124
1125 elif title == "Test Visualizations" and output_type == "category":
1126 unwanted = {
1127 "compare_classifiers_multiclass_multimetric__label_best10.png",
1128 "compare_classifiers_multiclass_multimetric__label_top10.png",
1129 "compare_classifiers_multiclass_multimetric__label_worst10.png",
1130 }
1131 display_order = [
1132 "confusion_matrix__label_top10.png",
1133 "roc_curves.png",
1134 "compare_performance_label.png",
1135 "compare_classifiers_performance_from_prob.png",
1136 "compare_classifiers_multiclass_multimetric__label_sorted.png",
1137 "confusion_matrix_entropy__label_top10.png",
1138 ]
1139 img_names = {img.name: img for img in imgs if img.name not in unwanted}
1140 ordered_imgs = [
1141 img_names[fname] for fname in display_order if fname in img_names
1142 ]
1143 remaining = sorted(
1144 [
1145 img
1146 for img in img_names.values()
1147 if img.name not in display_order
1148 ]
1149 )
1150 imgs = ordered_imgs + remaining
1151
1152 else:
1153 if output_type == "category":
1154 unwanted = {
1155 "compare_classifiers_multiclass_multimetric__label_best10.png",
1156 "compare_classifiers_multiclass_multimetric__label_top10.png",
1157 "compare_classifiers_multiclass_multimetric__label_worst10.png",
1158 }
1159 imgs = sorted([img for img in imgs if img.name not in unwanted])
1160 else:
1161 imgs = sorted(imgs)
780 1162
781 section_html = f"<h2 style='text-align: center;'>{title}</h2><div>" 1163 section_html = f"<h2 style='text-align: center;'>{title}</h2><div>"
782 for img in imgs: 1164 for img in imgs:
783 b64 = encode_image_to_base64(str(img)) 1165 b64 = encode_image_to_base64(str(img))
784 section_html += ( 1166 section_html += (
785 f'<div class="plot" style="margin-bottom:20px;text-align:center;">' 1167 f'<div class="plot" style="margin-bottom:20px;text-align:center;">'
786 f"<h3>{img.stem.replace('_',' ').title()}</h3>" 1168 f"<h3>{img.stem.replace('_', ' ').title()}</h3>"
787 f'<img src="data:image/png;base64,{b64}" ' 1169 f'<img src="data:image/png;base64,{b64}" '
788 'style="max-width:90%;max-height:600px;border:1px solid #ddd;" />' 1170 f'style="max-width:90%;max-height:600px;border:1px solid #ddd;" />'
789 "</div>" 1171 f"</div>"
790 ) 1172 )
791 section_html += "</div>" 1173 section_html += "</div>"
792 return section_html 1174 return section_html
793 1175
794 train_plots_html = render_img_section("Training & Validation Visualizations", train_viz_dir) 1176 button_html = """
795 test_plots_html = render_img_section("Test Visualizations", test_viz_dir) 1177 <button class="help-modal-btn openMetricsHelp">Model Evaluation Metrics — Help Guide</button>
796 html += build_tabbed_html(config_html + metrics_html, train_plots_html, test_plots_html) 1178 <br><br>
1179 <style>
1180 .help-modal-btn {
1181 background-color: #17623b;
1182 color: #fff;
1183 border: none;
1184 border-radius: 24px;
1185 padding: 10px 28px;
1186 font-size: 1.1rem;
1187 font-weight: bold;
1188 letter-spacing: 0.03em;
1189 cursor: pointer;
1190 transition: background 0.2s, box-shadow 0.2s;
1191 box-shadow: 0 2px 8px rgba(23,98,59,0.07);
1192 }
1193 .help-modal-btn:hover, .help-modal-btn:focus {
1194 background-color: #21895e;
1195 outline: none;
1196 box-shadow: 0 4px 16px rgba(23,98,59,0.14);
1197 }
1198 </style>
1199 """
1200 tab1_content = button_html + config_html + metrics_html
1201 tab2_content = (
1202 button_html
1203 + train_val_metrics_html
1204 + render_img_section("Training & Validation Visualizations", train_viz_dir)
1205 )
1206 tab3_content = (
1207 button_html
1208 + test_metrics_html
1209 + render_img_section("Test Visualizations", test_viz_dir, output_type)
1210 )
1211
1212 tabbed_html = build_tabbed_html(tab1_content, tab2_content, tab3_content)
1213 modal_html = get_metrics_help_modal()
1214 html += tabbed_html + modal_html
797 html += get_html_closing() 1215 html += get_html_closing()
798 1216
799 try: 1217 try:
800 with open(report_path, "w") as f: 1218 with open(report_path, "w") as f:
801 f.write(html) 1219 f.write(html)
806 1224
807 return report_path 1225 return report_path
808 1226
809 1227
810 class WorkflowOrchestrator: 1228 class WorkflowOrchestrator:
811 """ 1229 """Manages the image-classification workflow."""
812 Manages the image-classification workflow:
813 1. Creates temp dirs
814 2. Extracts images
815 3. Prepares data (CSV + splits)
816 4. Renders a backend config
817 5. Runs the experiment
818 6. Cleans up
819 """
820 1230
821 def __init__(self, args: argparse.Namespace, backend: Backend): 1231 def __init__(self, args: argparse.Namespace, backend: Backend):
822 self.args = args 1232 self.args = args
823 self.backend = backend 1233 self.backend = backend
824 self.temp_dir: Optional[Path] = None 1234 self.temp_dir: Optional[Path] = None
826 logger.info(f"Orchestrator initialized with backend: {type(backend).__name__}") 1236 logger.info(f"Orchestrator initialized with backend: {type(backend).__name__}")
827 1237
828 def _create_temp_dirs(self) -> None: 1238 def _create_temp_dirs(self) -> None:
829 """Create temporary output and image extraction directories.""" 1239 """Create temporary output and image extraction directories."""
830 try: 1240 try:
831 self.temp_dir = Path(tempfile.mkdtemp( 1241 self.temp_dir = Path(
832 dir=self.args.output_dir, 1242 tempfile.mkdtemp(dir=self.args.output_dir, prefix=TEMP_DIR_PREFIX)
833 prefix=TEMP_DIR_PREFIX 1243 )
834 ))
835 self.image_extract_dir = self.temp_dir / "images" 1244 self.image_extract_dir = self.temp_dir / "images"
836 self.image_extract_dir.mkdir() 1245 self.image_extract_dir.mkdir()
837 logger.info(f"Created temp directory: {self.temp_dir}") 1246 logger.info(f"Created temp directory: {self.temp_dir}")
838 except Exception: 1247 except Exception:
839 logger.error("Failed to create temporary directories", exc_info=True) 1248 logger.error("Failed to create temporary directories", exc_info=True)
841 1250
842 def _extract_images(self) -> None: 1251 def _extract_images(self) -> None:
843 """Extract images from ZIP into the temp image directory.""" 1252 """Extract images from ZIP into the temp image directory."""
844 if self.image_extract_dir is None: 1253 if self.image_extract_dir is None:
845 raise RuntimeError("Temp image directory not initialized.") 1254 raise RuntimeError("Temp image directory not initialized.")
846 logger.info(f"Extracting images from {self.args.image_zip} → {self.image_extract_dir}") 1255 logger.info(
1256 f"Extracting images from {self.args.image_zip} → {self.image_extract_dir}"
1257 )
847 try: 1258 try:
848 with zipfile.ZipFile(self.args.image_zip, "r") as z: 1259 with zipfile.ZipFile(self.args.image_zip, "r") as z:
849 z.extractall(self.image_extract_dir) 1260 z.extractall(self.image_extract_dir)
850 logger.info("Image extraction complete.") 1261 logger.info("Image extraction complete.")
851 except Exception: 1262 except Exception:
852 logger.error("Error extracting zip file", exc_info=True) 1263 logger.error("Error extracting zip file", exc_info=True)
853 raise 1264 raise
854 1265
855 def _prepare_data(self) -> Tuple[Path, Dict[str, Any]]: 1266 def _prepare_data(self) -> Tuple[Path, Dict[str, Any]]:
856 """ 1267 """Load CSV, update image paths, handle splits, and write prepared CSV."""
857 Load CSV, update image paths, handle splits, and write prepared CSV.
858 Returns:
859 final_csv_path: Path to the prepared CSV
860 split_config: Dict for backend split settings
861 """
862 if not self.temp_dir or not self.image_extract_dir: 1268 if not self.temp_dir or not self.image_extract_dir:
863 raise RuntimeError("Temp dirs not initialized before data prep.") 1269 raise RuntimeError("Temp dirs not initialized before data prep.")
864 1270
865 # 1) Load
866 try: 1271 try:
867 df = pd.read_csv(self.args.csv_file) 1272 df = pd.read_csv(self.args.csv_file)
868 logger.info(f"Loaded CSV: {self.args.csv_file}") 1273 logger.info(f"Loaded CSV: {self.args.csv_file}")
869 except Exception: 1274 except Exception:
870 logger.error("Error loading CSV file", exc_info=True) 1275 logger.error("Error loading CSV file", exc_info=True)
871 raise 1276 raise
872 1277
873 # 2) Validate columns
874 required = {IMAGE_PATH_COLUMN_NAME, LABEL_COLUMN_NAME} 1278 required = {IMAGE_PATH_COLUMN_NAME, LABEL_COLUMN_NAME}
875 missing = required - set(df.columns) 1279 missing = required - set(df.columns)
876 if missing: 1280 if missing:
877 raise ValueError(f"Missing CSV columns: {', '.join(missing)}") 1281 raise ValueError(f"Missing CSV columns: {', '.join(missing)}")
878 1282
879 # 3) Update image paths
880 try: 1283 try:
881 df[IMAGE_PATH_COLUMN_NAME] = df[IMAGE_PATH_COLUMN_NAME].apply( 1284 df[IMAGE_PATH_COLUMN_NAME] = df[IMAGE_PATH_COLUMN_NAME].apply(
882 lambda p: str((self.image_extract_dir / p).resolve()) 1285 lambda p: str((self.image_extract_dir / p).resolve())
883 ) 1286 )
884 except Exception: 1287 except Exception:
885 logger.error("Error updating image paths", exc_info=True) 1288 logger.error("Error updating image paths", exc_info=True)
886 raise 1289 raise
887 1290
888 # 4) Handle splits
889 if SPLIT_COLUMN_NAME in df.columns: 1291 if SPLIT_COLUMN_NAME in df.columns:
890 df, split_config, split_info = self._process_fixed_split(df) 1292 df, split_config, split_info = self._process_fixed_split(df)
891 else: 1293 else:
892 logger.info("No split column; using random split") 1294 logger.info("No split column; using random split")
893 split_config = { 1295 split_config = {
894 "type": "random", 1296 "type": "random",
895 "probabilities": self.args.split_probabilities 1297 "probabilities": self.args.split_probabilities,
896 } 1298 }
897 split_info = ( 1299 split_info = (
898 f"No split column in CSV. Used random split: " 1300 f"No split column in CSV. Used random split: "
899 f"{[int(p*100) for p in self.args.split_probabilities]}% for train/val/test." 1301 f"{[int(p * 100) for p in self.args.split_probabilities]}% "
900 ) 1302 f"for train/val/test."
901 1303 )
902 # 5) Write out prepared CSV 1304
903 final_csv = TEMP_CSV_FILENAME 1305 final_csv = TEMP_CSV_FILENAME
904 try: 1306 try:
905 df.to_csv(final_csv, index=False) 1307 df.to_csv(final_csv, index=False)
906 logger.info(f"Saved prepared data to {final_csv}") 1308 logger.info(f"Saved prepared data to {final_csv}")
907 except Exception: 1309 except Exception:
913 def _process_fixed_split(self, df: pd.DataFrame) -> Dict[str, Any]: 1315 def _process_fixed_split(self, df: pd.DataFrame) -> Dict[str, Any]:
914 """Process a fixed split column (0=train,1=val,2=test).""" 1316 """Process a fixed split column (0=train,1=val,2=test)."""
915 logger.info(f"Fixed split column '{SPLIT_COLUMN_NAME}' detected.") 1317 logger.info(f"Fixed split column '{SPLIT_COLUMN_NAME}' detected.")
916 try: 1318 try:
917 col = df[SPLIT_COLUMN_NAME] 1319 col = df[SPLIT_COLUMN_NAME]
918 df[SPLIT_COLUMN_NAME] = pd.to_numeric(col, errors="coerce").astype(pd.Int64Dtype()) 1320 df[SPLIT_COLUMN_NAME] = pd.to_numeric(col, errors="coerce").astype(
1321 pd.Int64Dtype()
1322 )
919 if df[SPLIT_COLUMN_NAME].isna().any(): 1323 if df[SPLIT_COLUMN_NAME].isna().any():
920 logger.warning("Split column contains non-numeric/missing values.") 1324 logger.warning("Split column contains non-numeric/missing values.")
921 1325
922 unique = set(df[SPLIT_COLUMN_NAME].dropna().unique()) 1326 unique = set(df[SPLIT_COLUMN_NAME].dropna().unique())
923 logger.info(f"Unique split values: {unique}") 1327 logger.info(f"Unique split values: {unique}")
924 1328
925 if unique == {0, 2}: 1329 if unique == {0, 2}:
926 df = split_data_0_2( 1330 df = split_data_0_2(
927 df, SPLIT_COLUMN_NAME, 1331 df,
1332 SPLIT_COLUMN_NAME,
928 validation_size=self.args.validation_size, 1333 validation_size=self.args.validation_size,
929 label_column=LABEL_COLUMN_NAME, 1334 label_column=LABEL_COLUMN_NAME,
930 random_state=self.args.random_seed 1335 random_state=self.args.random_seed,
931 ) 1336 )
932 split_info = ( 1337 split_info = (
933 "Detected a split column (with values 0 and 2) in the input CSV. " 1338 "Detected a split column (with values 0 and 2) in the input CSV. "
934 f"Used this column as a base and" 1339 f"Used this column as a base and reassigned "
935 f"reassigned {self.args.validation_size * 100:.1f}% " 1340 f"{self.args.validation_size * 100:.1f}% "
936 "of the training set (originally labeled 0) to validation (labeled 1)." 1341 "of the training set (originally labeled 0) to validation (labeled 1)."
937 ) 1342 )
938
939 logger.info("Applied custom 0/2 split.") 1343 logger.info("Applied custom 0/2 split.")
940 elif unique.issubset({0, 1, 2}): 1344 elif unique.issubset({0, 1, 2}):
941 split_info = "Used user-defined split column from CSV." 1345 split_info = "Used user-defined split column from CSV."
942 logger.info("Using fixed split as-is.") 1346 logger.info("Using fixed split as-is.")
943 else: 1347 else:
948 except Exception: 1352 except Exception:
949 logger.error("Error processing fixed split", exc_info=True) 1353 logger.error("Error processing fixed split", exc_info=True)
950 raise 1354 raise
951 1355
952 def _cleanup_temp_dirs(self) -> None: 1356 def _cleanup_temp_dirs(self) -> None:
953 """Remove any temporary directories."""
954 if self.temp_dir and self.temp_dir.exists(): 1357 if self.temp_dir and self.temp_dir.exists():
955 logger.info(f"Cleaning up temp directory: {self.temp_dir}") 1358 logger.info(f"Cleaning up temp directory: {self.temp_dir}")
956 shutil.rmtree(self.temp_dir, ignore_errors=True) 1359 shutil.rmtree(self.temp_dir, ignore_errors=True)
957 self.temp_dir = None 1360 self.temp_dir = None
958 self.image_extract_dir = None 1361 self.image_extract_dir = None
978 "preprocessing_num_processes": self.args.preprocessing_num_processes, 1381 "preprocessing_num_processes": self.args.preprocessing_num_processes,
979 "split_probabilities": self.args.split_probabilities, 1382 "split_probabilities": self.args.split_probabilities,
980 "learning_rate": self.args.learning_rate, 1383 "learning_rate": self.args.learning_rate,
981 "random_seed": self.args.random_seed, 1384 "random_seed": self.args.random_seed,
982 "early_stop": self.args.early_stop, 1385 "early_stop": self.args.early_stop,
1386 "label_column_data_path": csv_path,
983 } 1387 }
984 yaml_str = self.backend.prepare_config(backend_args, split_cfg) 1388 yaml_str = self.backend.prepare_config(backend_args, split_cfg)
985 1389
986 config_file = self.temp_dir / TEMP_CONFIG_FILENAME 1390 config_file = self.temp_dir / TEMP_CONFIG_FILENAME
987 config_file.write_text(yaml_str) 1391 config_file.write_text(yaml_str)
989 1393
990 self.backend.run_experiment( 1394 self.backend.run_experiment(
991 csv_path, 1395 csv_path,
992 config_file, 1396 config_file,
993 self.args.output_dir, 1397 self.args.output_dir,
994 self.args.random_seed 1398 self.args.random_seed,
995 ) 1399 )
996 logger.info("Workflow completed successfully.") 1400 logger.info("Workflow completed successfully.")
997 self.backend.generate_plots(self.args.output_dir) 1401 self.backend.generate_plots(self.args.output_dir)
998 report_file = self.backend.generate_html_report( 1402 report_file = self.backend.generate_html_report(
999 "Image Classification Results", 1403 "Image Classification Results",
1000 self.args.output_dir, 1404 self.args.output_dir,
1001 backend_args, 1405 backend_args,
1002 split_info 1406 split_info,
1003 ) 1407 )
1004 logger.info(f"HTML report generated at: {report_file}") 1408 logger.info(f"HTML report generated at: {report_file}")
1005 self.backend.convert_parquet_to_csv(self.args.output_dir) 1409 self.backend.convert_parquet_to_csv(self.args.output_dir)
1006 logger.info("Converted Parquet to CSV.") 1410 logger.info("Converted Parquet to CSV.")
1007 except Exception: 1411 except Exception:
1008 logger.error("Workflow execution failed", exc_info=True) 1412 logger.error("Workflow execution failed", exc_info=True)
1009 raise 1413 raise
1010
1011 finally: 1414 finally:
1012 self._cleanup_temp_dirs() 1415 self._cleanup_temp_dirs()
1013 1416
1014 1417
1015 def parse_learning_rate(s): 1418 def parse_learning_rate(s):
1019 return None 1422 return None
1020 1423
1021 1424
1022 class SplitProbAction(argparse.Action): 1425 class SplitProbAction(argparse.Action):
1023 def __call__(self, parser, namespace, values, option_string=None): 1426 def __call__(self, parser, namespace, values, option_string=None):
1024 # values is a list of three floats
1025 train, val, test = values 1427 train, val, test = values
1026 total = train + val + test 1428 total = train + val + test
1027 if abs(total - 1.0) > 1e-6: 1429 if abs(total - 1.0) > 1e-6:
1028 parser.error( 1430 parser.error(
1029 f"--split-probabilities must sum to 1.0; " 1431 f"--split-probabilities must sum to 1.0; "
1031 ) 1433 )
1032 setattr(namespace, self.dest, values) 1434 setattr(namespace, self.dest, values)
1033 1435
1034 1436
1035 def main(): 1437 def main():
1036
1037 parser = argparse.ArgumentParser( 1438 parser = argparse.ArgumentParser(
1038 description="Image Classification Learner with Pluggable Backends" 1439 description="Image Classification Learner with Pluggable Backends",
1039 ) 1440 )
1040 parser.add_argument( 1441 parser.add_argument(
1041 "--csv-file", required=True, type=Path, 1442 "--csv-file",
1042 help="Path to the input CSV" 1443 required=True,
1444 type=Path,
1445 help="Path to the input CSV",
1043 ) 1446 )
1044 parser.add_argument( 1447 parser.add_argument(
1045 "--image-zip", required=True, type=Path, 1448 "--image-zip",
1046 help="Path to the images ZIP" 1449 required=True,
1450 type=Path,
1451 help="Path to the images ZIP",
1047 ) 1452 )
1048 parser.add_argument( 1453 parser.add_argument(
1049 "--model-name", required=True, 1454 "--model-name",
1455 required=True,
1050 choices=MODEL_ENCODER_TEMPLATES.keys(), 1456 choices=MODEL_ENCODER_TEMPLATES.keys(),
1051 help="Which model template to use" 1457 help="Which model template to use",
1052 ) 1458 )
1053 parser.add_argument( 1459 parser.add_argument(
1054 "--use-pretrained", action="store_true", 1460 "--use-pretrained",
1055 help="Use pretrained weights for the model" 1461 action="store_true",
1462 help="Use pretrained weights for the model",
1056 ) 1463 )
1057 parser.add_argument( 1464 parser.add_argument(
1058 "--fine-tune", action="store_true", 1465 "--fine-tune",
1059 help="Enable fine-tuning" 1466 action="store_true",
1467 help="Enable fine-tuning",
1060 ) 1468 )
1061 parser.add_argument( 1469 parser.add_argument(
1062 "--epochs", type=int, default=10, 1470 "--epochs",
1063 help="Number of training epochs" 1471 type=int,
1472 default=10,
1473 help="Number of training epochs",
1064 ) 1474 )
1065 parser.add_argument( 1475 parser.add_argument(
1066 "--early-stop", type=int, default=5, 1476 "--early-stop",
1067 help="Early stopping patience" 1477 type=int,
1478 default=5,
1479 help="Early stopping patience",
1068 ) 1480 )
1069 parser.add_argument( 1481 parser.add_argument(
1070 "--batch-size", type=int, 1482 "--batch-size",
1071 help="Batch size (None = auto)" 1483 type=int,
1484 help="Batch size (None = auto)",
1072 ) 1485 )
1073 parser.add_argument( 1486 parser.add_argument(
1074 "--output-dir", type=Path, default=Path("learner_output"), 1487 "--output-dir",
1075 help="Where to write outputs" 1488 type=Path,
1489 default=Path("learner_output"),
1490 help="Where to write outputs",
1076 ) 1491 )
1077 parser.add_argument( 1492 parser.add_argument(
1078 "--validation-size", type=float, default=0.15, 1493 "--validation-size",
1079 help="Fraction for validation (0.0–1.0)" 1494 type=float,
1495 default=0.15,
1496 help="Fraction for validation (0.0–1.0)",
1080 ) 1497 )
1081 parser.add_argument( 1498 parser.add_argument(
1082 "--preprocessing-num-processes", type=int, 1499 "--preprocessing-num-processes",
1500 type=int,
1083 default=max(1, os.cpu_count() // 2), 1501 default=max(1, os.cpu_count() // 2),
1084 help="CPU processes for data prep" 1502 help="CPU processes for data prep",
1085 ) 1503 )
1086 parser.add_argument( 1504 parser.add_argument(
1087 "--split-probabilities", type=float, nargs=3, 1505 "--split-probabilities",
1506 type=float,
1507 nargs=3,
1088 metavar=("train", "val", "test"), 1508 metavar=("train", "val", "test"),
1089 action=SplitProbAction, 1509 action=SplitProbAction,
1090 default=[0.7, 0.1, 0.2], 1510 default=[0.7, 0.1, 0.2],
1091 help="Random split proportions (e.g., 0.7 0.1 0.2). Only used if no split column is present." 1511 help="Random split proportions (e.g., 0.7 0.1 0.2). Only used if no split column.",
1092 ) 1512 )
1093 parser.add_argument( 1513 parser.add_argument(
1094 "--random-seed", type=int, default=42, 1514 "--random-seed",
1095 help="Random seed used for dataset splitting (default: 42)" 1515 type=int,
1516 default=42,
1517 help="Random seed used for dataset splitting (default: 42)",
1096 ) 1518 )
1097 parser.add_argument( 1519 parser.add_argument(
1098 "--learning-rate", type=parse_learning_rate, default=None, 1520 "--learning-rate",
1099 help="Learning rate. If not provided, Ludwig will auto-select it." 1521 type=parse_learning_rate,
1522 default=None,
1523 help="Learning rate. If not provided, Ludwig will auto-select it.",
1100 ) 1524 )
1101 1525
1102 args = parser.parse_args() 1526 args = parser.parse_args()
1103 1527
1104 # -- Validation --
1105 if not 0.0 <= args.validation_size <= 1.0: 1528 if not 0.0 <= args.validation_size <= 1.0:
1106 parser.error("validation-size must be between 0.0 and 1.0") 1529 parser.error("validation-size must be between 0.0 and 1.0")
1107 if not args.csv_file.is_file(): 1530 if not args.csv_file.is_file():
1108 parser.error(f"CSV not found: {args.csv_file}") 1531 parser.error(f"CSV not found: {args.csv_file}")
1109 if not args.image_zip.is_file(): 1532 if not args.image_zip.is_file():
1110 parser.error(f"ZIP not found: {args.image_zip}") 1533 parser.error(f"ZIP not found: {args.image_zip}")
1111 1534
1112 # --- Instantiate Backend and Orchestrator ---
1113 # Use the new LudwigDirectBackend
1114 backend_instance = LudwigDirectBackend() 1535 backend_instance = LudwigDirectBackend()
1115 orchestrator = WorkflowOrchestrator(args, backend_instance) 1536 orchestrator = WorkflowOrchestrator(args, backend_instance)
1116 1537
1117 # --- Run Workflow ---
1118 exit_code = 0 1538 exit_code = 0
1119 try: 1539 try:
1120 orchestrator.run() 1540 orchestrator.run()
1121 logger.info("Main script finished successfully.") 1541 logger.info("Main script finished successfully.")
1122 except Exception as e: 1542 except Exception as e:
1124 exit_code = 1 1544 exit_code = 1
1125 finally: 1545 finally:
1126 sys.exit(exit_code) 1546 sys.exit(exit_code)
1127 1547
1128 1548
1129 if __name__ == '__main__': 1549 if __name__ == "__main__":
1130 try: 1550 try:
1131 import ludwig 1551 import ludwig
1552
1132 logger.debug(f"Found Ludwig version: {ludwig.globals.LUDWIG_VERSION}") 1553 logger.debug(f"Found Ludwig version: {ludwig.globals.LUDWIG_VERSION}")
1133 except ImportError: 1554 except ImportError:
1134 logger.error("Ludwig library not found. Please ensure Ludwig is installed ('pip install ludwig[image]')") 1555 logger.error(
1556 "Ludwig library not found. Please ensure Ludwig is installed "
1557 "('pip install ludwig[image]')"
1558 )
1135 sys.exit(1) 1559 sys.exit(1)
1136 1560
1137 main() 1561 main()