changeset 6:871957823d0c draft default tip

planemo upload for repository https://github.com/goeckslab/gleam.git commit 6e49ad44dd8572382ee203926690a30d7e888203
author goeckslab
date Mon, 26 Jan 2026 18:44:07 +0000
parents 975512caae22
children
files multimodal_learner.py multimodal_learner.xml training_pipeline.py utils.py
diffstat 4 files changed, 168 insertions(+), 1 deletions(-) [+]
line wrap: on
line diff
--- a/multimodal_learner.py	Fri Jan 23 23:06:00 2026 +0000
+++ b/multimodal_learner.py	Mon Jan 26 18:44:07 2026 +0000
@@ -63,6 +63,10 @@
     parser.add_argument("--epochs", type=int, default=None)
     parser.add_argument("--learning_rate", type=float, default=None)
     parser.add_argument("--batch_size", type=int, default=None)
+    parser.add_argument("--num_workers", type=int, default=None,
+                        help="DataLoader worker count (0 disables multiprocessing).")
+    parser.add_argument("--num_workers_eval", type=int, default=None,
+                        help="DataLoader workers for evaluation; defaults to --num_workers.")
     parser.add_argument("--backbone_image", type=str, default="swin_base_patch4_window7_224")
     parser.add_argument("--backbone_text", type=str, default="microsoft/deberta-v3-base")
     parser.add_argument("--validation_size", type=float, default=0.2)
@@ -369,6 +373,8 @@
         epochs=args.epochs,
         learning_rate=args.learning_rate,
         batch_size=args.batch_size,
+        num_workers=args.num_workers,
+        num_workers_evaluation=args.num_workers_eval,
         backbone_image=args.backbone_image,
         backbone_text=args.backbone_text,
         preset=args.preset,
--- a/multimodal_learner.xml	Fri Jan 23 23:06:00 2026 +0000
+++ b/multimodal_learner.xml	Mon Jan 26 18:44:07 2026 +0000
@@ -1,4 +1,4 @@
-<tool id="multimodal_learner" name="Multimodal Learner" version="0.1.2" profile="22.01">
+<tool id="multimodal_learner" name="Multimodal Learner" version="0.1.3" profile="22.01">
   <description>Train and evaluate an AutoGluon Multimodal model (tabular + image + text)</description>
 
   <requirements>
--- a/training_pipeline.py	Fri Jan 23 23:06:00 2026 +0000
+++ b/training_pipeline.py	Mon Jan 26 18:44:07 2026 +0000
@@ -20,6 +20,53 @@
 
 logger = logging.getLogger(__name__)
 
+_LOW_SHM_BYTES = 1 << 30  # 1 GiB
+
+
+def _get_env_int(keys: List[str]) -> Optional[int]:
+    for key in keys:
+        if key not in os.environ:
+            continue
+        raw = os.environ.get(key)
+        try:
+            return int(raw)
+        except (TypeError, ValueError):
+            logger.warning("Ignoring non-integer %s=%s", key, raw)
+    return None
+
+
+def _get_shm_bytes() -> Optional[int]:
+    try:
+        stat = os.statvfs("/dev/shm")
+    except Exception:
+        return None
+    return int(stat.f_frsize * stat.f_blocks)
+
+
+def _resolve_num_workers(
+    explicit_value: Optional[int],
+    env_keys: List[str],
+    label: str,
+    shm_bytes: Optional[int],
+    default_value: Optional[int] = None,
+) -> Optional[int]:
+    if explicit_value is not None:
+        return int(explicit_value)
+    env_val = _get_env_int(env_keys)
+    if env_val is not None:
+        return env_val
+    if shm_bytes is not None and shm_bytes < _LOW_SHM_BYTES:
+        logger.warning(
+            "Detected small /dev/shm (%.1f MB); setting %s num_workers=0 to avoid DataLoader shm errors.",
+            shm_bytes / (1024 * 1024),
+            label,
+        )
+        return 0
+    if default_value is not None:
+        logger.info("Using default %s num_workers=%d (heuristic).", label, int(default_value))
+        return int(default_value)
+    return None
+
 # ---------------------- small utilities ----------------------
 
 
@@ -390,6 +437,8 @@
     epochs,
     learning_rate,
     batch_size,
+    num_workers,
+    num_workers_evaluation,
     backbone_image,
     backbone_text,
     preset,
@@ -421,6 +470,37 @@
         env_cfg["seed"] = int(random_seed)
     if batch_size is not None:
         env_cfg["per_gpu_batch_size"] = int(batch_size)
