Mercurial > repos > goeckslab > image_learner
comparison MetaFormer/metaformer_stacked_cnn.py @ 11:c5150cceab47 draft default tip
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
author | goeckslab |
---|---|
date | Sat, 18 Oct 2025 03:17:09 +0000 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
10:b0d893d04d4c | 11:c5150cceab47 |
---|---|
1 import logging | |
2 import os | |
3 import sys | |
4 from typing import Dict, List, Optional | |
5 | |
6 import torch | |
7 import torch.nn as nn | |
8 | |
9 sys.path.insert(0, os.path.dirname(__file__)) | |
10 | |
11 logging.basicConfig( | |
12 level=logging.INFO, | |
13 format="%(asctime)s %(levelname)s %(name)s: %(message)s", | |
14 ) | |
15 logger = logging.getLogger(__name__) | |
16 | |
17 SUPPORTED_PREFIXES = ( | |
18 'identityformer_', | |
19 'randformer_', | |
20 'poolformerv2_', | |
21 'convformer_', | |
22 'caformer_', | |
23 ) | |
24 | |
25 try: | |
26 from metaformer_models import default_cfgs as META_DEFAULT_CFGS | |
27 META_MODELS_AVAILABLE = True | |
28 logger.info("MetaFormer models imported successfully") | |
29 except Exception as e: | |
30 META_MODELS_AVAILABLE = False | |
31 logger.warning(f"MetaFormer models not available: {e}") | |
32 | |
33 | |
34 def _resolve_metaformer_ctor(model_name: str): | |
35 # Prefer getattr to avoid importing every factory explicitly | |
36 try: | |
37 # Import the module itself for dynamic access | |
38 import metaformer_models | |
39 _factories = metaformer_models.__dict__ | |
40 if model_name in _factories and callable(_factories[model_name]): | |
41 return _factories[model_name] | |
42 except Exception: | |
43 pass | |
44 return None | |
45 | |
46 | |
47 class MetaFormerStackedCNN(nn.Module): | |
48 def __init__( | |
49 self, | |
50 height: int = 224, | |
51 width: int = 224, | |
52 num_channels: int = 3, | |
53 output_size: int = 128, | |
54 custom_model: str = "identityformer_s12", | |
55 use_pretrained: bool = True, | |
56 trainable: bool = True, | |
57 conv_layers: Optional[List[Dict]] = None, | |
58 num_conv_layers: Optional[int] = None, | |
59 conv_activation: str = "relu", | |
60 conv_dropout: float = 0.0, | |
61 conv_norm: Optional[str] = None, | |
62 conv_use_bias: bool = True, | |
63 fc_layers: Optional[List[Dict]] = None, | |
64 num_fc_layers: int = 1, | |
65 fc_activation: str = "relu", | |
66 fc_dropout: float = 0.0, | |
67 fc_norm: Optional[str] = None, | |
68 fc_use_bias: bool = True, | |
69 **kwargs, | |
70 ): | |
71 super().__init__() | |
72 logger.info("MetaFormerStackedCNN encoder instantiated") | |
73 logger.info(f"Using MetaFormer model: {custom_model}") | |
74 | |
75 try: | |
76 height = int(height) | |
77 width = int(width) | |
78 num_channels = int(num_channels) | |
79 except (TypeError, ValueError) as exc: | |
80 raise ValueError("MetaFormerStackedCNN requires integer height, width, and num_channels.") from exc | |
81 | |
82 if height <= 0 or width <= 0: | |
83 raise ValueError(f"MetaFormerStackedCNN received non-positive dimensions: {height}x{width}.") | |
84 if num_channels <= 0: | |
85 raise ValueError(f"MetaFormerStackedCNN requires num_channels > 0, received {num_channels}.") | |
86 | |
87 self.height = height | |
88 self.width = width | |
89 self.num_channels = num_channels | |
90 self.output_size = output_size | |
91 self.custom_model = custom_model | |
92 self.use_pretrained = use_pretrained | |
93 self.trainable = trainable | |
94 | |
95 cfg = META_DEFAULT_CFGS.get(custom_model, {}) | |
96 input_size = cfg.get('input_size', (3, 224, 224)) | |
97 if isinstance(input_size, (list, tuple)) and len(input_size) == 3: | |
98 expected_channels, expected_height, expected_width = input_size | |
99 else: | |
100 expected_channels, expected_height, expected_width = 3, 224, 224 | |
101 | |
102 self.expected_channels = expected_channels | |
103 self.expected_height = expected_height | |
104 self.expected_width = expected_width | |
105 | |
106 logger.info(f"Initializing MetaFormerStackedCNN with model: {custom_model}") | |
107 logger.info( | |
108 "Input: %sx%sx%s -> Output: %s (expected backbone size: %sx%s)", | |
109 num_channels, | |
110 height, | |
111 width, | |
112 output_size, | |
113 self.expected_height, | |
114 self.expected_width, | |
115 ) | |
116 | |
117 self.channel_adapter: Optional[nn.Conv2d] = None | |
118 if num_channels != self.expected_channels: | |
119 self.channel_adapter = nn.Conv2d( | |
120 num_channels, self.expected_channels, kernel_size=1, stride=1, padding=0 | |
121 ) | |
122 logger.info( | |
123 "Added channel adapter: %s -> %s channels", | |
124 num_channels, | |
125 self.expected_channels, | |
126 ) | |
127 | |
128 self.size_adapter: Optional[nn.Module] = None | |
129 if height != self.expected_height or width != self.expected_width: | |
130 self.size_adapter = nn.AdaptiveAvgPool2d((height, width)) | |
131 logger.info( | |
132 "Configured size adapter to requested input: %sx%s", | |
133 height, | |
134 width, | |
135 ) | |
136 self.backbone_adapter: Optional[nn.Module] = None | |
137 | |
138 self.backbone = self._load_metaformer_backbone() | |
139 self.feature_dim = self._get_feature_dim() | |
140 | |
141 self.fc_layers = self._create_fc_layers( | |
142 input_dim=self.feature_dim, | |
143 output_dim=output_size, | |
144 num_layers=num_fc_layers, | |
145 activation=fc_activation, | |
146 dropout=fc_dropout, | |
147 norm=fc_norm, | |
148 use_bias=fc_use_bias, | |
149 fc_layers_config=fc_layers, | |
150 ) | |
151 | |
152 if not trainable: | |
153 for param in self.backbone.parameters(): | |
154 param.requires_grad = False | |
155 logger.info("MetaFormer backbone frozen (trainable=False)") | |
156 | |
157 logger.info("MetaFormerStackedCNN initialized successfully") | |
158 | |
159 def _load_metaformer_backbone(self): | |
160 if not META_MODELS_AVAILABLE: | |
161 raise ImportError("MetaFormer models are not available") | |
162 | |
163 ctor = _resolve_metaformer_ctor(self.custom_model) | |
164 if ctor is None: | |
165 raise ValueError(f"Unknown MetaFormer model: {self.custom_model}") | |
166 | |
167 cfg = META_DEFAULT_CFGS.get(self.custom_model, {}) | |
168 weights_url = cfg.get('url') | |
169 # track loading | |
170 self._pretrained_loaded = False | |
171 self._loaded_weights_url: Optional[str] = None | |
172 if self.use_pretrained and weights_url: | |
173 print(f"LOADING MetaFormer pretrained weights from: {weights_url}") | |
174 logger.info(f"Loading pretrained weights from: {weights_url}") | |
175 # Ensure we log whenever the factories call torch.hub.load_state_dict_from_url | |
176 orig_loader = getattr(torch.hub, 'load_state_dict_from_url', None) | |
177 | |
178 def _wrapped_loader(url, *args, **kwargs): | |
179 print(f"DOWNLOADING weights from: {url}") | |
180 logger.info(f"DOWNLOADING weights from: {url}") | |
181 self._pretrained_loaded = True | |
182 self._loaded_weights_url = url | |
183 result = orig_loader(url, *args, **kwargs) | |
184 print(f"WEIGHTS DOWNLOADED successfully from: {url}") | |
185 return result | |
186 try: | |
187 if self.use_pretrained and orig_loader is not None: | |
188 torch.hub.load_state_dict_from_url = _wrapped_loader # type: ignore[attr-defined] | |
189 print(f"CREATING MetaFormer model: {self.custom_model} (pretrained={self.use_pretrained})") | |
190 try: | |
191 model = ctor(pretrained=self.use_pretrained, num_classes=1000) | |
192 print(f"MetaFormer model CREATED: {self.custom_model}") | |
193 except Exception as model_error: | |
194 if self.use_pretrained: | |
195 print(f"⚠ Warning: Failed to load {self.custom_model} with pretrained weights: {model_error}") | |
196 print("Attempting to load without pretrained weights as fallback...") | |
197 logger.warning(f"Failed to load {self.custom_model} with pretrained weights: {model_error}") | |
198 model = ctor(pretrained=False, num_classes=1000) | |
199 print(f"✓ Successfully loaded {self.custom_model} without pretrained weights") | |
200 self.use_pretrained = False # Update state to reflect actual loading | |
201 else: | |
202 raise model_error | |
203 finally: | |
204 if orig_loader is not None: | |
205 torch.hub.load_state_dict_from_url = orig_loader # type: ignore[attr-defined] | |
206 self._metaformer_weights_url = weights_url | |
207 if self.use_pretrained: | |
208 if self._pretrained_loaded: | |
209 print(f"MetaFormer: pretrained weights loaded from {self._loaded_weights_url}") | |
210 logger.info(f"MetaFormer: pretrained weights loaded from {self._loaded_weights_url}") | |
211 else: | |
212 # Warn but don't fail - weights may have failed to load but model creation succeeded | |
213 print("⚠ Warning: MetaFormer pretrained weights were requested but not confirmed as loaded") | |
214 logger.warning("MetaFormer: pretrained weights were requested but not confirmed as loaded") | |
215 else: | |
216 print(f"MetaFormer: using randomly initialized weights for {self.custom_model}") | |
217 logger.info(f"MetaFormer: using randomly initialized weights for {self.custom_model}") | |
218 logger.info(f"Loaded MetaFormer backbone: {self.custom_model} (pretrained={self.use_pretrained})") | |
219 return model | |
220 | |
221 def _get_feature_dim(self): | |
222 with torch.no_grad(): | |
223 dummy_input = torch.randn(1, 3, 224, 224) | |
224 features = self.backbone.forward_features(dummy_input) | |
225 feature_dim = features.shape[-1] | |
226 logger.info(f"MetaFormer feature dimension: {feature_dim}") | |
227 return feature_dim | |
228 | |
229 def _create_fc_layers(self, input_dim, output_dim, num_layers, activation, dropout, norm, use_bias, fc_layers_config): | |
230 layers = [] | |
231 if fc_layers_config: | |
232 current_dim = input_dim | |
233 for i, layer_config in enumerate(fc_layers_config): | |
234 layer_output_dim = layer_config.get('output_size', output_dim if i == len(fc_layers_config) - 1 else current_dim) | |
235 layers.append(nn.Linear(current_dim, layer_output_dim, bias=use_bias)) | |
236 if i < len(fc_layers_config) - 1: | |
237 if activation == "relu": | |
238 layers.append(nn.ReLU()) | |
239 elif activation == "tanh": | |
240 layers.append(nn.Tanh()) | |
241 elif activation == "sigmoid": | |
242 layers.append(nn.Sigmoid()) | |
243 elif activation == "leaky_relu": | |
244 layers.append(nn.LeakyReLU()) | |
245 if dropout > 0: | |
246 layers.append(nn.Dropout(dropout)) | |
247 if norm == "batch": | |
248 layers.append(nn.BatchNorm1d(layer_output_dim)) | |
249 elif norm == "layer": | |
250 layers.append(nn.LayerNorm(layer_output_dim)) | |
251 current_dim = layer_output_dim | |
252 else: | |
253 if num_layers == 1: | |
254 layers.append(nn.Linear(input_dim, output_dim, bias=use_bias)) | |
255 else: | |
256 intermediate_dims = [input_dim] | |
257 for i in range(num_layers - 1): | |
258 intermediate_dim = int(input_dim * (0.5 ** (i + 1))) | |
259 intermediate_dim = max(intermediate_dim, output_dim) | |
260 intermediate_dims.append(intermediate_dim) | |
261 intermediate_dims.append(output_dim) | |
262 for i in range(num_layers): | |
263 layers.append(nn.Linear(intermediate_dims[i], intermediate_dims[i + 1], bias=use_bias)) | |
264 if i < num_layers - 1: | |
265 if activation == "relu": | |
266 layers.append(nn.ReLU()) | |
267 elif activation == "tanh": | |
268 layers.append(nn.Tanh()) | |
269 elif activation == "sigmoid": | |
270 layers.append(nn.Sigmoid()) | |
271 elif activation == "leaky_relu": | |
272 layers.append(nn.LeakyReLU()) | |
273 if dropout > 0: | |
274 layers.append(nn.Dropout(dropout)) | |
275 if norm == "batch": | |
276 layers.append(nn.BatchNorm1d(intermediate_dims[i + 1])) | |
277 elif norm == "layer": | |
278 layers.append(nn.LayerNorm(intermediate_dims[i + 1])) | |
279 return nn.Sequential(*layers) | |
280 | |
281 def forward(self, x): | |
282 if x.shape[1] != self.expected_channels: | |
283 if ( | |
284 self.channel_adapter is None | |
285 or self.channel_adapter.in_channels != x.shape[1] | |
286 or self.channel_adapter.out_channels != self.expected_channels | |
287 ): | |
288 self.channel_adapter = nn.Conv2d( | |
289 x.shape[1], | |
290 self.expected_channels, | |
291 kernel_size=1, | |
292 stride=1, | |
293 padding=0, | |
294 ).to(x.device) | |
295 logger.info( | |
296 "Created dynamic channel adapter: %s -> %s channels", | |
297 x.shape[1], | |
298 self.expected_channels, | |
299 ) | |
300 x = self.channel_adapter(x) | |
301 | |
302 target_height, target_width = self.height, self.width | |
303 if x.shape[2] != target_height or x.shape[3] != target_width: | |
304 if ( | |
305 self.size_adapter is None | |
306 or getattr(self.size_adapter, "output_size", None) | |
307 != (target_height, target_width) | |
308 ): | |
309 self.size_adapter = nn.AdaptiveAvgPool2d( | |
310 (target_height, target_width) | |
311 ).to(x.device) | |
312 logger.info( | |
313 "Created size adapter: %sx%s -> %sx%s", | |
314 x.shape[2], | |
315 x.shape[3], | |
316 target_height, | |
317 target_width, | |
318 ) | |
319 x = self.size_adapter(x) | |
320 | |
321 if target_height != self.expected_height or target_width != self.expected_width: | |
322 if ( | |
323 self.backbone_adapter is None | |
324 or getattr(self.backbone_adapter, "output_size", None) | |
325 != (self.expected_height, self.expected_width) | |
326 ): | |
327 self.backbone_adapter = nn.AdaptiveAvgPool2d( | |
328 (self.expected_height, self.expected_width) | |
329 ).to(x.device) | |
330 logger.info( | |
331 "Aligning to MetaFormer backbone size: %sx%s", | |
332 self.expected_height, | |
333 self.expected_width, | |
334 ) | |
335 x = self.backbone_adapter(x) | |
336 | |
337 features = self.backbone.forward_features(x) | |
338 output = self.fc_layers(features) | |
339 return {'encoder_output': output} | |
340 | |
341 @property | |
342 def output_shape(self): | |
343 return [self.output_size] | |
344 | |
345 | |
346 def create_metaformer_stacked_cnn(model_name: str, **kwargs) -> MetaFormerStackedCNN: | |
347 encoder = MetaFormerStackedCNN(custom_model=model_name, **kwargs) | |
348 return encoder | |
349 | |
350 | |
351 def patch_ludwig_stacked_cnn(): | |
352 # Only patch Ludwig if MetaFormer models are available in this runtime | |
353 if not META_MODELS_AVAILABLE: | |
354 logger.warning("MetaFormer models unavailable; skipping Ludwig patch for stacked_cnn.") | |
355 return False | |
356 return patch_ludwig_direct() | |
357 | |
358 | |
359 def _is_supported_metaformer(custom_model: Optional[str]) -> bool: | |
360 return bool(custom_model) and custom_model.startswith(SUPPORTED_PREFIXES) | |
361 | |
362 | |
363 def patch_ludwig_direct(): | |
364 try: | |
365 from ludwig.encoders.image.base import Stacked2DCNN | |
366 original_stacked_cnn_init = Stacked2DCNN.__init__ | |
367 | |
368 def patched_stacked_cnn_init(self, *args, **kwargs): | |
369 custom_model = kwargs.pop("custom_model", None) | |
370 if custom_model is None: | |
371 custom_model = getattr(patch_ludwig_direct, '_metaformer_model', None) | |
372 | |
373 try: | |
374 if META_MODELS_AVAILABLE and _is_supported_metaformer(custom_model): | |
375 print(f"DETECTED MetaFormer model: {custom_model}") | |
376 print("MetaFormer encoder is being loaded and used.") | |
377 # Initialize base class to keep Ludwig internals intact | |
378 original_stacked_cnn_init(self, *args, **kwargs) | |
379 # Create our MetaFormer encoder and graft behavior | |
380 mf_encoder = create_metaformer_stacked_cnn(custom_model, **kwargs) | |
381 # ensure base attributes won't be used accidentally | |
382 for attr in ("conv_layers", "fc_layers", "combiner", "output_shape", "reduce_output"): | |
383 if hasattr(self, attr): | |
384 try: | |
385 setattr(self, attr, getattr(mf_encoder, attr, None)) | |
386 except Exception: | |
387 pass | |
388 self.forward = mf_encoder.forward | |
389 if hasattr(mf_encoder, 'backbone'): | |
390 self.backbone = mf_encoder.backbone | |
391 if hasattr(mf_encoder, 'fc_layers'): | |
392 self.fc_layers = mf_encoder.fc_layers | |
393 if hasattr(mf_encoder, 'custom_model'): | |
394 self.custom_model = mf_encoder.custom_model | |
395 # explicit confirmation logs | |
396 try: | |
397 url_info = getattr(mf_encoder, '_loaded_weights_url', None) | |
398 loaded_flag = getattr(mf_encoder, '_pretrained_loaded', False) | |
399 if loaded_flag and url_info: | |
400 print(f"CONFIRMED: MetaFormer '{custom_model}' using pretrained weights from: {url_info}") | |
401 logger.info(f"CONFIRMED: MetaFormer '{custom_model}' using pretrained weights from: {url_info}") | |
402 else: | |
403 print(f"CONFIRMED: MetaFormer '{custom_model}' using randomly initialized weights (no pretrained)") | |
404 logger.info(f"CONFIRMED: MetaFormer '{custom_model}' using randomly initialized weights") | |
405 except Exception: | |
406 pass | |
407 else: | |
408 original_stacked_cnn_init(self, *args, **kwargs) | |
409 finally: | |
410 if hasattr(patch_ludwig_direct, '_metaformer_model'): | |
411 patch_ludwig_direct._metaformer_model = None | |
412 | |
413 Stacked2DCNN.__init__ = patched_stacked_cnn_init | |
414 return True | |
415 except Exception as e: | |
416 logger.error(f"Failed to apply MetaFormer direct patch: {e}") | |
417 return False | |
418 | |
419 | |
420 def set_current_metaformer_model(model_name: str): | |
421 """Store the current MetaFormer model name for the patch to use.""" | |
422 setattr(patch_ludwig_direct, '_metaformer_model', model_name) | |
423 | |
424 | |
425 def clear_current_metaformer_model(): | |
426 """Remove any cached MetaFormer model hint.""" | |
427 if hasattr(patch_ludwig_direct, '_metaformer_model'): | |
428 delattr(patch_ludwig_direct, '_metaformer_model') |