changeset 3:25bb80df7c0c draft

planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
author goeckslab
date Sat, 17 Jan 2026 22:53:42 +0000
parents b708d0e210e6
children de753cf07008
files README.md multimodal_learner.py multimodal_learner.xml split_logic.py
diffstat 4 files changed, 256 insertions(+), 25 deletions(-) [+]
line wrap: on
line diff
--- a/README.md	Sat Jan 10 16:13:19 2026 +0000
+++ b/README.md	Sat Jan 17 22:53:42 2026 +0000
@@ -12,6 +12,7 @@
 ## Inputs
 - `Training dataset (CSV/TSV)`: includes the label column and any feature columns; image columns should contain file paths that exist in the provided ZIP archives (or absolute paths).
 - Optional `Test dataset (CSV/TSV)`: if omitted, the tool performs train/validation/test splitting or k-fold CV.
+- Optional `Sample ID column`: when provided, related rows are grouped into the same split and cross-validation fold to reduce leakage.
 - Optional `Image archive(s) (ZIP)`: one or more archives containing the image files referenced in the table.
 - Optional overrides: text and image backbones, evaluation metric, quality preset, threshold for binary tasks, and extra hyperparameters (JSON/YAML string or file path).
 
--- a/multimodal_learner.py	Sat Jan 10 16:13:19 2026 +0000
+++ b/multimodal_learner.py	Sat Jan 17 22:53:42 2026 +0000
@@ -68,6 +68,7 @@
     parser.add_argument("--validation_size", type=float, default=0.2)
     parser.add_argument("--split_probabilities", type=float, nargs=3,
                         default=[0.7, 0.1, 0.2], metavar=("train", "val", "test"))
+    parser.add_argument("--sample_id_column", default=None)
     parser.add_argument("--preset", choices=["medium_quality", "high_quality", "best_quality"],
                         default="medium_quality")
     parser.add_argument("--eval_metric", default="roc_auc")
@@ -103,7 +104,35 @@
     except Exception:
         use_stratified = False
 
-    kf = StratifiedKFold(n_splits=int(args.num_folds), shuffle=True, random_state=int(args.random_seed)) if use_stratified else KFold(n_splits=int(args.num_folds), shuffle=True, random_state=int(args.random_seed))
+    if args.sample_id_column and args.sample_id_column in df_full.columns:
+        groups = df_full[args.sample_id_column]
+        if use_stratified:
+            try:
+                from sklearn.model_selection import StratifiedGroupKFold
+
+                kf = StratifiedGroupKFold(
+                    n_splits=int(args.num_folds),
+                    shuffle=True,
+                    random_state=int(args.random_seed),
+                )
+            except Exception as exc:
+                logger.warning(
+                    "StratifiedGroupKFold unavailable (%s); falling back to GroupKFold.",
+                    exc,
+                )
+                from sklearn.model_selection import GroupKFold
+
+                kf = GroupKFold(n_splits=int(args.num_folds))
+                use_stratified = False
+        else:
+            from sklearn.model_selection import GroupKFold
+
+            kf = GroupKFold(n_splits=int(args.num_folds))
+    else:
+        kf = StratifiedKFold(n_splits=int(args.num_folds), shuffle=True, random_state=int(args.random_seed)) if use_stratified else KFold(n_splits=int(args.num_folds), shuffle=True, random_state=int(args.random_seed))
+
+    if args.sample_id_column and test_dataset is not None:
+        test_dataset = test_dataset.drop(columns=[args.sample_id_column], errors="ignore")
 
     raw_folds = []
     ag_folds = []
@@ -111,10 +140,18 @@
     last_predictor = None
     last_data_ctx = None
 
-    for fold_idx, (train_idx, val_idx) in enumerate(kf.split(df_full, y if use_stratified else None), start=1):
+    if args.sample_id_column and args.sample_id_column in df_full.columns:
+        split_iter = kf.split(df_full, y if use_stratified else None, groups)
+    else:
+        split_iter = kf.split(df_full, y if use_stratified else None)
+
+    for fold_idx, (train_idx, val_idx) in enumerate(split_iter, start=1):
         logger.info(f"CV fold {fold_idx}/{args.num_folds}")
         df_tr = df_full.iloc[train_idx].copy()
         df_va = df_full.iloc[val_idx].copy()
+        if args.sample_id_column:
+            df_tr = df_tr.drop(columns=[args.sample_id_column], errors="ignore")
+            df_va = df_va.drop(columns=[args.sample_id_column], errors="ignore")
 
         df_tr["split"] = "train"
         df_va["split"] = "val"
@@ -252,6 +289,7 @@
         split_probabilities=args.split_probabilities,
         validation_size=args.validation_size,
         random_seed=args.random_seed,
