Mercurial > repos > goeckslab > multimodal_learner
changeset 6:871957823d0c draft default tip
planemo upload for repository https://github.com/goeckslab/gleam.git commit 6e49ad44dd8572382ee203926690a30d7e888203
| author | goeckslab |
|---|---|
| date | Mon, 26 Jan 2026 18:44:07 +0000 |
| parents | 975512caae22 |
| children | |
| files | multimodal_learner.py multimodal_learner.xml training_pipeline.py utils.py |
| diffstat | 4 files changed, 168 insertions(+), 1 deletions(-) [+] |
line wrap: on
line diff
--- a/multimodal_learner.py Fri Jan 23 23:06:00 2026 +0000 +++ b/multimodal_learner.py Mon Jan 26 18:44:07 2026 +0000 @@ -63,6 +63,10 @@ parser.add_argument("--epochs", type=int, default=None) parser.add_argument("--learning_rate", type=float, default=None) parser.add_argument("--batch_size", type=int, default=None) + parser.add_argument("--num_workers", type=int, default=None, + help="DataLoader worker count (0 disables multiprocessing).") + parser.add_argument("--num_workers_eval", type=int, default=None, + help="DataLoader workers for evaluation; defaults to --num_workers.") parser.add_argument("--backbone_image", type=str, default="swin_base_patch4_window7_224") parser.add_argument("--backbone_text", type=str, default="microsoft/deberta-v3-base") parser.add_argument("--validation_size", type=float, default=0.2) @@ -369,6 +373,8 @@ epochs=args.epochs, learning_rate=args.learning_rate, batch_size=args.batch_size, + num_workers=args.num_workers, + num_workers_evaluation=args.num_workers_eval, backbone_image=args.backbone_image, backbone_text=args.backbone_text, preset=args.preset,
--- a/multimodal_learner.xml Fri Jan 23 23:06:00 2026 +0000 +++ b/multimodal_learner.xml Mon Jan 26 18:44:07 2026 +0000 @@ -1,4 +1,4 @@ -<tool id="multimodal_learner" name="Multimodal Learner" version="0.1.2" profile="22.01"> +<tool id="multimodal_learner" name="Multimodal Learner" version="0.1.3" profile="22.01"> <description>Train and evaluate an AutoGluon Multimodal model (tabular + image + text)</description> <requirements>
--- a/training_pipeline.py Fri Jan 23 23:06:00 2026 +0000 +++ b/training_pipeline.py Mon Jan 26 18:44:07 2026 +0000 @@ -20,6 +20,53 @@ logger = logging.getLogger(__name__) +_LOW_SHM_BYTES = 1 << 30 # 1 GiB + + +def _get_env_int(keys: List[str]) -> Optional[int]: + for key in keys: + if key not in os.environ: + continue + raw = os.environ.get(key) + try: + return int(raw) + except (TypeError, ValueError): + logger.warning("Ignoring non-integer %s=%s", key, raw) + return None + + +def _get_shm_bytes() -> Optional[int]: + try: + stat = os.statvfs("/dev/shm") + except Exception: + return None + return int(stat.f_frsize * stat.f_blocks) + + +def _resolve_num_workers( + explicit_value: Optional[int], + env_keys: List[str], + label: str, + shm_bytes: Optional[int], + default_value: Optional[int] = None, +) -> Optional[int]: + if explicit_value is not None: + return int(explicit_value) + env_val = _get_env_int(env_keys) + if env_val is not None: + return env_val + if shm_bytes is not None and shm_bytes < _LOW_SHM_BYTES: + logger.warning( + "Detected small /dev/shm (%.1f MB); setting %s num_workers=0 to avoid DataLoader shm errors.", + shm_bytes / (1024 * 1024), + label, + ) + return 0 + if default_value is not None: + logger.info("Using default %s num_workers=%d (heuristic).", label, int(default_value)) + return int(default_value) + return None + # ---------------------- small utilities ---------------------- @@ -390,6 +437,8 @@ epochs, learning_rate, batch_size, + num_workers, + num_workers_evaluation, backbone_image, backbone_text, preset, @@ -421,6 +470,37 @@ env_cfg["seed"] = int(random_seed) if batch_size is not None: env_cfg["per_gpu_batch_size"] = int(batch_size) + shm_bytes = _get_shm_bytes() + default_workers = None + if shm_bytes is None or shm_bytes >= _LOW_SHM_BYTES: + cpu_count = os.cpu_count() or 1 + default_workers = max(1, min(8, cpu_count // 2)) + resolved_num_workers = _resolve_num_workers( + num_workers, + ["AG_MM_NUM_WORKERS", "AG_NUM_WORKERS", "AUTOMM_NUM_WORKERS"], + "training", + shm_bytes, + default_value=default_workers, + ) + resolved_num_workers_inference = _resolve_num_workers( + num_workers_evaluation, + [ + "AG_MM_NUM_WORKERS_INFERENCE", + "AG_MM_NUM_WORKERS_EVAL", + "AG_MM_NUM_WORKERS_EVALUATION", + "AUTOMM_NUM_WORKERS_EVAL", + ], + "inference", + shm_bytes, + default_value=default_workers, + ) + if resolved_num_workers_inference is None and resolved_num_workers is not None: + resolved_num_workers_inference = resolved_num_workers + if resolved_num_workers is not None: + env_cfg["num_workers"] = int(resolved_num_workers) + if resolved_num_workers_inference is not None: + key = "num_workers_inference" + env_cfg[key] = int(resolved_num_workers_inference) optim_cfg = {} if epochs is not None: @@ -463,6 +543,10 @@ hp["optim.per_device_train_batch_size"] = bs_val hp["optim.batch_size"] = bs_val hp["env.per_gpu_batch_size"] = bs_val + if resolved_num_workers is not None: + hp["env.num_workers"] = int(resolved_num_workers) + if resolved_num_workers_inference is not None: + hp[f"env.{key}"] = int(resolved_num_workers_inference) if backbone_image: hp["model.timm_image.checkpoint_name"] = str(backbone_image) if backbone_text:
--- a/utils.py Fri Jan 23 23:06:00 2026 +0000 +++ b/utils.py Mon Jan 26 18:44:07 2026 +0000 @@ -30,6 +30,7 @@ _MAX_EXTRACTED_INDEX_CACHE_SIZE = 2 _MAX_EXTRACTED_INDEX_FILES = 100000 _EXTRACTED_INDEX_CACHE = OrderedDict() +_EXTRACTED_PATH_CACHE = OrderedDict() def str2bool(val) -> bool: @@ -140,6 +141,33 @@ return index +def _build_extracted_maps(extracted_root: Optional[Path]) -> tuple[dict, dict]: + if extracted_root is None: + return {}, {} + rel_map: dict[str, str] = {} + name_map: dict[str, str] = {} + name_collisions = set() + count = 0 + for root, _dirs, files in os.walk(extracted_root): + rel_root = os.path.relpath(root, extracted_root) + for fname in files: + ext = os.path.splitext(fname)[1].lower() + if ext not in _IMAGE_EXTENSIONS: + continue + count += 1 + rel_path = fname if rel_root == "." else os.path.join(rel_root, fname) + rel_norm = rel_path.replace("\\", "/") + abs_path = os.path.join(root, fname) + rel_map[rel_norm] = abs_path + if fname in name_map and name_map[fname] != abs_path: + name_collisions.add(fname) + else: + name_map[fname] = abs_path + for name in name_collisions: + name_map.pop(name, None) + return rel_map, name_map + + def _get_cached_extracted_index(extracted_root: Optional[Path]) -> set: if extracted_root is None: return set() @@ -175,6 +203,37 @@ return index +def _get_cached_extracted_maps(extracted_root: Optional[Path]) -> tuple[dict, dict]: + if extracted_root is None: + return {}, {} + try: + root = extracted_root.resolve() + except Exception: + root = extracted_root + cache_key = str(root) + try: + mtime_ns = root.stat().st_mtime_ns + except OSError: + _EXTRACTED_PATH_CACHE.pop(cache_key, None) + return _build_extracted_maps(root) + cached = _EXTRACTED_PATH_CACHE.get(cache_key) + if cached: + cached_mtime, rel_map, name_map = cached + if cached_mtime == mtime_ns: + _EXTRACTED_PATH_CACHE.move_to_end(cache_key) + LOG.debug("Using cached extracted path map for %s (%d entries)", root, len(rel_map)) + return rel_map, name_map + _EXTRACTED_PATH_CACHE.pop(cache_key, None) + LOG.debug("Invalidated extracted path map cache for %s (mtime changed)", root) + rel_map, name_map = _build_extracted_maps(root) + if rel_map and len(rel_map) <= _MAX_EXTRACTED_INDEX_FILES: + _EXTRACTED_PATH_CACHE[cache_key] = (mtime_ns, rel_map, name_map) + _EXTRACTED_PATH_CACHE.move_to_end(cache_key) + while len(_EXTRACTED_PATH_CACHE) > _MAX_EXTRACTED_INDEX_CACHE_SIZE: + _EXTRACTED_PATH_CACHE.popitem(last=False) + return rel_map, name_map + + def prepare_image_search_dirs(args) -> Optional[Path]: if not args.images_zip: return None @@ -204,6 +263,7 @@ image_columns = [c for c in (image_columns or []) if c in df.columns] extracted_index = None + extracted_maps = None def get_extracted_index() -> set: nonlocal extracted_index @@ -211,6 +271,12 @@ extracted_index = _get_cached_extracted_index(extracted_root) return extracted_index + def get_extracted_maps() -> tuple[dict, dict]: + nonlocal extracted_maps + if extracted_maps is None: + extracted_maps = _get_cached_extracted_maps(extracted_root) + return extracted_maps + def resolve(p): if pd.isna(p): return None @@ -232,6 +298,17 @@ if e.errno == errno.ENAMETOOLONG: LOG.warning("Path too long for filesystem: %s", cand) continue + if extracted_root is not None: + rel_map, name_map = get_extracted_maps() + if rel_map: + norm = raw.replace("\\", "/").lstrip("./") + mapped = rel_map.get(norm) + if mapped: + return str(Path(mapped).resolve()) + base = Path(norm).name + mapped = name_map.get(base) + if mapped: + return str(Path(mapped).resolve()) return None def matches_extracted(p) -> bool:
