diff image_learner_cli.py @ 7:801a8b6973fb draft default tip

planemo upload for repository https://github.com/goeckslab/gleam.git commit 67df782ea551181e1d240d463764016ba528eba9
author goeckslab
date Fri, 08 Aug 2025 13:06:28 +0000
parents 09904b1f61f5
children
line wrap: on
line diff
--- a/image_learner_cli.py	Mon Jul 14 14:47:32 2025 +0000
+++ b/image_learner_cli.py	Fri Aug 08 13:06:28 2025 +0000
@@ -9,6 +9,7 @@
 from pathlib import Path
 from typing import Any, Dict, Optional, Protocol, Tuple
 
+import numpy as np
 import pandas as pd
 import pandas.api.types as ptypes
 import yaml
@@ -418,7 +419,7 @@
 def split_data_0_2(
     df: pd.DataFrame,
     split_column: str,
-    validation_size: float = 0.15,
+    validation_size: float = 0.1,
     random_state: int = 42,
     label_column: Optional[str] = None,
 ) -> pd.DataFrame:
@@ -431,15 +432,25 @@
     if not idx_train:
         logger.info("No rows with split=0; nothing to do.")
         return out
+
+    # Always use stratify if possible
     stratify_arr = None
     if label_column and label_column in out.columns:
         label_counts = out.loc[idx_train, label_column].value_counts()
-        if label_counts.size > 1 and (label_counts.min() * validation_size) >= 1:
+        if label_counts.size > 1:
+            # Force stratify even with fewer samples - adjust validation_size if needed
+            min_samples_per_class = label_counts.min()
+            if min_samples_per_class * validation_size < 1:
+                # Adjust validation_size to ensure at least 1 sample per class, but do not exceed original validation_size
+                adjusted_validation_size = min(validation_size, 1.0 / min_samples_per_class)
+                if adjusted_validation_size != validation_size:
+                    validation_size = adjusted_validation_size
+                    logger.info(f"Adjusted validation_size to {validation_size:.3f} to ensure at least one sample per class in validation")
             stratify_arr = out.loc[idx_train, label_column]
+            logger.info("Using stratified split for validation set")
         else:
-            logger.warning(
-                "Cannot stratify (too few labels); splitting without stratify."
-            )
+            logger.warning("Only one label class found; cannot stratify")
+
     if validation_size <= 0:
         logger.info("validation_size <= 0; keeping all as train.")
         return out
@@ -447,6 +458,8 @@
         logger.info("validation_size >= 1; moving all train → validation.")
         out.loc[idx_train, split_column] = 1
         return out
+
+    # Always try stratified split first
     try:
         train_idx, val_idx = train_test_split(
             idx_train,
@@ -454,20 +467,109 @@
             random_state=random_state,
             stratify=stratify_arr,
         )
+        logger.info("Successfully applied stratified split")
     except ValueError as e:
-        logger.warning(f"Stratified split failed ({e}); retrying without stratify.")
+        logger.warning(f"Stratified split failed ({e}); falling back to random split.")
         train_idx, val_idx = train_test_split(
             idx_train,
             test_size=validation_size,
             random_state=random_state,
             stratify=None,
         )
+
     out.loc[train_idx, split_column] = 0
     out.loc[val_idx, split_column] = 1
     out[split_column] = out[split_column].astype(int)
     return out
 
 
