Mercurial > repos > goeckslab > multimodal_learner
diff utils.py @ 2:b708d0e210e6 draft default tip
planemo upload for repository https://github.com/goeckslab/gleam.git commit ffd47c4881aaa9fc33e7d3993a8fdf4bd82f3792
| author | goeckslab |
|---|---|
| date | Sat, 10 Jan 2026 16:13:19 +0000 |
| parents | 375c36923da1 |
| children |
line wrap: on
line diff
--- a/utils.py Fri Dec 19 05:12:47 2025 +0000 +++ b/utils.py Sat Jan 10 16:13:19 2026 +0000 @@ -1,3 +1,4 @@ +import errno import json import logging import os @@ -5,6 +6,7 @@ import sys import tempfile import zipfile +from collections import OrderedDict from pathlib import Path from typing import List, Optional @@ -13,6 +15,21 @@ import torch LOG = logging.getLogger(__name__) +_IMAGE_EXTENSIONS = { + ".jpg", + ".jpeg", + ".png", + ".bmp", + ".gif", + ".tif", + ".tiff", + ".webp", + ".svs", +} +_MAX_PATH_COMPONENT = 255 +_MAX_EXTRACTED_INDEX_CACHE_SIZE = 2 +_MAX_EXTRACTED_INDEX_FILES = 100000 +_EXTRACTED_INDEX_CACHE = OrderedDict() def str2bool(val) -> bool: @@ -89,6 +106,75 @@ return pd.read_csv(path, sep=None, engine="python") +def _normalize_path_value(val: object) -> Optional[str]: + if val is None: + return None + s = str(val).strip().strip('"').strip("'") + return s if s else None + + +def _warn_if_long_component(path_str: str) -> None: + for part in path_str.replace("\\", "/").split("/"): + if len(part) > _MAX_PATH_COMPONENT: + LOG.warning( + "Path component exceeds %d chars; resolution may fail: %s", + _MAX_PATH_COMPONENT, + path_str, + ) + return + + +def _build_extracted_index(extracted_root: Optional[Path]) -> set: + if extracted_root is None: + return set() + index = set() + for root, _dirs, files in os.walk(extracted_root): + rel_root = os.path.relpath(root, extracted_root) + for fname in files: + ext = os.path.splitext(fname)[1].lower() + if ext not in _IMAGE_EXTENSIONS: + continue + rel_path = fname if rel_root == "." else os.path.join(rel_root, fname) + index.add(rel_path.replace("\\", "/")) + index.add(fname) + return index + + +def _get_cached_extracted_index(extracted_root: Optional[Path]) -> set: + if extracted_root is None: + return set() + try: + root = extracted_root.resolve() + except Exception: + root = extracted_root + cache_key = str(root) + try: + mtime_ns = root.stat().st_mtime_ns + except OSError: + _EXTRACTED_INDEX_CACHE.pop(cache_key, None) + return _build_extracted_index(root) + cached = _EXTRACTED_INDEX_CACHE.get(cache_key) + if cached: + cached_mtime, cached_index = cached + if cached_mtime == mtime_ns: + _EXTRACTED_INDEX_CACHE.move_to_end(cache_key) + LOG.debug("Using cached extracted index for %s (%d entries)", root, len(cached_index)) + return cached_index + _EXTRACTED_INDEX_CACHE.pop(cache_key, None) + LOG.debug("Invalidated extracted index cache for %s (mtime changed)", root) + else: + LOG.debug("No extracted index cache for %s; building", root) + index = _build_extracted_index(root) + if len(index) <= _MAX_EXTRACTED_INDEX_FILES: + _EXTRACTED_INDEX_CACHE[cache_key] = (mtime_ns, index) + _EXTRACTED_INDEX_CACHE.move_to_end(cache_key) + while len(_EXTRACTED_INDEX_CACHE) > _MAX_EXTRACTED_INDEX_CACHE_SIZE: + _EXTRACTED_INDEX_CACHE.popitem(last=False) + else: + LOG.debug("Extracted index has %d entries; skipping cache for %s", len(index), root) + return index + + def prepare_image_search_dirs(args) -> Optional[Path]: if not args.images_zip: return None @@ -117,21 +203,50 @@ return [] image_columns = [c for c in (image_columns or []) if c in df.columns] + extracted_index = None + + def get_extracted_index() -> set: + nonlocal extracted_index + if extracted_index is None: + extracted_index = _get_cached_extracted_index(extracted_root) + return extracted_index def resolve(p): if pd.isna(p): return None - orig = Path(str(p).strip()) + raw = _normalize_path_value(p) + if not raw: + return None + _warn_if_long_component(raw) + orig = Path(raw) candidates = [] if orig.is_absolute(): candidates.append(orig) if extracted_root is not None: candidates.extend([extracted_root / orig, extracted_root / orig.name]) for cand in candidates: - if cand.exists(): - return str(cand.resolve()) + try: + if cand.exists(): + return str(cand.resolve()) + except OSError as e: + if e.errno == errno.ENAMETOOLONG: + LOG.warning("Path too long for filesystem: %s", cand) + continue return None + def matches_extracted(p) -> bool: + if pd.isna(p): + return False + raw = _normalize_path_value(p) + if not raw: + return False + _warn_if_long_component(raw) + index = get_extracted_index() + if not index: + return False + norm = raw.replace("\\", "/").lstrip("./") + return norm in index + # Infer image columns if none were provided if not image_columns: obj_cols = [c for c in df.columns if str(df[c].dtype) == "object"] @@ -140,6 +255,15 @@ sample = df[col].dropna().head(50) if sample.empty: continue + if extracted_root is not None: + index = get_extracted_index() + else: + index = set() + if index: + matched = sample.apply(matches_extracted) + if matched.any(): + inferred.append(col) + continue resolved_sample = sample.apply(resolve) if resolved_sample.notna().any(): inferred.append(col)
