annotate split_logic.py @ 3:25bb80df7c0c draft default tip

planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
author goeckslab
date Sat, 17 Jan 2026 22:53:42 +0000
parents 375c36923da1
children
Ignore whitespace changes - Everywhere: Within whitespace: At end of lines:
rev   line source
0
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1 import logging
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
2 from typing import List, Optional
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
3
3
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
4 import numpy as np
0
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
5 import pandas as pd
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
6 from sklearn.model_selection import train_test_split
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
7
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
8 logger = logging.getLogger(__name__)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
9 SPLIT_COL = "split"
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
10
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
11
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
12 def _can_stratify(y: pd.Series) -> bool:
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
13 return y.nunique() >= 2 and (y.value_counts() >= 2).all()
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
14
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
15
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
16 def split_dataset(
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
17 train_dataset: pd.DataFrame,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
18 test_dataset: Optional[pd.DataFrame],
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
19 target_column: str,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
20 split_probabilities: List[float],
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
21 validation_size: float,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
22 random_seed: int = 42,
3
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
23 sample_id_column: Optional[str] = None,
0
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
24 ) -> None:
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
25 if target_column not in train_dataset.columns:
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
26 raise ValueError(f"Target column '{target_column}' not found")
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
27
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
28 # Drop NaN labels early
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
29 before = len(train_dataset)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
30 train_dataset.dropna(subset=[target_column], inplace=True)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
31 if len(train_dataset) == 0:
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
32 raise ValueError("No rows remain after dropping NaN targets")
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
33 if before != len(train_dataset):
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
34 logger.warning(f"Dropped {before - len(train_dataset)} rows with NaN target")
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
35 y = train_dataset[target_column]
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
36
3
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
37 if sample_id_column and sample_id_column not in train_dataset.columns:
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
38 logger.warning(
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
39 "Sample ID column '%s' not found; proceeding without group-aware split.",
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
40 sample_id_column,
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
41 )
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
42 sample_id_column = None
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
43 if sample_id_column and sample_id_column == target_column:
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
44 logger.warning(
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
45 "Sample ID column '%s' matches target column; proceeding without group-aware split.",
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
46 sample_id_column,
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
47 )
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
48 sample_id_column = None
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
49
0
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
50 # Respect existing valid split column
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
51 if SPLIT_COL in train_dataset.columns:
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
52 unique = set(train_dataset[SPLIT_COL].dropna().unique())
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
53 valid = {"train", "val", "validation", "test"}
3
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
54 if unique.issubset(valid):
0
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
55 train_dataset[SPLIT_COL] = train_dataset[SPLIT_COL].replace("validation", "val")
3
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
56 normalized = set(train_dataset[SPLIT_COL].dropna().unique())
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
57 required = {"train"} if test_dataset is not None else {"train", "test"}
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
58 missing = required - normalized
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
59 if missing:
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
60 missing_list = ", ".join(sorted(missing))
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
61 if test_dataset is not None:
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
62 raise ValueError(
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
63 "Pre-existing 'split' column is missing required split(s): "
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
64 f"{missing_list}. Expected at least train when an external test set is provided, "
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
65 "or remove the 'split' column to let the tool create splits."
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
66 )
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
67 raise ValueError(
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
68 "Pre-existing 'split' column is missing required split(s): "
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
69 f"{missing_list}. Expected at least train and test, "
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
70 "or remove the 'split' column to let the tool create splits."
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
71 )
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
72 logger.info(f"Using pre-existing 'split' column: {sorted(normalized)}")
0
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
73 return
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
74
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
75 train_dataset[SPLIT_COL] = "train"
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
76
3
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
77 def _allocate_split_counts(n_total: int, probs: list) -> list:
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
78 if n_total <= 0:
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
79 return [0 for _ in probs]
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
80 counts = [0 for _ in probs]
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
81 active = [i for i, p in enumerate(probs) if p > 0]
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
82 remainder = n_total
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
83 if active and n_total >= len(active):
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
84 for i in active:
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
85 counts[i] = 1
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
86 remainder -= len(active)
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
87 if remainder > 0:
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
88 probs_arr = np.array(probs, dtype=float)
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
89 probs_arr = probs_arr / probs_arr.sum()
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
90 raw = remainder * probs_arr
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
91 floors = np.floor(raw).astype(int)
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
92 for i, value in enumerate(floors.tolist()):
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
93 counts[i] += value
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
94 leftover = remainder - int(floors.sum())
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
95 if leftover > 0 and active:
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
96 frac = raw - floors
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
97 order = sorted(active, key=lambda i: (-frac[i], i))
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
98 for i in range(leftover):
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
99 counts[order[i % len(order)]] += 1
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
100 return counts
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
101
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
102 def _choose_split(counts: list, targets: list, active: list) -> int:
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
103 remaining = [targets[i] - counts[i] for i in range(len(targets))]
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
104 best = max(active, key=lambda i: (remaining[i], -counts[i], -targets[i]))
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
105 if remaining[best] <= 0:
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
106 best = min(active, key=lambda i: counts[i])
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
107 return best
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
108
0
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
109 if test_dataset is not None:
3
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
110 if sample_id_column and sample_id_column in test_dataset.columns:
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
111 train_ids = set(train_dataset[sample_id_column].dropna().astype(object).unique())
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
112 test_ids = set(test_dataset[sample_id_column].dropna().astype(object).unique())
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
113 overlap = train_ids & test_ids
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
114 if overlap:
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
115 logger.warning(
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
116 "Sample ID column '%s' has %d overlapping IDs between train and external test sets; "
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
117 "consider removing overlaps to avoid leakage.",
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
118 sample_id_column,
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
119 len(overlap),
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
120 )
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
121 if sample_id_column:
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
122 rng = np.random.RandomState(random_seed)
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
123 group_series = train_dataset[sample_id_column].astype(object)
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
124 missing_mask = group_series.isna()
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
125 if missing_mask.any():
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
126 group_series = group_series.copy()
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
127 group_series.loc[missing_mask] = [
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
128 f"__missing__{idx}" for idx in group_series.index[missing_mask]
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
129 ]
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
130 group_to_indices = {}
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
131 for idx, group_id in group_series.items():
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
132 group_to_indices.setdefault(group_id, []).append(idx)
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
133 group_ids = sorted(group_to_indices.keys(), key=lambda x: str(x))
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
134 rng.shuffle(group_ids)
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
135 targets = _allocate_split_counts(len(train_dataset), [1.0 - validation_size, validation_size])
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
136 counts = [0, 0]
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
137 active = [0, 1]
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
138 train_idx = []
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
139 val_idx = []
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
140 for group_id in group_ids:
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
141 split_idx = _choose_split(counts, targets, active)
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
142 counts[split_idx] += len(group_to_indices[group_id])
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
143 if split_idx == 0:
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
144 train_idx.extend(group_to_indices[group_id])
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
145 else:
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
146 val_idx.extend(group_to_indices[group_id])
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
147 train_dataset.loc[val_idx, SPLIT_COL] = "val"
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
148 else:
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
149 if validation_size <= 0:
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
150 logger.warning(
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
151 "validation_size is %.3f; skipping validation split to avoid train_test_split errors.",
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
152 validation_size,
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
153 )
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
154 elif validation_size >= 1:
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
155 logger.warning(
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
156 "validation_size is %.3f; assigning all rows to validation to avoid train_test_split errors.",
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
157 validation_size,
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
158 )
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
159 train_dataset[SPLIT_COL] = "val"
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
160 else:
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
161 stratify = y if _can_stratify(y) else None
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
162 train_idx, val_idx = train_test_split(
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
163 train_dataset.index, test_size=validation_size,
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
164 random_state=random_seed, stratify=stratify
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
165 )
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
166 train_dataset.loc[val_idx, SPLIT_COL] = "val"
0
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
167 logger.info(f"External test set → created val split ({validation_size:.0%})")
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
168
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
169 else:
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
170 p_train, p_val, p_test = split_probabilities
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
171 if abs(p_train + p_val + p_test - 1.0) > 1e-6:
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
172 raise ValueError("split_probabilities must sum to 1.0")
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
173
3
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
174 if sample_id_column:
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
175 rng = np.random.RandomState(random_seed)
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
176 group_series = train_dataset[sample_id_column].astype(object)
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
177 missing_mask = group_series.isna()
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
178 if missing_mask.any():
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
179 group_series = group_series.copy()
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
180 group_series.loc[missing_mask] = [
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
181 f"__missing__{idx}" for idx in group_series.index[missing_mask]
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
182 ]
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
183 group_to_indices = {}
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
184 for idx, group_id in group_series.items():
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
185 group_to_indices.setdefault(group_id, []).append(idx)
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
186 group_ids = sorted(group_to_indices.keys(), key=lambda x: str(x))
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
187 rng.shuffle(group_ids)
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
188 targets = _allocate_split_counts(len(train_dataset), [p_train, p_val, p_test])
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
189 counts = [0, 0, 0]
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
190 active = [0, 1, 2]
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
191 train_idx = []
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
192 val_idx = []
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
193 test_idx = []
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
194 for group_id in group_ids:
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
195 split_idx = _choose_split(counts, targets, active)
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
196 counts[split_idx] += len(group_to_indices[group_id])
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
197 if split_idx == 0:
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
198 train_idx.extend(group_to_indices[group_id])
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
199 elif split_idx == 1:
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
200 val_idx.extend(group_to_indices[group_id])
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
201 else:
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
202 test_idx.extend(group_to_indices[group_id])
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
203 train_dataset.loc[val_idx, SPLIT_COL] = "val"
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
204 train_dataset.loc[test_idx, SPLIT_COL] = "test"
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
205 else:
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
206 stratify = y if _can_stratify(y) else None
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
207 if p_test <= 0:
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
208 logger.warning(
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
209 "split_probabilities specify 0 test size; skipping test split to avoid train_test_split errors."
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
210 )
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
211 tv_idx = train_dataset.index
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
212 test_idx = train_dataset.index[:0]
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
213 elif p_test >= 1:
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
214 logger.warning(
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
215 "split_probabilities specify 100% test size; assigning all rows to test to avoid train_test_split errors."
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
216 )
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
217 tv_idx = train_dataset.index[:0]
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
218 test_idx = train_dataset.index
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
219 else:
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
220 tv_idx, test_idx = train_test_split(
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
221 train_dataset.index, test_size=p_test,
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
222 random_state=random_seed, stratify=stratify
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
223 )
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
224 rel_val = p_val / (p_train + p_val) if (p_train + p_val) > 0 else 0
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
225 train_idx = train_dataset.index[:0]
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
226 val_idx = train_dataset.index[:0]
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
227 if len(tv_idx):
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
228 if rel_val <= 0:
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
229 logger.warning(
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
230 "split_probabilities specify 0 validation size; skipping validation split to avoid train_test_split errors."
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
231 )
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
232 train_idx = tv_idx
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
233 elif rel_val >= 1:
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
234 logger.warning(
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
235 "split_probabilities specify 100% validation size; assigning all remaining rows to validation "
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
236 "to avoid train_test_split errors."
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
237 )
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
238 val_idx = tv_idx
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
239 else:
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
240 strat_tv = y.loc[tv_idx] if _can_stratify(y.loc[tv_idx]) else None
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
241 train_idx, val_idx = train_test_split(
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
242 tv_idx, test_size=rel_val,
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
243 random_state=random_seed, stratify=strat_tv
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
244 )
0
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
245
3
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
246 train_dataset.loc[val_idx, SPLIT_COL] = "val"
25bb80df7c0c planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents: 0
diff changeset
247 train_dataset.loc[test_idx, SPLIT_COL] = "test"
0
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
248 logger.info(f"3-way split → train:{len(train_idx)}, val:{len(val_idx)}, test:{len(test_idx)}")
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
249
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
250 logger.info(f"Final split distribution:\n{train_dataset[SPLIT_COL].value_counts().sort_index()}")