Mercurial > repos > goeckslab > image_learner
diff split_data.py @ 20:64872c48a21f draft
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
| author | goeckslab |
|---|---|
| date | Tue, 06 Jan 2026 15:35:11 +0000 |
| parents | bcfa2e234a80 |
| children |
line wrap: on
line diff
--- a/split_data.py Thu Dec 18 16:59:58 2025 +0000 +++ b/split_data.py Tue Jan 06 15:35:11 2026 +0000 @@ -81,6 +81,7 @@ split_probabilities: list = [0.7, 0.1, 0.2], random_state: int = 42, label_column: Optional[str] = None, + group_column: Optional[str] = None, ) -> pd.DataFrame: """Create a stratified random split when no split column exists.""" out = df.copy() @@ -88,22 +89,103 @@ # initialize split column out[split_column] = 0 + if group_column and group_column not in out.columns: + logger.warning( + "Group column '%s' not found in data; proceeding without group-aware split.", + group_column, + ) + group_column = None + + def _allocate_split_counts(n_total: int, probs: list) -> list: + """Allocate exact split counts using largest remainder rounding.""" + if n_total <= 0: + return [0 for _ in probs] + + counts = [0 for _ in probs] + active = [i for i, p in enumerate(probs) if p > 0] + remainder = n_total + + if active and n_total >= len(active): + for i in active: + counts[i] = 1 + remainder -= len(active) + + if remainder > 0: + probs_arr = np.array(probs, dtype=float) + probs_arr = probs_arr / probs_arr.sum() + raw = remainder * probs_arr + floors = np.floor(raw).astype(int) + for i, value in enumerate(floors.tolist()): + counts[i] += value + leftover = remainder - int(floors.sum()) + if leftover > 0 and active: + frac = raw - floors + order = sorted(active, key=lambda i: (-frac[i], i)) + for i in range(leftover): + counts[order[i % len(order)]] += 1 + + return counts + + def _choose_split(counts: list, targets: list, active: list) -> int: + remaining = [targets[i] - counts[i] for i in range(len(targets))] + best = max(active, key=lambda i: (remaining[i], -counts[i], -targets[i])) + if remaining[best] <= 0: + best = min(active, key=lambda i: counts[i]) + return best + if not label_column or label_column not in out.columns: logger.warning( "No label column found; using random split without stratification" ) - # fall back to simple random assignment + # fall back to random assignment (group-aware if requested) indices = out.index.tolist() - np.random.seed(random_state) - np.random.shuffle(indices) + rng = np.random.RandomState(random_state) + + if group_column: + group_series = out[group_column].astype(object) + missing_mask = group_series.isna() + if missing_mask.any(): + group_series = group_series.copy() + group_series.loc[missing_mask] = [ + f"__missing__{idx}" for idx in group_series.index[missing_mask] + ] + group_to_indices = {} + for idx, group_id in group_series.items(): + group_to_indices.setdefault(group_id, []).append(idx) + group_ids = sorted(group_to_indices.keys(), key=lambda x: str(x)) + rng.shuffle(group_ids) - n_total = len(indices) - n_train = int(n_total * split_probabilities[0]) - n_val = int(n_total * split_probabilities[1]) + targets = _allocate_split_counts(len(indices), split_probabilities) + counts = [0 for _ in split_probabilities] + active = [i for i, p in enumerate(split_probabilities) if p > 0] + train_idx = [] + val_idx = [] + test_idx = [] + + for group_id in group_ids: + size = len(group_to_indices[group_id]) + split_idx = _choose_split(counts, targets, active) + counts[split_idx] += size + if split_idx == 0: + train_idx.extend(group_to_indices[group_id]) + elif split_idx == 1: + val_idx.extend(group_to_indices[group_id]) + else: + test_idx.extend(group_to_indices[group_id]) + + out.loc[train_idx, split_column] = 0 + out.loc[val_idx, split_column] = 1 + out.loc[test_idx, split_column] = 2 + return out.astype({split_column: int}) + + rng.shuffle(indices) + + targets = _allocate_split_counts(len(indices), split_probabilities) + n_train, n_val, n_test = targets out.loc[indices[:n_train], split_column] = 0 out.loc[indices[n_train:n_train + n_val], split_column] = 1 - out.loc[indices[n_train + n_val:], split_column] = 2 + out.loc[indices[n_train + n_val:n_train + n_val + n_test], split_column] = 2 return out.astype({split_column: int}) @@ -112,48 +194,175 @@ min_samples_per_class = label_counts.min() # ensure we have enough samples for stratification: - # Each class must have at least as many samples as the number of splits, + # Each class must have at least as many samples as the number of nonzero splits, # so that each split can receive at least one sample per class. - min_samples_required = len(split_probabilities) + active_splits = [p for p in split_probabilities if p > 0] + min_samples_required = len(active_splits) if min_samples_per_class < min_samples_required: logger.warning( f"Insufficient samples per class for stratification (min: {min_samples_per_class}, required: {min_samples_required}); using random split" ) - # fall back to simple random assignment + # fall back to simple random assignment (group-aware if requested) indices = out.index.tolist() - np.random.seed(random_state) - np.random.shuffle(indices) + rng = np.random.RandomState(random_state) + + if group_column: + group_series = out[group_column].astype(object) + missing_mask = group_series.isna() + if missing_mask.any(): + group_series = group_series.copy() + group_series.loc[missing_mask] = [ + f"__missing__{idx}" for idx in group_series.index[missing_mask] + ] + group_to_indices = {} + for idx, group_id in group_series.items(): + group_to_indices.setdefault(group_id, []).append(idx) + group_ids = sorted(group_to_indices.keys(), key=lambda x: str(x)) + rng.shuffle(group_ids) - n_total = len(indices) - n_train = int(n_total * split_probabilities[0]) - n_val = int(n_total * split_probabilities[1]) + targets = _allocate_split_counts(len(indices), split_probabilities) + counts = [0 for _ in split_probabilities] + active = [i for i, p in enumerate(split_probabilities) if p > 0] + train_idx = [] + val_idx = [] + test_idx = [] + + for group_id in group_ids: + size = len(group_to_indices[group_id]) + split_idx = _choose_split(counts, targets, active) + counts[split_idx] += size + if split_idx == 0: + train_idx.extend(group_to_indices[group_id]) + elif split_idx == 1: + val_idx.extend(group_to_indices[group_id]) + else: + test_idx.extend(group_to_indices[group_id]) + + out.loc[train_idx, split_column] = 0 + out.loc[val_idx, split_column] = 1 + out.loc[test_idx, split_column] = 2 + return out.astype({split_column: int}) + + rng.shuffle(indices) + + targets = _allocate_split_counts(len(indices), split_probabilities) + n_train, n_val, n_test = targets out.loc[indices[:n_train], split_column] = 0 out.loc[indices[n_train:n_train + n_val], split_column] = 1 - out.loc[indices[n_train + n_val:], split_column] = 2 + out.loc[indices[n_train + n_val:n_train + n_val + n_test], split_column] = 2 return out.astype({split_column: int}) - logger.info("Using stratified random split for train/validation/test sets") + if group_column: + logger.info( + "Using stratified random split for train/validation/test sets (grouped by '%s')", + group_column, + ) + rng = np.random.RandomState(random_state) + + group_series = out[group_column].astype(object) + missing_mask = group_series.isna() + if missing_mask.any(): + group_series = group_series.copy() + group_series.loc[missing_mask] = [ + f"__missing__{idx}" for idx in group_series.index[missing_mask] + ] + + group_to_indices = {} + for idx, group_id in group_series.items(): + group_to_indices.setdefault(group_id, []).append(idx) + + group_ids = sorted(group_to_indices.keys(), key=lambda x: str(x)) + group_labels = {} + mixed_groups = [] + label_series = out[label_column] - # first split: separate test set - train_val_idx, test_idx = train_test_split( - out.index.tolist(), - test_size=split_probabilities[2], - random_state=random_state, - stratify=out[label_column], - ) + for group_id in group_ids: + labels = label_series.loc[group_to_indices[group_id]].dropna().unique() + if len(labels) == 1: + group_labels[group_id] = labels[0] + elif len(labels) == 0: + group_labels[group_id] = None + else: + mode_vals = label_series.loc[group_to_indices[group_id]].mode(dropna=True) + group_labels[group_id] = mode_vals.iloc[0] if not mode_vals.empty else labels[0] + mixed_groups.append(group_id) + + if mixed_groups: + logger.warning( + "Detected %d groups with multiple labels; using the most common label per group for stratification.", + len(mixed_groups), + ) + + train_idx = [] + val_idx = [] + test_idx = [] + active = [i for i, p in enumerate(split_probabilities) if p > 0] + + for label_value in sorted(label_counts.index.tolist(), key=lambda x: str(x)): + label_groups = [g for g in group_ids if group_labels.get(g) == label_value] + if not label_groups: + continue + rng.shuffle(label_groups) + label_total = sum(len(group_to_indices[g]) for g in label_groups) + targets = _allocate_split_counts(label_total, split_probabilities) + counts = [0 for _ in split_probabilities] - # second split: separate training and validation from remaining data - val_size_adjusted = split_probabilities[1] / ( - split_probabilities[0] + split_probabilities[1] - ) - train_idx, val_idx = train_test_split( - train_val_idx, - test_size=val_size_adjusted, - random_state=random_state, - stratify=out.loc[train_val_idx, label_column] if label_column and label_column in out.columns else None, - ) + for group_id in label_groups: + size = len(group_to_indices[group_id]) + split_idx = _choose_split(counts, targets, active) + counts[split_idx] += size + if split_idx == 0: + train_idx.extend(group_to_indices[group_id]) + elif split_idx == 1: + val_idx.extend(group_to_indices[group_id]) + else: + test_idx.extend(group_to_indices[group_id]) + + # Assign groups without a label (or missing labels) using overall targets. + unlabeled_groups = [g for g in group_ids if group_labels.get(g) is None] + if unlabeled_groups: + rng.shuffle(unlabeled_groups) + total_unlabeled = sum(len(group_to_indices[g]) for g in unlabeled_groups) + targets = _allocate_split_counts(total_unlabeled, split_probabilities) + counts = [0 for _ in split_probabilities] + for group_id in unlabeled_groups: + size = len(group_to_indices[group_id]) + split_idx = _choose_split(counts, targets, active) + counts[split_idx] += size + if split_idx == 0: + train_idx.extend(group_to_indices[group_id]) + elif split_idx == 1: + val_idx.extend(group_to_indices[group_id]) + else: + test_idx.extend(group_to_indices[group_id]) + + out.loc[train_idx, split_column] = 0 + out.loc[val_idx, split_column] = 1 + out.loc[test_idx, split_column] = 2 + + logger.info("Successfully applied stratified random split") + logger.info( + f"Split counts: Train={len(train_idx)}, Val={len(val_idx)}, Test={len(test_idx)}" + ) + return out.astype({split_column: int}) + + logger.info("Using stratified random split for train/validation/test sets (per-class allocation)") + + rng = np.random.RandomState(random_state) + label_values = sorted(label_counts.index.tolist(), key=lambda x: str(x)) + train_idx = [] + val_idx = [] + test_idx = [] + + for label_value in label_values: + label_indices = out.index[out[label_column] == label_value].tolist() + rng.shuffle(label_indices) + n_train, n_val, n_test = _allocate_split_counts(len(label_indices), split_probabilities) + train_idx.extend(label_indices[:n_train]) + val_idx.extend(label_indices[n_train:n_train + n_val]) + test_idx.extend(label_indices[n_train + n_val:n_train + n_val + n_test]) # assign split values out.loc[train_idx, split_column] = 0
