comparison MetaFormer/metaformer_stacked_cnn.py @ 11:c5150cceab47 draft default tip

planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
author goeckslab
date Sat, 18 Oct 2025 03:17:09 +0000
parents
children
comparison
equal deleted inserted replaced
10:b0d893d04d4c 11:c5150cceab47
1 import logging
2 import os
3 import sys
4 from typing import Dict, List, Optional
5
6 import torch
7 import torch.nn as nn
8
9 sys.path.insert(0, os.path.dirname(__file__))
10
11 logging.basicConfig(
12 level=logging.INFO,
13 format="%(asctime)s %(levelname)s %(name)s: %(message)s",
14 )
15 logger = logging.getLogger(__name__)
16
17 SUPPORTED_PREFIXES = (
18 'identityformer_',
19 'randformer_',
20 'poolformerv2_',
21 'convformer_',
22 'caformer_',
23 )
24
25 try:
26 from metaformer_models import default_cfgs as META_DEFAULT_CFGS
27 META_MODELS_AVAILABLE = True
28 logger.info("MetaFormer models imported successfully")
29 except Exception as e:
30 META_MODELS_AVAILABLE = False
31 logger.warning(f"MetaFormer models not available: {e}")
32
33
34 def _resolve_metaformer_ctor(model_name: str):
35 # Prefer getattr to avoid importing every factory explicitly
36 try:
37 # Import the module itself for dynamic access
38 import metaformer_models
39 _factories = metaformer_models.__dict__
40 if model_name in _factories and callable(_factories[model_name]):
41 return _factories[model_name]
42 except Exception:
43 pass
44 return None
45
46
47 class MetaFormerStackedCNN(nn.Module):
48 def __init__(
49 self,
50 height: int = 224,
51 width: int = 224,
52 num_channels: int = 3,
53 output_size: int = 128,
54 custom_model: str = "identityformer_s12",
55 use_pretrained: bool = True,
56 trainable: bool = True,
57 conv_layers: Optional[List[Dict]] = None,
58 num_conv_layers: Optional[int] = None,
59 conv_activation: str = "relu",
60 conv_dropout: float = 0.0,
61 conv_norm: Optional[str] = None,
62 conv_use_bias: bool = True,
63 fc_layers: Optional[List[Dict]] = None,
64 num_fc_layers: int = 1,
65 fc_activation: str = "relu",
66 fc_dropout: float = 0.0,
67 fc_norm: Optional[str] = None,
68 fc_use_bias: bool = True,
69 **kwargs,
70 ):
71 super().__init__()
72 logger.info("MetaFormerStackedCNN encoder instantiated")
73 logger.info(f"Using MetaFormer model: {custom_model}")
74
75 try:
76 height = int(height)
77 width = int(width)
78 num_channels = int(num_channels)
79 except (TypeError, ValueError) as exc:
80 raise ValueError("MetaFormerStackedCNN requires integer height, width, and num_channels.") from exc
81
82 if height <= 0 or width <= 0:
83 raise ValueError(f"MetaFormerStackedCNN received non-positive dimensions: {height}x{width}.")
84 if num_channels <= 0:
85 raise ValueError(f"MetaFormerStackedCNN requires num_channels > 0, received {num_channels}.")
86
87 self.height = height
88 self.width = width
89 self.num_channels = num_channels
90 self.output_size = output_size
91 self.custom_model = custom_model
92 self.use_pretrained = use_pretrained
93 self.trainable = trainable
94
95 cfg = META_DEFAULT_CFGS.get(custom_model, {})
96 input_size = cfg.get('input_size', (3, 224, 224))
97 if isinstance(input_size, (list, tuple)) and len(input_size) == 3:
98 expected_channels, expected_height, expected_width = input_size
99 else:
100 expected_channels, expected_height, expected_width = 3, 224, 224
101
102 self.expected_channels = expected_channels
103 self.expected_height = expected_height
104 self.expected_width = expected_width
105
106 logger.info(f"Initializing MetaFormerStackedCNN with model: {custom_model}")
107 logger.info(
108 "Input: %sx%sx%s -> Output: %s (expected backbone size: %sx%s)",
109 num_channels,
110 height,
111 width,
112 output_size,
113 self.expected_height,
114 self.expected_width,
115 )
116
117 self.channel_adapter: Optional[nn.Conv2d] = None
118 if num_channels != self.expected_channels:
119 self.channel_adapter = nn.Conv2d(
120 num_channels, self.expected_channels, kernel_size=1, stride=1, padding=0
121 )
122 logger.info(
123 "Added channel adapter: %s -> %s channels",
124 num_channels,
125 self.expected_channels,
126 )
127
128 self.size_adapter: Optional[nn.Module] = None
129 if height != self.expected_height or width != self.expected_width:
130 self.size_adapter = nn.AdaptiveAvgPool2d((height, width))
131 logger.info(
132 "Configured size adapter to requested input: %sx%s",
133 height,
134 width,
135 )
136 self.backbone_adapter: Optional[nn.Module] = None
137
138 self.backbone = self._load_metaformer_backbone()
139 self.feature_dim = self._get_feature_dim()
140
141 self.fc_layers = self._create_fc_layers(
142 input_dim=self.feature_dim,
143 output_dim=output_size,
144 num_layers=num_fc_layers,
145 activation=fc_activation,
146 dropout=fc_dropout,
147 norm=fc_norm,
148 use_bias=fc_use_bias,
149 fc_layers_config=fc_layers,
150 )
151
152 if not trainable:
153 for param in self.backbone.parameters():
154 param.requires_grad = False
155 logger.info("MetaFormer backbone frozen (trainable=False)")
156
157 logger.info("MetaFormerStackedCNN initialized successfully")
158
159 def _load_metaformer_backbone(self):
160 if not META_MODELS_AVAILABLE:
161 raise ImportError("MetaFormer models are not available")
162
163 ctor = _resolve_metaformer_ctor(self.custom_model)
164 if ctor is None:
165 raise ValueError(f"Unknown MetaFormer model: {self.custom_model}")
166
167 cfg = META_DEFAULT_CFGS.get(self.custom_model, {})
168 weights_url = cfg.get('url')
169 # track loading
170 self._pretrained_loaded = False
171 self._loaded_weights_url: Optional[str] = None
172 if self.use_pretrained and weights_url:
173 print(f"LOADING MetaFormer pretrained weights from: {weights_url}")
174 logger.info(f"Loading pretrained weights from: {weights_url}")
175 # Ensure we log whenever the factories call torch.hub.load_state_dict_from_url
176 orig_loader = getattr(torch.hub, 'load_state_dict_from_url', None)
177
178 def _wrapped_loader(url, *args, **kwargs):
179 print(f"DOWNLOADING weights from: {url}")
180 logger.info(f"DOWNLOADING weights from: {url}")
181 self._pretrained_loaded = True
182 self._loaded_weights_url = url
183 result = orig_loader(url, *args, **kwargs)
184 print(f"WEIGHTS DOWNLOADED successfully from: {url}")
185 return result
186 try:
187 if self.use_pretrained and orig_loader is not None:
188 torch.hub.load_state_dict_from_url = _wrapped_loader # type: ignore[attr-defined]
189 print(f"CREATING MetaFormer model: {self.custom_model} (pretrained={self.use_pretrained})")
190 try:
191 model = ctor(pretrained=self.use_pretrained, num_classes=1000)
192 print(f"MetaFormer model CREATED: {self.custom_model}")
193 except Exception as model_error:
194 if self.use_pretrained:
195 print(f"⚠ Warning: Failed to load {self.custom_model} with pretrained weights: {model_error}")
196 print("Attempting to load without pretrained weights as fallback...")
197 logger.warning(f"Failed to load {self.custom_model} with pretrained weights: {model_error}")
198 model = ctor(pretrained=False, num_classes=1000)
199 print(f"✓ Successfully loaded {self.custom_model} without pretrained weights")
200 self.use_pretrained = False # Update state to reflect actual loading
201 else:
202 raise model_error
203 finally:
204 if orig_loader is not None:
205 torch.hub.load_state_dict_from_url = orig_loader # type: ignore[attr-defined]
206 self._metaformer_weights_url = weights_url
207 if self.use_pretrained:
208 if self._pretrained_loaded:
209 print(f"MetaFormer: pretrained weights loaded from {self._loaded_weights_url}")
210 logger.info(f"MetaFormer: pretrained weights loaded from {self._loaded_weights_url}")
211 else:
212 # Warn but don't fail - weights may have failed to load but model creation succeeded
213 print("⚠ Warning: MetaFormer pretrained weights were requested but not confirmed as loaded")
214 logger.warning("MetaFormer: pretrained weights were requested but not confirmed as loaded")
215 else:
216 print(f"MetaFormer: using randomly initialized weights for {self.custom_model}")
217 logger.info(f"MetaFormer: using randomly initialized weights for {self.custom_model}")
218 logger.info(f"Loaded MetaFormer backbone: {self.custom_model} (pretrained={self.use_pretrained})")
219 return model
220
221 def _get_feature_dim(self):
222 with torch.no_grad():
223 dummy_input = torch.randn(1, 3, 224, 224)
224 features = self.backbone.forward_features(dummy_input)
225 feature_dim = features.shape[-1]
226 logger.info(f"MetaFormer feature dimension: {feature_dim}")
227 return feature_dim
228
229 def _create_fc_layers(self, input_dim, output_dim, num_layers, activation, dropout, norm, use_bias, fc_layers_config):
230 layers = []
231 if fc_layers_config:
232 current_dim = input_dim
233 for i, layer_config in enumerate(fc_layers_config):
234 layer_output_dim = layer_config.get('output_size', output_dim if i == len(fc_layers_config) - 1 else current_dim)
235 layers.append(nn.Linear(current_dim, layer_output_dim, bias=use_bias))
236 if i < len(fc_layers_config) - 1:
237 if activation == "relu":
238 layers.append(nn.ReLU())
239 elif activation == "tanh":
240 layers.append(nn.Tanh())
241 elif activation == "sigmoid":
242 layers.append(nn.Sigmoid())
243 elif activation == "leaky_relu":
244 layers.append(nn.LeakyReLU())
245 if dropout > 0:
246 layers.append(nn.Dropout(dropout))
247 if norm == "batch":
248 layers.append(nn.BatchNorm1d(layer_output_dim))
249 elif norm == "layer":
250 layers.append(nn.LayerNorm(layer_output_dim))
251 current_dim = layer_output_dim
252 else:
253 if num_layers == 1:
254 layers.append(nn.Linear(input_dim, output_dim, bias=use_bias))
255 else:
256 intermediate_dims = [input_dim]
257 for i in range(num_layers - 1):
258 intermediate_dim = int(input_dim * (0.5 ** (i + 1)))
259 intermediate_dim = max(intermediate_dim, output_dim)
260 intermediate_dims.append(intermediate_dim)
261 intermediate_dims.append(output_dim)
262 for i in range(num_layers):
263 layers.append(nn.Linear(intermediate_dims[i], intermediate_dims[i + 1], bias=use_bias))
264 if i < num_layers - 1:
265 if activation == "relu":
266 layers.append(nn.ReLU())
267 elif activation == "tanh":
268 layers.append(nn.Tanh())
269 elif activation == "sigmoid":
270 layers.append(nn.Sigmoid())
271 elif activation == "leaky_relu":
272 layers.append(nn.LeakyReLU())
273 if dropout > 0:
274 layers.append(nn.Dropout(dropout))
275 if norm == "batch":
276 layers.append(nn.BatchNorm1d(intermediate_dims[i + 1]))
277 elif norm == "layer":
278 layers.append(nn.LayerNorm(intermediate_dims[i + 1]))
279 return nn.Sequential(*layers)
280
281 def forward(self, x):
282 if x.shape[1] != self.expected_channels:
283 if (
284 self.channel_adapter is None
285 or self.channel_adapter.in_channels != x.shape[1]
286 or self.channel_adapter.out_channels != self.expected_channels
287 ):
288 self.channel_adapter = nn.Conv2d(
289 x.shape[1],
290 self.expected_channels,
291 kernel_size=1,
292 stride=1,
293 padding=0,
294 ).to(x.device)
295 logger.info(
296 "Created dynamic channel adapter: %s -> %s channels",
297 x.shape[1],
298 self.expected_channels,
299 )
300 x = self.channel_adapter(x)
301
302 target_height, target_width = self.height, self.width
303 if x.shape[2] != target_height or x.shape[3] != target_width:
304 if (
305 self.size_adapter is None
306 or getattr(self.size_adapter, "output_size", None)
307 != (target_height, target_width)
308 ):
309 self.size_adapter = nn.AdaptiveAvgPool2d(
310 (target_height, target_width)
311 ).to(x.device)
312 logger.info(
313 "Created size adapter: %sx%s -> %sx%s",
314 x.shape[2],
315 x.shape[3],
316 target_height,
317 target_width,
318 )
319 x = self.size_adapter(x)
320
321 if target_height != self.expected_height or target_width != self.expected_width:
322 if (
323 self.backbone_adapter is None
324 or getattr(self.backbone_adapter, "output_size", None)
325 != (self.expected_height, self.expected_width)
326 ):
327 self.backbone_adapter = nn.AdaptiveAvgPool2d(
328 (self.expected_height, self.expected_width)
329 ).to(x.device)
330 logger.info(
331 "Aligning to MetaFormer backbone size: %sx%s",
332 self.expected_height,
333 self.expected_width,
334 )
335 x = self.backbone_adapter(x)
336
337 features = self.backbone.forward_features(x)
338 output = self.fc_layers(features)
339 return {'encoder_output': output}
340
341 @property
342 def output_shape(self):
343 return [self.output_size]
344
345
346 def create_metaformer_stacked_cnn(model_name: str, **kwargs) -> MetaFormerStackedCNN:
347 encoder = MetaFormerStackedCNN(custom_model=model_name, **kwargs)
348 return encoder
349
350
351 def patch_ludwig_stacked_cnn():
352 # Only patch Ludwig if MetaFormer models are available in this runtime
353 if not META_MODELS_AVAILABLE:
354 logger.warning("MetaFormer models unavailable; skipping Ludwig patch for stacked_cnn.")
355 return False
356 return patch_ludwig_direct()
357
358
359 def _is_supported_metaformer(custom_model: Optional[str]) -> bool:
360 return bool(custom_model) and custom_model.startswith(SUPPORTED_PREFIXES)
361
362
363 def patch_ludwig_direct():
364 try:
365 from ludwig.encoders.image.base import Stacked2DCNN
366 original_stacked_cnn_init = Stacked2DCNN.__init__
367
368 def patched_stacked_cnn_init(self, *args, **kwargs):
369 custom_model = kwargs.pop("custom_model", None)
370 if custom_model is None:
371 custom_model = getattr(patch_ludwig_direct, '_metaformer_model', None)
372
373 try:
374 if META_MODELS_AVAILABLE and _is_supported_metaformer(custom_model):
375 print(f"DETECTED MetaFormer model: {custom_model}")
376 print("MetaFormer encoder is being loaded and used.")
377 # Initialize base class to keep Ludwig internals intact
378 original_stacked_cnn_init(self, *args, **kwargs)
379 # Create our MetaFormer encoder and graft behavior
380 mf_encoder = create_metaformer_stacked_cnn(custom_model, **kwargs)
381 # ensure base attributes won't be used accidentally
382 for attr in ("conv_layers", "fc_layers", "combiner", "output_shape", "reduce_output"):
383 if hasattr(self, attr):
384 try:
385 setattr(self, attr, getattr(mf_encoder, attr, None))
386 except Exception:
387 pass
388 self.forward = mf_encoder.forward
389 if hasattr(mf_encoder, 'backbone'):
390 self.backbone = mf_encoder.backbone
391 if hasattr(mf_encoder, 'fc_layers'):
392 self.fc_layers = mf_encoder.fc_layers
393 if hasattr(mf_encoder, 'custom_model'):
394 self.custom_model = mf_encoder.custom_model
395 # explicit confirmation logs
396 try:
397 url_info = getattr(mf_encoder, '_loaded_weights_url', None)
398 loaded_flag = getattr(mf_encoder, '_pretrained_loaded', False)
399 if loaded_flag and url_info:
400 print(f"CONFIRMED: MetaFormer '{custom_model}' using pretrained weights from: {url_info}")
401 logger.info(f"CONFIRMED: MetaFormer '{custom_model}' using pretrained weights from: {url_info}")
402 else:
403 print(f"CONFIRMED: MetaFormer '{custom_model}' using randomly initialized weights (no pretrained)")
404 logger.info(f"CONFIRMED: MetaFormer '{custom_model}' using randomly initialized weights")
405 except Exception:
406 pass
407 else:
408 original_stacked_cnn_init(self, *args, **kwargs)
409 finally:
410 if hasattr(patch_ludwig_direct, '_metaformer_model'):
411 patch_ludwig_direct._metaformer_model = None
412
413 Stacked2DCNN.__init__ = patched_stacked_cnn_init
414 return True
415 except Exception as e:
416 logger.error(f"Failed to apply MetaFormer direct patch: {e}")
417 return False
418
419
420 def set_current_metaformer_model(model_name: str):
421 """Store the current MetaFormer model name for the patch to use."""
422 setattr(patch_ludwig_direct, '_metaformer_model', model_name)
423
424
425 def clear_current_metaformer_model():
426 """Remove any cached MetaFormer model hint."""
427 if hasattr(patch_ludwig_direct, '_metaformer_model'):
428 delattr(patch_ludwig_direct, '_metaformer_model')