Mercurial > repos > goeckslab > multimodal_learner
comparison split_logic.py @ 3:25bb80df7c0c draft default tip
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
| author | goeckslab |
|---|---|
| date | Sat, 17 Jan 2026 22:53:42 +0000 |
| parents | 375c36923da1 |
| children |
comparison
equal
deleted
inserted
replaced
| 2:b708d0e210e6 | 3:25bb80df7c0c |
|---|---|
| 1 import logging | 1 import logging |
| 2 from typing import List, Optional | 2 from typing import List, Optional |
| 3 | 3 |
| 4 import numpy as np | |
| 4 import pandas as pd | 5 import pandas as pd |
| 5 from sklearn.model_selection import train_test_split | 6 from sklearn.model_selection import train_test_split |
| 6 | 7 |
| 7 logger = logging.getLogger(__name__) | 8 logger = logging.getLogger(__name__) |
| 8 SPLIT_COL = "split" | 9 SPLIT_COL = "split" |
| 17 test_dataset: Optional[pd.DataFrame], | 18 test_dataset: Optional[pd.DataFrame], |
| 18 target_column: str, | 19 target_column: str, |
| 19 split_probabilities: List[float], | 20 split_probabilities: List[float], |
| 20 validation_size: float, | 21 validation_size: float, |
| 21 random_seed: int = 42, | 22 random_seed: int = 42, |
| 23 sample_id_column: Optional[str] = None, | |
| 22 ) -> None: | 24 ) -> None: |
| 23 if target_column not in train_dataset.columns: | 25 if target_column not in train_dataset.columns: |
| 24 raise ValueError(f"Target column '{target_column}' not found") | 26 raise ValueError(f"Target column '{target_column}' not found") |
| 25 | 27 |
| 26 # Drop NaN labels early | 28 # Drop NaN labels early |
| 30 raise ValueError("No rows remain after dropping NaN targets") | 32 raise ValueError("No rows remain after dropping NaN targets") |
| 31 if before != len(train_dataset): | 33 if before != len(train_dataset): |
| 32 logger.warning(f"Dropped {before - len(train_dataset)} rows with NaN target") | 34 logger.warning(f"Dropped {before - len(train_dataset)} rows with NaN target") |
| 33 y = train_dataset[target_column] | 35 y = train_dataset[target_column] |
| 34 | 36 |
| 37 if sample_id_column and sample_id_column not in train_dataset.columns: | |
| 38 logger.warning( | |
| 39 "Sample ID column '%s' not found; proceeding without group-aware split.", | |
| 40 sample_id_column, | |
| 41 ) | |
| 42 sample_id_column = None | |
| 43 if sample_id_column and sample_id_column == target_column: | |
| 44 logger.warning( | |
| 45 "Sample ID column '%s' matches target column; proceeding without group-aware split.", | |
| 46 sample_id_column, | |
| 47 ) | |
| 48 sample_id_column = None | |
| 49 | |
| 35 # Respect existing valid split column | 50 # Respect existing valid split column |
| 36 if SPLIT_COL in train_dataset.columns: | 51 if SPLIT_COL in train_dataset.columns: |
| 37 unique = set(train_dataset[SPLIT_COL].dropna().unique()) | 52 unique = set(train_dataset[SPLIT_COL].dropna().unique()) |
| 38 valid = {"train", "val", "validation", "test"} | 53 valid = {"train", "val", "validation", "test"} |
| 39 if unique.issubset(valid | {"validation"}): | 54 if unique.issubset(valid): |
| 40 train_dataset[SPLIT_COL] = train_dataset[SPLIT_COL].replace("validation", "val") | 55 train_dataset[SPLIT_COL] = train_dataset[SPLIT_COL].replace("validation", "val") |
| 41 logger.info(f"Using pre-existing 'split' column: {sorted(unique)}") | 56 normalized = set(train_dataset[SPLIT_COL].dropna().unique()) |
| 57 required = {"train"} if test_dataset is not None else {"train", "test"} | |
| 58 missing = required - normalized | |
| 59 if missing: | |
| 60 missing_list = ", ".join(sorted(missing)) | |
| 61 if test_dataset is not None: | |
| 62 raise ValueError( | |
| 63 "Pre-existing 'split' column is missing required split(s): " | |
| 64 f"{missing_list}. Expected at least train when an external test set is provided, " | |
| 65 "or remove the 'split' column to let the tool create splits." | |
| 66 ) | |
| 67 raise ValueError( | |
| 68 "Pre-existing 'split' column is missing required split(s): " | |
| 69 f"{missing_list}. Expected at least train and test, " | |
| 70 "or remove the 'split' column to let the tool create splits." | |
| 71 ) | |
| 72 logger.info(f"Using pre-existing 'split' column: {sorted(normalized)}") | |
| 42 return | 73 return |
| 43 | 74 |
| 44 train_dataset[SPLIT_COL] = "train" | 75 train_dataset[SPLIT_COL] = "train" |
| 45 | 76 |
| 77 def _allocate_split_counts(n_total: int, probs: list) -> list: | |
| 78 if n_total <= 0: | |
| 79 return [0 for _ in probs] | |
| 80 counts = [0 for _ in probs] | |
| 81 active = [i for i, p in enumerate(probs) if p > 0] | |
| 82 remainder = n_total | |
| 83 if active and n_total >= len(active): | |
| 84 for i in active: | |
| 85 counts[i] = 1 | |
| 86 remainder -= len(active) | |
| 87 if remainder > 0: | |
| 88 probs_arr = np.array(probs, dtype=float) | |
| 89 probs_arr = probs_arr / probs_arr.sum() | |
| 90 raw = remainder * probs_arr | |
| 91 floors = np.floor(raw).astype(int) | |
| 92 for i, value in enumerate(floors.tolist()): | |
| 93 counts[i] += value | |
| 94 leftover = remainder - int(floors.sum()) | |
| 95 if leftover > 0 and active: | |
| 96 frac = raw - floors | |
| 97 order = sorted(active, key=lambda i: (-frac[i], i)) | |
| 98 for i in range(leftover): | |
| 99 counts[order[i % len(order)]] += 1 | |
| 100 return counts | |
| 101 | |
| 102 def _choose_split(counts: list, targets: list, active: list) -> int: | |
| 103 remaining = [targets[i] - counts[i] for i in range(len(targets))] | |
| 104 best = max(active, key=lambda i: (remaining[i], -counts[i], -targets[i])) | |
| 105 if remaining[best] <= 0: | |
| 106 best = min(active, key=lambda i: counts[i]) | |
| 107 return best | |
| 108 | |
| 46 if test_dataset is not None: | 109 if test_dataset is not None: |
| 47 stratify = y if _can_stratify(y) else None | 110 if sample_id_column and sample_id_column in test_dataset.columns: |
| 48 train_idx, val_idx = train_test_split( | 111 train_ids = set(train_dataset[sample_id_column].dropna().astype(object).unique()) |
| 49 train_dataset.index, test_size=validation_size, | 112 test_ids = set(test_dataset[sample_id_column].dropna().astype(object).unique()) |
| 50 random_state=random_seed, stratify=stratify | 113 overlap = train_ids & test_ids |
| 51 ) | 114 if overlap: |
| 52 train_dataset.loc[val_idx, SPLIT_COL] = "val" | 115 logger.warning( |
| 116 "Sample ID column '%s' has %d overlapping IDs between train and external test sets; " | |
| 117 "consider removing overlaps to avoid leakage.", | |
| 118 sample_id_column, | |
| 119 len(overlap), | |
| 120 ) | |
| 121 if sample_id_column: | |
| 122 rng = np.random.RandomState(random_seed) | |
| 123 group_series = train_dataset[sample_id_column].astype(object) | |
| 124 missing_mask = group_series.isna() | |
| 125 if missing_mask.any(): | |
| 126 group_series = group_series.copy() | |
| 127 group_series.loc[missing_mask] = [ | |
| 128 f"__missing__{idx}" for idx in group_series.index[missing_mask] | |
| 129 ] | |
| 130 group_to_indices = {} | |
| 131 for idx, group_id in group_series.items(): | |
| 132 group_to_indices.setdefault(group_id, []).append(idx) | |
| 133 group_ids = sorted(group_to_indices.keys(), key=lambda x: str(x)) | |
| 134 rng.shuffle(group_ids) | |
| 135 targets = _allocate_split_counts(len(train_dataset), [1.0 - validation_size, validation_size]) | |
| 136 counts = [0, 0] | |
| 137 active = [0, 1] | |
| 138 train_idx = [] | |
| 139 val_idx = [] | |
| 140 for group_id in group_ids: | |
| 141 split_idx = _choose_split(counts, targets, active) | |
| 142 counts[split_idx] += len(group_to_indices[group_id]) | |
| 143 if split_idx == 0: | |
| 144 train_idx.extend(group_to_indices[group_id]) | |
| 145 else: | |
| 146 val_idx.extend(group_to_indices[group_id]) | |
| 147 train_dataset.loc[val_idx, SPLIT_COL] = "val" | |
| 148 else: | |
| 149 if validation_size <= 0: | |
| 150 logger.warning( | |
| 151 "validation_size is %.3f; skipping validation split to avoid train_test_split errors.", | |
| 152 validation_size, | |
| 153 ) | |
| 154 elif validation_size >= 1: | |
| 155 logger.warning( | |
| 156 "validation_size is %.3f; assigning all rows to validation to avoid train_test_split errors.", | |
| 157 validation_size, | |
| 158 ) | |
| 159 train_dataset[SPLIT_COL] = "val" | |
| 160 else: | |
| 161 stratify = y if _can_stratify(y) else None | |
| 162 train_idx, val_idx = train_test_split( | |
| 163 train_dataset.index, test_size=validation_size, | |
| 164 random_state=random_seed, stratify=stratify | |
| 165 ) | |
| 166 train_dataset.loc[val_idx, SPLIT_COL] = "val" | |
| 53 logger.info(f"External test set → created val split ({validation_size:.0%})") | 167 logger.info(f"External test set → created val split ({validation_size:.0%})") |
| 54 | 168 |
| 55 else: | 169 else: |
| 56 p_train, p_val, p_test = split_probabilities | 170 p_train, p_val, p_test = split_probabilities |
| 57 if abs(p_train + p_val + p_test - 1.0) > 1e-6: | 171 if abs(p_train + p_val + p_test - 1.0) > 1e-6: |
| 58 raise ValueError("split_probabilities must sum to 1.0") | 172 raise ValueError("split_probabilities must sum to 1.0") |
| 59 | 173 |
| 60 stratify = y if _can_stratify(y) else None | 174 if sample_id_column: |
| 61 tv_idx, test_idx = train_test_split( | 175 rng = np.random.RandomState(random_seed) |
| 62 train_dataset.index, test_size=p_test, | 176 group_series = train_dataset[sample_id_column].astype(object) |
| 63 random_state=random_seed, stratify=stratify | 177 missing_mask = group_series.isna() |
| 64 ) | 178 if missing_mask.any(): |
| 65 rel_val = p_val / (p_train + p_val) if (p_train + p_val) > 0 else 0 | 179 group_series = group_series.copy() |
| 66 strat_tv = y.loc[tv_idx] if _can_stratify(y.loc[tv_idx]) else None | 180 group_series.loc[missing_mask] = [ |
| 67 train_idx, val_idx = train_test_split( | 181 f"__missing__{idx}" for idx in group_series.index[missing_mask] |
| 68 tv_idx, test_size=rel_val, | 182 ] |
| 69 random_state=random_seed, stratify=strat_tv | 183 group_to_indices = {} |
| 70 ) | 184 for idx, group_id in group_series.items(): |
| 71 | 185 group_to_indices.setdefault(group_id, []).append(idx) |
| 72 train_dataset.loc[val_idx, SPLIT_COL] = "val" | 186 group_ids = sorted(group_to_indices.keys(), key=lambda x: str(x)) |
| 73 train_dataset.loc[test_idx, SPLIT_COL] = "test" | 187 rng.shuffle(group_ids) |
| 188 targets = _allocate_split_counts(len(train_dataset), [p_train, p_val, p_test]) | |
| 189 counts = [0, 0, 0] | |
| 190 active = [0, 1, 2] | |
| 191 train_idx = [] | |
| 192 val_idx = [] | |
| 193 test_idx = [] | |
| 194 for group_id in group_ids: | |
| 195 split_idx = _choose_split(counts, targets, active) | |
| 196 counts[split_idx] += len(group_to_indices[group_id]) | |
| 197 if split_idx == 0: | |
| 198 train_idx.extend(group_to_indices[group_id]) | |
| 199 elif split_idx == 1: | |
| 200 val_idx.extend(group_to_indices[group_id]) | |
| 201 else: | |
| 202 test_idx.extend(group_to_indices[group_id]) | |
| 203 train_dataset.loc[val_idx, SPLIT_COL] = "val" | |
| 204 train_dataset.loc[test_idx, SPLIT_COL] = "test" | |
| 205 else: | |
| 206 stratify = y if _can_stratify(y) else None | |
| 207 if p_test <= 0: | |
| 208 logger.warning( | |
| 209 "split_probabilities specify 0 test size; skipping test split to avoid train_test_split errors." | |
| 210 ) | |
| 211 tv_idx = train_dataset.index | |
| 212 test_idx = train_dataset.index[:0] | |
| 213 elif p_test >= 1: | |
| 214 logger.warning( | |
| 215 "split_probabilities specify 100% test size; assigning all rows to test to avoid train_test_split errors." | |
| 216 ) | |
| 217 tv_idx = train_dataset.index[:0] | |
| 218 test_idx = train_dataset.index | |
| 219 else: | |
| 220 tv_idx, test_idx = train_test_split( | |
| 221 train_dataset.index, test_size=p_test, | |
| 222 random_state=random_seed, stratify=stratify | |
| 223 ) | |
| 224 rel_val = p_val / (p_train + p_val) if (p_train + p_val) > 0 else 0 | |
| 225 train_idx = train_dataset.index[:0] | |
| 226 val_idx = train_dataset.index[:0] | |
| 227 if len(tv_idx): | |
| 228 if rel_val <= 0: | |
| 229 logger.warning( | |
| 230 "split_probabilities specify 0 validation size; skipping validation split to avoid train_test_split errors." | |
| 231 ) | |
| 232 train_idx = tv_idx | |
| 233 elif rel_val >= 1: | |
| 234 logger.warning( | |
| 235 "split_probabilities specify 100% validation size; assigning all remaining rows to validation " | |
| 236 "to avoid train_test_split errors." | |
| 237 ) | |
| 238 val_idx = tv_idx | |
| 239 else: | |
| 240 strat_tv = y.loc[tv_idx] if _can_stratify(y.loc[tv_idx]) else None | |
| 241 train_idx, val_idx = train_test_split( | |
| 242 tv_idx, test_size=rel_val, | |
| 243 random_state=random_seed, stratify=strat_tv | |
| 244 ) | |
| 245 | |
| 246 train_dataset.loc[val_idx, SPLIT_COL] = "val" | |
| 247 train_dataset.loc[test_idx, SPLIT_COL] = "test" | |
| 74 logger.info(f"3-way split → train:{len(train_idx)}, val:{len(val_idx)}, test:{len(test_idx)}") | 248 logger.info(f"3-way split → train:{len(train_idx)}, val:{len(val_idx)}, test:{len(test_idx)}") |
| 75 | 249 |
| 76 logger.info(f"Final split distribution:\n{train_dataset[SPLIT_COL].value_counts().sort_index()}") | 250 logger.info(f"Final split distribution:\n{train_dataset[SPLIT_COL].value_counts().sort_index()}") |
