Mercurial > repos > goeckslab > image_learner
annotate split_data.py @ 21:d5c582cf74bc draft default tip
planemo upload for repository https://github.com/goeckslab/gleam.git commit eed8c1e1d99a8a0c8f3a6bfdd8af48a5bfa19444
| author | goeckslab |
|---|---|
| date | Tue, 20 Jan 2026 01:25:35 +0000 |
| parents | 64872c48a21f |
| children |
| rev | line source |
|---|---|
|
12
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1 import argparse |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
2 import logging |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
3 from typing import Optional |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
4 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
5 import numpy as np |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
6 import pandas as pd |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
7 from sklearn.model_selection import train_test_split |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
8 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
9 logger = logging.getLogger("ImageLearner") |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
10 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
11 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
12 def split_data_0_2( |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
13 df: pd.DataFrame, |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
14 split_column: str, |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
15 validation_size: float = 0.1, |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
16 random_state: int = 42, |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
17 label_column: Optional[str] = None, |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
18 ) -> pd.DataFrame: |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
19 """Given a DataFrame whose split_column only contains {0,2}, re-assign a portion of the 0s to become 1s (validation).""" |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
20 out = df.copy() |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
21 out[split_column] = pd.to_numeric(out[split_column], errors="coerce").astype(int) |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
22 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
23 idx_train = out.index[out[split_column] == 0].tolist() |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
24 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
25 if not idx_train: |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
26 logger.info("No rows with split=0; nothing to do.") |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
27 return out |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
28 stratify_arr = None |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
29 if label_column and label_column in out.columns: |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
30 label_counts = out.loc[idx_train, label_column].value_counts() |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
31 if label_counts.size > 1: |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
32 # Force stratify even with fewer samples - adjust validation_size if needed |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
33 min_samples_per_class = label_counts.min() |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
34 if min_samples_per_class * validation_size < 1: |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
35 # Adjust validation_size to ensure at least 1 sample per class, but do not exceed original validation_size |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
36 adjusted_validation_size = min( |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
37 validation_size, 1.0 / min_samples_per_class |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
38 ) |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
39 if adjusted_validation_size != validation_size: |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
40 validation_size = adjusted_validation_size |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
41 logger.info( |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
42 f"Adjusted validation_size to {validation_size:.3f} to ensure at least one sample per class in validation" |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
43 ) |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
44 stratify_arr = out.loc[idx_train, label_column] |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
45 logger.info("Using stratified split for validation set") |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
46 else: |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
47 logger.warning("Only one label class found; cannot stratify") |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
48 if validation_size <= 0: |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
49 logger.info("validation_size <= 0; keeping all as train.") |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
50 return out |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
51 if validation_size >= 1: |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
52 logger.info("validation_size >= 1; moving all train → validation.") |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
53 out.loc[idx_train, split_column] = 1 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
54 return out |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
55 # Always try stratified split first |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
56 try: |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
57 train_idx, val_idx = train_test_split( |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
58 idx_train, |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
59 test_size=validation_size, |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
60 random_state=random_state, |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
61 stratify=stratify_arr, |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
62 ) |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
63 logger.info("Successfully applied stratified split") |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
64 except ValueError as e: |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
65 logger.warning(f"Stratified split failed ({e}); falling back to random split.") |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
66 train_idx, val_idx = train_test_split( |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
67 idx_train, |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
68 test_size=validation_size, |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
69 random_state=random_state, |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
70 stratify=None, |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
71 ) |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
72 out.loc[train_idx, split_column] = 0 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
73 out.loc[val_idx, split_column] = 1 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
74 out[split_column] = out[split_column].astype(int) |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
75 return out |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
76 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
77 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
78 def create_stratified_random_split( |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
79 df: pd.DataFrame, |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
80 split_column: str, |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
81 split_probabilities: list = [0.7, 0.1, 0.2], |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
82 random_state: int = 42, |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
83 label_column: Optional[str] = None, |
|
20
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
84 group_column: Optional[str] = None, |
|
12
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
85 ) -> pd.DataFrame: |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
86 """Create a stratified random split when no split column exists.""" |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
87 out = df.copy() |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
88 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
89 # initialize split column |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
90 out[split_column] = 0 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
91 |
|
20
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
92 if group_column and group_column not in out.columns: |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
93 logger.warning( |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
94 "Group column '%s' not found in data; proceeding without group-aware split.", |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
95 group_column, |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
96 ) |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
97 group_column = None |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
98 |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
99 def _allocate_split_counts(n_total: int, probs: list) -> list: |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
100 """Allocate exact split counts using largest remainder rounding.""" |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
101 if n_total <= 0: |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
102 return [0 for _ in probs] |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
103 |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
104 counts = [0 for _ in probs] |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
105 active = [i for i, p in enumerate(probs) if p > 0] |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
106 remainder = n_total |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
107 |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
108 if active and n_total >= len(active): |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
109 for i in active: |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
110 counts[i] = 1 |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
111 remainder -= len(active) |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
112 |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
113 if remainder > 0: |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
114 probs_arr = np.array(probs, dtype=float) |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
115 probs_arr = probs_arr / probs_arr.sum() |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
116 raw = remainder * probs_arr |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
117 floors = np.floor(raw).astype(int) |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
118 for i, value in enumerate(floors.tolist()): |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
119 counts[i] += value |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
120 leftover = remainder - int(floors.sum()) |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
121 if leftover > 0 and active: |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
122 frac = raw - floors |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
123 order = sorted(active, key=lambda i: (-frac[i], i)) |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
124 for i in range(leftover): |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
125 counts[order[i % len(order)]] += 1 |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
126 |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
127 return counts |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
128 |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
129 def _choose_split(counts: list, targets: list, active: list) -> int: |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
130 remaining = [targets[i] - counts[i] for i in range(len(targets))] |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
131 best = max(active, key=lambda i: (remaining[i], -counts[i], -targets[i])) |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
132 if remaining[best] <= 0: |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
133 best = min(active, key=lambda i: counts[i]) |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
134 return best |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
135 |
|
12
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
136 if not label_column or label_column not in out.columns: |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
137 logger.warning( |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
138 "No label column found; using random split without stratification" |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
139 ) |
|
20
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
140 # fall back to random assignment (group-aware if requested) |
|
12
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
141 indices = out.index.tolist() |
|
20
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
142 rng = np.random.RandomState(random_state) |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
143 |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
144 if group_column: |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
145 group_series = out[group_column].astype(object) |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
146 missing_mask = group_series.isna() |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
147 if missing_mask.any(): |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
148 group_series = group_series.copy() |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
149 group_series.loc[missing_mask] = [ |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
150 f"__missing__{idx}" for idx in group_series.index[missing_mask] |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
151 ] |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
152 group_to_indices = {} |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
153 for idx, group_id in group_series.items(): |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
154 group_to_indices.setdefault(group_id, []).append(idx) |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
155 group_ids = sorted(group_to_indices.keys(), key=lambda x: str(x)) |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
156 rng.shuffle(group_ids) |
|
12
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
157 |
|
20
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
158 targets = _allocate_split_counts(len(indices), split_probabilities) |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
159 counts = [0 for _ in split_probabilities] |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
160 active = [i for i, p in enumerate(split_probabilities) if p > 0] |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
161 train_idx = [] |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
162 val_idx = [] |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
163 test_idx = [] |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
164 |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
165 for group_id in group_ids: |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
166 size = len(group_to_indices[group_id]) |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
167 split_idx = _choose_split(counts, targets, active) |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
168 counts[split_idx] += size |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
169 if split_idx == 0: |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
170 train_idx.extend(group_to_indices[group_id]) |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
171 elif split_idx == 1: |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
172 val_idx.extend(group_to_indices[group_id]) |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
173 else: |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
174 test_idx.extend(group_to_indices[group_id]) |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
175 |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
176 out.loc[train_idx, split_column] = 0 |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
177 out.loc[val_idx, split_column] = 1 |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
178 out.loc[test_idx, split_column] = 2 |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
179 return out.astype({split_column: int}) |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
180 |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
181 rng.shuffle(indices) |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
182 |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
183 targets = _allocate_split_counts(len(indices), split_probabilities) |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
184 n_train, n_val, n_test = targets |
|
12
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
185 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
186 out.loc[indices[:n_train], split_column] = 0 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
187 out.loc[indices[n_train:n_train + n_val], split_column] = 1 |
|
20
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
188 out.loc[indices[n_train + n_val:n_train + n_val + n_test], split_column] = 2 |
|
12
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
189 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
190 return out.astype({split_column: int}) |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
191 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
192 # check if stratification is possible |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
193 label_counts = out[label_column].value_counts() |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
194 min_samples_per_class = label_counts.min() |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
195 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
196 # ensure we have enough samples for stratification: |
|
20
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
197 # Each class must have at least as many samples as the number of nonzero splits, |
|
12
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
198 # so that each split can receive at least one sample per class. |
|
20
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
199 active_splits = [p for p in split_probabilities if p > 0] |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
200 min_samples_required = len(active_splits) |
|
12
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
201 if min_samples_per_class < min_samples_required: |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
202 logger.warning( |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
203 f"Insufficient samples per class for stratification (min: {min_samples_per_class}, required: {min_samples_required}); using random split" |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
204 ) |
|
20
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
205 # fall back to simple random assignment (group-aware if requested) |
|
12
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
206 indices = out.index.tolist() |
|
20
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
207 rng = np.random.RandomState(random_state) |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
208 |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
209 if group_column: |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
210 group_series = out[group_column].astype(object) |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
211 missing_mask = group_series.isna() |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
212 if missing_mask.any(): |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
213 group_series = group_series.copy() |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
214 group_series.loc[missing_mask] = [ |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
215 f"__missing__{idx}" for idx in group_series.index[missing_mask] |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
216 ] |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
217 group_to_indices = {} |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
218 for idx, group_id in group_series.items(): |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
219 group_to_indices.setdefault(group_id, []).append(idx) |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
220 group_ids = sorted(group_to_indices.keys(), key=lambda x: str(x)) |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
221 rng.shuffle(group_ids) |
|
12
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
222 |
|
20
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
223 targets = _allocate_split_counts(len(indices), split_probabilities) |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
224 counts = [0 for _ in split_probabilities] |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
225 active = [i for i, p in enumerate(split_probabilities) if p > 0] |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
226 train_idx = [] |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
227 val_idx = [] |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
228 test_idx = [] |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
229 |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
230 for group_id in group_ids: |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
231 size = len(group_to_indices[group_id]) |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
232 split_idx = _choose_split(counts, targets, active) |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
233 counts[split_idx] += size |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
234 if split_idx == 0: |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
235 train_idx.extend(group_to_indices[group_id]) |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
236 elif split_idx == 1: |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
237 val_idx.extend(group_to_indices[group_id]) |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
238 else: |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
239 test_idx.extend(group_to_indices[group_id]) |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
240 |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
241 out.loc[train_idx, split_column] = 0 |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
242 out.loc[val_idx, split_column] = 1 |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
243 out.loc[test_idx, split_column] = 2 |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
244 return out.astype({split_column: int}) |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
245 |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
246 rng.shuffle(indices) |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
247 |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
248 targets = _allocate_split_counts(len(indices), split_probabilities) |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
249 n_train, n_val, n_test = targets |
|
12
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
250 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
251 out.loc[indices[:n_train], split_column] = 0 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
252 out.loc[indices[n_train:n_train + n_val], split_column] = 1 |
|
20
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
253 out.loc[indices[n_train + n_val:n_train + n_val + n_test], split_column] = 2 |
|
12
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
254 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
255 return out.astype({split_column: int}) |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
256 |
|
20
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
257 if group_column: |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
258 logger.info( |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
259 "Using stratified random split for train/validation/test sets (grouped by '%s')", |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
260 group_column, |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
261 ) |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
262 rng = np.random.RandomState(random_state) |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
263 |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
264 group_series = out[group_column].astype(object) |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
265 missing_mask = group_series.isna() |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
266 if missing_mask.any(): |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
267 group_series = group_series.copy() |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
268 group_series.loc[missing_mask] = [ |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
269 f"__missing__{idx}" for idx in group_series.index[missing_mask] |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
270 ] |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
271 |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
272 group_to_indices = {} |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
273 for idx, group_id in group_series.items(): |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
274 group_to_indices.setdefault(group_id, []).append(idx) |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
275 |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
276 group_ids = sorted(group_to_indices.keys(), key=lambda x: str(x)) |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
277 group_labels = {} |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
278 mixed_groups = [] |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
279 label_series = out[label_column] |
|
12
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
280 |
|
20
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
281 for group_id in group_ids: |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
282 labels = label_series.loc[group_to_indices[group_id]].dropna().unique() |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
283 if len(labels) == 1: |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
284 group_labels[group_id] = labels[0] |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
285 elif len(labels) == 0: |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
286 group_labels[group_id] = None |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
287 else: |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
288 mode_vals = label_series.loc[group_to_indices[group_id]].mode(dropna=True) |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
289 group_labels[group_id] = mode_vals.iloc[0] if not mode_vals.empty else labels[0] |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
290 mixed_groups.append(group_id) |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
291 |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
292 if mixed_groups: |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
293 logger.warning( |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
294 "Detected %d groups with multiple labels; using the most common label per group for stratification.", |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
295 len(mixed_groups), |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
296 ) |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
297 |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
298 train_idx = [] |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
299 val_idx = [] |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
300 test_idx = [] |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
301 active = [i for i, p in enumerate(split_probabilities) if p > 0] |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
302 |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
303 for label_value in sorted(label_counts.index.tolist(), key=lambda x: str(x)): |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
304 label_groups = [g for g in group_ids if group_labels.get(g) == label_value] |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
305 if not label_groups: |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
306 continue |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
307 rng.shuffle(label_groups) |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
308 label_total = sum(len(group_to_indices[g]) for g in label_groups) |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
309 targets = _allocate_split_counts(label_total, split_probabilities) |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
310 counts = [0 for _ in split_probabilities] |
|
12
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
311 |
|
20
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
312 for group_id in label_groups: |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
313 size = len(group_to_indices[group_id]) |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
314 split_idx = _choose_split(counts, targets, active) |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
315 counts[split_idx] += size |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
316 if split_idx == 0: |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
317 train_idx.extend(group_to_indices[group_id]) |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
318 elif split_idx == 1: |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
319 val_idx.extend(group_to_indices[group_id]) |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
320 else: |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
321 test_idx.extend(group_to_indices[group_id]) |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
322 |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
323 # Assign groups without a label (or missing labels) using overall targets. |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
324 unlabeled_groups = [g for g in group_ids if group_labels.get(g) is None] |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
325 if unlabeled_groups: |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
326 rng.shuffle(unlabeled_groups) |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
327 total_unlabeled = sum(len(group_to_indices[g]) for g in unlabeled_groups) |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
328 targets = _allocate_split_counts(total_unlabeled, split_probabilities) |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
329 counts = [0 for _ in split_probabilities] |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
330 for group_id in unlabeled_groups: |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
331 size = len(group_to_indices[group_id]) |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
332 split_idx = _choose_split(counts, targets, active) |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
333 counts[split_idx] += size |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
334 if split_idx == 0: |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
335 train_idx.extend(group_to_indices[group_id]) |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
336 elif split_idx == 1: |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
337 val_idx.extend(group_to_indices[group_id]) |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
338 else: |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
339 test_idx.extend(group_to_indices[group_id]) |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
340 |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
341 out.loc[train_idx, split_column] = 0 |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
342 out.loc[val_idx, split_column] = 1 |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
343 out.loc[test_idx, split_column] = 2 |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
344 |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
345 logger.info("Successfully applied stratified random split") |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
346 logger.info( |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
347 f"Split counts: Train={len(train_idx)}, Val={len(val_idx)}, Test={len(test_idx)}" |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
348 ) |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
349 return out.astype({split_column: int}) |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
350 |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
351 logger.info("Using stratified random split for train/validation/test sets (per-class allocation)") |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
352 |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
353 rng = np.random.RandomState(random_state) |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
354 label_values = sorted(label_counts.index.tolist(), key=lambda x: str(x)) |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
355 train_idx = [] |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
356 val_idx = [] |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
357 test_idx = [] |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
358 |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
359 for label_value in label_values: |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
360 label_indices = out.index[out[label_column] == label_value].tolist() |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
361 rng.shuffle(label_indices) |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
362 n_train, n_val, n_test = _allocate_split_counts(len(label_indices), split_probabilities) |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
363 train_idx.extend(label_indices[:n_train]) |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
364 val_idx.extend(label_indices[n_train:n_train + n_val]) |
|
64872c48a21f
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents:
12
diff
changeset
|
365 test_idx.extend(label_indices[n_train + n_val:n_train + n_val + n_test]) |
|
12
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
366 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
367 # assign split values |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
368 out.loc[train_idx, split_column] = 0 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
369 out.loc[val_idx, split_column] = 1 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
370 out.loc[test_idx, split_column] = 2 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
371 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
372 logger.info("Successfully applied stratified random split") |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
373 logger.info( |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
374 f"Split counts: Train={len(train_idx)}, Val={len(val_idx)}, Test={len(test_idx)}" |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
375 ) |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
376 return out.astype({split_column: int}) |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
377 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
378 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
379 class SplitProbAction(argparse.Action): |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
380 def __call__(self, parser, namespace, values, option_string=None): |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
381 train, val, test = values |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
382 total = train + val + test |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
383 if abs(total - 1.0) > 1e-6: |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
384 parser.error( |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
385 f"--split-probabilities must sum to 1.0; " |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
386 f"got {train:.3f} + {val:.3f} + {test:.3f} = {total:.3f}" |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
387 ) |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
388 setattr(namespace, self.dest, values) |
