Mercurial > repos > goeckslab > image_learner
comparison image_learner_cli.py @ 1:39202fe5cf97 draft
planemo upload for repository https://github.com/goeckslab/gleam.git commit 06c0da44ac93256dfb616a6b40276b5485a71e8e
author | goeckslab |
---|---|
date | Wed, 02 Jul 2025 18:59:10 +0000 |
parents | 54b871dfc51e |
children | 186424a7eca7 |
comparison
equal
deleted
inserted
replaced
0:54b871dfc51e | 1:39202fe5cf97 |
---|---|
22 from ludwig.visualize import get_visualizations_registry | 22 from ludwig.visualize import get_visualizations_registry |
23 from sklearn.model_selection import train_test_split | 23 from sklearn.model_selection import train_test_split |
24 from utils import encode_image_to_base64, get_html_closing, get_html_template | 24 from utils import encode_image_to_base64, get_html_closing, get_html_template |
25 | 25 |
26 # --- Constants --- | 26 # --- Constants --- |
27 SPLIT_COLUMN_NAME = 'split' | 27 SPLIT_COLUMN_NAME = "split" |
28 LABEL_COLUMN_NAME = 'label' | 28 LABEL_COLUMN_NAME = "label" |
29 IMAGE_PATH_COLUMN_NAME = 'image_path' | 29 IMAGE_PATH_COLUMN_NAME = "image_path" |
30 DEFAULT_SPLIT_PROBABILITIES = [0.7, 0.1, 0.2] | 30 DEFAULT_SPLIT_PROBABILITIES = [0.7, 0.1, 0.2] |
31 TEMP_CSV_FILENAME = "processed_data_for_ludwig.csv" | 31 TEMP_CSV_FILENAME = "processed_data_for_ludwig.csv" |
32 TEMP_CONFIG_FILENAME = "ludwig_config.yaml" | 32 TEMP_CONFIG_FILENAME = "ludwig_config.yaml" |
33 TEMP_DIR_PREFIX = "ludwig_api_work_" | 33 TEMP_DIR_PREFIX = "ludwig_api_work_" |
34 MODEL_ENCODER_TEMPLATES: Dict[str, Any] = { | 34 MODEL_ENCODER_TEMPLATES: Dict[str, Any] = { |
35 'stacked_cnn': 'stacked_cnn', | 35 "stacked_cnn": "stacked_cnn", |
36 'resnet18': {'type': 'resnet', 'model_variant': 18}, | 36 "resnet18": {"type": "resnet", "model_variant": 18}, |
37 'resnet34': {'type': 'resnet', 'model_variant': 34}, | 37 "resnet34": {"type": "resnet", "model_variant": 34}, |
38 'resnet50': {'type': 'resnet', 'model_variant': 50}, | 38 "resnet50": {"type": "resnet", "model_variant": 50}, |
39 'resnet101': {'type': 'resnet', 'model_variant': 101}, | 39 "resnet101": {"type": "resnet", "model_variant": 101}, |
40 'resnet152': {'type': 'resnet', 'model_variant': 152}, | 40 "resnet152": {"type": "resnet", "model_variant": 152}, |
41 'resnext50_32x4d': {'type': 'resnext', 'model_variant': '50_32x4d'}, | 41 "resnext50_32x4d": {"type": "resnext", "model_variant": "50_32x4d"}, |
42 'resnext101_32x8d': {'type': 'resnext', 'model_variant': '101_32x8d'}, | 42 "resnext101_32x8d": {"type": "resnext", "model_variant": "101_32x8d"}, |
43 'resnext101_64x4d': {'type': 'resnext', 'model_variant': '101_64x4d'}, | 43 "resnext101_64x4d": {"type": "resnext", "model_variant": "101_64x4d"}, |
44 'resnext152_32x8d': {'type': 'resnext', 'model_variant': '152_32x8d'}, | 44 "resnext152_32x8d": {"type": "resnext", "model_variant": "152_32x8d"}, |
45 'wide_resnet50_2': {'type': 'wide_resnet', 'model_variant': '50_2'}, | 45 "wide_resnet50_2": {"type": "wide_resnet", "model_variant": "50_2"}, |
46 'wide_resnet101_2': {'type': 'wide_resnet', 'model_variant': '101_2'}, | 46 "wide_resnet101_2": {"type": "wide_resnet", "model_variant": "101_2"}, |
47 'wide_resnet103_2': {'type': 'wide_resnet', 'model_variant': '103_2'}, | 47 "wide_resnet103_2": {"type": "wide_resnet", "model_variant": "103_2"}, |
48 'efficientnet_b0': {'type': 'efficientnet', 'model_variant': 'b0'}, | 48 "efficientnet_b0": {"type": "efficientnet", "model_variant": "b0"}, |
49 'efficientnet_b1': {'type': 'efficientnet', 'model_variant': 'b1'}, | 49 "efficientnet_b1": {"type": "efficientnet", "model_variant": "b1"}, |
50 'efficientnet_b2': {'type': 'efficientnet', 'model_variant': 'b2'}, | 50 "efficientnet_b2": {"type": "efficientnet", "model_variant": "b2"}, |
51 'efficientnet_b3': {'type': 'efficientnet', 'model_variant': 'b3'}, | 51 "efficientnet_b3": {"type": "efficientnet", "model_variant": "b3"}, |
52 'efficientnet_b4': {'type': 'efficientnet', 'model_variant': 'b4'}, | 52 "efficientnet_b4": {"type": "efficientnet", "model_variant": "b4"}, |
53 'efficientnet_b5': {'type': 'efficientnet', 'model_variant': 'b5'}, | 53 "efficientnet_b5": {"type": "efficientnet", "model_variant": "b5"}, |
54 'efficientnet_b6': {'type': 'efficientnet', 'model_variant': 'b6'}, | 54 "efficientnet_b6": {"type": "efficientnet", "model_variant": "b6"}, |
55 'efficientnet_b7': {'type': 'efficientnet', 'model_variant': 'b7'}, | 55 "efficientnet_b7": {"type": "efficientnet", "model_variant": "b7"}, |
56 'efficientnet_v2_s': {'type': 'efficientnet', 'model_variant': 'v2_s'}, | 56 "efficientnet_v2_s": {"type": "efficientnet", "model_variant": "v2_s"}, |
57 'efficientnet_v2_m': {'type': 'efficientnet', 'model_variant': 'v2_m'}, | 57 "efficientnet_v2_m": {"type": "efficientnet", "model_variant": "v2_m"}, |
58 'efficientnet_v2_l': {'type': 'efficientnet', 'model_variant': 'v2_l'}, | 58 "efficientnet_v2_l": {"type": "efficientnet", "model_variant": "v2_l"}, |
59 'regnet_y_400mf': {'type': 'regnet', 'model_variant': 'y_400mf'}, | 59 "regnet_y_400mf": {"type": "regnet", "model_variant": "y_400mf"}, |
60 'regnet_y_800mf': {'type': 'regnet', 'model_variant': 'y_800mf'}, | 60 "regnet_y_800mf": {"type": "regnet", "model_variant": "y_800mf"}, |
61 'regnet_y_1_6gf': {'type': 'regnet', 'model_variant': 'y_1_6gf'}, | 61 "regnet_y_1_6gf": {"type": "regnet", "model_variant": "y_1_6gf"}, |
62 'regnet_y_3_2gf': {'type': 'regnet', 'model_variant': 'y_3_2gf'}, | 62 "regnet_y_3_2gf": {"type": "regnet", "model_variant": "y_3_2gf"}, |
63 'regnet_y_8gf': {'type': 'regnet', 'model_variant': 'y_8gf'}, | 63 "regnet_y_8gf": {"type": "regnet", "model_variant": "y_8gf"}, |
64 'regnet_y_16gf': {'type': 'regnet', 'model_variant': 'y_16gf'}, | 64 "regnet_y_16gf": {"type": "regnet", "model_variant": "y_16gf"}, |
65 'regnet_y_32gf': {'type': 'regnet', 'model_variant': 'y_32gf'}, | 65 "regnet_y_32gf": {"type": "regnet", "model_variant": "y_32gf"}, |
66 'regnet_y_128gf': {'type': 'regnet', 'model_variant': 'y_128gf'}, | 66 "regnet_y_128gf": {"type": "regnet", "model_variant": "y_128gf"}, |
67 'regnet_x_400mf': {'type': 'regnet', 'model_variant': 'x_400mf'}, | 67 "regnet_x_400mf": {"type": "regnet", "model_variant": "x_400mf"}, |
68 'regnet_x_800mf': {'type': 'regnet', 'model_variant': 'x_800mf'}, | 68 "regnet_x_800mf": {"type": "regnet", "model_variant": "x_800mf"}, |
69 'regnet_x_1_6gf': {'type': 'regnet', 'model_variant': 'x_1_6gf'}, | 69 "regnet_x_1_6gf": {"type": "regnet", "model_variant": "x_1_6gf"}, |
70 'regnet_x_3_2gf': {'type': 'regnet', 'model_variant': 'x_3_2gf'}, | 70 "regnet_x_3_2gf": {"type": "regnet", "model_variant": "x_3_2gf"}, |
71 'regnet_x_8gf': {'type': 'regnet', 'model_variant': 'x_8gf'}, | 71 "regnet_x_8gf": {"type": "regnet", "model_variant": "x_8gf"}, |
72 'regnet_x_16gf': {'type': 'regnet', 'model_variant': 'x_16gf'}, | 72 "regnet_x_16gf": {"type": "regnet", "model_variant": "x_16gf"}, |
73 'regnet_x_32gf': {'type': 'regnet', 'model_variant': 'x_32gf'}, | 73 "regnet_x_32gf": {"type": "regnet", "model_variant": "x_32gf"}, |
74 'vgg11': {'type': 'vgg', 'model_variant': 11}, | 74 "vgg11": {"type": "vgg", "model_variant": 11}, |
75 'vgg11_bn': {'type': 'vgg', 'model_variant': '11_bn'}, | 75 "vgg11_bn": {"type": "vgg", "model_variant": "11_bn"}, |
76 'vgg13': {'type': 'vgg', 'model_variant': 13}, | 76 "vgg13": {"type": "vgg", "model_variant": 13}, |
77 'vgg13_bn': {'type': 'vgg', 'model_variant': '13_bn'}, | 77 "vgg13_bn": {"type": "vgg", "model_variant": "13_bn"}, |
78 'vgg16': {'type': 'vgg', 'model_variant': 16}, | 78 "vgg16": {"type": "vgg", "model_variant": 16}, |
79 'vgg16_bn': {'type': 'vgg', 'model_variant': '16_bn'}, | 79 "vgg16_bn": {"type": "vgg", "model_variant": "16_bn"}, |
80 'vgg19': {'type': 'vgg', 'model_variant': 19}, | 80 "vgg19": {"type": "vgg", "model_variant": 19}, |
81 'vgg19_bn': {'type': 'vgg', 'model_variant': '19_bn'}, | 81 "vgg19_bn": {"type": "vgg", "model_variant": "19_bn"}, |
82 'shufflenet_v2_x0_5': {'type': 'shufflenet_v2', 'model_variant': 'x0_5'}, | 82 "shufflenet_v2_x0_5": {"type": "shufflenet_v2", "model_variant": "x0_5"}, |
83 'shufflenet_v2_x1_0': {'type': 'shufflenet_v2', 'model_variant': 'x1_0'}, | 83 "shufflenet_v2_x1_0": {"type": "shufflenet_v2", "model_variant": "x1_0"}, |
84 'shufflenet_v2_x1_5': {'type': 'shufflenet_v2', 'model_variant': 'x1_5'}, | 84 "shufflenet_v2_x1_5": {"type": "shufflenet_v2", "model_variant": "x1_5"}, |
85 'shufflenet_v2_x2_0': {'type': 'shufflenet_v2', 'model_variant': 'x2_0'}, | 85 "shufflenet_v2_x2_0": {"type": "shufflenet_v2", "model_variant": "x2_0"}, |
86 'squeezenet1_0': {'type': 'squeezenet', 'model_variant': '1_0'}, | 86 "squeezenet1_0": {"type": "squeezenet", "model_variant": "1_0"}, |
87 'squeezenet1_1': {'type': 'squeezenet', 'model_variant': '1_1'}, | 87 "squeezenet1_1": {"type": "squeezenet", "model_variant": "1_1"}, |
88 'swin_t': {'type': 'swin_transformer', 'model_variant': 't'}, | 88 "swin_t": {"type": "swin_transformer", "model_variant": "t"}, |
89 'swin_s': {'type': 'swin_transformer', 'model_variant': 's'}, | 89 "swin_s": {"type": "swin_transformer", "model_variant": "s"}, |
90 'swin_b': {'type': 'swin_transformer', 'model_variant': 'b'}, | 90 "swin_b": {"type": "swin_transformer", "model_variant": "b"}, |
91 'swin_v2_t': {'type': 'swin_transformer', 'model_variant': 'v2_t'}, | 91 "swin_v2_t": {"type": "swin_transformer", "model_variant": "v2_t"}, |
92 'swin_v2_s': {'type': 'swin_transformer', 'model_variant': 'v2_s'}, | 92 "swin_v2_s": {"type": "swin_transformer", "model_variant": "v2_s"}, |
93 'swin_v2_b': {'type': 'swin_transformer', 'model_variant': 'v2_b'}, | 93 "swin_v2_b": {"type": "swin_transformer", "model_variant": "v2_b"}, |
94 'vit_b_16': {'type': 'vision_transformer', 'model_variant': 'b_16'}, | 94 "vit_b_16": {"type": "vision_transformer", "model_variant": "b_16"}, |
95 'vit_b_32': {'type': 'vision_transformer', 'model_variant': 'b_32'}, | 95 "vit_b_32": {"type": "vision_transformer", "model_variant": "b_32"}, |
96 'vit_l_16': {'type': 'vision_transformer', 'model_variant': 'l_16'}, | 96 "vit_l_16": {"type": "vision_transformer", "model_variant": "l_16"}, |
97 'vit_l_32': {'type': 'vision_transformer', 'model_variant': 'l_32'}, | 97 "vit_l_32": {"type": "vision_transformer", "model_variant": "l_32"}, |
98 'vit_h_14': {'type': 'vision_transformer', 'model_variant': 'h_14'}, | 98 "vit_h_14": {"type": "vision_transformer", "model_variant": "h_14"}, |
99 'convnext_tiny': {'type': 'convnext', 'model_variant': 'tiny'}, | 99 "convnext_tiny": {"type": "convnext", "model_variant": "tiny"}, |
100 'convnext_small': {'type': 'convnext', 'model_variant': 'small'}, | 100 "convnext_small": {"type": "convnext", "model_variant": "small"}, |
101 'convnext_base': {'type': 'convnext', 'model_variant': 'base'}, | 101 "convnext_base": {"type": "convnext", "model_variant": "base"}, |
102 'convnext_large': {'type': 'convnext', 'model_variant': 'large'}, | 102 "convnext_large": {"type": "convnext", "model_variant": "large"}, |
103 'maxvit_t': {'type': 'maxvit', 'model_variant': 't'}, | 103 "maxvit_t": {"type": "maxvit", "model_variant": "t"}, |
104 'alexnet': {'type': 'alexnet'}, | 104 "alexnet": {"type": "alexnet"}, |
105 'googlenet': {'type': 'googlenet'}, | 105 "googlenet": {"type": "googlenet"}, |
106 'inception_v3': {'type': 'inception_v3'}, | 106 "inception_v3": {"type": "inception_v3"}, |
107 'mobilenet_v2': {'type': 'mobilenet_v2'}, | 107 "mobilenet_v2": {"type": "mobilenet_v2"}, |
108 'mobilenet_v3_large': {'type': 'mobilenet_v3_large'}, | 108 "mobilenet_v3_large": {"type": "mobilenet_v3_large"}, |
109 'mobilenet_v3_small': {'type': 'mobilenet_v3_small'}, | 109 "mobilenet_v3_small": {"type": "mobilenet_v3_small"}, |
110 } | |
111 METRIC_DISPLAY_NAMES = { | |
112 "accuracy": "Accuracy", | |
113 "accuracy_micro": "Accuracy-Micro", | |
114 "loss": "Loss", | |
115 "roc_auc": "ROC-AUC", | |
116 "roc_auc_macro": "ROC-AUC-Macro", | |
117 "roc_auc_micro": "ROC-AUC-Micro", | |
118 "hits_at_k": "Hits at K", | |
119 "precision": "Precision", | |
120 "recall": "Recall", | |
121 "specificity": "Specificity", | |
122 "kappa_score": "Cohen's Kappa", | |
123 "token_accuracy": "Token Accuracy", | |
124 "avg_precision_macro": "Precision-Macro", | |
125 "avg_recall_macro": "Recall-Macro", | |
126 "avg_f1_score_macro": "F1-score-Macro", | |
127 "avg_precision_micro": "Precision-Micro", | |
128 "avg_recall_micro": "Recall-Micro", | |
129 "avg_f1_score_micro": "F1-score-Micro", | |
130 "avg_precision_weighted": "Precision-Weighted", | |
131 "avg_recall_weighted": "Recall-Weighted", | |
132 "avg_f1_score_weighted": "F1-score-Weighted", | |
133 "average_precision_macro": " Precision-Average-Macro", | |
134 "average_precision_micro": "Precision-Average-Micro", | |
135 "average_precision_samples": "Precision-Average-Samples", | |
110 } | 136 } |
111 | 137 |
112 # --- Logging Setup --- | 138 # --- Logging Setup --- |
113 logging.basicConfig( | 139 logging.basicConfig( |
114 level=logging.INFO, | 140 level=logging.INFO, |
115 format='%(asctime)s %(levelname)s %(name)s: %(message)s' | 141 format="%(asctime)s %(levelname)s %(name)s: %(message)s", |
116 ) | 142 ) |
117 logger = logging.getLogger("ImageLearner") | 143 logger = logging.getLogger("ImageLearner") |
118 | 144 |
119 | 145 |
146 def get_metrics_help_modal() -> str: | |
147 modal_html = """ | |
148 <div id="metricsHelpModal" class="modal"> | |
149 <div class="modal-content"> | |
150 <span class="close">×</span> | |
151 <h2>Model Evaluation Metrics — Help Guide</h2> | |
152 <div class="metrics-guide"> | |
153 <h3>1) General Metrics</h3> | |
154 <p><strong>Loss:</strong> Measures the difference between predicted and actual values. Lower is better. Often used for optimization during training.</p> | |
155 <p><strong>Accuracy:</strong> Proportion of correct predictions among all predictions. Simple but can be misleading for imbalanced datasets.</p> | |
156 <p><strong>Micro Accuracy:</strong> Calculates accuracy by summing up all individual true positives and true negatives across all classes, making it suitable for multiclass or multilabel problems.</p> | |
157 <p><strong>Token Accuracy:</strong> Measures how often the predicted tokens (e.g., in sequences) match the true tokens. Useful in sequence prediction tasks like NLP.</p> | |
158 <h3>2) Precision, Recall & Specificity</h3> | |
159 <p><strong>Precision:</strong> Out of all positive predictions, how many were correct. Precision = TP / (TP + FP). Helps when false positives are costly.</p> | |
160 <p><strong>Recall (Sensitivity):</strong> Out of all actual positives, how many were predicted correctly. Recall = TP / (TP + FN). Important when missing positives is risky.</p> | |
161 <p><strong>Specificity:</strong> True negative rate. Measures how well the model identifies negatives. Specificity = TN / (TN + FP). Useful in medical testing to avoid false alarms.</p> | |
162 <h3>3) Macro, Micro, and Weighted Averages</h3> | |
163 <p><strong>Macro Precision / Recall / F1:</strong> Averages the metric across all classes, treating each class equally, regardless of class frequency. Best when class sizes are balanced.</p> | |
164 <p><strong>Micro Precision / Recall / F1:</strong> Aggregates TP, FP, FN across all classes before computing the metric. Gives a global view and is ideal for class-imbalanced problems.</p> | |
165 <p><strong>Weighted Precision / Recall / F1:</strong> Averages each metric across classes, weighted by the number of true instances per class. Balances importance of classes based on frequency.</p> | |
166 <h3>4) Average Precision (PR-AUC Variants)</h3> | |
167 <p><strong>Average Precision Macro:</strong> Precision-Recall AUC averaged across all classes equally. Useful for balanced multi-class problems.</p> | |
168 <p><strong>Average Precision Micro:</strong> Global Precision-Recall AUC using all instances. Best for imbalanced data or multi-label classification.</p> | |
169 <p><strong>Average Precision Samples:</strong> Precision-Recall AUC averaged across individual samples (not classes). Ideal for multi-label problems where each sample can belong to multiple classes.</p> | |
170 <h3>5) ROC-AUC Variants</h3> | |
171 <p><strong>ROC-AUC:</strong> Measures model's ability to distinguish between classes. AUC = 1 is perfect; 0.5 is random guessing. Use for binary classification.</p> | |
172 <p><strong>Macro ROC-AUC:</strong> Averages the AUC across all classes equally. Suitable when classes are balanced and of equal importance.</p> | |
173 <p><strong>Micro ROC-AUC:</strong> Computes AUC from aggregated predictions across all classes. Useful in multiclass or multilabel settings with imbalance.</p> | |
174 <h3>6) Ranking Metrics</h3> | |
175 <p><strong>Hits at K:</strong> Measures whether the true label is among the top-K predictions. Common in recommendation systems and retrieval tasks.</p> | |
176 <h3>7) Confusion Matrix Stats (Per Class)</h3> | |
177 <p><strong>True Positives / Negatives (TP / TN):</strong> Correct predictions for positives and negatives respectively.</p> | |
178 <p><strong>False Positives / Negatives (FP / FN):</strong> Incorrect predictions — false alarms and missed detections.</p> | |
179 <h3>8) Other Useful Metrics</h3> | |
180 <p><strong>Cohen's Kappa:</strong> Measures agreement between predicted and actual values adjusted for chance. Useful for multiclass classification with imbalanced labels.</p> | |
181 <p><strong>Matthews Correlation Coefficient (MCC):</strong> Balanced measure of prediction quality that takes into account TP, TN, FP, and FN. Particularly effective for imbalanced datasets.</p> | |
182 <h3>9) Metric Recommendations</h3> | |
183 <ul> | |
184 <li>Use <strong>Accuracy + F1</strong> for balanced data.</li> | |
185 <li>Use <strong>Precision, Recall, ROC-AUC</strong> for imbalanced datasets.</li> | |
186 <li>Use <strong>Average Precision Micro</strong> for multilabel or class-imbalanced problems.</li> | |
187 <li>Use <strong>Macro scores</strong> when all classes should be treated equally.</li> | |
188 <li>Use <strong>Weighted scores</strong> when class imbalance should be accounted for without ignoring small classes.</li> | |
189 <li>Use <strong>Confusion Matrix stats</strong> to analyze class-wise performance.</li> | |
190 <li>Use <strong>Hits at K</strong> for recommendation or ranking-based tasks.</li> | |
191 </ul> | |
192 </div> | |
193 </div> | |
194 </div> | |
195 """ | |
196 modal_css = """ | |
197 <style> | |
198 .modal { | |
199 display: none; | |
200 position: fixed; | |
201 z-index: 1; | |
202 left: 0; | |
203 top: 0; | |
204 width: 100%; | |
205 height: 100%; | |
206 overflow: auto; | |
207 background-color: rgba(0,0,0,0.4); | |
208 } | |
209 .modal-content { | |
210 background-color: #fefefe; | |
211 margin: 15% auto; | |
212 padding: 20px; | |
213 border: 1px solid #888; | |
214 width: 80%; | |
215 max-width: 800px; | |
216 } | |
217 .close { | |
218 color: #aaa; | |
219 float: right; | |
220 font-size: 28px; | |
221 font-weight: bold; | |
222 } | |
223 .close:hover, | |
224 .close:focus { | |
225 color: black; | |
226 text-decoration: none; | |
227 cursor: pointer; | |
228 } | |
229 .metrics-guide h3 { | |
230 margin-top: 20px; | |
231 } | |
232 .metrics-guide p { | |
233 margin: 5px 0; | |
234 } | |
235 .metrics-guide ul { | |
236 margin: 10px 0; | |
237 padding-left: 20px; | |
238 } | |
239 </style> | |
240 """ | |
241 modal_js = """ | |
242 <script> | |
243 document.addEventListener("DOMContentLoaded", function() { | |
244 var modal = document.getElementById("metricsHelpModal"); | |
245 var closeBtn = document.getElementsByClassName("close")[0]; | |
246 | |
247 document.querySelectorAll(".openMetricsHelp").forEach(btn => { | |
248 btn.onclick = function() { | |
249 modal.style.display = "block"; | |
250 }; | |
251 }); | |
252 | |
253 if (closeBtn) { | |
254 closeBtn.onclick = function() { | |
255 modal.style.display = "none"; | |
256 }; | |
257 } | |
258 | |
259 window.onclick = function(event) { | |
260 if (event.target == modal) { | |
261 modal.style.display = "none"; | |
262 } | |
263 } | |
264 }); | |
265 </script> | |
266 """ | |
267 return modal_css + modal_html + modal_js | |
268 | |
269 | |
120 def format_config_table_html( | 270 def format_config_table_html( |
121 config: dict, | 271 config: dict, |
122 split_info: Optional[str] = None, | 272 split_info: Optional[str] = None, |
123 training_progress: dict = None) -> str: | 273 training_progress: dict = None, |
274 ) -> str: | |
124 display_keys = [ | 275 display_keys = [ |
125 "model_name", | 276 "model_name", |
126 "epochs", | 277 "epochs", |
127 "batch_size", | 278 "batch_size", |
128 "fine_tune", | 279 "fine_tune", |
141 val = int(val) | 292 val = int(val) |
142 else: | 293 else: |
143 if training_progress: | 294 if training_progress: |
144 val = "Auto-selected batch size by Ludwig:<br>" | 295 val = "Auto-selected batch size by Ludwig:<br>" |
145 resolved_val = training_progress.get("batch_size") | 296 resolved_val = training_progress.get("batch_size") |
146 val += ( | 297 val += f"<span style='font-size: 0.85em;'>{resolved_val}</span><br>" |
147 f"<span style='font-size: 0.85em;'>{resolved_val}</span><br>" | |
148 ) | |
149 else: | 298 else: |
150 val = "auto" | 299 val = "auto" |
151 if key == "learning_rate": | 300 if key == "learning_rate": |
152 resolved_val = None | 301 resolved_val = None |
153 if val is None or val == "auto": | 302 if val is None or val == "auto": |
154 if training_progress: | 303 if training_progress: |
155 resolved_val = training_progress.get("learning_rate") | 304 resolved_val = training_progress.get("learning_rate") |
156 val = ( | 305 val = ( |
157 "Auto-selected learning rate by Ludwig:<br>" | 306 "Auto-selected learning rate by Ludwig:<br>" |
158 f"<span style='font-size: 0.85em;'>{resolved_val if resolved_val else val}</span><br>" | 307 f"<span style='font-size: 0.85em;'>" |
308 f"{resolved_val if resolved_val else val}</span><br>" | |
159 "<span style='font-size: 0.85em;'>" | 309 "<span style='font-size: 0.85em;'>" |
160 "Based on model architecture and training setup (e.g., fine-tuning).<br>" | 310 "Based on model architecture and training setup " |
161 "See <a href='https://ludwig.ai/latest/configuration/trainer/#trainer-parameters' " | 311 "(e.g., fine-tuning).<br>" |
162 "target='_blank'>Ludwig Trainer Parameters</a> for details." | 312 "See <a href='https://ludwig.ai/latest/configuration/trainer/" |
313 "#trainer-parameters' target='_blank'>" | |
314 "Ludwig Trainer Parameters</a> for details." | |
163 "</span>" | 315 "</span>" |
164 ) | 316 ) |
165 else: | 317 else: |
166 val = ( | 318 val = ( |
167 "Auto-selected by Ludwig<br>" | 319 "Auto-selected by Ludwig<br>" |
168 "<span style='font-size: 0.85em;'>" | 320 "<span style='font-size: 0.85em;'>" |
169 "Automatically tuned based on architecture and dataset.<br>" | 321 "Automatically tuned based on architecture and dataset.<br>" |
170 "See <a href='https://ludwig.ai/latest/configuration/trainer/#trainer-parameters' " | 322 "See <a href='https://ludwig.ai/latest/configuration/trainer/" |
171 "target='_blank'>Ludwig Trainer Parameters</a> for details." | 323 "#trainer-parameters' target='_blank'>" |
324 "Ludwig Trainer Parameters</a> for details." | |
172 "</span>" | 325 "</span>" |
173 ) | 326 ) |
174 else: | 327 else: |
175 val = f"{val:.6f}" | 328 val = f"{val:.6f}" |
176 if key == "epochs": | 329 if key == "epochs": |
177 if training_progress and "epoch" in training_progress and val > training_progress["epoch"]: | 330 if ( |
331 training_progress | |
332 and "epoch" in training_progress | |
333 and val > training_progress["epoch"] | |
334 ): | |
178 val = ( | 335 val = ( |
179 f"Because of early stopping: the training" | 336 f"Because of early stopping: the training " |
180 f"stopped at epoch {training_progress['epoch']}" | 337 f"stopped at epoch {training_progress['epoch']}" |
181 ) | 338 ) |
182 | 339 |
183 if val is None: | 340 if val is None: |
184 continue | 341 continue |
185 rows.append( | 342 rows.append( |
186 f"<tr>" | 343 f"<tr>" |
187 f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left;'>" | 344 f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left;'>" |
188 f"{key.replace('_', ' ').title()}</td>" | 345 f"{key.replace('_', ' ').title()}</td>" |
189 f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center;'>{val}</td>" | 346 f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center;'>" |
347 f"{val}</td>" | |
190 f"</tr>" | 348 f"</tr>" |
191 ) | 349 ) |
192 | 350 |
193 if split_info: | 351 if split_info: |
194 rows.append( | 352 rows.append( |
195 f"<tr>" | 353 f"<tr>" |
196 f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left;'>Data Split</td>" | 354 f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left;'>" |
197 f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center;'>{split_info}</td>" | 355 f"Data Split</td>" |
356 f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center;'>" | |
357 f"{split_info}</td>" | |
198 f"</tr>" | 358 f"</tr>" |
199 ) | 359 ) |
200 | 360 |
201 return ( | 361 return ( |
202 "<h2 style='text-align: center;'>Training Setup</h2>" | 362 "<h2 style='text-align: center;'>Training Setup</h2>" |
203 "<div style='display: flex; justify-content: center;'>" | 363 "<div style='display: flex; justify-content: center;'>" |
204 "<table style='border-collapse: collapse; width: 60%; table-layout: auto;'>" | 364 "<table style='border-collapse: collapse; width: 60%; table-layout: auto;'>" |
205 "<thead><tr>" | 365 "<thead><tr>" |
206 "<th style='padding: 10px; border: 1px solid #ccc; text-align: left;'>Parameter</th>" | 366 "<th style='padding: 10px; border: 1px solid #ccc; text-align: left;'>" |
207 "<th style='padding: 10px; border: 1px solid #ccc; text-align: center;'>Value</th>" | 367 "Parameter</th>" |
368 "<th style='padding: 10px; border: 1px solid #ccc; text-align: center;'>" | |
369 "Value</th>" | |
208 "</tr></thead><tbody>" + "".join(rows) + "</tbody></table></div><br>" | 370 "</tr></thead><tbody>" + "".join(rows) + "</tbody></table></div><br>" |
209 "<p style='text-align: center; font-size: 0.9em;'>" | 371 "<p style='text-align: center; font-size: 0.9em;'>" |
210 "Model trained using Ludwig.<br>" | 372 "Model trained using Ludwig.<br>" |
211 "If want to learn more about Ludwig default settings," | 373 "If want to learn more about Ludwig default settings," |
212 "please check the their <a href='https://ludwig.ai' target='_blank'>website(ludwig.ai)</a>." | 374 "please check the their <a href='https://ludwig.ai' target='_blank'>" |
375 "website(ludwig.ai)</a>." | |
213 "</p><hr>" | 376 "</p><hr>" |
214 ) | 377 ) |
215 | 378 |
216 | 379 |
217 def format_stats_table_html(training_stats: dict, test_stats: dict) -> str: | 380 def detect_output_type(test_stats): |
218 train_metrics = training_stats.get("training", {}).get("label", {}) | 381 """Detects if the output type is 'binary' or 'category' based on test statistics.""" |
219 val_metrics = training_stats.get("validation", {}).get("label", {}) | 382 label_stats = test_stats.get("label", {}) |
220 test_metrics = test_stats.get("label", {}) | 383 per_class = label_stats.get("per_class_stats", {}) |
221 | 384 if len(per_class) == 2: |
222 all_metrics = set(train_metrics) | set(val_metrics) | set(test_metrics) | 385 return "binary" |
386 return "category" | |
387 | |
388 | |
389 def extract_metrics_from_json( | |
390 train_stats: dict, | |
391 test_stats: dict, | |
392 output_type: str, | |
393 ) -> dict: | |
394 """Extracts relevant metrics from training and test statistics based on the output type.""" | |
395 metrics = {"training": {}, "validation": {}, "test": {}} | |
223 | 396 |
224 def get_last_value(stats, key): | 397 def get_last_value(stats, key): |
225 val = stats.get(key) | 398 val = stats.get(key) |
226 if isinstance(val, list) and val: | 399 if isinstance(val, list) and val: |
227 return val[-1] | 400 return val[-1] |
228 elif isinstance(val, (int, float)): | 401 elif isinstance(val, (int, float)): |
229 return val | 402 return val |
230 return None | 403 return None |
231 | 404 |
405 for split in ["training", "validation"]: | |
406 split_stats = train_stats.get(split, {}) | |
407 if not split_stats: | |
408 logging.warning(f"No statistics found for {split} split") | |
409 continue | |
410 label_stats = split_stats.get("label", {}) | |
411 if not label_stats: | |
412 logging.warning(f"No label statistics found for {split} split") | |
413 continue | |
414 if output_type == "binary": | |
415 metrics[split] = { | |
416 "accuracy": get_last_value(label_stats, "accuracy"), | |
417 "loss": get_last_value(label_stats, "loss"), | |
418 "precision": get_last_value(label_stats, "precision"), | |
419 "recall": get_last_value(label_stats, "recall"), | |
420 "specificity": get_last_value(label_stats, "specificity"), | |
421 "roc_auc": get_last_value(label_stats, "roc_auc"), | |
422 } | |
423 else: | |
424 metrics[split] = { | |
425 "accuracy": get_last_value(label_stats, "accuracy"), | |
426 "accuracy_micro": get_last_value(label_stats, "accuracy_micro"), | |
427 "loss": get_last_value(label_stats, "loss"), | |
428 "roc_auc": get_last_value(label_stats, "roc_auc"), | |
429 "hits_at_k": get_last_value(label_stats, "hits_at_k"), | |
430 } | |
431 | |
432 # Test metrics: dynamic extraction according to exclusions | |
433 test_label_stats = test_stats.get("label", {}) | |
434 if not test_label_stats: | |
435 logging.warning("No label statistics found for test split") | |
436 else: | |
437 combined_stats = test_stats.get("combined", {}) | |
438 overall_stats = test_label_stats.get("overall_stats", {}) | |
439 | |
440 # Define exclusions | |
441 if output_type == "binary": | |
442 exclude = {"per_class_stats", "precision_recall_curve", "roc_curve"} | |
443 else: | |
444 exclude = {"per_class_stats", "confusion_matrix"} | |
445 | |
446 # 1. Get all scalar test_label_stats not excluded | |
447 test_metrics = {} | |
448 for k, v in test_label_stats.items(): | |
449 if k in exclude: | |
450 continue | |
451 if k == "overall_stats": | |
452 continue | |
453 if isinstance(v, (int, float, str, bool)): | |
454 test_metrics[k] = v | |
455 | |
456 # 2. Add overall_stats (flattened) | |
457 for k, v in overall_stats.items(): | |
458 test_metrics[k] = v | |
459 | |
460 # 3. Optionally include combined/loss if present and not already | |
461 if "loss" in combined_stats and "loss" not in test_metrics: | |
462 test_metrics["loss"] = combined_stats["loss"] | |
463 | |
464 metrics["test"] = test_metrics | |
465 | |
466 return metrics | |
467 | |
468 | |
469 def generate_table_row(cells, styles): | |
470 """Helper function to generate an HTML table row.""" | |
471 return ( | |
472 "<tr>" | |
473 + "".join(f"<td style='{styles}'>{cell}</td>" for cell in cells) | |
474 + "</tr>" | |
475 ) | |
476 | |
477 | |
478 def format_stats_table_html(train_stats: dict, test_stats: dict) -> str: | |
479 """Formats a combined HTML table for training, validation, and test metrics.""" | |
480 output_type = detect_output_type(test_stats) | |
481 all_metrics = extract_metrics_from_json(train_stats, test_stats, output_type) | |
232 rows = [] | 482 rows = [] |
233 for metric in sorted(all_metrics): | 483 for metric_key in sorted(all_metrics["training"].keys()): |
234 t = get_last_value(train_metrics, metric) | 484 if ( |
235 v = get_last_value(val_metrics, metric) | 485 metric_key in all_metrics["validation"] |
236 te = get_last_value(test_metrics, metric) | 486 and metric_key in all_metrics["test"] |
237 if all(x is not None for x in [t, v, te]): | 487 ): |
238 row = ( | 488 display_name = METRIC_DISPLAY_NAMES.get( |
239 f"<tr>" | 489 metric_key, |
240 f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left;'>{metric}</td>" | 490 metric_key.replace("_", " ").title(), |
241 f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center;'>{t:.4f}</td>" | 491 ) |
242 f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center;'>{v:.4f}</td>" | 492 t = all_metrics["training"].get(metric_key) |
243 f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center;'>{te:.4f}</td>" | 493 v = all_metrics["validation"].get(metric_key) |
244 f"</tr>" | 494 te = all_metrics["test"].get(metric_key) |
245 ) | 495 if all(x is not None for x in [t, v, te]): |
246 rows.append(row) | 496 rows.append([display_name, f"{t:.4f}", f"{v:.4f}", f"{te:.4f}"]) |
247 | 497 |
248 if not rows: | 498 if not rows: |
249 return "<p><em>No metric values found.</em></p>" | 499 return "<table><tr><td>No metric values found.</td></tr></table>" |
250 | 500 |
251 return ( | 501 html = ( |
252 "<h2 style='text-align: center;'>Model Performance Summary</h2>" | 502 "<h2 style='text-align: center;'>Model Performance Summary</h2>" |
253 "<div style='display: flex; justify-content: center;'>" | 503 "<div style='display: flex; justify-content: center;'>" |
254 "<table style='border-collapse: collapse; width: 80%; table-layout: fixed;'>" | 504 "<table style='border-collapse: collapse; table-layout: auto;'>" |
255 "<colgroup>" | |
256 "<col style='width: 40%;'>" | |
257 "<col style='width: 20%;'>" | |
258 "<col style='width: 20%;'>" | |
259 "<col style='width: 20%;'>" | |
260 "</colgroup>" | |
261 "<thead><tr>" | 505 "<thead><tr>" |
262 "<th style='padding: 10px; border: 1px solid #ccc; text-align: left;'>Metric</th>" | 506 "<th style='padding: 10px; border: 1px solid #ccc; text-align: left; " |
263 "<th style='padding: 10px; border: 1px solid #ccc; text-align: center;'>Train</th>" | 507 "white-space: nowrap;'>Metric</th>" |
264 "<th style='padding: 10px; border: 1px solid #ccc; text-align: center;'>Validation</th>" | 508 "<th style='padding: 10px; border: 1px solid #ccc; text-align: center; " |
265 "<th style='padding: 10px; border: 1px solid #ccc; text-align: center;'>Test</th>" | 509 "white-space: nowrap;'>Train</th>" |
266 "</tr></thead><tbody>" + "".join(rows) + "</tbody></table></div><br>" | 510 "<th style='padding: 10px; border: 1px solid #ccc; text-align: center; " |
267 ) | 511 "white-space: nowrap;'>Validation</th>" |
268 | 512 "<th style='padding: 10px; border: 1px solid #ccc; text-align: center; " |
269 | 513 "white-space: nowrap;'>Test</th>" |
270 def build_tabbed_html( | 514 "</tr></thead><tbody>" |
271 metrics_html: str, | 515 ) |
272 train_viz_html: str, | 516 for row in rows: |
273 test_viz_html: str) -> str: | 517 html += generate_table_row( |
518 row, | |
519 "padding: 10px; border: 1px solid #ccc; text-align: center; " | |
520 "white-space: nowrap;", | |
521 ) | |
522 html += "</tbody></table></div><br>" | |
523 return html | |
524 | |
525 | |
526 def format_train_val_stats_table_html(train_stats: dict, test_stats: dict) -> str: | |
527 """Formats an HTML table for training and validation metrics.""" | |
528 output_type = detect_output_type(test_stats) | |
529 all_metrics = extract_metrics_from_json(train_stats, test_stats, output_type) | |
530 rows = [] | |
531 for metric_key in sorted(all_metrics["training"].keys()): | |
532 if metric_key in all_metrics["validation"]: | |
533 display_name = METRIC_DISPLAY_NAMES.get( | |
534 metric_key, | |
535 metric_key.replace("_", " ").title(), | |
536 ) | |
537 t = all_metrics["training"].get(metric_key) | |
538 v = all_metrics["validation"].get(metric_key) | |
539 if t is not None and v is not None: | |
540 rows.append([display_name, f"{t:.4f}", f"{v:.4f}"]) | |
541 | |
542 if not rows: | |
543 return "<table><tr><td>No metric values found for Train/Validation.</td></tr></table>" | |
544 | |
545 html = ( | |
546 "<h2 style='text-align: center;'>Train/Validation Performance Summary</h2>" | |
547 "<div style='display: flex; justify-content: center;'>" | |
548 "<table style='border-collapse: collapse; table-layout: auto;'>" | |
549 "<thead><tr>" | |
550 "<th style='padding: 10px; border: 1px solid #ccc; text-align: left; " | |
551 "white-space: nowrap;'>Metric</th>" | |
552 "<th style='padding: 10px; border: 1px solid #ccc; text-align: center; " | |
553 "white-space: nowrap;'>Train</th>" | |
554 "<th style='padding: 10px; border: 1px solid #ccc; text-align: center; " | |
555 "white-space: nowrap;'>Validation</th>" | |
556 "</tr></thead><tbody>" | |
557 ) | |
558 for row in rows: | |
559 html += generate_table_row( | |
560 row, | |
561 "padding: 10px; border: 1px solid #ccc; text-align: center; " | |
562 "white-space: nowrap;", | |
563 ) | |
564 html += "</tbody></table></div><br>" | |
565 return html | |
566 | |
567 | |
568 def format_test_merged_stats_table_html(test_metrics: Dict[str, Optional[float]]) -> str: | |
569 """Formats an HTML table for test metrics.""" | |
570 rows = [] | |
571 for key in sorted(test_metrics.keys()): | |
572 display_name = METRIC_DISPLAY_NAMES.get(key, key.replace("_", " ").title()) | |
573 value = test_metrics[key] | |
574 if value is not None: | |
575 rows.append([display_name, f"{value:.4f}"]) | |
576 | |
577 if not rows: | |
578 return "<table><tr><td>No test metric values found.</td></tr></table>" | |
579 | |
580 html = ( | |
581 "<h2 style='text-align: center;'>Test Performance Summary</h2>" | |
582 "<div style='display: flex; justify-content: center;'>" | |
583 "<table style='border-collapse: collapse; table-layout: auto;'>" | |
584 "<thead><tr>" | |
585 "<th style='padding: 10px; border: 1px solid #ccc; text-align: left; " | |
586 "white-space: nowrap;'>Metric</th>" | |
587 "<th style='padding: 10px; border: 1px solid #ccc; text-align: center; " | |
588 "white-space: nowrap;'>Test</th>" | |
589 "</tr></thead><tbody>" | |
590 ) | |
591 for row in rows: | |
592 html += generate_table_row( | |
593 row, | |
594 "padding: 10px; border: 1px solid #ccc; text-align: center; " | |
595 "white-space: nowrap;", | |
596 ) | |
597 html += "</tbody></table></div><br>" | |
598 return html | |
599 | |
600 | |
601 def build_tabbed_html(metrics_html: str, train_val_html: str, test_html: str) -> str: | |
274 return f""" | 602 return f""" |
275 <style> | 603 <style> |
276 .tabs {{ | 604 .tabs {{ |
277 display: flex; | 605 display: flex; |
278 border-bottom: 2px solid #ccc; | 606 border-bottom: 2px solid #ccc; |
300 }} | 628 }} |
301 .tab-content.active {{ | 629 .tab-content.active {{ |
302 display: block; | 630 display: block; |
303 }} | 631 }} |
304 </style> | 632 </style> |
305 | |
306 <div class="tabs"> | 633 <div class="tabs"> |
307 <div class="tab active" onclick="showTab('metrics')"> Config & Metrics</div> | 634 <div class="tab active" onclick="showTab('metrics')"> Config & Results Summary</div> |
308 <div class="tab" onclick="showTab('trainval')"> Train/Validation Plots</div> | 635 <div class="tab" onclick="showTab('trainval')"> Train/Validation Results</div> |
309 <div class="tab" onclick="showTab('test')"> Test Plots</div> | 636 <div class="tab" onclick="showTab('test')"> Test Results</div> |
310 </div> | 637 </div> |
311 | |
312 <div id="metrics" class="tab-content active"> | 638 <div id="metrics" class="tab-content active"> |
313 {metrics_html} | 639 {metrics_html} |
314 </div> | 640 </div> |
315 <div id="trainval" class="tab-content"> | 641 <div id="trainval" class="tab-content"> |
316 {train_viz_html} | 642 {train_val_html} |
317 </div> | 643 </div> |
318 <div id="test" class="tab-content"> | 644 <div id="test" class="tab-content"> |
319 {test_viz_html} | 645 {test_html} |
320 </div> | 646 </div> |
321 | |
322 <script> | 647 <script> |
323 function showTab(id) {{ | 648 function showTab(id) {{ |
324 document.querySelectorAll('.tab-content').forEach(el => el.classList.remove('active')); | 649 document.querySelectorAll('.tab-content').forEach(el => el.classList.remove('active')); |
325 document.querySelectorAll('.tab').forEach(el => el.classList.remove('active')); | 650 document.querySelectorAll('.tab').forEach(el => el.classList.remove('active')); |
326 document.getElementById(id).classList.add('active'); | 651 document.getElementById(id).classList.add('active'); |
335 split_column: str, | 660 split_column: str, |
336 validation_size: float = 0.15, | 661 validation_size: float = 0.15, |
337 random_state: int = 42, | 662 random_state: int = 42, |
338 label_column: Optional[str] = None, | 663 label_column: Optional[str] = None, |
339 ) -> pd.DataFrame: | 664 ) -> pd.DataFrame: |
340 """ | 665 """Given a DataFrame whose split_column only contains {0,2}, re-assign a portion of the 0s to become 1s (validation).""" |
341 Given a DataFrame whose split_column only contains {0,2}, re-assign | |
342 a portion of the 0s to become 1s (validation). Returns a fresh DataFrame. | |
343 """ | |
344 # Work on a copy | |
345 out = df.copy() | 666 out = df.copy() |
346 # Ensure split col is integer dtype | |
347 out[split_column] = pd.to_numeric(out[split_column], errors="coerce").astype(int) | 667 out[split_column] = pd.to_numeric(out[split_column], errors="coerce").astype(int) |
348 | 668 |
349 idx_train = out.index[out[split_column] == 0].tolist() | 669 idx_train = out.index[out[split_column] == 0].tolist() |
350 | 670 |
351 if not idx_train: | 671 if not idx_train: |
352 logger.info("No rows with split=0; nothing to do.") | 672 logger.info("No rows with split=0; nothing to do.") |
353 return out | 673 return out |
354 | |
355 # Determine stratify array if possible | |
356 stratify_arr = None | 674 stratify_arr = None |
357 if label_column and label_column in out.columns: | 675 if label_column and label_column in out.columns: |
358 # Only stratify if at least two classes and enough samples | |
359 label_counts = out.loc[idx_train, label_column].value_counts() | 676 label_counts = out.loc[idx_train, label_column].value_counts() |
360 if label_counts.size > 1 and (label_counts.min() * validation_size) >= 1: | 677 if label_counts.size > 1 and (label_counts.min() * validation_size) >= 1: |
361 stratify_arr = out.loc[idx_train, label_column] | 678 stratify_arr = out.loc[idx_train, label_column] |
362 else: | 679 else: |
363 logger.warning("Cannot stratify (too few labels); splitting without stratify.") | 680 logger.warning( |
364 | 681 "Cannot stratify (too few labels); splitting without stratify." |
365 # Edge cases | 682 ) |
366 if validation_size <= 0: | 683 if validation_size <= 0: |
367 logger.info("validation_size <= 0; keeping all as train.") | 684 logger.info("validation_size <= 0; keeping all as train.") |
368 return out | 685 return out |
369 if validation_size >= 1: | 686 if validation_size >= 1: |
370 logger.info("validation_size >= 1; moving all train → validation.") | 687 logger.info("validation_size >= 1; moving all train → validation.") |
371 out.loc[idx_train, split_column] = 1 | 688 out.loc[idx_train, split_column] = 1 |
372 return out | 689 return out |
373 | |
374 # Do the split | |
375 try: | 690 try: |
376 train_idx, val_idx = train_test_split( | 691 train_idx, val_idx = train_test_split( |
377 idx_train, | 692 idx_train, |
378 test_size=validation_size, | 693 test_size=validation_size, |
379 random_state=random_state, | 694 random_state=random_state, |
380 stratify=stratify_arr | 695 stratify=stratify_arr, |
381 ) | 696 ) |
382 except ValueError as e: | 697 except ValueError as e: |
383 logger.warning(f"Stratified split failed ({e}); retrying without stratify.") | 698 logger.warning(f"Stratified split failed ({e}); retrying without stratify.") |
384 train_idx, val_idx = train_test_split( | 699 train_idx, val_idx = train_test_split( |
385 idx_train, | 700 idx_train, |
386 test_size=validation_size, | 701 test_size=validation_size, |
387 random_state=random_state, | 702 random_state=random_state, |
388 stratify=None | 703 stratify=None, |
389 ) | 704 ) |
390 | |
391 # Assign new splits | |
392 out.loc[train_idx, split_column] = 0 | 705 out.loc[train_idx, split_column] = 0 |
393 out.loc[val_idx, split_column] = 1 | 706 out.loc[val_idx, split_column] = 1 |
394 # idx_test stays at 2 | |
395 | |
396 # Cast back to a clean integer type | |
397 out[split_column] = out[split_column].astype(int) | 707 out[split_column] = out[split_column].astype(int) |
398 # print(out) | |
399 return out | 708 return out |
400 | 709 |
401 | 710 |
402 class Backend(Protocol): | 711 class Backend(Protocol): |
403 """Interface for a machine learning backend.""" | 712 """Interface for a machine learning backend.""" |
713 | |
404 def prepare_config( | 714 def prepare_config( |
405 self, | 715 self, |
406 config_params: Dict[str, Any], | 716 config_params: Dict[str, Any], |
407 split_config: Dict[str, Any] | 717 split_config: Dict[str, Any], |
408 ) -> str: | 718 ) -> str: |
409 ... | 719 ... |
410 | 720 |
411 def run_experiment( | 721 def run_experiment( |
412 self, | 722 self, |
430 ) -> Path: | 740 ) -> Path: |
431 ... | 741 ... |
432 | 742 |
433 | 743 |
434 class LudwigDirectBackend: | 744 class LudwigDirectBackend: |
435 """ | 745 """Backend for running Ludwig experiments directly via the internal experiment_cli function.""" |
436 Backend for running Ludwig experiments directly via the internal experiment_cli function. | |
437 """ | |
438 | 746 |
439 def prepare_config( | 747 def prepare_config( |
440 self, | 748 self, |
441 config_params: Dict[str, Any], | 749 config_params: Dict[str, Any], |
442 split_config: Dict[str, Any], | 750 split_config: Dict[str, Any], |
443 ) -> str: | 751 ) -> str: |
444 """ | 752 """Build and serialize the Ludwig YAML configuration.""" |
445 Build and serialize the Ludwig YAML configuration. | |
446 """ | |
447 logger.info("LudwigDirectBackend: Preparing YAML configuration.") | 753 logger.info("LudwigDirectBackend: Preparing YAML configuration.") |
448 | 754 |
449 model_name = config_params.get("model_name", "resnet18") | 755 model_name = config_params.get("model_name", "resnet18") |
450 use_pretrained = config_params.get("use_pretrained", False) | 756 use_pretrained = config_params.get("use_pretrained", False) |
451 fine_tune = config_params.get("fine_tune", False) | 757 fine_tune = config_params.get("fine_tune", False) |
458 trainable = fine_tune or (not use_pretrained) | 764 trainable = fine_tune or (not use_pretrained) |
459 if not use_pretrained and not trainable: | 765 if not use_pretrained and not trainable: |
460 logger.warning("trainable=False; use_pretrained=False is ignored.") | 766 logger.warning("trainable=False; use_pretrained=False is ignored.") |
461 logger.warning("Setting trainable=True to train the model from scratch.") | 767 logger.warning("Setting trainable=True to train the model from scratch.") |
462 trainable = True | 768 trainable = True |
463 | |
464 # Encoder setup | |
465 raw_encoder = MODEL_ENCODER_TEMPLATES.get(model_name, model_name) | 769 raw_encoder = MODEL_ENCODER_TEMPLATES.get(model_name, model_name) |
466 if isinstance(raw_encoder, dict): | 770 if isinstance(raw_encoder, dict): |
467 encoder_config = { | 771 encoder_config = { |
468 **raw_encoder, | 772 **raw_encoder, |
469 "use_pretrained": use_pretrained, | 773 "use_pretrained": use_pretrained, |
470 "trainable": trainable, | 774 "trainable": trainable, |
471 } | 775 } |
472 else: | 776 else: |
473 encoder_config = {"type": raw_encoder} | 777 encoder_config = {"type": raw_encoder} |
474 | 778 |
475 # Trainer & optimizer | |
476 # optimizer = {"type": "adam", "learning_rate": 5e-5} if fine_tune else {"type": "adam"} | |
477 batch_size_cfg = batch_size or "auto" | 779 batch_size_cfg = batch_size or "auto" |
780 | |
781 label_column_path = config_params.get("label_column_data_path") | |
782 if label_column_path is not None and Path(label_column_path).exists(): | |
783 try: | |
784 label_series = pd.read_csv(label_column_path)[LABEL_COLUMN_NAME] | |
785 num_unique_labels = label_series.nunique() | |
786 except Exception as e: | |
787 logger.warning( | |
788 f"Could not determine label cardinality, defaulting to 'binary': {e}" | |
789 ) | |
790 num_unique_labels = 2 | |
791 else: | |
792 logger.warning( | |
793 "label_column_data_path not provided, defaulting to 'binary'" | |
794 ) | |
795 num_unique_labels = 2 | |
796 | |
797 output_type = "binary" if num_unique_labels == 2 else "category" | |
478 | 798 |
479 conf: Dict[str, Any] = { | 799 conf: Dict[str, Any] = { |
480 "model_type": "ecd", | 800 "model_type": "ecd", |
481 "input_features": [ | 801 "input_features": [ |
482 { | 802 { |
483 "name": IMAGE_PATH_COLUMN_NAME, | 803 "name": IMAGE_PATH_COLUMN_NAME, |
484 "type": "image", | 804 "type": "image", |
485 "encoder": encoder_config, | 805 "encoder": encoder_config, |
486 } | 806 } |
487 ], | 807 ], |
488 "output_features": [ | 808 "output_features": [{"name": LABEL_COLUMN_NAME, "type": output_type}], |
489 {"name": LABEL_COLUMN_NAME, "type": "category"} | |
490 ], | |
491 "combiner": {"type": "concat"}, | 809 "combiner": {"type": "concat"}, |
492 "trainer": { | 810 "trainer": { |
493 "epochs": epochs, | 811 "epochs": epochs, |
494 "early_stop": early_stop, | 812 "early_stop": early_stop, |
495 "batch_size": batch_size_cfg, | 813 "batch_size": batch_size_cfg, |
506 try: | 824 try: |
507 yaml_str = yaml.dump(conf, sort_keys=False, indent=2) | 825 yaml_str = yaml.dump(conf, sort_keys=False, indent=2) |
508 logger.info("LudwigDirectBackend: YAML config generated.") | 826 logger.info("LudwigDirectBackend: YAML config generated.") |
509 return yaml_str | 827 return yaml_str |
510 except Exception: | 828 except Exception: |
511 logger.error("LudwigDirectBackend: Failed to serialize YAML.", exc_info=True) | 829 logger.error( |
830 "LudwigDirectBackend: Failed to serialize YAML.", | |
831 exc_info=True, | |
832 ) | |
512 raise | 833 raise |
513 | 834 |
514 def run_experiment( | 835 def run_experiment( |
515 self, | 836 self, |
516 dataset_path: Path, | 837 dataset_path: Path, |
517 config_path: Path, | 838 config_path: Path, |
518 output_dir: Path, | 839 output_dir: Path, |
519 random_seed: int = 42, | 840 random_seed: int = 42, |
520 ) -> None: | 841 ) -> None: |
521 """ | 842 """Invoke Ludwig's internal experiment_cli function to run the experiment.""" |
522 Invoke Ludwig's internal experiment_cli function to run the experiment. | |
523 """ | |
524 logger.info("LudwigDirectBackend: Starting experiment execution.") | 843 logger.info("LudwigDirectBackend: Starting experiment execution.") |
525 | 844 |
526 try: | 845 try: |
527 from ludwig.experiment import experiment_cli | 846 from ludwig.experiment import experiment_cli |
528 except ImportError as e: | 847 except ImportError as e: |
529 logger.error( | 848 logger.error( |
530 "LudwigDirectBackend: Could not import experiment_cli.", | 849 "LudwigDirectBackend: Could not import experiment_cli.", |
531 exc_info=True | 850 exc_info=True, |
532 ) | 851 ) |
533 raise RuntimeError("Ludwig import failed.") from e | 852 raise RuntimeError("Ludwig import failed.") from e |
534 | 853 |
535 output_dir.mkdir(parents=True, exist_ok=True) | 854 output_dir.mkdir(parents=True, exist_ok=True) |
536 | 855 |
539 dataset=str(dataset_path), | 858 dataset=str(dataset_path), |
540 config=str(config_path), | 859 config=str(config_path), |
541 output_directory=str(output_dir), | 860 output_directory=str(output_dir), |
542 random_seed=random_seed, | 861 random_seed=random_seed, |
543 ) | 862 ) |
544 logger.info(f"LudwigDirectBackend: Experiment completed. Results in {output_dir}") | 863 logger.info( |
864 f"LudwigDirectBackend: Experiment completed. Results in {output_dir}" | |
865 ) | |
545 except TypeError as e: | 866 except TypeError as e: |
546 logger.error( | 867 logger.error( |
547 "LudwigDirectBackend: Argument mismatch in experiment_cli call.", | 868 "LudwigDirectBackend: Argument mismatch in experiment_cli call.", |
548 exc_info=True | 869 exc_info=True, |
549 ) | 870 ) |
550 raise RuntimeError("Ludwig argument error.") from e | 871 raise RuntimeError("Ludwig argument error.") from e |
551 except Exception: | 872 except Exception: |
552 logger.error( | 873 logger.error( |
553 "LudwigDirectBackend: Experiment execution error.", | 874 "LudwigDirectBackend: Experiment execution error.", |
554 exc_info=True | 875 exc_info=True, |
555 ) | 876 ) |
556 raise | 877 raise |
557 | 878 |
558 def get_training_process(self, output_dir) -> float: | 879 def get_training_process(self, output_dir) -> float: |
559 """ | 880 """Retrieve the learning rate used in the most recent Ludwig run.""" |
560 Retrieve the learning rate used in the most recent Ludwig run. | |
561 Returns: | |
562 float: learning rate (or None if not found) | |
563 """ | |
564 output_dir = Path(output_dir) | 881 output_dir = Path(output_dir) |
565 exp_dirs = sorted( | 882 exp_dirs = sorted( |
566 output_dir.glob("experiment_run*"), | 883 output_dir.glob("experiment_run*"), |
567 key=lambda p: p.stat().st_mtime | 884 key=lambda p: p.stat().st_mtime, |
568 ) | 885 ) |
569 | 886 |
570 if not exp_dirs: | 887 if not exp_dirs: |
571 logger.warning(f"No experiment run directories found in {output_dir}") | 888 logger.warning(f"No experiment run directories found in {output_dir}") |
572 return None | 889 return None |
583 "learning_rate": data.get("learning_rate"), | 900 "learning_rate": data.get("learning_rate"), |
584 "batch_size": data.get("batch_size"), | 901 "batch_size": data.get("batch_size"), |
585 "epoch": data.get("epoch"), | 902 "epoch": data.get("epoch"), |
586 } | 903 } |
587 except Exception as e: | 904 except Exception as e: |
588 self.logger.warning(f"Failed to read training progress info: {e}") | 905 logger.warning(f"Failed to read training progress info: {e}") |
589 return {} | 906 return {} |
590 | 907 |
591 def convert_parquet_to_csv(self, output_dir: Path): | 908 def convert_parquet_to_csv(self, output_dir: Path): |
592 """Convert the predictions Parquet file to CSV.""" | 909 """Convert the predictions Parquet file to CSV.""" |
593 output_dir = Path(output_dir) | 910 output_dir = Path(output_dir) |
594 exp_dirs = sorted( | 911 exp_dirs = sorted( |
595 output_dir.glob("experiment_run*"), | 912 output_dir.glob("experiment_run*"), |
596 key=lambda p: p.stat().st_mtime | 913 key=lambda p: p.stat().st_mtime, |
597 ) | 914 ) |
598 if not exp_dirs: | 915 if not exp_dirs: |
599 logger.warning(f"No experiment run dirs found in {output_dir}") | 916 logger.warning(f"No experiment run dirs found in {output_dir}") |
600 return | 917 return |
601 exp_dir = exp_dirs[-1] | 918 exp_dir = exp_dirs[-1] |
607 logger.info(f"Converted Parquet to CSV: {csv_path}") | 924 logger.info(f"Converted Parquet to CSV: {csv_path}") |
608 except Exception as e: | 925 except Exception as e: |
609 logger.error(f"Error converting Parquet to CSV: {e}") | 926 logger.error(f"Error converting Parquet to CSV: {e}") |
610 | 927 |
611 def generate_plots(self, output_dir: Path) -> None: | 928 def generate_plots(self, output_dir: Path) -> None: |
612 """ | 929 """Generate all registered Ludwig visualizations for the latest experiment run.""" |
613 Generate _all_ registered Ludwig visualizations for the latest experiment run. | |
614 """ | |
615 logger.info("Generating all Ludwig visualizations…") | 930 logger.info("Generating all Ludwig visualizations…") |
616 | 931 |
617 test_plots = { | 932 test_plots = { |
618 'compare_performance', | 933 "compare_performance", |
619 'compare_classifiers_performance_from_prob', | 934 "compare_classifiers_performance_from_prob", |
620 'compare_classifiers_performance_from_pred', | 935 "compare_classifiers_performance_from_pred", |
621 'compare_classifiers_performance_changing_k', | 936 "compare_classifiers_performance_changing_k", |
622 'compare_classifiers_multiclass_multimetric', | 937 "compare_classifiers_multiclass_multimetric", |
623 'compare_classifiers_predictions', | 938 "compare_classifiers_predictions", |
624 'confidence_thresholding_2thresholds_2d', | 939 "confidence_thresholding_2thresholds_2d", |
625 'confidence_thresholding_2thresholds_3d', | 940 "confidence_thresholding_2thresholds_3d", |
626 'confidence_thresholding', | 941 "confidence_thresholding", |
627 'confidence_thresholding_data_vs_acc', | 942 "confidence_thresholding_data_vs_acc", |
628 'binary_threshold_vs_metric', | 943 "binary_threshold_vs_metric", |
629 'roc_curves', | 944 "roc_curves", |
630 'roc_curves_from_test_statistics', | 945 "roc_curves_from_test_statistics", |
631 'calibration_1_vs_all', | 946 "calibration_1_vs_all", |
632 'calibration_multiclass', | 947 "calibration_multiclass", |
633 'confusion_matrix', | 948 "confusion_matrix", |
634 'frequency_vs_f1', | 949 "frequency_vs_f1", |
635 } | 950 } |
636 train_plots = { | 951 train_plots = { |
637 'learning_curves', | 952 "learning_curves", |
638 'compare_classifiers_performance_subset', | 953 "compare_classifiers_performance_subset", |
639 } | 954 } |
640 | 955 |
641 # 1) find the most recent experiment directory | |
642 output_dir = Path(output_dir) | 956 output_dir = Path(output_dir) |
643 exp_dirs = sorted( | 957 exp_dirs = sorted( |
644 output_dir.glob("experiment_run*"), | 958 output_dir.glob("experiment_run*"), |
645 key=lambda p: p.stat().st_mtime | 959 key=lambda p: p.stat().st_mtime, |
646 ) | 960 ) |
647 if not exp_dirs: | 961 if not exp_dirs: |
648 logger.warning(f"No experiment run dirs found in {output_dir}") | 962 logger.warning(f"No experiment run dirs found in {output_dir}") |
649 return | 963 return |
650 exp_dir = exp_dirs[-1] | 964 exp_dir = exp_dirs[-1] |
651 | 965 |
652 # 2) ensure viz output subfolder exists | |
653 viz_dir = exp_dir / "visualizations" | 966 viz_dir = exp_dir / "visualizations" |
654 viz_dir.mkdir(exist_ok=True) | 967 viz_dir.mkdir(exist_ok=True) |
655 train_viz = viz_dir / "train" | 968 train_viz = viz_dir / "train" |
656 test_viz = viz_dir / "test" | 969 test_viz = viz_dir / "test" |
657 train_viz.mkdir(parents=True, exist_ok=True) | 970 train_viz.mkdir(parents=True, exist_ok=True) |
658 test_viz.mkdir(parents=True, exist_ok=True) | 971 test_viz.mkdir(parents=True, exist_ok=True) |
659 | 972 |
660 # 3) helper to check file existence | |
661 def _check(p: Path) -> Optional[str]: | 973 def _check(p: Path) -> Optional[str]: |
662 return str(p) if p.exists() else None | 974 return str(p) if p.exists() else None |
663 | 975 |
664 # 4) gather standard Ludwig output files | |
665 training_stats = _check(exp_dir / "training_statistics.json") | 976 training_stats = _check(exp_dir / "training_statistics.json") |
666 test_stats = _check(exp_dir / TEST_STATISTICS_FILE_NAME) | 977 test_stats = _check(exp_dir / TEST_STATISTICS_FILE_NAME) |
667 probs_path = _check(exp_dir / PREDICTIONS_PARQUET_FILE_NAME) | 978 probs_path = _check(exp_dir / PREDICTIONS_PARQUET_FILE_NAME) |
668 gt_metadata = _check(exp_dir / "model" / TRAIN_SET_METADATA_FILE_NAME) | 979 gt_metadata = _check(exp_dir / "model" / TRAIN_SET_METADATA_FILE_NAME) |
669 | 980 |
670 # 5) try to read original dataset & split file from description.json | |
671 dataset_path = None | 981 dataset_path = None |
672 split_file = None | 982 split_file = None |
673 desc = exp_dir / DESCRIPTION_FILE_NAME | 983 desc = exp_dir / DESCRIPTION_FILE_NAME |
674 if desc.exists(): | 984 if desc.exists(): |
675 with open(desc, "r") as f: | 985 with open(desc, "r") as f: |
676 cfg = json.load(f) | 986 cfg = json.load(f) |
677 dataset_path = _check(Path(cfg.get("dataset", ""))) | 987 dataset_path = _check(Path(cfg.get("dataset", ""))) |
678 split_file = _check(Path(get_split_path(cfg.get("dataset", "")))) | 988 split_file = _check(Path(get_split_path(cfg.get("dataset", "")))) |
679 | 989 |
680 # 6) infer output feature name | |
681 output_feature = "" | 990 output_feature = "" |
682 if desc.exists(): | 991 if desc.exists(): |
683 try: | 992 try: |
684 output_feature = cfg["config"]["output_features"][0]["name"] | 993 output_feature = cfg["config"]["output_features"][0]["name"] |
685 except Exception: | 994 except Exception: |
687 if not output_feature and test_stats: | 996 if not output_feature and test_stats: |
688 with open(test_stats, "r") as f: | 997 with open(test_stats, "r") as f: |
689 stats = json.load(f) | 998 stats = json.load(f) |
690 output_feature = next(iter(stats.keys()), "") | 999 output_feature = next(iter(stats.keys()), "") |
691 | 1000 |
692 # 7) loop through every registered viz | |
693 viz_registry = get_visualizations_registry() | 1001 viz_registry = get_visualizations_registry() |
694 for viz_name, viz_func in viz_registry.items(): | 1002 for viz_name, viz_func in viz_registry.items(): |
695 viz_dir_plot = None | 1003 viz_dir_plot = None |
696 if viz_name in train_plots: | 1004 if viz_name in train_plots: |
697 viz_dir_plot = train_viz | 1005 viz_dir_plot = train_viz |
719 logger.warning(f"✘ Skipped {viz_name}: {e}") | 1027 logger.warning(f"✘ Skipped {viz_name}: {e}") |
720 | 1028 |
721 logger.info(f"All visualizations written to {viz_dir}") | 1029 logger.info(f"All visualizations written to {viz_dir}") |
722 | 1030 |
723 def generate_html_report( | 1031 def generate_html_report( |
724 self, | 1032 self, |
725 title: str, | 1033 title: str, |
726 output_dir: str, | 1034 output_dir: str, |
727 config: dict, | 1035 config: dict, |
728 split_info: str) -> Path: | 1036 split_info: str, |
729 """ | 1037 ) -> Path: |
730 Assemble an HTML report from visualizations under train_val/ and test/ folders. | 1038 """Assemble an HTML report from visualizations under train_val/ and test/ folders.""" |
731 """ | |
732 cwd = Path.cwd() | 1039 cwd = Path.cwd() |
733 report_name = title.lower().replace(" ", "_") + "_report.html" | 1040 report_name = title.lower().replace(" ", "_") + "_report.html" |
734 report_path = cwd / report_name | 1041 report_path = cwd / report_name |
735 output_dir = Path(output_dir) | 1042 output_dir = Path(output_dir) |
736 | 1043 |
737 # Find latest experiment dir | 1044 exp_dirs = sorted( |
738 exp_dirs = sorted(output_dir.glob("experiment_run*"), key=lambda p: p.stat().st_mtime) | 1045 output_dir.glob("experiment_run*"), |
1046 key=lambda p: p.stat().st_mtime, | |
1047 ) | |
739 if not exp_dirs: | 1048 if not exp_dirs: |
740 raise RuntimeError(f"No 'experiment*' dirs found in {output_dir}") | 1049 raise RuntimeError(f"No 'experiment*' dirs found in {output_dir}") |
741 exp_dir = exp_dirs[-1] | 1050 exp_dir = exp_dirs[-1] |
742 | 1051 |
743 base_viz_dir = exp_dir / "visualizations" | 1052 base_viz_dir = exp_dir / "visualizations" |
746 | 1055 |
747 html = get_html_template() | 1056 html = get_html_template() |
748 html += f"<h1>{title}</h1>" | 1057 html += f"<h1>{title}</h1>" |
749 | 1058 |
750 metrics_html = "" | 1059 metrics_html = "" |
751 | 1060 train_val_metrics_html = "" |
752 # Load and embed metrics table (training/val/test stats) | 1061 test_metrics_html = "" |
1062 | |
753 try: | 1063 try: |
754 train_stats_path = exp_dir / "training_statistics.json" | 1064 train_stats_path = exp_dir / "training_statistics.json" |
755 test_stats_path = exp_dir / TEST_STATISTICS_FILE_NAME | 1065 test_stats_path = exp_dir / TEST_STATISTICS_FILE_NAME |
756 if train_stats_path.exists() and test_stats_path.exists(): | 1066 if train_stats_path.exists() and test_stats_path.exists(): |
757 with open(train_stats_path) as f: | 1067 with open(train_stats_path) as f: |
758 train_stats = json.load(f) | 1068 train_stats = json.load(f) |
759 with open(test_stats_path) as f: | 1069 with open(test_stats_path) as f: |
760 test_stats = json.load(f) | 1070 test_stats = json.load(f) |
761 output_feature = next(iter(train_stats.keys()), "") | 1071 output_type = detect_output_type(test_stats) |
762 if output_feature: | 1072 all_metrics = extract_metrics_from_json( |
763 metrics_html += format_stats_table_html(train_stats, test_stats) | 1073 train_stats, |
1074 test_stats, | |
1075 output_type, | |
1076 ) | |
1077 metrics_html = format_stats_table_html(train_stats, test_stats) | |
1078 train_val_metrics_html = format_train_val_stats_table_html( | |
1079 train_stats, | |
1080 test_stats, | |
1081 ) | |
1082 test_metrics_html = format_test_merged_stats_table_html( | |
1083 all_metrics["test"], | |
1084 ) | |
764 except Exception as e: | 1085 except Exception as e: |
765 logger.warning(f"Could not load stats for HTML report: {e}") | 1086 logger.warning( |
1087 f"Could not load stats for HTML report: {type(e).__name__}: {e}" | |
1088 ) | |
766 | 1089 |
767 config_html = "" | 1090 config_html = "" |
768 training_progress = self.get_training_process(output_dir) | 1091 training_progress = self.get_training_process(output_dir) |
769 try: | 1092 try: |
770 config_html = format_config_table_html(config, split_info, training_progress) | 1093 config_html = format_config_table_html(config, split_info, training_progress) |
771 except Exception as e: | 1094 except Exception as e: |
772 logger.warning(f"Could not load config for HTML report: {e}") | 1095 logger.warning(f"Could not load config for HTML report: {e}") |
773 | 1096 |
774 def render_img_section(title: str, dir_path: Path) -> str: | 1097 def render_img_section(title: str, dir_path: Path, output_type: str = None) -> str: |
775 if not dir_path.exists(): | 1098 if not dir_path.exists(): |
776 return f"<h2>{title}</h2><p><em>Directory not found.</em></p>" | 1099 return f"<h2>{title}</h2><p><em>Directory not found.</em></p>" |
777 imgs = sorted(dir_path.glob("*.png")) | 1100 |
1101 imgs = list(dir_path.glob("*.png")) | |
778 if not imgs: | 1102 if not imgs: |
779 return f"<h2>{title}</h2><p><em>No plots found.</em></p>" | 1103 return f"<h2>{title}</h2><p><em>No plots found.</em></p>" |
1104 | |
1105 if title == "Test Visualizations" and output_type == "binary": | |
1106 order = [ | |
1107 "confusion_matrix__label_top2.png", | |
1108 "roc_curves_from_prediction_statistics.png", | |
1109 "compare_performance_label.png", | |
1110 "confusion_matrix_entropy__label_top2.png", | |
1111 ] | |
1112 img_names = {img.name: img for img in imgs} | |
1113 ordered_imgs = [ | |
1114 img_names[fname] for fname in order if fname in img_names | |
1115 ] | |
1116 remaining = sorted( | |
1117 [ | |
1118 img | |
1119 for img in imgs | |
1120 if img.name not in order and img.name != "roc_curves.png" | |
1121 ] | |
1122 ) | |
1123 imgs = ordered_imgs + remaining | |
1124 | |
1125 elif title == "Test Visualizations" and output_type == "category": | |
1126 unwanted = { | |
1127 "compare_classifiers_multiclass_multimetric__label_best10.png", | |
1128 "compare_classifiers_multiclass_multimetric__label_top10.png", | |
1129 "compare_classifiers_multiclass_multimetric__label_worst10.png", | |
1130 } | |
1131 display_order = [ | |
1132 "confusion_matrix__label_top10.png", | |
1133 "roc_curves.png", | |
1134 "compare_performance_label.png", | |
1135 "compare_classifiers_performance_from_prob.png", | |
1136 "compare_classifiers_multiclass_multimetric__label_sorted.png", | |
1137 "confusion_matrix_entropy__label_top10.png", | |
1138 ] | |
1139 img_names = {img.name: img for img in imgs if img.name not in unwanted} | |
1140 ordered_imgs = [ | |
1141 img_names[fname] for fname in display_order if fname in img_names | |
1142 ] | |
1143 remaining = sorted( | |
1144 [ | |
1145 img | |
1146 for img in img_names.values() | |
1147 if img.name not in display_order | |
1148 ] | |
1149 ) | |
1150 imgs = ordered_imgs + remaining | |
1151 | |
1152 else: | |
1153 if output_type == "category": | |
1154 unwanted = { | |
1155 "compare_classifiers_multiclass_multimetric__label_best10.png", | |
1156 "compare_classifiers_multiclass_multimetric__label_top10.png", | |
1157 "compare_classifiers_multiclass_multimetric__label_worst10.png", | |
1158 } | |
1159 imgs = sorted([img for img in imgs if img.name not in unwanted]) | |
1160 else: | |
1161 imgs = sorted(imgs) | |
780 | 1162 |
781 section_html = f"<h2 style='text-align: center;'>{title}</h2><div>" | 1163 section_html = f"<h2 style='text-align: center;'>{title}</h2><div>" |
782 for img in imgs: | 1164 for img in imgs: |
783 b64 = encode_image_to_base64(str(img)) | 1165 b64 = encode_image_to_base64(str(img)) |
784 section_html += ( | 1166 section_html += ( |
785 f'<div class="plot" style="margin-bottom:20px;text-align:center;">' | 1167 f'<div class="plot" style="margin-bottom:20px;text-align:center;">' |
786 f"<h3>{img.stem.replace('_',' ').title()}</h3>" | 1168 f"<h3>{img.stem.replace('_', ' ').title()}</h3>" |
787 f'<img src="data:image/png;base64,{b64}" ' | 1169 f'<img src="data:image/png;base64,{b64}" ' |
788 'style="max-width:90%;max-height:600px;border:1px solid #ddd;" />' | 1170 f'style="max-width:90%;max-height:600px;border:1px solid #ddd;" />' |
789 "</div>" | 1171 f"</div>" |
790 ) | 1172 ) |
791 section_html += "</div>" | 1173 section_html += "</div>" |
792 return section_html | 1174 return section_html |
793 | 1175 |
794 train_plots_html = render_img_section("Training & Validation Visualizations", train_viz_dir) | 1176 button_html = """ |
795 test_plots_html = render_img_section("Test Visualizations", test_viz_dir) | 1177 <button class="help-modal-btn openMetricsHelp">Model Evaluation Metrics — Help Guide</button> |
796 html += build_tabbed_html(config_html + metrics_html, train_plots_html, test_plots_html) | 1178 <br><br> |
1179 <style> | |
1180 .help-modal-btn { | |
1181 background-color: #17623b; | |
1182 color: #fff; | |
1183 border: none; | |
1184 border-radius: 24px; | |
1185 padding: 10px 28px; | |
1186 font-size: 1.1rem; | |
1187 font-weight: bold; | |
1188 letter-spacing: 0.03em; | |
1189 cursor: pointer; | |
1190 transition: background 0.2s, box-shadow 0.2s; | |
1191 box-shadow: 0 2px 8px rgba(23,98,59,0.07); | |
1192 } | |
1193 .help-modal-btn:hover, .help-modal-btn:focus { | |
1194 background-color: #21895e; | |
1195 outline: none; | |
1196 box-shadow: 0 4px 16px rgba(23,98,59,0.14); | |
1197 } | |
1198 </style> | |
1199 """ | |
1200 tab1_content = button_html + config_html + metrics_html | |
1201 tab2_content = ( | |
1202 button_html | |
1203 + train_val_metrics_html | |
1204 + render_img_section("Training & Validation Visualizations", train_viz_dir) | |
1205 ) | |
1206 tab3_content = ( | |
1207 button_html | |
1208 + test_metrics_html | |
1209 + render_img_section("Test Visualizations", test_viz_dir, output_type) | |
1210 ) | |
1211 | |
1212 tabbed_html = build_tabbed_html(tab1_content, tab2_content, tab3_content) | |
1213 modal_html = get_metrics_help_modal() | |
1214 html += tabbed_html + modal_html | |
797 html += get_html_closing() | 1215 html += get_html_closing() |
798 | 1216 |
799 try: | 1217 try: |
800 with open(report_path, "w") as f: | 1218 with open(report_path, "w") as f: |
801 f.write(html) | 1219 f.write(html) |
806 | 1224 |
807 return report_path | 1225 return report_path |
808 | 1226 |
809 | 1227 |
810 class WorkflowOrchestrator: | 1228 class WorkflowOrchestrator: |
811 """ | 1229 """Manages the image-classification workflow.""" |
812 Manages the image-classification workflow: | |
813 1. Creates temp dirs | |
814 2. Extracts images | |
815 3. Prepares data (CSV + splits) | |
816 4. Renders a backend config | |
817 5. Runs the experiment | |
818 6. Cleans up | |
819 """ | |
820 | 1230 |
821 def __init__(self, args: argparse.Namespace, backend: Backend): | 1231 def __init__(self, args: argparse.Namespace, backend: Backend): |
822 self.args = args | 1232 self.args = args |
823 self.backend = backend | 1233 self.backend = backend |
824 self.temp_dir: Optional[Path] = None | 1234 self.temp_dir: Optional[Path] = None |
826 logger.info(f"Orchestrator initialized with backend: {type(backend).__name__}") | 1236 logger.info(f"Orchestrator initialized with backend: {type(backend).__name__}") |
827 | 1237 |
828 def _create_temp_dirs(self) -> None: | 1238 def _create_temp_dirs(self) -> None: |
829 """Create temporary output and image extraction directories.""" | 1239 """Create temporary output and image extraction directories.""" |
830 try: | 1240 try: |
831 self.temp_dir = Path(tempfile.mkdtemp( | 1241 self.temp_dir = Path( |
832 dir=self.args.output_dir, | 1242 tempfile.mkdtemp(dir=self.args.output_dir, prefix=TEMP_DIR_PREFIX) |
833 prefix=TEMP_DIR_PREFIX | 1243 ) |
834 )) | |
835 self.image_extract_dir = self.temp_dir / "images" | 1244 self.image_extract_dir = self.temp_dir / "images" |
836 self.image_extract_dir.mkdir() | 1245 self.image_extract_dir.mkdir() |
837 logger.info(f"Created temp directory: {self.temp_dir}") | 1246 logger.info(f"Created temp directory: {self.temp_dir}") |
838 except Exception: | 1247 except Exception: |
839 logger.error("Failed to create temporary directories", exc_info=True) | 1248 logger.error("Failed to create temporary directories", exc_info=True) |
841 | 1250 |
842 def _extract_images(self) -> None: | 1251 def _extract_images(self) -> None: |
843 """Extract images from ZIP into the temp image directory.""" | 1252 """Extract images from ZIP into the temp image directory.""" |
844 if self.image_extract_dir is None: | 1253 if self.image_extract_dir is None: |
845 raise RuntimeError("Temp image directory not initialized.") | 1254 raise RuntimeError("Temp image directory not initialized.") |
846 logger.info(f"Extracting images from {self.args.image_zip} → {self.image_extract_dir}") | 1255 logger.info( |
1256 f"Extracting images from {self.args.image_zip} → {self.image_extract_dir}" | |
1257 ) | |
847 try: | 1258 try: |
848 with zipfile.ZipFile(self.args.image_zip, "r") as z: | 1259 with zipfile.ZipFile(self.args.image_zip, "r") as z: |
849 z.extractall(self.image_extract_dir) | 1260 z.extractall(self.image_extract_dir) |
850 logger.info("Image extraction complete.") | 1261 logger.info("Image extraction complete.") |
851 except Exception: | 1262 except Exception: |
852 logger.error("Error extracting zip file", exc_info=True) | 1263 logger.error("Error extracting zip file", exc_info=True) |
853 raise | 1264 raise |
854 | 1265 |
855 def _prepare_data(self) -> Tuple[Path, Dict[str, Any]]: | 1266 def _prepare_data(self) -> Tuple[Path, Dict[str, Any]]: |
856 """ | 1267 """Load CSV, update image paths, handle splits, and write prepared CSV.""" |
857 Load CSV, update image paths, handle splits, and write prepared CSV. | |
858 Returns: | |
859 final_csv_path: Path to the prepared CSV | |
860 split_config: Dict for backend split settings | |
861 """ | |
862 if not self.temp_dir or not self.image_extract_dir: | 1268 if not self.temp_dir or not self.image_extract_dir: |
863 raise RuntimeError("Temp dirs not initialized before data prep.") | 1269 raise RuntimeError("Temp dirs not initialized before data prep.") |
864 | 1270 |
865 # 1) Load | |
866 try: | 1271 try: |
867 df = pd.read_csv(self.args.csv_file) | 1272 df = pd.read_csv(self.args.csv_file) |
868 logger.info(f"Loaded CSV: {self.args.csv_file}") | 1273 logger.info(f"Loaded CSV: {self.args.csv_file}") |
869 except Exception: | 1274 except Exception: |
870 logger.error("Error loading CSV file", exc_info=True) | 1275 logger.error("Error loading CSV file", exc_info=True) |
871 raise | 1276 raise |
872 | 1277 |
873 # 2) Validate columns | |
874 required = {IMAGE_PATH_COLUMN_NAME, LABEL_COLUMN_NAME} | 1278 required = {IMAGE_PATH_COLUMN_NAME, LABEL_COLUMN_NAME} |
875 missing = required - set(df.columns) | 1279 missing = required - set(df.columns) |
876 if missing: | 1280 if missing: |
877 raise ValueError(f"Missing CSV columns: {', '.join(missing)}") | 1281 raise ValueError(f"Missing CSV columns: {', '.join(missing)}") |
878 | 1282 |
879 # 3) Update image paths | |
880 try: | 1283 try: |
881 df[IMAGE_PATH_COLUMN_NAME] = df[IMAGE_PATH_COLUMN_NAME].apply( | 1284 df[IMAGE_PATH_COLUMN_NAME] = df[IMAGE_PATH_COLUMN_NAME].apply( |
882 lambda p: str((self.image_extract_dir / p).resolve()) | 1285 lambda p: str((self.image_extract_dir / p).resolve()) |
883 ) | 1286 ) |
884 except Exception: | 1287 except Exception: |
885 logger.error("Error updating image paths", exc_info=True) | 1288 logger.error("Error updating image paths", exc_info=True) |
886 raise | 1289 raise |
887 | 1290 |
888 # 4) Handle splits | |
889 if SPLIT_COLUMN_NAME in df.columns: | 1291 if SPLIT_COLUMN_NAME in df.columns: |
890 df, split_config, split_info = self._process_fixed_split(df) | 1292 df, split_config, split_info = self._process_fixed_split(df) |
891 else: | 1293 else: |
892 logger.info("No split column; using random split") | 1294 logger.info("No split column; using random split") |
893 split_config = { | 1295 split_config = { |
894 "type": "random", | 1296 "type": "random", |
895 "probabilities": self.args.split_probabilities | 1297 "probabilities": self.args.split_probabilities, |
896 } | 1298 } |
897 split_info = ( | 1299 split_info = ( |
898 f"No split column in CSV. Used random split: " | 1300 f"No split column in CSV. Used random split: " |
899 f"{[int(p*100) for p in self.args.split_probabilities]}% for train/val/test." | 1301 f"{[int(p * 100) for p in self.args.split_probabilities]}% " |
900 ) | 1302 f"for train/val/test." |
901 | 1303 ) |
902 # 5) Write out prepared CSV | 1304 |
903 final_csv = TEMP_CSV_FILENAME | 1305 final_csv = TEMP_CSV_FILENAME |
904 try: | 1306 try: |
905 df.to_csv(final_csv, index=False) | 1307 df.to_csv(final_csv, index=False) |
906 logger.info(f"Saved prepared data to {final_csv}") | 1308 logger.info(f"Saved prepared data to {final_csv}") |
907 except Exception: | 1309 except Exception: |
913 def _process_fixed_split(self, df: pd.DataFrame) -> Dict[str, Any]: | 1315 def _process_fixed_split(self, df: pd.DataFrame) -> Dict[str, Any]: |
914 """Process a fixed split column (0=train,1=val,2=test).""" | 1316 """Process a fixed split column (0=train,1=val,2=test).""" |
915 logger.info(f"Fixed split column '{SPLIT_COLUMN_NAME}' detected.") | 1317 logger.info(f"Fixed split column '{SPLIT_COLUMN_NAME}' detected.") |
916 try: | 1318 try: |
917 col = df[SPLIT_COLUMN_NAME] | 1319 col = df[SPLIT_COLUMN_NAME] |
918 df[SPLIT_COLUMN_NAME] = pd.to_numeric(col, errors="coerce").astype(pd.Int64Dtype()) | 1320 df[SPLIT_COLUMN_NAME] = pd.to_numeric(col, errors="coerce").astype( |
1321 pd.Int64Dtype() | |
1322 ) | |
919 if df[SPLIT_COLUMN_NAME].isna().any(): | 1323 if df[SPLIT_COLUMN_NAME].isna().any(): |
920 logger.warning("Split column contains non-numeric/missing values.") | 1324 logger.warning("Split column contains non-numeric/missing values.") |
921 | 1325 |
922 unique = set(df[SPLIT_COLUMN_NAME].dropna().unique()) | 1326 unique = set(df[SPLIT_COLUMN_NAME].dropna().unique()) |
923 logger.info(f"Unique split values: {unique}") | 1327 logger.info(f"Unique split values: {unique}") |
924 | 1328 |
925 if unique == {0, 2}: | 1329 if unique == {0, 2}: |
926 df = split_data_0_2( | 1330 df = split_data_0_2( |
927 df, SPLIT_COLUMN_NAME, | 1331 df, |
1332 SPLIT_COLUMN_NAME, | |
928 validation_size=self.args.validation_size, | 1333 validation_size=self.args.validation_size, |
929 label_column=LABEL_COLUMN_NAME, | 1334 label_column=LABEL_COLUMN_NAME, |
930 random_state=self.args.random_seed | 1335 random_state=self.args.random_seed, |
931 ) | 1336 ) |
932 split_info = ( | 1337 split_info = ( |
933 "Detected a split column (with values 0 and 2) in the input CSV. " | 1338 "Detected a split column (with values 0 and 2) in the input CSV. " |
934 f"Used this column as a base and" | 1339 f"Used this column as a base and reassigned " |
935 f"reassigned {self.args.validation_size * 100:.1f}% " | 1340 f"{self.args.validation_size * 100:.1f}% " |
936 "of the training set (originally labeled 0) to validation (labeled 1)." | 1341 "of the training set (originally labeled 0) to validation (labeled 1)." |
937 ) | 1342 ) |
938 | |
939 logger.info("Applied custom 0/2 split.") | 1343 logger.info("Applied custom 0/2 split.") |
940 elif unique.issubset({0, 1, 2}): | 1344 elif unique.issubset({0, 1, 2}): |
941 split_info = "Used user-defined split column from CSV." | 1345 split_info = "Used user-defined split column from CSV." |
942 logger.info("Using fixed split as-is.") | 1346 logger.info("Using fixed split as-is.") |
943 else: | 1347 else: |
948 except Exception: | 1352 except Exception: |
949 logger.error("Error processing fixed split", exc_info=True) | 1353 logger.error("Error processing fixed split", exc_info=True) |
950 raise | 1354 raise |
951 | 1355 |
952 def _cleanup_temp_dirs(self) -> None: | 1356 def _cleanup_temp_dirs(self) -> None: |
953 """Remove any temporary directories.""" | |
954 if self.temp_dir and self.temp_dir.exists(): | 1357 if self.temp_dir and self.temp_dir.exists(): |
955 logger.info(f"Cleaning up temp directory: {self.temp_dir}") | 1358 logger.info(f"Cleaning up temp directory: {self.temp_dir}") |
956 shutil.rmtree(self.temp_dir, ignore_errors=True) | 1359 shutil.rmtree(self.temp_dir, ignore_errors=True) |
957 self.temp_dir = None | 1360 self.temp_dir = None |
958 self.image_extract_dir = None | 1361 self.image_extract_dir = None |
978 "preprocessing_num_processes": self.args.preprocessing_num_processes, | 1381 "preprocessing_num_processes": self.args.preprocessing_num_processes, |
979 "split_probabilities": self.args.split_probabilities, | 1382 "split_probabilities": self.args.split_probabilities, |
980 "learning_rate": self.args.learning_rate, | 1383 "learning_rate": self.args.learning_rate, |
981 "random_seed": self.args.random_seed, | 1384 "random_seed": self.args.random_seed, |
982 "early_stop": self.args.early_stop, | 1385 "early_stop": self.args.early_stop, |
1386 "label_column_data_path": csv_path, | |
983 } | 1387 } |
984 yaml_str = self.backend.prepare_config(backend_args, split_cfg) | 1388 yaml_str = self.backend.prepare_config(backend_args, split_cfg) |
985 | 1389 |
986 config_file = self.temp_dir / TEMP_CONFIG_FILENAME | 1390 config_file = self.temp_dir / TEMP_CONFIG_FILENAME |
987 config_file.write_text(yaml_str) | 1391 config_file.write_text(yaml_str) |
989 | 1393 |
990 self.backend.run_experiment( | 1394 self.backend.run_experiment( |
991 csv_path, | 1395 csv_path, |
992 config_file, | 1396 config_file, |
993 self.args.output_dir, | 1397 self.args.output_dir, |
994 self.args.random_seed | 1398 self.args.random_seed, |
995 ) | 1399 ) |
996 logger.info("Workflow completed successfully.") | 1400 logger.info("Workflow completed successfully.") |
997 self.backend.generate_plots(self.args.output_dir) | 1401 self.backend.generate_plots(self.args.output_dir) |
998 report_file = self.backend.generate_html_report( | 1402 report_file = self.backend.generate_html_report( |
999 "Image Classification Results", | 1403 "Image Classification Results", |
1000 self.args.output_dir, | 1404 self.args.output_dir, |
1001 backend_args, | 1405 backend_args, |
1002 split_info | 1406 split_info, |
1003 ) | 1407 ) |
1004 logger.info(f"HTML report generated at: {report_file}") | 1408 logger.info(f"HTML report generated at: {report_file}") |
1005 self.backend.convert_parquet_to_csv(self.args.output_dir) | 1409 self.backend.convert_parquet_to_csv(self.args.output_dir) |
1006 logger.info("Converted Parquet to CSV.") | 1410 logger.info("Converted Parquet to CSV.") |
1007 except Exception: | 1411 except Exception: |
1008 logger.error("Workflow execution failed", exc_info=True) | 1412 logger.error("Workflow execution failed", exc_info=True) |
1009 raise | 1413 raise |
1010 | |
1011 finally: | 1414 finally: |
1012 self._cleanup_temp_dirs() | 1415 self._cleanup_temp_dirs() |
1013 | 1416 |
1014 | 1417 |
1015 def parse_learning_rate(s): | 1418 def parse_learning_rate(s): |
1019 return None | 1422 return None |
1020 | 1423 |
1021 | 1424 |
1022 class SplitProbAction(argparse.Action): | 1425 class SplitProbAction(argparse.Action): |
1023 def __call__(self, parser, namespace, values, option_string=None): | 1426 def __call__(self, parser, namespace, values, option_string=None): |
1024 # values is a list of three floats | |
1025 train, val, test = values | 1427 train, val, test = values |
1026 total = train + val + test | 1428 total = train + val + test |
1027 if abs(total - 1.0) > 1e-6: | 1429 if abs(total - 1.0) > 1e-6: |
1028 parser.error( | 1430 parser.error( |
1029 f"--split-probabilities must sum to 1.0; " | 1431 f"--split-probabilities must sum to 1.0; " |
1031 ) | 1433 ) |
1032 setattr(namespace, self.dest, values) | 1434 setattr(namespace, self.dest, values) |
1033 | 1435 |
1034 | 1436 |
1035 def main(): | 1437 def main(): |
1036 | |
1037 parser = argparse.ArgumentParser( | 1438 parser = argparse.ArgumentParser( |
1038 description="Image Classification Learner with Pluggable Backends" | 1439 description="Image Classification Learner with Pluggable Backends", |
1039 ) | 1440 ) |
1040 parser.add_argument( | 1441 parser.add_argument( |
1041 "--csv-file", required=True, type=Path, | 1442 "--csv-file", |
1042 help="Path to the input CSV" | 1443 required=True, |
1444 type=Path, | |
1445 help="Path to the input CSV", | |
1043 ) | 1446 ) |
1044 parser.add_argument( | 1447 parser.add_argument( |
1045 "--image-zip", required=True, type=Path, | 1448 "--image-zip", |
1046 help="Path to the images ZIP" | 1449 required=True, |
1450 type=Path, | |
1451 help="Path to the images ZIP", | |
1047 ) | 1452 ) |
1048 parser.add_argument( | 1453 parser.add_argument( |
1049 "--model-name", required=True, | 1454 "--model-name", |
1455 required=True, | |
1050 choices=MODEL_ENCODER_TEMPLATES.keys(), | 1456 choices=MODEL_ENCODER_TEMPLATES.keys(), |
1051 help="Which model template to use" | 1457 help="Which model template to use", |
1052 ) | 1458 ) |
1053 parser.add_argument( | 1459 parser.add_argument( |
1054 "--use-pretrained", action="store_true", | 1460 "--use-pretrained", |
1055 help="Use pretrained weights for the model" | 1461 action="store_true", |
1462 help="Use pretrained weights for the model", | |
1056 ) | 1463 ) |
1057 parser.add_argument( | 1464 parser.add_argument( |
1058 "--fine-tune", action="store_true", | 1465 "--fine-tune", |
1059 help="Enable fine-tuning" | 1466 action="store_true", |
1467 help="Enable fine-tuning", | |
1060 ) | 1468 ) |
1061 parser.add_argument( | 1469 parser.add_argument( |
1062 "--epochs", type=int, default=10, | 1470 "--epochs", |
1063 help="Number of training epochs" | 1471 type=int, |
1472 default=10, | |
1473 help="Number of training epochs", | |
1064 ) | 1474 ) |
1065 parser.add_argument( | 1475 parser.add_argument( |
1066 "--early-stop", type=int, default=5, | 1476 "--early-stop", |
1067 help="Early stopping patience" | 1477 type=int, |
1478 default=5, | |
1479 help="Early stopping patience", | |
1068 ) | 1480 ) |
1069 parser.add_argument( | 1481 parser.add_argument( |
1070 "--batch-size", type=int, | 1482 "--batch-size", |
1071 help="Batch size (None = auto)" | 1483 type=int, |
1484 help="Batch size (None = auto)", | |
1072 ) | 1485 ) |
1073 parser.add_argument( | 1486 parser.add_argument( |
1074 "--output-dir", type=Path, default=Path("learner_output"), | 1487 "--output-dir", |
1075 help="Where to write outputs" | 1488 type=Path, |
1489 default=Path("learner_output"), | |
1490 help="Where to write outputs", | |
1076 ) | 1491 ) |
1077 parser.add_argument( | 1492 parser.add_argument( |
1078 "--validation-size", type=float, default=0.15, | 1493 "--validation-size", |
1079 help="Fraction for validation (0.0–1.0)" | 1494 type=float, |
1495 default=0.15, | |
1496 help="Fraction for validation (0.0–1.0)", | |
1080 ) | 1497 ) |
1081 parser.add_argument( | 1498 parser.add_argument( |
1082 "--preprocessing-num-processes", type=int, | 1499 "--preprocessing-num-processes", |
1500 type=int, | |
1083 default=max(1, os.cpu_count() // 2), | 1501 default=max(1, os.cpu_count() // 2), |
1084 help="CPU processes for data prep" | 1502 help="CPU processes for data prep", |
1085 ) | 1503 ) |
1086 parser.add_argument( | 1504 parser.add_argument( |
1087 "--split-probabilities", type=float, nargs=3, | 1505 "--split-probabilities", |
1506 type=float, | |
1507 nargs=3, | |
1088 metavar=("train", "val", "test"), | 1508 metavar=("train", "val", "test"), |
1089 action=SplitProbAction, | 1509 action=SplitProbAction, |
1090 default=[0.7, 0.1, 0.2], | 1510 default=[0.7, 0.1, 0.2], |
1091 help="Random split proportions (e.g., 0.7 0.1 0.2). Only used if no split column is present." | 1511 help="Random split proportions (e.g., 0.7 0.1 0.2). Only used if no split column.", |
1092 ) | 1512 ) |
1093 parser.add_argument( | 1513 parser.add_argument( |
1094 "--random-seed", type=int, default=42, | 1514 "--random-seed", |
1095 help="Random seed used for dataset splitting (default: 42)" | 1515 type=int, |
1516 default=42, | |
1517 help="Random seed used for dataset splitting (default: 42)", | |
1096 ) | 1518 ) |
1097 parser.add_argument( | 1519 parser.add_argument( |
1098 "--learning-rate", type=parse_learning_rate, default=None, | 1520 "--learning-rate", |
1099 help="Learning rate. If not provided, Ludwig will auto-select it." | 1521 type=parse_learning_rate, |
1522 default=None, | |
1523 help="Learning rate. If not provided, Ludwig will auto-select it.", | |
1100 ) | 1524 ) |
1101 | 1525 |
1102 args = parser.parse_args() | 1526 args = parser.parse_args() |
1103 | 1527 |
1104 # -- Validation -- | |
1105 if not 0.0 <= args.validation_size <= 1.0: | 1528 if not 0.0 <= args.validation_size <= 1.0: |
1106 parser.error("validation-size must be between 0.0 and 1.0") | 1529 parser.error("validation-size must be between 0.0 and 1.0") |
1107 if not args.csv_file.is_file(): | 1530 if not args.csv_file.is_file(): |
1108 parser.error(f"CSV not found: {args.csv_file}") | 1531 parser.error(f"CSV not found: {args.csv_file}") |
1109 if not args.image_zip.is_file(): | 1532 if not args.image_zip.is_file(): |
1110 parser.error(f"ZIP not found: {args.image_zip}") | 1533 parser.error(f"ZIP not found: {args.image_zip}") |
1111 | 1534 |
1112 # --- Instantiate Backend and Orchestrator --- | |
1113 # Use the new LudwigDirectBackend | |
1114 backend_instance = LudwigDirectBackend() | 1535 backend_instance = LudwigDirectBackend() |
1115 orchestrator = WorkflowOrchestrator(args, backend_instance) | 1536 orchestrator = WorkflowOrchestrator(args, backend_instance) |
1116 | 1537 |
1117 # --- Run Workflow --- | |
1118 exit_code = 0 | 1538 exit_code = 0 |
1119 try: | 1539 try: |
1120 orchestrator.run() | 1540 orchestrator.run() |
1121 logger.info("Main script finished successfully.") | 1541 logger.info("Main script finished successfully.") |
1122 except Exception as e: | 1542 except Exception as e: |
1124 exit_code = 1 | 1544 exit_code = 1 |
1125 finally: | 1545 finally: |
1126 sys.exit(exit_code) | 1546 sys.exit(exit_code) |
1127 | 1547 |
1128 | 1548 |
1129 if __name__ == '__main__': | 1549 if __name__ == "__main__": |
1130 try: | 1550 try: |
1131 import ludwig | 1551 import ludwig |
1552 | |
1132 logger.debug(f"Found Ludwig version: {ludwig.globals.LUDWIG_VERSION}") | 1553 logger.debug(f"Found Ludwig version: {ludwig.globals.LUDWIG_VERSION}") |
1133 except ImportError: | 1554 except ImportError: |
1134 logger.error("Ludwig library not found. Please ensure Ludwig is installed ('pip install ludwig[image]')") | 1555 logger.error( |
1556 "Ludwig library not found. Please ensure Ludwig is installed " | |
1557 "('pip install ludwig[image]')" | |
1558 ) | |
1135 sys.exit(1) | 1559 sys.exit(1) |
1136 | 1560 |
1137 main() | 1561 main() |