Mercurial > repos > goeckslab > image_learner
changeset 21:d5c582cf74bc draft default tip
planemo upload for repository https://github.com/goeckslab/gleam.git commit eed8c1e1d99a8a0c8f3a6bfdd8af48a5bfa19444
| author | goeckslab |
|---|---|
| date | Tue, 20 Jan 2026 01:25:35 +0000 |
| parents | 64872c48a21f |
| children | |
| files | image_learner.xml image_learner_cli.py image_workflow.py |
| diffstat | 3 files changed, 23 insertions(+), 0 deletions(-) [+] |
line wrap: on
line diff
--- a/image_learner.xml Tue Jan 06 15:35:11 2026 +0000 +++ b/image_learner.xml Tue Jan 20 01:25:35 2026 +0000 @@ -28,6 +28,8 @@ #set $sanitized_input_csv = re.sub('[^\w\-_\.]', '_', $input_csv.element_identifier.strip()) ln -sf '$input_csv' "./${sanitized_input_csv}"; #end if + #set $TORCH_HOME = "./torch_cache"; + mkdir -p "$TORCH_HOME"; #set $selected_validation_metric = "" #if $task_selection.task == "binary" @@ -89,6 +91,7 @@ #end if --image-resize "$image_resize" --random-seed "$random_seed" + --torch-home "$TORCH_HOME" --output-dir "." && mkdir -p '$output_model.extra_files_path' &&
--- a/image_learner_cli.py Tue Jan 06 15:35:11 2026 +0000 +++ b/image_learner_cli.py Tue Jan 20 01:25:35 2026 +0000 @@ -172,6 +172,15 @@ "to prevent data leakage (e.g., patient_id or slide_id)." ), ) + parser.add_argument( + "--torch-home", + type=Path, + default=None, + help=( + "Directory for Torch Hub cache (pretrained weights). " + "Overrides TORCH_HOME for this run." + ), + ) args = parser.parse_args()
--- a/image_workflow.py Tue Jan 06 15:35:11 2026 +0000 +++ b/image_workflow.py Tue Jan 20 01:25:35 2026 +0000 @@ -367,6 +367,17 @@ self.args.output_dir.mkdir(parents=True, exist_ok=True) try: + if getattr(self.args, "torch_home", None): + torch_home = Path(self.args.torch_home).expanduser().resolve() + torch_home.mkdir(parents=True, exist_ok=True) + os.environ["TORCH_HOME"] = str(torch_home) + try: + import torch + + torch.hub.set_dir(str(torch_home)) + except Exception as exc: + logger.warning("Unable to set Torch Hub cache dir: %s", exc) + self._create_temp_dirs() self._extract_images() csv_path, split_cfg, split_info = self._prepare_data()
