Mercurial > repos > goeckslab > image_learner
diff image_learner_cli.py @ 7:801a8b6973fb draft default tip
planemo upload for repository https://github.com/goeckslab/gleam.git commit 67df782ea551181e1d240d463764016ba528eba9
author | goeckslab |
---|---|
date | Fri, 08 Aug 2025 13:06:28 +0000 |
parents | 09904b1f61f5 |
children |
line wrap: on
line diff
--- a/image_learner_cli.py Mon Jul 14 14:47:32 2025 +0000 +++ b/image_learner_cli.py Fri Aug 08 13:06:28 2025 +0000 @@ -9,6 +9,7 @@ from pathlib import Path from typing import Any, Dict, Optional, Protocol, Tuple +import numpy as np import pandas as pd import pandas.api.types as ptypes import yaml @@ -418,7 +419,7 @@ def split_data_0_2( df: pd.DataFrame, split_column: str, - validation_size: float = 0.15, + validation_size: float = 0.1, random_state: int = 42, label_column: Optional[str] = None, ) -> pd.DataFrame: @@ -431,15 +432,25 @@ if not idx_train: logger.info("No rows with split=0; nothing to do.") return out + + # Always use stratify if possible stratify_arr = None if label_column and label_column in out.columns: label_counts = out.loc[idx_train, label_column].value_counts() - if label_counts.size > 1 and (label_counts.min() * validation_size) >= 1: + if label_counts.size > 1: + # Force stratify even with fewer samples - adjust validation_size if needed + min_samples_per_class = label_counts.min() + if min_samples_per_class * validation_size < 1: + # Adjust validation_size to ensure at least 1 sample per class, but do not exceed original validation_size + adjusted_validation_size = min(validation_size, 1.0 / min_samples_per_class) + if adjusted_validation_size != validation_size: + validation_size = adjusted_validation_size + logger.info(f"Adjusted validation_size to {validation_size:.3f} to ensure at least one sample per class in validation") stratify_arr = out.loc[idx_train, label_column] + logger.info("Using stratified split for validation set") else: - logger.warning( - "Cannot stratify (too few labels); splitting without stratify." - ) + logger.warning("Only one label class found; cannot stratify") + if validation_size <= 0: logger.info("validation_size <= 0; keeping all as train.") return out @@ -447,6 +458,8 @@ logger.info("validation_size >= 1; moving all train → validation.") out.loc[idx_train, split_column] = 1 return out + + # Always try stratified split first try: train_idx, val_idx = train_test_split( idx_train, @@ -454,20 +467,109 @@ random_state=random_state, stratify=stratify_arr, ) + logger.info("Successfully applied stratified split") except ValueError as e: - logger.warning(f"Stratified split failed ({e}); retrying without stratify.") + logger.warning(f"Stratified split failed ({e}); falling back to random split.") train_idx, val_idx = train_test_split( idx_train, test_size=validation_size, random_state=random_state, stratify=None, ) + out.loc[train_idx, split_column] = 0 out.loc[val_idx, split_column] = 1 out[split_column] = out[split_column].astype(int) return out +def create_stratified_random_split( + df: pd.DataFrame, + split_column: str, + split_probabilities: list = [0.7, 0.1, 0.2], + random_state: int = 42, + label_column: Optional[str] = None, +) -> pd.DataFrame: + """Create a stratified random split when no split column exists.""" + out = df.copy() + + # initialize split column + out[split_column] = 0 + + if not label_column or label_column not in out.columns: + logger.warning("No label column found; using random split without stratification") + # fall back to simple random assignment + indices = out.index.tolist() + np.random.seed(random_state) + np.random.shuffle(indices) + + n_total = len(indices) + n_train = int(n_total * split_probabilities[0]) + n_val = int(n_total * split_probabilities[1]) + + out.loc[indices[:n_train], split_column] = 0 + out.loc[indices[n_train:n_train + n_val], split_column] = 1 + out.loc[indices[n_train + n_val:], split_column] = 2 + + return out.astype({split_column: int}) + + # check if stratification is possible + label_counts = out[label_column].value_counts() + min_samples_per_class = label_counts.min() + + # ensure we have enough samples for stratification: + # Each class must have at least as many samples as the number of splits, + # so that each split can receive at least one sample per class. + min_samples_required = len(split_probabilities) + if min_samples_per_class < min_samples_required: + logger.warning( + f"Insufficient samples per class for stratification (min: {min_samples_per_class}, required: {min_samples_required}); using random split" + ) + # fall back to simple random assignment + indices = out.index.tolist() + np.random.seed(random_state) + np.random.shuffle(indices) + + n_total = len(indices) + n_train = int(n_total * split_probabilities[0]) + n_val = int(n_total * split_probabilities[1]) + + out.loc[indices[:n_train], split_column] = 0 + out.loc[indices[n_train:n_train + n_val], split_column] = 1 + out.loc[indices[n_train + n_val:], split_column] = 2 + + return out.astype({split_column: int}) + + logger.info("Using stratified random split for train/validation/test sets") + + # first split: separate test set + train_val_idx, test_idx = train_test_split( + out.index.tolist(), + test_size=split_probabilities[2], + random_state=random_state, + stratify=out[label_column], + ) + + # second split: separate training and validation from remaining data + val_size_adjusted = split_probabilities[1] / (split_probabilities[0] + split_probabilities[1]) + train_idx, val_idx = train_test_split( + train_val_idx, + test_size=val_size_adjusted, + random_state=random_state, + stratify=out.loc[train_val_idx, label_column], + ) + + # assign split values + out.loc[train_idx, split_column] = 0 + out.loc[val_idx, split_column] = 1 + out.loc[test_idx, split_column] = 2 + + logger.info("Successfully applied stratified random split") + logger.info(f"Split counts: Train={len(train_idx)}, Val={len(val_idx)}, Test={len(test_idx)}") + + return out.astype({split_column: int}) + + class Backend(Protocol): """Interface for a machine learning backend.""" @@ -1089,15 +1191,22 @@ if SPLIT_COLUMN_NAME in df.columns: df, split_config, split_info = self._process_fixed_split(df) else: - logger.info("No split column; using random split") + logger.info("No split column; creating stratified random split") + df = create_stratified_random_split( + df=df, + split_column=SPLIT_COLUMN_NAME, + split_probabilities=self.args.split_probabilities, + random_state=self.args.random_seed, + label_column=LABEL_COLUMN_NAME, + ) split_config = { - "type": "random", - "probabilities": self.args.split_probabilities, + "type": "fixed", + "column": SPLIT_COLUMN_NAME, } split_info = ( - f"No split column in CSV. Used random split: " + f"No split column in CSV. Created stratified random split: " f"{[int(p * 100) for p in self.args.split_probabilities]}% " - f"for train/val/test." + f"for train/val/test with balanced label distribution." ) final_csv = self.temp_dir / TEMP_CSV_FILENAME @@ -1139,7 +1248,7 @@ "Detected a split column (with values 0 and 2) in the input CSV. " f"Used this column as a base and reassigned " f"{self.args.validation_size * 100:.1f}% " - "of the training set (originally labeled 0) to validation (labeled 1)." + "of the training set (originally labeled 0) to validation (labeled 1) using stratified sampling." ) logger.info("Applied custom 0/2 split.") elif unique.issubset({0, 1, 2}): @@ -1319,7 +1428,7 @@ parser.add_argument( "--validation-size", type=float, - default=0.15, + default=0.1, help="Fraction for validation (0.0–1.0)", ) parser.add_argument(