Mercurial > repos > goeckslab > multimodal_learner
annotate split_logic.py @ 3:25bb80df7c0c draft default tip
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
| author | goeckslab |
|---|---|
| date | Sat, 17 Jan 2026 22:53:42 +0000 |
| parents | 375c36923da1 |
| children |
| rev | line source |
|---|---|
|
0
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1 import logging |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
2 from typing import List, Optional |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
3 |
|
3
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
4 import numpy as np |
|
0
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
5 import pandas as pd |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
6 from sklearn.model_selection import train_test_split |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
7 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
8 logger = logging.getLogger(__name__) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
9 SPLIT_COL = "split" |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
10 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
11 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
12 def _can_stratify(y: pd.Series) -> bool: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
13 return y.nunique() >= 2 and (y.value_counts() >= 2).all() |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
14 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
15 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
16 def split_dataset( |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
17 train_dataset: pd.DataFrame, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
18 test_dataset: Optional[pd.DataFrame], |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
19 target_column: str, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
20 split_probabilities: List[float], |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
21 validation_size: float, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
22 random_seed: int = 42, |
|
3
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
23 sample_id_column: Optional[str] = None, |
|
0
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
24 ) -> None: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
25 if target_column not in train_dataset.columns: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
26 raise ValueError(f"Target column '{target_column}' not found") |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
27 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
28 # Drop NaN labels early |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
29 before = len(train_dataset) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
30 train_dataset.dropna(subset=[target_column], inplace=True) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
31 if len(train_dataset) == 0: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
32 raise ValueError("No rows remain after dropping NaN targets") |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
33 if before != len(train_dataset): |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
34 logger.warning(f"Dropped {before - len(train_dataset)} rows with NaN target") |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
35 y = train_dataset[target_column] |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
36 |
|
3
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
37 if sample_id_column and sample_id_column not in train_dataset.columns: |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
38 logger.warning( |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
39 "Sample ID column '%s' not found; proceeding without group-aware split.", |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
40 sample_id_column, |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
41 ) |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
42 sample_id_column = None |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
43 if sample_id_column and sample_id_column == target_column: |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
44 logger.warning( |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
45 "Sample ID column '%s' matches target column; proceeding without group-aware split.", |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
46 sample_id_column, |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
47 ) |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
48 sample_id_column = None |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
49 |
|
0
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
50 # Respect existing valid split column |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
51 if SPLIT_COL in train_dataset.columns: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
52 unique = set(train_dataset[SPLIT_COL].dropna().unique()) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
53 valid = {"train", "val", "validation", "test"} |
|
3
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
54 if unique.issubset(valid): |
|
0
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
55 train_dataset[SPLIT_COL] = train_dataset[SPLIT_COL].replace("validation", "val") |
|
3
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
56 normalized = set(train_dataset[SPLIT_COL].dropna().unique()) |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
57 required = {"train"} if test_dataset is not None else {"train", "test"} |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
58 missing = required - normalized |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
59 if missing: |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
60 missing_list = ", ".join(sorted(missing)) |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
61 if test_dataset is not None: |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
62 raise ValueError( |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
63 "Pre-existing 'split' column is missing required split(s): " |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
64 f"{missing_list}. Expected at least train when an external test set is provided, " |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
65 "or remove the 'split' column to let the tool create splits." |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
66 ) |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
67 raise ValueError( |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
68 "Pre-existing 'split' column is missing required split(s): " |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
69 f"{missing_list}. Expected at least train and test, " |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
70 "or remove the 'split' column to let the tool create splits." |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
71 ) |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
72 logger.info(f"Using pre-existing 'split' column: {sorted(normalized)}") |
|
0
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
73 return |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
74 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
75 train_dataset[SPLIT_COL] = "train" |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
76 |
|
3
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
77 def _allocate_split_counts(n_total: int, probs: list) -> list: |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
78 if n_total <= 0: |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
79 return [0 for _ in probs] |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
80 counts = [0 for _ in probs] |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
81 active = [i for i, p in enumerate(probs) if p > 0] |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
82 remainder = n_total |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
83 if active and n_total >= len(active): |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
84 for i in active: |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
85 counts[i] = 1 |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
86 remainder -= len(active) |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
87 if remainder > 0: |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
88 probs_arr = np.array(probs, dtype=float) |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
89 probs_arr = probs_arr / probs_arr.sum() |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
90 raw = remainder * probs_arr |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
91 floors = np.floor(raw).astype(int) |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
92 for i, value in enumerate(floors.tolist()): |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
93 counts[i] += value |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
94 leftover = remainder - int(floors.sum()) |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
95 if leftover > 0 and active: |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
96 frac = raw - floors |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
97 order = sorted(active, key=lambda i: (-frac[i], i)) |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
98 for i in range(leftover): |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
99 counts[order[i % len(order)]] += 1 |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
100 return counts |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
101 |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
102 def _choose_split(counts: list, targets: list, active: list) -> int: |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
103 remaining = [targets[i] - counts[i] for i in range(len(targets))] |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
104 best = max(active, key=lambda i: (remaining[i], -counts[i], -targets[i])) |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
105 if remaining[best] <= 0: |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
106 best = min(active, key=lambda i: counts[i]) |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
107 return best |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
108 |
|
0
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
109 if test_dataset is not None: |
|
3
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
110 if sample_id_column and sample_id_column in test_dataset.columns: |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
111 train_ids = set(train_dataset[sample_id_column].dropna().astype(object).unique()) |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
112 test_ids = set(test_dataset[sample_id_column].dropna().astype(object).unique()) |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
113 overlap = train_ids & test_ids |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
114 if overlap: |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
115 logger.warning( |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
116 "Sample ID column '%s' has %d overlapping IDs between train and external test sets; " |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
117 "consider removing overlaps to avoid leakage.", |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
118 sample_id_column, |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
119 len(overlap), |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
120 ) |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
121 if sample_id_column: |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
122 rng = np.random.RandomState(random_seed) |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
123 group_series = train_dataset[sample_id_column].astype(object) |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
124 missing_mask = group_series.isna() |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
125 if missing_mask.any(): |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
126 group_series = group_series.copy() |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
127 group_series.loc[missing_mask] = [ |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
128 f"__missing__{idx}" for idx in group_series.index[missing_mask] |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
129 ] |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
130 group_to_indices = {} |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
131 for idx, group_id in group_series.items(): |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
132 group_to_indices.setdefault(group_id, []).append(idx) |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
133 group_ids = sorted(group_to_indices.keys(), key=lambda x: str(x)) |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
134 rng.shuffle(group_ids) |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
135 targets = _allocate_split_counts(len(train_dataset), [1.0 - validation_size, validation_size]) |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
136 counts = [0, 0] |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
137 active = [0, 1] |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
138 train_idx = [] |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
139 val_idx = [] |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
140 for group_id in group_ids: |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
141 split_idx = _choose_split(counts, targets, active) |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
142 counts[split_idx] += len(group_to_indices[group_id]) |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
143 if split_idx == 0: |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
144 train_idx.extend(group_to_indices[group_id]) |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
145 else: |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
146 val_idx.extend(group_to_indices[group_id]) |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
147 train_dataset.loc[val_idx, SPLIT_COL] = "val" |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
148 else: |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
149 if validation_size <= 0: |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
150 logger.warning( |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
151 "validation_size is %.3f; skipping validation split to avoid train_test_split errors.", |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
152 validation_size, |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
153 ) |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
154 elif validation_size >= 1: |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
155 logger.warning( |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
156 "validation_size is %.3f; assigning all rows to validation to avoid train_test_split errors.", |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
157 validation_size, |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
158 ) |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
159 train_dataset[SPLIT_COL] = "val" |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
160 else: |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
161 stratify = y if _can_stratify(y) else None |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
162 train_idx, val_idx = train_test_split( |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
163 train_dataset.index, test_size=validation_size, |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
164 random_state=random_seed, stratify=stratify |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
165 ) |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
166 train_dataset.loc[val_idx, SPLIT_COL] = "val" |
|
0
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
167 logger.info(f"External test set → created val split ({validation_size:.0%})") |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
168 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
169 else: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
170 p_train, p_val, p_test = split_probabilities |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
171 if abs(p_train + p_val + p_test - 1.0) > 1e-6: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
172 raise ValueError("split_probabilities must sum to 1.0") |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
173 |
|
3
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
174 if sample_id_column: |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
175 rng = np.random.RandomState(random_seed) |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
176 group_series = train_dataset[sample_id_column].astype(object) |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
177 missing_mask = group_series.isna() |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
178 if missing_mask.any(): |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
179 group_series = group_series.copy() |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
180 group_series.loc[missing_mask] = [ |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
181 f"__missing__{idx}" for idx in group_series.index[missing_mask] |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
182 ] |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
183 group_to_indices = {} |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
184 for idx, group_id in group_series.items(): |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
185 group_to_indices.setdefault(group_id, []).append(idx) |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
186 group_ids = sorted(group_to_indices.keys(), key=lambda x: str(x)) |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
187 rng.shuffle(group_ids) |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
188 targets = _allocate_split_counts(len(train_dataset), [p_train, p_val, p_test]) |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
189 counts = [0, 0, 0] |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
190 active = [0, 1, 2] |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
191 train_idx = [] |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
192 val_idx = [] |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
193 test_idx = [] |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
194 for group_id in group_ids: |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
195 split_idx = _choose_split(counts, targets, active) |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
196 counts[split_idx] += len(group_to_indices[group_id]) |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
197 if split_idx == 0: |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
198 train_idx.extend(group_to_indices[group_id]) |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
199 elif split_idx == 1: |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
200 val_idx.extend(group_to_indices[group_id]) |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
201 else: |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
202 test_idx.extend(group_to_indices[group_id]) |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
203 train_dataset.loc[val_idx, SPLIT_COL] = "val" |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
204 train_dataset.loc[test_idx, SPLIT_COL] = "test" |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
205 else: |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
206 stratify = y if _can_stratify(y) else None |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
207 if p_test <= 0: |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
208 logger.warning( |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
209 "split_probabilities specify 0 test size; skipping test split to avoid train_test_split errors." |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
210 ) |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
211 tv_idx = train_dataset.index |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
212 test_idx = train_dataset.index[:0] |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
213 elif p_test >= 1: |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
214 logger.warning( |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
215 "split_probabilities specify 100% test size; assigning all rows to test to avoid train_test_split errors." |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
216 ) |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
217 tv_idx = train_dataset.index[:0] |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
218 test_idx = train_dataset.index |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
219 else: |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
220 tv_idx, test_idx = train_test_split( |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
221 train_dataset.index, test_size=p_test, |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
222 random_state=random_seed, stratify=stratify |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
223 ) |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
224 rel_val = p_val / (p_train + p_val) if (p_train + p_val) > 0 else 0 |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
225 train_idx = train_dataset.index[:0] |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
226 val_idx = train_dataset.index[:0] |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
227 if len(tv_idx): |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
228 if rel_val <= 0: |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
229 logger.warning( |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
230 "split_probabilities specify 0 validation size; skipping validation split to avoid train_test_split errors." |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
231 ) |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
232 train_idx = tv_idx |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
233 elif rel_val >= 1: |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
234 logger.warning( |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
235 "split_probabilities specify 100% validation size; assigning all remaining rows to validation " |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
236 "to avoid train_test_split errors." |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
237 ) |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
238 val_idx = tv_idx |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
239 else: |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
240 strat_tv = y.loc[tv_idx] if _can_stratify(y.loc[tv_idx]) else None |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
241 train_idx, val_idx = train_test_split( |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
242 tv_idx, test_size=rel_val, |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
243 random_state=random_seed, stratify=strat_tv |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
244 ) |
|
0
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
245 |
|
3
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
246 train_dataset.loc[val_idx, SPLIT_COL] = "val" |
|
25bb80df7c0c
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
goeckslab
parents:
0
diff
changeset
|
247 train_dataset.loc[test_idx, SPLIT_COL] = "test" |
|
0
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
248 logger.info(f"3-way split → train:{len(train_idx)}, val:{len(val_idx)}, test:{len(test_idx)}") |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
249 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
250 logger.info(f"Final split distribution:\n{train_dataset[SPLIT_COL].value_counts().sort_index()}") |
