diff split_logic.py @ 3:25bb80df7c0c draft default tip

planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
author goeckslab
date Sat, 17 Jan 2026 22:53:42 +0000
parents 375c36923da1
children
line wrap: on
line diff
--- a/split_logic.py	Sat Jan 10 16:13:19 2026 +0000
+++ b/split_logic.py	Sat Jan 17 22:53:42 2026 +0000
@@ -1,6 +1,7 @@
 import logging
 from typing import List, Optional
 
+import numpy as np
 import pandas as pd
 from sklearn.model_selection import train_test_split
 
@@ -19,6 +20,7 @@
     split_probabilities: List[float],
     validation_size: float,
     random_seed: int = 42,
+    sample_id_column: Optional[str] = None,
 ) -> None:
     if target_column not in train_dataset.columns:
         raise ValueError(f"Target column '{target_column}' not found")
@@ -32,24 +34,136 @@
         logger.warning(f"Dropped {before - len(train_dataset)} rows with NaN target")
     y = train_dataset[target_column]
 
+    if sample_id_column and sample_id_column not in train_dataset.columns:
+        logger.warning(
+            "Sample ID column '%s' not found; proceeding without group-aware split.",
+            sample_id_column,
+        )
+        sample_id_column = None
+    if sample_id_column and sample_id_column == target_column:
+        logger.warning(
+            "Sample ID column '%s' matches target column; proceeding without group-aware split.",
+            sample_id_column,
+        )
+        sample_id_column = None
+
     # Respect existing valid split column
     if SPLIT_COL in train_dataset.columns:
         unique = set(train_dataset[SPLIT_COL].dropna().unique())
         valid = {"train", "val", "validation", "test"}
-        if unique.issubset(valid | {"validation"}):
+        if unique.issubset(valid):
             train_dataset[SPLIT_COL] = train_dataset[SPLIT_COL].replace("validation", "val")
-            logger.info(f"Using pre-existing 'split' column: {sorted(unique)}")
+            normalized = set(train_dataset[SPLIT_COL].dropna().unique())
+            required = {"train"} if test_dataset is not None else {"train", "test"}
+            missing = required - normalized
+            if missing:
+                missing_list = ", ".join(sorted(missing))
+                if test_dataset is not None:
+                    raise ValueError(
+                        "Pre-existing 'split' column is missing required split(s): "
+                        f"{missing_list}. Expected at least train when an external test set is provided, "
+                        "or remove the 'split' column to let the tool create splits."
+                    )
+                raise ValueError(
+                    "Pre-existing 'split' column is missing required split(s): "
+                    f"{missing_list}. Expected at least train and test, "
+                    "or remove the 'split' column to let the tool create splits."
+                )
+            logger.info(f"Using pre-existing 'split' column: {sorted(normalized)}")
             return
 
     train_dataset[SPLIT_COL] = "train"
 
+    def _allocate_split_counts(n_total: int, probs: list) -> list:
+        if n_total <= 0:
+            return [0 for _ in probs]
+        counts = [0 for _ in probs]
+        active = [i for i, p in enumerate(probs) if p > 0]
+        remainder = n_total
+        if active and n_total >= len(active):
+            for i in active:
+                counts[i] = 1
+            remainder -= len(active)
+        if remainder > 0:
+            probs_arr = np.array(probs, dtype=float)
+            probs_arr = probs_arr / probs_arr.sum()
+            raw = remainder * probs_arr
+            floors = np.floor(raw).astype(int)
+            for i, value in enumerate(floors.tolist()):
+                counts[i] += value
+            leftover = remainder - int(floors.sum())
+            if leftover > 0 and active:
+                frac = raw - floors
+                order = sorted(active, key=lambda i: (-frac[i], i))
+                for i in range(leftover):
+                    counts[order[i % len(order)]] += 1
+        return counts
+
+    def _choose_split(counts: list, targets: list, active: list) -> int:
+        remaining = [targets[i] - counts[i] for i in range(len(targets))]
+        best = max(active, key=lambda i: (remaining[i], -counts[i], -targets[i]))
+        if remaining[best] <= 0:
+            best = min(active, key=lambda i: counts[i])
+        return best
+
     if test_dataset is not None:
-        stratify = y if _can_stratify(y) else None
-        train_idx, val_idx = train_test_split(
-            train_dataset.index, test_size=validation_size,
-            random_state=random_seed, stratify=stratify
-        )
-        train_dataset.loc[val_idx, SPLIT_COL] = "val"
+        if sample_id_column and sample_id_column in test_dataset.columns:
+            train_ids = set(train_dataset[sample_id_column].dropna().astype(object).unique())
+            test_ids = set(test_dataset[sample_id_column].dropna().astype(object).unique())
+            overlap = train_ids & test_ids
+            if overlap:
+                logger.warning(
+                    "Sample ID column '%s' has %d overlapping IDs between train and external test sets; "
+                    "consider removing overlaps to avoid leakage.",
+                    sample_id_column,
+                    len(overlap),
+                )
+        if sample_id_column:
+            rng = np.random.RandomState(random_seed)
+            group_series = train_dataset[sample_id_column].astype(object)
+            missing_mask = group_series.isna()
+            if missing_mask.any():
+                group_series = group_series.copy()
+                group_series.loc[missing_mask] = [
+                    f"__missing__{idx}" for idx in group_series.index[missing_mask]
+                ]
+            group_to_indices = {}
+            for idx, group_id in group_series.items():
+                group_to_indices.setdefault(group_id, []).append(idx)
+            group_ids = sorted(group_to_indices.keys(), key=lambda x: str(x))
+            rng.shuffle(group_ids)
+            targets = _allocate_split_counts(len(train_dataset), [1.0 - validation_size, validation_size])
+            counts = [0, 0]
+            active = [0, 1]
+            train_idx = []
+            val_idx = []
+            for group_id in group_ids:
+                split_idx = _choose_split(counts, targets, active)
+                counts[split_idx] += len(group_to_indices[group_id])
+                if split_idx == 0:
+                    train_idx.extend(group_to_indices[group_id])
+                else:
+                    val_idx.extend(group_to_indices[group_id])
+            train_dataset.loc[val_idx, SPLIT_COL] = "val"
+        else:
+            if validation_size <= 0:
+                logger.warning(
+                    "validation_size is %.3f; skipping validation split to avoid train_test_split errors.",
+                    validation_size,
+                )
+            elif validation_size >= 1:
+                logger.warning(
+                    "validation_size is %.3f; assigning all rows to validation to avoid train_test_split errors.",
+                    validation_size,
+                )
+                train_dataset[SPLIT_COL] = "val"
+            else:
+                stratify = y if _can_stratify(y) else None
+                train_idx, val_idx = train_test_split(
+                    train_dataset.index, test_size=validation_size,
+                    random_state=random_seed, stratify=stratify
+                )
+                train_dataset.loc[val_idx, SPLIT_COL] = "val"
         logger.info(f"External test set → created val split ({validation_size:.0%})")
 
     else:
