Mercurial > repos > goeckslab > image_learner
diff split_data.py @ 12:bcfa2e234a80 draft
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
| author | goeckslab |
|---|---|
| date | Fri, 21 Nov 2025 15:58:13 +0000 |
| parents | |
| children |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/split_data.py Fri Nov 21 15:58:13 2025 +0000 @@ -0,0 +1,179 @@ +import argparse +import logging +from typing import Optional + +import numpy as np +import pandas as pd +from sklearn.model_selection import train_test_split + +logger = logging.getLogger("ImageLearner") + + +def split_data_0_2( + df: pd.DataFrame, + split_column: str, + validation_size: float = 0.1, + random_state: int = 42, + label_column: Optional[str] = None, +) -> pd.DataFrame: + """Given a DataFrame whose split_column only contains {0,2}, re-assign a portion of the 0s to become 1s (validation).""" + out = df.copy() + out[split_column] = pd.to_numeric(out[split_column], errors="coerce").astype(int) + + idx_train = out.index[out[split_column] == 0].tolist() + + if not idx_train: + logger.info("No rows with split=0; nothing to do.") + return out + stratify_arr = None + if label_column and label_column in out.columns: + label_counts = out.loc[idx_train, label_column].value_counts() + if label_counts.size > 1: + # Force stratify even with fewer samples - adjust validation_size if needed + min_samples_per_class = label_counts.min() + if min_samples_per_class * validation_size < 1: + # Adjust validation_size to ensure at least 1 sample per class, but do not exceed original validation_size + adjusted_validation_size = min( + validation_size, 1.0 / min_samples_per_class + ) + if adjusted_validation_size != validation_size: + validation_size = adjusted_validation_size + logger.info( + f"Adjusted validation_size to {validation_size:.3f} to ensure at least one sample per class in validation" + ) + stratify_arr = out.loc[idx_train, label_column] + logger.info("Using stratified split for validation set") + else: + logger.warning("Only one label class found; cannot stratify") + if validation_size <= 0: + logger.info("validation_size <= 0; keeping all as train.") + return out + if validation_size >= 1: + logger.info("validation_size >= 1; moving all train → validation.") + out.loc[idx_train, split_column] = 1 + return out + # Always try stratified split first + try: + train_idx, val_idx = train_test_split( + idx_train, + test_size=validation_size, + random_state=random_state, + stratify=stratify_arr, + ) + logger.info("Successfully applied stratified split") + except ValueError as e: + logger.warning(f"Stratified split failed ({e}); falling back to random split.") + train_idx, val_idx = train_test_split( + idx_train, + test_size=validation_size, + random_state=random_state, + stratify=None, + ) + out.loc[train_idx, split_column] = 0 + out.loc[val_idx, split_column] = 1 + out[split_column] = out[split_column].astype(int) + return out + + +def create_stratified_random_split( + df: pd.DataFrame, + split_column: str, + split_probabilities: list = [0.7, 0.1, 0.2], + random_state: int = 42, + label_column: Optional[str] = None, +) -> pd.DataFrame: + """Create a stratified random split when no split column exists.""" + out = df.copy() + + # initialize split column + out[split_column] = 0 + + if not label_column or label_column not in out.columns: + logger.warning( + "No label column found; using random split without stratification" + ) + # fall back to simple random assignment + indices = out.index.tolist() + np.random.seed(random_state) + np.random.shuffle(indices) + + n_total = len(indices) + n_train = int(n_total * split_probabilities[0]) + n_val = int(n_total * split_probabilities[1]) + + out.loc[indices[:n_train], split_column] = 0 + out.loc[indices[n_train:n_train + n_val], split_column] = 1 + out.loc[indices[n_train + n_val:], split_column] = 2 + + return out.astype({split_column: int}) + + # check if stratification is possible + label_counts = out[label_column].value_counts() + min_samples_per_class = label_counts.min() + + # ensure we have enough samples for stratification: + # Each class must have at least as many samples as the number of splits, + # so that each split can receive at least one sample per class. + min_samples_required = len(split_probabilities) + if min_samples_per_class < min_samples_required: + logger.warning( + f"Insufficient samples per class for stratification (min: {min_samples_per_class}, required: {min_samples_required}); using random split" + ) + # fall back to simple random assignment + indices = out.index.tolist() + np.random.seed(random_state) + np.random.shuffle(indices) + + n_total = len(indices) + n_train = int(n_total * split_probabilities[0]) + n_val = int(n_total * split_probabilities[1]) + + out.loc[indices[:n_train], split_column] = 0 + out.loc[indices[n_train:n_train + n_val], split_column] = 1 + out.loc[indices[n_train + n_val:], split_column] = 2 + + return out.astype({split_column: int}) + + logger.info("Using stratified random split for train/validation/test sets") + + # first split: separate test set + train_val_idx, test_idx = train_test_split( + out.index.tolist(), + test_size=split_probabilities[2], + random_state=random_state, + stratify=out[label_column], + ) + + # second split: separate training and validation from remaining data + val_size_adjusted = split_probabilities[1] / ( + split_probabilities[0] + split_probabilities[1] + ) + train_idx, val_idx = train_test_split( + train_val_idx, + test_size=val_size_adjusted, + random_state=random_state, + stratify=out.loc[train_val_idx, label_column] if label_column and label_column in out.columns else None, + ) + + # assign split values + out.loc[train_idx, split_column] = 0 + out.loc[val_idx, split_column] = 1 + out.loc[test_idx, split_column] = 2 + + logger.info("Successfully applied stratified random split") + logger.info( + f"Split counts: Train={len(train_idx)}, Val={len(val_idx)}, Test={len(test_idx)}" + ) + return out.astype({split_column: int}) + + +class SplitProbAction(argparse.Action): + def __call__(self, parser, namespace, values, option_string=None): + train, val, test = values + total = train + val + test + if abs(total - 1.0) > 1e-6: + parser.error( + f"--split-probabilities must sum to 1.0; " + f"got {train:.3f} + {val:.3f} + {test:.3f} = {total:.3f}" + ) + setattr(namespace, self.dest, values)
