Mercurial > repos > goeckslab > multimodal_learner
diff utils.py @ 6:871957823d0c draft default tip
planemo upload for repository https://github.com/goeckslab/gleam.git commit 6e49ad44dd8572382ee203926690a30d7e888203
| author | goeckslab |
|---|---|
| date | Mon, 26 Jan 2026 18:44:07 +0000 |
| parents | b708d0e210e6 |
| children |
line wrap: on
line diff
--- a/utils.py Fri Jan 23 23:06:00 2026 +0000 +++ b/utils.py Mon Jan 26 18:44:07 2026 +0000 @@ -30,6 +30,7 @@ _MAX_EXTRACTED_INDEX_CACHE_SIZE = 2 _MAX_EXTRACTED_INDEX_FILES = 100000 _EXTRACTED_INDEX_CACHE = OrderedDict() +_EXTRACTED_PATH_CACHE = OrderedDict() def str2bool(val) -> bool: @@ -140,6 +141,33 @@ return index +def _build_extracted_maps(extracted_root: Optional[Path]) -> tuple[dict, dict]: + if extracted_root is None: + return {}, {} + rel_map: dict[str, str] = {} + name_map: dict[str, str] = {} + name_collisions = set() + count = 0 + for root, _dirs, files in os.walk(extracted_root): + rel_root = os.path.relpath(root, extracted_root) + for fname in files: + ext = os.path.splitext(fname)[1].lower() + if ext not in _IMAGE_EXTENSIONS: + continue + count += 1 + rel_path = fname if rel_root == "." else os.path.join(rel_root, fname) + rel_norm = rel_path.replace("\\", "/") + abs_path = os.path.join(root, fname) + rel_map[rel_norm] = abs_path + if fname in name_map and name_map[fname] != abs_path: + name_collisions.add(fname) + else: + name_map[fname] = abs_path + for name in name_collisions: + name_map.pop(name, None) + return rel_map, name_map + + def _get_cached_extracted_index(extracted_root: Optional[Path]) -> set: if extracted_root is None: return set() @@ -175,6 +203,37 @@ return index +def _get_cached_extracted_maps(extracted_root: Optional[Path]) -> tuple[dict, dict]: + if extracted_root is None: + return {}, {} + try: + root = extracted_root.resolve() + except Exception: + root = extracted_root + cache_key = str(root) + try: + mtime_ns = root.stat().st_mtime_ns + except OSError: + _EXTRACTED_PATH_CACHE.pop(cache_key, None) + return _build_extracted_maps(root) + cached = _EXTRACTED_PATH_CACHE.get(cache_key) + if cached: + cached_mtime, rel_map, name_map = cached + if cached_mtime == mtime_ns: + _EXTRACTED_PATH_CACHE.move_to_end(cache_key) + LOG.debug("Using cached extracted path map for %s (%d entries)", root, len(rel_map)) + return rel_map, name_map + _EXTRACTED_PATH_CACHE.pop(cache_key, None) + LOG.debug("Invalidated extracted path map cache for %s (mtime changed)", root) + rel_map, name_map = _build_extracted_maps(root) + if rel_map and len(rel_map) <= _MAX_EXTRACTED_INDEX_FILES: + _EXTRACTED_PATH_CACHE[cache_key] = (mtime_ns, rel_map, name_map) + _EXTRACTED_PATH_CACHE.move_to_end(cache_key) + while len(_EXTRACTED_PATH_CACHE) > _MAX_EXTRACTED_INDEX_CACHE_SIZE: + _EXTRACTED_PATH_CACHE.popitem(last=False) + return rel_map, name_map + + def prepare_image_search_dirs(args) -> Optional[Path]: if not args.images_zip: return None @@ -204,6 +263,7 @@ image_columns = [c for c in (image_columns or []) if c in df.columns] extracted_index = None + extracted_maps = None def get_extracted_index() -> set: nonlocal extracted_index @@ -211,6 +271,12 @@ extracted_index = _get_cached_extracted_index(extracted_root) return extracted_index + def get_extracted_maps() -> tuple[dict, dict]: + nonlocal extracted_maps + if extracted_maps is None: + extracted_maps = _get_cached_extracted_maps(extracted_root) + return extracted_maps + def resolve(p): if pd.isna(p): return None @@ -232,6 +298,17 @@ if e.errno == errno.ENAMETOOLONG: LOG.warning("Path too long for filesystem: %s", cand) continue + if extracted_root is not None: + rel_map, name_map = get_extracted_maps() + if rel_map: + norm = raw.replace("\\", "/").lstrip("./") + mapped = rel_map.get(norm) + if mapped: + return str(Path(mapped).resolve()) + base = Path(norm).name + mapped = name_map.get(base) + if mapped: + return str(Path(mapped).resolve()) return None def matches_extracted(p) -> bool:
