Mercurial > repos > goeckslab > image_learner
view utils.py @ 13:1a9c42974a5a draft default tip
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9f96da4ea7ab3b572af86698ff51b870125cd674
| author | goeckslab |
|---|---|
| date | Fri, 21 Nov 2025 17:35:00 +0000 |
| parents | bcfa2e234a80 |
| children |
line wrap: on
line source
import logging from pathlib import Path import pandas as pd logger = logging.getLogger("ImageLearner") def load_metadata_table(file_path: Path) -> pd.DataFrame: """Load image metadata allowing either CSV or TSV delimiters.""" logger.info("Loading metadata table from %s", file_path) return pd.read_csv(file_path, sep=None, engine="python") def detect_output_type(test_stats): """Detects if the output type is 'binary' or 'category' based on test statistics.""" label_stats = test_stats.get("label", {}) if "mean_squared_error" in label_stats: return "regression" per_class = label_stats.get("per_class_stats", {}) if len(per_class) == 2: return "binary" return "category" def aug_parse(aug_string: str): """ Parse comma-separated augmentation keys into Ludwig augmentation dicts. Raises ValueError on unknown key. """ mapping = { "random_horizontal_flip": {"type": "random_horizontal_flip"}, "random_vertical_flip": {"type": "random_vertical_flip"}, "random_rotate": {"type": "random_rotate", "degree": 10}, "random_blur": {"type": "random_blur", "kernel_size": 3}, "random_brightness": {"type": "random_brightness", "min": 0.5, "max": 2.0}, "random_contrast": {"type": "random_contrast", "min": 0.5, "max": 2.0}, } aug_list = [] for tok in aug_string.split(","): key = tok.strip() if not key: continue if key not in mapping: valid = ", ".join(mapping.keys()) raise ValueError(f"Unknown augmentation '{key}'. Valid choices: {valid}") aug_list.append(mapping[key]) return aug_list def argument_checker(args, parser): if not 0.0 <= args.validation_size <= 1.0: parser.error("validation-size must be between 0.0 and 1.0") if not args.csv_file.is_file(): parser.error(f"Metada file not found: {args.csv_file}") if not (args.image_zip.is_file() or args.image_zip.is_dir()): parser.error(f"ZIP or directory not found: {args.image_zip}") if args.augmentation is not None: try: augmentation_setup = aug_parse(args.augmentation) setattr(args, "augmentation", augmentation_setup) except ValueError as e: parser.error(str(e)) def parse_learning_rate(s): try: return float(s) except (TypeError, ValueError): return None def extract_metrics_from_json( train_stats: dict, test_stats: dict, output_type: str, ) -> dict: """Extracts relevant metrics from training and test statistics based on the output type.""" metrics = {"training": {}, "validation": {}, "test": {}} def get_last_value(stats, key): val = stats.get(key) if isinstance(val, list) and val: return val[-1] elif isinstance(val, (int, float)): return val return None for split in ["training", "validation"]: split_stats = train_stats.get(split, {}) if not split_stats: logger.warning("No statistics found for %s split", split) continue label_stats = split_stats.get("label", {}) if not label_stats: logger.warning("No label statistics found for %s split", split) continue if output_type == "binary": metrics[split] = { "accuracy": get_last_value(label_stats, "accuracy"), "loss": get_last_value(label_stats, "loss"), "precision": get_last_value(label_stats, "precision"), "recall": get_last_value(label_stats, "recall"), "specificity": get_last_value(label_stats, "specificity"), "roc_auc": get_last_value(label_stats, "roc_auc"), } elif output_type == "regression": metrics[split] = { "loss": get_last_value(label_stats, "loss"), "mean_absolute_error": get_last_value( label_stats, "mean_absolute_error" ), "mean_absolute_percentage_error": get_last_value( label_stats, "mean_absolute_percentage_error" ), "mean_squared_error": get_last_value(label_stats, "mean_squared_error"), "root_mean_squared_error": get_last_value( label_stats, "root_mean_squared_error" ), "root_mean_squared_percentage_error": get_last_value( label_stats, "root_mean_squared_percentage_error" ), "r2": get_last_value(label_stats, "r2"), } else: metrics[split] = { "accuracy": get_last_value(label_stats, "accuracy"), "accuracy_micro": get_last_value(label_stats, "accuracy_micro"), "loss": get_last_value(label_stats, "loss"), "roc_auc": get_last_value(label_stats, "roc_auc"), "hits_at_k": get_last_value(label_stats, "hits_at_k"), } # Test metrics: dynamic extraction according to exclusions test_label_stats = test_stats.get("label", {}) if not test_label_stats: logger.warning("No label statistics found for test split") else: combined_stats = test_stats.get("combined", {}) overall_stats = test_label_stats.get("overall_stats", {}) # Define exclusions if output_type == "binary": exclude = {"per_class_stats", "precision_recall_curve", "roc_curve"} else: exclude = {"per_class_stats", "confusion_matrix"} # 1. Get all scalar test_label_stats not excluded test_metrics = {} for k, v in test_label_stats.items(): if k in exclude: continue if k == "overall_stats": continue if isinstance(v, (int, float, str, bool)): test_metrics[k] = v # 2. Add overall_stats (flattened) for k, v in overall_stats.items(): test_metrics[k] = v # 3. Optionally include combined/loss if present and not already if "loss" in combined_stats and "loss" not in test_metrics: test_metrics["loss"] = combined_stats["loss"] metrics["test"] = test_metrics return metrics
