changeset 20:64872c48a21f draft default tip

planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
author goeckslab
date Tue, 06 Jan 2026 15:35:11 +0000
parents c460abae83eb
children
files image_learner.xml image_learner_cli.py image_workflow.py ludwig_backend.py split_data.py
diffstat 5 files changed, 298 insertions(+), 58 deletions(-) [+]
line wrap: on
line diff
--- a/image_learner.xml	Thu Dec 18 16:59:58 2025 +0000
+++ b/image_learner.xml	Tue Jan 06 15:35:11 2026 +0000
@@ -29,6 +29,7 @@
             ln -sf '$input_csv' "./${sanitized_input_csv}";
             #end if
 
+            #set $selected_validation_metric = ""
             #if $task_selection.task == "binary"
                 #set $selected_validation_metric = $task_selection.validation_metric_binary
             #elif $task_selection.task == "classification"
@@ -36,7 +37,9 @@
             #elif $task_selection.task == "regression"
                 #set $selected_validation_metric = $task_selection.validation_metric_regression
             #else
-                #set $selected_validation_metric = None
+                #if $task_selection.validation_metric_auto
+                    #set $selected_validation_metric = $task_selection.validation_metric_auto
+                #end if
             #end if
 
             python '$__tool_directory__/image_learner_cli.py'
@@ -81,6 +84,9 @@
                         --image-column "$column_override.image_column"
                     #end if
                 #end if
+                #if $sample_id_column
+                    --sample-id-column "$sample_id_column"
+                #end if
                 --image-resize "$image_resize"
                 --random-seed "$random_seed"
                 --output-dir "." &&
@@ -103,26 +109,26 @@
                 <option value="regression">Regression</option>
             </param>
             <when value="binary">
-                <param name="validation_metric_binary" type="select" optional="true" label="Validation metric (binary)" help="Metrics accepted by Ludwig for binary outputs.">
-                    <option value="roc_auc" selected="true">ROC-AUC</option>
+                <param name="validation_metric_binary" type="select" optional="true" label="Validation metric (binary)" help="Metrics accepted by Ludwig 0.10.1 for binary outputs.">
+                    <option value="" selected="true">Auto (use task default)</option>
+                    <option value="roc_auc">ROC-AUC</option>
                     <option value="accuracy">Accuracy</option>
                     <option value="precision">Precision</option>
                     <option value="recall">Recall</option>
-                    <option value="specificity">Specificity</option>
                     <option value="loss">Loss</option>
                 </param>
             </when>
             <when value="classification">
-                <param name="validation_metric_multiclass" type="select" optional="true" label="Validation metric (multi-class)" help="Metrics accepted by Ludwig for multi-class outputs.">
-                    <option value="accuracy" selected="true">Accuracy</option>
-                    <option value="balanced_accuracy">Balanced Accuracy</option>
-                    <option value="hits_at_k">Hits at K (top-k)</option>
+                <param name="validation_metric_multiclass" type="select" optional="true" label="Validation metric (multi-class)" help="Metrics accepted by Ludwig 0.10.1 for multi-class outputs.">
+                    <option value="" selected="true">Auto (use task default)</option>
+                    <option value="accuracy">Accuracy</option>
                     <option value="loss">Loss</option>
                 </param>
             </when>
             <when value="regression">
-                <param name="validation_metric_regression" type="select" optional="true" label="Validation metric (regression)" help="Metrics accepted by Ludwig for regression outputs.">
-                    <option value="mean_squared_error" selected="true">Mean Squared Error</option>
+                <param name="validation_metric_regression" type="select" optional="true" label="Validation metric (regression)" help="Metrics accepted by Ludwig 0.10.1 for regression outputs.">
+                    <option value="" selected="true">Auto (use task default)</option>
+                    <option value="mean_squared_error">Mean Squared Error</option>
                     <option value="mean_absolute_error">Mean Absolute Error</option>
                     <option value="root_mean_squared_error">Root Mean Squared Error</option>
                     <option value="root_mean_squared_percentage_error">Root Mean Squared Percentage Error</option>