+        sample_id_column=args.sample_id_column,
     )
 
     logger.info("Preprocessing complete — ready for AutoGluon training!")
@@ -335,6 +373,11 @@
             "fit_summary": None,
         }
     else:
+        # Drop sample-id column before training so it does not leak into modeling.
+        if args.sample_id_column:
+            train_dataset = train_dataset.drop(columns=[args.sample_id_column], errors="ignore")
+            if test_dataset is not None:
+                test_dataset = test_dataset.drop(columns=[args.sample_id_column], errors="ignore")
         predictor, data_ctx = run_autogluon_experiment(
             train_dataset=train_dataset,
             test_dataset=test_dataset,
--- a/multimodal_learner.xml	Sat Jan 10 16:13:19 2026 +0000
+++ b/multimodal_learner.xml	Sat Jan 17 22:53:42 2026 +0000
@@ -1,4 +1,4 @@
-<tool id="multimodal_learner" name="Multimodal Learner" version="0.1.1" profile="22.01">
+<tool id="multimodal_learner" name="Multimodal Learner" version="0.1.2" profile="22.01">
   <description>Train and evaluate an AutoGluon Multimodal model (tabular + image + text)</description>
 
   <requirements>
@@ -43,12 +43,15 @@
 ln -sf '$test_dataset_conditional.input_test' 'test_input.csv';
 #end if
 
-python '$__tool_directory__/multimodal_learner.py'
+  python '$__tool_directory__/multimodal_learner.py'
   --input_csv_train 'train_input.csv'
   #if $test_dataset_conditional.has_test_dataset == "yes"
   --input_csv_test 'test_input.csv'
   #end if
   --target_column '$target_column'
+  #if $sample_id_selector.use_sample_id == "yes"
+  --sample_id_column '$sample_id_selector.sample_id_column'
+  #end if
 
   #if $use_images_conditional.use_images == "yes"
     #if $images_zip_cli
@@ -111,6 +114,16 @@
   <inputs>
     <param name="input_csv" type="data" format="csv,tsv" label="Training dataset (CSV/TSV)" help="Must contain the target column and optional image paths"/>
     <param name="target_column" type="data_column" data_ref="input_csv" numerical="false" use_header_names="true" label="Target / Label column"/>
+    <conditional name="sample_id_selector">
+      <param name="use_sample_id" type="select" label="Use a sample ID column for leakage-aware splitting?" help="Select yes to choose a column that groups related records (e.g., patient_id or slide_id).">
+        <option value="no" selected="true">No column selected</option>
+        <option value="yes">Yes</option>
+      </param>
+      <when value="yes">
+        <param name="sample_id_column" type="data_column" data_ref="input_csv" use_header_names="true" label="Sample ID column" help="All rows with the same ID stay in the same split or fold to reduce leakage. Used for internal train/val/test splits and group-aware CV folds." />
+      </when>
+      <when value="no"/>
+    </conditional>
 
     <conditional name="test_dataset_conditional">
       <param name="has_test_dataset" type="boolean" truevalue="yes" falsevalue="no" checked="false" label="Provide separate test dataset?"/>
--- a/split_logic.py	Sat Jan 10 16:13:19 2026 +0000
+++ b/split_logic.py	Sat Jan 17 22:53:42 2026 +0000
@@ -1,6 +1,7 @@
 import logging
 from typing import List, Optional
 
+import numpy as np
 import pandas as pd
 from sklearn.model_selection import train_test_split
 
@@ -19,6 +20,7 @@
     split_probabilities: List[float],
     validation_size: float,
     random_seed: int = 42,
+    sample_id_column: Optional[str] = None,
 ) -> None:
     if target_column not in train_dataset.columns:
         raise ValueError(f"Target column '{target_column}' not found")
@@ -32,24 +34,136 @@
         logger.warning(f"Dropped {before - len(train_dataset)} rows with NaN target")
     y = train_dataset[target_column]
 
+    if sample_id_column and sample_id_column not in train_dataset.columns:
+        logger.warning(
+            "Sample ID column '%s' not found; proceeding without group-aware split.",
+            sample_id_column,
+        )
+        sample_id_column = None
+    if sample_id_column and sample_id_column == target_column:
+        logger.warning(
+            "Sample ID column '%s' matches target column; proceeding without group-aware split.",
+            sample_id_column,
+        )
+        sample_id_column = None
+
     # Respect existing valid split column
     if SPLIT_COL in train_dataset.columns:
         unique = set(train_dataset[SPLIT_COL].dropna().unique())
         valid = {"train", "val", "validation", "test"}
-        if unique.issubset(valid | {"validation"}):
+        if unique.issubset(valid):
             train_dataset[SPLIT_COL] = train_dataset[SPLIT_COL].replace("validation", "val")
