comparison MetaFormer/metaformer_models.py @ 11:c5150cceab47 draft default tip

planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
author goeckslab
date Sat, 18 Oct 2025 03:17:09 +0000
parents
children
comparison
equal deleted inserted replaced
10:b0d893d04d4c 11:c5150cceab47
1 """
2 MetaFormer baselines including IdentityFormer, RandFormer, PoolFormerV2,
3 ConvFormer and CAFormer.
4 Standalone implementation for Galaxy Image Learner tool (no timm dependency).
5 """
6 import logging
7 from functools import partial
8
9 import torch
10 import torch.nn as nn
11 import torch.nn.functional as F
12 from torch.nn.init import trunc_normal_ # use torch's built-in truncated normal
13
14 logger = logging.getLogger(__name__)
15
16
17 def to_2tuple(v):
18 if isinstance(v, (list, tuple)):
19 return tuple(v)
20 return (v, v)
21
22
23 class DropPath(nn.Module):
24 def __init__(self, drop_prob: float = 0.0):
25 super().__init__()
26 self.drop_prob = float(drop_prob)
27
28 def forward(self, x):
29 if self.drop_prob == 0.0 or not self.training:
30 return x
31 keep_prob = 1.0 - self.drop_prob
32 shape = (x.shape[0],) + (1,) * (x.ndim - 1)
33 random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
34 random_tensor.floor_()
35 return x.div(keep_prob) * random_tensor
36
37
38 # ImageNet normalization constants
39 IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
40 IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
41
42
43 def register_model(fn):
44 # no-op decorator to mirror timm API without dependency
45 return fn
46
47
48 def _cfg(url: str = '', **kwargs):
49 return {
50 'url': url,
51 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
52 'crop_pct': 1.0, 'interpolation': 'bicubic',
53 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'classifier': 'head',
54 **kwargs
55 }
56
57
58 default_cfgs = {
59 'identityformer_s12': _cfg(
60 url='https://huggingface.co/sail/dl/resolve/main/identityformer/identityformer_s12.pth'),
61 'identityformer_s24': _cfg(
62 url='https://huggingface.co/sail/dl/resolve/main/identityformer/identityformer_s24.pth'),
63 'identityformer_s36': _cfg(
64 url='https://huggingface.co/sail/dl/resolve/main/identityformer/identityformer_s36.pth'),
65 'identityformer_m36': _cfg(
66 url='https://huggingface.co/sail/dl/resolve/main/identityformer/identityformer_m36.pth'),
67 'identityformer_m48': _cfg(
68 url='https://huggingface.co/sail/dl/resolve/main/identityformer/identityformer_m48.pth'),
69
70 'randformer_s12': _cfg(
71 url='https://huggingface.co/sail/dl/resolve/main/randformer/randformer_s12.pth'),
72 'randformer_s24': _cfg(
73 url='https://huggingface.co/sail/dl/resolve/main/randformer/randformer_s24.pth'),
74 'randformer_s36': _cfg(
75 url='https://huggingface.co/sail/dl/resolve/main/randformer/randformer_s36.pth'),
76 'randformer_m36': _cfg(
77 url='https://huggingface.co/sail/dl/resolve/main/randformer/randformer_m36.pth'),
78 'randformer_m48': _cfg(
79 url='https://huggingface.co/sail/dl/resolve/main/randformer/randformer_m48.pth'),
80
81 'poolformerv2_s12': _cfg(
82 url='https://huggingface.co/sail/dl/resolve/main/poolformerv2/poolformerv2_s12.pth'),
83 'poolformerv2_s24': _cfg(
84 url='https://huggingface.co/sail/dl/resolve/main/poolformerv2/poolformerv2_s24.pth'),
85 'poolformerv2_s36': _cfg(
86 url='https://huggingface.co/sail/dl/resolve/main/poolformerv2/poolformerv2_s36.pth'),
87 'poolformerv2_m36': _cfg(
88 url='https://huggingface.co/sail/dl/resolve/main/poolformerv2/poolformerv2_m36.pth'),
89 'poolformerv2_m48': _cfg(
90 url='https://huggingface.co/sail/dl/resolve/main/poolformerv2/poolformerv2_m48.pth'),
91
92 'convformer_s18': _cfg(
93 url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s18.pth'),
94 'convformer_s18_384': _cfg(
95 url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s18_384.pth',
96 input_size=(3, 384, 384)),
97 'convformer_s18_in21ft1k': _cfg(
98 url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s18_in21ft1k.pth'),
99 'convformer_s18_384_in21ft1k': _cfg(
100 url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s18_384_in21ft1k.pth',
101 input_size=(3, 384, 384)),
102 'convformer_s18_in21k': _cfg(
103 url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s18_in21k.pth',
104 num_classes=21841),
105
106 'convformer_s36': _cfg(
107 url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s36.pth'),
108 'convformer_s36_384': _cfg(
109 url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s36_384.pth',
110 input_size=(3, 384, 384)),
111 'convformer_s36_in21ft1k': _cfg(
112 url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s36_in21ft1k.pth'),
113 'convformer_s36_384_in21ft1k': _cfg(
114 url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s36_384_in21ft1k.pth',
115 input_size=(3, 384, 384)),
116 'convformer_s36_in21k': _cfg(
117 url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s36_in21k.pth',
118 num_classes=21841),
119
120 'convformer_m36': _cfg(
121 url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_m36.pth'),
122 'convformer_m36_384': _cfg(
123 url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_m36_384.pth',
124 input_size=(3, 384, 384)),
125 'convformer_m36_in21ft1k': _cfg(
126 url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_m36_in21ft1k.pth'),
127 'convformer_m36_384_in21ft1k': _cfg(
128 url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_m36_384_in21ft1k.pth',
129 input_size=(3, 384, 384)),
130 'convformer_m36_in21k': _cfg(
131 url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_m36_in21k.pth',
132 num_classes=21841),
133
134 'convformer_b36': _cfg(
135 url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_b36.pth'),
136 'convformer_b36_384': _cfg(
137 url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_b36_384.pth',
138 input_size=(3, 384, 384)),
139 'convformer_b36_in21ft1k': _cfg(
140 url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_b36_in21ft1k.pth'),
141 'convformer_b36_384_in21ft1k': _cfg(
142 url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_b36_384_in21ft1k.pth',
143 input_size=(3, 384, 384)),
144 'convformer_b36_in21k': _cfg(
145 url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_b36_in21k.pth',
146 num_classes=21841),
147
148 'caformer_s18': _cfg(
149 url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s18.pth'),
150 'caformer_s18_384': _cfg(
151 url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s18_384.pth',
152 input_size=(3, 384, 384)),
153 'caformer_s18_in21ft1k': _cfg(
154 url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s18_in21ft1k.pth'),
155 'caformer_s18_384_in21ft1k': _cfg(
156 url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s18_384_in21ft1k.pth',
157 input_size=(3, 384, 384)),
158 'caformer_s18_in21k': _cfg(
159 url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s18_in21k.pth',
160 num_classes=21841),
161
162 'caformer_s36': _cfg(
163 url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s36.pth'),
164 'caformer_s36_384': _cfg(
165 url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s36_384.pth',
166 input_size=(3, 384, 384)),
167 'caformer_s36_in21ft1k': _cfg(
168 url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s36_in21ft1k.pth'),
169 'caformer_s36_384_in21ft1k': _cfg(
170 url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s36_384_in21ft1k.pth',
171 input_size=(3, 384, 384)),
172 'caformer_s36_in21k': _cfg(
173 url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s36_in21k.pth',
174 num_classes=21841),
175
176 'caformer_m36': _cfg(
177 url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_m36.pth'),
178 'caformer_m36_384': _cfg(
179 url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_m36_384.pth',
180 input_size=(3, 384, 384)),
181 'caformer_m36_in21ft1k': _cfg(
182 url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_m36_in21ft1k.pth'),
183 'caformer_m36_384_in21ft1k': _cfg(
184 url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_m36_384_in21ft1k.pth',
185 input_size=(3, 384, 384)),
186 'caformer_m36_in21k': _cfg(
187 url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_m36_in21k.pth',
188 num_classes=21841),
189
190 'caformer_b36': _cfg(
191 url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_b36.pth'),
192 'caformer_b36_384': _cfg(
193 url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_b36_384.pth',
194 input_size=(3, 384, 384)),
195 'caformer_b36_in21ft1k': _cfg(
196 url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_b36_in21ft1k.pth'),
197 'caformer_b36_384_in21ft1k': _cfg(
198 url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_b36_384_in21ft1k.pth',
199 input_size=(3, 384, 384)),
200 'caformer_b36_in21k': _cfg(
201 url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_b36_in21k.pth',
202 num_classes=21841),
203 }
204
205
206 class Downsampling(nn.Module):
207 """Downsampling implemented by a layer of convolution."""
208 def __init__(self, in_channels, out_channels,
209 kernel_size, stride=1, padding=0,
210 pre_norm=None, post_norm=None, pre_permute=False):
211 super().__init__()
212 self.pre_norm = pre_norm(in_channels) if pre_norm else nn.Identity()
213 self.pre_permute = pre_permute
214 self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size,
215 stride=stride, padding=padding)
216 self.post_norm = post_norm(out_channels) if post_norm else nn.Identity()
217
218 def forward(self, x):
219 x = self.pre_norm(x)
220 if self.pre_permute:
221 x = x.permute(0, 3, 1, 2)
222 x = self.conv(x)
223 x = x.permute(0, 2, 3, 1)
224 x = self.post_norm(x)
225 return x
226
227
228 class Scale(nn.Module):
229 """Scale vector by element multiplications."""
230 def __init__(self, dim, init_value=1.0, trainable=True):
231 super().__init__()
232 self.scale = nn.Parameter(init_value * torch.ones(dim), requires_grad=trainable)
233
234 def forward(self, x):
235 return x * self.scale
236
237
238 class SquaredReLU(nn.Module):
239 """Squared ReLU: https://arxiv.org/abs/2109.08668"""
240 def __init__(self, inplace=False):
241 super().__init__()
242 self.relu = nn.ReLU(inplace=inplace)
243
244 def forward(self, x):
245 return torch.square(self.relu(x))
246
247
248 class StarReLU(nn.Module):
249 """StarReLU: s * relu(x) ** 2 + b"""
250 def __init__(self, scale_value=1.0, bias_value=0.0,
251 scale_learnable=True, bias_learnable=True,
252 mode=None, inplace=False):
253 super().__init__()
254 self.inplace = inplace
255 self.relu = nn.ReLU(inplace=inplace)
256 self.scale = nn.Parameter(scale_value * torch.ones(1),
257 requires_grad=scale_learnable)
258 self.bias = nn.Parameter(bias_value * torch.ones(1),
259 requires_grad=bias_learnable)
260
261 def forward(self, x):
262 return self.scale * self.relu(x) ** 2 + self.bias
263
264
265 class Attention(nn.Module):
266 """Vanilla self-attention from Transformer."""
267 def __init__(self, dim, head_dim=32, num_heads=None, qkv_bias=False,
268 attn_drop=0., proj_drop=0., proj_bias=False, **kwargs):
269 super().__init__()
270
271 self.head_dim = head_dim
272 self.scale = head_dim ** -0.5
273
274 self.num_heads = num_heads if num_heads else dim // head_dim
275 if self.num_heads == 0:
276 self.num_heads = 1
277
278 self.attention_dim = self.num_heads * self.head_dim
279 self.qkv = nn.Linear(dim, self.attention_dim * 3, bias=qkv_bias)
280 self.attn_drop = nn.Dropout(attn_drop)
281 self.proj = nn.Linear(self.attention_dim, dim, bias=proj_bias)
282 self.proj_drop = nn.Dropout(proj_drop)
283
284 def forward(self, x):
285 B, H, W, C = x.shape
286 N = H * W
287 qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
288 q, k, v = qkv.unbind(0)
289 attn = (q @ k.transpose(-2, -1)) * self.scale
290 attn = attn.softmax(dim=-1)
291 attn = self.attn_drop(attn)
292 x = (attn @ v).transpose(1, 2).reshape(B, H, W, self.attention_dim)
293 x = self.proj(x)
294 x = self.proj_drop(x)
295 return x
296
297
298 class RandomMixing(nn.Module):
299 def __init__(self, num_tokens=196, **kwargs):
300 super().__init__()
301 self.num_tokens = num_tokens
302 base_matrix = torch.softmax(torch.rand(num_tokens, num_tokens), dim=-1)
303 self.register_buffer("random_matrix", base_matrix, persistent=True)
304
305 def forward(self, x):
306 B, H, W, C = x.shape
307 actual_tokens = H * W
308
309 if actual_tokens == self.random_matrix.shape[0]:
310 mixing = self.random_matrix
311 else:
312 base = self.random_matrix
313 if base.device != x.device:
314 base = base.to(x.device)
315 resized = F.interpolate(
316 base.unsqueeze(0).unsqueeze(0),
317 size=(actual_tokens, actual_tokens),
318 mode='bilinear',
319 align_corners=False,
320 ).squeeze(0).squeeze(0)
321 mixing = torch.softmax(resized, dim=-1)
322
323 x = x.reshape(B, actual_tokens, C)
324 x = torch.einsum('mn, bnc -> bmc', mixing, x)
325 x = x.reshape(B, H, W, C)
326 return x
327
328
329 class LayerNormGeneral(nn.Module):
330 """General LayerNorm for different situations."""
331 def __init__(self, affine_shape=None, normalized_dim=(-1,), scale=True,
332 bias=True, eps=1e-5):
333 super().__init__()
334 self.normalized_dim = normalized_dim
335 self.use_scale = scale
336 self.use_bias = bias
337 self.weight = nn.Parameter(torch.ones(affine_shape)) if scale else None
338 self.bias = nn.Parameter(torch.zeros(affine_shape)) if bias else None
339 self.eps = eps
340
341 def forward(self, x):
342 c = x - x.mean(self.normalized_dim, keepdim=True)
343 s = c.pow(2).mean(self.normalized_dim, keepdim=True)
344 x = c / torch.sqrt(s + self.eps)
345 if self.use_scale:
346 x = x * self.weight
347 if self.use_bias:
348 x = x + self.bias
349 return x
350
351
352 class LayerNormWithoutBias(nn.Module):
353 """Equal to partial(LayerNormGeneral, bias=False) but faster."""
354 def __init__(self, normalized_shape, eps=1e-5, **kwargs):
355 super().__init__()
356 self.eps = eps
357 self.bias = None
358 if isinstance(normalized_shape, int):
359 normalized_shape = (normalized_shape,)
360 self.weight = nn.Parameter(torch.ones(normalized_shape))
361 self.normalized_shape = normalized_shape
362
363 def forward(self, x):
364 return F.layer_norm(x, self.normalized_shape, weight=self.weight, bias=self.bias, eps=self.eps)
365
366
367 class SepConv(nn.Module):
368 """Inverted separable convolution from MobileNetV2."""
369 def __init__(self, dim, expansion_ratio=2,
370 act1_layer=StarReLU, act2_layer=nn.Identity,
371 bias=False, kernel_size=7, padding=3,
372 **kwargs):
373 super().__init__()
374 med_channels = int(expansion_ratio * dim)
375 self.pwconv1 = nn.Linear(dim, med_channels, bias=bias)
376 self.act1 = act1_layer()
377 self.dwconv = nn.Conv2d(
378 med_channels, med_channels, kernel_size=kernel_size,
379 padding=padding, groups=med_channels, bias=bias)
380 self.act2 = act2_layer()
381 self.pwconv2 = nn.Linear(med_channels, dim, bias=bias)
382
383 def forward(self, x):
384 x = self.pwconv1(x)
385 x = self.act1(x)
386 x = x.permute(0, 3, 1, 2)
387 x = self.dwconv(x)
388 x = x.permute(0, 2, 3, 1)
389 x = self.act2(x)
390 x = self.pwconv2(x)
391 return x
392
393
394 class Pooling(nn.Module):
395 """Pooling for PoolFormer."""
396 def __init__(self, pool_size=3, **kwargs):
397 super().__init__()
398 self.pool = nn.AvgPool2d(pool_size, stride=1, padding=pool_size // 2, count_include_pad=False)
399
400 def forward(self, x):
401 y = x.permute(0, 3, 1, 2)
402 y = self.pool(y)
403 y = y.permute(0, 2, 3, 1)
404 return y - x
405
406
407 class Mlp(nn.Module):
408 """ MLP used in MetaFormer models."""
409 def __init__(self, dim, mlp_ratio=4, out_features=None, act_layer=StarReLU, drop=0., bias=False, **kwargs):
410 super().__init__()
411 in_features = dim
412 out_features = out_features or in_features
413 hidden_features = int(mlp_ratio * in_features)
414 drop_probs = to_2tuple(drop)
415
416 self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
417 self.act = act_layer()
418 self.drop1 = nn.Dropout(drop_probs[0])
419 self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
420 self.drop2 = nn.Dropout(drop_probs[1])
421
422 def forward(self, x):
423 x = self.fc1(x)
424 x = self.act(x)
425 x = self.drop1(x)
426 x = self.fc2(x)
427 x = self.drop2(x)
428 return x
429
430
431 class MlpHead(nn.Module):
432 def __init__(self, dim, num_classes=1000, mlp_ratio=4, act_layer=SquaredReLU,
433 norm_layer=nn.LayerNorm, head_dropout=0., bias=True):
434 super().__init__()
435 hidden_features = int(mlp_ratio * dim)
436 self.fc1 = nn.Linear(dim, hidden_features, bias=bias)
437 self.act = act_layer()
438 self.norm = norm_layer(hidden_features)
439 self.fc2 = nn.Linear(hidden_features, num_classes, bias=bias)
440 self.head_dropout = nn.Dropout(head_dropout)
441
442 def forward(self, x):
443 x = self.fc1(x)
444 x = self.act(x)
445 x = self.norm(x)
446 x = self.head_dropout(x)
447 x = self.fc2(x)
448 return x
449
450
451 class MetaFormerBlock(nn.Module):
452 def __init__(self, dim,
453 token_mixer=nn.Identity, mlp=Mlp,
454 norm_layer=nn.LayerNorm,
455 drop=0., drop_path=0.,
456 layer_scale_init_value=None, res_scale_init_value=None):
457 super().__init__()
458 self.norm1 = norm_layer(dim)
459 self.token_mixer = token_mixer(dim=dim, drop=drop)
460 self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
461 self.layer_scale1 = Scale(dim=dim, init_value=layer_scale_init_value) if layer_scale_init_value else nn.Identity()
462 self.res_scale1 = Scale(dim=dim, init_value=res_scale_init_value) if res_scale_init_value else nn.Identity()
463
464 self.norm2 = norm_layer(dim)
465 self.mlp = mlp(dim=dim, drop=drop)
466 self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
467 self.layer_scale2 = Scale(dim=dim, init_value=layer_scale_init_value) if layer_scale_init_value else nn.Identity()
468 self.res_scale2 = Scale(dim=dim, init_value=res_scale_init_value) if res_scale_init_value else nn.Identity()
469
470 def forward(self, x):
471 x = self.res_scale1(x) + self.layer_scale1(self.drop_path1(self.token_mixer(self.norm1(x))))
472 x = self.res_scale2(x) + self.layer_scale2(self.drop_path2(self.mlp(self.norm2(x))))
473 return x
474
475
476 DOWNSAMPLE_LAYERS_FOUR_STAGES = [partial(Downsampling,
477 kernel_size=7, stride=4, padding=2,
478 post_norm=partial(LayerNormGeneral, bias=False, eps=1e-6)
479 )] + \
480 [partial(Downsampling,
481 kernel_size=3, stride=2, padding=1,
482 pre_norm=partial(LayerNormGeneral, bias=False, eps=1e-6), pre_permute=True
483 )] * 3
484
485
486 class MetaFormer(nn.Module):
487 def __init__(self, in_chans=3, num_classes=1000,
488 depths=[2, 2, 6, 2],
489 dims=[64, 128, 320, 512],
490 downsample_layers=DOWNSAMPLE_LAYERS_FOUR_STAGES,
491 token_mixers=nn.Identity,
492 mlps=Mlp,
493 norm_layers=partial(LayerNormWithoutBias, eps=1e-6),
494 drop_path_rate=0.,
495 head_dropout=0.0,
496 layer_scale_init_values=None,
497 res_scale_init_values=[None, None, 1.0, 1.0],
498 output_norm=partial(nn.LayerNorm, eps=1e-6),
499 head_fn=nn.Linear,
500 **kwargs):
501 super().__init__()
502 self.num_classes = num_classes
503
504 if not isinstance(depths, (list, tuple)):
505 depths = [depths]
506 if not isinstance(dims, (list, tuple)):
507 dims = [dims]
508
509 num_stage = len(depths)
510 self.num_stage = num_stage
511
512 if not isinstance(downsample_layers, (list, tuple)):
513 downsample_layers = [downsample_layers] * num_stage
514 down_dims = [in_chans] + dims
515 self.downsample_layers = nn.ModuleList(
516 [downsample_layers[i](down_dims[i], down_dims[i + 1]) for i in range(num_stage)]
517 )
518
519 if not isinstance(token_mixers, (list, tuple)):
520 token_mixers = [token_mixers] * num_stage
521 if not isinstance(mlps, (list, tuple)):
522 mlps = [mlps] * num_stage
523 if not isinstance(norm_layers, (list, tuple)):
524 norm_layers = [norm_layers] * num_stage
525
526 dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
527
528 if not isinstance(layer_scale_init_values, (list, tuple)):
529 layer_scale_init_values = [layer_scale_init_values] * num_stage
530 if not isinstance(res_scale_init_values, (list, tuple)):
531 res_scale_init_values = [res_scale_init_values] * num_stage
532
533 self.stages = nn.ModuleList()
534 cur = 0
535 for i in range(num_stage):
536 stage = nn.Sequential(
537 *[MetaFormerBlock(dim=dims[i],
538 token_mixer=token_mixers[i],
539 mlp=mlps[i],
540 norm_layer=norm_layers[i],
541 drop_path=dp_rates[cur + j],
542 layer_scale_init_value=layer_scale_init_values[i],
543 res_scale_init_value=res_scale_init_values[i])
544 for j in range(depths[i])]
545 )
546 self.stages.append(stage)
547 cur += depths[i]
548
549 self.norm = output_norm(dims[-1])
550 self.head = head_fn(dims[-1], num_classes) if head_dropout <= 0.0 else head_fn(dims[-1], num_classes, head_dropout=head_dropout)
551 self.apply(self._init_weights)
552
553 def _init_weights(self, m):
554 if isinstance(m, (nn.Conv2d, nn.Linear)):
555 trunc_normal_(m.weight, std=.02)
556 if m.bias is not None:
557 nn.init.constant_(m.bias, 0)
558
559 @torch.jit.ignore
560 def no_weight_decay(self):
561 return {'norm'}
562
563 def forward_features(self, x):
564 for i in range(self.num_stage):
565 x = self.downsample_layers[i](x)
566 x = self.stages[i](x)
567 return self.norm(x.mean([1, 2]))
568
569 def forward(self, x):
570 x = self.forward_features(x)
571 x = self.head(x)
572 return x
573
574
575 # ---- Model factory functions (subset, extend as needed) ----
576
577 @register_model
578 def identityformer_s12(pretrained=False, **kwargs):
579 model = MetaFormer(
580 depths=[2, 2, 6, 2],
581 dims=[64, 128, 320, 512],
582 token_mixers=nn.Identity,
583 norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False),
584 **kwargs)
585 model.default_cfg = default_cfgs['identityformer_s12']
586 if pretrained:
587 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True)
588 model.load_state_dict(state_dict)
589 return model
590
591
592 @register_model
593 def identityformer_s24(pretrained=False, **kwargs):
594 model = MetaFormer(
595 depths=[4, 4, 12, 4],
596 dims=[64, 128, 320, 512],
597 token_mixers=nn.Identity,
598 norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False),
599 **kwargs)
600 model.default_cfg = default_cfgs['identityformer_s24']
601 if pretrained:
602 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True)
603 model.load_state_dict(state_dict)
604 return model
605
606
607 @register_model
608 def identityformer_s36(pretrained=False, **kwargs):
609 model = MetaFormer(
610 depths=[6, 6, 18, 6],
611 dims=[64, 128, 320, 512],
612 token_mixers=nn.Identity,
613 norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False),
614 **kwargs)
615 model.default_cfg = default_cfgs['identityformer_s36']
616 if pretrained:
617 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True)
618 model.load_state_dict(state_dict)
619 return model
620
621
622 @register_model
623 def identityformer_m36(pretrained=False, **kwargs):
624 model = MetaFormer(
625 depths=[6, 6, 18, 6],
626 dims=[96, 192, 384, 768],
627 token_mixers=nn.Identity,
628 norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False),
629 **kwargs)
630 model.default_cfg = default_cfgs['identityformer_m36']
631 if pretrained:
632 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True)
633 model.load_state_dict(state_dict)
634 return model
635
636
637 @register_model
638 def identityformer_m48(pretrained=False, **kwargs):
639 model = MetaFormer(
640 depths=[8, 8, 24, 8],
641 dims=[96, 192, 384, 768],
642 token_mixers=nn.Identity,
643 norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False),
644 **kwargs)
645 model.default_cfg = default_cfgs['identityformer_m48']
646 if pretrained:
647 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True)
648 model.load_state_dict(state_dict)
649 return model
650
651
652 @register_model
653 def randformer_s12(pretrained=False, **kwargs):
654 model = MetaFormer(
655 depths=[2, 2, 6, 2],
656 dims=[64, 128, 320, 512],
657 token_mixers=[nn.Identity, nn.Identity, RandomMixing, partial(RandomMixing, num_tokens=49)],
658 norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False),
659 **kwargs)
660 model.default_cfg = default_cfgs['randformer_s12']
661 if pretrained:
662 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True)
663 model.load_state_dict(state_dict)
664 return model
665
666
667 @register_model
668 def randformer_s24(pretrained=False, **kwargs):
669 model = MetaFormer(
670 depths=[4, 4, 12, 4],
671 dims=[64, 128, 320, 512],
672 token_mixers=[nn.Identity, nn.Identity, RandomMixing, partial(RandomMixing, num_tokens=49)],
673 norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False),
674 **kwargs)
675 model.default_cfg = default_cfgs['randformer_s24']
676 if pretrained:
677 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True)
678 model.load_state_dict(state_dict)
679 return model
680
681
682 @register_model
683 def randformer_s36(pretrained=False, **kwargs):
684 model = MetaFormer(
685 depths=[6, 6, 18, 6],
686 dims=[64, 128, 320, 512],
687 token_mixers=[nn.Identity, nn.Identity, RandomMixing, partial(RandomMixing, num_tokens=49)],
688 norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False),
689 **kwargs)
690 model.default_cfg = default_cfgs['randformer_s36']
691 if pretrained:
692 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True)
693 model.load_state_dict(state_dict)
694 return model
695
696
697 @register_model
698 def randformer_m36(pretrained=False, **kwargs):
699 model = MetaFormer(
700 depths=[6, 6, 18, 6],
701 dims=[96, 192, 384, 768],
702 token_mixers=[nn.Identity, nn.Identity, RandomMixing, partial(RandomMixing, num_tokens=49)],
703 norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False),
704 **kwargs)
705 model.default_cfg = default_cfgs['randformer_m36']
706 if pretrained:
707 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True)
708 model.load_state_dict(state_dict)
709 return model
710
711
712 @register_model
713 def randformer_m48(pretrained=False, **kwargs):
714 model = MetaFormer(
715 depths=[8, 8, 24, 8],
716 dims=[96, 192, 384, 768],
717 token_mixers=[nn.Identity, nn.Identity, RandomMixing, partial(RandomMixing, num_tokens=49)],
718 norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False),
719 **kwargs)
720 model.default_cfg = default_cfgs['randformer_m48']
721 if pretrained:
722 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True)
723 model.load_state_dict(state_dict)
724 return model
725
726
727 @register_model
728 def poolformerv2_s12(pretrained=False, **kwargs):
729 model = MetaFormer(
730 depths=[2, 2, 6, 2],
731 dims=[64, 128, 320, 512],
732 token_mixers=Pooling,
733 norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False),
734 **kwargs)
735 model.default_cfg = default_cfgs['poolformerv2_s12']
736 if pretrained:
737 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True)
738 model.load_state_dict(state_dict)
739 return model
740
741
742 @register_model
743 def poolformerv2_s24(pretrained=False, **kwargs):
744 model = MetaFormer(
745 depths=[4, 4, 12, 4],
746 dims=[64, 128, 320, 512],
747 token_mixers=Pooling,
748 norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False),
749 **kwargs)
750 model.default_cfg = default_cfgs['poolformerv2_s24']
751 if pretrained:
752 try:
753 logger.info("Loading pretrained weights for poolformerv2_s24 from: %s", model.default_cfg['url'])
754
755 # Add timeout to prevent hanging in CI environments
756 import socket
757 original_timeout = socket.getdefaulttimeout()
758 socket.setdefaulttimeout(60) # 60 second timeout
759 try:
760 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True)
761 model.load_state_dict(state_dict)
762 print("✓ Successfully loaded pretrained weights for poolformerv2_s24")
763 finally:
764 socket.setdefaulttimeout(original_timeout)
765 except Exception as e:
766 logger.warning("Failed to load pretrained weights for poolformerv2_s24: %s", e)
767 logger.info("Continuing with randomly initialized weights...")
768 return model
769
770
771 @register_model
772 def poolformerv2_s36(pretrained=False, **kwargs):
773 model = MetaFormer(
774 depths=[6, 6, 18, 6],
775 dims=[64, 128, 320, 512],
776 token_mixers=Pooling,
777 norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False),
778 **kwargs)
779 model.default_cfg = default_cfgs['poolformerv2_s36']
780 if pretrained:
781 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True)
782 model.load_state_dict(state_dict)
783 return model
784
785
786 @register_model
787 def poolformerv2_m36(pretrained=False, **kwargs):
788 model = MetaFormer(
789 depths=[6, 6, 18, 6],
790 dims=[96, 192, 384, 768],
791 token_mixers=Pooling,
792 norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False),
793 **kwargs)
794 model.default_cfg = default_cfgs['poolformerv2_m36']
795 if pretrained:
796 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True)
797 model.load_state_dict(state_dict)
798 return model
799
800
801 @register_model
802 def poolformerv2_m48(pretrained=False, **kwargs):
803 model = MetaFormer(
804 depths=[8, 8, 24, 8],
805 dims=[96, 192, 384, 768],
806 token_mixers=Pooling,
807 norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False),
808 **kwargs)
809 model.default_cfg = default_cfgs['poolformerv2_m48']
810 if pretrained:
811 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True)
812 model.load_state_dict(state_dict)
813 return model
814
815
816 @register_model
817 def convformer_s18(pretrained=False, **kwargs):
818 model = MetaFormer(
819 depths=[3, 3, 9, 3],
820 dims=[64, 128, 320, 512],
821 token_mixers=SepConv,
822 head_fn=MlpHead,
823 **kwargs)
824 model.default_cfg = default_cfgs['convformer_s18']
825 if pretrained:
826 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True)
827 model.load_state_dict(state_dict)
828 return model
829
830
831 @register_model
832 def convformer_s18_384(pretrained=False, **kwargs):
833 model = MetaFormer(
834 depths=[3, 3, 9, 3],
835 dims=[64, 128, 320, 512],
836 token_mixers=SepConv,
837 head_fn=MlpHead,
838 **kwargs)
839 model.default_cfg = default_cfgs['convformer_s18_384']
840 if pretrained:
841 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True)
842 model.load_state_dict(state_dict)
843 return model
844
845
846 @register_model
847 def convformer_s18_in21ft1k(pretrained=False, **kwargs):
848 model = MetaFormer(
849 depths=[3, 3, 9, 3],
850 dims=[64, 128, 320, 512],
851 token_mixers=SepConv,
852 head_fn=MlpHead,
853 **kwargs)
854 model.default_cfg = default_cfgs['convformer_s18_in21ft1k']
855 if pretrained:
856 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True)
857 model.load_state_dict(state_dict)
858 return model
859
860
861 @register_model
862 def convformer_s18_384_in21ft1k(pretrained=False, **kwargs):
863 model = MetaFormer(
864 depths=[3, 3, 9, 3],
865 dims=[64, 128, 320, 512],
866 token_mixers=SepConv,
867 head_fn=MlpHead,
868 **kwargs)
869 model.default_cfg = default_cfgs['convformer_s18_384_in21ft1k']
870 if pretrained:
871 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True)
872 model.load_state_dict(state_dict)
873 return model
874
875
876 @register_model
877 def convformer_s18_in21k(pretrained=False, **kwargs):
878 model = MetaFormer(
879 depths=[3, 3, 9, 3],
880 dims=[64, 128, 320, 512],
881 token_mixers=SepConv,
882 head_fn=MlpHead,
883 **kwargs)
884 model.default_cfg = default_cfgs['convformer_s18_in21k']
885 if pretrained:
886 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True)
887 model.load_state_dict(state_dict)
888 return model
889
890
891 @register_model
892 def convformer_s36(pretrained=False, **kwargs):
893 model = MetaFormer(
894 depths=[3, 12, 18, 3],
895 dims=[64, 128, 320, 512],
896 token_mixers=SepConv,
897 head_fn=MlpHead,
898 **kwargs)
899 model.default_cfg = default_cfgs['convformer_s36']
900 if pretrained:
901 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True)
902 model.load_state_dict(state_dict)
903 return model
904
905
906 @register_model
907 def convformer_s36_384(pretrained=False, **kwargs):
908 model = MetaFormer(
909 depths=[3, 12, 18, 3],
910 dims=[64, 128, 320, 512],
911 token_mixers=SepConv,
912 head_fn=MlpHead,
913 **kwargs)
914 model.default_cfg = default_cfgs['convformer_s36_384']
915 if pretrained:
916 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True)
917 model.load_state_dict(state_dict)
918 return model
919
920
921 @register_model
922 def convformer_s36_in21ft1k(pretrained=False, **kwargs):
923 model = MetaFormer(
924 depths=[3, 12, 18, 3],
925 dims=[64, 128, 320, 512],
926 token_mixers=SepConv,
927 head_fn=MlpHead,
928 **kwargs)
929 model.default_cfg = default_cfgs['convformer_s36_in21ft1k']
930 if pretrained:
931 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True)
932 model.load_state_dict(state_dict)
933 return model
934
935
936 @register_model
937 def convformer_s36_384_in21ft1k(pretrained=False, **kwargs):
938 model = MetaFormer(
939 depths=[3, 12, 18, 3],
940 dims=[64, 128, 320, 512],
941 token_mixers=SepConv,
942 head_fn=MlpHead,
943 **kwargs)
944 model.default_cfg = default_cfgs['convformer_s36_384_in21ft1k']
945 if pretrained:
946 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True)
947 model.load_state_dict(state_dict)
948 return model
949
950
951 @register_model
952 def convformer_s36_in21k(pretrained=False, **kwargs):
953 model = MetaFormer(
954 depths=[3, 12, 18, 3],
955 dims=[64, 128, 320, 512],
956 token_mixers=SepConv,
957 head_fn=MlpHead,
958 **kwargs)
959 model.default_cfg = default_cfgs['convformer_s36_in21k']
960 if pretrained:
961 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True)
962 model.load_state_dict(state_dict)
963 return model
964
965
966 @register_model
967 def convformer_m36(pretrained=False, **kwargs):
968 model = MetaFormer(
969 depths=[3, 12, 18, 3],
970 dims=[96, 192, 384, 576],
971 token_mixers=SepConv,
972 head_fn=MlpHead,
973 **kwargs)
974 model.default_cfg = default_cfgs['convformer_m36']
975 if pretrained:
976 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True)
977 model.load_state_dict(state_dict)
978 return model
979
980
981 @register_model
982 def convformer_m36_384(pretrained=False, **kwargs):
983 model = MetaFormer(
984 depths=[3, 12, 18, 3],
985 dims=[96, 192, 384, 576],
986 token_mixers=SepConv,
987 head_fn=MlpHead,
988 **kwargs)
989 model.default_cfg = default_cfgs['convformer_m36_384']
990 if pretrained:
991 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True)
992 model.load_state_dict(state_dict)
993 return model
994
995
996 @register_model
997 def convformer_m36_in21ft1k(pretrained=False, **kwargs):
998 model = MetaFormer(
999 depths=[3, 12, 18, 3],
1000 dims=[96, 192, 384, 576],
1001 token_mixers=SepConv,
1002 head_fn=MlpHead,
1003 **kwargs)
1004 model.default_cfg = default_cfgs['convformer_m36_in21ft1k']
1005 if pretrained:
1006 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True)
1007 model.load_state_dict(state_dict)
1008 return model
1009
1010
1011 @register_model
1012 def convformer_m36_384_in21ft1k(pretrained=False, **kwargs):
1013 model = MetaFormer(
1014 depths=[3, 12, 18, 3],
1015 dims=[96, 192, 384, 576],
1016 token_mixers=SepConv,
1017 head_fn=MlpHead,
1018 **kwargs)
1019 model.default_cfg = default_cfgs['convformer_m36_384_in21ft1k']
1020 if pretrained:
1021 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True)
1022 model.load_state_dict(state_dict)
1023 return model
1024
1025
1026 @register_model
1027 def convformer_m36_in21k(pretrained=False, **kwargs):
1028 model = MetaFormer(
1029 depths=[3, 12, 18, 3],
1030 dims=[96, 192, 384, 576],
1031 token_mixers=SepConv,
1032 head_fn=MlpHead,
1033 **kwargs)
1034 model.default_cfg = default_cfgs['convformer_m36_in21k']
1035 if pretrained:
1036 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True)
1037 model.load_state_dict(state_dict)
1038 return model
1039
1040
1041 @register_model
1042 def convformer_b36(pretrained=False, **kwargs):
1043 model = MetaFormer(
1044 depths=[3, 12, 18, 3],
1045 dims=[128, 256, 512, 768],
1046 token_mixers=SepConv,
1047 head_fn=MlpHead,
1048 **kwargs)
1049 model.default_cfg = default_cfgs['convformer_b36']
1050 if pretrained:
1051 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True)
1052 model.load_state_dict(state_dict)
1053 return model
1054
1055
1056 @register_model
1057 def convformer_b36_384(pretrained=False, **kwargs):
1058 model = MetaFormer(
1059 depths=[3, 12, 18, 3],
1060 dims=[128, 256, 512, 768],
1061 token_mixers=SepConv,
1062 head_fn=MlpHead,
1063 **kwargs)
1064 model.default_cfg = default_cfgs['convformer_b36_384']
1065 if pretrained:
1066 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True)
1067 model.load_state_dict(state_dict)
1068 return model
1069
1070
1071 @register_model
1072 def convformer_b36_in21ft1k(pretrained=False, **kwargs):
1073 model = MetaFormer(
1074 depths=[3, 12, 18, 3],
1075 dims=[128, 256, 512, 768],
1076 token_mixers=SepConv,
1077 head_fn=MlpHead,
1078 **kwargs)
1079 model.default_cfg = default_cfgs['convformer_b36_in21ft1k']
1080 if pretrained:
1081 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True)
1082 model.load_state_dict(state_dict)
1083 return model
1084
1085
1086 @register_model
1087 def convformer_b36_384_in21ft1k(pretrained=False, **kwargs):
1088 model = MetaFormer(
1089 depths=[3, 12, 18, 3],
1090 dims=[128, 256, 512, 768],
1091 token_mixers=SepConv,
1092 head_fn=MlpHead,
1093 **kwargs)
1094 model.default_cfg = default_cfgs['convformer_b36_384_in21ft1k']
1095 if pretrained:
1096 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True)
1097 model.load_state_dict(state_dict)
1098 return model
1099
1100
1101 @register_model
1102 def convformer_b36_in21k(pretrained=False, **kwargs):
1103 model = MetaFormer(
1104 depths=[3, 12, 18, 3],
1105 dims=[128, 256, 512, 768],
1106 token_mixers=SepConv,
1107 head_fn=MlpHead,
1108 **kwargs)
1109 model.default_cfg = default_cfgs['convformer_b36_in21k']
1110 if pretrained:
1111 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True)
1112 model.load_state_dict(state_dict)
1113 return model
1114
1115
1116 @register_model
1117 def caformer_s18(pretrained=False, **kwargs):
1118 model = MetaFormer(
1119 depths=[3, 3, 9, 3],
1120 dims=[64, 128, 320, 512],
1121 token_mixers=[SepConv, SepConv, Attention, Attention],
1122 head_fn=MlpHead,
1123 **kwargs)
1124 model.default_cfg = default_cfgs['caformer_s18']
1125 if pretrained:
1126 try:
1127 print(f"Loading pretrained weights for caformer_s18 from: {model.default_cfg['url']}")
1128 # Add timeout to prevent hanging in CI environments
1129 import socket
1130 original_timeout = socket.getdefaulttimeout()
1131 socket.setdefaulttimeout(60) # 60 second timeout
1132 try:
1133 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True)
1134 model.load_state_dict(state_dict)
1135 print("✓ Successfully loaded pretrained weights for caformer_s18")
1136 finally:
1137 socket.setdefaulttimeout(original_timeout)
1138 except Exception as e:
1139 print(f"⚠ Warning: Failed to load pretrained weights for caformer_s18: {e}")
1140 print("Continuing with randomly initialized weights...")
1141 return model
1142
1143
1144 @register_model
1145 def caformer_s18_384(pretrained=False, **kwargs):
1146 model = MetaFormer(
1147 depths=[3, 3, 9, 3],
1148 dims=[64, 128, 320, 512],
1149 token_mixers=[SepConv, SepConv, Attention, Attention],
1150 head_fn=MlpHead,
1151 **kwargs)
1152 model.default_cfg = default_cfgs['caformer_s18_384']
1153 if pretrained:
1154 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True)
1155 model.load_state_dict(state_dict)
1156 return model
1157
1158
1159 @register_model
1160 def caformer_s18_in21ft1k(pretrained=False, **kwargs):
1161 model = MetaFormer(
1162 depths=[3, 3, 9, 3],
1163 dims=[64, 128, 320, 512],
1164 token_mixers=[SepConv, SepConv, Attention, Attention],
1165 head_fn=MlpHead,
1166 **kwargs)
1167 model.default_cfg = default_cfgs['caformer_s18_in21ft1k']
1168 if pretrained:
1169 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True)
1170 model.load_state_dict(state_dict)
1171 return model
1172
1173
1174 @register_model
1175 def caformer_s18_384_in21ft1k(pretrained=False, **kwargs):
1176 model = MetaFormer(
1177 depths=[3, 3, 9, 3],
1178 dims=[64, 128, 320, 512],
1179 token_mixers=[SepConv, SepConv, Attention, Attention],
1180 head_fn=MlpHead,
1181 **kwargs)
1182 model.default_cfg = default_cfgs['caformer_s18_384_in21ft1k']
1183 if pretrained:
1184 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True)
1185 model.load_state_dict(state_dict)
1186 return model
1187
1188
1189 @register_model
1190 def caformer_s18_in21k(pretrained=False, **kwargs):
1191 model = MetaFormer(
1192 depths=[3, 3, 9, 3],
1193 dims=[64, 128, 320, 512],
1194 token_mixers=[SepConv, SepConv, Attention, Attention],
1195 head_fn=MlpHead,
1196 **kwargs)
1197 model.default_cfg = default_cfgs['caformer_s18_in21k']
1198 if pretrained:
1199 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True)
1200 model.load_state_dict(state_dict)
1201 return model
1202
1203
1204 @register_model
1205 def caformer_s36(pretrained=False, **kwargs):
1206 model = MetaFormer(
1207 depths=[3, 12, 18, 3],
1208 dims=[64, 128, 320, 512],
1209 token_mixers=[SepConv, SepConv, Attention, Attention],
1210 head_fn=MlpHead,
1211 **kwargs)
1212 model.default_cfg = default_cfgs['caformer_s36']
1213 if pretrained:
1214 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True)
1215 model.load_state_dict(state_dict)
1216 return model
1217
1218
1219 @register_model
1220 def caformer_s36_384(pretrained=False, **kwargs):
1221 model = MetaFormer(
1222 depths=[3, 12, 18, 3],
1223 dims=[64, 128, 320, 512],
1224 token_mixers=[SepConv, SepConv, Attention, Attention],
1225 head_fn=MlpHead,
1226 **kwargs)
1227 model.default_cfg = default_cfgs['caformer_s36_384']
1228 if pretrained:
1229 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True)
1230 model.load_state_dict(state_dict)
1231 return model
1232
1233
1234 @register_model
1235 def caformer_s36_in21ft1k(pretrained=False, **kwargs):
1236 model = MetaFormer(
1237 depths=[3, 12, 18, 3],
1238 dims=[64, 128, 320, 512],
1239 token_mixers=[SepConv, SepConv, Attention, Attention],
1240 head_fn=MlpHead,
1241 **kwargs)
1242 model.default_cfg = default_cfgs['caformer_s36_in21ft1k']
1243 if pretrained:
1244 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True)
1245 model.load_state_dict(state_dict)
1246 return model
1247
1248
1249 @register_model
1250 def caformer_s36_384_in21ft1k(pretrained=False, **kwargs):
1251 model = MetaFormer(
1252 depths=[3, 12, 18, 3],
1253 dims=[64, 128, 320, 512],
1254 token_mixers=[SepConv, SepConv, Attention, Attention],
1255 head_fn=MlpHead,
1256 **kwargs)
1257 model.default_cfg = default_cfgs['caformer_s36_384_in21ft1k']
1258 if pretrained:
1259 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True)
1260 model.load_state_dict(state_dict)
1261 return model
1262
1263
1264 @register_model
1265 def caformer_s36_in21k(pretrained=False, **kwargs):
1266 model = MetaFormer(
1267 depths=[3, 12, 18, 3],
1268 dims=[64, 128, 320, 512],
1269 token_mixers=[SepConv, SepConv, Attention, Attention],
1270 head_fn=MlpHead,
1271 **kwargs)
1272 model.default_cfg = default_cfgs['caformer_s36_in21k']
1273 if pretrained:
1274 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True)
1275 model.load_state_dict(state_dict)
1276 return model
1277
1278
1279 @register_model
1280 def caformer_m36(pretrained=False, **kwargs):
1281 model = MetaFormer(
1282 depths=[3, 12, 18, 3],
1283 dims=[96, 192, 384, 576],
1284 token_mixers=[SepConv, SepConv, Attention, Attention],
1285 head_fn=MlpHead,
1286 **kwargs)
1287 model.default_cfg = default_cfgs['caformer_m36']
1288 if pretrained:
1289 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True)
1290 model.load_state_dict(state_dict)
1291 return model
1292
1293
1294 @register_model
1295 def caformer_m36_384(pretrained=False, **kwargs):
1296 model = MetaFormer(
1297 depths=[3, 12, 18, 3],
1298 dims=[96, 192, 384, 576],
1299 token_mixers=[SepConv, SepConv, Attention, Attention],
1300 head_fn=MlpHead,
1301 **kwargs)
1302 model.default_cfg = default_cfgs['caformer_m36_384']
1303 if pretrained:
1304 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True)
1305 model.load_state_dict(state_dict)
1306 return model
1307
1308
1309 @register_model
1310 def caformer_m36_in21ft1k(pretrained=False, **kwargs):
1311 model = MetaFormer(
1312 depths=[3, 12, 18, 3],
1313 dims=[96, 192, 384, 576],
1314 token_mixers=[SepConv, SepConv, Attention, Attention],
1315 head_fn=MlpHead,
1316 **kwargs)
1317 model.default_cfg = default_cfgs['caformer_m36_in21ft1k']
1318 if pretrained:
1319 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True)
1320 model.load_state_dict(state_dict)
1321 return model
1322
1323
1324 @register_model
1325 def caformer_m36_384_in21ft1k(pretrained=False, **kwargs):
1326 model = MetaFormer(
1327 depths=[3, 12, 18, 3],
1328 dims=[96, 192, 384, 576],
1329 token_mixers=[SepConv, SepConv, Attention, Attention],
1330 head_fn=MlpHead,
1331 **kwargs)
1332 model.default_cfg = default_cfgs['caformer_m36_384_in21ft1k']
1333 if pretrained:
1334 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True)
1335 model.load_state_dict(state_dict)
1336 return model
1337
1338
1339 @register_model
1340 def caformer_m36_in21k(pretrained=False, **kwargs):
1341 model = MetaFormer(
1342 depths=[3, 12, 18, 3],
1343 dims=[96, 192, 384, 576],
1344 token_mixers=[SepConv, SepConv, Attention, Attention],
1345 head_fn=MlpHead,
1346 **kwargs)
1347 model.default_cfg = default_cfgs['caformer_m36_in21k']
1348 if pretrained:
1349 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True)
1350 model.load_state_dict(state_dict)
1351 return model
1352
1353
1354 @register_model
1355 def caformer_b36(pretrained=False, **kwargs):
1356 model = MetaFormer(
1357 depths=[3, 12, 18, 3],
1358 dims=[128, 256, 512, 768],
1359 token_mixers=[SepConv, SepConv, Attention, Attention],
1360 head_fn=MlpHead,
1361 **kwargs)
1362 model.default_cfg = default_cfgs['caformer_b36']
1363 if pretrained:
1364 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True)
1365 model.load_state_dict(state_dict)
1366 return model
1367
1368
1369 @register_model
1370 def caformer_b36_384(pretrained=False, **kwargs):
1371 model = MetaFormer(
1372 depths=[3, 12, 18, 3],
1373 dims=[128, 256, 512, 768],
1374 token_mixers=[SepConv, SepConv, Attention, Attention],
1375 head_fn=MlpHead,
1376 **kwargs)
1377 model.default_cfg = default_cfgs['caformer_b36_384']
1378 if pretrained:
1379 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True)
1380 model.load_state_dict(state_dict)
1381 return model
1382
1383
1384 @register_model
1385 def caformer_b36_in21ft1k(pretrained=False, **kwargs):
1386 model = MetaFormer(
1387 depths=[3, 12, 18, 3],
1388 dims=[128, 256, 512, 768],
1389 token_mixers=[SepConv, SepConv, Attention, Attention],
1390 head_fn=MlpHead,
1391 **kwargs)
1392 model.default_cfg = default_cfgs['caformer_b36_in21ft1k']
1393 if pretrained:
1394 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True)
1395 model.load_state_dict(state_dict)
1396 return model
1397
1398
1399 @register_model
1400 def caformer_b36_384_in21ft1k(pretrained=False, **kwargs):
1401 model = MetaFormer(
1402 depths=[3, 12, 18, 3],
1403 dims=[128, 256, 512, 768],
1404 token_mixers=[SepConv, SepConv, Attention, Attention],
1405 head_fn=MlpHead,
1406 **kwargs)
1407 model.default_cfg = default_cfgs['caformer_b36_384_in21ft1k']
1408 if pretrained:
1409 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True)
1410 model.load_state_dict(state_dict)
1411 return model
1412
1413
1414 @register_model
1415 def caformer_b36_in21k(pretrained=False, **kwargs):
1416 model = MetaFormer(
1417 depths=[3, 12, 18, 3],
1418 dims=[128, 256, 512, 768],
1419 token_mixers=[SepConv, SepConv, Attention, Attention],
1420 head_fn=MlpHead,
1421 **kwargs)
1422 model.default_cfg = default_cfgs['caformer_b36_in21k']
1423 if pretrained:
1424 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True)
1425 model.load_state_dict(state_dict)
1426 return model