@@ -130,7 +136,9 @@
                 </param>
             </when>
             <when value="auto">
-                <!-- No validation metric selection; tool will infer task and metric. -->
+                <param name="validation_metric_auto" type="select" optional="true" label="Validation metric (auto)" help="Auto defers to the inferred task and picks the best default metric; use this only to override the choice.">
+                    <option value="" selected="true"></option>
+                </param>
             </when>
         </conditional>
         <conditional name="column_override">
@@ -139,13 +147,14 @@
                 <option value="true">Yes</option>
             </param>
             <when value="true">
-                <param name="target_column" type="text" optional="true" label="Target/label column name" help="Overrides the default 'label' column name in the metadata CSV." />
-                <param name="image_column" type="text" optional="true" label="Image column name" help="Overrides the default 'image_path' column name in the metadata CSV." />
+                <param name="target_column" type="data_column" data_ref="input_csv" use_header_names="true" optional="true" label="Target/label column name" help="Overrides the default 'label' column name in the metadata CSV." />
+                <param name="image_column" type="data_column" data_ref="input_csv" use_header_names="true" optional="true" label="Image column name" help="Overrides the default 'image_path' column name in the metadata CSV." />
             </when>
             <when value="false">
                 <!-- No additional parameters -->
             </when>
         </conditional>
+        <param name="sample_id_column" type="data_column" data_ref="input_csv" use_header_names="true" optional="true" label="Sample ID column (optional)" help="Optional column used to group samples during splitting to prevent data leakage (e.g., patient_id or slide_id). Only used when no split column is provided." />
         <param name="model_name" type="select" optional="false" label="Select a model for your experiment" >
 
             <option value="resnet18">Resnet18</option>
@@ -564,7 +573,7 @@
             <param name="advanced_settings|customize_defaults" value="true" />
             <param name="advanced_settings|threshold" value="0.6" />
             <param name="task_selection|task" value="classification" />
-            <param name="task_selection|validation_metric_multiclass" value="balanced_accuracy" />
+            <param name="task_selection|validation_metric_multiclass" value="accuracy" />
             <output name="output_report">
                 <assert_contents>
                     <has_text text="Config and Overall Performance Summary" />
@@ -591,6 +600,7 @@
 The metadata csv should contain a column with the name 'image_path' and a column with the name 'label'.
 Optionally, you can also add a column with the name 'split' to specify which split each row belongs to (train, val, test). 
 If you do not provide a split column, the tool will automatically split the data into train, val, and test sets based on the proportions you specify or [0.7, 0.1, 0.2] by default.
+You can optionally specify a sample ID column to keep related samples in the same split and prevent data leakage.
 
 **Models Available**
 This tool supports a wide range of state-of-the-art image classification models including:
--- a/image_learner_cli.py	Thu Dec 18 16:59:58 2025 +0000
+++ b/image_learner_cli.py	Tue Jan 06 15:35:11 2026 +0000
@@ -163,6 +163,15 @@
         default=None,
         help="Name of the image column in the metadata file (defaults to 'image_path').",
     )
+    parser.add_argument(
+        "--sample-id-column",
+        type=str,
+        default=None,
+        help=(
+            "Optional column name used to group samples during splitting "
+            "to prevent data leakage (e.g., patient_id or slide_id)."
+        ),
+    )
 
     args = parser.parse_args()
 
--- a/image_workflow.py	Thu Dec 18 16:59:58 2025 +0000
+++ b/image_workflow.py	Tue Jan 06 15:35:11 2026 +0000
@@ -168,6 +168,7 @@
                 split_probabilities=self.args.split_probabilities,
                 random_state=self.args.random_seed,
                 label_column=LABEL_COLUMN_NAME,
