changeset 15:01e7c5481f13 draft default tip

planemo upload for repository https://github.com/goeckslab/gleam commit f632803cda732005bdcf3ac3e8fe7a807a82c1d9
author goeckslab
date Mon, 19 Jan 2026 05:54:52 +0000
parents edd515746388
children
files base_model_trainer.py pycaret_macros.xml pycaret_train.py tabular_learner.xml
diffstat 4 files changed, 178 insertions(+), 3 deletions(-) [+]
line wrap: on
line diff
--- a/base_model_trainer.py	Mon Dec 29 20:34:38 2025 +0000
+++ b/base_model_trainer.py	Mon Jan 19 05:54:52 2026 +0000
@@ -119,6 +119,13 @@
             )
 
         self.target = names[target_index]
+        sample_id_column = getattr(self, "sample_id_column", None)
+        if sample_id_column:
+            sample_id_column = sample_id_column.replace(".", "_")
+            self.sample_id_column = sample_id_column
+        else:
+            self.sample_id_column = None
+        self.sample_id_series = None
 
         # Conditional drop: only if 'prediction_label' exists and is not
         # the target
@@ -154,8 +161,24 @@
         names = self.data.columns.to_list()
         LOG.info(f"Dataset columns after processing: {names}")
 
-        self.features_name = [n for n in names if n != self.target]
-        self.plot_feature_names = self._select_plot_features(self.features_name)
+        sample_id_valid = False
+        if sample_id_column:
+            if sample_id_column not in self.data.columns:
+                LOG.warning(
+                    "Sample ID column '%s' not found; proceeding without group-aware split.",
+                    sample_id_column,
+                )
+                sample_id_column = None
+                self.sample_id_column = None
+            elif sample_id_column == self.target:
+                LOG.warning(
+                    "Sample ID column '%s' matches target column; skipping group-aware split.",
+                    sample_id_column,
+                )
+                sample_id_column = None
+                self.sample_id_column = None
+            else:
+                sample_id_valid = True
 
         if self.test_file:
             LOG.info(f"Loading test data from {self.test_file}")
@@ -165,6 +188,113 @@
             df_test.columns = df_test.columns.str.replace(".", "_")
             self.test_data = df_test
 
+        if sample_id_valid and self.test_data is None:
+            train_size = getattr(self, "train_size", None)
+            if train_size is None:
+                train_size = 0.7
+            if train_size <= 0 or train_size >= 1:
+                LOG.warning(
+                    "Invalid train_size=%s; skipping group-aware split.",
+                    train_size,
+                )
+            else:
+                rng = np.random.RandomState(self.random_seed)
+
+                def _allocate_split_counts(n_total: int, probs: list) -> list:
+                    if n_total <= 0:
+                        return [0 for _ in probs]
+                    counts = [0 for _ in probs]
+                    active = [i for i, p in enumerate(probs) if p > 0]
+                    remainder = n_total
+                    if active and n_total >= len(active):
+                        for i in active:
+                            counts[i] = 1
+                        remainder -= len(active)
+                    if remainder > 0:
+                        probs_arr = np.array(probs, dtype=float)
+                        probs_arr = probs_arr / probs_arr.sum()
+                        raw = remainder * probs_arr
+                        floors = np.floor(raw).astype(int)
+                        for i, value in enumerate(floors.tolist()):
+                            counts[i] += value
+                        leftover = remainder - int(floors.sum())
+                        if leftover > 0 and active:
+                            frac = raw - floors
+                            order = sorted(active, key=lambda i: (-frac[i], i))
+                            for i in range(leftover):
+                                counts[order[i % len(order)]] += 1
+                    return counts
+
+                def _choose_split(counts: list, targets: list, active: list) -> int:
+                    remaining = [targets[i] - counts[i] for i in range(len(targets))]
+                    best = max(active, key=lambda i: (remaining[i], -counts[i], -targets[i]))
+                    if remaining[best] <= 0:
+                        best = min(active, key=lambda i: counts[i])
+                    return best
+
+                probs = [train_size, 1.0 - train_size]
+                targets = _allocate_split_counts(len(self.data), probs)
+                counts = [0, 0]
+                active = [0, 1]
+                train_idx = []
+                test_idx = []
+
+                group_series = self.data[sample_id_column].astype(object)
+                missing_mask = group_series.isna()
+                if missing_mask.any():
+                    group_series = group_series.copy()
+                    group_series.loc[missing_mask] = [
+                        f"__missing__{idx}" for idx in group_series.index[missing_mask]
+                    ]
+                group_to_indices = {}
+                for idx, group_id in group_series.items():
+                    group_to_indices.setdefault(group_id, []).append(idx)
+
+                group_ids = sorted(group_to_indices.keys(), key=lambda x: str(x))
+                rng.shuffle(group_ids)
+
+                for group_id in group_ids:
+                    split_idx = _choose_split(counts, targets, active)
+                    counts[split_idx] += len(group_to_indices[group_id])
+                    if split_idx == 0:
+                        train_idx.extend(group_to_indices[group_id])
+                    else:
+                        test_idx.extend(group_to_indices[group_id])
+
+                missing_splits = []
+                if not train_idx:
+                    missing_splits.append("train")
+                if not test_idx:
+                    missing_splits.append("test")
+                if missing_splits:
+                    LOG.warning(
+                        "Group-aware split using '%s' produced empty %s set; "
+                        "falling back to default split.",
+                        sample_id_column,
+                        " and ".join(missing_splits),
+                    )
+                else:
+                    self.test_data = self.data.loc[test_idx].reset_index(drop=True)
+                    self.data = self.data.loc[train_idx].reset_index(drop=True)
+                    LOG.info(
+                        "Applied group-aware split using '%s' (train=%s, test=%s).",
+                        sample_id_column,
+                        len(train_idx),
+                        len(test_idx),
+                    )
+
+        if sample_id_valid:
+            self.sample_id_series = self.data[sample_id_column].copy()
+            if sample_id_column in self.data.columns:
+                self.data = self.data.drop(columns=[sample_id_column])
+            if self.test_data is not None and sample_id_column in self.test_data.columns:
+                self.test_data = self.test_data.drop(columns=[sample_id_column])
+
+        # Refresh feature lists after any sample-id column removal.
+        names = self.data.columns.to_list()
+        self.features_name = [n for n in names if n != self.target]
+        self.plot_feature_names = self._select_plot_features(self.features_name)
+
     def _select_plot_features(self, all_features):
         limit = getattr(self, "plot_feature_limit", 30)
         if not isinstance(limit, int) or limit <= 0:
