Mercurial > repos > goeckslab > image_learner
annotate split_data.py @ 13:1a9c42974a5a draft default tip
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9f96da4ea7ab3b572af86698ff51b870125cd674
| author | goeckslab |
|---|---|
| date | Fri, 21 Nov 2025 17:35:00 +0000 |
| parents | bcfa2e234a80 |
| children |
| rev | line source |
|---|---|
|
12
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1 import argparse |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
2 import logging |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
3 from typing import Optional |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
4 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
5 import numpy as np |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
6 import pandas as pd |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
7 from sklearn.model_selection import train_test_split |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
8 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
9 logger = logging.getLogger("ImageLearner") |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
10 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
11 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
12 def split_data_0_2( |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
13 df: pd.DataFrame, |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
14 split_column: str, |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
15 validation_size: float = 0.1, |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
16 random_state: int = 42, |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
17 label_column: Optional[str] = None, |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
18 ) -> pd.DataFrame: |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
19 """Given a DataFrame whose split_column only contains {0,2}, re-assign a portion of the 0s to become 1s (validation).""" |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
20 out = df.copy() |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
21 out[split_column] = pd.to_numeric(out[split_column], errors="coerce").astype(int) |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
22 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
23 idx_train = out.index[out[split_column] == 0].tolist() |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
24 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
25 if not idx_train: |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
26 logger.info("No rows with split=0; nothing to do.") |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
27 return out |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
28 stratify_arr = None |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
29 if label_column and label_column in out.columns: |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
30 label_counts = out.loc[idx_train, label_column].value_counts() |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
31 if label_counts.size > 1: |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
32 # Force stratify even with fewer samples - adjust validation_size if needed |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
33 min_samples_per_class = label_counts.min() |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
34 if min_samples_per_class * validation_size < 1: |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
35 # Adjust validation_size to ensure at least 1 sample per class, but do not exceed original validation_size |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
36 adjusted_validation_size = min( |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
37 validation_size, 1.0 / min_samples_per_class |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
38 ) |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
39 if adjusted_validation_size != validation_size: |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
40 validation_size = adjusted_validation_size |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
41 logger.info( |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
42 f"Adjusted validation_size to {validation_size:.3f} to ensure at least one sample per class in validation" |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
43 ) |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
44 stratify_arr = out.loc[idx_train, label_column] |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
45 logger.info("Using stratified split for validation set") |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
46 else: |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
47 logger.warning("Only one label class found; cannot stratify") |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
48 if validation_size <= 0: |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
49 logger.info("validation_size <= 0; keeping all as train.") |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
50 return out |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
51 if validation_size >= 1: |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
52 logger.info("validation_size >= 1; moving all train → validation.") |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
53 out.loc[idx_train, split_column] = 1 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
54 return out |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
55 # Always try stratified split first |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
56 try: |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
57 train_idx, val_idx = train_test_split( |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
58 idx_train, |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
59 test_size=validation_size, |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
60 random_state=random_state, |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
61 stratify=stratify_arr, |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
62 ) |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
63 logger.info("Successfully applied stratified split") |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
64 except ValueError as e: |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
65 logger.warning(f"Stratified split failed ({e}); falling back to random split.") |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
66 train_idx, val_idx = train_test_split( |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
67 idx_train, |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
68 test_size=validation_size, |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
69 random_state=random_state, |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
70 stratify=None, |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
71 ) |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
72 out.loc[train_idx, split_column] = 0 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
73 out.loc[val_idx, split_column] = 1 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
74 out[split_column] = out[split_column].astype(int) |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
75 return out |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
76 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
77 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
78 def create_stratified_random_split( |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
79 df: pd.DataFrame, |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
80 split_column: str, |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
81 split_probabilities: list = [0.7, 0.1, 0.2], |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
82 random_state: int = 42, |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
83 label_column: Optional[str] = None, |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
84 ) -> pd.DataFrame: |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
85 """Create a stratified random split when no split column exists.""" |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
86 out = df.copy() |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
87 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
88 # initialize split column |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
89 out[split_column] = 0 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
90 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
91 if not label_column or label_column not in out.columns: |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
92 logger.warning( |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
93 "No label column found; using random split without stratification" |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
94 ) |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
95 # fall back to simple random assignment |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
96 indices = out.index.tolist() |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
97 np.random.seed(random_state) |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
98 np.random.shuffle(indices) |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
99 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
100 n_total = len(indices) |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
101 n_train = int(n_total * split_probabilities[0]) |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
102 n_val = int(n_total * split_probabilities[1]) |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
103 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
104 out.loc[indices[:n_train], split_column] = 0 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
105 out.loc[indices[n_train:n_train + n_val], split_column] = 1 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
106 out.loc[indices[n_train + n_val:], split_column] = 2 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
107 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
108 return out.astype({split_column: int}) |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
109 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
110 # check if stratification is possible |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
111 label_counts = out[label_column].value_counts() |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
112 min_samples_per_class = label_counts.min() |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
113 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
114 # ensure we have enough samples for stratification: |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
115 # Each class must have at least as many samples as the number of splits, |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
116 # so that each split can receive at least one sample per class. |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
117 min_samples_required = len(split_probabilities) |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
118 if min_samples_per_class < min_samples_required: |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
119 logger.warning( |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
120 f"Insufficient samples per class for stratification (min: {min_samples_per_class}, required: {min_samples_required}); using random split" |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
121 ) |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
122 # fall back to simple random assignment |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
123 indices = out.index.tolist() |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
124 np.random.seed(random_state) |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
125 np.random.shuffle(indices) |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
126 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
127 n_total = len(indices) |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
128 n_train = int(n_total * split_probabilities[0]) |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
129 n_val = int(n_total * split_probabilities[1]) |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
130 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
131 out.loc[indices[:n_train], split_column] = 0 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
132 out.loc[indices[n_train:n_train + n_val], split_column] = 1 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
133 out.loc[indices[n_train + n_val:], split_column] = 2 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
134 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
135 return out.astype({split_column: int}) |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
136 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
137 logger.info("Using stratified random split for train/validation/test sets") |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
138 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
139 # first split: separate test set |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
140 train_val_idx, test_idx = train_test_split( |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
141 out.index.tolist(), |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
142 test_size=split_probabilities[2], |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
143 random_state=random_state, |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
144 stratify=out[label_column], |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
145 ) |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
146 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
147 # second split: separate training and validation from remaining data |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
148 val_size_adjusted = split_probabilities[1] / ( |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
149 split_probabilities[0] + split_probabilities[1] |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
150 ) |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
151 train_idx, val_idx = train_test_split( |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
152 train_val_idx, |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
153 test_size=val_size_adjusted, |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
154 random_state=random_state, |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
155 stratify=out.loc[train_val_idx, label_column] if label_column and label_column in out.columns else None, |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
156 ) |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
157 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
158 # assign split values |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
159 out.loc[train_idx, split_column] = 0 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
160 out.loc[val_idx, split_column] = 1 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
161 out.loc[test_idx, split_column] = 2 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
162 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
163 logger.info("Successfully applied stratified random split") |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
164 logger.info( |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
165 f"Split counts: Train={len(train_idx)}, Val={len(val_idx)}, Test={len(test_idx)}" |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
166 ) |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
167 return out.astype({split_column: int}) |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
168 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
169 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
170 class SplitProbAction(argparse.Action): |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
171 def __call__(self, parser, namespace, values, option_string=None): |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
172 train, val, test = values |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
173 total = train + val + test |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
174 if abs(total - 1.0) > 1e-6: |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
175 parser.error( |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
176 f"--split-probabilities must sum to 1.0; " |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
177 f"got {train:.3f} + {val:.3f} + {test:.3f} = {total:.3f}" |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
178 ) |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
179 setattr(namespace, self.dest, values) |