+                group_column=self.args.sample_id_column,
             )
             split_config = {
                 "type": "fixed",
@@ -178,6 +179,11 @@
                 f"{[int(p * 100) for p in self.args.split_probabilities]}% "
                 f"for train/val/test with balanced label distribution."
             )
+            if self.args.sample_id_column:
+                split_info += (
+                    f" Grouped by sample ID column '{self.args.sample_id_column}' "
+                    "to prevent data leakage."
+                )
 
         final_csv = self.temp_dir / TEMP_CSV_FILENAME
 
--- a/ludwig_backend.py	Thu Dec 18 16:59:58 2025 +0000
+++ b/ludwig_backend.py	Tue Jan 06 15:35:11 2026 +0000
@@ -416,7 +416,8 @@
                 "binary": "roc_auc",
                 "category": "accuracy",
             }
-            allowed_map = {
+            # Safe defaults when Ludwig's registry isn't available.
+            safe_allowed_map = {
                 "regression": {
                     "mean_absolute_error",
                     "mean_squared_error",
@@ -429,13 +430,10 @@
                     "accuracy",
                     "precision",
                     "recall",
-                    "specificity",
                     "loss",
                 },
                 "category": {
                     "accuracy",
-                    "balanced_accuracy",
-                    "hits_at_k",
                     "loss",
                 },
             }
@@ -472,6 +470,16 @@
                     )
                     if isinstance(metrics_attr, dict):
                         registry_metrics = set(metrics_attr.keys())
+                    elif isinstance(metrics_attr, (list, tuple, set)):
+                        extracted = set()
+                        for item in metrics_attr:
+                            if isinstance(item, str):
+                                extracted.add(item)
+                            elif hasattr(item, "name"):
+                                extracted.add(str(item.name))
+                            elif hasattr(item, "__name__"):
+                                extracted.add(str(item.__name__))
+                        registry_metrics = extracted or None
             except Exception as exc:
                 logger.debug(
                     "Could not inspect Ludwig metrics for output type %s: %s",
@@ -479,12 +487,10 @@
                     exc,
                 )
 
-            allowed = set(allowed_map.get(task, set()))
-            if registry_metrics:
-                # Only keep metrics that Ludwig actually exposes for this output type;
-                # if the intersection is empty, fall back to the registry set.
-                intersected = allowed.intersection(registry_metrics)
-                allowed = intersected or registry_metrics
+            allowed = set(safe_allowed_map.get(task, set()))
+            if registry_metrics is not None:
+                # Use Ludwig's registry when available; fall back to safe defaults if it's empty.
+                allowed = registry_metrics or allowed
 
             if allowed and metric not in allowed:
                 fallback_candidates = [
--- a/split_data.py	Thu Dec 18 16:59:58 2025 +0000
+++ b/split_data.py	Tue Jan 06 15:35:11 2026 +0000
@@ -81,6 +81,7 @@
     split_probabilities: list = [0.7, 0.1, 0.2],
     random_state: int = 42,
     label_column: Optional[str] = None,
+    group_column: Optional[str] = None,
 ) -> pd.DataFrame:
     """Create a stratified random split when no split column exists."""
     out = df.copy()
@@ -88,22 +89,103 @@
     # initialize split column
     out[split_column] = 0
 
