Mercurial > repos > goeckslab > image_learner
changeset 20:64872c48a21f draft default tip
planemo upload for repository https://github.com/goeckslab/gleam.git commit d4b122527a2402e43512f9b4bda00c7bff0ec9e9
| author | goeckslab |
|---|---|
| date | Tue, 06 Jan 2026 15:35:11 +0000 |
| parents | c460abae83eb |
| children | |
| files | image_learner.xml image_learner_cli.py image_workflow.py ludwig_backend.py split_data.py |
| diffstat | 5 files changed, 298 insertions(+), 58 deletions(-) [+] |
line wrap: on
line diff
--- a/image_learner.xml Thu Dec 18 16:59:58 2025 +0000 +++ b/image_learner.xml Tue Jan 06 15:35:11 2026 +0000 @@ -29,6 +29,7 @@ ln -sf '$input_csv' "./${sanitized_input_csv}"; #end if + #set $selected_validation_metric = "" #if $task_selection.task == "binary" #set $selected_validation_metric = $task_selection.validation_metric_binary #elif $task_selection.task == "classification" @@ -36,7 +37,9 @@ #elif $task_selection.task == "regression" #set $selected_validation_metric = $task_selection.validation_metric_regression #else - #set $selected_validation_metric = None + #if $task_selection.validation_metric_auto + #set $selected_validation_metric = $task_selection.validation_metric_auto + #end if #end if python '$__tool_directory__/image_learner_cli.py' @@ -81,6 +84,9 @@ --image-column "$column_override.image_column" #end if #end if + #if $sample_id_column + --sample-id-column "$sample_id_column" + #end if --image-resize "$image_resize" --random-seed "$random_seed" --output-dir "." && @@ -103,26 +109,26 @@ <option value="regression">Regression</option> </param> <when value="binary"> - <param name="validation_metric_binary" type="select" optional="true" label="Validation metric (binary)" help="Metrics accepted by Ludwig for binary outputs."> - <option value="roc_auc" selected="true">ROC-AUC</option> + <param name="validation_metric_binary" type="select" optional="true" label="Validation metric (binary)" help="Metrics accepted by Ludwig 0.10.1 for binary outputs."> + <option value="" selected="true">Auto (use task default)</option> + <option value="roc_auc">ROC-AUC</option> <option value="accuracy">Accuracy</option> <option value="precision">Precision</option> <option value="recall">Recall</option> - <option value="specificity">Specificity</option> <option value="loss">Loss</option> </param> </when> <when value="classification"> - <param name="validation_metric_multiclass" type="select" optional="true" label="Validation metric (multi-class)" help="Metrics accepted by Ludwig for multi-class outputs."> - <option value="accuracy" selected="true">Accuracy</option> - <option value="balanced_accuracy">Balanced Accuracy</option> - <option value="hits_at_k">Hits at K (top-k)</option> + <param name="validation_metric_multiclass" type="select" optional="true" label="Validation metric (multi-class)" help="Metrics accepted by Ludwig 0.10.1 for multi-class outputs."> + <option value="" selected="true">Auto (use task default)</option> + <option value="accuracy">Accuracy</option> <option value="loss">Loss</option> </param> </when> <when value="regression"> - <param name="validation_metric_regression" type="select" optional="true" label="Validation metric (regression)" help="Metrics accepted by Ludwig for regression outputs."> - <option value="mean_squared_error" selected="true">Mean Squared Error</option> + <param name="validation_metric_regression" type="select" optional="true" label="Validation metric (regression)" help="Metrics accepted by Ludwig 0.10.1 for regression outputs."> + <option value="" selected="true">Auto (use task default)</option> + <option value="mean_squared_error">Mean Squared Error</option> <option value="mean_absolute_error">Mean Absolute Error</option> <option value="root_mean_squared_error">Root Mean Squared Error</option> <option value="root_mean_squared_percentage_error">Root Mean Squared Percentage Error</option> @@ -130,7 +136,9 @@ </param> </when> <when value="auto"> - <!-- No validation metric selection; tool will infer task and metric. --> + <param name="validation_metric_auto" type="select" optional="true" label="Validation metric (auto)" help="Auto defers to the inferred task and picks the best default metric; use this only to override the choice."> + <option value="" selected="true"></option> + </param> </when> </conditional> <conditional name="column_override"> @@ -139,13 +147,14 @@ <option value="true">Yes</option> </param> <when value="true"> - <param name="target_column" type="text" optional="true" label="Target/label column name" help="Overrides the default 'label' column name in the metadata CSV." /> - <param name="image_column" type="text" optional="true" label="Image column name" help="Overrides the default 'image_path' column name in the metadata CSV." /> + <param name="target_column" type="data_column" data_ref="input_csv" use_header_names="true" optional="true" label="Target/label column name" help="Overrides the default 'label' column name in the metadata CSV." /> + <param name="image_column" type="data_column" data_ref="input_csv" use_header_names="true" optional="true" label="Image column name" help="Overrides the default 'image_path' column name in the metadata CSV." /> </when> <when value="false"> <!-- No additional parameters --> </when> </conditional> + <param name="sample_id_column" type="data_column" data_ref="input_csv" use_header_names="true" optional="true" label="Sample ID column (optional)" help="Optional column used to group samples during splitting to prevent data leakage (e.g., patient_id or slide_id). Only used when no split column is provided." /> <param name="model_name" type="select" optional="false" label="Select a model for your experiment" > <option value="resnet18">Resnet18</option> @@ -564,7 +573,7 @@ <param name="advanced_settings|customize_defaults" value="true" /> <param name="advanced_settings|threshold" value="0.6" /> <param name="task_selection|task" value="classification" /> - <param name="task_selection|validation_metric_multiclass" value="balanced_accuracy" /> + <param name="task_selection|validation_metric_multiclass" value="accuracy" /> <output name="output_report"> <assert_contents> <has_text text="Config and Overall Performance Summary" /> @@ -591,6 +600,7 @@ The metadata csv should contain a column with the name 'image_path' and a column with the name 'label'. Optionally, you can also add a column with the name 'split' to specify which split each row belongs to (train, val, test). If you do not provide a split column, the tool will automatically split the data into train, val, and test sets based on the proportions you specify or [0.7, 0.1, 0.2] by default. +You can optionally specify a sample ID column to keep related samples in the same split and prevent data leakage. **Models Available** This tool supports a wide range of state-of-the-art image classification models including:
--- a/image_learner_cli.py Thu Dec 18 16:59:58 2025 +0000 +++ b/image_learner_cli.py Tue Jan 06 15:35:11 2026 +0000 @@ -163,6 +163,15 @@ default=None, help="Name of the image column in the metadata file (defaults to 'image_path').", ) + parser.add_argument( + "--sample-id-column", + type=str, + default=None, + help=( + "Optional column name used to group samples during splitting " + "to prevent data leakage (e.g., patient_id or slide_id)." + ), + ) args = parser.parse_args()
--- a/image_workflow.py Thu Dec 18 16:59:58 2025 +0000 +++ b/image_workflow.py Tue Jan 06 15:35:11 2026 +0000 @@ -168,6 +168,7 @@ split_probabilities=self.args.split_probabilities, random_state=self.args.random_seed, label_column=LABEL_COLUMN_NAME, + group_column=self.args.sample_id_column, ) split_config = { "type": "fixed", @@ -178,6 +179,11 @@ f"{[int(p * 100) for p in self.args.split_probabilities]}% " f"for train/val/test with balanced label distribution." ) + if self.args.sample_id_column: + split_info += ( + f" Grouped by sample ID column '{self.args.sample_id_column}' " + "to prevent data leakage." + ) final_csv = self.temp_dir / TEMP_CSV_FILENAME
--- a/ludwig_backend.py Thu Dec 18 16:59:58 2025 +0000 +++ b/ludwig_backend.py Tue Jan 06 15:35:11 2026 +0000 @@ -416,7 +416,8 @@ "binary": "roc_auc", "category": "accuracy", } - allowed_map = { + # Safe defaults when Ludwig's registry isn't available. + safe_allowed_map = { "regression": { "mean_absolute_error", "mean_squared_error", @@ -429,13 +430,10 @@ "accuracy", "precision", "recall", - "specificity", "loss", }, "category": { "accuracy", - "balanced_accuracy", - "hits_at_k", "loss", }, } @@ -472,6 +470,16 @@ ) if isinstance(metrics_attr, dict): registry_metrics = set(metrics_attr.keys()) + elif isinstance(metrics_attr, (list, tuple, set)): + extracted = set() + for item in metrics_attr: + if isinstance(item, str): + extracted.add(item) + elif hasattr(item, "name"): + extracted.add(str(item.name)) + elif hasattr(item, "__name__"): + extracted.add(str(item.__name__)) + registry_metrics = extracted or None except Exception as exc: logger.debug( "Could not inspect Ludwig metrics for output type %s: %s", @@ -479,12 +487,10 @@ exc, ) - allowed = set(allowed_map.get(task, set())) - if registry_metrics: - # Only keep metrics that Ludwig actually exposes for this output type; - # if the intersection is empty, fall back to the registry set. - intersected = allowed.intersection(registry_metrics) - allowed = intersected or registry_metrics + allowed = set(safe_allowed_map.get(task, set())) + if registry_metrics is not None: + # Use Ludwig's registry when available; fall back to safe defaults if it's empty. + allowed = registry_metrics or allowed if allowed and metric not in allowed: fallback_candidates = [
--- a/split_data.py Thu Dec 18 16:59:58 2025 +0000 +++ b/split_data.py Tue Jan 06 15:35:11 2026 +0000 @@ -81,6 +81,7 @@ split_probabilities: list = [0.7, 0.1, 0.2], random_state: int = 42, label_column: Optional[str] = None, + group_column: Optional[str] = None, ) -> pd.DataFrame: """Create a stratified random split when no split column exists.""" out = df.copy() @@ -88,22 +89,103 @@ # initialize split column out[split_column] = 0 + if group_column and group_column not in out.columns: + logger.warning( + "Group column '%s' not found in data; proceeding without group-aware split.", + group_column, + ) + group_column = None + + def _allocate_split_counts(n_total: int, probs: list) -> list: + """Allocate exact split counts using largest remainder rounding.""" + if n_total <= 0: + return [0 for _ in probs] + + counts = [0 for _ in probs] + active = [i for i, p in enumerate(probs) if p > 0] + remainder = n_total + + if active and n_total >= len(active): + for i in active: + counts[i] = 1 + remainder -= len(active) + + if remainder > 0: + probs_arr = np.array(probs, dtype=float) + probs_arr = probs_arr / probs_arr.sum() + raw = remainder * probs_arr + floors = np.floor(raw).astype(int) + for i, value in enumerate(floors.tolist()): + counts[i] += value + leftover = remainder - int(floors.sum()) + if leftover > 0 and active: + frac = raw - floors + order = sorted(active, key=lambda i: (-frac[i], i)) + for i in range(leftover): + counts[order[i % len(order)]] += 1 + + return counts + + def _choose_split(counts: list, targets: list, active: list) -> int: + remaining = [targets[i] - counts[i] for i in range(len(targets))] + best = max(active, key=lambda i: (remaining[i], -counts[i], -targets[i])) + if remaining[best] <= 0: + best = min(active, key=lambda i: counts[i]) + return best + if not label_column or label_column not in out.columns: logger.warning( "No label column found; using random split without stratification" ) - # fall back to simple random assignment + # fall back to random assignment (group-aware if requested) indices = out.index.tolist() - np.random.seed(random_state) - np.random.shuffle(indices) + rng = np.random.RandomState(random_state) + + if group_column: + group_series = out[group_column].astype(object) + missing_mask = group_series.isna() + if missing_mask.any(): + group_series = group_series.copy() + group_series.loc[missing_mask] = [ + f"__missing__{idx}" for idx in group_series.index[missing_mask] + ] + group_to_indices = {} + for idx, group_id in group_series.items(): + group_to_indices.setdefault(group_id, []).append(idx) + group_ids = sorted(group_to_indices.keys(), key=lambda x: str(x)) + rng.shuffle(group_ids) - n_total = len(indices) - n_train = int(n_total * split_probabilities[0]) - n_val = int(n_total * split_probabilities[1]) + targets = _allocate_split_counts(len(indices), split_probabilities) + counts = [0 for _ in split_probabilities] + active = [i for i, p in enumerate(split_probabilities) if p > 0] + train_idx = [] + val_idx = [] + test_idx = [] + + for group_id in group_ids: + size = len(group_to_indices[group_id]) + split_idx = _choose_split(counts, targets, active) + counts[split_idx] += size + if split_idx == 0: + train_idx.extend(group_to_indices[group_id]) + elif split_idx == 1: + val_idx.extend(group_to_indices[group_id]) + else: + test_idx.extend(group_to_indices[group_id]) + + out.loc[train_idx, split_column] = 0 + out.loc[val_idx, split_column] = 1 + out.loc[test_idx, split_column] = 2 + return out.astype({split_column: int}) + + rng.shuffle(indices) + + targets = _allocate_split_counts(len(indices), split_probabilities) + n_train, n_val, n_test = targets out.loc[indices[:n_train], split_column] = 0 out.loc[indices[n_train:n_train + n_val], split_column] = 1 - out.loc[indices[n_train + n_val:], split_column] = 2 + out.loc[indices[n_train + n_val:n_train + n_val + n_test], split_column] = 2 return out.astype({split_column: int}) @@ -112,48 +194,175 @@ min_samples_per_class = label_counts.min() # ensure we have enough samples for stratification: - # Each class must have at least as many samples as the number of splits, + # Each class must have at least as many samples as the number of nonzero splits, # so that each split can receive at least one sample per class. - min_samples_required = len(split_probabilities) + active_splits = [p for p in split_probabilities if p > 0] + min_samples_required = len(active_splits) if min_samples_per_class < min_samples_required: logger.warning( f"Insufficient samples per class for stratification (min: {min_samples_per_class}, required: {min_samples_required}); using random split" ) - # fall back to simple random assignment + # fall back to simple random assignment (group-aware if requested) indices = out.index.tolist() - np.random.seed(random_state) - np.random.shuffle(indices) + rng = np.random.RandomState(random_state) + + if group_column: + group_series = out[group_column].astype(object) + missing_mask = group_series.isna() + if missing_mask.any(): + group_series = group_series.copy() + group_series.loc[missing_mask] = [ + f"__missing__{idx}" for idx in group_series.index[missing_mask] + ] + group_to_indices = {} + for idx, group_id in group_series.items(): + group_to_indices.setdefault(group_id, []).append(idx) + group_ids = sorted(group_to_indices.keys(), key=lambda x: str(x)) + rng.shuffle(group_ids) - n_total = len(indices) - n_train = int(n_total * split_probabilities[0]) - n_val = int(n_total * split_probabilities[1]) + targets = _allocate_split_counts(len(indices), split_probabilities) + counts = [0 for _ in split_probabilities] + active = [i for i, p in enumerate(split_probabilities) if p > 0] + train_idx = [] + val_idx = [] + test_idx = [] + + for group_id in group_ids: + size = len(group_to_indices[group_id]) + split_idx = _choose_split(counts, targets, active) + counts[split_idx] += size + if split_idx == 0: + train_idx.extend(group_to_indices[group_id]) + elif split_idx == 1: + val_idx.extend(group_to_indices[group_id]) + else: + test_idx.extend(group_to_indices[group_id]) + + out.loc[train_idx, split_column] = 0 + out.loc[val_idx, split_column] = 1 + out.loc[test_idx, split_column] = 2 + return out.astype({split_column: int}) + + rng.shuffle(indices) + + targets = _allocate_split_counts(len(indices), split_probabilities) + n_train, n_val, n_test = targets out.loc[indices[:n_train], split_column] = 0 out.loc[indices[n_train:n_train + n_val], split_column] = 1 - out.loc[indices[n_train + n_val:], split_column] = 2 + out.loc[indices[n_train + n_val:n_train + n_val + n_test], split_column] = 2 return out.astype({split_column: int}) - logger.info("Using stratified random split for train/validation/test sets") + if group_column: + logger.info( + "Using stratified random split for train/validation/test sets (grouped by '%s')", + group_column, + ) + rng = np.random.RandomState(random_state) + + group_series = out[group_column].astype(object) + missing_mask = group_series.isna() + if missing_mask.any(): + group_series = group_series.copy() + group_series.loc[missing_mask] = [ + f"__missing__{idx}" for idx in group_series.index[missing_mask] + ] + + group_to_indices = {} + for idx, group_id in group_series.items(): + group_to_indices.setdefault(group_id, []).append(idx) + + group_ids = sorted(group_to_indices.keys(), key=lambda x: str(x)) + group_labels = {} + mixed_groups = [] + label_series = out[label_column] - # first split: separate test set - train_val_idx, test_idx = train_test_split( - out.index.tolist(), - test_size=split_probabilities[2], - random_state=random_state, - stratify=out[label_column], - ) + for group_id in group_ids: + labels = label_series.loc[group_to_indices[group_id]].dropna().unique() + if len(labels) == 1: + group_labels[group_id] = labels[0] + elif len(labels) == 0: + group_labels[group_id] = None + else: + mode_vals = label_series.loc[group_to_indices[group_id]].mode(dropna=True) + group_labels[group_id] = mode_vals.iloc[0] if not mode_vals.empty else labels[0] + mixed_groups.append(group_id) + + if mixed_groups: + logger.warning( + "Detected %d groups with multiple labels; using the most common label per group for stratification.", + len(mixed_groups), + ) + + train_idx = [] + val_idx = [] + test_idx = [] + active = [i for i, p in enumerate(split_probabilities) if p > 0] + + for label_value in sorted(label_counts.index.tolist(), key=lambda x: str(x)): + label_groups = [g for g in group_ids if group_labels.get(g) == label_value] + if not label_groups: + continue + rng.shuffle(label_groups) + label_total = sum(len(group_to_indices[g]) for g in label_groups) + targets = _allocate_split_counts(label_total, split_probabilities) + counts = [0 for _ in split_probabilities] - # second split: separate training and validation from remaining data - val_size_adjusted = split_probabilities[1] / ( - split_probabilities[0] + split_probabilities[1] - ) - train_idx, val_idx = train_test_split( - train_val_idx, - test_size=val_size_adjusted, - random_state=random_state, - stratify=out.loc[train_val_idx, label_column] if label_column and label_column in out.columns else None, - ) + for group_id in label_groups: + size = len(group_to_indices[group_id]) + split_idx = _choose_split(counts, targets, active) + counts[split_idx] += size + if split_idx == 0: + train_idx.extend(group_to_indices[group_id]) + elif split_idx == 1: + val_idx.extend(group_to_indices[group_id]) + else: + test_idx.extend(group_to_indices[group_id]) + + # Assign groups without a label (or missing labels) using overall targets. + unlabeled_groups = [g for g in group_ids if group_labels.get(g) is None] + if unlabeled_groups: + rng.shuffle(unlabeled_groups) + total_unlabeled = sum(len(group_to_indices[g]) for g in unlabeled_groups) + targets = _allocate_split_counts(total_unlabeled, split_probabilities) + counts = [0 for _ in split_probabilities] + for group_id in unlabeled_groups: + size = len(group_to_indices[group_id]) + split_idx = _choose_split(counts, targets, active) + counts[split_idx] += size + if split_idx == 0: + train_idx.extend(group_to_indices[group_id]) + elif split_idx == 1: + val_idx.extend(group_to_indices[group_id]) + else: + test_idx.extend(group_to_indices[group_id]) + + out.loc[train_idx, split_column] = 0 + out.loc[val_idx, split_column] = 1 + out.loc[test_idx, split_column] = 2 + + logger.info("Successfully applied stratified random split") + logger.info( + f"Split counts: Train={len(train_idx)}, Val={len(val_idx)}, Test={len(test_idx)}" + ) + return out.astype({split_column: int}) + + logger.info("Using stratified random split for train/validation/test sets (per-class allocation)") + + rng = np.random.RandomState(random_state) + label_values = sorted(label_counts.index.tolist(), key=lambda x: str(x)) + train_idx = [] + val_idx = [] + test_idx = [] + + for label_value in label_values: + label_indices = out.index[out[label_column] == label_value].tolist() + rng.shuffle(label_indices) + n_train, n_val, n_test = _allocate_split_counts(len(label_indices), split_probabilities) + train_idx.extend(label_indices[:n_train]) + val_idx.extend(label_indices[n_train:n_train + n_val]) + test_idx.extend(label_indices[n_train + n_val:n_train + n_val + n_test]) # assign split values out.loc[train_idx, split_column] = 0