+def create_stratified_random_split(
+    df: pd.DataFrame,
+    split_column: str,
+    split_probabilities: list = [0.7, 0.1, 0.2],
+    random_state: int = 42,
+    label_column: Optional[str] = None,
+) -> pd.DataFrame:
+    """Create a stratified random split when no split column exists."""
+    out = df.copy()
+
+    # initialize split column
+    out[split_column] = 0
+
+    if not label_column or label_column not in out.columns:
+        logger.warning("No label column found; using random split without stratification")
+        # fall back to simple random assignment
+        indices = out.index.tolist()
+        np.random.seed(random_state)
+        np.random.shuffle(indices)
+
+        n_total = len(indices)
+        n_train = int(n_total * split_probabilities[0])
+        n_val = int(n_total * split_probabilities[1])
+
+        out.loc[indices[:n_train], split_column] = 0
+        out.loc[indices[n_train:n_train + n_val], split_column] = 1
+        out.loc[indices[n_train + n_val:], split_column] = 2
+
+        return out.astype({split_column: int})
+
+    # check if stratification is possible
+    label_counts = out[label_column].value_counts()
+    min_samples_per_class = label_counts.min()
+
+    # ensure we have enough samples for stratification:
+    # Each class must have at least as many samples as the number of splits,
+    # so that each split can receive at least one sample per class.
+    min_samples_required = len(split_probabilities)
+    if min_samples_per_class < min_samples_required:
+        logger.warning(
+            f"Insufficient samples per class for stratification (min: {min_samples_per_class}, required: {min_samples_required}); using random split"
+        )
+        # fall back to simple random assignment
+        indices = out.index.tolist()
+        np.random.seed(random_state)
+        np.random.shuffle(indices)
+
+        n_total = len(indices)
+        n_train = int(n_total * split_probabilities[0])
+        n_val = int(n_total * split_probabilities[1])
+
+        out.loc[indices[:n_train], split_column] = 0
+        out.loc[indices[n_train:n_train + n_val], split_column] = 1
+        out.loc[indices[n_train + n_val:], split_column] = 2
+
+        return out.astype({split_column: int})
+
+    logger.info("Using stratified random split for train/validation/test sets")
+
+    # first split: separate test set
+    train_val_idx, test_idx = train_test_split(
+        out.index.tolist(),
+        test_size=split_probabilities[2],
+        random_state=random_state,
+        stratify=out[label_column],
+    )
+
+    # second split: separate training and validation from remaining data
+    val_size_adjusted = split_probabilities[1] / (split_probabilities[0] + split_probabilities[1])
+    train_idx, val_idx = train_test_split(
+        train_val_idx,
+        test_size=val_size_adjusted,
+        random_state=random_state,
+        stratify=out.loc[train_val_idx, label_column],
+    )
+
+    # assign split values
+    out.loc[train_idx, split_column] = 0
+    out.loc[val_idx, split_column] = 1
+    out.loc[test_idx, split_column] = 2
+
+    logger.info("Successfully applied stratified random split")
+    logger.info(f"Split counts: Train={len(train_idx)}, Val={len(val_idx)}, Test={len(test_idx)}")
+
+    return out.astype({split_column: int})
+
+
 class Backend(Protocol):
     """Interface for a machine learning backend."""
 
@@ -1089,15 +1191,22 @@
         if SPLIT_COLUMN_NAME in df.columns:
             df, split_config, split_info = self._process_fixed_split(df)
         else:
-            logger.info("No split column; using random split")
+            logger.info("No split column; creating stratified random split")
+            df = create_stratified_random_split(
+                df=df,
+                split_column=SPLIT_COLUMN_NAME,
+                split_probabilities=self.args.split_probabilities,
+                random_state=self.args.random_seed,
+                label_column=LABEL_COLUMN_NAME,
+            )
             split_config = {
-                "type": "random",
-                "probabilities": self.args.split_probabilities,
+                "type": "fixed",
+                "column": SPLIT_COLUMN_NAME,
             }
             split_info = (
-                f"No split column in CSV. Used random split: "
+                f"No split column in CSV. Created stratified random split: "
                 f"{[int(p * 100) for p in self.args.split_probabilities]}% "
-                f"for train/val/test."
+                f"for train/val/test with balanced label distribution."
             )
 
         final_csv = self.temp_dir / TEMP_CSV_FILENAME
@@ -1139,7 +1248,7 @@
                     "Detected a split column (with values 0 and 2) in the input CSV. "
                     f"Used this column as a base and reassigned "
                     f"{self.args.validation_size * 100:.1f}% "
-                    "of the training set (originally labeled 0) to validation (labeled 1)."
+                    "of the training set (originally labeled 0) to validation (labeled 1) using stratified sampling."
                 )
                 logger.info("Applied custom 0/2 split.")
             elif unique.issubset({0, 1, 2}):
@@ -1319,7 +1428,7 @@
     parser.add_argument(
         "--validation-size",
         type=float,
-        default=0.15,
+        default=0.1,
         help="Fraction for validation (0.0–1.0)",
     )
     parser.add_argument(