+    if group_column and group_column not in out.columns:
+        logger.warning(
+            "Group column '%s' not found in data; proceeding without group-aware split.",
+            group_column,
+        )
+        group_column = None
+
+    def _allocate_split_counts(n_total: int, probs: list) -> list:
+        """Allocate exact split counts using largest remainder rounding."""
+        if n_total <= 0:
+            return [0 for _ in probs]
+
+        counts = [0 for _ in probs]
+        active = [i for i, p in enumerate(probs) if p > 0]
+        remainder = n_total
+
+        if active and n_total >= len(active):
+            for i in active:
+                counts[i] = 1
+            remainder -= len(active)
+
+        if remainder > 0:
+            probs_arr = np.array(probs, dtype=float)
+            probs_arr = probs_arr / probs_arr.sum()
+            raw = remainder * probs_arr
+            floors = np.floor(raw).astype(int)
+            for i, value in enumerate(floors.tolist()):
+                counts[i] += value
+            leftover = remainder - int(floors.sum())
+            if leftover > 0 and active:
+                frac = raw - floors
+                order = sorted(active, key=lambda i: (-frac[i], i))
+                for i in range(leftover):
+                    counts[order[i % len(order)]] += 1
+
+        return counts
+
+    def _choose_split(counts: list, targets: list, active: list) -> int:
+        remaining = [targets[i] - counts[i] for i in range(len(targets))]
+        best = max(active, key=lambda i: (remaining[i], -counts[i], -targets[i]))
+        if remaining[best] <= 0:
+            best = min(active, key=lambda i: counts[i])
+        return best
+
     if not label_column or label_column not in out.columns:
         logger.warning(
             "No label column found; using random split without stratification"
         )
-        # fall back to simple random assignment
+        # fall back to random assignment (group-aware if requested)
         indices = out.index.tolist()
-        np.random.seed(random_state)
-        np.random.shuffle(indices)
+        rng = np.random.RandomState(random_state)
+
+        if group_column:
+            group_series = out[group_column].astype(object)
+            missing_mask = group_series.isna()
+            if missing_mask.any():
+                group_series = group_series.copy()
+                group_series.loc[missing_mask] = [
+                    f"__missing__{idx}" for idx in group_series.index[missing_mask]
+                ]
+            group_to_indices = {}
+            for idx, group_id in group_series.items():
+                group_to_indices.setdefault(group_id, []).append(idx)
+            group_ids = sorted(group_to_indices.keys(), key=lambda x: str(x))
+            rng.shuffle(group_ids)
 
-        n_total = len(indices)
-        n_train = int(n_total * split_probabilities[0])
-        n_val = int(n_total * split_probabilities[1])
+            targets = _allocate_split_counts(len(indices), split_probabilities)
+            counts = [0 for _ in split_probabilities]
+            active = [i for i, p in enumerate(split_probabilities) if p > 0]
+            train_idx = []
+            val_idx = []
+            test_idx = []
+
+            for group_id in group_ids:
+                size = len(group_to_indices[group_id])
+                split_idx = _choose_split(counts, targets, active)
+                counts[split_idx] += size
+                if split_idx == 0:
+                    train_idx.extend(group_to_indices[group_id])
+                elif split_idx == 1:
+                    val_idx.extend(group_to_indices[group_id])
+                else:
+                    test_idx.extend(group_to_indices[group_id])
+
+            out.loc[train_idx, split_column] = 0
+            out.loc[val_idx, split_column] = 1
+            out.loc[test_idx, split_column] = 2
+            return out.astype({split_column: int})
+
+        rng.shuffle(indices)
+
+        targets = _allocate_split_counts(len(indices), split_probabilities)
+        n_train, n_val, n_test = targets
 
         out.loc[indices[:n_train], split_column] = 0
         out.loc[indices[n_train:n_train + n_val], split_column] = 1
-        out.loc[indices[n_train + n_val:], split_column] = 2
+        out.loc[indices[n_train + n_val:n_train + n_val + n_test], split_column] = 2
 
         return out.astype({split_column: int})
 
@@ -112,48 +194,175 @@
     min_samples_per_class = label_counts.min()
 
     # ensure we have enough samples for stratification:
-    # Each class must have at least as many samples as the number of splits,
+    # Each class must have at least as many samples as the number of nonzero splits,
     # so that each split can receive at least one sample per class.
-    min_samples_required = len(split_probabilities)
+    active_splits = [p for p in split_probabilities if p > 0]
+    min_samples_required = len(active_splits)
     if min_samples_per_class < min_samples_required:
         logger.warning(
             f"Insufficient samples per class for stratification (min: {min_samples_per_class}, required: {min_samples_required}); using random split"
         )
