comparison multimodal_learner.py @ 3:25bb80df7c0c draft

planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
author goeckslab
date Sat, 17 Jan 2026 22:53:42 +0000
parents 375c36923da1
children de753cf07008
comparison
equal deleted inserted replaced
2:b708d0e210e6 3:25bb80df7c0c
66 parser.add_argument("--backbone_image", type=str, default="swin_base_patch4_window7_224") 66 parser.add_argument("--backbone_image", type=str, default="swin_base_patch4_window7_224")
67 parser.add_argument("--backbone_text", type=str, default="microsoft/deberta-v3-base") 67 parser.add_argument("--backbone_text", type=str, default="microsoft/deberta-v3-base")
68 parser.add_argument("--validation_size", type=float, default=0.2) 68 parser.add_argument("--validation_size", type=float, default=0.2)
69 parser.add_argument("--split_probabilities", type=float, nargs=3, 69 parser.add_argument("--split_probabilities", type=float, nargs=3,
70 default=[0.7, 0.1, 0.2], metavar=("train", "val", "test")) 70 default=[0.7, 0.1, 0.2], metavar=("train", "val", "test"))
71 parser.add_argument("--sample_id_column", default=None)
71 parser.add_argument("--preset", choices=["medium_quality", "high_quality", "best_quality"], 72 parser.add_argument("--preset", choices=["medium_quality", "high_quality", "best_quality"],
72 default="medium_quality") 73 default="medium_quality")
73 parser.add_argument("--eval_metric", default="roc_auc") 74 parser.add_argument("--eval_metric", default="roc_auc")
74 parser.add_argument("--hyperparameters", default=None) 75 parser.add_argument("--hyperparameters", default=None)
75 76
101 try: 102 try:
102 use_stratified = y.dtype == object or y.nunique() <= 20 103 use_stratified = y.dtype == object or y.nunique() <= 20
103 except Exception: 104 except Exception:
104 use_stratified = False 105 use_stratified = False
105 106
106 kf = StratifiedKFold(n_splits=int(args.num_folds), shuffle=True, random_state=int(args.random_seed)) if use_stratified else KFold(n_splits=int(args.num_folds), shuffle=True, random_state=int(args.random_seed)) 107 if args.sample_id_column and args.sample_id_column in df_full.columns:
108 groups = df_full[args.sample_id_column]
109 if use_stratified:
110 try:
111 from sklearn.model_selection import StratifiedGroupKFold
112
113 kf = StratifiedGroupKFold(
114 n_splits=int(args.num_folds),
115 shuffle=True,
116 random_state=int(args.random_seed),
117 )
118 except Exception as exc:
119 logger.warning(
120 "StratifiedGroupKFold unavailable (%s); falling back to GroupKFold.",
121 exc,
122 )
123 from sklearn.model_selection import GroupKFold
124
125 kf = GroupKFold(n_splits=int(args.num_folds))
126 use_stratified = False
127 else:
128 from sklearn.model_selection import GroupKFold
129
130 kf = GroupKFold(n_splits=int(args.num_folds))
131 else:
132 kf = StratifiedKFold(n_splits=int(args.num_folds), shuffle=True, random_state=int(args.random_seed)) if use_stratified else KFold(n_splits=int(args.num_folds), shuffle=True, random_state=int(args.random_seed))
133
134 if args.sample_id_column and test_dataset is not None:
135 test_dataset = test_dataset.drop(columns=[args.sample_id_column], errors="ignore")
107 136
108 raw_folds = [] 137 raw_folds = []
109 ag_folds = [] 138 ag_folds = []
110 folds_info = [] 139 folds_info = []
111 last_predictor = None 140 last_predictor = None
112 last_data_ctx = None 141 last_data_ctx = None
113 142
114 for fold_idx, (train_idx, val_idx) in enumerate(kf.split(df_full, y if use_stratified else None), start=1): 143 if args.sample_id_column and args.sample_id_column in df_full.columns:
144 split_iter = kf.split(df_full, y if use_stratified else None, groups)
145 else:
146 split_iter = kf.split(df_full, y if use_stratified else None)
147
148 for fold_idx, (train_idx, val_idx) in enumerate(split_iter, start=1):
115 logger.info(f"CV fold {fold_idx}/{args.num_folds}") 149 logger.info(f"CV fold {fold_idx}/{args.num_folds}")
116 df_tr = df_full.iloc[train_idx].copy() 150 df_tr = df_full.iloc[train_idx].copy()
117 df_va = df_full.iloc[val_idx].copy() 151 df_va = df_full.iloc[val_idx].copy()
152 if args.sample_id_column:
153 df_tr = df_tr.drop(columns=[args.sample_id_column], errors="ignore")
154 df_va = df_va.drop(columns=[args.sample_id_column], errors="ignore")
118 155
119 df_tr["split"] = "train" 156 df_tr["split"] = "train"
120 df_va["split"] = "val" 157 df_va["split"] = "val"
121 fold_dataset = pd.concat([df_tr, df_va], ignore_index=True) 158 fold_dataset = pd.concat([df_tr, df_va], ignore_index=True)
122 159
250 test_dataset=test_dataset, 287 test_dataset=test_dataset,
251 target_column=args.target_column, 288 target_column=args.target_column,
252 split_probabilities=args.split_probabilities, 289 split_probabilities=args.split_probabilities,
253 validation_size=args.validation_size, 290 validation_size=args.validation_size,
254 random_seed=args.random_seed, 291 random_seed=args.random_seed,
292 sample_id_column=args.sample_id_column,
255 ) 293 )
256 294
257 logger.info("Preprocessing complete — ready for AutoGluon training!") 295 logger.info("Preprocessing complete — ready for AutoGluon training!")
258 logger.info(f"Final split counts:\n{train_dataset['split'].value_counts().sort_index()}") 296 logger.info(f"Final split counts:\n{train_dataset['split'].value_counts().sort_index()}")
259 297
333 "raw_metrics": raw_metrics, 371 "raw_metrics": raw_metrics,
334 "ag_eval": ag_by_split, 372 "ag_eval": ag_by_split,
335 "fit_summary": None, 373 "fit_summary": None,
336 } 374 }
337 else: 375 else:
376 # Drop sample-id column before training so it does not leak into modeling.
377 if args.sample_id_column:
378 train_dataset = train_dataset.drop(columns=[args.sample_id_column], errors="ignore")
379 if test_dataset is not None:
380 test_dataset = test_dataset.drop(columns=[args.sample_id_column], errors="ignore")
338 predictor, data_ctx = run_autogluon_experiment( 381 predictor, data_ctx = run_autogluon_experiment(
339 train_dataset=train_dataset, 382 train_dataset=train_dataset,
340 test_dataset=test_dataset, 383 test_dataset=test_dataset,
341 target_column=args.target_column, 384 target_column=args.target_column,
342 image_columns=image_cols, 385 image_columns=image_cols,