Mercurial > repos > goeckslab > image_learner
comparison split_data.py @ 20:64872c48a21f draft
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
| author | goeckslab |
|---|---|
| date | Tue, 06 Jan 2026 15:35:11 +0000 |
| parents | bcfa2e234a80 |
| children |
comparison
equal
deleted
inserted
replaced
| 19:c460abae83eb | 20:64872c48a21f |
|---|---|
| 79 df: pd.DataFrame, | 79 df: pd.DataFrame, |
| 80 split_column: str, | 80 split_column: str, |
| 81 split_probabilities: list = [0.7, 0.1, 0.2], | 81 split_probabilities: list = [0.7, 0.1, 0.2], |
| 82 random_state: int = 42, | 82 random_state: int = 42, |
| 83 label_column: Optional[str] = None, | 83 label_column: Optional[str] = None, |
| 84 group_column: Optional[str] = None, | |
| 84 ) -> pd.DataFrame: | 85 ) -> pd.DataFrame: |
| 85 """Create a stratified random split when no split column exists.""" | 86 """Create a stratified random split when no split column exists.""" |
| 86 out = df.copy() | 87 out = df.copy() |
| 87 | 88 |
| 88 # initialize split column | 89 # initialize split column |
| 89 out[split_column] = 0 | 90 out[split_column] = 0 |
| 90 | 91 |
| 92 if group_column and group_column not in out.columns: | |
| 93 logger.warning( | |
| 94 "Group column '%s' not found in data; proceeding without group-aware split.", | |
| 95 group_column, | |
| 96 ) | |
| 97 group_column = None | |
| 98 | |
| 99 def _allocate_split_counts(n_total: int, probs: list) -> list: | |
| 100 """Allocate exact split counts using largest remainder rounding.""" | |
| 101 if n_total <= 0: | |
| 102 return [0 for _ in probs] | |
| 103 | |
| 104 counts = [0 for _ in probs] | |
| 105 active = [i for i, p in enumerate(probs) if p > 0] | |
| 106 remainder = n_total | |
| 107 | |
| 108 if active and n_total >= len(active): | |
| 109 for i in active: | |
| 110 counts[i] = 1 | |
| 111 remainder -= len(active) | |
| 112 | |
| 113 if remainder > 0: | |
| 114 probs_arr = np.array(probs, dtype=float) | |
| 115 probs_arr = probs_arr / probs_arr.sum() | |
| 116 raw = remainder * probs_arr | |
| 117 floors = np.floor(raw).astype(int) | |
| 118 for i, value in enumerate(floors.tolist()): | |
| 119 counts[i] += value | |
| 120 leftover = remainder - int(floors.sum()) | |
| 121 if leftover > 0 and active: | |
| 122 frac = raw - floors | |
| 123 order = sorted(active, key=lambda i: (-frac[i], i)) | |
| 124 for i in range(leftover): | |
| 125 counts[order[i % len(order)]] += 1 | |
| 126 | |
| 127 return counts | |
| 128 | |
| 129 def _choose_split(counts: list, targets: list, active: list) -> int: | |
| 130 remaining = [targets[i] - counts[i] for i in range(len(targets))] | |
| 131 best = max(active, key=lambda i: (remaining[i], -counts[i], -targets[i])) | |
| 132 if remaining[best] <= 0: | |
| 133 best = min(active, key=lambda i: counts[i]) | |
| 134 return best | |
| 135 | |
| 91 if not label_column or label_column not in out.columns: | 136 if not label_column or label_column not in out.columns: |
| 92 logger.warning( | 137 logger.warning( |
| 93 "No label column found; using random split without stratification" | 138 "No label column found; using random split without stratification" |
| 94 ) | 139 ) |
| 95 # fall back to simple random assignment | 140 # fall back to random assignment (group-aware if requested) |
| 96 indices = out.index.tolist() | 141 indices = out.index.tolist() |
| 97 np.random.seed(random_state) | 142 rng = np.random.RandomState(random_state) |
| 98 np.random.shuffle(indices) | 143 |
| 99 | 144 if group_column: |
| 100 n_total = len(indices) | 145 group_series = out[group_column].astype(object) |
| 101 n_train = int(n_total * split_probabilities[0]) | 146 missing_mask = group_series.isna() |
| 102 n_val = int(n_total * split_probabilities[1]) | 147 if missing_mask.any(): |
| 148 group_series = group_series.copy() | |
| 149 group_series.loc[missing_mask] = [ | |
| 150 f"__missing__{idx}" for idx in group_series.index[missing_mask] | |
| 151 ] | |
| 152 group_to_indices = {} | |
| 153 for idx, group_id in group_series.items(): | |
| 154 group_to_indices.setdefault(group_id, []).append(idx) | |
| 155 group_ids = sorted(group_to_indices.keys(), key=lambda x: str(x)) | |
| 156 rng.shuffle(group_ids) | |
| 157 | |
| 158 targets = _allocate_split_counts(len(indices), split_probabilities) | |
| 159 counts = [0 for _ in split_probabilities] | |
| 160 active = [i for i, p in enumerate(split_probabilities) if p > 0] | |
| 161 train_idx = [] | |
| 162 val_idx = [] | |
| 163 test_idx = [] | |
| 164 | |
| 165 for group_id in group_ids: | |
| 166 size = len(group_to_indices[group_id]) | |
| 167 split_idx = _choose_split(counts, targets, active) | |
| 168 counts[split_idx] += size | |
| 169 if split_idx == 0: | |
| 170 train_idx.extend(group_to_indices[group_id]) | |
| 171 elif split_idx == 1: | |
| 172 val_idx.extend(group_to_indices[group_id]) | |
| 173 else: | |
| 174 test_idx.extend(group_to_indices[group_id]) | |
| 175 | |
| 176 out.loc[train_idx, split_column] = 0 | |
| 177 out.loc[val_idx, split_column] = 1 | |
| 178 out.loc[test_idx, split_column] = 2 | |
| 179 return out.astype({split_column: int}) | |
| 180 | |
| 181 rng.shuffle(indices) | |
| 182 | |
| 183 targets = _allocate_split_counts(len(indices), split_probabilities) | |
| 184 n_train, n_val, n_test = targets | |
| 103 | 185 |
| 104 out.loc[indices[:n_train], split_column] = 0 | 186 out.loc[indices[:n_train], split_column] = 0 |
| 105 out.loc[indices[n_train:n_train + n_val], split_column] = 1 | 187 out.loc[indices[n_train:n_train + n_val], split_column] = 1 |
| 106 out.loc[indices[n_train + n_val:], split_column] = 2 | 188 out.loc[indices[n_train + n_val:n_train + n_val + n_test], split_column] = 2 |
| 107 | 189 |
| 108 return out.astype({split_column: int}) | 190 return out.astype({split_column: int}) |
| 109 | 191 |
| 110 # check if stratification is possible | 192 # check if stratification is possible |
| 111 label_counts = out[label_column].value_counts() | 193 label_counts = out[label_column].value_counts() |
| 112 min_samples_per_class = label_counts.min() | 194 min_samples_per_class = label_counts.min() |
| 113 | 195 |
| 114 # ensure we have enough samples for stratification: | 196 # ensure we have enough samples for stratification: |
| 115 # Each class must have at least as many samples as the number of splits, | 197 # Each class must have at least as many samples as the number of nonzero splits, |
| 116 # so that each split can receive at least one sample per class. | 198 # so that each split can receive at least one sample per class. |
| 117 min_samples_required = len(split_probabilities) | 199 active_splits = [p for p in split_probabilities if p > 0] |
| 200 min_samples_required = len(active_splits) | |
| 118 if min_samples_per_class < min_samples_required: | 201 if min_samples_per_class < min_samples_required: |
| 119 logger.warning( | 202 logger.warning( |
| 120 f"Insufficient samples per class for stratification (min: {min_samples_per_class}, required: {min_samples_required}); using random split" | 203 f"Insufficient samples per class for stratification (min: {min_samples_per_class}, required: {min_samples_required}); using random split" |
| 121 ) | 204 ) |
| 122 # fall back to simple random assignment | 205 # fall back to simple random assignment (group-aware if requested) |
| 123 indices = out.index.tolist() | 206 indices = out.index.tolist() |
| 124 np.random.seed(random_state) | 207 rng = np.random.RandomState(random_state) |
| 125 np.random.shuffle(indices) | 208 |
| 126 | 209 if group_column: |
| 127 n_total = len(indices) | 210 group_series = out[group_column].astype(object) |
| 128 n_train = int(n_total * split_probabilities[0]) | 211 missing_mask = group_series.isna() |
| 129 n_val = int(n_total * split_probabilities[1]) | 212 if missing_mask.any(): |
| 213 group_series = group_series.copy() | |
| 214 group_series.loc[missing_mask] = [ | |
| 215 f"__missing__{idx}" for idx in group_series.index[missing_mask] | |
| 216 ] | |
| 217 group_to_indices = {} | |
| 218 for idx, group_id in group_series.items(): | |
| 219 group_to_indices.setdefault(group_id, []).append(idx) | |
| 220 group_ids = sorted(group_to_indices.keys(), key=lambda x: str(x)) | |
| 221 rng.shuffle(group_ids) | |
| 222 | |
| 223 targets = _allocate_split_counts(len(indices), split_probabilities) | |
| 224 counts = [0 for _ in split_probabilities] | |
| 225 active = [i for i, p in enumerate(split_probabilities) if p > 0] | |
| 226 train_idx = [] | |
| 227 val_idx = [] | |
| 228 test_idx = [] | |
| 229 | |
| 230 for group_id in group_ids: | |
| 231 size = len(group_to_indices[group_id]) | |
| 232 split_idx = _choose_split(counts, targets, active) | |
| 233 counts[split_idx] += size | |
| 234 if split_idx == 0: | |
| 235 train_idx.extend(group_to_indices[group_id]) | |
| 236 elif split_idx == 1: | |
| 237 val_idx.extend(group_to_indices[group_id]) | |
| 238 else: | |
| 239 test_idx.extend(group_to_indices[group_id]) | |
| 240 | |
| 241 out.loc[train_idx, split_column] = 0 | |
| 242 out.loc[val_idx, split_column] = 1 | |
| 243 out.loc[test_idx, split_column] = 2 | |
| 244 return out.astype({split_column: int}) | |
| 245 | |
| 246 rng.shuffle(indices) | |
| 247 | |
| 248 targets = _allocate_split_counts(len(indices), split_probabilities) | |
| 249 n_train, n_val, n_test = targets | |
| 130 | 250 |
| 131 out.loc[indices[:n_train], split_column] = 0 | 251 out.loc[indices[:n_train], split_column] = 0 |
| 132 out.loc[indices[n_train:n_train + n_val], split_column] = 1 | 252 out.loc[indices[n_train:n_train + n_val], split_column] = 1 |
| 133 out.loc[indices[n_train + n_val:], split_column] = 2 | 253 out.loc[indices[n_train + n_val:n_train + n_val + n_test], split_column] = 2 |
| 134 | 254 |
| 135 return out.astype({split_column: int}) | 255 return out.astype({split_column: int}) |
| 136 | 256 |
| 137 logger.info("Using stratified random split for train/validation/test sets") | 257 if group_column: |
| 138 | 258 logger.info( |
| 139 # first split: separate test set | 259 "Using stratified random split for train/validation/test sets (grouped by '%s')", |
| 140 train_val_idx, test_idx = train_test_split( | 260 group_column, |
| 141 out.index.tolist(), | 261 ) |
| 142 test_size=split_probabilities[2], | 262 rng = np.random.RandomState(random_state) |
| 143 random_state=random_state, | 263 |
| 144 stratify=out[label_column], | 264 group_series = out[group_column].astype(object) |
| 145 ) | 265 missing_mask = group_series.isna() |
| 146 | 266 if missing_mask.any(): |
| 147 # second split: separate training and validation from remaining data | 267 group_series = group_series.copy() |
| 148 val_size_adjusted = split_probabilities[1] / ( | 268 group_series.loc[missing_mask] = [ |
| 149 split_probabilities[0] + split_probabilities[1] | 269 f"__missing__{idx}" for idx in group_series.index[missing_mask] |
| 150 ) | 270 ] |
| 151 train_idx, val_idx = train_test_split( | 271 |
| 152 train_val_idx, | 272 group_to_indices = {} |
| 153 test_size=val_size_adjusted, | 273 for idx, group_id in group_series.items(): |
| 154 random_state=random_state, | 274 group_to_indices.setdefault(group_id, []).append(idx) |
| 155 stratify=out.loc[train_val_idx, label_column] if label_column and label_column in out.columns else None, | 275 |
| 156 ) | 276 group_ids = sorted(group_to_indices.keys(), key=lambda x: str(x)) |
| 277 group_labels = {} | |
| 278 mixed_groups = [] | |
| 279 label_series = out[label_column] | |
| 280 | |
| 281 for group_id in group_ids: | |
| 282 labels = label_series.loc[group_to_indices[group_id]].dropna().unique() | |
| 283 if len(labels) == 1: | |
| 284 group_labels[group_id] = labels[0] | |
| 285 elif len(labels) == 0: | |
| 286 group_labels[group_id] = None | |
| 287 else: | |
| 288 mode_vals = label_series.loc[group_to_indices[group_id]].mode(dropna=True) | |
| 289 group_labels[group_id] = mode_vals.iloc[0] if not mode_vals.empty else labels[0] | |
| 290 mixed_groups.append(group_id) | |
| 291 | |
| 292 if mixed_groups: | |
| 293 logger.warning( | |
| 294 "Detected %d groups with multiple labels; using the most common label per group for stratification.", | |
| 295 len(mixed_groups), | |
| 296 ) | |
| 297 | |
| 298 train_idx = [] | |
| 299 val_idx = [] | |
| 300 test_idx = [] | |
| 301 active = [i for i, p in enumerate(split_probabilities) if p > 0] | |
| 302 | |
| 303 for label_value in sorted(label_counts.index.tolist(), key=lambda x: str(x)): | |
| 304 label_groups = [g for g in group_ids if group_labels.get(g) == label_value] | |
| 305 if not label_groups: | |
| 306 continue | |
| 307 rng.shuffle(label_groups) | |
| 308 label_total = sum(len(group_to_indices[g]) for g in label_groups) | |
| 309 targets = _allocate_split_counts(label_total, split_probabilities) | |
| 310 counts = [0 for _ in split_probabilities] | |
| 311 | |
| 312 for group_id in label_groups: | |
| 313 size = len(group_to_indices[group_id]) | |
| 314 split_idx = _choose_split(counts, targets, active) | |
| 315 counts[split_idx] += size | |
| 316 if split_idx == 0: | |
| 317 train_idx.extend(group_to_indices[group_id]) | |
| 318 elif split_idx == 1: | |
| 319 val_idx.extend(group_to_indices[group_id]) | |
| 320 else: | |
| 321 test_idx.extend(group_to_indices[group_id]) | |
| 322 | |
| 323 # Assign groups without a label (or missing labels) using overall targets. | |
| 324 unlabeled_groups = [g for g in group_ids if group_labels.get(g) is None] | |
| 325 if unlabeled_groups: | |
| 326 rng.shuffle(unlabeled_groups) | |
| 327 total_unlabeled = sum(len(group_to_indices[g]) for g in unlabeled_groups) | |
| 328 targets = _allocate_split_counts(total_unlabeled, split_probabilities) | |
| 329 counts = [0 for _ in split_probabilities] | |
| 330 for group_id in unlabeled_groups: | |
| 331 size = len(group_to_indices[group_id]) | |
| 332 split_idx = _choose_split(counts, targets, active) | |
| 333 counts[split_idx] += size | |
| 334 if split_idx == 0: | |
| 335 train_idx.extend(group_to_indices[group_id]) | |
| 336 elif split_idx == 1: | |
| 337 val_idx.extend(group_to_indices[group_id]) | |
| 338 else: | |
| 339 test_idx.extend(group_to_indices[group_id]) | |
| 340 | |
| 341 out.loc[train_idx, split_column] = 0 | |
| 342 out.loc[val_idx, split_column] = 1 | |
| 343 out.loc[test_idx, split_column] = 2 | |
| 344 | |
| 345 logger.info("Successfully applied stratified random split") | |
| 346 logger.info( | |
| 347 f"Split counts: Train={len(train_idx)}, Val={len(val_idx)}, Test={len(test_idx)}" | |
| 348 ) | |
| 349 return out.astype({split_column: int}) | |
| 350 | |
| 351 logger.info("Using stratified random split for train/validation/test sets (per-class allocation)") | |
| 352 | |
| 353 rng = np.random.RandomState(random_state) | |
| 354 label_values = sorted(label_counts.index.tolist(), key=lambda x: str(x)) | |
| 355 train_idx = [] | |
| 356 val_idx = [] | |
| 357 test_idx = [] | |
| 358 | |
| 359 for label_value in label_values: | |
| 360 label_indices = out.index[out[label_column] == label_value].tolist() | |
| 361 rng.shuffle(label_indices) | |
| 362 n_train, n_val, n_test = _allocate_split_counts(len(label_indices), split_probabilities) | |
| 363 train_idx.extend(label_indices[:n_train]) | |
| 364 val_idx.extend(label_indices[n_train:n_train + n_val]) | |
| 365 test_idx.extend(label_indices[n_train + n_val:n_train + n_val + n_test]) | |
| 157 | 366 |
| 158 # assign split values | 367 # assign split values |
| 159 out.loc[train_idx, split_column] = 0 | 368 out.loc[train_idx, split_column] = 0 |
| 160 out.loc[val_idx, split_column] = 1 | 369 out.loc[val_idx, split_column] = 1 |
| 161 out.loc[test_idx, split_column] = 2 | 370 out.loc[test_idx, split_column] = 2 |