-        # fall back to simple random assignment
+        # fall back to simple random assignment (group-aware if requested)
         indices = out.index.tolist()
-        np.random.seed(random_state)
-        np.random.shuffle(indices)
+        rng = np.random.RandomState(random_state)
+
+        if group_column:
+            group_series = out[group_column].astype(object)
+            missing_mask = group_series.isna()
+            if missing_mask.any():
+                group_series = group_series.copy()
+                group_series.loc[missing_mask] = [
+                    f"__missing__{idx}" for idx in group_series.index[missing_mask]
+                ]
+            group_to_indices = {}
+            for idx, group_id in group_series.items():
+                group_to_indices.setdefault(group_id, []).append(idx)
+            group_ids = sorted(group_to_indices.keys(), key=lambda x: str(x))
+            rng.shuffle(group_ids)
 
-        n_total = len(indices)
-        n_train = int(n_total * split_probabilities[0])
-        n_val = int(n_total * split_probabilities[1])
+            targets = _allocate_split_counts(len(indices), split_probabilities)
+            counts = [0 for _ in split_probabilities]
+            active = [i for i, p in enumerate(split_probabilities) if p > 0]
+            train_idx = []
+            val_idx = []
+            test_idx = []
+
+            for group_id in group_ids:
+                size = len(group_to_indices[group_id])
+                split_idx = _choose_split(counts, targets, active)
+                counts[split_idx] += size
+                if split_idx == 0:
+                    train_idx.extend(group_to_indices[group_id])
+                elif split_idx == 1:
+                    val_idx.extend(group_to_indices[group_id])
+                else:
+                    test_idx.extend(group_to_indices[group_id])
+
+            out.loc[train_idx, split_column] = 0
+            out.loc[val_idx, split_column] = 1
+            out.loc[test_idx, split_column] = 2
+            return out.astype({split_column: int})
+
+        rng.shuffle(indices)
+
+        targets = _allocate_split_counts(len(indices), split_probabilities)
+        n_train, n_val, n_test = targets
 
         out.loc[indices[:n_train], split_column] = 0
         out.loc[indices[n_train:n_train + n_val], split_column] = 1
-        out.loc[indices[n_train + n_val:], split_column] = 2
+        out.loc[indices[n_train + n_val:n_train + n_val + n_test], split_column] = 2
 
         return out.astype({split_column: int})
 
-    logger.info("Using stratified random split for train/validation/test sets")
+    if group_column:
+        logger.info(
+            "Using stratified random split for train/validation/test sets (grouped by '%s')",
+            group_column,
+        )
+        rng = np.random.RandomState(random_state)
+
+        group_series = out[group_column].astype(object)
+        missing_mask = group_series.isna()
+        if missing_mask.any():
+            group_series = group_series.copy()
+            group_series.loc[missing_mask] = [
+                f"__missing__{idx}" for idx in group_series.index[missing_mask]
+            ]
+
+        group_to_indices = {}
+        for idx, group_id in group_series.items():
+            group_to_indices.setdefault(group_id, []).append(idx)
+
+        group_ids = sorted(group_to_indices.keys(), key=lambda x: str(x))
+        group_labels = {}
+        mixed_groups = []
+        label_series = out[label_column]
 