@@ -242,6 +372,25 @@
             self.setup_params["fold"] = self.cross_validation_folds
         LOG.info(self.setup_params)
 
+        group_series = getattr(self, "sample_id_series", None)
+        if group_series is not None and getattr(self, "cross_validation", None) is not False:
+            n_groups = pd.Series(group_series).nunique(dropna=False)
+            fold_count = getattr(self, "cross_validation_folds", None)
+            if fold_count is not None and fold_count > n_groups:
+                LOG.warning(
+                    "cross_validation_folds=%s exceeds unique groups=%s; "
+                    "skipping group-aware CV.",
+                    fold_count,
+                    n_groups,
+                )
+            else:
+                self.setup_params["fold_strategy"] = "groupkfold"
+                self.setup_params["fold_groups"] = pd.Series(group_series).reset_index(drop=True)
+                LOG.info(
+                    "Enabled group-aware CV with %s unique groups.",
+                    n_groups,
+                )
+
         if self.task_type == "classification":
             from pycaret.classification import ClassificationExperiment
 
--- a/pycaret_macros.xml	Mon Dec 29 20:34:38 2025 +0000
+++ b/pycaret_macros.xml	Mon Jan 19 05:54:52 2026 +0000
@@ -1,5 +1,5 @@
 <macros>
-    <token name="@TABULAR_LEARNER_VERSION@">0.1.3</token>
+    <token name="@TABULAR_LEARNER_VERSION@">0.1.4</token>
     <token name="@PYCARET_VERSION@">3.3.2</token>
     <token name="@SUFFIX@">2</token>
     <token name="@PYCARET_PREDICT_VERSION@">@PYCARET_VERSION@+@SUFFIX@</token>
--- a/pycaret_train.py	Mon Dec 29 20:34:38 2025 +0000
+++ b/pycaret_train.py	Mon Jan 19 05:54:52 2026 +0000
@@ -134,6 +134,15 @@
         default=None,
         help="Metric used to select the best model (e.g. AUC, Accuracy, R2, RMSE).",
     )
+    parser.add_argument(
+        "--sample-id-column",
+        type=str,
+        default=None,
+        help=(
+            "Optional column name used to group samples during splitting "
+            "to prevent data leakage (e.g., patient_id or slide_id)."
+        ),
+    )
 
     args = parser.parse_args()
 
@@ -170,6 +179,7 @@
         "n_jobs": n_jobs,
         "probability_threshold": args.probability_threshold,
         "best_model_metric": args.best_model_metric,
+        "sample_id_column": args.sample_id_column,
     }
     LOG.info(f"Model kwargs: {model_kwargs}")
 
--- a/tabular_learner.xml	Mon Dec 29 20:34:38 2025 +0000
+++ b/tabular_learner.xml	Mon Jan 19 05:54:52 2026 +0000
@@ -7,6 +7,9 @@
     <command>
         <![CDATA[
         python $__tool_directory__/pycaret_train.py --input_file '$input_file' --target_col '$target_feature' --output_dir '.' --random_seed '$random_seed' --n-jobs \${GALAXY_SLOTS:-1}
+        #if $sample_id_selector.use_sample_id == "yes"
+            --sample-id-column '$sample_id_selector.sample_id_column'
+        #end if
         #if $model_selection.model_type == "classification"
             #if $model_selection.classification_models
                 --models '$model_selection.classification_models'
@@ -81,6 +84,18 @@
             </when>
         </conditional>
         <param name="target_feature" multiple="false" type="data_column" use_header_names="true" data_ref="input_file" label="Select the target column:" />
+        <conditional name="sample_id_selector">
+            <param name="use_sample_id" type="select" label="Use a sample ID column for leakage-aware splitting?" help="Select yes to choose a column that groups related records (e.g., patient_id or slide_id).">
+                <option value="no" selected="true">No column selected</option>
+                <option value="yes">Yes</option>
+            </param>
+            <when value="yes">
+                <param name="sample_id_column" type="data_column" data_ref="input_file" use_header_names="true" label="Sample ID column" help="All rows with the same ID stay in the same split to reduce leakage. Used for group-aware splitting when no separate test file is provided, and for group-aware cross-validation when enabled." />
+            </when>
+            <when value="no">
+                <!-- No sample ID column -->
+            </when>
+        </conditional>
         <conditional name="model_selection">
             <param name="model_type" type="select" label="Task">
                 <option value="classification">classification</option>
@@ -311,6 +326,7 @@
     <help>
         This tool uses PyCaret to train and evaluate machine learning models.
         It compares different models on a dataset and provides the best model based on the performance metrics.
+        You can optionally select a sample ID column to keep related records in the same split and reduce data leakage when the tool creates splits internally.
 
         **Outputs**