annotate split_data.py @ 20:64872c48a21f draft

planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
author goeckslab
date Tue, 06 Jan 2026 15:35:11 +0000
parents bcfa2e234a80
children
Ignore whitespace changes - Everywhere: Within whitespace: At end of lines:
rev   line source
12
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
1 import argparse
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
2 import logging
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
3 from typing import Optional
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
4
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
5 import numpy as np
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
6 import pandas as pd
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
7 from sklearn.model_selection import train_test_split
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
8
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
9 logger = logging.getLogger("ImageLearner")
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
10
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
11
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
12 def split_data_0_2(
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
13 df: pd.DataFrame,
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
14 split_column: str,
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
15 validation_size: float = 0.1,
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
16 random_state: int = 42,
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
17 label_column: Optional[str] = None,
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
18 ) -> pd.DataFrame:
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
19 """Given a DataFrame whose split_column only contains {0,2}, re-assign a portion of the 0s to become 1s (validation)."""
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
20 out = df.copy()
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
21 out[split_column] = pd.to_numeric(out[split_column], errors="coerce").astype(int)
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
22
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
23 idx_train = out.index[out[split_column] == 0].tolist()
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
24
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
25 if not idx_train:
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
26 logger.info("No rows with split=0; nothing to do.")
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
27 return out
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
28 stratify_arr = None
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
29 if label_column and label_column in out.columns:
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
30 label_counts = out.loc[idx_train, label_column].value_counts()
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
31 if label_counts.size > 1:
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
32 # Force stratify even with fewer samples - adjust validation_size if needed
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
33 min_samples_per_class = label_counts.min()
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
34 if min_samples_per_class * validation_size < 1:
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
35 # Adjust validation_size to ensure at least 1 sample per class, but do not exceed original validation_size
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
36 adjusted_validation_size = min(
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
37 validation_size, 1.0 / min_samples_per_class
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
38 )
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
39 if adjusted_validation_size != validation_size:
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
40 validation_size = adjusted_validation_size
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
41 logger.info(
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
42 f"Adjusted validation_size to {validation_size:.3f} to ensure at least one sample per class in validation"
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
43 )
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
44 stratify_arr = out.loc[idx_train, label_column]
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
45 logger.info("Using stratified split for validation set")
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
46 else:
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
47 logger.warning("Only one label class found; cannot stratify")
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
48 if validation_size <= 0:
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
49 logger.info("validation_size <= 0; keeping all as train.")
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
50 return out
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
51 if validation_size >= 1:
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
52 logger.info("validation_size >= 1; moving all train → validation.")
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
53 out.loc[idx_train, split_column] = 1
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
54 return out
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
55 # Always try stratified split first
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
56 try:
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
57 train_idx, val_idx = train_test_split(
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
58 idx_train,
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
59 test_size=validation_size,
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
60 random_state=random_state,
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
61 stratify=stratify_arr,
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
62 )
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
63 logger.info("Successfully applied stratified split")
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
64 except ValueError as e:
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
65 logger.warning(f"Stratified split failed ({e}); falling back to random split.")
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
66 train_idx, val_idx = train_test_split(
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
67 idx_train,
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
68 test_size=validation_size,
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
69 random_state=random_state,
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
70 stratify=None,
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
71 )
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
72 out.loc[train_idx, split_column] = 0
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
73 out.loc[val_idx, split_column] = 1
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
74 out[split_column] = out[split_column].astype(int)
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
75 return out
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
76
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
77
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
78 def create_stratified_random_split(
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
79 df: pd.DataFrame,
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
80 split_column: str,
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
81 split_probabilities: list = [0.7, 0.1, 0.2],
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
82 random_state: int = 42,
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
83 label_column: Optional[str] = None,
20
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
84 group_column: Optional[str] = None,
12
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
85 ) -> pd.DataFrame:
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
86 """Create a stratified random split when no split column exists."""
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
87 out = df.copy()
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
88
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
89 # initialize split column
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
90 out[split_column] = 0
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
91
20
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
92 if group_column and group_column not in out.columns:
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
93 logger.warning(
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
94 "Group column '%s' not found in data; proceeding without group-aware split.",
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
95 group_column,
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
96 )
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
97 group_column = None
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
98
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
99 def _allocate_split_counts(n_total: int, probs: list) -> list:
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
100 """Allocate exact split counts using largest remainder rounding."""
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
101 if n_total <= 0:
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
102 return [0 for _ in probs]
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
103
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
104 counts = [0 for _ in probs]
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
105 active = [i for i, p in enumerate(probs) if p > 0]
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
106 remainder = n_total
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
107
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
108 if active and n_total >= len(active):
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
109 for i in active:
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
110 counts[i] = 1
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
111 remainder -= len(active)
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
112
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
113 if remainder > 0:
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
114 probs_arr = np.array(probs, dtype=float)
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
115 probs_arr = probs_arr / probs_arr.sum()
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
116 raw = remainder * probs_arr
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
117 floors = np.floor(raw).astype(int)
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
118 for i, value in enumerate(floors.tolist()):
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
119 counts[i] += value
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
120 leftover = remainder - int(floors.sum())
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
121 if leftover > 0 and active:
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
122 frac = raw - floors
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
123 order = sorted(active, key=lambda i: (-frac[i], i))
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
124 for i in range(leftover):
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
125 counts[order[i % len(order)]] += 1
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
126
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
127 return counts
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
128
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
129 def _choose_split(counts: list, targets: list, active: list) -> int:
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
130 remaining = [targets[i] - counts[i] for i in range(len(targets))]
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
131 best = max(active, key=lambda i: (remaining[i], -counts[i], -targets[i]))
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
132 if remaining[best] <= 0:
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
133 best = min(active, key=lambda i: counts[i])
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
134 return best
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
135
12
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
136 if not label_column or label_column not in out.columns:
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
137 logger.warning(
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
138 "No label column found; using random split without stratification"
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
139 )
20
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
140 # fall back to random assignment (group-aware if requested)
12
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
141 indices = out.index.tolist()
20
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
142 rng = np.random.RandomState(random_state)
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
143
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
144 if group_column:
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
145 group_series = out[group_column].astype(object)
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
146 missing_mask = group_series.isna()
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
147 if missing_mask.any():
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
148 group_series = group_series.copy()
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
149 group_series.loc[missing_mask] = [
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
150 f"__missing__{idx}" for idx in group_series.index[missing_mask]
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
151 ]
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
152 group_to_indices = {}
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
153 for idx, group_id in group_series.items():
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
154 group_to_indices.setdefault(group_id, []).append(idx)
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
155 group_ids = sorted(group_to_indices.keys(), key=lambda x: str(x))
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
156 rng.shuffle(group_ids)
12
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
157
20
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
158 targets = _allocate_split_counts(len(indices), split_probabilities)
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
159 counts = [0 for _ in split_probabilities]
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
160 active = [i for i, p in enumerate(split_probabilities) if p > 0]
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
161 train_idx = []
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
162 val_idx = []
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
163 test_idx = []
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
164
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
165 for group_id in group_ids:
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
166 size = len(group_to_indices[group_id])
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
167 split_idx = _choose_split(counts, targets, active)
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
168 counts[split_idx] += size
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
169 if split_idx == 0:
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
170 train_idx.extend(group_to_indices[group_id])
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
171 elif split_idx == 1:
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
172 val_idx.extend(group_to_indices[group_id])
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
173 else:
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
174 test_idx.extend(group_to_indices[group_id])
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
175
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
176 out.loc[train_idx, split_column] = 0
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
177 out.loc[val_idx, split_column] = 1
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
178 out.loc[test_idx, split_column] = 2
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
179 return out.astype({split_column: int})
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
180
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
181 rng.shuffle(indices)
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
182
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
183 targets = _allocate_split_counts(len(indices), split_probabilities)
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
184 n_train, n_val, n_test = targets
12
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
185
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
186 out.loc[indices[:n_train], split_column] = 0
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
187 out.loc[indices[n_train:n_train + n_val], split_column] = 1
20
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
188 out.loc[indices[n_train + n_val:n_train + n_val + n_test], split_column] = 2
12
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
189
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
190 return out.astype({split_column: int})
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
191
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
192 # check if stratification is possible
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
193 label_counts = out[label_column].value_counts()
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
194 min_samples_per_class = label_counts.min()
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
195
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
196 # ensure we have enough samples for stratification:
20
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
197 # Each class must have at least as many samples as the number of nonzero splits,
12
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
198 # so that each split can receive at least one sample per class.
20
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
199 active_splits = [p for p in split_probabilities if p > 0]
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
200 min_samples_required = len(active_splits)
12
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
201 if min_samples_per_class < min_samples_required:
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
202 logger.warning(
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
203 f"Insufficient samples per class for stratification (min: {min_samples_per_class}, required: {min_samples_required}); using random split"
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
204 )
20
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
205 # fall back to simple random assignment (group-aware if requested)
12
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
206 indices = out.index.tolist()
20
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
207 rng = np.random.RandomState(random_state)
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
208
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
209 if group_column:
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
210 group_series = out[group_column].astype(object)
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
211 missing_mask = group_series.isna()
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
212 if missing_mask.any():
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
213 group_series = group_series.copy()
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
214 group_series.loc[missing_mask] = [
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
215 f"__missing__{idx}" for idx in group_series.index[missing_mask]
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
216 ]
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
217 group_to_indices = {}
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
218 for idx, group_id in group_series.items():
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
219 group_to_indices.setdefault(group_id, []).append(idx)
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
220 group_ids = sorted(group_to_indices.keys(), key=lambda x: str(x))
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
221 rng.shuffle(group_ids)
12
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
222
20
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
223 targets = _allocate_split_counts(len(indices), split_probabilities)
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
224 counts = [0 for _ in split_probabilities]
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
225 active = [i for i, p in enumerate(split_probabilities) if p > 0]
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
226 train_idx = []
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
227 val_idx = []
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
228 test_idx = []
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
229
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
230 for group_id in group_ids:
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
231 size = len(group_to_indices[group_id])
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
232 split_idx = _choose_split(counts, targets, active)
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
233 counts[split_idx] += size
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
234 if split_idx == 0:
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
235 train_idx.extend(group_to_indices[group_id])
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
236 elif split_idx == 1:
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
237 val_idx.extend(group_to_indices[group_id])
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
238 else:
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
239 test_idx.extend(group_to_indices[group_id])
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
240
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
241 out.loc[train_idx, split_column] = 0
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
242 out.loc[val_idx, split_column] = 1
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
243 out.loc[test_idx, split_column] = 2
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
244 return out.astype({split_column: int})
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
245
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
246 rng.shuffle(indices)
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
247
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
248 targets = _allocate_split_counts(len(indices), split_probabilities)
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
249 n_train, n_val, n_test = targets
12
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
250
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
251 out.loc[indices[:n_train], split_column] = 0
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
252 out.loc[indices[n_train:n_train + n_val], split_column] = 1
20
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
253 out.loc[indices[n_train + n_val:n_train + n_val + n_test], split_column] = 2
12
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
254
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
255 return out.astype({split_column: int})
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
256
20
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
257 if group_column:
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
258 logger.info(
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
259 "Using stratified random split for train/validation/test sets (grouped by '%s')",
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
260 group_column,
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
261 )
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
262 rng = np.random.RandomState(random_state)
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
263
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
264 group_series = out[group_column].astype(object)
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
265 missing_mask = group_series.isna()
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
266 if missing_mask.any():
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
267 group_series = group_series.copy()
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
268 group_series.loc[missing_mask] = [
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
269 f"__missing__{idx}" for idx in group_series.index[missing_mask]
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
270 ]
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
271
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
272 group_to_indices = {}
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
273 for idx, group_id in group_series.items():
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
274 group_to_indices.setdefault(group_id, []).append(idx)
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
275
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
276 group_ids = sorted(group_to_indices.keys(), key=lambda x: str(x))
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
277 group_labels = {}
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
278 mixed_groups = []
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
279 label_series = out[label_column]
12
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
280
20
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
281 for group_id in group_ids:
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
282 labels = label_series.loc[group_to_indices[group_id]].dropna().unique()
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
283 if len(labels) == 1:
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
284 group_labels[group_id] = labels[0]
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
285 elif len(labels) == 0:
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
286 group_labels[group_id] = None
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
287 else:
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
288 mode_vals = label_series.loc[group_to_indices[group_id]].mode(dropna=True)
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
289 group_labels[group_id] = mode_vals.iloc[0] if not mode_vals.empty else labels[0]
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
290 mixed_groups.append(group_id)
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
291
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
292 if mixed_groups:
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
293 logger.warning(
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
294 "Detected %d groups with multiple labels; using the most common label per group for stratification.",
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
295 len(mixed_groups),
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
296 )
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
297
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
298 train_idx = []
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
299 val_idx = []
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
300 test_idx = []
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
301 active = [i for i, p in enumerate(split_probabilities) if p > 0]
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
302
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
303 for label_value in sorted(label_counts.index.tolist(), key=lambda x: str(x)):
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
304 label_groups = [g for g in group_ids if group_labels.get(g) == label_value]
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
305 if not label_groups:
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
306 continue
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
307 rng.shuffle(label_groups)
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
308 label_total = sum(len(group_to_indices[g]) for g in label_groups)
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
309 targets = _allocate_split_counts(label_total, split_probabilities)
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
310 counts = [0 for _ in split_probabilities]
12
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
311
20
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
312 for group_id in label_groups:
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
313 size = len(group_to_indices[group_id])
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
314 split_idx = _choose_split(counts, targets, active)
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
315 counts[split_idx] += size
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
316 if split_idx == 0:
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
317 train_idx.extend(group_to_indices[group_id])
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
318 elif split_idx == 1:
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
319 val_idx.extend(group_to_indices[group_id])
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
320 else:
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
321 test_idx.extend(group_to_indices[group_id])
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
322
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
323 # Assign groups without a label (or missing labels) using overall targets.
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
324 unlabeled_groups = [g for g in group_ids if group_labels.get(g) is None]
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
325 if unlabeled_groups:
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
326 rng.shuffle(unlabeled_groups)
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
327 total_unlabeled = sum(len(group_to_indices[g]) for g in unlabeled_groups)
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
328 targets = _allocate_split_counts(total_unlabeled, split_probabilities)
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
329 counts = [0 for _ in split_probabilities]
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
330 for group_id in unlabeled_groups:
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
331 size = len(group_to_indices[group_id])
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
332 split_idx = _choose_split(counts, targets, active)
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
333 counts[split_idx] += size
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
334 if split_idx == 0:
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
335 train_idx.extend(group_to_indices[group_id])
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
336 elif split_idx == 1:
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
337 val_idx.extend(group_to_indices[group_id])
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
338 else:
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
339 test_idx.extend(group_to_indices[group_id])
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
340
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
341 out.loc[train_idx, split_column] = 0
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
342 out.loc[val_idx, split_column] = 1
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
343 out.loc[test_idx, split_column] = 2
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
344
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
345 logger.info("Successfully applied stratified random split")
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
346 logger.info(
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
347 f"Split counts: Train={len(train_idx)}, Val={len(val_idx)}, Test={len(test_idx)}"
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
348 )
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
349 return out.astype({split_column: int})
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
350
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
351 logger.info("Using stratified random split for train/validation/test sets (per-class allocation)")
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
352
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
353 rng = np.random.RandomState(random_state)
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
354 label_values = sorted(label_counts.index.tolist(), key=lambda x: str(x))
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
355 train_idx = []
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
356 val_idx = []
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
357 test_idx = []
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
358
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
359 for label_value in label_values:
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
360 label_indices = out.index[out[label_column] == label_value].tolist()
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
361 rng.shuffle(label_indices)
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
362 n_train, n_val, n_test = _allocate_split_counts(len(label_indices), split_probabilities)
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
363 train_idx.extend(label_indices[:n_train])
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
364 val_idx.extend(label_indices[n_train:n_train + n_val])
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
365 test_idx.extend(label_indices[n_train + n_val:n_train + n_val + n_test])
12
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
366
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
367 # assign split values
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
368 out.loc[train_idx, split_column] = 0
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
369 out.loc[val_idx, split_column] = 1
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
370 out.loc[test_idx, split_column] = 2
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
371
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
372 logger.info("Successfully applied stratified random split")
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
373 logger.info(
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
374 f"Split counts: Train={len(train_idx)}, Val={len(val_idx)}, Test={len(test_idx)}"
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
375 )
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
376 return out.astype({split_column: int})
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
377
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
378
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
379 class SplitProbAction(argparse.Action):
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
380 def __call__(self, parser, namespace, values, option_string=None):
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
381 train, val, test = values
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
382 total = train + val + test
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
383 if abs(total - 1.0) > 1e-6:
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
384 parser.error(
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
385 f"--split-probabilities must sum to 1.0; "
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
386 f"got {train:.3f} + {val:.3f} + {test:.3f} = {total:.3f}"
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
387 )
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
388 setattr(namespace, self.dest, values)