Mercurial > repos > goeckslab > multimodal_learner
comparison multimodal_learner.py @ 3:25bb80df7c0c draft
planemo upload for repository https://github.com/goeckslab/gleam.git commit f0daa5846b336584c708d88f6d7f1b5ee8dc3093
| author | goeckslab |
|---|---|
| date | Sat, 17 Jan 2026 22:53:42 +0000 |
| parents | 375c36923da1 |
| children | de753cf07008 |
comparison
equal
deleted
inserted
replaced
| 2:b708d0e210e6 | 3:25bb80df7c0c |
|---|---|
| 66 parser.add_argument("--backbone_image", type=str, default="swin_base_patch4_window7_224") | 66 parser.add_argument("--backbone_image", type=str, default="swin_base_patch4_window7_224") |
| 67 parser.add_argument("--backbone_text", type=str, default="microsoft/deberta-v3-base") | 67 parser.add_argument("--backbone_text", type=str, default="microsoft/deberta-v3-base") |
| 68 parser.add_argument("--validation_size", type=float, default=0.2) | 68 parser.add_argument("--validation_size", type=float, default=0.2) |
| 69 parser.add_argument("--split_probabilities", type=float, nargs=3, | 69 parser.add_argument("--split_probabilities", type=float, nargs=3, |
| 70 default=[0.7, 0.1, 0.2], metavar=("train", "val", "test")) | 70 default=[0.7, 0.1, 0.2], metavar=("train", "val", "test")) |
| 71 parser.add_argument("--sample_id_column", default=None) | |
| 71 parser.add_argument("--preset", choices=["medium_quality", "high_quality", "best_quality"], | 72 parser.add_argument("--preset", choices=["medium_quality", "high_quality", "best_quality"], |
| 72 default="medium_quality") | 73 default="medium_quality") |
| 73 parser.add_argument("--eval_metric", default="roc_auc") | 74 parser.add_argument("--eval_metric", default="roc_auc") |
| 74 parser.add_argument("--hyperparameters", default=None) | 75 parser.add_argument("--hyperparameters", default=None) |
| 75 | 76 |
| 101 try: | 102 try: |
| 102 use_stratified = y.dtype == object or y.nunique() <= 20 | 103 use_stratified = y.dtype == object or y.nunique() <= 20 |
| 103 except Exception: | 104 except Exception: |
| 104 use_stratified = False | 105 use_stratified = False |
| 105 | 106 |
| 106 kf = StratifiedKFold(n_splits=int(args.num_folds), shuffle=True, random_state=int(args.random_seed)) if use_stratified else KFold(n_splits=int(args.num_folds), shuffle=True, random_state=int(args.random_seed)) | 107 if args.sample_id_column and args.sample_id_column in df_full.columns: |
| 108 groups = df_full[args.sample_id_column] | |
| 109 if use_stratified: | |
| 110 try: | |
| 111 from sklearn.model_selection import StratifiedGroupKFold | |
| 112 | |
| 113 kf = StratifiedGroupKFold( | |
| 114 n_splits=int(args.num_folds), | |
| 115 shuffle=True, | |
| 116 random_state=int(args.random_seed), | |
| 117 ) | |
| 118 except Exception as exc: | |
| 119 logger.warning( | |
| 120 "StratifiedGroupKFold unavailable (%s); falling back to GroupKFold.", | |
| 121 exc, | |
| 122 ) | |
| 123 from sklearn.model_selection import GroupKFold | |
| 124 | |
| 125 kf = GroupKFold(n_splits=int(args.num_folds)) | |
| 126 use_stratified = False | |
| 127 else: | |
| 128 from sklearn.model_selection import GroupKFold | |
| 129 | |
| 130 kf = GroupKFold(n_splits=int(args.num_folds)) | |
| 131 else: | |
| 132 kf = StratifiedKFold(n_splits=int(args.num_folds), shuffle=True, random_state=int(args.random_seed)) if use_stratified else KFold(n_splits=int(args.num_folds), shuffle=True, random_state=int(args.random_seed)) | |
| 133 | |
| 134 if args.sample_id_column and test_dataset is not None: | |
| 135 test_dataset = test_dataset.drop(columns=[args.sample_id_column], errors="ignore") | |
| 107 | 136 |
| 108 raw_folds = [] | 137 raw_folds = [] |
| 109 ag_folds = [] | 138 ag_folds = [] |
| 110 folds_info = [] | 139 folds_info = [] |
| 111 last_predictor = None | 140 last_predictor = None |
| 112 last_data_ctx = None | 141 last_data_ctx = None |
| 113 | 142 |
| 114 for fold_idx, (train_idx, val_idx) in enumerate(kf.split(df_full, y if use_stratified else None), start=1): | 143 if args.sample_id_column and args.sample_id_column in df_full.columns: |
| 144 split_iter = kf.split(df_full, y if use_stratified else None, groups) | |
| 145 else: | |
| 146 split_iter = kf.split(df_full, y if use_stratified else None) | |
| 147 | |
| 148 for fold_idx, (train_idx, val_idx) in enumerate(split_iter, start=1): | |
| 115 logger.info(f"CV fold {fold_idx}/{args.num_folds}") | 149 logger.info(f"CV fold {fold_idx}/{args.num_folds}") |
| 116 df_tr = df_full.iloc[train_idx].copy() | 150 df_tr = df_full.iloc[train_idx].copy() |
| 117 df_va = df_full.iloc[val_idx].copy() | 151 df_va = df_full.iloc[val_idx].copy() |
| 152 if args.sample_id_column: | |
| 153 df_tr = df_tr.drop(columns=[args.sample_id_column], errors="ignore") | |
| 154 df_va = df_va.drop(columns=[args.sample_id_column], errors="ignore") | |
| 118 | 155 |
| 119 df_tr["split"] = "train" | 156 df_tr["split"] = "train" |
| 120 df_va["split"] = "val" | 157 df_va["split"] = "val" |
| 121 fold_dataset = pd.concat([df_tr, df_va], ignore_index=True) | 158 fold_dataset = pd.concat([df_tr, df_va], ignore_index=True) |
| 122 | 159 |
| 250 test_dataset=test_dataset, | 287 test_dataset=test_dataset, |
| 251 target_column=args.target_column, | 288 target_column=args.target_column, |
| 252 split_probabilities=args.split_probabilities, | 289 split_probabilities=args.split_probabilities, |
| 253 validation_size=args.validation_size, | 290 validation_size=args.validation_size, |
| 254 random_seed=args.random_seed, | 291 random_seed=args.random_seed, |
| 292 sample_id_column=args.sample_id_column, | |
| 255 ) | 293 ) |
| 256 | 294 |
| 257 logger.info("Preprocessing complete — ready for AutoGluon training!") | 295 logger.info("Preprocessing complete — ready for AutoGluon training!") |
| 258 logger.info(f"Final split counts:\n{train_dataset['split'].value_counts().sort_index()}") | 296 logger.info(f"Final split counts:\n{train_dataset['split'].value_counts().sort_index()}") |
| 259 | 297 |
| 333 "raw_metrics": raw_metrics, | 371 "raw_metrics": raw_metrics, |
| 334 "ag_eval": ag_by_split, | 372 "ag_eval": ag_by_split, |
| 335 "fit_summary": None, | 373 "fit_summary": None, |
| 336 } | 374 } |
| 337 else: | 375 else: |
| 376 # Drop sample-id column before training so it does not leak into modeling. | |
| 377 if args.sample_id_column: | |
| 378 train_dataset = train_dataset.drop(columns=[args.sample_id_column], errors="ignore") | |
| 379 if test_dataset is not None: | |
| 380 test_dataset = test_dataset.drop(columns=[args.sample_id_column], errors="ignore") | |
| 338 predictor, data_ctx = run_autogluon_experiment( | 381 predictor, data_ctx = run_autogluon_experiment( |
| 339 train_dataset=train_dataset, | 382 train_dataset=train_dataset, |
| 340 test_dataset=test_dataset, | 383 test_dataset=test_dataset, |
| 341 target_column=args.target_column, | 384 target_column=args.target_column, |
| 342 image_columns=image_cols, | 385 image_columns=image_cols, |