@@ -57,20 +171,80 @@
         if abs(p_train + p_val + p_test - 1.0) > 1e-6:
             raise ValueError("split_probabilities must sum to 1.0")
 
-        stratify = y if _can_stratify(y) else None
-        tv_idx, test_idx = train_test_split(
-            train_dataset.index, test_size=p_test,
-            random_state=random_seed, stratify=stratify
-        )
-        rel_val = p_val / (p_train + p_val) if (p_train + p_val) > 0 else 0
-        strat_tv = y.loc[tv_idx] if _can_stratify(y.loc[tv_idx]) else None
-        train_idx, val_idx = train_test_split(
-            tv_idx, test_size=rel_val,
-            random_state=random_seed, stratify=strat_tv
-        )
+        if sample_id_column:
+            rng = np.random.RandomState(random_seed)
+            group_series = train_dataset[sample_id_column].astype(object)
+            missing_mask = group_series.isna()
+            if missing_mask.any():
+                group_series = group_series.copy()
+                group_series.loc[missing_mask] = [
+                    f"__missing__{idx}" for idx in group_series.index[missing_mask]
+                ]
+            group_to_indices = {}
+            for idx, group_id in group_series.items():
+                group_to_indices.setdefault(group_id, []).append(idx)
+            group_ids = sorted(group_to_indices.keys(), key=lambda x: str(x))
+            rng.shuffle(group_ids)
+            targets = _allocate_split_counts(len(train_dataset), [p_train, p_val, p_test])
+            counts = [0, 0, 0]
+            active = [0, 1, 2]
+            train_idx = []
+            val_idx = []
+            test_idx = []
+            for group_id in group_ids:
+                split_idx = _choose_split(counts, targets, active)
+                counts[split_idx] += len(group_to_indices[group_id])
+                if split_idx == 0:
+                    train_idx.extend(group_to_indices[group_id])
+                elif split_idx == 1:
+                    val_idx.extend(group_to_indices[group_id])
+                else:
+                    test_idx.extend(group_to_indices[group_id])
+            train_dataset.loc[val_idx, SPLIT_COL] = "val"
+            train_dataset.loc[test_idx, SPLIT_COL] = "test"
+        else:
+            stratify = y if _can_stratify(y) else None
+            if p_test <= 0:
+                logger.warning(
+                    "split_probabilities specify 0 test size; skipping test split to avoid train_test_split errors."
+                )
+                tv_idx = train_dataset.index
+                test_idx = train_dataset.index[:0]
+            elif p_test >= 1:
+                logger.warning(
+                    "split_probabilities specify 100% test size; assigning all rows to test to avoid train_test_split errors."
+                )
+                tv_idx = train_dataset.index[:0]
+                test_idx = train_dataset.index
+            else:
+                tv_idx, test_idx = train_test_split(
+                    train_dataset.index, test_size=p_test,
+                    random_state=random_seed, stratify=stratify
+                )
+            rel_val = p_val / (p_train + p_val) if (p_train + p_val) > 0 else 0
+            train_idx = train_dataset.index[:0]
+            val_idx = train_dataset.index[:0]
+            if len(tv_idx):
+                if rel_val <= 0:
+                    logger.warning(
+                        "split_probabilities specify 0 validation size; skipping validation split to avoid train_test_split errors."
+                    )
+                    train_idx = tv_idx
+                elif rel_val >= 1:
+                    logger.warning(
+                        "split_probabilities specify 100% validation size; assigning all remaining rows to validation "
+                        "to avoid train_test_split errors."
+                    )
+                    val_idx = tv_idx
+                else:
+                    strat_tv = y.loc[tv_idx] if _can_stratify(y.loc[tv_idx]) else None
+                    train_idx, val_idx = train_test_split(
+                        tv_idx, test_size=rel_val,
+                        random_state=random_seed, stratify=strat_tv
+                    )
 
-        train_dataset.loc[val_idx, SPLIT_COL] = "val"
-        train_dataset.loc[test_idx, SPLIT_COL] = "test"
+            train_dataset.loc[val_idx, SPLIT_COL] = "val"
+            train_dataset.loc[test_idx, SPLIT_COL] = "test"
         logger.info(f"3-way split → train:{len(train_idx)}, val:{len(val_idx)}, test:{len(test_idx)}")
 
     logger.info(f"Final split distribution:\n{train_dataset[SPLIT_COL].value_counts().sort_index()}")