Mercurial > repos > goeckslab > multimodal_learner
view split_logic.py @ 3:25bb80df7c0c draft default tip
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
| author | goeckslab |
|---|---|
| date | Sat, 17 Jan 2026 22:53:42 +0000 |
| parents | 375c36923da1 |
| children |
line wrap: on
line source
import logging from typing import List, Optional import numpy as np import pandas as pd from sklearn.model_selection import train_test_split logger = logging.getLogger(__name__) SPLIT_COL = "split" def _can_stratify(y: pd.Series) -> bool: return y.nunique() >= 2 and (y.value_counts() >= 2).all() def split_dataset( train_dataset: pd.DataFrame, test_dataset: Optional[pd.DataFrame], target_column: str, split_probabilities: List[float], validation_size: float, random_seed: int = 42, sample_id_column: Optional[str] = None, ) -> None: if target_column not in train_dataset.columns: raise ValueError(f"Target column '{target_column}' not found") # Drop NaN labels early before = len(train_dataset) train_dataset.dropna(subset=[target_column], inplace=True) if len(train_dataset) == 0: raise ValueError("No rows remain after dropping NaN targets") if before != len(train_dataset): logger.warning(f"Dropped {before - len(train_dataset)} rows with NaN target") y = train_dataset[target_column] if sample_id_column and sample_id_column not in train_dataset.columns: logger.warning( "Sample ID column '%s' not found; proceeding without group-aware split.", sample_id_column, ) sample_id_column = None if sample_id_column and sample_id_column == target_column: logger.warning( "Sample ID column '%s' matches target column; proceeding without group-aware split.", sample_id_column, ) sample_id_column = None # Respect existing valid split column if SPLIT_COL in train_dataset.columns: unique = set(train_dataset[SPLIT_COL].dropna().unique()) valid = {"train", "val", "validation", "test"} if unique.issubset(valid): train_dataset[SPLIT_COL] = train_dataset[SPLIT_COL].replace("validation", "val") normalized = set(train_dataset[SPLIT_COL].dropna().unique()) required = {"train"} if test_dataset is not None else {"train", "test"} missing = required - normalized if missing: missing_list = ", ".join(sorted(missing)) if test_dataset is not None: raise ValueError( "Pre-existing 'split' column is missing required split(s): " f"{missing_list}. Expected at least train when an external test set is provided, " "or remove the 'split' column to let the tool create splits." ) raise ValueError( "Pre-existing 'split' column is missing required split(s): " f"{missing_list}. Expected at least train and test, " "or remove the 'split' column to let the tool create splits." ) logger.info(f"Using pre-existing 'split' column: {sorted(normalized)}") return train_dataset[SPLIT_COL] = "train" def _allocate_split_counts(n_total: int, probs: list) -> list: if n_total <= 0: return [0 for _ in probs] counts = [0 for _ in probs] active = [i for i, p in enumerate(probs) if p > 0] remainder = n_total if active and n_total >= len(active): for i in active: counts[i] = 1 remainder -= len(active) if remainder > 0: probs_arr = np.array(probs, dtype=float) probs_arr = probs_arr / probs_arr.sum() raw = remainder * probs_arr floors = np.floor(raw).astype(int) for i, value in enumerate(floors.tolist()): counts[i] += value leftover = remainder - int(floors.sum()) if leftover > 0 and active: frac = raw - floors order = sorted(active, key=lambda i: (-frac[i], i)) for i in range(leftover): counts[order[i % len(order)]] += 1 return counts def _choose_split(counts: list, targets: list, active: list) -> int: remaining = [targets[i] - counts[i] for i in range(len(targets))] best = max(active, key=lambda i: (remaining[i], -counts[i], -targets[i])) if remaining[best] <= 0: best = min(active, key=lambda i: counts[i]) return best if test_dataset is not None: if sample_id_column and sample_id_column in test_dataset.columns: train_ids = set(train_dataset[sample_id_column].dropna().astype(object).unique()) test_ids = set(test_dataset[sample_id_column].dropna().astype(object).unique()) overlap = train_ids & test_ids if overlap: logger.warning( "Sample ID column '%s' has %d overlapping IDs between train and external test sets; " "consider removing overlaps to avoid leakage.", sample_id_column, len(overlap), ) if sample_id_column: rng = np.random.RandomState(random_seed) group_series = train_dataset[sample_id_column].astype(object) missing_mask = group_series.isna() if missing_mask.any(): group_series = group_series.copy() group_series.loc[missing_mask] = [ f"__missing__{idx}" for idx in group_series.index[missing_mask] ] group_to_indices = {} for idx, group_id in group_series.items(): group_to_indices.setdefault(group_id, []).append(idx) group_ids = sorted(group_to_indices.keys(), key=lambda x: str(x)) rng.shuffle(group_ids) targets = _allocate_split_counts(len(train_dataset), [1.0 - validation_size, validation_size]) counts = [0, 0] active = [0, 1] train_idx = [] val_idx = [] for group_id in group_ids: split_idx = _choose_split(counts, targets, active) counts[split_idx] += len(group_to_indices[group_id]) if split_idx == 0: train_idx.extend(group_to_indices[group_id]) else: val_idx.extend(group_to_indices[group_id]) train_dataset.loc[val_idx, SPLIT_COL] = "val" else: if validation_size <= 0: logger.warning( "validation_size is %.3f; skipping validation split to avoid train_test_split errors.", validation_size, ) elif validation_size >= 1: logger.warning( "validation_size is %.3f; assigning all rows to validation to avoid train_test_split errors.", validation_size, ) train_dataset[SPLIT_COL] = "val" else: stratify = y if _can_stratify(y) else None train_idx, val_idx = train_test_split( train_dataset.index, test_size=validation_size, random_state=random_seed, stratify=stratify ) train_dataset.loc[val_idx, SPLIT_COL] = "val" logger.info(f"External test set → created val split ({validation_size:.0%})") else: p_train, p_val, p_test = split_probabilities if abs(p_train + p_val + p_test - 1.0) > 1e-6: raise ValueError("split_probabilities must sum to 1.0") if sample_id_column: rng = np.random.RandomState(random_seed) group_series = train_dataset[sample_id_column].astype(object) missing_mask = group_series.isna() if missing_mask.any(): group_series = group_series.copy() group_series.loc[missing_mask] = [ f"__missing__{idx}" for idx in group_series.index[missing_mask] ] group_to_indices = {} for idx, group_id in group_series.items(): group_to_indices.setdefault(group_id, []).append(idx) group_ids = sorted(group_to_indices.keys(), key=lambda x: str(x)) rng.shuffle(group_ids) targets = _allocate_split_counts(len(train_dataset), [p_train, p_val, p_test]) counts = [0, 0, 0] active = [0, 1, 2] train_idx = [] val_idx = [] test_idx = [] for group_id in group_ids: split_idx = _choose_split(counts, targets, active) counts[split_idx] += len(group_to_indices[group_id]) if split_idx == 0: train_idx.extend(group_to_indices[group_id]) elif split_idx == 1: val_idx.extend(group_to_indices[group_id]) else: test_idx.extend(group_to_indices[group_id]) train_dataset.loc[val_idx, SPLIT_COL] = "val" train_dataset.loc[test_idx, SPLIT_COL] = "test" else: stratify = y if _can_stratify(y) else None if p_test <= 0: logger.warning( "split_probabilities specify 0 test size; skipping test split to avoid train_test_split errors." ) tv_idx = train_dataset.index test_idx = train_dataset.index[:0] elif p_test >= 1: logger.warning( "split_probabilities specify 100% test size; assigning all rows to test to avoid train_test_split errors." ) tv_idx = train_dataset.index[:0] test_idx = train_dataset.index else: tv_idx, test_idx = train_test_split( train_dataset.index, test_size=p_test, random_state=random_seed, stratify=stratify ) rel_val = p_val / (p_train + p_val) if (p_train + p_val) > 0 else 0 train_idx = train_dataset.index[:0] val_idx = train_dataset.index[:0] if len(tv_idx): if rel_val <= 0: logger.warning( "split_probabilities specify 0 validation size; skipping validation split to avoid train_test_split errors." ) train_idx = tv_idx elif rel_val >= 1: logger.warning( "split_probabilities specify 100% validation size; assigning all remaining rows to validation " "to avoid train_test_split errors." ) val_idx = tv_idx else: strat_tv = y.loc[tv_idx] if _can_stratify(y.loc[tv_idx]) else None train_idx, val_idx = train_test_split( tv_idx, test_size=rel_val, random_state=random_seed, stratify=strat_tv ) train_dataset.loc[val_idx, SPLIT_COL] = "val" train_dataset.loc[test_idx, SPLIT_COL] = "test" logger.info(f"3-way split → train:{len(train_idx)}, val:{len(val_idx)}, test:{len(test_idx)}") logger.info(f"Final split distribution:\n{train_dataset[SPLIT_COL].value_counts().sort_index()}")
