comparison split_data.py @ 20:64872c48a21f draft

planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
author goeckslab
date Tue, 06 Jan 2026 15:35:11 +0000
parents bcfa2e234a80
children
comparison
equal deleted inserted replaced
19:c460abae83eb 20:64872c48a21f
79 df: pd.DataFrame, 79 df: pd.DataFrame,
80 split_column: str, 80 split_column: str,
81 split_probabilities: list = [0.7, 0.1, 0.2], 81 split_probabilities: list = [0.7, 0.1, 0.2],
82 random_state: int = 42, 82 random_state: int = 42,
83 label_column: Optional[str] = None, 83 label_column: Optional[str] = None,
84 group_column: Optional[str] = None,
84 ) -> pd.DataFrame: 85 ) -> pd.DataFrame:
85 """Create a stratified random split when no split column exists.""" 86 """Create a stratified random split when no split column exists."""
86 out = df.copy() 87 out = df.copy()
87 88
88 # initialize split column 89 # initialize split column
89 out[split_column] = 0 90 out[split_column] = 0
90 91
92 if group_column and group_column not in out.columns:
93 logger.warning(
94 "Group column '%s' not found in data; proceeding without group-aware split.",
95 group_column,
96 )
97 group_column = None
98
99 def _allocate_split_counts(n_total: int, probs: list) -> list:
100 """Allocate exact split counts using largest remainder rounding."""
101 if n_total <= 0:
102 return [0 for _ in probs]
103
104 counts = [0 for _ in probs]
105 active = [i for i, p in enumerate(probs) if p > 0]
106 remainder = n_total
107
108 if active and n_total >= len(active):
109 for i in active:
110 counts[i] = 1
111 remainder -= len(active)
112
113 if remainder > 0:
114 probs_arr = np.array(probs, dtype=float)
115 probs_arr = probs_arr / probs_arr.sum()
116 raw = remainder * probs_arr
117 floors = np.floor(raw).astype(int)
118 for i, value in enumerate(floors.tolist()):
119 counts[i] += value
120 leftover = remainder - int(floors.sum())
121 if leftover > 0 and active:
122 frac = raw - floors
123 order = sorted(active, key=lambda i: (-frac[i], i))
124 for i in range(leftover):
125 counts[order[i % len(order)]] += 1
126
127 return counts
128
129 def _choose_split(counts: list, targets: list, active: list) -> int:
130 remaining = [targets[i] - counts[i] for i in range(len(targets))]
131 best = max(active, key=lambda i: (remaining[i], -counts[i], -targets[i]))
132 if remaining[best] <= 0:
133 best = min(active, key=lambda i: counts[i])
134 return best
135
91 if not label_column or label_column not in out.columns: 136 if not label_column or label_column not in out.columns:
92 logger.warning( 137 logger.warning(
93 "No label column found; using random split without stratification" 138 "No label column found; using random split without stratification"
94 ) 139 )
95 # fall back to simple random assignment 140 # fall back to random assignment (group-aware if requested)
96 indices = out.index.tolist() 141 indices = out.index.tolist()
97 np.random.seed(random_state) 142 rng = np.random.RandomState(random_state)
98 np.random.shuffle(indices) 143
99 144 if group_column:
100 n_total = len(indices) 145 group_series = out[group_column].astype(object)
101 n_train = int(n_total * split_probabilities[0]) 146 missing_mask = group_series.isna()
102 n_val = int(n_total * split_probabilities[1]) 147 if missing_mask.any():
148 group_series = group_series.copy()
149 group_series.loc[missing_mask] = [
150 f"__missing__{idx}" for idx in group_series.index[missing_mask]
151 ]
152 group_to_indices = {}
153 for idx, group_id in group_series.items():
154 group_to_indices.setdefault(group_id, []).append(idx)
155 group_ids = sorted(group_to_indices.keys(), key=lambda x: str(x))
156 rng.shuffle(group_ids)
157
158 targets = _allocate_split_counts(len(indices), split_probabilities)
159 counts = [0 for _ in split_probabilities]
160 active = [i for i, p in enumerate(split_probabilities) if p > 0]
161 train_idx = []
162 val_idx = []
163 test_idx = []
164
165 for group_id in group_ids:
166 size = len(group_to_indices[group_id])
167 split_idx = _choose_split(counts, targets, active)
168 counts[split_idx] += size
169 if split_idx == 0:
170 train_idx.extend(group_to_indices[group_id])
171 elif split_idx == 1:
172 val_idx.extend(group_to_indices[group_id])
173 else:
174 test_idx.extend(group_to_indices[group_id])
175
176 out.loc[train_idx, split_column] = 0
177 out.loc[val_idx, split_column] = 1
178 out.loc[test_idx, split_column] = 2
179 return out.astype({split_column: int})
180
181 rng.shuffle(indices)
182
183 targets = _allocate_split_counts(len(indices), split_probabilities)
184 n_train, n_val, n_test = targets
103 185
104 out.loc[indices[:n_train], split_column] = 0 186 out.loc[indices[:n_train], split_column] = 0
105 out.loc[indices[n_train:n_train + n_val], split_column] = 1 187 out.loc[indices[n_train:n_train + n_val], split_column] = 1
106 out.loc[indices[n_train + n_val:], split_column] = 2 188 out.loc[indices[n_train + n_val:n_train + n_val + n_test], split_column] = 2
107 189
108 return out.astype({split_column: int}) 190 return out.astype({split_column: int})
109 191
110 # check if stratification is possible 192 # check if stratification is possible
111 label_counts = out[label_column].value_counts() 193 label_counts = out[label_column].value_counts()
112 min_samples_per_class = label_counts.min() 194 min_samples_per_class = label_counts.min()
113 195
114 # ensure we have enough samples for stratification: 196 # ensure we have enough samples for stratification:
115 # Each class must have at least as many samples as the number of splits, 197 # Each class must have at least as many samples as the number of nonzero splits,
116 # so that each split can receive at least one sample per class. 198 # so that each split can receive at least one sample per class.
117 min_samples_required = len(split_probabilities) 199 active_splits = [p for p in split_probabilities if p > 0]
200 min_samples_required = len(active_splits)
118 if min_samples_per_class < min_samples_required: 201 if min_samples_per_class < min_samples_required:
119 logger.warning( 202 logger.warning(
120 f"Insufficient samples per class for stratification (min: {min_samples_per_class}, required: {min_samples_required}); using random split" 203 f"Insufficient samples per class for stratification (min: {min_samples_per_class}, required: {min_samples_required}); using random split"
121 ) 204 )
122 # fall back to simple random assignment 205 # fall back to simple random assignment (group-aware if requested)
123 indices = out.index.tolist() 206 indices = out.index.tolist()
124 np.random.seed(random_state) 207 rng = np.random.RandomState(random_state)
125 np.random.shuffle(indices) 208
126 209 if group_column:
127 n_total = len(indices) 210 group_series = out[group_column].astype(object)
128 n_train = int(n_total * split_probabilities[0]) 211 missing_mask = group_series.isna()
129 n_val = int(n_total * split_probabilities[1]) 212 if missing_mask.any():
213 group_series = group_series.copy()
214 group_series.loc[missing_mask] = [
215 f"__missing__{idx}" for idx in group_series.index[missing_mask]
216 ]
217 group_to_indices = {}
218 for idx, group_id in group_series.items():
219 group_to_indices.setdefault(group_id, []).append(idx)
220 group_ids = sorted(group_to_indices.keys(), key=lambda x: str(x))
221 rng.shuffle(group_ids)
222
223 targets = _allocate_split_counts(len(indices), split_probabilities)
224 counts = [0 for _ in split_probabilities]
225 active = [i for i, p in enumerate(split_probabilities) if p > 0]
226 train_idx = []
227 val_idx = []
228 test_idx = []
229
230 for group_id in group_ids:
231 size = len(group_to_indices[group_id])
232 split_idx = _choose_split(counts, targets, active)
233 counts[split_idx] += size
234 if split_idx == 0:
235 train_idx.extend(group_to_indices[group_id])
236 elif split_idx == 1:
237 val_idx.extend(group_to_indices[group_id])
238 else:
239 test_idx.extend(group_to_indices[group_id])
240
241 out.loc[train_idx, split_column] = 0
242 out.loc[val_idx, split_column] = 1
243 out.loc[test_idx, split_column] = 2
244 return out.astype({split_column: int})
245
246 rng.shuffle(indices)
247
248 targets = _allocate_split_counts(len(indices), split_probabilities)
249 n_train, n_val, n_test = targets
130 250
131 out.loc[indices[:n_train], split_column] = 0 251 out.loc[indices[:n_train], split_column] = 0
132 out.loc[indices[n_train:n_train + n_val], split_column] = 1 252 out.loc[indices[n_train:n_train + n_val], split_column] = 1
133 out.loc[indices[n_train + n_val:], split_column] = 2 253 out.loc[indices[n_train + n_val:n_train + n_val + n_test], split_column] = 2
134 254
135 return out.astype({split_column: int}) 255 return out.astype({split_column: int})
136 256
137 logger.info("Using stratified random split for train/validation/test sets") 257 if group_column:
138 258 logger.info(
139 # first split: separate test set 259 "Using stratified random split for train/validation/test sets (grouped by '%s')",
140 train_val_idx, test_idx = train_test_split( 260 group_column,
141 out.index.tolist(), 261 )
142 test_size=split_probabilities[2], 262 rng = np.random.RandomState(random_state)
143 random_state=random_state, 263
144 stratify=out[label_column], 264 group_series = out[group_column].astype(object)
145 ) 265 missing_mask = group_series.isna()
146 266 if missing_mask.any():
147 # second split: separate training and validation from remaining data 267 group_series = group_series.copy()
148 val_size_adjusted = split_probabilities[1] / ( 268 group_series.loc[missing_mask] = [
149 split_probabilities[0] + split_probabilities[1] 269 f"__missing__{idx}" for idx in group_series.index[missing_mask]
150 ) 270 ]
151 train_idx, val_idx = train_test_split( 271
152 train_val_idx, 272 group_to_indices = {}
153 test_size=val_size_adjusted, 273 for idx, group_id in group_series.items():
154 random_state=random_state, 274 group_to_indices.setdefault(group_id, []).append(idx)
155 stratify=out.loc[train_val_idx, label_column] if label_column and label_column in out.columns else None, 275
156 ) 276 group_ids = sorted(group_to_indices.keys(), key=lambda x: str(x))
277 group_labels = {}
278 mixed_groups = []
279 label_series = out[label_column]
280
281 for group_id in group_ids:
282 labels = label_series.loc[group_to_indices[group_id]].dropna().unique()
283 if len(labels) == 1:
284 group_labels[group_id] = labels[0]
285 elif len(labels) == 0:
286 group_labels[group_id] = None
287 else:
288 mode_vals = label_series.loc[group_to_indices[group_id]].mode(dropna=True)
289 group_labels[group_id] = mode_vals.iloc[0] if not mode_vals.empty else labels[0]
290 mixed_groups.append(group_id)
291
292 if mixed_groups:
293 logger.warning(
294 "Detected %d groups with multiple labels; using the most common label per group for stratification.",
295 len(mixed_groups),
296 )
297
298 train_idx = []
299 val_idx = []
300 test_idx = []
301 active = [i for i, p in enumerate(split_probabilities) if p > 0]
302
303 for label_value in sorted(label_counts.index.tolist(), key=lambda x: str(x)):
304 label_groups = [g for g in group_ids if group_labels.get(g) == label_value]
305 if not label_groups:
306 continue
307 rng.shuffle(label_groups)
308 label_total = sum(len(group_to_indices[g]) for g in label_groups)
309 targets = _allocate_split_counts(label_total, split_probabilities)
310 counts = [0 for _ in split_probabilities]
311
312 for group_id in label_groups:
313 size = len(group_to_indices[group_id])
314 split_idx = _choose_split(counts, targets, active)
315 counts[split_idx] += size
316 if split_idx == 0:
317 train_idx.extend(group_to_indices[group_id])
318 elif split_idx == 1:
319 val_idx.extend(group_to_indices[group_id])
320 else:
321 test_idx.extend(group_to_indices[group_id])
322
323 # Assign groups without a label (or missing labels) using overall targets.
324 unlabeled_groups = [g for g in group_ids if group_labels.get(g) is None]
325 if unlabeled_groups:
326 rng.shuffle(unlabeled_groups)
327 total_unlabeled = sum(len(group_to_indices[g]) for g in unlabeled_groups)
328 targets = _allocate_split_counts(total_unlabeled, split_probabilities)
329 counts = [0 for _ in split_probabilities]
330 for group_id in unlabeled_groups:
331 size = len(group_to_indices[group_id])
332 split_idx = _choose_split(counts, targets, active)
333 counts[split_idx] += size
334 if split_idx == 0:
335 train_idx.extend(group_to_indices[group_id])
336 elif split_idx == 1:
337 val_idx.extend(group_to_indices[group_id])
338 else:
339 test_idx.extend(group_to_indices[group_id])
340
341 out.loc[train_idx, split_column] = 0
342 out.loc[val_idx, split_column] = 1
343 out.loc[test_idx, split_column] = 2
344
345 logger.info("Successfully applied stratified random split")
346 logger.info(
347 f"Split counts: Train={len(train_idx)}, Val={len(val_idx)}, Test={len(test_idx)}"
348 )
349 return out.astype({split_column: int})
350
351 logger.info("Using stratified random split for train/validation/test sets (per-class allocation)")
352
353 rng = np.random.RandomState(random_state)
354 label_values = sorted(label_counts.index.tolist(), key=lambda x: str(x))
355 train_idx = []
356 val_idx = []
357 test_idx = []
358
359 for label_value in label_values:
360 label_indices = out.index[out[label_column] == label_value].tolist()
361 rng.shuffle(label_indices)
362 n_train, n_val, n_test = _allocate_split_counts(len(label_indices), split_probabilities)
363 train_idx.extend(label_indices[:n_train])
364 val_idx.extend(label_indices[n_train:n_train + n_val])
365 test_idx.extend(label_indices[n_train + n_val:n_train + n_val + n_test])
157 366
158 # assign split values 367 # assign split values
159 out.loc[train_idx, split_column] = 0 368 out.loc[train_idx, split_column] = 0
160 out.loc[val_idx, split_column] = 1 369 out.loc[val_idx, split_column] = 1
161 out.loc[test_idx, split_column] = 2 370 out.loc[test_idx, split_column] = 2