Mercurial > repos > goeckslab > tabular_learner
changeset 15:01e7c5481f13 draft default tip
planemo upload for repository https://github.com/goeckslab/gleam commit f632803cda732005bdcf3ac3e8fe7a807a82c1d9
| author | goeckslab |
|---|---|
| date | Mon, 19 Jan 2026 05:54:52 +0000 |
| parents | edd515746388 |
| children | |
| files | base_model_trainer.py pycaret_macros.xml pycaret_train.py tabular_learner.xml |
| diffstat | 4 files changed, 178 insertions(+), 3 deletions(-) [+] |
line wrap: on
line diff
--- a/base_model_trainer.py Mon Dec 29 20:34:38 2025 +0000 +++ b/base_model_trainer.py Mon Jan 19 05:54:52 2026 +0000 @@ -119,6 +119,13 @@ ) self.target = names[target_index] + sample_id_column = getattr(self, "sample_id_column", None) + if sample_id_column: + sample_id_column = sample_id_column.replace(".", "_") + self.sample_id_column = sample_id_column + else: + self.sample_id_column = None + self.sample_id_series = None # Conditional drop: only if 'prediction_label' exists and is not # the target @@ -154,8 +161,24 @@ names = self.data.columns.to_list() LOG.info(f"Dataset columns after processing: {names}") - self.features_name = [n for n in names if n != self.target] - self.plot_feature_names = self._select_plot_features(self.features_name) + sample_id_valid = False + if sample_id_column: + if sample_id_column not in self.data.columns: + LOG.warning( + "Sample ID column '%s' not found; proceeding without group-aware split.", + sample_id_column, + ) + sample_id_column = None + self.sample_id_column = None + elif sample_id_column == self.target: + LOG.warning( + "Sample ID column '%s' matches target column; skipping group-aware split.", + sample_id_column, + ) + sample_id_column = None + self.sample_id_column = None + else: + sample_id_valid = True if self.test_file: LOG.info(f"Loading test data from {self.test_file}") @@ -165,6 +188,113 @@ df_test.columns = df_test.columns.str.replace(".", "_") self.test_data = df_test + if sample_id_valid and self.test_data is None: + train_size = getattr(self, "train_size", None) + if train_size is None: + train_size = 0.7 + if train_size <= 0 or train_size >= 1: + LOG.warning( + "Invalid train_size=%s; skipping group-aware split.", + train_size, + ) + else: + rng = np.random.RandomState(self.random_seed) + + def _allocate_split_counts(n_total: int, probs: list) -> list: + if n_total <= 0: + return [0 for _ in probs] + counts = [0 for _ in probs] + active = [i for i, p in enumerate(probs) if p > 0] + remainder = n_total + if active and n_total >= len(active): + for i in active: + counts[i] = 1 + remainder -= len(active) + if remainder > 0: + probs_arr = np.array(probs, dtype=float) + probs_arr = probs_arr / probs_arr.sum() + raw = remainder * probs_arr + floors = np.floor(raw).astype(int) + for i, value in enumerate(floors.tolist()): + counts[i] += value + leftover = remainder - int(floors.sum()) + if leftover > 0 and active: + frac = raw - floors + order = sorted(active, key=lambda i: (-frac[i], i)) + for i in range(leftover): + counts[order[i % len(order)]] += 1 + return counts + + def _choose_split(counts: list, targets: list, active: list) -> int: + remaining = [targets[i] - counts[i] for i in range(len(targets))] + best = max(active, key=lambda i: (remaining[i], -counts[i], -targets[i])) + if remaining[best] <= 0: + best = min(active, key=lambda i: counts[i]) + return best + + probs = [train_size, 1.0 - train_size] + targets = _allocate_split_counts(len(self.data), probs) + counts = [0, 0] + active = [0, 1] + train_idx = [] + test_idx = [] + + group_series = self.data[sample_id_column].astype(object) + missing_mask = group_series.isna() + if missing_mask.any(): + group_series = group_series.copy() + group_series.loc[missing_mask] = [ + f"__missing__{idx}" for idx in group_series.index[missing_mask] + ] + group_to_indices = {} + for idx, group_id in group_series.items(): + group_to_indices.setdefault(group_id, []).append(idx) + + group_ids = sorted(group_to_indices.keys(), key=lambda x: str(x)) + rng.shuffle(group_ids) + + for group_id in group_ids: + split_idx = _choose_split(counts, targets, active) + counts[split_idx] += len(group_to_indices[group_id]) + if split_idx == 0: + train_idx.extend(group_to_indices[group_id]) + else: + test_idx.extend(group_to_indices[group_id]) + + missing_splits = [] + if not train_idx: + missing_splits.append("train") + if not test_idx: + missing_splits.append("test") + if missing_splits: + LOG.warning( + "Group-aware split using '%s' produced empty %s set; " + "falling back to default split.", + sample_id_column, + " and ".join(missing_splits), + ) + else: + self.test_data = self.data.loc[test_idx].reset_index(drop=True) + self.data = self.data.loc[train_idx].reset_index(drop=True) + LOG.info( + "Applied group-aware split using '%s' (train=%s, test=%s).", + sample_id_column, + len(train_idx), + len(test_idx), + ) + + if sample_id_valid: + self.sample_id_series = self.data[sample_id_column].copy() + if sample_id_column in self.data.columns: + self.data = self.data.drop(columns=[sample_id_column]) + if self.test_data is not None and sample_id_column in self.test_data.columns: + self.test_data = self.test_data.drop(columns=[sample_id_column]) + + # Refresh feature lists after any sample-id column removal. + names = self.data.columns.to_list() + self.features_name = [n for n in names if n != self.target] + self.plot_feature_names = self._select_plot_features(self.features_name) + def _select_plot_features(self, all_features): limit = getattr(self, "plot_feature_limit", 30) if not isinstance(limit, int) or limit <= 0: @@ -242,6 +372,25 @@ self.setup_params["fold"] = self.cross_validation_folds LOG.info(self.setup_params) + group_series = getattr(self, "sample_id_series", None) + if group_series is not None and getattr(self, "cross_validation", None) is not False: + n_groups = pd.Series(group_series).nunique(dropna=False) + fold_count = getattr(self, "cross_validation_folds", None) + if fold_count is not None and fold_count > n_groups: + LOG.warning( + "cross_validation_folds=%s exceeds unique groups=%s; " + "skipping group-aware CV.", + fold_count, + n_groups, + ) + else: + self.setup_params["fold_strategy"] = "groupkfold" + self.setup_params["fold_groups"] = pd.Series(group_series).reset_index(drop=True) + LOG.info( + "Enabled group-aware CV with %s unique groups.", + n_groups, + ) + if self.task_type == "classification": from pycaret.classification import ClassificationExperiment
--- a/pycaret_macros.xml Mon Dec 29 20:34:38 2025 +0000 +++ b/pycaret_macros.xml Mon Jan 19 05:54:52 2026 +0000 @@ -1,5 +1,5 @@ <macros> - <token name="@TABULAR_LEARNER_VERSION@">0.1.3</token> + <token name="@TABULAR_LEARNER_VERSION@">0.1.4</token> <token name="@PYCARET_VERSION@">3.3.2</token> <token name="@SUFFIX@">2</token> <token name="@PYCARET_PREDICT_VERSION@">@PYCARET_VERSION@+@SUFFIX@</token>
--- a/pycaret_train.py Mon Dec 29 20:34:38 2025 +0000 +++ b/pycaret_train.py Mon Jan 19 05:54:52 2026 +0000 @@ -134,6 +134,15 @@ default=None, help="Metric used to select the best model (e.g. AUC, Accuracy, R2, RMSE).", ) + parser.add_argument( + "--sample-id-column", + type=str, + default=None, + help=( + "Optional column name used to group samples during splitting " + "to prevent data leakage (e.g., patient_id or slide_id)." + ), + ) args = parser.parse_args() @@ -170,6 +179,7 @@ "n_jobs": n_jobs, "probability_threshold": args.probability_threshold, "best_model_metric": args.best_model_metric, + "sample_id_column": args.sample_id_column, } LOG.info(f"Model kwargs: {model_kwargs}")
--- a/tabular_learner.xml Mon Dec 29 20:34:38 2025 +0000 +++ b/tabular_learner.xml Mon Jan 19 05:54:52 2026 +0000 @@ -7,6 +7,9 @@ <command> <