Mercurial > repos > goeckslab > image_learner
view split_data.py @ 21:d5c582cf74bc draft default tip
planemo upload for repository https://github.com/goeckslab/gleam.git commit eed8c1e1d99a8a0c8f3a6bfdd8af48a5bfa19444
| author | goeckslab |
|---|---|
| date | Tue, 20 Jan 2026 01:25:35 +0000 |
| parents | 64872c48a21f |
| children |
line wrap: on
line source
import argparse import logging from typing import Optional import numpy as np import pandas as pd from sklearn.model_selection import train_test_split logger = logging.getLogger("ImageLearner") def split_data_0_2( df: pd.DataFrame, split_column: str, validation_size: float = 0.1, random_state: int = 42, label_column: Optional[str] = None, ) -> pd.DataFrame: """Given a DataFrame whose split_column only contains {0,2}, re-assign a portion of the 0s to become 1s (validation).""" out = df.copy() out[split_column] = pd.to_numeric(out[split_column], errors="coerce").astype(int) idx_train = out.index[out[split_column] == 0].tolist() if not idx_train: logger.info("No rows with split=0; nothing to do.") return out stratify_arr = None if label_column and label_column in out.columns: label_counts = out.loc[idx_train, label_column].value_counts() if label_counts.size > 1: # Force stratify even with fewer samples - adjust validation_size if needed min_samples_per_class = label_counts.min() if min_samples_per_class * validation_size < 1: # Adjust validation_size to ensure at least 1 sample per class, but do not exceed original validation_size adjusted_validation_size = min( validation_size, 1.0 / min_samples_per_class ) if adjusted_validation_size != validation_size: validation_size = adjusted_validation_size logger.info( f"Adjusted validation_size to {validation_size:.3f} to ensure at least one sample per class in validation" ) stratify_arr = out.loc[idx_train, label_column] logger.info("Using stratified split for validation set") else: logger.warning("Only one label class found; cannot stratify") if validation_size <= 0: logger.info("validation_size <= 0; keeping all as train.") return out if validation_size >= 1: logger.info("validation_size >= 1; moving all train → validation.") out.loc[idx_train, split_column] = 1 return out # Always try stratified split first try: train_idx, val_idx = train_test_split( idx_train, test_size=validation_size, random_state=random_state, stratify=stratify_arr, ) logger.info("Successfully applied stratified split") except ValueError as e: logger.warning(f"Stratified split failed ({e}); falling back to random split.") train_idx, val_idx = train_test_split( idx_train, test_size=validation_size, random_state=random_state, stratify=None, ) out.loc[train_idx, split_column] = 0 out.loc[val_idx, split_column] = 1 out[split_column] = out[split_column].astype(int) return out def create_stratified_random_split( df: pd.DataFrame, split_column: str, split_probabilities: list = [0.7, 0.1, 0.2], random_state: int = 42, label_column: Optional[str] = None, group_column: Optional[str] = None, ) -> pd.DataFrame: """Create a stratified random split when no split column exists.""" out = df.copy() # initialize split column out[split_column] = 0 if group_column and group_column not in out.columns: logger.warning( "Group column '%s' not found in data; proceeding without group-aware split.", group_column, ) group_column = None def _allocate_split_counts(n_total: int, probs: list) -> list: """Allocate exact split counts using largest remainder rounding.""" if n_total <= 0: return [0 for _ in probs] counts = [0 for _ in probs] active = [i for i, p in enumerate(probs) if p > 0] remainder = n_total if active and n_total >= len(active): for i in active: counts[i] = 1 remainder -= len(active) if remainder > 0: probs_arr = np.array(probs, dtype=float) probs_arr = probs_arr / probs_arr.sum() raw = remainder * probs_arr floors = np.floor(raw).astype(int) for i, value in enumerate(floors.tolist()): counts[i] += value leftover = remainder - int(floors.sum()) if leftover > 0 and active: frac = raw - floors order = sorted(active, key=lambda i: (-frac[i], i)) for i in range(leftover): counts[order[i % len(order)]] += 1 return counts def _choose_split(counts: list, targets: list, active: list) -> int: remaining = [targets[i] - counts[i] for i in range(len(targets))] best = max(active, key=lambda i: (remaining[i], -counts[i], -targets[i])) if remaining[best] <= 0: best = min(active, key=lambda i: counts[i]) return best if not label_column or label_column not in out.columns: logger.warning( "No label column found; using random split without stratification" ) # fall back to random assignment (group-aware if requested) indices = out.index.tolist() rng = np.random.RandomState(random_state) if group_column: group_series = out[group_column].astype(object) missing_mask = group_series.isna() if missing_mask.any(): group_series = group_series.copy() group_series.loc[missing_mask] = [ f"__missing__{idx}" for idx in group_series.index[missing_mask] ] group_to_indices = {} for idx, group_id in group_series.items(): group_to_indices.setdefault(group_id, []).append(idx) group_ids = sorted(group_to_indices.keys(), key=lambda x: str(x)) rng.shuffle(group_ids) targets = _allocate_split_counts(len(indices), split_probabilities) counts = [0 for _ in split_probabilities] active = [i for i, p in enumerate(split_probabilities) if p > 0] train_idx = [] val_idx = [] test_idx = [] for group_id in group_ids: size = len(group_to_indices[group_id]) split_idx = _choose_split(counts, targets, active) counts[split_idx] += size if split_idx == 0: train_idx.extend(group_to_indices[group_id]) elif split_idx == 1: val_idx.extend(group_to_indices[group_id]) else: test_idx.extend(group_to_indices[group_id]) out.loc[train_idx, split_column] = 0 out.loc[val_idx, split_column] = 1 out.loc[test_idx, split_column] = 2 return out.astype({split_column: int}) rng.shuffle(indices) targets = _allocate_split_counts(len(indices), split_probabilities) n_train, n_val, n_test = targets out.loc[indices[:n_train], split_column] = 0 out.loc[indices[n_train:n_train + n_val], split_column] = 1 out.loc[indices[n_train + n_val:n_train + n_val + n_test], split_column] = 2 return out.astype({split_column: int}) # check if stratification is possible label_counts = out[label_column].value_counts() min_samples_per_class = label_counts.min() # ensure we have enough samples for stratification: # Each class must have at least as many samples as the number of nonzero splits, # so that each split can receive at least one sample per class. active_splits = [p for p in split_probabilities if p > 0] min_samples_required = len(active_splits) if min_samples_per_class < min_samples_required: logger.warning( f"Insufficient samples per class for stratification (min: {min_samples_per_class}, required: {min_samples_required}); using random split" ) # fall back to simple random assignment (group-aware if requested) indices = out.index.tolist() rng = np.random.RandomState(random_state) if group_column: group_series = out[group_column].astype(object) missing_mask = group_series.isna() if missing_mask.any(): group_series = group_series.copy() group_series.loc[missing_mask] = [ f"__missing__{idx}" for idx in group_series.index[missing_mask] ] group_to_indices = {} for idx, group_id in group_series.items(): group_to_indices.setdefault(group_id, []).append(idx) group_ids = sorted(group_to_indices.keys(), key=lambda x: str(x)) rng.shuffle(group_ids) targets = _allocate_split_counts(len(indices), split_probabilities) counts = [0 for _ in split_probabilities] active = [i for i, p in enumerate(split_probabilities) if p > 0] train_idx = [] val_idx = [] test_idx = [] for group_id in group_ids: size = len(group_to_indices[group_id]) split_idx = _choose_split(counts, targets, active) counts[split_idx] += size if split_idx == 0: train_idx.extend(group_to_indices[group_id]) elif split_idx == 1: val_idx.extend(group_to_indices[group_id]) else: test_idx.extend(group_to_indices[group_id]) out.loc[train_idx, split_column] = 0 out.loc[val_idx, split_column] = 1 out.loc[test_idx, split_column] = 2 return out.astype({split_column: int}) rng.shuffle(indices) targets = _allocate_split_counts(len(indices), split_probabilities) n_train, n_val, n_test = targets out.loc[indices[:n_train], split_column] = 0 out.loc[indices[n_train:n_train + n_val], split_column] = 1 out.loc[indices[n_train + n_val:n_train + n_val + n_test], split_column] = 2 return out.astype({split_column: int}) if group_column: logger.info( "Using stratified random split for train/validation/test sets (grouped by '%s')", group_column, ) rng = np.random.RandomState(random_state) group_series = out[group_column].astype(object) missing_mask = group_series.isna() if missing_mask.any(): group_series = group_series.copy() group_series.loc[missing_mask] = [ f"__missing__{idx}" for idx in group_series.index[missing_mask] ] group_to_indices = {} for idx, group_id in group_series.items(): group_to_indices.setdefault(group_id, []).append(idx) group_ids = sorted(group_to_indices.keys(), key=lambda x: str(x)) group_labels = {} mixed_groups = [] label_series = out[label_column] for group_id in group_ids: labels = label_series.loc[group_to_indices[group_id]].dropna().unique() if len(labels) == 1: group_labels[group_id] = labels[0] elif len(labels) == 0: group_labels[group_id] = None else: mode_vals = label_series.loc[group_to_indices[group_id]].mode(dropna=True) group_labels[group_id] = mode_vals.iloc[0] if not mode_vals.empty else labels[0] mixed_groups.append(group_id) if mixed_groups: logger.warning( "Detected %d groups with multiple labels; using the most common label per group for stratification.", len(mixed_groups), ) train_idx = [] val_idx = [] test_idx = [] active = [i for i, p in enumerate(split_probabilities) if p > 0] for label_value in sorted(label_counts.index.tolist(), key=lambda x: str(x)): label_groups = [g for g in group_ids if group_labels.get(g) == label_value] if not label_groups: continue rng.shuffle(label_groups) label_total = sum(len(group_to_indices[g]) for g in label_groups) targets = _allocate_split_counts(label_total, split_probabilities) counts = [0 for _ in split_probabilities] for group_id in label_groups: size = len(group_to_indices[group_id]) split_idx = _choose_split(counts, targets, active) counts[split_idx] += size if split_idx == 0: train_idx.extend(group_to_indices[group_id]) elif split_idx == 1: val_idx.extend(group_to_indices[group_id]) else: test_idx.extend(group_to_indices[group_id]) # Assign groups without a label (or missing labels) using overall targets. unlabeled_groups = [g for g in group_ids if group_labels.get(g) is None] if unlabeled_groups: rng.shuffle(unlabeled_groups) total_unlabeled = sum(len(group_to_indices[g]) for g in unlabeled_groups) targets = _allocate_split_counts(total_unlabeled, split_probabilities) counts = [0 for _ in split_probabilities] for group_id in unlabeled_groups: size = len(group_to_indices[group_id]) split_idx = _choose_split(counts, targets, active) counts[split_idx] += size if split_idx == 0: train_idx.extend(group_to_indices[group_id]) elif split_idx == 1: val_idx.extend(group_to_indices[group_id]) else: test_idx.extend(group_to_indices[group_id]) out.loc[train_idx, split_column] = 0 out.loc[val_idx, split_column] = 1 out.loc[test_idx, split_column] = 2 logger.info("Successfully applied stratified random split") logger.info( f"Split counts: Train={len(train_idx)}, Val={len(val_idx)}, Test={len(test_idx)}" ) return out.astype({split_column: int}) logger.info("Using stratified random split for train/validation/test sets (per-class allocation)") rng = np.random.RandomState(random_state) label_values = sorted(label_counts.index.tolist(), key=lambda x: str(x)) train_idx = [] val_idx = [] test_idx = [] for label_value in label_values: label_indices = out.index[out[label_column] == label_value].tolist() rng.shuffle(label_indices) n_train, n_val, n_test = _allocate_split_counts(len(label_indices), split_probabilities) train_idx.extend(label_indices[:n_train]) val_idx.extend(label_indices[n_train:n_train + n_val]) test_idx.extend(label_indices[n_train + n_val:n_train + n_val + n_test]) # assign split values out.loc[train_idx, split_column] = 0 out.loc[val_idx, split_column] = 1 out.loc[test_idx, split_column] = 2 logger.info("Successfully applied stratified random split") logger.info( f"Split counts: Train={len(train_idx)}, Val={len(val_idx)}, Test={len(test_idx)}" ) return out.astype({split_column: int}) class SplitProbAction(argparse.Action): def __call__(self, parser, namespace, values, option_string=None): train, val, test = values total = train + val + test if abs(total - 1.0) > 1e-6: parser.error( f"--split-probabilities must sum to 1.0; " f"got {train:.3f} + {val:.3f} + {test:.3f} = {total:.3f}" ) setattr(namespace, self.dest, values)