-    # first split: separate test set
-    train_val_idx, test_idx = train_test_split(
-        out.index.tolist(),
-        test_size=split_probabilities[2],
-        random_state=random_state,
-        stratify=out[label_column],
-    )
+        for group_id in group_ids:
+            labels = label_series.loc[group_to_indices[group_id]].dropna().unique()
+            if len(labels) == 1:
+                group_labels[group_id] = labels[0]
+            elif len(labels) == 0:
+                group_labels[group_id] = None
+            else:
+                mode_vals = label_series.loc[group_to_indices[group_id]].mode(dropna=True)
+                group_labels[group_id] = mode_vals.iloc[0] if not mode_vals.empty else labels[0]
+                mixed_groups.append(group_id)
+
+        if mixed_groups:
+            logger.warning(
+                "Detected %d groups with multiple labels; using the most common label per group for stratification.",
+                len(mixed_groups),
+            )
+
+        train_idx = []
+        val_idx = []
+        test_idx = []
+        active = [i for i, p in enumerate(split_probabilities) if p > 0]
+
+        for label_value in sorted(label_counts.index.tolist(), key=lambda x: str(x)):
+            label_groups = [g for g in group_ids if group_labels.get(g) == label_value]
+            if not label_groups:
+                continue
+            rng.shuffle(label_groups)
+            label_total = sum(len(group_to_indices[g]) for g in label_groups)
+            targets = _allocate_split_counts(label_total, split_probabilities)
+            counts = [0 for _ in split_probabilities]
 
-    # second split: separate training and validation from remaining data
-    val_size_adjusted = split_probabilities[1] / (
-        split_probabilities[0] + split_probabilities[1]
-    )
-    train_idx, val_idx = train_test_split(
-        train_val_idx,
-        test_size=val_size_adjusted,
-        random_state=random_state,
-        stratify=out.loc[train_val_idx, label_column] if label_column and label_column in out.columns else None,
-    )
+            for group_id in label_groups:
+                size = len(group_to_indices[group_id])
+                split_idx = _choose_split(counts, targets, active)
+                counts[split_idx] += size
+                if split_idx == 0:
+                    train_idx.extend(group_to_indices[group_id])
+                elif split_idx == 1:
+                    val_idx.extend(group_to_indices[group_id])
+                else:
+                    test_idx.extend(group_to_indices[group_id])
+
+        # Assign groups without a label (or missing labels) using overall targets.
+        unlabeled_groups = [g for g in group_ids if group_labels.get(g) is None]
+        if unlabeled_groups:
+            rng.shuffle(unlabeled_groups)
+            total_unlabeled = sum(len(group_to_indices[g]) for g in unlabeled_groups)
+            targets = _allocate_split_counts(total_unlabeled, split_probabilities)
+            counts = [0 for _ in split_probabilities]
+            for group_id in unlabeled_groups:
+                size = len(group_to_indices[group_id])
+                split_idx = _choose_split(counts, targets, active)
+                counts[split_idx] += size
+                if split_idx == 0:
+                    train_idx.extend(group_to_indices[group_id])
+                elif split_idx == 1:
+                    val_idx.extend(group_to_indices[group_id])
+                else:
+                    test_idx.extend(group_to_indices[group_id])
+
+        out.loc[train_idx, split_column] = 0
+        out.loc[val_idx, split_column] = 1
+        out.loc[test_idx, split_column] = 2
+
+        logger.info("Successfully applied stratified random split")
+        logger.info(
+            f"Split counts: Train={len(train_idx)}, Val={len(val_idx)}, Test={len(test_idx)}"
+        )
+        return out.astype({split_column: int})
+
+    logger.info("Using stratified random split for train/validation/test sets (per-class allocation)")
+
+    rng = np.random.RandomState(random_state)
+    label_values = sorted(label_counts.index.tolist(), key=lambda x: str(x))
+    train_idx = []
+    val_idx = []
+    test_idx = []
+
+    for label_value in label_values:
+        label_indices = out.index[out[label_column] == label_value].tolist()
+        rng.shuffle(label_indices)
+        n_train, n_val, n_test = _allocate_split_counts(len(label_indices), split_probabilities)
+        train_idx.extend(label_indices[:n_train])
+        val_idx.extend(label_indices[n_train:n_train + n_val])
+        test_idx.extend(label_indices[n_train + n_val:n_train + n_val + n_test])
 
     # assign split values
     out.loc[train_idx, split_column] = 0