comparison ludwig_backend.py @ 16:8729f69e9207 draft default tip

planemo upload for repository https://github.com/goeckslab/gleam.git commit bb4bcdc888d73bbfd85d78ce8999a1080fe813ff
author goeckslab
date Wed, 03 Dec 2025 01:28:52 +0000
parents d17e3a1b8659
children
comparison
equal deleted inserted replaced
15:d17e3a1b8659 16:8729f69e9207
1 import json 1 import json
2 import logging 2 import logging
3 import os
3 from pathlib import Path 4 from pathlib import Path
4 from typing import Any, Dict, Optional, Protocol, Tuple 5 from typing import Any, Dict, Optional, Protocol, Tuple
5 6
6 import pandas as pd 7 import pandas as pd
7 import pandas.api.types as ptypes 8 import pandas.api.types as ptypes
160 custom_model = raw_encoder["custom_model"] 161 custom_model = raw_encoder["custom_model"]
161 else: 162 else:
162 custom_model = model_name 163 custom_model = model_name
163 164
164 logger.info(f"DETECTED MetaFormer model: {custom_model}") 165 logger.info(f"DETECTED MetaFormer model: {custom_model}")
166 # Stash the model name for patched Stacked2DCNN in case Ludwig drops custom_model from kwargs
167 try:
168 from MetaFormer.metaformer_stacked_cnn import set_current_metaformer_model
169
170 set_current_metaformer_model(custom_model)
171 except Exception:
172 logger.debug("Could not set current MetaFormer model hint; proceeding without global override")
173 # Also pass via environment to survive process boundaries (e.g., Ray workers)
174 os.environ["GLEAM_META_FORMER_MODEL"] = custom_model
165 cfg_channels, cfg_height, cfg_width = 3, 224, 224 175 cfg_channels, cfg_height, cfg_width = 3, 224, 224
176 model_cfg = {}
166 if META_DEFAULT_CFGS: 177 if META_DEFAULT_CFGS:
167 model_cfg = META_DEFAULT_CFGS.get(custom_model, {}) 178 model_cfg = META_DEFAULT_CFGS.get(custom_model, {})
168 input_size = model_cfg.get("input_size") 179 input_size = model_cfg.get("input_size")
169 if isinstance(input_size, (list, tuple)) and len(input_size) == 3: 180 if isinstance(input_size, (list, tuple)) and len(input_size) == 3:
170 cfg_channels, cfg_height, cfg_width = ( 181 cfg_channels, cfg_height, cfg_width = (
171 int(input_size[0]), 182 int(input_size[0]),
172 int(input_size[1]), 183 int(input_size[1]),
173 int(input_size[2]), 184 int(input_size[2]),
174 ) 185 )
175 186
176 target_height, target_width = cfg_height, cfg_width 187 weights_url = None
188 if isinstance(model_cfg, dict):
189 weights_url = model_cfg.get("url")
190 logger.info(
191 "MetaFormer cfg lookup: model=%s has_cfg=%s url=%s use_pretrained=%s",
192 custom_model,
193 bool(model_cfg),
194 weights_url,
195 use_pretrained,
196 )
197 if use_pretrained and not weights_url:
198 logger.warning(
199 "MetaFormer pretrained requested for %s but no URL found in default cfgs; model will be randomly initialized",
200 custom_model,
201 )
202
177 resize_value = config_params.get("image_resize") 203 resize_value = config_params.get("image_resize")
178 if resize_value and resize_value != "original": 204 if resize_value and resize_value != "original":
179 try: 205 try:
180 dimensions = resize_value.split("x") 206 dimensions = resize_value.split("x")
181 if len(dimensions) == 2: 207 if len(dimensions) == 2:
196 ) 222 )
197 target_height, target_width = cfg_height, cfg_width 223 target_height, target_width = cfg_height, cfg_width
198 else: 224 else:
199 image_zip_path = config_params.get("image_zip", "") 225 image_zip_path = config_params.get("image_zip", "")
200 detected_height, detected_width = self._detect_image_dimensions(image_zip_path) 226 detected_height, detected_width = self._detect_image_dimensions(image_zip_path)
201 if use_pretrained: 227 target_height, target_width = detected_height, detected_width
202 if (detected_height, detected_width) != (cfg_height, cfg_width): 228 if use_pretrained and (detected_height, detected_width) != (cfg_height, cfg_width):
203 logger.info( 229 logger.info(
204 "MetaFormer pretrained weights expect %sx%s; resizing from detected %sx%s", 230 "MetaFormer pretrained weights expect %sx%s; proceeding with detected %sx%s",
205 cfg_height, 231 cfg_height,
206 cfg_width, 232 cfg_width,
207 detected_height, 233 detected_height,
208 detected_width, 234 detected_width,
209 ) 235 )
210 else:
211 target_height, target_width = detected_height, detected_width
212 if target_height <= 0 or target_width <= 0: 236 if target_height <= 0 or target_width <= 0:
213 raise ValueError( 237 raise ValueError(
214 f"Invalid detected image dimensions for MetaFormer: {target_height}x{target_width}." 238 f"Invalid detected image dimensions for MetaFormer: {target_height}x{target_width}."
215 ) 239 )
216 240