diff base_model_trainer.py @ 15:01e7c5481f13 draft default tip

planemo upload for repository https://github.com/goeckslab/gleam commit f632803cda732005bdcf3ac3e8fe7a807a82c1d9
author goeckslab
date Mon, 19 Jan 2026 05:54:52 +0000
parents bf0df21a1ea3
children
line wrap: on
line diff
--- a/base_model_trainer.py	Mon Dec 29 20:34:38 2025 +0000
+++ b/base_model_trainer.py	Mon Jan 19 05:54:52 2026 +0000
@@ -119,6 +119,13 @@
             )
 
         self.target = names[target_index]
+        sample_id_column = getattr(self, "sample_id_column", None)
+        if sample_id_column:
+            sample_id_column = sample_id_column.replace(".", "_")
+            self.sample_id_column = sample_id_column
+        else:
+            self.sample_id_column = None
+        self.sample_id_series = None
 
         # Conditional drop: only if 'prediction_label' exists and is not
         # the target
@@ -154,8 +161,24 @@
         names = self.data.columns.to_list()
         LOG.info(f"Dataset columns after processing: {names}")
 
-        self.features_name = [n for n in names if n != self.target]
-        self.plot_feature_names = self._select_plot_features(self.features_name)
+        sample_id_valid = False
+        if sample_id_column:
+            if sample_id_column not in self.data.columns:
+                LOG.warning(
+                    "Sample ID column '%s' not found; proceeding without group-aware split.",
+                    sample_id_column,
+                )
+                sample_id_column = None
+                self.sample_id_column = None
+            elif sample_id_column == self.target:
+                LOG.warning(
+                    "Sample ID column '%s' matches target column; skipping group-aware split.",
+                    sample_id_column,
+                )
+                sample_id_column = None
+                self.sample_id_column = None
+            else:
+                sample_id_valid = True
 
         if self.test_file:
             LOG.info(f"Loading test data from {self.test_file}")
@@ -165,6 +188,113 @@
             df_test.columns = df_test.columns.str.replace(".", "_")
             self.test_data = df_test
 
+        if sample_id_valid and self.test_data is None:
+            train_size = getattr(self, "train_size", None)
+            if train_size is None:
+                train_size = 0.7
+            if train_size <= 0 or train_size >= 1:
+                LOG.warning(
+                    "Invalid train_size=%s; skipping group-aware split.",
+                    train_size,
+                )
+            else:
+                rng = np.random.RandomState(self.random_seed)
+
+                def _allocate_split_counts(n_total: int, probs: list) -> list:
+                    if n_total <= 0:
+                        return [0 for _ in probs]
+                    counts = [0 for _ in probs]
+                    active = [i for i, p in enumerate(probs) if p > 0]
+                    remainder = n_total
+                    if active and n_total >= len(active):
+                        for i in active:
+                            counts[i] = 1
+                        remainder -= len(active)
+                    if remainder > 0:
+                        probs_arr = np.array(probs, dtype=float)
+                        probs_arr = probs_arr / probs_arr.sum()
+                        raw = remainder * probs_arr
+                        floors = np.floor(raw).astype(int)
+                        for i, value in enumerate(floors.tolist()):
+                            counts[i] += value
+                        leftover = remainder - int(floors.sum())
+                        if leftover > 0 and active:
+                            frac = raw - floors
+                            order = sorted(active, key=lambda i: (-frac[i], i))
+                            for i in range(leftover):
+                                counts[order[i % len(order)]] += 1
+                    return counts
+
+                def _choose_split(counts: list, targets: list, active: list) -> int:
+                    remaining = [targets[i] - counts[i] for i in range(len(targets))]
+                    best = max(active, key=lambda i: (remaining[i], -counts[i], -targets[i]))
+                    if remaining[best] <= 0:
+                        best = min(active, key=lambda i: counts[i])
+                    return best
+
+                probs = [train_size, 1.0 - train_size]
+                targets = _allocate_split_counts(len(self.data), probs)
+                counts = [0, 0]
+                active = [0, 1]
+                train_idx = []
+                test_idx = []
+
+                group_series = self.data[sample_id_column].astype(object)
+                missing_mask = group_series.isna()
+                if missing_mask.any():
+                    group_series = group_series.copy()
+                    group_series.loc[missing_mask] = [
+                        f"__missing__{idx}" for idx in group_series.index[missing_mask]
+                    ]
+                group_to_indices = {}
+                for idx, group_id in group_series.items():
+                    group_to_indices.setdefault(group_id, []).append(idx)
+
+                group_ids = sorted(group_to_indices.keys(), key=lambda x: str(x))
+                rng.shuffle(group_ids)
+
+                for group_id in group_ids:
+                    split_idx = _choose_split(counts, targets, active)
+                    counts[split_idx] += len(group_to_indices[group_id])
+                    if split_idx == 0:
+                        train_idx.extend(group_to_indices[group_id])
+                    else:
+                        test_idx.extend(group_to_indices[group_id])
+
+                missing_splits = []
+                if not train_idx:
+                    missing_splits.append("train")
+                if not test_idx:
+                    missing_splits.append("test")
+                if missing_splits:
+                    LOG.warning(
+                        "Group-aware split using '%s' produced empty %s set; "
+                        "falling back to default split.",
+                        sample_id_column,
+                        " and ".join(missing_splits),
+                    )
+                else:
+                    self.test_data = self.data.loc[test_idx].reset_index(drop=True)
+                    self.data = self.data.loc[train_idx].reset_index(drop=True)
+                    LOG.info(
+                        "Applied group-aware split using '%s' (train=%s, test=%s).",
+                        sample_id_column,
+                        len(train_idx),
+                        len(test_idx),
+                    )
+
+        if sample_id_valid:
+            self.sample_id_series = self.data[sample_id_column].copy()
+            if sample_id_column in self.data.columns:
+                self.data = self.data.drop(columns=[sample_id_column])
+            if self.test_data is not None and sample_id_column in self.test_data.columns:
+                self.test_data = self.test_data.drop(columns=[sample_id_column])
+
+        # Refresh feature lists after any sample-id column removal.
+        names = self.data.columns.to_list()
+        self.features_name = [n for n in names if n != self.target]
+        self.plot_feature_names = self._select_plot_features(self.features_name)
+
     def _select_plot_features(self, all_features):
         limit = getattr(self, "plot_feature_limit", 30)
         if not isinstance(limit, int) or limit <= 0:
@@ -242,6 +372,25 @@
             self.setup_params["fold"] = self.cross_validation_folds
         LOG.info(self.setup_params)
 
+        group_series = getattr(self, "sample_id_series", None)
+        if group_series is not None and getattr(self, "cross_validation", None) is not False:
+            n_groups = pd.Series(group_series).nunique(dropna=False)
+            fold_count = getattr(self, "cross_validation_folds", None)
+            if fold_count is not None and fold_count > n_groups:
+                LOG.warning(
+                    "cross_validation_folds=%s exceeds unique groups=%s; "
+                    "skipping group-aware CV.",
+                    fold_count,
+                    n_groups,
+                )
+            else:
+                self.setup_params["fold_strategy"] = "groupkfold"
+                self.setup_params["fold_groups"] = pd.Series(group_series).reset_index(drop=True)
+                LOG.info(
+                    "Enabled group-aware CV with %s unique groups.",
+                    n_groups,
+                )
+
         if self.task_type == "classification":
             from pycaret.classification import ClassificationExperiment