diff multimodal_learner.py @ 3:25bb80df7c0c draft

planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
author goeckslab
date Sat, 17 Jan 2026 22:53:42 +0000
parents 375c36923da1
children de753cf07008
line wrap: on
line diff
--- a/multimodal_learner.py	Sat Jan 10 16:13:19 2026 +0000
+++ b/multimodal_learner.py	Sat Jan 17 22:53:42 2026 +0000
@@ -68,6 +68,7 @@
     parser.add_argument("--validation_size", type=float, default=0.2)
     parser.add_argument("--split_probabilities", type=float, nargs=3,
                         default=[0.7, 0.1, 0.2], metavar=("train", "val", "test"))
+    parser.add_argument("--sample_id_column", default=None)
     parser.add_argument("--preset", choices=["medium_quality", "high_quality", "best_quality"],
                         default="medium_quality")
     parser.add_argument("--eval_metric", default="roc_auc")
@@ -103,7 +104,35 @@
     except Exception:
         use_stratified = False
 
-    kf = StratifiedKFold(n_splits=int(args.num_folds), shuffle=True, random_state=int(args.random_seed)) if use_stratified else KFold(n_splits=int(args.num_folds), shuffle=True, random_state=int(args.random_seed))
+    if args.sample_id_column and args.sample_id_column in df_full.columns:
+        groups = df_full[args.sample_id_column]
+        if use_stratified:
+            try:
+                from sklearn.model_selection import StratifiedGroupKFold
+
+                kf = StratifiedGroupKFold(
+                    n_splits=int(args.num_folds),
+                    shuffle=True,
+                    random_state=int(args.random_seed),
+                )
+            except Exception as exc:
+                logger.warning(
+                    "StratifiedGroupKFold unavailable (%s); falling back to GroupKFold.",
+                    exc,
+                )
+                from sklearn.model_selection import GroupKFold
+
+                kf = GroupKFold(n_splits=int(args.num_folds))
+                use_stratified = False
+        else:
+            from sklearn.model_selection import GroupKFold
+
+            kf = GroupKFold(n_splits=int(args.num_folds))
+    else:
+        kf = StratifiedKFold(n_splits=int(args.num_folds), shuffle=True, random_state=int(args.random_seed)) if use_stratified else KFold(n_splits=int(args.num_folds), shuffle=True, random_state=int(args.random_seed))
+
+    if args.sample_id_column and test_dataset is not None:
+        test_dataset = test_dataset.drop(columns=[args.sample_id_column], errors="ignore")
 
     raw_folds = []
     ag_folds = []
@@ -111,10 +140,18 @@
     last_predictor = None
     last_data_ctx = None
 
-    for fold_idx, (train_idx, val_idx) in enumerate(kf.split(df_full, y if use_stratified else None), start=1):
+    if args.sample_id_column and args.sample_id_column in df_full.columns:
+        split_iter = kf.split(df_full, y if use_stratified else None, groups)
+    else:
+        split_iter = kf.split(df_full, y if use_stratified else None)
+
+    for fold_idx, (train_idx, val_idx) in enumerate(split_iter, start=1):
         logger.info(f"CV fold {fold_idx}/{args.num_folds}")
         df_tr = df_full.iloc[train_idx].copy()
         df_va = df_full.iloc[val_idx].copy()
+        if args.sample_id_column:
+            df_tr = df_tr.drop(columns=[args.sample_id_column], errors="ignore")
+            df_va = df_va.drop(columns=[args.sample_id_column], errors="ignore")
 
         df_tr["split"] = "train"
         df_va["split"] = "val"
@@ -252,6 +289,7 @@
         split_probabilities=args.split_probabilities,
         validation_size=args.validation_size,
         random_seed=args.random_seed,
+        sample_id_column=args.sample_id_column,
     )
 
     logger.info("Preprocessing complete — ready for AutoGluon training!")
@@ -335,6 +373,11 @@
             "fit_summary": None,
         }
     else:
+        # Drop sample-id column before training so it does not leak into modeling.
+        if args.sample_id_column:
+            train_dataset = train_dataset.drop(columns=[args.sample_id_column], errors="ignore")
+            if test_dataset is not None:
+                test_dataset = test_dataset.drop(columns=[args.sample_id_column], errors="ignore")
         predictor, data_ctx = run_autogluon_experiment(
             train_dataset=train_dataset,
             test_dataset=test_dataset,