changeset 21:d5c582cf74bc draft default tip

planemo upload for repository https://github.com/goeckslab/gleam.git commit eed8c1e1d99a8a0c8f3a6bfdd8af48a5bfa19444
author goeckslab
date Tue, 20 Jan 2026 01:25:35 +0000
parents 64872c48a21f
children
files image_learner.xml image_learner_cli.py image_workflow.py
diffstat 3 files changed, 23 insertions(+), 0 deletions(-) [+]
line wrap: on
line diff
--- a/image_learner.xml	Tue Jan 06 15:35:11 2026 +0000
+++ b/image_learner.xml	Tue Jan 20 01:25:35 2026 +0000
@@ -28,6 +28,8 @@
             #set $sanitized_input_csv = re.sub('[^\w\-_\.]', '_', $input_csv.element_identifier.strip())
             ln -sf '$input_csv' "./${sanitized_input_csv}";
             #end if
+            #set $TORCH_HOME = "./torch_cache";
+            mkdir -p "$TORCH_HOME";
 
             #set $selected_validation_metric = ""
             #if $task_selection.task == "binary"
@@ -89,6 +91,7 @@
                 #end if
                 --image-resize "$image_resize"
                 --random-seed "$random_seed"
+                --torch-home "$TORCH_HOME"
                 --output-dir "." &&
 
             mkdir -p '$output_model.extra_files_path' &&
--- a/image_learner_cli.py	Tue Jan 06 15:35:11 2026 +0000
+++ b/image_learner_cli.py	Tue Jan 20 01:25:35 2026 +0000
@@ -172,6 +172,15 @@
             "to prevent data leakage (e.g., patient_id or slide_id)."
         ),
     )
+    parser.add_argument(
+        "--torch-home",
+        type=Path,
+        default=None,
+        help=(
+            "Directory for Torch Hub cache (pretrained weights). "
+            "Overrides TORCH_HOME for this run."
+        ),
+    )
 
     args = parser.parse_args()
 
--- a/image_workflow.py	Tue Jan 06 15:35:11 2026 +0000
+++ b/image_workflow.py	Tue Jan 20 01:25:35 2026 +0000
@@ -367,6 +367,17 @@
         self.args.output_dir.mkdir(parents=True, exist_ok=True)
 
         try:
+            if getattr(self.args, "torch_home", None):
+                torch_home = Path(self.args.torch_home).expanduser().resolve()
+                torch_home.mkdir(parents=True, exist_ok=True)
+                os.environ["TORCH_HOME"] = str(torch_home)
+                try:
+                    import torch
+
+                    torch.hub.set_dir(str(torch_home))
+                except Exception as exc:
+                    logger.warning("Unable to set Torch Hub cache dir: %s", exc)
+
             self._create_temp_dirs()
             self._extract_images()
             csv_path, split_cfg, split_info = self._prepare_data()