Mercurial > repos > goeckslab > tabular_learner
diff base_model_trainer.py @ 15:01e7c5481f13 draft default tip
planemo upload for repository https://github.com/goeckslab/gleam commit f632803cda732005bdcf3ac3e8fe7a807a82c1d9
| author | goeckslab |
|---|---|
| date | Mon, 19 Jan 2026 05:54:52 +0000 |
| parents | bf0df21a1ea3 |
| children |
line wrap: on
line diff
--- a/base_model_trainer.py Mon Dec 29 20:34:38 2025 +0000 +++ b/base_model_trainer.py Mon Jan 19 05:54:52 2026 +0000 @@ -119,6 +119,13 @@ ) self.target = names[target_index] + sample_id_column = getattr(self, "sample_id_column", None) + if sample_id_column: + sample_id_column = sample_id_column.replace(".", "_") + self.sample_id_column = sample_id_column + else: + self.sample_id_column = None + self.sample_id_series = None # Conditional drop: only if 'prediction_label' exists and is not # the target @@ -154,8 +161,24 @@ names = self.data.columns.to_list() LOG.info(f"Dataset columns after processing: {names}") - self.features_name = [n for n in names if n != self.target] - self.plot_feature_names = self._select_plot_features(self.features_name) + sample_id_valid = False + if sample_id_column: + if sample_id_column not in self.data.columns: + LOG.warning( + "Sample ID column '%s' not found; proceeding without group-aware split.", + sample_id_column, + ) + sample_id_column = None + self.sample_id_column = None + elif sample_id_column == self.target: + LOG.warning( + "Sample ID column '%s' matches target column; skipping group-aware split.", + sample_id_column, + ) + sample_id_column = None + self.sample_id_column = None + else: + sample_id_valid = True if self.test_file: LOG.info(f"Loading test data from {self.test_file}") @@ -165,6 +188,113 @@ df_test.columns = df_test.columns.str.replace(".", "_") self.test_data = df_test + if sample_id_valid and self.test_data is None: + train_size = getattr(self, "train_size", None) + if train_size is None: + train_size = 0.7 + if train_size <= 0 or train_size >= 1: + LOG.warning( + "Invalid train_size=%s; skipping group-aware split.", + train_size, + ) + else: + rng = np.random.RandomState(self.random_seed) + + def _allocate_split_counts(n_total: int, probs: list) -> list: + if n_total <= 0: + return [0 for _ in probs] + counts = [0 for _ in probs] + active = [i for i, p in enumerate(probs) if p > 0] + remainder = n_total + if active and n_total >= len(active): + for i in active: + counts[i] = 1 + remainder -= len(active) + if remainder > 0: + probs_arr = np.array(probs, dtype=float) + probs_arr = probs_arr / probs_arr.sum() + raw = remainder * probs_arr + floors = np.floor(raw).astype(int) + for i, value in enumerate(floors.tolist()): + counts[i] += value + leftover = remainder - int(floors.sum()) + if leftover > 0 and active: + frac = raw - floors + order = sorted(active, key=lambda i: (-frac[i], i)) + for i in range(leftover): + counts[order[i % len(order)]] += 1 + return counts + + def _choose_split(counts: list, targets: list, active: list) -> int: + remaining = [targets[i] - counts[i] for i in range(len(targets))] + best = max(active, key=lambda i: (remaining[i], -counts[i], -targets[i])) + if remaining[best] <= 0: + best = min(active, key=lambda i: counts[i]) + return best + + probs = [train_size, 1.0 - train_size] + targets = _allocate_split_counts(len(self.data), probs) + counts = [0, 0] + active = [0, 1] + train_idx = [] + test_idx = [] + + group_series = self.data[sample_id_column].astype(object) + missing_mask = group_series.isna() + if missing_mask.any(): + group_series = group_series.copy() + group_series.loc[missing_mask] = [ + f"__missing__{idx}" for idx in group_series.index[missing_mask] + ] + group_to_indices = {} + for idx, group_id in group_series.items(): + group_to_indices.setdefault(group_id, []).append(idx) + + group_ids = sorted(group_to_indices.keys(), key=lambda x: str(x)) + rng.shuffle(group_ids) + + for group_id in group_ids: + split_idx = _choose_split(counts, targets, active) + counts[split_idx] += len(group_to_indices[group_id]) + if split_idx == 0: + train_idx.extend(group_to_indices[group_id]) + else: + test_idx.extend(group_to_indices[group_id]) + + missing_splits = [] + if not train_idx: + missing_splits.append("train") + if not test_idx: + missing_splits.append("test") + if missing_splits: + LOG.warning( + "Group-aware split using '%s' produced empty %s set; " + "falling back to default split.", + sample_id_column, + " and ".join(missing_splits), + ) + else: + self.test_data = self.data.loc[test_idx].reset_index(drop=True) + self.data = self.data.loc[train_idx].reset_index(drop=True) + LOG.info( + "Applied group-aware split using '%s' (train=%s, test=%s).", + sample_id_column, + len(train_idx), + len(test_idx), + ) + + if sample_id_valid: + self.sample_id_series = self.data[sample_id_column].copy() + if sample_id_column in self.data.columns: + self.data = self.data.drop(columns=[sample_id_column]) + if self.test_data is not None and sample_id_column in self.test_data.columns: + self.test_data = self.test_data.drop(columns=[sample_id_column]) + + # Refresh feature lists after any sample-id column removal. + names = self.data.columns.to_list() + self.features_name = [n for n in names if n != self.target] + self.plot_feature_names = self._select_plot_features(self.features_name) + def _select_plot_features(self, all_features): limit = getattr(self, "plot_feature_limit", 30) if not isinstance(limit, int) or limit <= 0: @@ -242,6 +372,25 @@ self.setup_params["fold"] = self.cross_validation_folds LOG.info(self.setup_params) + group_series = getattr(self, "sample_id_series", None) + if group_series is not None and getattr(self, "cross_validation", None) is not False: + n_groups = pd.Series(group_series).nunique(dropna=False) + fold_count = getattr(self, "cross_validation_folds", None) + if fold_count is not None and fold_count > n_groups: + LOG.warning( + "cross_validation_folds=%s exceeds unique groups=%s; " + "skipping group-aware CV.", + fold_count, + n_groups, + ) + else: + self.setup_params["fold_strategy"] = "groupkfold" + self.setup_params["fold_groups"] = pd.Series(group_series).reset_index(drop=True) + LOG.info( + "Enabled group-aware CV with %s unique groups.", + n_groups, + ) + if self.task_type == "classification": from pycaret.classification import ClassificationExperiment