-            logger.info(f"Using pre-existing 'split' column: {sorted(unique)}")
+            normalized = set(train_dataset[SPLIT_COL].dropna().unique())
+            required = {"train"} if test_dataset is not None else {"train", "test"}
+            missing = required - normalized
+            if missing:
+                missing_list = ", ".join(sorted(missing))
+                if test_dataset is not None:
+                    raise ValueError(
+                        "Pre-existing 'split' column is missing required split(s): "
+                        f"{missing_list}. Expected at least train when an external test set is provided, "
+                        "or remove the 'split' column to let the tool create splits."
+                    )
+                raise ValueError(
+                    "Pre-existing 'split' column is missing required split(s): "
+                    f"{missing_list}. Expected at least train and test, "
+                    "or remove the 'split' column to let the tool create splits."
+                )
+            logger.info(f"Using pre-existing 'split' column: {sorted(normalized)}")
             return
 
     train_dataset[SPLIT_COL] = "train"
 
+    def _allocate_split_counts(n_total: int, probs: list) -> list:
+        if n_total <= 0:
+            return [0 for _ in probs]
+        counts = [0 for _ in probs]
+        active = [i for i, p in enumerate(probs) if p > 0]
+        remainder = n_total
+        if active and n_total >= len(active):
+            for i in active:
+                counts[i] = 1
+            remainder -= len(active)
+        if remainder > 0:
+            probs_arr = np.array(probs, dtype=float)
+            probs_arr = probs_arr / probs_arr.sum()
+            raw = remainder * probs_arr
+            floors = np.floor(raw).astype(int)
+            for i, value in enumerate(floors.tolist()):
+                counts[i] += value
+            leftover = remainder - int(floors.sum())
+            if leftover > 0 and active:
+                frac = raw - floors
+                order = sorted(active, key=lambda i: (-frac[i], i))
+                for i in range(leftover):
+                    counts[order[i % len(order)]] += 1
+        return counts
+
+    def _choose_split(counts: list, targets: list, active: list) -> int:
+        remaining = [targets[i] - counts[i] for i in range(len(targets))]
+        best = max(active, key=lambda i: (remaining[i], -counts[i], -targets[i]))
+        if remaining[best] <= 0:
+            best = min(active, key=lambda i: counts[i])
+        return best
+
     if test_dataset is not None:
-        stratify = y if _can_stratify(y) else None
-        train_idx, val_idx = train_test_split(
-            train_dataset.index, test_size=validation_size,
-            random_state=random_seed, stratify=stratify
-        )
-        train_dataset.loc[val_idx, SPLIT_COL] = "val"
+        if sample_id_column and sample_id_column in test_dataset.columns:
+            train_ids = set(train_dataset[sample_id_column].dropna().astype(object).unique())
+            test_ids = set(test_dataset[sample_id_column].dropna().astype(object).unique())
+            overlap = train_ids & test_ids
+            if overlap:
+                logger.warning(
+                    "Sample ID column '%s' has %d overlapping IDs between train and external test sets; "
+                    "consider removing overlaps to avoid leakage.",
+                    sample_id_column,
+                    len(overlap),
+                )
+        if sample_id_column:
+            rng = np.random.RandomState(random_seed)
+            group_series = train_dataset[sample_id_column].astype(object)
+            missing_mask = group_series.isna()
+            if missing_mask.any():
+                group_series = group_series.copy()
+                group_series.loc[missing_mask] = [
+                    f"__missing__{idx}" for idx in group_series.index[missing_mask]
+                ]
+            group_to_indices = {}
+            for idx, group_id in group_series.items():
+                group_to_indices.setdefault(group_id, []).append(idx)
+            group_ids = sorted(group_to_indices.keys(), key=lambda x: str(x))
+            rng.shuffle(group_ids)
+            targets = _allocate_split_counts(len(train_dataset), [1.0 - validation_size, validation_size])
+            counts = [0, 0]
+            active = [0, 1]
+            train_idx = []
+            val_idx = []
+            for group_id in group_ids:
+                split_idx = _choose_split(counts, targets, active)
+                counts[split_idx] += len(group_to_indices[group_id])
+                if split_idx == 0:
+                    train_idx.extend(group_to_indices[group_id])
+                else:
+                    val_idx.extend(group_to_indices[group_id])
+            train_dataset.loc[val_idx, SPLIT_COL] = "val"
+        else:
+            if validation_size <= 0:
+                logger.warning(
+                    "validation_size is %.3f; skipping validation split to avoid train_test_split errors.",
+                    validation_size,
+                )
+            elif validation_size >= 1:
+                logger.warning(
+                    "validation_size is %.3f; assigning all rows to validation to avoid train_test_split errors.",
+                    validation_size,
+                )
+                train_dataset[SPLIT_COL] = "val"
+            else:
+                stratify = y if _can_stratify(y) else None
+                train_idx, val_idx = train_test_split(
+                    train_dataset.index, test_size=validation_size,
+                    random_state=random_seed, stratify=stratify
+                )
+                train_dataset.loc[val_idx, SPLIT_COL] = "val"
         logger.info(f"External test set → created val split ({validation_size:.0%})")
 
     else:
@@ -57,20 +171,80 @@
         if abs(p_train + p_val + p_test - 1.0) > 1e-6:
             raise ValueError("split_probabilities must sum to 1.0")
 
-        stratify = y if _can_stratify(y) else None
-        tv_idx, test_idx = train_test_split(
-            train_dataset.index, test_size=p_test,
-            random_state=random_seed, stratify=stratify
-        )
-        rel_val = p_val / (p_train + p_val) if (p_train + p_val) > 0 else 0
-        strat_tv = y.loc[tv_idx] if _can_stratify(y.loc[tv_idx]) else None
-        train_idx, val_idx = train_test_split(
-            tv_idx, test_size=rel_val,
-            random_state=random_seed, stratify=strat_tv
-        )
+        if sample_id_column:
+            rng = np.random.RandomState(random_seed)
+            group_series = train_dataset[sample_id_column].astype(object)
+            missing_mask = group_series.isna()
+            if missing_mask.any():
+                group_series = group_series.copy()
+                group_series.loc[missing_mask] = [
+                    f"__missing__{idx}" for idx in group_series.index[missing_mask]
+                ]
+            group_to_indices = {}
+            for idx, group_id in group_series.items():
+                group_to_indices.setdefault(group_id, []).append(idx)
+            group_ids = sorted(group_to_indices.keys(), key=lambda x: str(x))
+            rng.shuffle(group_ids)
+            targets = _allocate_split_counts(len(train_dataset), [p_train, p_val, p_test])
+            counts = [0, 0, 0]
+            active = [0, 1, 2]
+            train_idx = []
+            val_idx = []
+            test_idx = []
+            for group_id in group_ids:
+                split_idx = _choose_split(counts, targets, active)
+                counts[split_idx] += len(group_to_indices[group_id])
+                if split_idx == 0:
+                    train_idx.extend(group_to_indices[group_id])
+                elif split_idx == 1:
+                    val_idx.extend(group_to_indices[group_id])
+                else:
+                    test_idx.extend(group_to_indices[group_id])
+            train_dataset.loc[val_idx, SPLIT_COL] = "val"
+            train_dataset.loc[test_idx, SPLIT_COL] = "test"
+        else:
+            stratify = y if _can_stratify(y) else None
+            if p_test <= 0:
+                logger.warning(
+                    "split_probabilities specify 0 test size; skipping test split to avoid train_test_split errors."
+                )
+                tv_idx = train_dataset.index
+                test_idx = train_dataset.index[:0]
+            elif p_test >= 1:
+                logger.warning(
+                    "split_probabilities specify 100% test size; assigning all rows to test to avoid train_test_split errors."
+                )
+                tv_idx = train_dataset.index[:0]
+                test_idx = train_dataset.index
+            else:
+                tv_idx, test_idx = train_test_split(
+                    train_dataset.index, test_size=p_test,
+                    random_state=random_seed, stratify=stratify
+                )
+            rel_val = p_val / (p_train + p_val) if (p_train + p_val) > 0 else 0
+            train_idx = train_dataset.index[:0]
+            val_idx = train_dataset.index[:0]
+            if len(tv_idx):
+                if rel_val <= 0:
+                    logger.warning(
+                        "split_probabilities specify 0 validation size; skipping validation split to avoid train_test_split errors."
+                    )
+                    train_idx = tv_idx
+                elif rel_val >= 1:
+                    logger.warning(
+                        "split_probabilities specify 100% validation size; assigning all remaining rows to validation "
+                        "to avoid train_test_split errors."
+                    )
+                    val_idx = tv_idx
+                else:
+                    strat_tv = y.loc[tv_idx] if _can_stratify(y.loc[tv_idx]) else None
+                    train_idx, val_idx = train_test_split(
+                        tv_idx, test_size=rel_val,
+                        random_state=random_seed, stratify=strat_tv
+                    )
 
-        train_dataset.loc[val_idx, SPLIT_COL] = "val"
-        train_dataset.loc[test_idx, SPLIT_COL] = "test"
+            train_dataset.loc[val_idx, SPLIT_COL] = "val"
+            train_dataset.loc[test_idx, SPLIT_COL] = "test"
         logger.info(f"3-way split → train:{len(train_idx)}, val:{len(val_idx)}, test:{len(test_idx)}")
 
     logger.info(f"Final split distribution:\n{train_dataset[SPLIT_COL].value_counts().sort_index()}")