annotate split_data.py @ 23:2c6624cae3c5 draft default tip

planemo upload for repository https://github.com/goeckslab/gleam.git commit 4fade0f8199988fd9cf56cbcb49fa4b949b659ec
author goeckslab
date Sun, 25 Jan 2026 01:09:56 +0000
parents 64872c48a21f
children
Ignore whitespace changes - Everywhere: Within whitespace: At end of lines:
rev   line source
12
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
1 import argparse
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
2 import logging
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
3 from typing import Optional
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
4
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
5 import numpy as np
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
6 import pandas as pd
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
7 from sklearn.model_selection import train_test_split
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
8
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
9 logger = logging.getLogger("ImageLearner")
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
10
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
11
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
12 def split_data_0_2(
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
13 df: pd.DataFrame,
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
14 split_column: str,
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
15 validation_size: float = 0.1,
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
16 random_state: int = 42,
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
17 label_column: Optional[str] = None,
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
18 ) -> pd.DataFrame:
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
19 """Given a DataFrame whose split_column only contains {0,2}, re-assign a portion of the 0s to become 1s (validation)."""
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
20 out = df.copy()
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
21 out[split_column] = pd.to_numeric(out[split_column], errors="coerce").astype(int)
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
22
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
23 idx_train = out.index[out[split_column] == 0].tolist()
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
24
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
25 if not idx_train:
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
26 logger.info("No rows with split=0; nothing to do.")
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
27 return out
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
28 stratify_arr = None
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
29 if label_column and label_column in out.columns:
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
30 label_counts = out.loc[idx_train, label_column].value_counts()
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
31 if label_counts.size > 1:
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
32 # Force stratify even with fewer samples - adjust validation_size if needed
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
33 min_samples_per_class = label_counts.min()
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
34 if min_samples_per_class * validation_size < 1:
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
35 # Adjust validation_size to ensure at least 1 sample per class, but do not exceed original validation_size
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
36 adjusted_validation_size = min(
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
37 validation_size, 1.0 / min_samples_per_class
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
38 )
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
39 if adjusted_validation_size != validation_size:
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
40 validation_size = adjusted_validation_size
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
41 logger.info(
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
42 f"Adjusted validation_size to {validation_size:.3f} to ensure at least one sample per class in validation"
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
43 )
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
44 stratify_arr = out.loc[idx_train, label_column]
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
45 logger.info("Using stratified split for validation set")
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
46 else:
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
47 logger.warning("Only one label class found; cannot stratify")
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
48 if validation_size <= 0:
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
49 logger.info("validation_size <= 0; keeping all as train.")
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
50 return out
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
51 if validation_size >= 1:
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
52 logger.info("validation_size >= 1; moving all train → validation.")
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
53 out.loc[idx_train, split_column] = 1
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
54 return out
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
55 # Always try stratified split first
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
56 try:
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
57 train_idx, val_idx = train_test_split(
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
58 idx_train,
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
59 test_size=validation_size,
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
60 random_state=random_state,
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
61 stratify=stratify_arr,
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
62 )
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
63 logger.info("Successfully applied stratified split")
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
64 except ValueError as e:
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
65 logger.warning(f"Stratified split failed ({e}); falling back to random split.")
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
66 train_idx, val_idx = train_test_split(
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
67 idx_train,
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
68 test_size=validation_size,
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
69 random_state=random_state,
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
70 stratify=None,
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
71 )
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
72 out.loc[train_idx, split_column] = 0
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
73 out.loc[val_idx, split_column] = 1
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
74 out[split_column] = out[split_column].astype(int)
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
75 return out
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
76
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
77
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
78 def create_stratified_random_split(
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
79 df: pd.DataFrame,
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
80 split_column: str,
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
81 split_probabilities: list = [0.7, 0.1, 0.2],
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
82 random_state: int = 42,
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
83 label_column: Optional[str] = None,
20
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
84 group_column: Optional[str] = None,
12
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
85 ) -> pd.DataFrame:
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
86 """Create a stratified random split when no split column exists."""
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
87 out = df.copy()
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
88
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
89 # initialize split column
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
90 out[split_column] = 0
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
91
20
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
92 if group_column and group_column not in out.columns:
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
93 logger.warning(
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
94 "Group column '%s' not found in data; proceeding without group-aware split.",
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
95 group_column,
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
96 )
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
97 group_column = None
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
98
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
99 def _allocate_split_counts(n_total: int, probs: list) -> list:
23
2c6624cae3c5 planemo upload for repository https://github.com/goeckslab/gleam.git commit 4fade0f8199988fd9cf56cbcb49fa4b949b659ec
goeckslab
parents: 20
diff changeset
100 """Allocate exact split counts using largest remainder rounding.
2c6624cae3c5 planemo upload for repository https://github.com/goeckslab/gleam.git commit 4fade0f8199988fd9cf56cbcb49fa4b949b659ec
goeckslab
parents: 20
diff changeset
101
2c6624cae3c5 planemo upload for repository https://github.com/goeckslab/gleam.git commit 4fade0f8199988fd9cf56cbcb49fa4b949b659ec
goeckslab
parents: 20
diff changeset
102 Ensures at least one sample per active split *after* proportional allocation,
2c6624cae3c5 planemo upload for repository https://github.com/goeckslab/gleam.git commit 4fade0f8199988fd9cf56cbcb49fa4b949b659ec
goeckslab
parents: 20
diff changeset
103 by moving samples from other splits when possible. Tie-breaking for leftovers
2c6624cae3c5 planemo upload for repository https://github.com/goeckslab/gleam.git commit 4fade0f8199988fd9cf56cbcb49fa4b949b659ec
goeckslab
parents: 20
diff changeset
104 and corrections prioritizes: train (0), test (2), validation (1).
2c6624cae3c5 planemo upload for repository https://github.com/goeckslab/gleam.git commit 4fade0f8199988fd9cf56cbcb49fa4b949b659ec
goeckslab
parents: 20
diff changeset
105 """
20
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
106 if n_total <= 0:
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
107 return [0 for _ in probs]
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
108
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
109 counts = [0 for _ in probs]
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
110 active = [i for i, p in enumerate(probs) if p > 0]
23
2c6624cae3c5 planemo upload for repository https://github.com/goeckslab/gleam.git commit 4fade0f8199988fd9cf56cbcb49fa4b949b659ec
goeckslab
parents: 20
diff changeset
111 if not active:
2c6624cae3c5 planemo upload for repository https://github.com/goeckslab/gleam.git commit 4fade0f8199988fd9cf56cbcb49fa4b949b659ec
goeckslab
parents: 20
diff changeset
112 return counts
20
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
113
23
2c6624cae3c5 planemo upload for repository https://github.com/goeckslab/gleam.git commit 4fade0f8199988fd9cf56cbcb49fa4b949b659ec
goeckslab
parents: 20
diff changeset
114 # If there are fewer samples than active splits, fill in order and return.
2c6624cae3c5 planemo upload for repository https://github.com/goeckslab/gleam.git commit 4fade0f8199988fd9cf56cbcb49fa4b949b659ec
goeckslab
parents: 20
diff changeset
115 if n_total < len(active):
2c6624cae3c5 planemo upload for repository https://github.com/goeckslab/gleam.git commit 4fade0f8199988fd9cf56cbcb49fa4b949b659ec
goeckslab
parents: 20
diff changeset
116 priority_order = [0, 2, 1]
2c6624cae3c5 planemo upload for repository https://github.com/goeckslab/gleam.git commit 4fade0f8199988fd9cf56cbcb49fa4b949b659ec
goeckslab
parents: 20
diff changeset
117 ordered = [i for i in priority_order if i in active] + [
2c6624cae3c5 planemo upload for repository https://github.com/goeckslab/gleam.git commit 4fade0f8199988fd9cf56cbcb49fa4b949b659ec
goeckslab
parents: 20
diff changeset
118 i for i in active if i not in priority_order
2c6624cae3c5 planemo upload for repository https://github.com/goeckslab/gleam.git commit 4fade0f8199988fd9cf56cbcb49fa4b949b659ec
goeckslab
parents: 20
diff changeset
119 ]
2c6624cae3c5 planemo upload for repository https://github.com/goeckslab/gleam.git commit 4fade0f8199988fd9cf56cbcb49fa4b949b659ec
goeckslab
parents: 20
diff changeset
120 remaining = n_total
2c6624cae3c5 planemo upload for repository https://github.com/goeckslab/gleam.git commit 4fade0f8199988fd9cf56cbcb49fa4b949b659ec
goeckslab
parents: 20
diff changeset
121 for idx in ordered:
2c6624cae3c5 planemo upload for repository https://github.com/goeckslab/gleam.git commit 4fade0f8199988fd9cf56cbcb49fa4b949b659ec
goeckslab
parents: 20
diff changeset
122 if remaining <= 0:
2c6624cae3c5 planemo upload for repository https://github.com/goeckslab/gleam.git commit 4fade0f8199988fd9cf56cbcb49fa4b949b659ec
goeckslab
parents: 20
diff changeset
123 break
2c6624cae3c5 planemo upload for repository https://github.com/goeckslab/gleam.git commit 4fade0f8199988fd9cf56cbcb49fa4b949b659ec
goeckslab
parents: 20
diff changeset
124 counts[idx] = 1
2c6624cae3c5 planemo upload for repository https://github.com/goeckslab/gleam.git commit 4fade0f8199988fd9cf56cbcb49fa4b949b659ec
goeckslab
parents: 20
diff changeset
125 remaining -= 1
2c6624cae3c5 planemo upload for repository https://github.com/goeckslab/gleam.git commit 4fade0f8199988fd9cf56cbcb49fa4b949b659ec
goeckslab
parents: 20
diff changeset
126 return counts
2c6624cae3c5 planemo upload for repository https://github.com/goeckslab/gleam.git commit 4fade0f8199988fd9cf56cbcb49fa4b949b659ec
goeckslab
parents: 20
diff changeset
127
2c6624cae3c5 planemo upload for repository https://github.com/goeckslab/gleam.git commit 4fade0f8199988fd9cf56cbcb49fa4b949b659ec
goeckslab
parents: 20
diff changeset
128 probs_arr = np.array(probs, dtype=float)
2c6624cae3c5 planemo upload for repository https://github.com/goeckslab/gleam.git commit 4fade0f8199988fd9cf56cbcb49fa4b949b659ec
goeckslab
parents: 20
diff changeset
129 total_prob = probs_arr.sum()
2c6624cae3c5 planemo upload for repository https://github.com/goeckslab/gleam.git commit 4fade0f8199988fd9cf56cbcb49fa4b949b659ec
goeckslab
parents: 20
diff changeset
130 if total_prob <= 0:
2c6624cae3c5 planemo upload for repository https://github.com/goeckslab/gleam.git commit 4fade0f8199988fd9cf56cbcb49fa4b949b659ec
goeckslab
parents: 20
diff changeset
131 return counts
2c6624cae3c5 planemo upload for repository https://github.com/goeckslab/gleam.git commit 4fade0f8199988fd9cf56cbcb49fa4b949b659ec
goeckslab
parents: 20
diff changeset
132 probs_arr = probs_arr / total_prob
2c6624cae3c5 planemo upload for repository https://github.com/goeckslab/gleam.git commit 4fade0f8199988fd9cf56cbcb49fa4b949b659ec
goeckslab
parents: 20
diff changeset
133 raw = n_total * probs_arr
2c6624cae3c5 planemo upload for repository https://github.com/goeckslab/gleam.git commit 4fade0f8199988fd9cf56cbcb49fa4b949b659ec
goeckslab
parents: 20
diff changeset
134 floors = np.floor(raw).astype(int)
2c6624cae3c5 planemo upload for repository https://github.com/goeckslab/gleam.git commit 4fade0f8199988fd9cf56cbcb49fa4b949b659ec
goeckslab
parents: 20
diff changeset
135 counts = floors.tolist()
20
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
136
23
2c6624cae3c5 planemo upload for repository https://github.com/goeckslab/gleam.git commit 4fade0f8199988fd9cf56cbcb49fa4b949b659ec
goeckslab
parents: 20
diff changeset
137 leftover = n_total - int(floors.sum())
2c6624cae3c5 planemo upload for repository https://github.com/goeckslab/gleam.git commit 4fade0f8199988fd9cf56cbcb49fa4b949b659ec
goeckslab
parents: 20
diff changeset
138 if leftover > 0:
2c6624cae3c5 planemo upload for repository https://github.com/goeckslab/gleam.git commit 4fade0f8199988fd9cf56cbcb49fa4b949b659ec
goeckslab
parents: 20
diff changeset
139 frac = raw - floors
2c6624cae3c5 planemo upload for repository https://github.com/goeckslab/gleam.git commit 4fade0f8199988fd9cf56cbcb49fa4b949b659ec
goeckslab
parents: 20
diff changeset
140 priority_order = [0, 2, 1]
2c6624cae3c5 planemo upload for repository https://github.com/goeckslab/gleam.git commit 4fade0f8199988fd9cf56cbcb49fa4b949b659ec
goeckslab
parents: 20
diff changeset
141 order = sorted(
2c6624cae3c5 planemo upload for repository https://github.com/goeckslab/gleam.git commit 4fade0f8199988fd9cf56cbcb49fa4b949b659ec
goeckslab
parents: 20
diff changeset
142 active,
2c6624cae3c5 planemo upload for repository https://github.com/goeckslab/gleam.git commit 4fade0f8199988fd9cf56cbcb49fa4b949b659ec
goeckslab
parents: 20
diff changeset
143 key=lambda i: (-frac[i], priority_order.index(i) if i in priority_order else 999),
2c6624cae3c5 planemo upload for repository https://github.com/goeckslab/gleam.git commit 4fade0f8199988fd9cf56cbcb49fa4b949b659ec
goeckslab
parents: 20
diff changeset
144 )
2c6624cae3c5 planemo upload for repository https://github.com/goeckslab/gleam.git commit 4fade0f8199988fd9cf56cbcb49fa4b949b659ec
goeckslab
parents: 20
diff changeset
145 for i in range(leftover):
2c6624cae3c5 planemo upload for repository https://github.com/goeckslab/gleam.git commit 4fade0f8199988fd9cf56cbcb49fa4b949b659ec
goeckslab
parents: 20
diff changeset
146 counts[order[i % len(order)]] += 1
2c6624cae3c5 planemo upload for repository https://github.com/goeckslab/gleam.git commit 4fade0f8199988fd9cf56cbcb49fa4b949b659ec
goeckslab
parents: 20
diff changeset
147
2c6624cae3c5 planemo upload for repository https://github.com/goeckslab/gleam.git commit 4fade0f8199988fd9cf56cbcb49fa4b949b659ec
goeckslab
parents: 20
diff changeset
148 # Ensure at least one per active split by moving from other splits.
2c6624cae3c5 planemo upload for repository https://github.com/goeckslab/gleam.git commit 4fade0f8199988fd9cf56cbcb49fa4b949b659ec
goeckslab
parents: 20
diff changeset
149 missing = [i for i in active if counts[i] == 0]
2c6624cae3c5 planemo upload for repository https://github.com/goeckslab/gleam.git commit 4fade0f8199988fd9cf56cbcb49fa4b949b659ec
goeckslab
parents: 20
diff changeset
150 if missing:
2c6624cae3c5 planemo upload for repository https://github.com/goeckslab/gleam.git commit 4fade0f8199988fd9cf56cbcb49fa4b949b659ec
goeckslab
parents: 20
diff changeset
151 priority_order = [0, 2, 1]
2c6624cae3c5 planemo upload for repository https://github.com/goeckslab/gleam.git commit 4fade0f8199988fd9cf56cbcb49fa4b949b659ec
goeckslab
parents: 20
diff changeset
152 missing_ordered = [i for i in priority_order if i in missing] + [
2c6624cae3c5 planemo upload for repository https://github.com/goeckslab/gleam.git commit 4fade0f8199988fd9cf56cbcb49fa4b949b659ec
goeckslab
parents: 20
diff changeset
153 i for i in missing if i not in priority_order
2c6624cae3c5 planemo upload for repository https://github.com/goeckslab/gleam.git commit 4fade0f8199988fd9cf56cbcb49fa4b949b659ec
goeckslab
parents: 20
diff changeset
154 ]
2c6624cae3c5 planemo upload for repository https://github.com/goeckslab/gleam.git commit 4fade0f8199988fd9cf56cbcb49fa4b949b659ec
goeckslab
parents: 20
diff changeset
155 for idx in missing_ordered:
2c6624cae3c5 planemo upload for repository https://github.com/goeckslab/gleam.git commit 4fade0f8199988fd9cf56cbcb49fa4b949b659ec
goeckslab
parents: 20
diff changeset
156 donors = [i for i in active if counts[i] > 1 and i != idx]
2c6624cae3c5 planemo upload for repository https://github.com/goeckslab/gleam.git commit 4fade0f8199988fd9cf56cbcb49fa4b949b659ec
goeckslab
parents: 20
diff changeset
157 if not donors:
2c6624cae3c5 planemo upload for repository https://github.com/goeckslab/gleam.git commit 4fade0f8199988fd9cf56cbcb49fa4b949b659ec
goeckslab
parents: 20
diff changeset
158 break
2c6624cae3c5 planemo upload for repository https://github.com/goeckslab/gleam.git commit 4fade0f8199988fd9cf56cbcb49fa4b949b659ec
goeckslab
parents: 20
diff changeset
159 # Prefer taking from lower-priority splits first (val -> test -> train)
2c6624cae3c5 planemo upload for repository https://github.com/goeckslab/gleam.git commit 4fade0f8199988fd9cf56cbcb49fa4b949b659ec
goeckslab
parents: 20
diff changeset
160 donor_priority = [1, 2, 0]
2c6624cae3c5 planemo upload for repository https://github.com/goeckslab/gleam.git commit 4fade0f8199988fd9cf56cbcb49fa4b949b659ec
goeckslab
parents: 20
diff changeset
161 donors_sorted = sorted(
2c6624cae3c5 planemo upload for repository https://github.com/goeckslab/gleam.git commit 4fade0f8199988fd9cf56cbcb49fa4b949b659ec
goeckslab
parents: 20
diff changeset
162 donors,
2c6624cae3c5 planemo upload for repository https://github.com/goeckslab/gleam.git commit 4fade0f8199988fd9cf56cbcb49fa4b949b659ec
goeckslab
parents: 20
diff changeset
163 key=lambda i: (
2c6624cae3c5 planemo upload for repository https://github.com/goeckslab/gleam.git commit 4fade0f8199988fd9cf56cbcb49fa4b949b659ec
goeckslab
parents: 20
diff changeset
164 -counts[i],
2c6624cae3c5 planemo upload for repository https://github.com/goeckslab/gleam.git commit 4fade0f8199988fd9cf56cbcb49fa4b949b659ec
goeckslab
parents: 20
diff changeset
165 donor_priority.index(i) if i in donor_priority else 999,
2c6624cae3c5 planemo upload for repository https://github.com/goeckslab/gleam.git commit 4fade0f8199988fd9cf56cbcb49fa4b949b659ec
goeckslab
parents: 20
diff changeset
166 ),
2c6624cae3c5 planemo upload for repository https://github.com/goeckslab/gleam.git commit 4fade0f8199988fd9cf56cbcb49fa4b949b659ec
goeckslab
parents: 20
diff changeset
167 )
2c6624cae3c5 planemo upload for repository https://github.com/goeckslab/gleam.git commit 4fade0f8199988fd9cf56cbcb49fa4b949b659ec
goeckslab
parents: 20
diff changeset
168 donor = donors_sorted[0]
2c6624cae3c5 planemo upload for repository https://github.com/goeckslab/gleam.git commit 4fade0f8199988fd9cf56cbcb49fa4b949b659ec
goeckslab
parents: 20
diff changeset
169 counts[donor] -= 1
2c6624cae3c5 planemo upload for repository https://github.com/goeckslab/gleam.git commit 4fade0f8199988fd9cf56cbcb49fa4b949b659ec
goeckslab
parents: 20
diff changeset
170 counts[idx] += 1
20
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
171
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
172 return counts
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
173
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
174 def _choose_split(counts: list, targets: list, active: list) -> int:
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
175 remaining = [targets[i] - counts[i] for i in range(len(targets))]
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
176 best = max(active, key=lambda i: (remaining[i], -counts[i], -targets[i]))
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
177 if remaining[best] <= 0:
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
178 best = min(active, key=lambda i: counts[i])
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
179 return best
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
180
12
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
181 if not label_column or label_column not in out.columns:
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
182 logger.warning(
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
183 "No label column found; using random split without stratification"
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
184 )
20
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
185 # fall back to random assignment (group-aware if requested)
12
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
186 indices = out.index.tolist()
20
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
187 rng = np.random.RandomState(random_state)
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
188
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
189 if group_column:
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
190 group_series = out[group_column].astype(object)
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
191 missing_mask = group_series.isna()
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
192 if missing_mask.any():
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
193 group_series = group_series.copy()
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
194 group_series.loc[missing_mask] = [
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
195 f"__missing__{idx}" for idx in group_series.index[missing_mask]
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
196 ]
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
197 group_to_indices = {}
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
198 for idx, group_id in group_series.items():
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
199 group_to_indices.setdefault(group_id, []).append(idx)
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
200 group_ids = sorted(group_to_indices.keys(), key=lambda x: str(x))
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
201 rng.shuffle(group_ids)
12
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
202
20
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
203 targets = _allocate_split_counts(len(indices), split_probabilities)
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
204 counts = [0 for _ in split_probabilities]
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
205 active = [i for i, p in enumerate(split_probabilities) if p > 0]
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
206 train_idx = []
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
207 val_idx = []
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
208 test_idx = []
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
209
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
210 for group_id in group_ids:
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
211 size = len(group_to_indices[group_id])
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
212 split_idx = _choose_split(counts, targets, active)
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
213 counts[split_idx] += size
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
214 if split_idx == 0:
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
215 train_idx.extend(group_to_indices[group_id])
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
216 elif split_idx == 1:
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
217 val_idx.extend(group_to_indices[group_id])
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
218 else:
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
219 test_idx.extend(group_to_indices[group_id])
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
220
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
221 out.loc[train_idx, split_column] = 0
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
222 out.loc[val_idx, split_column] = 1
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
223 out.loc[test_idx, split_column] = 2
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
224 return out.astype({split_column: int})
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
225
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
226 rng.shuffle(indices)
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
227
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
228 targets = _allocate_split_counts(len(indices), split_probabilities)
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
229 n_train, n_val, n_test = targets
12
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
230
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
231 out.loc[indices[:n_train], split_column] = 0
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
232 out.loc[indices[n_train:n_train + n_val], split_column] = 1
20
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
233 out.loc[indices[n_train + n_val:n_train + n_val + n_test], split_column] = 2
12
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
234
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
235 return out.astype({split_column: int})
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
236
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
237 # check if stratification is possible
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
238 label_counts = out[label_column].value_counts()
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
239 min_samples_per_class = label_counts.min()
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
240
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
241 # ensure we have enough samples for stratification:
20
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
242 # Each class must have at least as many samples as the number of nonzero splits,
12
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
243 # so that each split can receive at least one sample per class.
20
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
244 active_splits = [p for p in split_probabilities if p > 0]
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
245 min_samples_required = len(active_splits)
12
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
246 if min_samples_per_class < min_samples_required:
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
247 logger.warning(
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
248 f"Insufficient samples per class for stratification (min: {min_samples_per_class}, required: {min_samples_required}); using random split"
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
249 )
20
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
250 # fall back to simple random assignment (group-aware if requested)
12
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
251 indices = out.index.tolist()
20
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
252 rng = np.random.RandomState(random_state)
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
253
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
254 if group_column:
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
255 group_series = out[group_column].astype(object)
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
256 missing_mask = group_series.isna()
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
257 if missing_mask.any():
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
258 group_series = group_series.copy()
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
259 group_series.loc[missing_mask] = [
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
260 f"__missing__{idx}" for idx in group_series.index[missing_mask]
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
261 ]
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
262 group_to_indices = {}
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
263 for idx, group_id in group_series.items():
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
264 group_to_indices.setdefault(group_id, []).append(idx)
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
265 group_ids = sorted(group_to_indices.keys(), key=lambda x: str(x))
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
266 rng.shuffle(group_ids)
12
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
267
20
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
268 targets = _allocate_split_counts(len(indices), split_probabilities)
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
269 counts = [0 for _ in split_probabilities]
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
270 active = [i for i, p in enumerate(split_probabilities) if p > 0]
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
271 train_idx = []
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
272 val_idx = []
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
273 test_idx = []
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
274
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
275 for group_id in group_ids:
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
276 size = len(group_to_indices[group_id])
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
277 split_idx = _choose_split(counts, targets, active)
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
278 counts[split_idx] += size
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
279 if split_idx == 0:
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
280 train_idx.extend(group_to_indices[group_id])
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
281 elif split_idx == 1:
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
282 val_idx.extend(group_to_indices[group_id])
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
283 else:
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
284 test_idx.extend(group_to_indices[group_id])
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
285
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
286 out.loc[train_idx, split_column] = 0
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
287 out.loc[val_idx, split_column] = 1
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
288 out.loc[test_idx, split_column] = 2
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
289 return out.astype({split_column: int})
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
290
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
291 rng.shuffle(indices)
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
292
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
293 targets = _allocate_split_counts(len(indices), split_probabilities)
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
294 n_train, n_val, n_test = targets
12
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
295
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
296 out.loc[indices[:n_train], split_column] = 0
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
297 out.loc[indices[n_train:n_train + n_val], split_column] = 1
20
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
298 out.loc[indices[n_train + n_val:n_train + n_val + n_test], split_column] = 2
12
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
299
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
300 return out.astype({split_column: int})
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
301
20
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
302 if group_column:
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
303 logger.info(
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
304 "Using stratified random split for train/validation/test sets (grouped by '%s')",
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
305 group_column,
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
306 )
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
307 rng = np.random.RandomState(random_state)
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
308
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
309 group_series = out[group_column].astype(object)
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
310 missing_mask = group_series.isna()
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
311 if missing_mask.any():
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
312 group_series = group_series.copy()
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
313 group_series.loc[missing_mask] = [
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
314 f"__missing__{idx}" for idx in group_series.index[missing_mask]
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
315 ]
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
316
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
317 group_to_indices = {}
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
318 for idx, group_id in group_series.items():
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
319 group_to_indices.setdefault(group_id, []).append(idx)
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
320
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
321 group_ids = sorted(group_to_indices.keys(), key=lambda x: str(x))
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
322 group_labels = {}
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
323 mixed_groups = []
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
324 label_series = out[label_column]
12
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
325
20
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
326 for group_id in group_ids:
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
327 labels = label_series.loc[group_to_indices[group_id]].dropna().unique()
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
328 if len(labels) == 1:
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
329 group_labels[group_id] = labels[0]
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
330 elif len(labels) == 0:
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
331 group_labels[group_id] = None
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
332 else:
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
333 mode_vals = label_series.loc[group_to_indices[group_id]].mode(dropna=True)
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
334 group_labels[group_id] = mode_vals.iloc[0] if not mode_vals.empty else labels[0]
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
335 mixed_groups.append(group_id)
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
336
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
337 if mixed_groups:
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
338 logger.warning(
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
339 "Detected %d groups with multiple labels; using the most common label per group for stratification.",
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
340 len(mixed_groups),
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
341 )
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
342
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
343 train_idx = []
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
344 val_idx = []
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
345 test_idx = []
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
346 active = [i for i, p in enumerate(split_probabilities) if p > 0]
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
347
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
348 for label_value in sorted(label_counts.index.tolist(), key=lambda x: str(x)):
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
349 label_groups = [g for g in group_ids if group_labels.get(g) == label_value]
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
350 if not label_groups:
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
351 continue
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
352 rng.shuffle(label_groups)
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
353 label_total = sum(len(group_to_indices[g]) for g in label_groups)
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
354 targets = _allocate_split_counts(label_total, split_probabilities)
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
355 counts = [0 for _ in split_probabilities]
12
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
356
20
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
357 for group_id in label_groups:
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
358 size = len(group_to_indices[group_id])
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
359 split_idx = _choose_split(counts, targets, active)
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
360 counts[split_idx] += size
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
361 if split_idx == 0:
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
362 train_idx.extend(group_to_indices[group_id])
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
363 elif split_idx == 1:
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
364 val_idx.extend(group_to_indices[group_id])
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
365 else:
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
366 test_idx.extend(group_to_indices[group_id])
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
367
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
368 # Assign groups without a label (or missing labels) using overall targets.
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
369 unlabeled_groups = [g for g in group_ids if group_labels.get(g) is None]
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
370 if unlabeled_groups:
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
371 rng.shuffle(unlabeled_groups)
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
372 total_unlabeled = sum(len(group_to_indices[g]) for g in unlabeled_groups)
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
373 targets = _allocate_split_counts(total_unlabeled, split_probabilities)
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
374 counts = [0 for _ in split_probabilities]
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
375 for group_id in unlabeled_groups:
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
376 size = len(group_to_indices[group_id])
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
377 split_idx = _choose_split(counts, targets, active)
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
378 counts[split_idx] += size
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
379 if split_idx == 0:
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
380 train_idx.extend(group_to_indices[group_id])
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
381 elif split_idx == 1:
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
382 val_idx.extend(group_to_indices[group_id])
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
383 else:
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
384 test_idx.extend(group_to_indices[group_id])
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
385
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
386 out.loc[train_idx, split_column] = 0
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
387 out.loc[val_idx, split_column] = 1
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
388 out.loc[test_idx, split_column] = 2
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
389
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
390 logger.info("Successfully applied stratified random split")
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
391 logger.info(
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
392 f"Split counts: Train={len(train_idx)}, Val={len(val_idx)}, Test={len(test_idx)}"
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
393 )
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
394 return out.astype({split_column: int})
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
395
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
396 logger.info("Using stratified random split for train/validation/test sets (per-class allocation)")
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
397
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
398 rng = np.random.RandomState(random_state)
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
399 label_values = sorted(label_counts.index.tolist(), key=lambda x: str(x))
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
400 train_idx = []
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
401 val_idx = []
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
402 test_idx = []
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
403
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
404 for label_value in label_values:
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
405 label_indices = out.index[out[label_column] == label_value].tolist()
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
406 rng.shuffle(label_indices)
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
407 n_train, n_val, n_test = _allocate_split_counts(len(label_indices), split_probabilities)
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
408 train_idx.extend(label_indices[:n_train])
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
409 val_idx.extend(label_indices[n_train:n_train + n_val])
64872c48a21f planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
goeckslab
parents: 12
diff changeset
410 test_idx.extend(label_indices[n_train + n_val:n_train + n_val + n_test])
12
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
411
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
412 # assign split values
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
413 out.loc[train_idx, split_column] = 0
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
414 out.loc[val_idx, split_column] = 1
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
415 out.loc[test_idx, split_column] = 2
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
416
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
417 logger.info("Successfully applied stratified random split")
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
418 logger.info(
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
419 f"Split counts: Train={len(train_idx)}, Val={len(val_idx)}, Test={len(test_idx)}"
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
420 )
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
421 return out.astype({split_column: int})
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
422
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
423
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
424 class SplitProbAction(argparse.Action):
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
425 def __call__(self, parser, namespace, values, option_string=None):
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
426 train, val, test = values
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
427 total = train + val + test
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
428 if abs(total - 1.0) > 1e-6:
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
429 parser.error(
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
430 f"--split-probabilities must sum to 1.0; "
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
431 f"got {train:.3f} + {val:.3f} + {test:.3f} = {total:.3f}"
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
432 )
bcfa2e234a80 planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff changeset
433 setattr(namespace, self.dest, values)