+    shm_bytes = _get_shm_bytes()
+    default_workers = None
+    if shm_bytes is None or shm_bytes >= _LOW_SHM_BYTES:
+        cpu_count = os.cpu_count() or 1
+        default_workers = max(1, min(8, cpu_count // 2))
+    resolved_num_workers = _resolve_num_workers(
+        num_workers,
+        ["AG_MM_NUM_WORKERS", "AG_NUM_WORKERS", "AUTOMM_NUM_WORKERS"],
+        "training",
+        shm_bytes,
+        default_value=default_workers,
+    )
+    resolved_num_workers_inference = _resolve_num_workers(
+        num_workers_evaluation,
+        [
+            "AG_MM_NUM_WORKERS_INFERENCE",
+            "AG_MM_NUM_WORKERS_EVAL",
+            "AG_MM_NUM_WORKERS_EVALUATION",
+            "AUTOMM_NUM_WORKERS_EVAL",
+        ],
+        "inference",
+        shm_bytes,
+        default_value=default_workers,
+    )
+    if resolved_num_workers_inference is None and resolved_num_workers is not None:
+        resolved_num_workers_inference = resolved_num_workers
+    if resolved_num_workers is not None:
+        env_cfg["num_workers"] = int(resolved_num_workers)
+    if resolved_num_workers_inference is not None:
+        key = "num_workers_inference"
+        env_cfg[key] = int(resolved_num_workers_inference)
 
     optim_cfg = {}
     if epochs is not None:
@@ -463,6 +543,10 @@
         hp["optim.per_device_train_batch_size"] = bs_val
         hp["optim.batch_size"] = bs_val
         hp["env.per_gpu_batch_size"] = bs_val
+    if resolved_num_workers is not None:
+        hp["env.num_workers"] = int(resolved_num_workers)
+    if resolved_num_workers_inference is not None:
+        hp[f"env.{key}"] = int(resolved_num_workers_inference)
     if backbone_image:
         hp["model.timm_image.checkpoint_name"] = str(backbone_image)
     if backbone_text:
--- a/utils.py	Fri Jan 23 23:06:00 2026 +0000
+++ b/utils.py	Mon Jan 26 18:44:07 2026 +0000
@@ -30,6 +30,7 @@
 _MAX_EXTRACTED_INDEX_CACHE_SIZE = 2
 _MAX_EXTRACTED_INDEX_FILES = 100000
 _EXTRACTED_INDEX_CACHE = OrderedDict()
+_EXTRACTED_PATH_CACHE = OrderedDict()
 
 
 def str2bool(val) -> bool:
@@ -140,6 +141,33 @@
     return index
 
 
+def _build_extracted_maps(extracted_root: Optional[Path]) -> tuple[dict, dict]:
+    if extracted_root is None:
+        return {}, {}
+    rel_map: dict[str, str] = {}
+    name_map: dict[str, str] = {}
+    name_collisions = set()
+    count = 0
+    for root, _dirs, files in os.walk(extracted_root):
+        rel_root = os.path.relpath(root, extracted_root)
+        for fname in files:
+            ext = os.path.splitext(fname)[1].lower()
+            if ext not in _IMAGE_EXTENSIONS:
+                continue
+            count += 1
+            rel_path = fname if rel_root == "." else os.path.join(rel_root, fname)
+            rel_norm = rel_path.replace("\\", "/")
+            abs_path = os.path.join(root, fname)
+            rel_map[rel_norm] = abs_path
+            if fname in name_map and name_map[fname] != abs_path:
+                name_collisions.add(fname)
+            else:
+                name_map[fname] = abs_path
+    for name in name_collisions:
+        name_map.pop(name, None)
+    return rel_map, name_map
+
+
 def _get_cached_extracted_index(extracted_root: Optional[Path]) -> set:
     if extracted_root is None:
         return set()
@@ -175,6 +203,37 @@
     return index
 
 
+def _get_cached_extracted_maps(extracted_root: Optional[Path]) -> tuple[dict, dict]:
+    if extracted_root is None:
+        return {}, {}
+    try:
+        root = extracted_root.resolve()
+    except Exception:
+        root = extracted_root
+    cache_key = str(root)
+    try:
+        mtime_ns = root.stat().st_mtime_ns
+    except OSError:
+        _EXTRACTED_PATH_CACHE.pop(cache_key, None)
+        return _build_extracted_maps(root)
+    cached = _EXTRACTED_PATH_CACHE.get(cache_key)
+    if cached:
+        cached_mtime, rel_map, name_map = cached
+        if cached_mtime == mtime_ns:
+            _EXTRACTED_PATH_CACHE.move_to_end(cache_key)
+            LOG.debug("Using cached extracted path map for %s (%d entries)", root, len(rel_map))
+            return rel_map, name_map
+        _EXTRACTED_PATH_CACHE.pop(cache_key, None)
+        LOG.debug("Invalidated extracted path map cache for %s (mtime changed)", root)
+    rel_map, name_map = _build_extracted_maps(root)
+    if rel_map and len(rel_map) <= _MAX_EXTRACTED_INDEX_FILES:
+        _EXTRACTED_PATH_CACHE[cache_key] = (mtime_ns, rel_map, name_map)
+        _EXTRACTED_PATH_CACHE.move_to_end(cache_key)
+        while len(_EXTRACTED_PATH_CACHE) > _MAX_EXTRACTED_INDEX_CACHE_SIZE:
+            _EXTRACTED_PATH_CACHE.popitem(last=False)
+    return rel_map, name_map
+
+
 def prepare_image_search_dirs(args) -> Optional[Path]:
     if not args.images_zip:
         return None
@@ -204,6 +263,7 @@
 
     image_columns = [c for c in (image_columns or []) if c in df.columns]
     extracted_index = None
+    extracted_maps = None
 
     def get_extracted_index() -> set:
         nonlocal extracted_index
@@ -211,6 +271,12 @@
             extracted_index = _get_cached_extracted_index(extracted_root)
         return extracted_index
 
+    def get_extracted_maps() -> tuple[dict, dict]:
+        nonlocal extracted_maps
+        if extracted_maps is None:
+            extracted_maps = _get_cached_extracted_maps(extracted_root)
+        return extracted_maps
+
     def resolve(p):
         if pd.isna(p):
             return None
@@ -232,6 +298,17 @@
                 if e.errno == errno.ENAMETOOLONG:
                     LOG.warning("Path too long for filesystem: %s", cand)
                 continue
+        if extracted_root is not None:
+            rel_map, name_map = get_extracted_maps()
+            if rel_map:
+                norm = raw.replace("\\", "/").lstrip("./")
+                mapped = rel_map.get(norm)
+                if mapped:
+                    return str(Path(mapped).resolve())
+                base = Path(norm).name
+                mapped = name_map.get(base)
+                if mapped:
+                    return str(Path(mapped).resolve())
         return None
 
     def matches_extracted(p) -> bool: