Mercurial > repos > goeckslab > image_learner
comparison utils.py @ 12:bcfa2e234a80 draft
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
| author | goeckslab |
|---|---|
| date | Fri, 21 Nov 2025 15:58:13 +0000 |
| parents | c5150cceab47 |
| children |
comparison
equal
deleted
inserted
replaced
| 11:c5150cceab47 | 12:bcfa2e234a80 |
|---|---|
| 1 import base64 | 1 import logging |
| 2 import json | 2 from pathlib import Path |
| 3 | |
| 4 import pandas as pd | |
| 5 | |
| 6 logger = logging.getLogger("ImageLearner") | |
| 3 | 7 |
| 4 | 8 |
| 5 def get_html_template(): | 9 def load_metadata_table(file_path: Path) -> pd.DataFrame: |
| 6 """ | 10 """Load image metadata allowing either CSV or TSV delimiters.""" |
| 7 Returns the opening HTML, <head> (with CSS/JS), and opens <body> + .container. | 11 logger.info("Loading metadata table from %s", file_path) |
| 8 Includes: | 12 return pd.read_csv(file_path, sep=None, engine="python") |
| 9 - Base styling for layout and tables | |
| 10 - Sortable table headers with 3-state arrows (none ⇅, asc ↑, desc ↓) | |
| 11 - A scroll helper class (.scroll-rows-30) that approximates ~30 visible rows | |
| 12 - A guarded script so initializing runs only once even if injected twice | |
| 13 """ | |
| 14 return """ | |
| 15 <!DOCTYPE html> | |
| 16 <html> | |
| 17 <head> | |
| 18 <meta charset="UTF-8"> | |
| 19 <title>Galaxy-Ludwig Report</title> | |
| 20 <style> | |
| 21 body { | |
| 22 font-family: Arial, sans-serif; | |
| 23 margin: 0; | |
| 24 padding: 20px; | |
| 25 background-color: #f4f4f4; | |
| 26 } | |
| 27 .container { | |
| 28 max-width: 1200px; | |
| 29 margin: auto; | |
| 30 background: white; | |
| 31 padding: 20px; | |
| 32 box-shadow: 0 0 10px rgba(0, 0, 0, 0.1); | |
| 33 overflow-x: auto; | |
| 34 } | |
| 35 h1 { | |
| 36 text-align: center; | |
| 37 color: #333; | |
| 38 } | |
| 39 h2 { | |
| 40 border-bottom: 2px solid #4CAF50; | |
| 41 color: #4CAF50; | |
| 42 padding-bottom: 5px; | |
| 43 margin-top: 28px; | |
| 44 } | |
| 45 | |
| 46 /* baseline table setup */ | |
| 47 table { | |
| 48 border-collapse: collapse; | |
| 49 margin: 20px 0; | |
| 50 width: 100%; | |
| 51 table-layout: fixed; | |
| 52 background: #fff; | |
| 53 } | |
| 54 table, th, td { | |
| 55 border: 1px solid #ddd; | |
| 56 } | |
| 57 th, td { | |
| 58 padding: 10px; | |
| 59 text-align: center; | |
| 60 vertical-align: middle; | |
| 61 word-break: break-word; | |
| 62 white-space: normal; | |
| 63 overflow-wrap: anywhere; | |
| 64 } | |
| 65 th { | |
| 66 background-color: #4CAF50; | |
| 67 color: white; | |
| 68 } | |
| 69 | |
| 70 .plot { | |
| 71 text-align: center; | |
| 72 margin: 20px 0; | |
| 73 } | |
| 74 .plot img { | |
| 75 max-width: 100%; | |
| 76 height: auto; | |
| 77 border: 1px solid #ddd; | |
| 78 } | |
| 79 | |
| 80 /* ------------------- | |
| 81 sortable columns (3-state: none ⇅, asc ↑, desc ↓) | |
| 82 ------------------- */ | |
| 83 table.performance-summary th.sortable { | |
| 84 cursor: pointer; | |
| 85 position: relative; | |
| 86 user-select: none; | |
| 87 } | |
| 88 /* default icon space */ | |
| 89 table.performance-summary th.sortable::after { | |
| 90 content: '⇅'; | |
| 91 position: absolute; | |
| 92 right: 12px; | |
| 93 top: 50%; | |
| 94 transform: translateY(-50%); | |
| 95 font-size: 0.8em; | |
| 96 color: #eaf5ea; /* light on green */ | |
| 97 text-shadow: 0 0 1px rgba(0,0,0,0.15); | |
| 98 } | |
| 99 /* three states override the default */ | |
| 100 table.performance-summary th.sortable.sorted-none::after { content: '⇅'; color: #eaf5ea; } | |
| 101 table.performance-summary th.sortable.sorted-asc::after { content: '↑'; color: #ffffff; } | |
| 102 table.performance-summary th.sortable.sorted-desc::after { content: '↓'; color: #ffffff; } | |
| 103 | |
| 104 /* show ~30 rows with a scrollbar (tweak if you want) */ | |
| 105 .scroll-rows-30 { | |
| 106 max-height: 900px; /* ~30 rows depending on row height */ | |
| 107 overflow-y: auto; /* vertical scrollbar ("sidebar") */ | |
| 108 overflow-x: auto; | |
| 109 } | |
| 110 | |
| 111 /* Tabs + Help button (used by build_tabbed_html) */ | |
| 112 .tabs { | |
| 113 display: flex; | |
| 114 align-items: center; | |
| 115 border-bottom: 2px solid #ccc; | |
| 116 margin-bottom: 1rem; | |
| 117 gap: 6px; | |
| 118 flex-wrap: wrap; | |
| 119 } | |
| 120 .tab { | |
| 121 padding: 10px 20px; | |
| 122 cursor: pointer; | |
| 123 border: 1px solid #ccc; | |
| 124 border-bottom: none; | |
| 125 background: #f9f9f9; | |
| 126 margin-right: 5px; | |
| 127 border-top-left-radius: 8px; | |
| 128 border-top-right-radius: 8px; | |
| 129 } | |
| 130 .tab.active { | |
| 131 background: white; | |
| 132 font-weight: bold; | |
| 133 } | |
| 134 .help-btn { | |
| 135 margin-left: auto; | |
| 136 padding: 6px 12px; | |
| 137 font-size: 0.9rem; | |
| 138 border: 1px solid #4CAF50; | |
| 139 border-radius: 4px; | |
| 140 background: #4CAF50; | |
| 141 color: white; | |
| 142 cursor: pointer; | |
| 143 } | |
| 144 .tab-content { | |
| 145 display: none; | |
| 146 padding: 20px; | |
| 147 border: 1px solid #ccc; | |
| 148 border-top: none; | |
| 149 background: #fff; | |
| 150 } | |
| 151 .tab-content.active { | |
| 152 display: block; | |
| 153 } | |
| 154 | |
| 155 /* Modal (used by get_metrics_help_modal) */ | |
| 156 .modal { | |
| 157 display: none; | |
| 158 position: fixed; | |
| 159 z-index: 9999; | |
| 160 left: 0; top: 0; | |
| 161 width: 100%; height: 100%; | |
| 162 overflow: auto; | |
| 163 background-color: rgba(0,0,0,0.4); | |
| 164 } | |
| 165 .modal-content { | |
| 166 background-color: #fefefe; | |
| 167 margin: 8% auto; | |
| 168 padding: 20px; | |
| 169 border: 1px solid #888; | |
| 170 width: 90%; | |
| 171 max-width: 900px; | |
| 172 border-radius: 8px; | |
| 173 } | |
| 174 .modal .close { | |
| 175 color: #777; | |
| 176 float: right; | |
| 177 font-size: 28px; | |
| 178 font-weight: bold; | |
| 179 line-height: 1; | |
| 180 margin-left: 8px; | |
| 181 } | |
| 182 .modal .close:hover, | |
| 183 .modal .close:focus { | |
| 184 color: black; | |
| 185 text-decoration: none; | |
| 186 cursor: pointer; | |
| 187 } | |
| 188 .metrics-guide h3 { margin-top: 20px; } | |
| 189 .metrics-guide p { margin: 6px 0; } | |
| 190 .metrics-guide ul { margin: 10px 0; padding-left: 20px; } | |
| 191 </style> | |
| 192 | |
| 193 <script> | |
| 194 // Guard to avoid double-initialization if this block is included twice | |
| 195 (function(){ | |
| 196 if (window.__perfSummarySortInit) return; | |
| 197 window.__perfSummarySortInit = true; | |
| 198 | |
| 199 function initPerfSummarySorting() { | |
| 200 // Record original order for "back to original" | |
| 201 document.querySelectorAll('table.performance-summary tbody').forEach(tbody => { | |
| 202 Array.from(tbody.rows).forEach((row, i) => { row.dataset.originalOrder = i; }); | |
| 203 }); | |
| 204 | |
| 205 const getText = td => (td?.innerText || '').trim(); | |
| 206 const cmp = (idx, asc) => (a, b) => { | |
| 207 const v1 = getText(a.children[idx]); | |
| 208 const v2 = getText(b.children[idx]); | |
| 209 const n1 = parseFloat(v1), n2 = parseFloat(v2); | |
| 210 if (!isNaN(n1) && !isNaN(n2)) return asc ? n1 - n2 : n2 - n1; // numeric | |
| 211 return asc ? v1.localeCompare(v2) : v2.localeCompare(v1); // lexical | |
| 212 }; | |
| 213 | |
| 214 document.querySelectorAll('table.performance-summary th.sortable').forEach(th => { | |
| 215 // initialize to "none" | |
| 216 th.classList.remove('sorted-asc','sorted-desc'); | |
| 217 th.classList.add('sorted-none'); | |
| 218 | |
| 219 th.addEventListener('click', () => { | |
| 220 const table = th.closest('table'); | |
| 221 const headerRow = th.parentNode; | |
| 222 const allTh = headerRow.querySelectorAll('th.sortable'); | |
| 223 const tbody = table.querySelector('tbody'); | |
| 224 | |
| 225 // Determine current state BEFORE clearing | |
| 226 const isAsc = th.classList.contains('sorted-asc'); | |
| 227 const isDesc = th.classList.contains('sorted-desc'); | |
| 228 | |
| 229 // Reset all headers in this row | |
| 230 allTh.forEach(x => x.classList.remove('sorted-asc','sorted-desc','sorted-none')); | |
| 231 | |
| 232 // Compute next state | |
| 233 let next; | |
| 234 if (!isAsc && !isDesc) { | |
| 235 next = 'asc'; | |
| 236 } else if (isAsc) { | |
| 237 next = 'desc'; | |
| 238 } else { | |
| 239 next = 'none'; | |
| 240 } | |
| 241 th.classList.add('sorted-' + next); | |
| 242 | |
| 243 // Sort rows according to the chosen state | |
| 244 const rows = Array.from(tbody.rows); | |
| 245 if (next === 'none') { | |
| 246 rows.sort((a, b) => (a.dataset.originalOrder - b.dataset.originalOrder)); | |
| 247 } else { | |
| 248 const idx = Array.from(headerRow.children).indexOf(th); | |
| 249 rows.sort(cmp(idx, next === 'asc')); | |
| 250 } | |
| 251 rows.forEach(r => tbody.appendChild(r)); | |
| 252 }); | |
| 253 }); | |
| 254 } | |
| 255 | |
| 256 // Run after DOM is ready | |
| 257 if (document.readyState === 'loading') { | |
| 258 document.addEventListener('DOMContentLoaded', initPerfSummarySorting); | |
| 259 } else { | |
| 260 initPerfSummarySorting(); | |
| 261 } | |
| 262 })(); | |
| 263 </script> | |
| 264 </head> | |
| 265 <body> | |
| 266 <div class="container"> | |
| 267 """ | |
| 268 | 13 |
| 269 | 14 |
| 270 def get_html_closing(): | 15 def detect_output_type(test_stats): |
| 271 """Closes .container, body, and html.""" | 16 """Detects if the output type is 'binary' or 'category' based on test statistics.""" |
| 272 return """ | 17 label_stats = test_stats.get("label", {}) |
| 273 </div> | 18 if "mean_squared_error" in label_stats: |
| 274 </body> | 19 return "regression" |
| 275 </html> | 20 per_class = label_stats.get("per_class_stats", {}) |
| 276 """ | 21 if len(per_class) == 2: |
| 22 return "binary" | |
| 23 return "category" | |
| 277 | 24 |
| 278 | 25 |
| 279 def encode_image_to_base64(image_path: str) -> str: | 26 def aug_parse(aug_string: str): |
| 280 """Convert an image file to a base64 encoded string.""" | 27 """ |
| 281 with open(image_path, "rb") as img_file: | 28 Parse comma-separated augmentation keys into Ludwig augmentation dicts. |
| 282 return base64.b64encode(img_file.read()).decode("utf-8") | 29 Raises ValueError on unknown key. |
| 30 """ | |
| 31 mapping = { | |
| 32 "random_horizontal_flip": {"type": "random_horizontal_flip"}, | |
| 33 "random_vertical_flip": {"type": "random_vertical_flip"}, | |
| 34 "random_rotate": {"type": "random_rotate", "degree": 10}, | |
| 35 "random_blur": {"type": "random_blur", "kernel_size": 3}, | |
| 36 "random_brightness": {"type": "random_brightness", "min": 0.5, "max": 2.0}, | |
| 37 "random_contrast": {"type": "random_contrast", "min": 0.5, "max": 2.0}, | |
| 38 } | |
| 39 aug_list = [] | |
| 40 for tok in aug_string.split(","): | |
| 41 key = tok.strip() | |
| 42 if not key: | |
| 43 continue | |
| 44 if key not in mapping: | |
| 45 valid = ", ".join(mapping.keys()) | |
| 46 raise ValueError(f"Unknown augmentation '{key}'. Valid choices: {valid}") | |
| 47 aug_list.append(mapping[key]) | |
| 48 return aug_list | |
| 283 | 49 |
| 284 | 50 |
| 285 def json_to_nested_html_table(json_data, depth: int = 0) -> str: | 51 def argument_checker(args, parser): |
| 286 """ | 52 if not 0.0 <= args.validation_size <= 1.0: |
| 287 Convert a JSON-able object to an HTML nested table. | 53 parser.error("validation-size must be between 0.0 and 1.0") |
| 288 Renders dicts as two-column tables (key/value) and lists as index/value rows. | 54 if not args.csv_file.is_file(): |
| 289 """ | 55 parser.error(f"Metada file not found: {args.csv_file}") |
| 290 # Base case: flat dict (no nested dict/list values) | 56 if not (args.image_zip.is_file() or args.image_zip.is_dir()): |
| 291 if isinstance(json_data, dict) and all( | 57 parser.error(f"ZIP or directory not found: {args.image_zip}") |
| 292 not isinstance(v, (dict, list)) for v in json_data.values() | 58 if args.augmentation is not None: |
| 293 ): | 59 try: |
| 294 rows = [ | 60 augmentation_setup = aug_parse(args.augmentation) |
| 295 f"<tr><th>{key}</th><td>{value}</td></tr>" | 61 setattr(args, "augmentation", augmentation_setup) |
| 296 for key, value in json_data.items() | 62 except ValueError as e: |
| 297 ] | 63 parser.error(str(e)) |
| 298 return f"<table>{''.join(rows)}</table>" | |
| 299 | |
| 300 # Base case: list of simple values | |
| 301 if isinstance(json_data, list) and all( | |
| 302 not isinstance(v, (dict, list)) for v in json_data | |
| 303 ): | |
| 304 rows = [ | |
| 305 f"<tr><th>Index {i}</th><td>{value}</td></tr>" | |
| 306 for i, value in enumerate(json_data) | |
| 307 ] | |
| 308 return f"<table>{''.join(rows)}</table>" | |
| 309 | |
| 310 # Recursive cases | |
| 311 if isinstance(json_data, dict): | |
| 312 rows = [ | |
| 313 ( | |
| 314 f"<tr><th style='text-align:left;padding-left:{depth * 20}px;'>{key}</th>" | |
| 315 f"<td>{json_to_nested_html_table(value, depth + 1)}</td></tr>" | |
| 316 ) | |
| 317 for key, value in json_data.items() | |
| 318 ] | |
| 319 return f"<table>{''.join(rows)}</table>" | |
| 320 | |
| 321 if isinstance(json_data, list): | |
| 322 rows = [ | |
| 323 ( | |
| 324 f"<tr><th style='text-align:left;padding-left:{depth * 20}px;'>[{i}]</th>" | |
| 325 f"<td>{json_to_nested_html_table(value, depth + 1)}</td></tr>" | |
| 326 ) | |
| 327 for i, value in enumerate(json_data) | |
| 328 ] | |
| 329 return f"<table>{''.join(rows)}</table>" | |
| 330 | |
| 331 # Primitive | |
| 332 return f"{json_data}" | |
| 333 | 64 |
| 334 | 65 |
| 335 def json_to_html_table(json_data) -> str: | 66 def parse_learning_rate(s): |
| 336 """ | 67 try: |
| 337 Convert JSON (dict or string) into a vertically oriented HTML table. | 68 return float(s) |
| 338 """ | 69 except (TypeError, ValueError): |
| 339 if isinstance(json_data, str): | 70 return None |
| 340 json_data = json.loads(json_data) | |
| 341 return json_to_nested_html_table(json_data) | |
| 342 | 71 |
| 343 | 72 |
| 344 def build_tabbed_html(metrics_html: str, train_val_html: str, test_html: str) -> str: | 73 def extract_metrics_from_json( |
| 345 """ | 74 train_stats: dict, |
| 346 Build a 3-tab interface: | 75 test_stats: dict, |
| 347 - Config and Results Summary | 76 output_type: str, |
| 348 - Train/Validation Results | 77 ) -> dict: |
| 349 - Test Results | 78 """Extracts relevant metrics from training and test statistics based on the output type.""" |
| 350 Includes a persistent "Help" button that toggles the metrics modal. | 79 metrics = {"training": {}, "validation": {}, "test": {}} |
| 351 """ | |
| 352 return f""" | |
| 353 <div class="tabs"> | |
| 354 <div class="tab active" onclick="showTab('metrics')">Config and Results Summary</div> | |
| 355 <div class="tab" onclick="showTab('trainval')">Train/Validation Results</div> | |
| 356 <div class="tab" onclick="showTab('test')">Test Results</div> | |
| 357 <button id="openMetricsHelp" class="help-btn" title="Open metrics help">Help</button> | |
| 358 </div> | |
| 359 | 80 |
| 360 <div id="metrics" class="tab-content active"> | 81 def get_last_value(stats, key): |
| 361 {metrics_html} | 82 val = stats.get(key) |
| 362 </div> | 83 if isinstance(val, list) and val: |
| 363 <div id="trainval" class="tab-content"> | 84 return val[-1] |
| 364 {train_val_html} | 85 elif isinstance(val, (int, float)): |
| 365 </div> | 86 return val |
| 366 <div id="test" class="tab-content"> | 87 return None |
| 367 {test_html} | |
| 368 </div> | |
| 369 | 88 |
| 370 <script> | 89 for split in ["training", "validation"]: |
| 371 function showTab(id) {{ | 90 split_stats = train_stats.get(split, {}) |
| 372 document.querySelectorAll('.tab-content').forEach(el => el.classList.remove('active')); | 91 if not split_stats: |
| 373 document.querySelectorAll('.tab').forEach(el => el.classList.remove('active')); | 92 logger.warning("No statistics found for %s split", split) |
| 374 document.getElementById(id).classList.add('active'); | 93 continue |
| 375 // find tab with matching onclick target | 94 label_stats = split_stats.get("label", {}) |
| 376 document.querySelectorAll('.tab').forEach(t => {{ | 95 if not label_stats: |
| 377 if (t.getAttribute('onclick') && t.getAttribute('onclick').includes(id)) {{ | 96 logger.warning("No label statistics found for %s split", split) |
| 378 t.classList.add('active'); | 97 continue |
| 379 }} | 98 if output_type == "binary": |
| 380 }}); | 99 metrics[split] = { |
| 381 }} | 100 "accuracy": get_last_value(label_stats, "accuracy"), |
| 382 </script> | 101 "loss": get_last_value(label_stats, "loss"), |
| 383 """ | 102 "precision": get_last_value(label_stats, "precision"), |
| 103 "recall": get_last_value(label_stats, "recall"), | |
| 104 "specificity": get_last_value(label_stats, "specificity"), | |
| 105 "roc_auc": get_last_value(label_stats, "roc_auc"), | |
| 106 } | |
| 107 elif output_type == "regression": | |
| 108 metrics[split] = { | |
| 109 "loss": get_last_value(label_stats, "loss"), | |
| 110 "mean_absolute_error": get_last_value( | |
| 111 label_stats, "mean_absolute_error" | |
| 112 ), | |
| 113 "mean_absolute_percentage_error": get_last_value( | |
| 114 label_stats, "mean_absolute_percentage_error" | |
| 115 ), | |
| 116 "mean_squared_error": get_last_value(label_stats, "mean_squared_error"), | |
| 117 "root_mean_squared_error": get_last_value( | |
| 118 label_stats, "root_mean_squared_error" | |
| 119 ), | |
| 120 "root_mean_squared_percentage_error": get_last_value( | |
| 121 label_stats, "root_mean_squared_percentage_error" | |
| 122 ), | |
| 123 "r2": get_last_value(label_stats, "r2"), | |
| 124 } | |
| 125 else: | |
| 126 metrics[split] = { | |
| 127 "accuracy": get_last_value(label_stats, "accuracy"), | |
| 128 "accuracy_micro": get_last_value(label_stats, "accuracy_micro"), | |
| 129 "loss": get_last_value(label_stats, "loss"), | |
| 130 "roc_auc": get_last_value(label_stats, "roc_auc"), | |
| 131 "hits_at_k": get_last_value(label_stats, "hits_at_k"), | |
| 132 } | |
| 384 | 133 |
| 134 # Test metrics: dynamic extraction according to exclusions | |
| 135 test_label_stats = test_stats.get("label", {}) | |
| 136 if not test_label_stats: | |
| 137 logger.warning("No label statistics found for test split") | |
| 138 else: | |
| 139 combined_stats = test_stats.get("combined", {}) | |
| 140 overall_stats = test_label_stats.get("overall_stats", {}) | |
| 385 | 141 |
| 386 def get_metrics_help_modal() -> str: | 142 # Define exclusions |
| 387 """ | 143 if output_type == "binary": |
| 388 Returns a ready-to-use modal with a comprehensive metrics guide and | 144 exclude = {"per_class_stats", "precision_recall_curve", "roc_curve"} |
| 389 the small script that wires the "Help" button to open/close the modal. | 145 else: |
| 390 """ | 146 exclude = {"per_class_stats", "confusion_matrix"} |
| 391 modal_html = ( | |
| 392 '<div id="metricsHelpModal" class="modal">' | |
| 393 ' <div class="modal-content">' | |
| 394 ' <span class="close">×</span>' | |
| 395 " <h2>Model Evaluation Metrics — Help Guide</h2>" | |
| 396 ' <div class="metrics-guide">' | |
| 397 ' <h3>1) General Metrics (Regression and Classification)</h3>' | |
| 398 ' <p><strong>Loss (Regression & Classification):</strong> ' | |
| 399 'Measures the difference between predicted and actual values, ' | |
| 400 'optimized during training. Lower is better. ' | |
| 401 'For regression, this is often Mean Squared Error (MSE) or ' | |
| 402 'Mean Absolute Error (MAE). For classification, it\'s typically ' | |
| 403 'cross-entropy or log loss.</p>' | |
| 404 ' <h3>2) Regression Metrics</h3>' | |
| 405 ' <p><strong>Mean Absolute Error (MAE):</strong> ' | |
| 406 'Average of absolute differences between predicted and actual values, ' | |
| 407 'in the same units as the target. Use for interpretable error measurement ' | |
| 408 'when all errors are equally important. Less sensitive to outliers than MSE.</p>' | |
| 409 ' <p><strong>Mean Squared Error (MSE):</strong> ' | |
| 410 'Average of squared differences between predicted and actual values. ' | |
| 411 'Penalizes larger errors more heavily, useful when large deviations are critical. ' | |
| 412 'Often used as the loss function in regression.</p>' | |
| 413 ' <p><strong>Root Mean Squared Error (RMSE):</strong> ' | |
| 414 'Square root of MSE, in the same units as the target. ' | |
| 415 'Balances interpretability and sensitivity to large errors. ' | |
| 416 'Widely used for regression evaluation.</p>' | |
| 417 ' <p><strong>Mean Absolute Percentage Error (MAPE):</strong> ' | |
| 418 'Average absolute error as a percentage of actual values. ' | |
| 419 'Scale-independent, ideal for comparing relative errors across datasets. ' | |
| 420 'Avoid when actual values are near zero.</p>' | |
| 421 ' <p><strong>Root Mean Squared Percentage Error (RMSPE):</strong> ' | |
| 422 'Square root of mean squared percentage error. Scale-independent, ' | |
| 423 'penalizes larger relative errors more than MAPE. Use for forecasting ' | |
| 424 'or when relative accuracy matters.</p>' | |
| 425 ' <p><strong>R² Score:</strong> Proportion of variance in the target ' | |
| 426 'explained by the model. Ranges from negative infinity to 1 (perfect prediction). ' | |
| 427 'Use to assess model fit; negative values indicate poor performance ' | |
| 428 'compared to predicting the mean.</p>' | |
| 429 ' <h3>3) Classification Metrics</h3>' | |
| 430 ' <p><strong>Accuracy:</strong> Proportion of correct predictions ' | |
| 431 'among all predictions. Simple but misleading for imbalanced datasets, ' | |
| 432 'where high accuracy may hide poor performance on minority classes.</p>' | |
| 433 ' <p><strong>Micro Accuracy:</strong> Sums true positives and true negatives ' | |
| 434 'across all classes before computing accuracy. Suitable for multiclass or ' | |
| 435 'multilabel problems with imbalanced data.</p>' | |
| 436 ' <p><strong>Token Accuracy:</strong> Measures how often predicted tokens ' | |
| 437 '(e.g., in sequences) match true tokens. Common in NLP tasks like text generation ' | |
| 438 'or token classification.</p>' | |
| 439 ' <p><strong>Precision:</strong> Proportion of positive predictions that are ' | |
| 440 'correct (TP / (TP + FP)). Use when false positives are costly, e.g., spam detection.</p>' | |
| 441 ' <p><strong>Recall (Sensitivity):</strong> Proportion of actual positives ' | |
| 442 'correctly predicted (TP / (TP + FN)). Use when missing positives is risky, ' | |
| 443 'e.g., disease detection.</p>' | |
| 444 ' <p><strong>Specificity:</strong> True negative rate (TN / (TN + FP)). ' | |
| 445 'Measures ability to identify negatives. Useful in medical testing to avoid ' | |
| 446 'false alarms.</p>' | |
| 447 ' <h3>4) Classification: Macro, Micro, and Weighted Averages</h3>' | |
| 448 ' <p><strong>Macro Precision / Recall / F1:</strong> Averages the metric ' | |
| 449 'across all classes, treating each equally. Best for balanced datasets where ' | |
| 450 'all classes are equally important.</p>' | |
| 451 ' <p><strong>Micro Precision / Recall / F1:</strong> Aggregates true positives, ' | |
| 452 'false positives, and false negatives across all classes before computing. ' | |
| 453 'Ideal for imbalanced or multilabel classification.</p>' | |
| 454 ' <p><strong>Weighted Precision / Recall / F1:</strong> Averages metrics ' | |
| 455 'across classes, weighted by the number of true instances per class. Balances ' | |
| 456 'class importance based on frequency.</p>' | |
| 457 ' <h3>5) Classification: Average Precision (PR-AUC Variants)</h3>' | |
| 458 ' <p><strong>Average Precision Macro:</strong> Precision-Recall AUC averaged ' | |
| 459 'equally across classes. Use for balanced multiclass problems.</p>' | |
| 460 ' <p><strong>Average Precision Micro:</strong> Global Precision-Recall AUC ' | |
| 461 'using all instances. Best for imbalanced or multilabel classification.</p>' | |
| 462 ' <p><strong>Average Precision Samples:</strong> Precision-Recall AUC averaged ' | |
| 463 'across individual samples. Ideal for multilabel tasks where samples have multiple ' | |
| 464 'labels.</p>' | |
| 465 ' <h3>6) Classification: ROC-AUC Variants</h3>' | |
| 466 ' <p><strong>ROC-AUC:</strong> Measures ability to distinguish between classes. ' | |
| 467 'AUC = 1 is perfect; 0.5 is random guessing. Use for binary classification.</p>' | |
| 468 ' <p><strong>Macro ROC-AUC:</strong> Averages AUC across all classes equally. ' | |
| 469 'Suitable for balanced multiclass problems.</p>' | |
| 470 ' <p><strong>Micro ROC-AUC:</strong> Computes AUC from aggregated predictions ' | |
| 471 'across all classes. Useful for imbalanced or multilabel settings.</p>' | |
| 472 ' <h3>7) Classification: Confusion Matrix Stats (Per Class)</h3>' | |
| 473 ' <p><strong>True Positives / Negatives (TP / TN):</strong> Correct predictions ' | |
| 474 'for positives and negatives, respectively.</p>' | |
| 475 ' <p><strong>False Positives / Negatives (FP / FN):</strong> Incorrect predictions ' | |
| 476 '— false alarms and missed detections.</p>' | |
| 477 ' <h3>8) Classification: Ranking Metrics</h3>' | |
| 478 ' <p><strong>Hits at K:</strong> Measures whether the true label is among the ' | |
| 479 'top-K predictions. Common in recommendation systems and retrieval tasks.</p>' | |
| 480 ' <h3>9) Other Metrics (Classification)</h3>' | |
| 481 ' <p><strong>Cohen\'s Kappa:</strong> Measures agreement between predicted and ' | |
| 482 'actual labels, adjusted for chance. Useful for multiclass classification with ' | |
| 483 'imbalanced data.</p>' | |
| 484 ' <p><strong>Matthews Correlation Coefficient (MCC):</strong> Balanced measure ' | |
| 485 'using TP, TN, FP, and FN. Effective for imbalanced datasets.</p>' | |
| 486 ' <h3>10) Metric Recommendations</h3>' | |
| 487 ' <ul>' | |
| 488 ' <li><strong>Regression:</strong> Use <strong>RMSE</strong> or ' | |
| 489 '<strong>MAE</strong> for general evaluation, <strong>MAPE</strong> for relative ' | |
| 490 'errors, and <strong>R²</strong> to assess model fit. Use <strong>MSE</strong> or ' | |
| 491 '<strong>RMSPE</strong> when large errors are critical.</li>' | |
| 492 ' <li><strong>Classification (Balanced Data):</strong> Use <strong>Accuracy</strong> ' | |
| 493 'and <strong>F1</strong> for overall performance.</li>' | |
| 494 ' <li><strong>Classification (Imbalanced Data):</strong> Use <strong>Precision</strong>, ' | |
| 495 '<strong>Recall</strong>, and <strong>ROC-AUC</strong> to focus on minority class ' | |
| 496 'performance.</li>' | |
| 497 ' <li><strong>Multilabel or Imbalanced Classification:</strong> Use ' | |
| 498 '<strong>Micro Precision/Recall/F1</strong> or <strong>Micro ROC-AUC</strong>.</li>' | |
| 499 ' <li><strong>Balanced Multiclass:</strong> Use <strong>Macro Precision/Recall/F1</strong> ' | |
| 500 'or <strong>Macro ROC-AUC</strong>.</li>' | |
| 501 ' <li><strong>Class Frequency Matters:</strong> Use <strong>Weighted Precision/Recall/F1</strong> ' | |
| 502 'to account for class imbalance.</li>' | |
| 503 ' <li><strong>Recommendation/Ranking:</strong> Use <strong>Hits at K</strong> for retrieval tasks.</li>' | |
| 504 ' <li><strong>Detailed Analysis:</strong> Use <strong>Confusion Matrix stats</strong> ' | |
| 505 'for class-wise performance in classification.</li>' | |
| 506 ' </ul>' | |
| 507 ' </div>' | |
| 508 ' </div>' | |
| 509 '</div>' | |
| 510 ) | |
| 511 | 147 |
| 512 modal_js = ( | 148 # 1. Get all scalar test_label_stats not excluded |
| 513 "<script>" | 149 test_metrics = {} |
| 514 "document.addEventListener('DOMContentLoaded', function() {" | 150 for k, v in test_label_stats.items(): |
| 515 " var modal = document.getElementById('metricsHelpModal');" | 151 if k in exclude: |
| 516 " var openBtn = document.getElementById('openMetricsHelp');" | 152 continue |
| 517 " var closeBtn = modal ? modal.querySelector('.close') : null;" | 153 if k == "overall_stats": |
| 518 " if (openBtn && modal) {" | 154 continue |
| 519 " openBtn.addEventListener('click', function(){ modal.style.display = 'block'; });" | 155 if isinstance(v, (int, float, str, bool)): |
| 520 " }" | 156 test_metrics[k] = v |
| 521 " if (closeBtn && modal) {" | 157 |
| 522 " closeBtn.addEventListener('click', function(){ modal.style.display = 'none'; });" | 158 # 2. Add overall_stats (flattened) |
| 523 " }" | 159 for k, v in overall_stats.items(): |
| 524 " window.addEventListener('click', function(ev){" | 160 test_metrics[k] = v |
| 525 " if (ev.target === modal) { modal.style.display = 'none'; }" | 161 |
| 526 " });" | 162 # 3. Optionally include combined/loss if present and not already |
| 527 "});" | 163 if "loss" in combined_stats and "loss" not in test_metrics: |
| 528 "</script>" | 164 test_metrics["loss"] = combined_stats["loss"] |
| 529 ) | 165 metrics["test"] = test_metrics |
| 530 return modal_html + modal_js | 166 return metrics |
