Mercurial > repos > goeckslab > image_learner
diff split_data.py @ 23:2c6624cae3c5 draft default tip
planemo upload for repository https://github.com/goeckslab/gleam.git commit 4fade0f8199988fd9cf56cbcb49fa4b949b659ec
| author | goeckslab |
|---|---|
| date | Sun, 25 Jan 2026 01:09:56 +0000 |
| parents | 64872c48a21f |
| children |
line wrap: on
line diff
--- a/split_data.py Fri Jan 23 20:25:27 2026 +0000 +++ b/split_data.py Sun Jan 25 01:09:56 2026 +0000 @@ -97,32 +97,77 @@ group_column = None def _allocate_split_counts(n_total: int, probs: list) -> list: - """Allocate exact split counts using largest remainder rounding.""" + """Allocate exact split counts using largest remainder rounding. + + Ensures at least one sample per active split *after* proportional allocation, + by moving samples from other splits when possible. Tie-breaking for leftovers + and corrections prioritizes: train (0), test (2), validation (1). + """ if n_total <= 0: return [0 for _ in probs] counts = [0 for _ in probs] active = [i for i, p in enumerate(probs) if p > 0] - remainder = n_total + if not active: + return counts - if active and n_total >= len(active): - for i in active: - counts[i] = 1 - remainder -= len(active) + # If there are fewer samples than active splits, fill in order and return. + if n_total < len(active): + priority_order = [0, 2, 1] + ordered = [i for i in priority_order if i in active] + [ + i for i in active if i not in priority_order + ] + remaining = n_total + for idx in ordered: + if remaining <= 0: + break + counts[idx] = 1 + remaining -= 1 + return counts + + probs_arr = np.array(probs, dtype=float) + total_prob = probs_arr.sum() + if total_prob <= 0: + return counts + probs_arr = probs_arr / total_prob + raw = n_total * probs_arr + floors = np.floor(raw).astype(int) + counts = floors.tolist() - if remainder > 0: - probs_arr = np.array(probs, dtype=float) - probs_arr = probs_arr / probs_arr.sum() - raw = remainder * probs_arr - floors = np.floor(raw).astype(int) - for i, value in enumerate(floors.tolist()): - counts[i] += value - leftover = remainder - int(floors.sum()) - if leftover > 0 and active: - frac = raw - floors - order = sorted(active, key=lambda i: (-frac[i], i)) - for i in range(leftover): - counts[order[i % len(order)]] += 1 + leftover = n_total - int(floors.sum()) + if leftover > 0: + frac = raw - floors + priority_order = [0, 2, 1] + order = sorted( + active, + key=lambda i: (-frac[i], priority_order.index(i) if i in priority_order else 999), + ) + for i in range(leftover): + counts[order[i % len(order)]] += 1 + + # Ensure at least one per active split by moving from other splits. + missing = [i for i in active if counts[i] == 0] + if missing: + priority_order = [0, 2, 1] + missing_ordered = [i for i in priority_order if i in missing] + [ + i for i in missing if i not in priority_order + ] + for idx in missing_ordered: + donors = [i for i in active if counts[i] > 1 and i != idx] + if not donors: + break + # Prefer taking from lower-priority splits first (val -> test -> train) + donor_priority = [1, 2, 0] + donors_sorted = sorted( + donors, + key=lambda i: ( + -counts[i], + donor_priority.index(i) if i in donor_priority else 999, + ), + ) + donor = donors_sorted[0] + counts[donor] -= 1 + counts[idx] += 1 return counts
