Mercurial > repos > goeckslab > image_learner
diff utils.py @ 12:bcfa2e234a80 draft
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
| author | goeckslab |
|---|---|
| date | Fri, 21 Nov 2025 15:58:13 +0000 |
| parents | c5150cceab47 |
| children |
line wrap: on
line diff
--- a/utils.py Sat Oct 18 03:17:09 2025 +0000 +++ b/utils.py Fri Nov 21 15:58:13 2025 +0000 @@ -1,530 +1,166 @@ -import base64 -import json +import logging +from pathlib import Path + +import pandas as pd + +logger = logging.getLogger("ImageLearner") + + +def load_metadata_table(file_path: Path) -> pd.DataFrame: + """Load image metadata allowing either CSV or TSV delimiters.""" + logger.info("Loading metadata table from %s", file_path) + return pd.read_csv(file_path, sep=None, engine="python") + + +def detect_output_type(test_stats): + """Detects if the output type is 'binary' or 'category' based on test statistics.""" + label_stats = test_stats.get("label", {}) + if "mean_squared_error" in label_stats: + return "regression" + per_class = label_stats.get("per_class_stats", {}) + if len(per_class) == 2: + return "binary" + return "category" + + +def aug_parse(aug_string: str): + """ + Parse comma-separated augmentation keys into Ludwig augmentation dicts. + Raises ValueError on unknown key. + """ + mapping = { + "random_horizontal_flip": {"type": "random_horizontal_flip"}, + "random_vertical_flip": {"type": "random_vertical_flip"}, + "random_rotate": {"type": "random_rotate", "degree": 10}, + "random_blur": {"type": "random_blur", "kernel_size": 3}, + "random_brightness": {"type": "random_brightness", "min": 0.5, "max": 2.0}, + "random_contrast": {"type": "random_contrast", "min": 0.5, "max": 2.0}, + } + aug_list = [] + for tok in aug_string.split(","): + key = tok.strip() + if not key: + continue + if key not in mapping: + valid = ", ".join(mapping.keys()) + raise ValueError(f"Unknown augmentation '{key}'. Valid choices: {valid}") + aug_list.append(mapping[key]) + return aug_list + + +def argument_checker(args, parser): + if not 0.0 <= args.validation_size <= 1.0: + parser.error("validation-size must be between 0.0 and 1.0") + if not args.csv_file.is_file(): + parser.error(f"Metada file not found: {args.csv_file}") + if not (args.image_zip.is_file() or args.image_zip.is_dir()): + parser.error(f"ZIP or directory not found: {args.image_zip}") + if args.augmentation is not None: + try: + augmentation_setup = aug_parse(args.augmentation) + setattr(args, "augmentation", augmentation_setup) + except ValueError as e: + parser.error(str(e)) + + +def parse_learning_rate(s): + try: + return float(s) + except (TypeError, ValueError): + return None -def get_html_template(): - """ - Returns the opening HTML, <head> (with CSS/JS), and opens <body> + .container. - Includes: - - Base styling for layout and tables - - Sortable table headers with 3-state arrows (none ⇅, asc ↑, desc ↓) - - A scroll helper class (.scroll-rows-30) that approximates ~30 visible rows - - A guarded script so initializing runs only once even if injected twice - """ - return """ -<!DOCTYPE html> -<html> -<head> - <meta charset="UTF-8"> - <title>Galaxy-Ludwig Report</title> - <style> - body { - font-family: Arial, sans-serif; - margin: 0; - padding: 20px; - background-color: #f4f4f4; - } - .container { - max-width: 1200px; - margin: auto; - background: white; - padding: 20px; - box-shadow: 0 0 10px rgba(0, 0, 0, 0.1); - overflow-x: auto; - } - h1 { - text-align: center; - color: #333; - } - h2 { - border-bottom: 2px solid #4CAF50; - color: #4CAF50; - padding-bottom: 5px; - margin-top: 28px; - } - - /* baseline table setup */ - table { - border-collapse: collapse; - margin: 20px 0; - width: 100%; - table-layout: fixed; - background: #fff; - } - table, th, td { - border: 1px solid #ddd; - } - th, td { - padding: 10px; - text-align: center; - vertical-align: middle; - word-break: break-word; - white-space: normal; - overflow-wrap: anywhere; - } - th { - background-color: #4CAF50; - color: white; - } - - .plot { - text-align: center; - margin: 20px 0; - } - .plot img { - max-width: 100%; - height: auto; - border: 1px solid #ddd; - } - - /* ------------------- - sortable columns (3-state: none ⇅, asc ↑, desc ↓) - ------------------- */ - table.performance-summary th.sortable { - cursor: pointer; - position: relative; - user-select: none; - } - /* default icon space */ - table.performance-summary th.sortable::after { - content: '⇅'; - position: absolute; - right: 12px; - top: 50%; - transform: translateY(-50%); - font-size: 0.8em; - color: #eaf5ea; /* light on green */ - text-shadow: 0 0 1px rgba(0,0,0,0.15); - } - /* three states override the default */ - table.performance-summary th.sortable.sorted-none::after { content: '⇅'; color: #eaf5ea; } - table.performance-summary th.sortable.sorted-asc::after { content: '↑'; color: #ffffff; } - table.performance-summary th.sortable.sorted-desc::after { content: '↓'; color: #ffffff; } - - /* show ~30 rows with a scrollbar (tweak if you want) */ - .scroll-rows-30 { - max-height: 900px; /* ~30 rows depending on row height */ - overflow-y: auto; /* vertical scrollbar ("sidebar") */ - overflow-x: auto; - } +def extract_metrics_from_json( + train_stats: dict, + test_stats: dict, + output_type: str, +) -> dict: + """Extracts relevant metrics from training and test statistics based on the output type.""" + metrics = {"training": {}, "validation": {}, "test": {}} - /* Tabs + Help button (used by build_tabbed_html) */ - .tabs { - display: flex; - align-items: center; - border-bottom: 2px solid #ccc; - margin-bottom: 1rem; - gap: 6px; - flex-wrap: wrap; - } - .tab { - padding: 10px 20px; - cursor: pointer; - border: 1px solid #ccc; - border-bottom: none; - background: #f9f9f9; - margin-right: 5px; - border-top-left-radius: 8px; - border-top-right-radius: 8px; - } - .tab.active { - background: white; - font-weight: bold; - } - .help-btn { - margin-left: auto; - padding: 6px 12px; - font-size: 0.9rem; - border: 1px solid #4CAF50; - border-radius: 4px; - background: #4CAF50; - color: white; - cursor: pointer; - } - .tab-content { - display: none; - padding: 20px; - border: 1px solid #ccc; - border-top: none; - background: #fff; - } - .tab-content.active { - display: block; - } - - /* Modal (used by get_metrics_help_modal) */ - .modal { - display: none; - position: fixed; - z-index: 9999; - left: 0; top: 0; - width: 100%; height: 100%; - overflow: auto; - background-color: rgba(0,0,0,0.4); - } - .modal-content { - background-color: #fefefe; - margin: 8% auto; - padding: 20px; - border: 1px solid #888; - width: 90%; - max-width: 900px; - border-radius: 8px; - } - .modal .close { - color: #777; - float: right; - font-size: 28px; - font-weight: bold; - line-height: 1; - margin-left: 8px; - } - .modal .close:hover, - .modal .close:focus { - color: black; - text-decoration: none; - cursor: pointer; - } - .metrics-guide h3 { margin-top: 20px; } - .metrics-guide p { margin: 6px 0; } - .metrics-guide ul { margin: 10px 0; padding-left: 20px; } - </style> - - <script> - // Guard to avoid double-initialization if this block is included twice - (function(){ - if (window.__perfSummarySortInit) return; - window.__perfSummarySortInit = true; - - function initPerfSummarySorting() { - // Record original order for "back to original" - document.querySelectorAll('table.performance-summary tbody').forEach(tbody => { - Array.from(tbody.rows).forEach((row, i) => { row.dataset.originalOrder = i; }); - }); - - const getText = td => (td?.innerText || '').trim(); - const cmp = (idx, asc) => (a, b) => { - const v1 = getText(a.children[idx]); - const v2 = getText(b.children[idx]); - const n1 = parseFloat(v1), n2 = parseFloat(v2); - if (!isNaN(n1) && !isNaN(n2)) return asc ? n1 - n2 : n2 - n1; // numeric - return asc ? v1.localeCompare(v2) : v2.localeCompare(v1); // lexical - }; - - document.querySelectorAll('table.performance-summary th.sortable').forEach(th => { - // initialize to "none" - th.classList.remove('sorted-asc','sorted-desc'); - th.classList.add('sorted-none'); - - th.addEventListener('click', () => { - const table = th.closest('table'); - const headerRow = th.parentNode; - const allTh = headerRow.querySelectorAll('th.sortable'); - const tbody = table.querySelector('tbody'); - - // Determine current state BEFORE clearing - const isAsc = th.classList.contains('sorted-asc'); - const isDesc = th.classList.contains('sorted-desc'); - - // Reset all headers in this row - allTh.forEach(x => x.classList.remove('sorted-asc','sorted-desc','sorted-none')); - - // Compute next state - let next; - if (!isAsc && !isDesc) { - next = 'asc'; - } else if (isAsc) { - next = 'desc'; - } else { - next = 'none'; - } - th.classList.add('sorted-' + next); - - // Sort rows according to the chosen state - const rows = Array.from(tbody.rows); - if (next === 'none') { - rows.sort((a, b) => (a.dataset.originalOrder - b.dataset.originalOrder)); - } else { - const idx = Array.from(headerRow.children).indexOf(th); - rows.sort(cmp(idx, next === 'asc')); - } - rows.forEach(r => tbody.appendChild(r)); - }); - }); - } + def get_last_value(stats, key): + val = stats.get(key) + if isinstance(val, list) and val: + return val[-1] + elif isinstance(val, (int, float)): + return val + return None - // Run after DOM is ready - if (document.readyState === 'loading') { - document.addEventListener('DOMContentLoaded', initPerfSummarySorting); - } else { - initPerfSummarySorting(); - } - })(); - </script> -</head> -<body> - <div class="container"> -""" - - -def get_html_closing(): - """Closes .container, body, and html.""" - return """ - </div> -</body> -</html> -""" - - -def encode_image_to_base64(image_path: str) -> str: - """Convert an image file to a base64 encoded string.""" - with open(image_path, "rb") as img_file: - return base64.b64encode(img_file.read()).decode("utf-8") - - -def json_to_nested_html_table(json_data, depth: int = 0) -> str: - """ - Convert a JSON-able object to an HTML nested table. - Renders dicts as two-column tables (key/value) and lists as index/value rows. - """ - # Base case: flat dict (no nested dict/list values) - if isinstance(json_data, dict) and all( - not isinstance(v, (dict, list)) for v in json_data.values() - ): - rows = [ - f"<tr><th>{key}</th><td>{value}</td></tr>" - for key, value in json_data.items() - ] - return f"<table>{''.join(rows)}</table>" - - # Base case: list of simple values - if isinstance(json_data, list) and all( - not isinstance(v, (dict, list)) for v in json_data - ): - rows = [ - f"<tr><th>Index {i}</th><td>{value}</td></tr>" - for i, value in enumerate(json_data) - ] - return f"<table>{''.join(rows)}</table>" - - # Recursive cases - if isinstance(json_data, dict): - rows = [ - ( - f"<tr><th style='text-align:left;padding-left:{depth * 20}px;'>{key}</th>" - f"<td>{json_to_nested_html_table(value, depth + 1)}</td></tr>" - ) - for key, value in json_data.items() - ] - return f"<table>{''.join(rows)}</table>" - - if isinstance(json_data, list): - rows = [ - ( - f"<tr><th style='text-align:left;padding-left:{depth * 20}px;'>[{i}]</th>" - f"<td>{json_to_nested_html_table(value, depth + 1)}</td></tr>" - ) - for i, value in enumerate(json_data) - ] - return f"<table>{''.join(rows)}</table>" - - # Primitive - return f"{json_data}" - - -def json_to_html_table(json_data) -> str: - """ - Convert JSON (dict or string) into a vertically oriented HTML table. - """ - if isinstance(json_data, str): - json_data = json.loads(json_data) - return json_to_nested_html_table(json_data) - - -def build_tabbed_html(metrics_html: str, train_val_html: str, test_html: str) -> str: - """ - Build a 3-tab interface: - - Config and Results Summary - - Train/Validation Results - - Test Results - Includes a persistent "Help" button that toggles the metrics modal. - """ - return f""" -<div class="tabs"> - <div class="tab active" onclick="showTab('metrics')">Config and Results Summary</div> - <div class="tab" onclick="showTab('trainval')">Train/Validation Results</div> - <div class="tab" onclick="showTab('test')">Test Results</div> - <button id="openMetricsHelp" class="help-btn" title="Open metrics help">Help</button> -</div> - -<div id="metrics" class="tab-content active"> - {metrics_html} -</div> -<div id="trainval" class="tab-content"> - {train_val_html} -</div> -<div id="test" class="tab-content"> - {test_html} -</div> - -<script> - function showTab(id) {{ - document.querySelectorAll('.tab-content').forEach(el => el.classList.remove('active')); - document.querySelectorAll('.tab').forEach(el => el.classList.remove('active')); - document.getElementById(id).classList.add('active'); - // find tab with matching onclick target - document.querySelectorAll('.tab').forEach(t => {{ - if (t.getAttribute('onclick') && t.getAttribute('onclick').includes(id)) {{ - t.classList.add('active'); - }} - }}); - }} -</script> -""" - + for split in ["training", "validation"]: + split_stats = train_stats.get(split, {}) + if not split_stats: + logger.warning("No statistics found for %s split", split) + continue + label_stats = split_stats.get("label", {}) + if not label_stats: + logger.warning("No label statistics found for %s split", split) + continue + if output_type == "binary": + metrics[split] = { + "accuracy": get_last_value(label_stats, "accuracy"), + "loss": get_last_value(label_stats, "loss"), + "precision": get_last_value(label_stats, "precision"), + "recall": get_last_value(label_stats, "recall"), + "specificity": get_last_value(label_stats, "specificity"), + "roc_auc": get_last_value(label_stats, "roc_auc"), + } + elif output_type == "regression": + metrics[split] = { + "loss": get_last_value(label_stats, "loss"), + "mean_absolute_error": get_last_value( + label_stats, "mean_absolute_error" + ), + "mean_absolute_percentage_error": get_last_value( + label_stats, "mean_absolute_percentage_error" + ), + "mean_squared_error": get_last_value(label_stats, "mean_squared_error"), + "root_mean_squared_error": get_last_value( + label_stats, "root_mean_squared_error" + ), + "root_mean_squared_percentage_error": get_last_value( + label_stats, "root_mean_squared_percentage_error" + ), + "r2": get_last_value(label_stats, "r2"), + } + else: + metrics[split] = { + "accuracy": get_last_value(label_stats, "accuracy"), + "accuracy_micro": get_last_value(label_stats, "accuracy_micro"), + "loss": get_last_value(label_stats, "loss"), + "roc_auc": get_last_value(label_stats, "roc_auc"), + "hits_at_k": get_last_value(label_stats, "hits_at_k"), + } -def get_metrics_help_modal() -> str: - """ - Returns a ready-to-use modal with a comprehensive metrics guide and - the small script that wires the "Help" button to open/close the modal. - """ - modal_html = ( - '<div id="metricsHelpModal" class="modal">' - ' <div class="modal-content">' - ' <span class="close">×</span>' - " <h2>Model Evaluation Metrics — Help Guide</h2>" - ' <div class="metrics-guide">' - ' <h3>1) General Metrics (Regression and Classification)</h3>' - ' <p><strong>Loss (Regression & Classification):</strong> ' - 'Measures the difference between predicted and actual values, ' - 'optimized during training. Lower is better. ' - 'For regression, this is often Mean Squared Error (MSE) or ' - 'Mean Absolute Error (MAE). For classification, it\'s typically ' - 'cross-entropy or log loss.</p>' - ' <h3>2) Regression Metrics</h3>' - ' <p><strong>Mean Absolute Error (MAE):</strong> ' - 'Average of absolute differences between predicted and actual values, ' - 'in the same units as the target. Use for interpretable error measurement ' - 'when all errors are equally important. Less sensitive to outliers than MSE.</p>' - ' <p><strong>Mean Squared Error (MSE):</strong> ' - 'Average of squared differences between predicted and actual values. ' - 'Penalizes larger errors more heavily, useful when large deviations are critical. ' - 'Often used as the loss function in regression.</p>' - ' <p><strong>Root Mean Squared Error (RMSE):</strong> ' - 'Square root of MSE, in the same units as the target. ' - 'Balances interpretability and sensitivity to large errors. ' - 'Widely used for regression evaluation.</p>' - ' <p><strong>Mean Absolute Percentage Error (MAPE):</strong> ' - 'Average absolute error as a percentage of actual values. ' - 'Scale-independent, ideal for comparing relative errors across datasets. ' - 'Avoid when actual values are near zero.</p>' - ' <p><strong>Root Mean Squared Percentage Error (RMSPE):</strong> ' - 'Square root of mean squared percentage error. Scale-independent, ' - 'penalizes larger relative errors more than MAPE. Use for forecasting ' - 'or when relative accuracy matters.</p>' - ' <p><strong>R² Score:</strong> Proportion of variance in the target ' - 'explained by the model. Ranges from negative infinity to 1 (perfect prediction). ' - 'Use to assess model fit; negative values indicate poor performance ' - 'compared to predicting the mean.</p>' - ' <h3>3) Classification Metrics</h3>' - ' <p><strong>Accuracy:</strong> Proportion of correct predictions ' - 'among all predictions. Simple but misleading for imbalanced datasets, ' - 'where high accuracy may hide poor performance on minority classes.</p>' - ' <p><strong>Micro Accuracy:</strong> Sums true positives and true negatives ' - 'across all classes before computing accuracy. Suitable for multiclass or ' - 'multilabel problems with imbalanced data.</p>' - ' <p><strong>Token Accuracy:</strong> Measures how often predicted tokens ' - '(e.g., in sequences) match true tokens. Common in NLP tasks like text generation ' - 'or token classification.</p>' - ' <p><strong>Precision:</strong> Proportion of positive predictions that are ' - 'correct (TP / (TP + FP)). Use when false positives are costly, e.g., spam detection.</p>' - ' <p><strong>Recall (Sensitivity):</strong> Proportion of actual positives ' - 'correctly predicted (TP / (TP + FN)). Use when missing positives is risky, ' - 'e.g., disease detection.</p>' - ' <p><strong>Specificity:</strong> True negative rate (TN / (TN + FP)). ' - 'Measures ability to identify negatives. Useful in medical testing to avoid ' - 'false alarms.</p>' - ' <h3>4) Classification: Macro, Micro, and Weighted Averages</h3>' - ' <p><strong>Macro Precision / Recall / F1:</strong> Averages the metric ' - 'across all classes, treating each equally. Best for balanced datasets where ' - 'all classes are equally important.</p>' - ' <p><strong>Micro Precision / Recall / F1:</strong> Aggregates true positives, ' - 'false positives, and false negatives across all classes before computing. ' - 'Ideal for imbalanced or multilabel classification.</p>' - ' <p><strong>Weighted Precision / Recall / F1:</strong> Averages metrics ' - 'across classes, weighted by the number of true instances per class. Balances ' - 'class importance based on frequency.</p>' - ' <h3>5) Classification: Average Precision (PR-AUC Variants)</h3>' - ' <p><strong>Average Precision Macro:</strong> Precision-Recall AUC averaged ' - 'equally across classes. Use for balanced multiclass problems.</p>' - ' <p><strong>Average Precision Micro:</strong> Global Precision-Recall AUC ' - 'using all instances. Best for imbalanced or multilabel classification.</p>' - ' <p><strong>Average Precision Samples:</strong> Precision-Recall AUC averaged ' - 'across individual samples. Ideal for multilabel tasks where samples have multiple ' - 'labels.</p>' - ' <h3>6) Classification: ROC-AUC Variants</h3>' - ' <p><strong>ROC-AUC:</strong> Measures ability to distinguish between classes. ' - 'AUC = 1 is perfect; 0.5 is random guessing. Use for binary classification.</p>' - ' <p><strong>Macro ROC-AUC:</strong> Averages AUC across all classes equally. ' - 'Suitable for balanced multiclass problems.</p>' - ' <p><strong>Micro ROC-AUC:</strong> Computes AUC from aggregated predictions ' - 'across all classes. Useful for imbalanced or multilabel settings.</p>' - ' <h3>7) Classification: Confusion Matrix Stats (Per Class)</h3>' - ' <p><strong>True Positives / Negatives (TP / TN):</strong> Correct predictions ' - 'for positives and negatives, respectively.</p>' - ' <p><strong>False Positives / Negatives (FP / FN):</strong> Incorrect predictions ' - '— false alarms and missed detections.</p>' - ' <h3>8) Classification: Ranking Metrics</h3>' - ' <p><strong>Hits at K:</strong> Measures whether the true label is among the ' - 'top-K predictions. Common in recommendation systems and retrieval tasks.</p>' - ' <h3>9) Other Metrics (Classification)</h3>' - ' <p><strong>Cohen\'s Kappa:</strong> Measures agreement between predicted and ' - 'actual labels, adjusted for chance. Useful for multiclass classification with ' - 'imbalanced data.</p>' - ' <p><strong>Matthews Correlation Coefficient (MCC):</strong> Balanced measure ' - 'using TP, TN, FP, and FN. Effective for imbalanced datasets.</p>' - ' <h3>10) Metric Recommendations</h3>' - ' <ul>' - ' <li><strong>Regression:</strong> Use <strong>RMSE</strong> or ' - '<strong>MAE</strong> for general evaluation, <strong>MAPE</strong> for relative ' - 'errors, and <strong>R²</strong> to assess model fit. Use <strong>MSE</strong> or ' - '<strong>RMSPE</strong> when large errors are critical.</li>' - ' <li><strong>Classification (Balanced Data):</strong> Use <strong>Accuracy</strong> ' - 'and <strong>F1</strong> for overall performance.</li>' - ' <li><strong>Classification (Imbalanced Data):</strong> Use <strong>Precision</strong>, ' - '<strong>Recall</strong>, and <strong>ROC-AUC</strong> to focus on minority class ' - 'performance.</li>' - ' <li><strong>Multilabel or Imbalanced Classification:</strong> Use ' - '<strong>Micro Precision/Recall/F1</strong> or <strong>Micro ROC-AUC</strong>.</li>' - ' <li><strong>Balanced Multiclass:</strong> Use <strong>Macro Precision/Recall/F1</strong> ' - 'or <strong>Macro ROC-AUC</strong>.</li>' - ' <li><strong>Class Frequency Matters:</strong> Use <strong>Weighted Precision/Recall/F1</strong> ' - 'to account for class imbalance.</li>' - ' <li><strong>Recommendation/Ranking:</strong> Use <strong>Hits at K</strong> for retrieval tasks.</li>' - ' <li><strong>Detailed Analysis:</strong> Use <strong>Confusion Matrix stats</strong> ' - 'for class-wise performance in classification.</li>' - ' </ul>' - ' </div>' - ' </div>' - '</div>' - ) + # Test metrics: dynamic extraction according to exclusions + test_label_stats = test_stats.get("label", {}) + if not test_label_stats: + logger.warning("No label statistics found for test split") + else: + combined_stats = test_stats.get("combined", {}) + overall_stats = test_label_stats.get("overall_stats", {}) + + # Define exclusions + if output_type == "binary": + exclude = {"per_class_stats", "precision_recall_curve", "roc_curve"} + else: + exclude = {"per_class_stats", "confusion_matrix"} - modal_js = ( - "<script>" - "document.addEventListener('DOMContentLoaded', function() {" - " var modal = document.getElementById('metricsHelpModal');" - " var openBtn = document.getElementById('openMetricsHelp');" - " var closeBtn = modal ? modal.querySelector('.close') : null;" - " if (openBtn && modal) {" - " openBtn.addEventListener('click', function(){ modal.style.display = 'block'; });" - " }" - " if (closeBtn && modal) {" - " closeBtn.addEventListener('click', function(){ modal.style.display = 'none'; });" - " }" - " window.addEventListener('click', function(ev){" - " if (ev.target === modal) { modal.style.display = 'none'; }" - " });" - "});" - "</script>" - ) - return modal_html + modal_js + # 1. Get all scalar test_label_stats not excluded + test_metrics = {} + for k, v in test_label_stats.items(): + if k in exclude: + continue + if k == "overall_stats": + continue + if isinstance(v, (int, float, str, bool)): + test_metrics[k] = v + + # 2. Add overall_stats (flattened) + for k, v in overall_stats.items(): + test_metrics[k] = v + + # 3. Optionally include combined/loss if present and not already + if "loss" in combined_stats and "loss" not in test_metrics: + test_metrics["loss"] = combined_stats["loss"] + metrics["test"] = test_metrics + return metrics
