comparison split_data.py @ 12:bcfa2e234a80 draft

planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
author goeckslab
date Fri, 21 Nov 2025 15:58:13 +0000
parents
children
comparison
equal deleted inserted replaced
11:c5150cceab47 12:bcfa2e234a80
1 import argparse
2 import logging
3 from typing import Optional
4
5 import numpy as np
6 import pandas as pd
7 from sklearn.model_selection import train_test_split
8
9 logger = logging.getLogger("ImageLearner")
10
11
12 def split_data_0_2(
13 df: pd.DataFrame,
14 split_column: str,
15 validation_size: float = 0.1,
16 random_state: int = 42,
17 label_column: Optional[str] = None,
18 ) -> pd.DataFrame:
19 """Given a DataFrame whose split_column only contains {0,2}, re-assign a portion of the 0s to become 1s (validation)."""
20 out = df.copy()
21 out[split_column] = pd.to_numeric(out[split_column], errors="coerce").astype(int)
22
23 idx_train = out.index[out[split_column] == 0].tolist()
24
25 if not idx_train:
26 logger.info("No rows with split=0; nothing to do.")
27 return out
28 stratify_arr = None
29 if label_column and label_column in out.columns:
30 label_counts = out.loc[idx_train, label_column].value_counts()
31 if label_counts.size > 1:
32 # Force stratify even with fewer samples - adjust validation_size if needed
33 min_samples_per_class = label_counts.min()
34 if min_samples_per_class * validation_size < 1:
35 # Adjust validation_size to ensure at least 1 sample per class, but do not exceed original validation_size
36 adjusted_validation_size = min(
37 validation_size, 1.0 / min_samples_per_class
38 )
39 if adjusted_validation_size != validation_size:
40 validation_size = adjusted_validation_size
41 logger.info(
42 f"Adjusted validation_size to {validation_size:.3f} to ensure at least one sample per class in validation"
43 )
44 stratify_arr = out.loc[idx_train, label_column]
45 logger.info("Using stratified split for validation set")
46 else:
47 logger.warning("Only one label class found; cannot stratify")
48 if validation_size <= 0:
49 logger.info("validation_size <= 0; keeping all as train.")
50 return out
51 if validation_size >= 1:
52 logger.info("validation_size >= 1; moving all train → validation.")
53 out.loc[idx_train, split_column] = 1
54 return out
55 # Always try stratified split first
56 try:
57 train_idx, val_idx = train_test_split(
58 idx_train,
59 test_size=validation_size,
60 random_state=random_state,
61 stratify=stratify_arr,
62 )
63 logger.info("Successfully applied stratified split")
64 except ValueError as e:
65 logger.warning(f"Stratified split failed ({e}); falling back to random split.")
66 train_idx, val_idx = train_test_split(
67 idx_train,
68 test_size=validation_size,
69 random_state=random_state,
70 stratify=None,
71 )
72 out.loc[train_idx, split_column] = 0
73 out.loc[val_idx, split_column] = 1
74 out[split_column] = out[split_column].astype(int)
75 return out
76
77
78 def create_stratified_random_split(
79 df: pd.DataFrame,
80 split_column: str,
81 split_probabilities: list = [0.7, 0.1, 0.2],
82 random_state: int = 42,
83 label_column: Optional[str] = None,
84 ) -> pd.DataFrame:
85 """Create a stratified random split when no split column exists."""
86 out = df.copy()
87
88 # initialize split column
89 out[split_column] = 0
90
91 if not label_column or label_column not in out.columns:
92 logger.warning(
93 "No label column found; using random split without stratification"
94 )
95 # fall back to simple random assignment
96 indices = out.index.tolist()
97 np.random.seed(random_state)
98 np.random.shuffle(indices)
99
100 n_total = len(indices)
101 n_train = int(n_total * split_probabilities[0])
102 n_val = int(n_total * split_probabilities[1])
103
104 out.loc[indices[:n_train], split_column] = 0
105 out.loc[indices[n_train:n_train + n_val], split_column] = 1
106 out.loc[indices[n_train + n_val:], split_column] = 2
107
108 return out.astype({split_column: int})
109
110 # check if stratification is possible
111 label_counts = out[label_column].value_counts()
112 min_samples_per_class = label_counts.min()
113
114 # ensure we have enough samples for stratification:
115 # Each class must have at least as many samples as the number of splits,
116 # so that each split can receive at least one sample per class.
117 min_samples_required = len(split_probabilities)
118 if min_samples_per_class < min_samples_required:
119 logger.warning(
120 f"Insufficient samples per class for stratification (min: {min_samples_per_class}, required: {min_samples_required}); using random split"
121 )
122 # fall back to simple random assignment
123 indices = out.index.tolist()
124 np.random.seed(random_state)
125 np.random.shuffle(indices)
126
127 n_total = len(indices)
128 n_train = int(n_total * split_probabilities[0])
129 n_val = int(n_total * split_probabilities[1])
130
131 out.loc[indices[:n_train], split_column] = 0
132 out.loc[indices[n_train:n_train + n_val], split_column] = 1
133 out.loc[indices[n_train + n_val:], split_column] = 2
134
135 return out.astype({split_column: int})
136
137 logger.info("Using stratified random split for train/validation/test sets")
138
139 # first split: separate test set
140 train_val_idx, test_idx = train_test_split(
141 out.index.tolist(),
142 test_size=split_probabilities[2],
143 random_state=random_state,
144 stratify=out[label_column],
145 )
146
147 # second split: separate training and validation from remaining data
148 val_size_adjusted = split_probabilities[1] / (
149 split_probabilities[0] + split_probabilities[1]
150 )
151 train_idx, val_idx = train_test_split(
152 train_val_idx,
153 test_size=val_size_adjusted,
154 random_state=random_state,
155 stratify=out.loc[train_val_idx, label_column] if label_column and label_column in out.columns else None,
156 )
157
158 # assign split values
159 out.loc[train_idx, split_column] = 0
160 out.loc[val_idx, split_column] = 1
161 out.loc[test_idx, split_column] = 2
162
163 logger.info("Successfully applied stratified random split")
164 logger.info(
165 f"Split counts: Train={len(train_idx)}, Val={len(val_idx)}, Test={len(test_idx)}"
166 )
167 return out.astype({split_column: int})
168
169
170 class SplitProbAction(argparse.Action):
171 def __call__(self, parser, namespace, values, option_string=None):
172 train, val, test = values
173 total = train + val + test
174 if abs(total - 1.0) > 1e-6:
175 parser.error(
176 f"--split-probabilities must sum to 1.0; "
177 f"got {train:.3f} + {val:.3f} + {test:.3f} = {total:.3f}"
178 )
179 setattr(namespace, self.dest, values)