comparison split_logic.py @ 3:25bb80df7c0c draft default tip

planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
author goeckslab
date Sat, 17 Jan 2026 22:53:42 +0000
parents 375c36923da1
children
comparison
equal deleted inserted replaced
2:b708d0e210e6 3:25bb80df7c0c
1 import logging 1 import logging
2 from typing import List, Optional 2 from typing import List, Optional
3 3
4 import numpy as np
4 import pandas as pd 5 import pandas as pd
5 from sklearn.model_selection import train_test_split 6 from sklearn.model_selection import train_test_split
6 7
7 logger = logging.getLogger(__name__) 8 logger = logging.getLogger(__name__)
8 SPLIT_COL = "split" 9 SPLIT_COL = "split"
17 test_dataset: Optional[pd.DataFrame], 18 test_dataset: Optional[pd.DataFrame],
18 target_column: str, 19 target_column: str,
19 split_probabilities: List[float], 20 split_probabilities: List[float],
20 validation_size: float, 21 validation_size: float,
21 random_seed: int = 42, 22 random_seed: int = 42,
23 sample_id_column: Optional[str] = None,
22 ) -> None: 24 ) -> None:
23 if target_column not in train_dataset.columns: 25 if target_column not in train_dataset.columns:
24 raise ValueError(f"Target column '{target_column}' not found") 26 raise ValueError(f"Target column '{target_column}' not found")
25 27
26 # Drop NaN labels early 28 # Drop NaN labels early
30 raise ValueError("No rows remain after dropping NaN targets") 32 raise ValueError("No rows remain after dropping NaN targets")
31 if before != len(train_dataset): 33 if before != len(train_dataset):
32 logger.warning(f"Dropped {before - len(train_dataset)} rows with NaN target") 34 logger.warning(f"Dropped {before - len(train_dataset)} rows with NaN target")
33 y = train_dataset[target_column] 35 y = train_dataset[target_column]
34 36
37 if sample_id_column and sample_id_column not in train_dataset.columns:
38 logger.warning(
39 "Sample ID column '%s' not found; proceeding without group-aware split.",
40 sample_id_column,
41 )
42 sample_id_column = None
43 if sample_id_column and sample_id_column == target_column:
44 logger.warning(
45 "Sample ID column '%s' matches target column; proceeding without group-aware split.",
46 sample_id_column,
47 )
48 sample_id_column = None
49
35 # Respect existing valid split column 50 # Respect existing valid split column
36 if SPLIT_COL in train_dataset.columns: 51 if SPLIT_COL in train_dataset.columns:
37 unique = set(train_dataset[SPLIT_COL].dropna().unique()) 52 unique = set(train_dataset[SPLIT_COL].dropna().unique())
38 valid = {"train", "val", "validation", "test"} 53 valid = {"train", "val", "validation", "test"}
39 if unique.issubset(valid | {"validation"}): 54 if unique.issubset(valid):
40 train_dataset[SPLIT_COL] = train_dataset[SPLIT_COL].replace("validation", "val") 55 train_dataset[SPLIT_COL] = train_dataset[SPLIT_COL].replace("validation", "val")
41 logger.info(f"Using pre-existing 'split' column: {sorted(unique)}") 56 normalized = set(train_dataset[SPLIT_COL].dropna().unique())
57 required = {"train"} if test_dataset is not None else {"train", "test"}
58 missing = required - normalized
59 if missing:
60 missing_list = ", ".join(sorted(missing))
61 if test_dataset is not None:
62 raise ValueError(
63 "Pre-existing 'split' column is missing required split(s): "
64 f"{missing_list}. Expected at least train when an external test set is provided, "
65 "or remove the 'split' column to let the tool create splits."
66 )
67 raise ValueError(
68 "Pre-existing 'split' column is missing required split(s): "
69 f"{missing_list}. Expected at least train and test, "
70 "or remove the 'split' column to let the tool create splits."
71 )
72 logger.info(f"Using pre-existing 'split' column: {sorted(normalized)}")
42 return 73 return
43 74
44 train_dataset[SPLIT_COL] = "train" 75 train_dataset[SPLIT_COL] = "train"
45 76
77 def _allocate_split_counts(n_total: int, probs: list) -> list:
78 if n_total <= 0:
79 return [0 for _ in probs]
80 counts = [0 for _ in probs]
81 active = [i for i, p in enumerate(probs) if p > 0]
82 remainder = n_total
83 if active and n_total >= len(active):
84 for i in active:
85 counts[i] = 1
86 remainder -= len(active)
87 if remainder > 0:
88 probs_arr = np.array(probs, dtype=float)
89 probs_arr = probs_arr / probs_arr.sum()
90 raw = remainder * probs_arr
91 floors = np.floor(raw).astype(int)
92 for i, value in enumerate(floors.tolist()):
93 counts[i] += value
94 leftover = remainder - int(floors.sum())
95 if leftover > 0 and active:
96 frac = raw - floors
97 order = sorted(active, key=lambda i: (-frac[i], i))
98 for i in range(leftover):
99 counts[order[i % len(order)]] += 1
100 return counts
101
102 def _choose_split(counts: list, targets: list, active: list) -> int:
103 remaining = [targets[i] - counts[i] for i in range(len(targets))]
104 best = max(active, key=lambda i: (remaining[i], -counts[i], -targets[i]))
105 if remaining[best] <= 0:
106 best = min(active, key=lambda i: counts[i])
107 return best
108
46 if test_dataset is not None: 109 if test_dataset is not None:
47 stratify = y if _can_stratify(y) else None 110 if sample_id_column and sample_id_column in test_dataset.columns:
48 train_idx, val_idx = train_test_split( 111 train_ids = set(train_dataset[sample_id_column].dropna().astype(object).unique())
49 train_dataset.index, test_size=validation_size, 112 test_ids = set(test_dataset[sample_id_column].dropna().astype(object).unique())
50 random_state=random_seed, stratify=stratify 113 overlap = train_ids & test_ids
51 ) 114 if overlap:
52 train_dataset.loc[val_idx, SPLIT_COL] = "val" 115 logger.warning(
116 "Sample ID column '%s' has %d overlapping IDs between train and external test sets; "
117 "consider removing overlaps to avoid leakage.",
118 sample_id_column,
119 len(overlap),
120 )
121 if sample_id_column:
122 rng = np.random.RandomState(random_seed)
123 group_series = train_dataset[sample_id_column].astype(object)
124 missing_mask = group_series.isna()
125 if missing_mask.any():
126 group_series = group_series.copy()
127 group_series.loc[missing_mask] = [
128 f"__missing__{idx}" for idx in group_series.index[missing_mask]
129 ]
130 group_to_indices = {}
131 for idx, group_id in group_series.items():
132 group_to_indices.setdefault(group_id, []).append(idx)
133 group_ids = sorted(group_to_indices.keys(), key=lambda x: str(x))
134 rng.shuffle(group_ids)
135 targets = _allocate_split_counts(len(train_dataset), [1.0 - validation_size, validation_size])
136 counts = [0, 0]
137 active = [0, 1]
138 train_idx = []
139 val_idx = []
140 for group_id in group_ids:
141 split_idx = _choose_split(counts, targets, active)
142 counts[split_idx] += len(group_to_indices[group_id])
143 if split_idx == 0:
144 train_idx.extend(group_to_indices[group_id])
145 else:
146 val_idx.extend(group_to_indices[group_id])
147 train_dataset.loc[val_idx, SPLIT_COL] = "val"
148 else:
149 if validation_size <= 0:
150 logger.warning(
151 "validation_size is %.3f; skipping validation split to avoid train_test_split errors.",
152 validation_size,
153 )
154 elif validation_size >= 1:
155 logger.warning(
156 "validation_size is %.3f; assigning all rows to validation to avoid train_test_split errors.",
157 validation_size,
158 )
159 train_dataset[SPLIT_COL] = "val"
160 else:
161 stratify = y if _can_stratify(y) else None
162 train_idx, val_idx = train_test_split(
163 train_dataset.index, test_size=validation_size,
164 random_state=random_seed, stratify=stratify
165 )
166 train_dataset.loc[val_idx, SPLIT_COL] = "val"
53 logger.info(f"External test set → created val split ({validation_size:.0%})") 167 logger.info(f"External test set → created val split ({validation_size:.0%})")
54 168
55 else: 169 else:
56 p_train, p_val, p_test = split_probabilities 170 p_train, p_val, p_test = split_probabilities
57 if abs(p_train + p_val + p_test - 1.0) > 1e-6: 171 if abs(p_train + p_val + p_test - 1.0) > 1e-6:
58 raise ValueError("split_probabilities must sum to 1.0") 172 raise ValueError("split_probabilities must sum to 1.0")
59 173
60 stratify = y if _can_stratify(y) else None 174 if sample_id_column:
61 tv_idx, test_idx = train_test_split( 175 rng = np.random.RandomState(random_seed)
62 train_dataset.index, test_size=p_test, 176 group_series = train_dataset[sample_id_column].astype(object)
63 random_state=random_seed, stratify=stratify 177 missing_mask = group_series.isna()
64 ) 178 if missing_mask.any():
65 rel_val = p_val / (p_train + p_val) if (p_train + p_val) > 0 else 0 179 group_series = group_series.copy()
66 strat_tv = y.loc[tv_idx] if _can_stratify(y.loc[tv_idx]) else None 180 group_series.loc[missing_mask] = [
67 train_idx, val_idx = train_test_split( 181 f"__missing__{idx}" for idx in group_series.index[missing_mask]
68 tv_idx, test_size=rel_val, 182 ]
69 random_state=random_seed, stratify=strat_tv 183 group_to_indices = {}
70 ) 184 for idx, group_id in group_series.items():
71 185 group_to_indices.setdefault(group_id, []).append(idx)
72 train_dataset.loc[val_idx, SPLIT_COL] = "val" 186 group_ids = sorted(group_to_indices.keys(), key=lambda x: str(x))
73 train_dataset.loc[test_idx, SPLIT_COL] = "test" 187 rng.shuffle(group_ids)
188 targets = _allocate_split_counts(len(train_dataset), [p_train, p_val, p_test])
189 counts = [0, 0, 0]
190 active = [0, 1, 2]
191 train_idx = []
192 val_idx = []
193 test_idx = []
194 for group_id in group_ids:
195 split_idx = _choose_split(counts, targets, active)
196 counts[split_idx] += len(group_to_indices[group_id])
197 if split_idx == 0:
198 train_idx.extend(group_to_indices[group_id])
199 elif split_idx == 1:
200 val_idx.extend(group_to_indices[group_id])
201 else:
202 test_idx.extend(group_to_indices[group_id])
203 train_dataset.loc[val_idx, SPLIT_COL] = "val"
204 train_dataset.loc[test_idx, SPLIT_COL] = "test"
205 else:
206 stratify = y if _can_stratify(y) else None
207 if p_test <= 0:
208 logger.warning(
209 "split_probabilities specify 0 test size; skipping test split to avoid train_test_split errors."
210 )
211 tv_idx = train_dataset.index
212 test_idx = train_dataset.index[:0]
213 elif p_test >= 1:
214 logger.warning(
215 "split_probabilities specify 100% test size; assigning all rows to test to avoid train_test_split errors."
216 )
217 tv_idx = train_dataset.index[:0]
218 test_idx = train_dataset.index
219 else:
220 tv_idx, test_idx = train_test_split(
221 train_dataset.index, test_size=p_test,
222 random_state=random_seed, stratify=stratify
223 )
224 rel_val = p_val / (p_train + p_val) if (p_train + p_val) > 0 else 0
225 train_idx = train_dataset.index[:0]
226 val_idx = train_dataset.index[:0]
227 if len(tv_idx):
228 if rel_val <= 0:
229 logger.warning(
230 "split_probabilities specify 0 validation size; skipping validation split to avoid train_test_split errors."
231 )
232 train_idx = tv_idx
233 elif rel_val >= 1:
234 logger.warning(
235 "split_probabilities specify 100% validation size; assigning all remaining rows to validation "
236 "to avoid train_test_split errors."
237 )
238 val_idx = tv_idx
239 else:
240 strat_tv = y.loc[tv_idx] if _can_stratify(y.loc[tv_idx]) else None
241 train_idx, val_idx = train_test_split(
242 tv_idx, test_size=rel_val,
243 random_state=random_seed, stratify=strat_tv
244 )
245
246 train_dataset.loc[val_idx, SPLIT_COL] = "val"
247 train_dataset.loc[test_idx, SPLIT_COL] = "test"
74 logger.info(f"3-way split → train:{len(train_idx)}, val:{len(val_idx)}, test:{len(test_idx)}") 248 logger.info(f"3-way split → train:{len(train_idx)}, val:{len(val_idx)}, test:{len(test_idx)}")
75 249
76 logger.info(f"Final split distribution:\n{train_dataset[SPLIT_COL].value_counts().sort_index()}") 250 logger.info(f"Final split distribution:\n{train_dataset[SPLIT_COL].value_counts().sort_index()}")