Mercurial > repos > goeckslab > image_learner
comparison MetaFormer/metaformer_models.py @ 11:c5150cceab47 draft default tip
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
author | goeckslab |
---|---|
date | Sat, 18 Oct 2025 03:17:09 +0000 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
10:b0d893d04d4c | 11:c5150cceab47 |
---|---|
1 """ | |
2 MetaFormer baselines including IdentityFormer, RandFormer, PoolFormerV2, | |
3 ConvFormer and CAFormer. | |
4 Standalone implementation for Galaxy Image Learner tool (no timm dependency). | |
5 """ | |
6 import logging | |
7 from functools import partial | |
8 | |
9 import torch | |
10 import torch.nn as nn | |
11 import torch.nn.functional as F | |
12 from torch.nn.init import trunc_normal_ # use torch's built-in truncated normal | |
13 | |
14 logger = logging.getLogger(__name__) | |
15 | |
16 | |
17 def to_2tuple(v): | |
18 if isinstance(v, (list, tuple)): | |
19 return tuple(v) | |
20 return (v, v) | |
21 | |
22 | |
23 class DropPath(nn.Module): | |
24 def __init__(self, drop_prob: float = 0.0): | |
25 super().__init__() | |
26 self.drop_prob = float(drop_prob) | |
27 | |
28 def forward(self, x): | |
29 if self.drop_prob == 0.0 or not self.training: | |
30 return x | |
31 keep_prob = 1.0 - self.drop_prob | |
32 shape = (x.shape[0],) + (1,) * (x.ndim - 1) | |
33 random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) | |
34 random_tensor.floor_() | |
35 return x.div(keep_prob) * random_tensor | |
36 | |
37 | |
38 # ImageNet normalization constants | |
39 IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) | |
40 IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) | |
41 | |
42 | |
43 def register_model(fn): | |
44 # no-op decorator to mirror timm API without dependency | |
45 return fn | |
46 | |
47 | |
48 def _cfg(url: str = '', **kwargs): | |
49 return { | |
50 'url': url, | |
51 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, | |
52 'crop_pct': 1.0, 'interpolation': 'bicubic', | |
53 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'classifier': 'head', | |
54 **kwargs | |
55 } | |
56 | |
57 | |
58 default_cfgs = { | |
59 'identityformer_s12': _cfg( | |
60 url='https://huggingface.co/sail/dl/resolve/main/identityformer/identityformer_s12.pth'), | |
61 'identityformer_s24': _cfg( | |
62 url='https://huggingface.co/sail/dl/resolve/main/identityformer/identityformer_s24.pth'), | |
63 'identityformer_s36': _cfg( | |
64 url='https://huggingface.co/sail/dl/resolve/main/identityformer/identityformer_s36.pth'), | |
65 'identityformer_m36': _cfg( | |
66 url='https://huggingface.co/sail/dl/resolve/main/identityformer/identityformer_m36.pth'), | |
67 'identityformer_m48': _cfg( | |
68 url='https://huggingface.co/sail/dl/resolve/main/identityformer/identityformer_m48.pth'), | |
69 | |
70 'randformer_s12': _cfg( | |
71 url='https://huggingface.co/sail/dl/resolve/main/randformer/randformer_s12.pth'), | |
72 'randformer_s24': _cfg( | |
73 url='https://huggingface.co/sail/dl/resolve/main/randformer/randformer_s24.pth'), | |
74 'randformer_s36': _cfg( | |
75 url='https://huggingface.co/sail/dl/resolve/main/randformer/randformer_s36.pth'), | |
76 'randformer_m36': _cfg( | |
77 url='https://huggingface.co/sail/dl/resolve/main/randformer/randformer_m36.pth'), | |
78 'randformer_m48': _cfg( | |
79 url='https://huggingface.co/sail/dl/resolve/main/randformer/randformer_m48.pth'), | |
80 | |
81 'poolformerv2_s12': _cfg( | |
82 url='https://huggingface.co/sail/dl/resolve/main/poolformerv2/poolformerv2_s12.pth'), | |
83 'poolformerv2_s24': _cfg( | |
84 url='https://huggingface.co/sail/dl/resolve/main/poolformerv2/poolformerv2_s24.pth'), | |
85 'poolformerv2_s36': _cfg( | |
86 url='https://huggingface.co/sail/dl/resolve/main/poolformerv2/poolformerv2_s36.pth'), | |
87 'poolformerv2_m36': _cfg( | |
88 url='https://huggingface.co/sail/dl/resolve/main/poolformerv2/poolformerv2_m36.pth'), | |
89 'poolformerv2_m48': _cfg( | |
90 url='https://huggingface.co/sail/dl/resolve/main/poolformerv2/poolformerv2_m48.pth'), | |
91 | |
92 'convformer_s18': _cfg( | |
93 url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s18.pth'), | |
94 'convformer_s18_384': _cfg( | |
95 url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s18_384.pth', | |
96 input_size=(3, 384, 384)), | |
97 'convformer_s18_in21ft1k': _cfg( | |
98 url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s18_in21ft1k.pth'), | |
99 'convformer_s18_384_in21ft1k': _cfg( | |
100 url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s18_384_in21ft1k.pth', | |
101 input_size=(3, 384, 384)), | |
102 'convformer_s18_in21k': _cfg( | |
103 url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s18_in21k.pth', | |
104 num_classes=21841), | |
105 | |
106 'convformer_s36': _cfg( | |
107 url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s36.pth'), | |
108 'convformer_s36_384': _cfg( | |
109 url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s36_384.pth', | |
110 input_size=(3, 384, 384)), | |
111 'convformer_s36_in21ft1k': _cfg( | |
112 url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s36_in21ft1k.pth'), | |
113 'convformer_s36_384_in21ft1k': _cfg( | |
114 url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s36_384_in21ft1k.pth', | |
115 input_size=(3, 384, 384)), | |
116 'convformer_s36_in21k': _cfg( | |
117 url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s36_in21k.pth', | |
118 num_classes=21841), | |
119 | |
120 'convformer_m36': _cfg( | |
121 url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_m36.pth'), | |
122 'convformer_m36_384': _cfg( | |
123 url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_m36_384.pth', | |
124 input_size=(3, 384, 384)), | |
125 'convformer_m36_in21ft1k': _cfg( | |
126 url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_m36_in21ft1k.pth'), | |
127 'convformer_m36_384_in21ft1k': _cfg( | |
128 url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_m36_384_in21ft1k.pth', | |
129 input_size=(3, 384, 384)), | |
130 'convformer_m36_in21k': _cfg( | |
131 url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_m36_in21k.pth', | |
132 num_classes=21841), | |
133 | |
134 'convformer_b36': _cfg( | |
135 url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_b36.pth'), | |
136 'convformer_b36_384': _cfg( | |
137 url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_b36_384.pth', | |
138 input_size=(3, 384, 384)), | |
139 'convformer_b36_in21ft1k': _cfg( | |
140 url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_b36_in21ft1k.pth'), | |
141 'convformer_b36_384_in21ft1k': _cfg( | |
142 url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_b36_384_in21ft1k.pth', | |
143 input_size=(3, 384, 384)), | |
144 'convformer_b36_in21k': _cfg( | |
145 url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_b36_in21k.pth', | |
146 num_classes=21841), | |
147 | |
148 'caformer_s18': _cfg( | |
149 url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s18.pth'), | |
150 'caformer_s18_384': _cfg( | |
151 url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s18_384.pth', | |
152 input_size=(3, 384, 384)), | |
153 'caformer_s18_in21ft1k': _cfg( | |
154 url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s18_in21ft1k.pth'), | |
155 'caformer_s18_384_in21ft1k': _cfg( | |
156 url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s18_384_in21ft1k.pth', | |
157 input_size=(3, 384, 384)), | |
158 'caformer_s18_in21k': _cfg( | |
159 url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s18_in21k.pth', | |
160 num_classes=21841), | |
161 | |
162 'caformer_s36': _cfg( | |
163 url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s36.pth'), | |
164 'caformer_s36_384': _cfg( | |
165 url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s36_384.pth', | |
166 input_size=(3, 384, 384)), | |
167 'caformer_s36_in21ft1k': _cfg( | |
168 url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s36_in21ft1k.pth'), | |
169 'caformer_s36_384_in21ft1k': _cfg( | |
170 url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s36_384_in21ft1k.pth', | |
171 input_size=(3, 384, 384)), | |
172 'caformer_s36_in21k': _cfg( | |
173 url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s36_in21k.pth', | |
174 num_classes=21841), | |
175 | |
176 'caformer_m36': _cfg( | |
177 url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_m36.pth'), | |
178 'caformer_m36_384': _cfg( | |
179 url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_m36_384.pth', | |
180 input_size=(3, 384, 384)), | |
181 'caformer_m36_in21ft1k': _cfg( | |
182 url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_m36_in21ft1k.pth'), | |
183 'caformer_m36_384_in21ft1k': _cfg( | |
184 url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_m36_384_in21ft1k.pth', | |
185 input_size=(3, 384, 384)), | |
186 'caformer_m36_in21k': _cfg( | |
187 url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_m36_in21k.pth', | |
188 num_classes=21841), | |
189 | |
190 'caformer_b36': _cfg( | |
191 url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_b36.pth'), | |
192 'caformer_b36_384': _cfg( | |
193 url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_b36_384.pth', | |
194 input_size=(3, 384, 384)), | |
195 'caformer_b36_in21ft1k': _cfg( | |
196 url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_b36_in21ft1k.pth'), | |
197 'caformer_b36_384_in21ft1k': _cfg( | |
198 url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_b36_384_in21ft1k.pth', | |
199 input_size=(3, 384, 384)), | |
200 'caformer_b36_in21k': _cfg( | |
201 url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_b36_in21k.pth', | |
202 num_classes=21841), | |
203 } | |
204 | |
205 | |
206 class Downsampling(nn.Module): | |
207 """Downsampling implemented by a layer of convolution.""" | |
208 def __init__(self, in_channels, out_channels, | |
209 kernel_size, stride=1, padding=0, | |
210 pre_norm=None, post_norm=None, pre_permute=False): | |
211 super().__init__() | |
212 self.pre_norm = pre_norm(in_channels) if pre_norm else nn.Identity() | |
213 self.pre_permute = pre_permute | |
214 self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, | |
215 stride=stride, padding=padding) | |
216 self.post_norm = post_norm(out_channels) if post_norm else nn.Identity() | |
217 | |
218 def forward(self, x): | |
219 x = self.pre_norm(x) | |
220 if self.pre_permute: | |
221 x = x.permute(0, 3, 1, 2) | |
222 x = self.conv(x) | |
223 x = x.permute(0, 2, 3, 1) | |
224 x = self.post_norm(x) | |
225 return x | |
226 | |
227 | |
228 class Scale(nn.Module): | |
229 """Scale vector by element multiplications.""" | |
230 def __init__(self, dim, init_value=1.0, trainable=True): | |
231 super().__init__() | |
232 self.scale = nn.Parameter(init_value * torch.ones(dim), requires_grad=trainable) | |
233 | |
234 def forward(self, x): | |
235 return x * self.scale | |
236 | |
237 | |
238 class SquaredReLU(nn.Module): | |
239 """Squared ReLU: https://arxiv.org/abs/2109.08668""" | |
240 def __init__(self, inplace=False): | |
241 super().__init__() | |
242 self.relu = nn.ReLU(inplace=inplace) | |
243 | |
244 def forward(self, x): | |
245 return torch.square(self.relu(x)) | |
246 | |
247 | |
248 class StarReLU(nn.Module): | |
249 """StarReLU: s * relu(x) ** 2 + b""" | |
250 def __init__(self, scale_value=1.0, bias_value=0.0, | |
251 scale_learnable=True, bias_learnable=True, | |
252 mode=None, inplace=False): | |
253 super().__init__() | |
254 self.inplace = inplace | |
255 self.relu = nn.ReLU(inplace=inplace) | |
256 self.scale = nn.Parameter(scale_value * torch.ones(1), | |
257 requires_grad=scale_learnable) | |
258 self.bias = nn.Parameter(bias_value * torch.ones(1), | |
259 requires_grad=bias_learnable) | |
260 | |
261 def forward(self, x): | |
262 return self.scale * self.relu(x) ** 2 + self.bias | |
263 | |
264 | |
265 class Attention(nn.Module): | |
266 """Vanilla self-attention from Transformer.""" | |
267 def __init__(self, dim, head_dim=32, num_heads=None, qkv_bias=False, | |
268 attn_drop=0., proj_drop=0., proj_bias=False, **kwargs): | |
269 super().__init__() | |
270 | |
271 self.head_dim = head_dim | |
272 self.scale = head_dim ** -0.5 | |
273 | |
274 self.num_heads = num_heads if num_heads else dim // head_dim | |
275 if self.num_heads == 0: | |
276 self.num_heads = 1 | |
277 | |
278 self.attention_dim = self.num_heads * self.head_dim | |
279 self.qkv = nn.Linear(dim, self.attention_dim * 3, bias=qkv_bias) | |
280 self.attn_drop = nn.Dropout(attn_drop) | |
281 self.proj = nn.Linear(self.attention_dim, dim, bias=proj_bias) | |
282 self.proj_drop = nn.Dropout(proj_drop) | |
283 | |
284 def forward(self, x): | |
285 B, H, W, C = x.shape | |
286 N = H * W | |
287 qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) | |
288 q, k, v = qkv.unbind(0) | |
289 attn = (q @ k.transpose(-2, -1)) * self.scale | |
290 attn = attn.softmax(dim=-1) | |
291 attn = self.attn_drop(attn) | |
292 x = (attn @ v).transpose(1, 2).reshape(B, H, W, self.attention_dim) | |
293 x = self.proj(x) | |
294 x = self.proj_drop(x) | |
295 return x | |
296 | |
297 | |
298 class RandomMixing(nn.Module): | |
299 def __init__(self, num_tokens=196, **kwargs): | |
300 super().__init__() | |
301 self.num_tokens = num_tokens | |
302 base_matrix = torch.softmax(torch.rand(num_tokens, num_tokens), dim=-1) | |
303 self.register_buffer("random_matrix", base_matrix, persistent=True) | |
304 | |
305 def forward(self, x): | |
306 B, H, W, C = x.shape | |
307 actual_tokens = H * W | |
308 | |
309 if actual_tokens == self.random_matrix.shape[0]: | |
310 mixing = self.random_matrix | |
311 else: | |
312 base = self.random_matrix | |
313 if base.device != x.device: | |
314 base = base.to(x.device) | |
315 resized = F.interpolate( | |
316 base.unsqueeze(0).unsqueeze(0), | |
317 size=(actual_tokens, actual_tokens), | |
318 mode='bilinear', | |
319 align_corners=False, | |
320 ).squeeze(0).squeeze(0) | |
321 mixing = torch.softmax(resized, dim=-1) | |
322 | |
323 x = x.reshape(B, actual_tokens, C) | |
324 x = torch.einsum('mn, bnc -> bmc', mixing, x) | |
325 x = x.reshape(B, H, W, C) | |
326 return x | |
327 | |
328 | |
329 class LayerNormGeneral(nn.Module): | |
330 """General LayerNorm for different situations.""" | |
331 def __init__(self, affine_shape=None, normalized_dim=(-1,), scale=True, | |
332 bias=True, eps=1e-5): | |
333 super().__init__() | |
334 self.normalized_dim = normalized_dim | |
335 self.use_scale = scale | |
336 self.use_bias = bias | |
337 self.weight = nn.Parameter(torch.ones(affine_shape)) if scale else None | |
338 self.bias = nn.Parameter(torch.zeros(affine_shape)) if bias else None | |
339 self.eps = eps | |
340 | |
341 def forward(self, x): | |
342 c = x - x.mean(self.normalized_dim, keepdim=True) | |
343 s = c.pow(2).mean(self.normalized_dim, keepdim=True) | |
344 x = c / torch.sqrt(s + self.eps) | |
345 if self.use_scale: | |
346 x = x * self.weight | |
347 if self.use_bias: | |
348 x = x + self.bias | |
349 return x | |
350 | |
351 | |
352 class LayerNormWithoutBias(nn.Module): | |
353 """Equal to partial(LayerNormGeneral, bias=False) but faster.""" | |
354 def __init__(self, normalized_shape, eps=1e-5, **kwargs): | |
355 super().__init__() | |
356 self.eps = eps | |
357 self.bias = None | |
358 if isinstance(normalized_shape, int): | |
359 normalized_shape = (normalized_shape,) | |
360 self.weight = nn.Parameter(torch.ones(normalized_shape)) | |
361 self.normalized_shape = normalized_shape | |
362 | |
363 def forward(self, x): | |
364 return F.layer_norm(x, self.normalized_shape, weight=self.weight, bias=self.bias, eps=self.eps) | |
365 | |
366 | |
367 class SepConv(nn.Module): | |
368 """Inverted separable convolution from MobileNetV2.""" | |
369 def __init__(self, dim, expansion_ratio=2, | |
370 act1_layer=StarReLU, act2_layer=nn.Identity, | |
371 bias=False, kernel_size=7, padding=3, | |
372 **kwargs): | |
373 super().__init__() | |
374 med_channels = int(expansion_ratio * dim) | |
375 self.pwconv1 = nn.Linear(dim, med_channels, bias=bias) | |
376 self.act1 = act1_layer() | |
377 self.dwconv = nn.Conv2d( | |
378 med_channels, med_channels, kernel_size=kernel_size, | |
379 padding=padding, groups=med_channels, bias=bias) | |
380 self.act2 = act2_layer() | |
381 self.pwconv2 = nn.Linear(med_channels, dim, bias=bias) | |
382 | |
383 def forward(self, x): | |
384 x = self.pwconv1(x) | |
385 x = self.act1(x) | |
386 x = x.permute(0, 3, 1, 2) | |
387 x = self.dwconv(x) | |
388 x = x.permute(0, 2, 3, 1) | |
389 x = self.act2(x) | |
390 x = self.pwconv2(x) | |
391 return x | |
392 | |
393 | |
394 class Pooling(nn.Module): | |
395 """Pooling for PoolFormer.""" | |
396 def __init__(self, pool_size=3, **kwargs): | |
397 super().__init__() | |
398 self.pool = nn.AvgPool2d(pool_size, stride=1, padding=pool_size // 2, count_include_pad=False) | |
399 | |
400 def forward(self, x): | |
401 y = x.permute(0, 3, 1, 2) | |
402 y = self.pool(y) | |
403 y = y.permute(0, 2, 3, 1) | |
404 return y - x | |
405 | |
406 | |
407 class Mlp(nn.Module): | |
408 """ MLP used in MetaFormer models.""" | |
409 def __init__(self, dim, mlp_ratio=4, out_features=None, act_layer=StarReLU, drop=0., bias=False, **kwargs): | |
410 super().__init__() | |
411 in_features = dim | |
412 out_features = out_features or in_features | |
413 hidden_features = int(mlp_ratio * in_features) | |
414 drop_probs = to_2tuple(drop) | |
415 | |
416 self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) | |
417 self.act = act_layer() | |
418 self.drop1 = nn.Dropout(drop_probs[0]) | |
419 self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) | |
420 self.drop2 = nn.Dropout(drop_probs[1]) | |
421 | |
422 def forward(self, x): | |
423 x = self.fc1(x) | |
424 x = self.act(x) | |
425 x = self.drop1(x) | |
426 x = self.fc2(x) | |
427 x = self.drop2(x) | |
428 return x | |
429 | |
430 | |
431 class MlpHead(nn.Module): | |
432 def __init__(self, dim, num_classes=1000, mlp_ratio=4, act_layer=SquaredReLU, | |
433 norm_layer=nn.LayerNorm, head_dropout=0., bias=True): | |
434 super().__init__() | |
435 hidden_features = int(mlp_ratio * dim) | |
436 self.fc1 = nn.Linear(dim, hidden_features, bias=bias) | |
437 self.act = act_layer() | |
438 self.norm = norm_layer(hidden_features) | |
439 self.fc2 = nn.Linear(hidden_features, num_classes, bias=bias) | |
440 self.head_dropout = nn.Dropout(head_dropout) | |
441 | |
442 def forward(self, x): | |
443 x = self.fc1(x) | |
444 x = self.act(x) | |
445 x = self.norm(x) | |
446 x = self.head_dropout(x) | |
447 x = self.fc2(x) | |
448 return x | |
449 | |
450 | |
451 class MetaFormerBlock(nn.Module): | |
452 def __init__(self, dim, | |
453 token_mixer=nn.Identity, mlp=Mlp, | |
454 norm_layer=nn.LayerNorm, | |
455 drop=0., drop_path=0., | |
456 layer_scale_init_value=None, res_scale_init_value=None): | |
457 super().__init__() | |
458 self.norm1 = norm_layer(dim) | |
459 self.token_mixer = token_mixer(dim=dim, drop=drop) | |
460 self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() | |
461 self.layer_scale1 = Scale(dim=dim, init_value=layer_scale_init_value) if layer_scale_init_value else nn.Identity() | |
462 self.res_scale1 = Scale(dim=dim, init_value=res_scale_init_value) if res_scale_init_value else nn.Identity() | |
463 | |
464 self.norm2 = norm_layer(dim) | |
465 self.mlp = mlp(dim=dim, drop=drop) | |
466 self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() | |
467 self.layer_scale2 = Scale(dim=dim, init_value=layer_scale_init_value) if layer_scale_init_value else nn.Identity() | |
468 self.res_scale2 = Scale(dim=dim, init_value=res_scale_init_value) if res_scale_init_value else nn.Identity() | |
469 | |
470 def forward(self, x): | |
471 x = self.res_scale1(x) + self.layer_scale1(self.drop_path1(self.token_mixer(self.norm1(x)))) | |
472 x = self.res_scale2(x) + self.layer_scale2(self.drop_path2(self.mlp(self.norm2(x)))) | |
473 return x | |
474 | |
475 | |
476 DOWNSAMPLE_LAYERS_FOUR_STAGES = [partial(Downsampling, | |
477 kernel_size=7, stride=4, padding=2, | |
478 post_norm=partial(LayerNormGeneral, bias=False, eps=1e-6) | |
479 )] + \ | |
480 [partial(Downsampling, | |
481 kernel_size=3, stride=2, padding=1, | |
482 pre_norm=partial(LayerNormGeneral, bias=False, eps=1e-6), pre_permute=True | |
483 )] * 3 | |
484 | |
485 | |
486 class MetaFormer(nn.Module): | |
487 def __init__(self, in_chans=3, num_classes=1000, | |
488 depths=[2, 2, 6, 2], | |
489 dims=[64, 128, 320, 512], | |
490 downsample_layers=DOWNSAMPLE_LAYERS_FOUR_STAGES, | |
491 token_mixers=nn.Identity, | |
492 mlps=Mlp, | |
493 norm_layers=partial(LayerNormWithoutBias, eps=1e-6), | |
494 drop_path_rate=0., | |
495 head_dropout=0.0, | |
496 layer_scale_init_values=None, | |
497 res_scale_init_values=[None, None, 1.0, 1.0], | |
498 output_norm=partial(nn.LayerNorm, eps=1e-6), | |
499 head_fn=nn.Linear, | |
500 **kwargs): | |
501 super().__init__() | |
502 self.num_classes = num_classes | |
503 | |
504 if not isinstance(depths, (list, tuple)): | |
505 depths = [depths] | |
506 if not isinstance(dims, (list, tuple)): | |
507 dims = [dims] | |
508 | |
509 num_stage = len(depths) | |
510 self.num_stage = num_stage | |
511 | |
512 if not isinstance(downsample_layers, (list, tuple)): | |
513 downsample_layers = [downsample_layers] * num_stage | |
514 down_dims = [in_chans] + dims | |
515 self.downsample_layers = nn.ModuleList( | |
516 [downsample_layers[i](down_dims[i], down_dims[i + 1]) for i in range(num_stage)] | |
517 ) | |
518 | |
519 if not isinstance(token_mixers, (list, tuple)): | |
520 token_mixers = [token_mixers] * num_stage | |
521 if not isinstance(mlps, (list, tuple)): | |
522 mlps = [mlps] * num_stage | |
523 if not isinstance(norm_layers, (list, tuple)): | |
524 norm_layers = [norm_layers] * num_stage | |
525 | |
526 dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] | |
527 | |
528 if not isinstance(layer_scale_init_values, (list, tuple)): | |
529 layer_scale_init_values = [layer_scale_init_values] * num_stage | |
530 if not isinstance(res_scale_init_values, (list, tuple)): | |
531 res_scale_init_values = [res_scale_init_values] * num_stage | |
532 | |
533 self.stages = nn.ModuleList() | |
534 cur = 0 | |
535 for i in range(num_stage): | |
536 stage = nn.Sequential( | |
537 *[MetaFormerBlock(dim=dims[i], | |
538 token_mixer=token_mixers[i], | |
539 mlp=mlps[i], | |
540 norm_layer=norm_layers[i], | |
541 drop_path=dp_rates[cur + j], | |
542 layer_scale_init_value=layer_scale_init_values[i], | |
543 res_scale_init_value=res_scale_init_values[i]) | |
544 for j in range(depths[i])] | |
545 ) | |
546 self.stages.append(stage) | |
547 cur += depths[i] | |
548 | |
549 self.norm = output_norm(dims[-1]) | |
550 self.head = head_fn(dims[-1], num_classes) if head_dropout <= 0.0 else head_fn(dims[-1], num_classes, head_dropout=head_dropout) | |
551 self.apply(self._init_weights) | |
552 | |
553 def _init_weights(self, m): | |
554 if isinstance(m, (nn.Conv2d, nn.Linear)): | |
555 trunc_normal_(m.weight, std=.02) | |
556 if m.bias is not None: | |
557 nn.init.constant_(m.bias, 0) | |
558 | |
559 @torch.jit.ignore | |
560 def no_weight_decay(self): | |
561 return {'norm'} | |
562 | |
563 def forward_features(self, x): | |
564 for i in range(self.num_stage): | |
565 x = self.downsample_layers[i](x) | |
566 x = self.stages[i](x) | |
567 return self.norm(x.mean([1, 2])) | |
568 | |
569 def forward(self, x): | |
570 x = self.forward_features(x) | |
571 x = self.head(x) | |
572 return x | |
573 | |
574 | |
575 # ---- Model factory functions (subset, extend as needed) ---- | |
576 | |
577 @register_model | |
578 def identityformer_s12(pretrained=False, **kwargs): | |
579 model = MetaFormer( | |
580 depths=[2, 2, 6, 2], | |
581 dims=[64, 128, 320, 512], | |
582 token_mixers=nn.Identity, | |
583 norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False), | |
584 **kwargs) | |
585 model.default_cfg = default_cfgs['identityformer_s12'] | |
586 if pretrained: | |
587 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) | |
588 model.load_state_dict(state_dict) | |
589 return model | |
590 | |
591 | |
592 @register_model | |
593 def identityformer_s24(pretrained=False, **kwargs): | |
594 model = MetaFormer( | |
595 depths=[4, 4, 12, 4], | |
596 dims=[64, 128, 320, 512], | |
597 token_mixers=nn.Identity, | |
598 norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False), | |
599 **kwargs) | |
600 model.default_cfg = default_cfgs['identityformer_s24'] | |
601 if pretrained: | |
602 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) | |
603 model.load_state_dict(state_dict) | |
604 return model | |
605 | |
606 | |
607 @register_model | |
608 def identityformer_s36(pretrained=False, **kwargs): | |
609 model = MetaFormer( | |
610 depths=[6, 6, 18, 6], | |
611 dims=[64, 128, 320, 512], | |
612 token_mixers=nn.Identity, | |
613 norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False), | |
614 **kwargs) | |
615 model.default_cfg = default_cfgs['identityformer_s36'] | |
616 if pretrained: | |
617 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) | |
618 model.load_state_dict(state_dict) | |
619 return model | |
620 | |
621 | |
622 @register_model | |
623 def identityformer_m36(pretrained=False, **kwargs): | |
624 model = MetaFormer( | |
625 depths=[6, 6, 18, 6], | |
626 dims=[96, 192, 384, 768], | |
627 token_mixers=nn.Identity, | |
628 norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False), | |
629 **kwargs) | |
630 model.default_cfg = default_cfgs['identityformer_m36'] | |
631 if pretrained: | |
632 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) | |
633 model.load_state_dict(state_dict) | |
634 return model | |
635 | |
636 | |
637 @register_model | |
638 def identityformer_m48(pretrained=False, **kwargs): | |
639 model = MetaFormer( | |
640 depths=[8, 8, 24, 8], | |
641 dims=[96, 192, 384, 768], | |
642 token_mixers=nn.Identity, | |
643 norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False), | |
644 **kwargs) | |
645 model.default_cfg = default_cfgs['identityformer_m48'] | |
646 if pretrained: | |
647 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) | |
648 model.load_state_dict(state_dict) | |
649 return model | |
650 | |
651 | |
652 @register_model | |
653 def randformer_s12(pretrained=False, **kwargs): | |
654 model = MetaFormer( | |
655 depths=[2, 2, 6, 2], | |
656 dims=[64, 128, 320, 512], | |
657 token_mixers=[nn.Identity, nn.Identity, RandomMixing, partial(RandomMixing, num_tokens=49)], | |
658 norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False), | |
659 **kwargs) | |
660 model.default_cfg = default_cfgs['randformer_s12'] | |
661 if pretrained: | |
662 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) | |
663 model.load_state_dict(state_dict) | |
664 return model | |
665 | |
666 | |
667 @register_model | |
668 def randformer_s24(pretrained=False, **kwargs): | |
669 model = MetaFormer( | |
670 depths=[4, 4, 12, 4], | |
671 dims=[64, 128, 320, 512], | |
672 token_mixers=[nn.Identity, nn.Identity, RandomMixing, partial(RandomMixing, num_tokens=49)], | |
673 norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False), | |
674 **kwargs) | |
675 model.default_cfg = default_cfgs['randformer_s24'] | |
676 if pretrained: | |
677 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) | |
678 model.load_state_dict(state_dict) | |
679 return model | |
680 | |
681 | |
682 @register_model | |
683 def randformer_s36(pretrained=False, **kwargs): | |
684 model = MetaFormer( | |
685 depths=[6, 6, 18, 6], | |
686 dims=[64, 128, 320, 512], | |
687 token_mixers=[nn.Identity, nn.Identity, RandomMixing, partial(RandomMixing, num_tokens=49)], | |
688 norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False), | |
689 **kwargs) | |
690 model.default_cfg = default_cfgs['randformer_s36'] | |
691 if pretrained: | |
692 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) | |
693 model.load_state_dict(state_dict) | |
694 return model | |
695 | |
696 | |
697 @register_model | |
698 def randformer_m36(pretrained=False, **kwargs): | |
699 model = MetaFormer( | |
700 depths=[6, 6, 18, 6], | |
701 dims=[96, 192, 384, 768], | |
702 token_mixers=[nn.Identity, nn.Identity, RandomMixing, partial(RandomMixing, num_tokens=49)], | |
703 norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False), | |
704 **kwargs) | |
705 model.default_cfg = default_cfgs['randformer_m36'] | |
706 if pretrained: | |
707 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) | |
708 model.load_state_dict(state_dict) | |
709 return model | |
710 | |
711 | |
712 @register_model | |
713 def randformer_m48(pretrained=False, **kwargs): | |
714 model = MetaFormer( | |
715 depths=[8, 8, 24, 8], | |
716 dims=[96, 192, 384, 768], | |
717 token_mixers=[nn.Identity, nn.Identity, RandomMixing, partial(RandomMixing, num_tokens=49)], | |
718 norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False), | |
719 **kwargs) | |
720 model.default_cfg = default_cfgs['randformer_m48'] | |
721 if pretrained: | |
722 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) | |
723 model.load_state_dict(state_dict) | |
724 return model | |
725 | |
726 | |
727 @register_model | |
728 def poolformerv2_s12(pretrained=False, **kwargs): | |
729 model = MetaFormer( | |
730 depths=[2, 2, 6, 2], | |
731 dims=[64, 128, 320, 512], | |
732 token_mixers=Pooling, | |
733 norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False), | |
734 **kwargs) | |
735 model.default_cfg = default_cfgs['poolformerv2_s12'] | |
736 if pretrained: | |
737 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) | |
738 model.load_state_dict(state_dict) | |
739 return model | |
740 | |
741 | |
742 @register_model | |
743 def poolformerv2_s24(pretrained=False, **kwargs): | |
744 model = MetaFormer( | |
745 depths=[4, 4, 12, 4], | |
746 dims=[64, 128, 320, 512], | |
747 token_mixers=Pooling, | |
748 norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False), | |
749 **kwargs) | |
750 model.default_cfg = default_cfgs['poolformerv2_s24'] | |
751 if pretrained: | |
752 try: | |
753 logger.info("Loading pretrained weights for poolformerv2_s24 from: %s", model.default_cfg['url']) | |
754 | |
755 # Add timeout to prevent hanging in CI environments | |
756 import socket | |
757 original_timeout = socket.getdefaulttimeout() | |
758 socket.setdefaulttimeout(60) # 60 second timeout | |
759 try: | |
760 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) | |
761 model.load_state_dict(state_dict) | |
762 print("✓ Successfully loaded pretrained weights for poolformerv2_s24") | |
763 finally: | |
764 socket.setdefaulttimeout(original_timeout) | |
765 except Exception as e: | |
766 logger.warning("Failed to load pretrained weights for poolformerv2_s24: %s", e) | |
767 logger.info("Continuing with randomly initialized weights...") | |
768 return model | |
769 | |
770 | |
771 @register_model | |
772 def poolformerv2_s36(pretrained=False, **kwargs): | |
773 model = MetaFormer( | |
774 depths=[6, 6, 18, 6], | |
775 dims=[64, 128, 320, 512], | |
776 token_mixers=Pooling, | |
777 norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False), | |
778 **kwargs) | |
779 model.default_cfg = default_cfgs['poolformerv2_s36'] | |
780 if pretrained: | |
781 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) | |
782 model.load_state_dict(state_dict) | |
783 return model | |
784 | |
785 | |
786 @register_model | |
787 def poolformerv2_m36(pretrained=False, **kwargs): | |
788 model = MetaFormer( | |
789 depths=[6, 6, 18, 6], | |
790 dims=[96, 192, 384, 768], | |
791 token_mixers=Pooling, | |
792 norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False), | |
793 **kwargs) | |
794 model.default_cfg = default_cfgs['poolformerv2_m36'] | |
795 if pretrained: | |
796 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) | |
797 model.load_state_dict(state_dict) | |
798 return model | |
799 | |
800 | |
801 @register_model | |
802 def poolformerv2_m48(pretrained=False, **kwargs): | |
803 model = MetaFormer( | |
804 depths=[8, 8, 24, 8], | |
805 dims=[96, 192, 384, 768], | |
806 token_mixers=Pooling, | |
807 norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False), | |
808 **kwargs) | |
809 model.default_cfg = default_cfgs['poolformerv2_m48'] | |
810 if pretrained: | |
811 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) | |
812 model.load_state_dict(state_dict) | |
813 return model | |
814 | |
815 | |
816 @register_model | |
817 def convformer_s18(pretrained=False, **kwargs): | |
818 model = MetaFormer( | |
819 depths=[3, 3, 9, 3], | |
820 dims=[64, 128, 320, 512], | |
821 token_mixers=SepConv, | |
822 head_fn=MlpHead, | |
823 **kwargs) | |
824 model.default_cfg = default_cfgs['convformer_s18'] | |
825 if pretrained: | |
826 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) | |
827 model.load_state_dict(state_dict) | |
828 return model | |
829 | |
830 | |
831 @register_model | |
832 def convformer_s18_384(pretrained=False, **kwargs): | |
833 model = MetaFormer( | |
834 depths=[3, 3, 9, 3], | |
835 dims=[64, 128, 320, 512], | |
836 token_mixers=SepConv, | |
837 head_fn=MlpHead, | |
838 **kwargs) | |
839 model.default_cfg = default_cfgs['convformer_s18_384'] | |
840 if pretrained: | |
841 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) | |
842 model.load_state_dict(state_dict) | |
843 return model | |
844 | |
845 | |
846 @register_model | |
847 def convformer_s18_in21ft1k(pretrained=False, **kwargs): | |
848 model = MetaFormer( | |
849 depths=[3, 3, 9, 3], | |
850 dims=[64, 128, 320, 512], | |
851 token_mixers=SepConv, | |
852 head_fn=MlpHead, | |
853 **kwargs) | |
854 model.default_cfg = default_cfgs['convformer_s18_in21ft1k'] | |
855 if pretrained: | |
856 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) | |
857 model.load_state_dict(state_dict) | |
858 return model | |
859 | |
860 | |
861 @register_model | |
862 def convformer_s18_384_in21ft1k(pretrained=False, **kwargs): | |
863 model = MetaFormer( | |
864 depths=[3, 3, 9, 3], | |
865 dims=[64, 128, 320, 512], | |
866 token_mixers=SepConv, | |
867 head_fn=MlpHead, | |
868 **kwargs) | |
869 model.default_cfg = default_cfgs['convformer_s18_384_in21ft1k'] | |
870 if pretrained: | |
871 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) | |
872 model.load_state_dict(state_dict) | |
873 return model | |
874 | |
875 | |
876 @register_model | |
877 def convformer_s18_in21k(pretrained=False, **kwargs): | |
878 model = MetaFormer( | |
879 depths=[3, 3, 9, 3], | |
880 dims=[64, 128, 320, 512], | |
881 token_mixers=SepConv, | |
882 head_fn=MlpHead, | |
883 **kwargs) | |
884 model.default_cfg = default_cfgs['convformer_s18_in21k'] | |
885 if pretrained: | |
886 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) | |
887 model.load_state_dict(state_dict) | |
888 return model | |
889 | |
890 | |
891 @register_model | |
892 def convformer_s36(pretrained=False, **kwargs): | |
893 model = MetaFormer( | |
894 depths=[3, 12, 18, 3], | |
895 dims=[64, 128, 320, 512], | |
896 token_mixers=SepConv, | |
897 head_fn=MlpHead, | |
898 **kwargs) | |
899 model.default_cfg = default_cfgs['convformer_s36'] | |
900 if pretrained: | |
901 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) | |
902 model.load_state_dict(state_dict) | |
903 return model | |
904 | |
905 | |
906 @register_model | |
907 def convformer_s36_384(pretrained=False, **kwargs): | |
908 model = MetaFormer( | |
909 depths=[3, 12, 18, 3], | |
910 dims=[64, 128, 320, 512], | |
911 token_mixers=SepConv, | |
912 head_fn=MlpHead, | |
913 **kwargs) | |
914 model.default_cfg = default_cfgs['convformer_s36_384'] | |
915 if pretrained: | |
916 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) | |
917 model.load_state_dict(state_dict) | |
918 return model | |
919 | |
920 | |
921 @register_model | |
922 def convformer_s36_in21ft1k(pretrained=False, **kwargs): | |
923 model = MetaFormer( | |
924 depths=[3, 12, 18, 3], | |
925 dims=[64, 128, 320, 512], | |
926 token_mixers=SepConv, | |
927 head_fn=MlpHead, | |
928 **kwargs) | |
929 model.default_cfg = default_cfgs['convformer_s36_in21ft1k'] | |
930 if pretrained: | |
931 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) | |
932 model.load_state_dict(state_dict) | |
933 return model | |
934 | |
935 | |
936 @register_model | |
937 def convformer_s36_384_in21ft1k(pretrained=False, **kwargs): | |
938 model = MetaFormer( | |
939 depths=[3, 12, 18, 3], | |
940 dims=[64, 128, 320, 512], | |
941 token_mixers=SepConv, | |
942 head_fn=MlpHead, | |
943 **kwargs) | |
944 model.default_cfg = default_cfgs['convformer_s36_384_in21ft1k'] | |
945 if pretrained: | |
946 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) | |
947 model.load_state_dict(state_dict) | |
948 return model | |
949 | |
950 | |
951 @register_model | |
952 def convformer_s36_in21k(pretrained=False, **kwargs): | |
953 model = MetaFormer( | |
954 depths=[3, 12, 18, 3], | |
955 dims=[64, 128, 320, 512], | |
956 token_mixers=SepConv, | |
957 head_fn=MlpHead, | |
958 **kwargs) | |
959 model.default_cfg = default_cfgs['convformer_s36_in21k'] | |
960 if pretrained: | |
961 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) | |
962 model.load_state_dict(state_dict) | |
963 return model | |
964 | |
965 | |
966 @register_model | |
967 def convformer_m36(pretrained=False, **kwargs): | |
968 model = MetaFormer( | |
969 depths=[3, 12, 18, 3], | |
970 dims=[96, 192, 384, 576], | |
971 token_mixers=SepConv, | |
972 head_fn=MlpHead, | |
973 **kwargs) | |
974 model.default_cfg = default_cfgs['convformer_m36'] | |
975 if pretrained: | |
976 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) | |
977 model.load_state_dict(state_dict) | |
978 return model | |
979 | |
980 | |
981 @register_model | |
982 def convformer_m36_384(pretrained=False, **kwargs): | |
983 model = MetaFormer( | |
984 depths=[3, 12, 18, 3], | |
985 dims=[96, 192, 384, 576], | |
986 token_mixers=SepConv, | |
987 head_fn=MlpHead, | |
988 **kwargs) | |
989 model.default_cfg = default_cfgs['convformer_m36_384'] | |
990 if pretrained: | |
991 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) | |
992 model.load_state_dict(state_dict) | |
993 return model | |
994 | |
995 | |
996 @register_model | |
997 def convformer_m36_in21ft1k(pretrained=False, **kwargs): | |
998 model = MetaFormer( | |
999 depths=[3, 12, 18, 3], | |
1000 dims=[96, 192, 384, 576], | |
1001 token_mixers=SepConv, | |
1002 head_fn=MlpHead, | |
1003 **kwargs) | |
1004 model.default_cfg = default_cfgs['convformer_m36_in21ft1k'] | |
1005 if pretrained: | |
1006 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) | |
1007 model.load_state_dict(state_dict) | |
1008 return model | |
1009 | |
1010 | |
1011 @register_model | |
1012 def convformer_m36_384_in21ft1k(pretrained=False, **kwargs): | |
1013 model = MetaFormer( | |
1014 depths=[3, 12, 18, 3], | |
1015 dims=[96, 192, 384, 576], | |
1016 token_mixers=SepConv, | |
1017 head_fn=MlpHead, | |
1018 **kwargs) | |
1019 model.default_cfg = default_cfgs['convformer_m36_384_in21ft1k'] | |
1020 if pretrained: | |
1021 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) | |
1022 model.load_state_dict(state_dict) | |
1023 return model | |
1024 | |
1025 | |
1026 @register_model | |
1027 def convformer_m36_in21k(pretrained=False, **kwargs): | |
1028 model = MetaFormer( | |
1029 depths=[3, 12, 18, 3], | |
1030 dims=[96, 192, 384, 576], | |
1031 token_mixers=SepConv, | |
1032 head_fn=MlpHead, | |
1033 **kwargs) | |
1034 model.default_cfg = default_cfgs['convformer_m36_in21k'] | |
1035 if pretrained: | |
1036 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) | |
1037 model.load_state_dict(state_dict) | |
1038 return model | |
1039 | |
1040 | |
1041 @register_model | |
1042 def convformer_b36(pretrained=False, **kwargs): | |
1043 model = MetaFormer( | |
1044 depths=[3, 12, 18, 3], | |
1045 dims=[128, 256, 512, 768], | |
1046 token_mixers=SepConv, | |
1047 head_fn=MlpHead, | |
1048 **kwargs) | |
1049 model.default_cfg = default_cfgs['convformer_b36'] | |
1050 if pretrained: | |
1051 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) | |
1052 model.load_state_dict(state_dict) | |
1053 return model | |
1054 | |
1055 | |
1056 @register_model | |
1057 def convformer_b36_384(pretrained=False, **kwargs): | |
1058 model = MetaFormer( | |
1059 depths=[3, 12, 18, 3], | |
1060 dims=[128, 256, 512, 768], | |
1061 token_mixers=SepConv, | |
1062 head_fn=MlpHead, | |
1063 **kwargs) | |
1064 model.default_cfg = default_cfgs['convformer_b36_384'] | |
1065 if pretrained: | |
1066 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) | |
1067 model.load_state_dict(state_dict) | |
1068 return model | |
1069 | |
1070 | |
1071 @register_model | |
1072 def convformer_b36_in21ft1k(pretrained=False, **kwargs): | |
1073 model = MetaFormer( | |
1074 depths=[3, 12, 18, 3], | |
1075 dims=[128, 256, 512, 768], | |
1076 token_mixers=SepConv, | |
1077 head_fn=MlpHead, | |
1078 **kwargs) | |
1079 model.default_cfg = default_cfgs['convformer_b36_in21ft1k'] | |
1080 if pretrained: | |
1081 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) | |
1082 model.load_state_dict(state_dict) | |
1083 return model | |
1084 | |
1085 | |
1086 @register_model | |
1087 def convformer_b36_384_in21ft1k(pretrained=False, **kwargs): | |
1088 model = MetaFormer( | |
1089 depths=[3, 12, 18, 3], | |
1090 dims=[128, 256, 512, 768], | |
1091 token_mixers=SepConv, | |
1092 head_fn=MlpHead, | |
1093 **kwargs) | |
1094 model.default_cfg = default_cfgs['convformer_b36_384_in21ft1k'] | |
1095 if pretrained: | |
1096 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) | |
1097 model.load_state_dict(state_dict) | |
1098 return model | |
1099 | |
1100 | |
1101 @register_model | |
1102 def convformer_b36_in21k(pretrained=False, **kwargs): | |
1103 model = MetaFormer( | |
1104 depths=[3, 12, 18, 3], | |
1105 dims=[128, 256, 512, 768], | |
1106 token_mixers=SepConv, | |
1107 head_fn=MlpHead, | |
1108 **kwargs) | |
1109 model.default_cfg = default_cfgs['convformer_b36_in21k'] | |
1110 if pretrained: | |
1111 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) | |
1112 model.load_state_dict(state_dict) | |
1113 return model | |
1114 | |
1115 | |
1116 @register_model | |
1117 def caformer_s18(pretrained=False, **kwargs): | |
1118 model = MetaFormer( | |
1119 depths=[3, 3, 9, 3], | |
1120 dims=[64, 128, 320, 512], | |
1121 token_mixers=[SepConv, SepConv, Attention, Attention], | |
1122 head_fn=MlpHead, | |
1123 **kwargs) | |
1124 model.default_cfg = default_cfgs['caformer_s18'] | |
1125 if pretrained: | |
1126 try: | |
1127 print(f"Loading pretrained weights for caformer_s18 from: {model.default_cfg['url']}") | |
1128 # Add timeout to prevent hanging in CI environments | |
1129 import socket | |
1130 original_timeout = socket.getdefaulttimeout() | |
1131 socket.setdefaulttimeout(60) # 60 second timeout | |
1132 try: | |
1133 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) | |
1134 model.load_state_dict(state_dict) | |
1135 print("✓ Successfully loaded pretrained weights for caformer_s18") | |
1136 finally: | |
1137 socket.setdefaulttimeout(original_timeout) | |
1138 except Exception as e: | |
1139 print(f"⚠ Warning: Failed to load pretrained weights for caformer_s18: {e}") | |
1140 print("Continuing with randomly initialized weights...") | |
1141 return model | |
1142 | |
1143 | |
1144 @register_model | |
1145 def caformer_s18_384(pretrained=False, **kwargs): | |
1146 model = MetaFormer( | |
1147 depths=[3, 3, 9, 3], | |
1148 dims=[64, 128, 320, 512], | |
1149 token_mixers=[SepConv, SepConv, Attention, Attention], | |
1150 head_fn=MlpHead, | |
1151 **kwargs) | |
1152 model.default_cfg = default_cfgs['caformer_s18_384'] | |
1153 if pretrained: | |
1154 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) | |
1155 model.load_state_dict(state_dict) | |
1156 return model | |
1157 | |
1158 | |
1159 @register_model | |
1160 def caformer_s18_in21ft1k(pretrained=False, **kwargs): | |
1161 model = MetaFormer( | |
1162 depths=[3, 3, 9, 3], | |
1163 dims=[64, 128, 320, 512], | |
1164 token_mixers=[SepConv, SepConv, Attention, Attention], | |
1165 head_fn=MlpHead, | |
1166 **kwargs) | |
1167 model.default_cfg = default_cfgs['caformer_s18_in21ft1k'] | |
1168 if pretrained: | |
1169 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) | |
1170 model.load_state_dict(state_dict) | |
1171 return model | |
1172 | |
1173 | |
1174 @register_model | |
1175 def caformer_s18_384_in21ft1k(pretrained=False, **kwargs): | |
1176 model = MetaFormer( | |
1177 depths=[3, 3, 9, 3], | |
1178 dims=[64, 128, 320, 512], | |
1179 token_mixers=[SepConv, SepConv, Attention, Attention], | |
1180 head_fn=MlpHead, | |
1181 **kwargs) | |
1182 model.default_cfg = default_cfgs['caformer_s18_384_in21ft1k'] | |
1183 if pretrained: | |
1184 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) | |
1185 model.load_state_dict(state_dict) | |
1186 return model | |
1187 | |
1188 | |
1189 @register_model | |
1190 def caformer_s18_in21k(pretrained=False, **kwargs): | |
1191 model = MetaFormer( | |
1192 depths=[3, 3, 9, 3], | |
1193 dims=[64, 128, 320, 512], | |
1194 token_mixers=[SepConv, SepConv, Attention, Attention], | |
1195 head_fn=MlpHead, | |
1196 **kwargs) | |
1197 model.default_cfg = default_cfgs['caformer_s18_in21k'] | |
1198 if pretrained: | |
1199 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) | |
1200 model.load_state_dict(state_dict) | |
1201 return model | |
1202 | |
1203 | |
1204 @register_model | |
1205 def caformer_s36(pretrained=False, **kwargs): | |
1206 model = MetaFormer( | |
1207 depths=[3, 12, 18, 3], | |
1208 dims=[64, 128, 320, 512], | |
1209 token_mixers=[SepConv, SepConv, Attention, Attention], | |
1210 head_fn=MlpHead, | |
1211 **kwargs) | |
1212 model.default_cfg = default_cfgs['caformer_s36'] | |
1213 if pretrained: | |
1214 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) | |
1215 model.load_state_dict(state_dict) | |
1216 return model | |
1217 | |
1218 | |
1219 @register_model | |
1220 def caformer_s36_384(pretrained=False, **kwargs): | |
1221 model = MetaFormer( | |
1222 depths=[3, 12, 18, 3], | |
1223 dims=[64, 128, 320, 512], | |
1224 token_mixers=[SepConv, SepConv, Attention, Attention], | |
1225 head_fn=MlpHead, | |
1226 **kwargs) | |
1227 model.default_cfg = default_cfgs['caformer_s36_384'] | |
1228 if pretrained: | |
1229 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) | |
1230 model.load_state_dict(state_dict) | |
1231 return model | |
1232 | |
1233 | |
1234 @register_model | |
1235 def caformer_s36_in21ft1k(pretrained=False, **kwargs): | |
1236 model = MetaFormer( | |
1237 depths=[3, 12, 18, 3], | |
1238 dims=[64, 128, 320, 512], | |
1239 token_mixers=[SepConv, SepConv, Attention, Attention], | |
1240 head_fn=MlpHead, | |
1241 **kwargs) | |
1242 model.default_cfg = default_cfgs['caformer_s36_in21ft1k'] | |
1243 if pretrained: | |
1244 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) | |
1245 model.load_state_dict(state_dict) | |
1246 return model | |
1247 | |
1248 | |
1249 @register_model | |
1250 def caformer_s36_384_in21ft1k(pretrained=False, **kwargs): | |
1251 model = MetaFormer( | |
1252 depths=[3, 12, 18, 3], | |
1253 dims=[64, 128, 320, 512], | |
1254 token_mixers=[SepConv, SepConv, Attention, Attention], | |
1255 head_fn=MlpHead, | |
1256 **kwargs) | |
1257 model.default_cfg = default_cfgs['caformer_s36_384_in21ft1k'] | |
1258 if pretrained: | |
1259 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) | |
1260 model.load_state_dict(state_dict) | |
1261 return model | |
1262 | |
1263 | |
1264 @register_model | |
1265 def caformer_s36_in21k(pretrained=False, **kwargs): | |
1266 model = MetaFormer( | |
1267 depths=[3, 12, 18, 3], | |
1268 dims=[64, 128, 320, 512], | |
1269 token_mixers=[SepConv, SepConv, Attention, Attention], | |
1270 head_fn=MlpHead, | |
1271 **kwargs) | |
1272 model.default_cfg = default_cfgs['caformer_s36_in21k'] | |
1273 if pretrained: | |
1274 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) | |
1275 model.load_state_dict(state_dict) | |
1276 return model | |
1277 | |
1278 | |
1279 @register_model | |
1280 def caformer_m36(pretrained=False, **kwargs): | |
1281 model = MetaFormer( | |
1282 depths=[3, 12, 18, 3], | |
1283 dims=[96, 192, 384, 576], | |
1284 token_mixers=[SepConv, SepConv, Attention, Attention], | |
1285 head_fn=MlpHead, | |
1286 **kwargs) | |
1287 model.default_cfg = default_cfgs['caformer_m36'] | |
1288 if pretrained: | |
1289 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) | |
1290 model.load_state_dict(state_dict) | |
1291 return model | |
1292 | |
1293 | |
1294 @register_model | |
1295 def caformer_m36_384(pretrained=False, **kwargs): | |
1296 model = MetaFormer( | |
1297 depths=[3, 12, 18, 3], | |
1298 dims=[96, 192, 384, 576], | |
1299 token_mixers=[SepConv, SepConv, Attention, Attention], | |
1300 head_fn=MlpHead, | |
1301 **kwargs) | |
1302 model.default_cfg = default_cfgs['caformer_m36_384'] | |
1303 if pretrained: | |
1304 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) | |
1305 model.load_state_dict(state_dict) | |
1306 return model | |
1307 | |
1308 | |
1309 @register_model | |
1310 def caformer_m36_in21ft1k(pretrained=False, **kwargs): | |
1311 model = MetaFormer( | |
1312 depths=[3, 12, 18, 3], | |
1313 dims=[96, 192, 384, 576], | |
1314 token_mixers=[SepConv, SepConv, Attention, Attention], | |
1315 head_fn=MlpHead, | |
1316 **kwargs) | |
1317 model.default_cfg = default_cfgs['caformer_m36_in21ft1k'] | |
1318 if pretrained: | |
1319 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) | |
1320 model.load_state_dict(state_dict) | |
1321 return model | |
1322 | |
1323 | |
1324 @register_model | |
1325 def caformer_m36_384_in21ft1k(pretrained=False, **kwargs): | |
1326 model = MetaFormer( | |
1327 depths=[3, 12, 18, 3], | |
1328 dims=[96, 192, 384, 576], | |
1329 token_mixers=[SepConv, SepConv, Attention, Attention], | |
1330 head_fn=MlpHead, | |
1331 **kwargs) | |
1332 model.default_cfg = default_cfgs['caformer_m36_384_in21ft1k'] | |
1333 if pretrained: | |
1334 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) | |
1335 model.load_state_dict(state_dict) | |
1336 return model | |
1337 | |
1338 | |
1339 @register_model | |
1340 def caformer_m36_in21k(pretrained=False, **kwargs): | |
1341 model = MetaFormer( | |
1342 depths=[3, 12, 18, 3], | |
1343 dims=[96, 192, 384, 576], | |
1344 token_mixers=[SepConv, SepConv, Attention, Attention], | |
1345 head_fn=MlpHead, | |
1346 **kwargs) | |
1347 model.default_cfg = default_cfgs['caformer_m36_in21k'] | |
1348 if pretrained: | |
1349 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) | |
1350 model.load_state_dict(state_dict) | |
1351 return model | |
1352 | |
1353 | |
1354 @register_model | |
1355 def caformer_b36(pretrained=False, **kwargs): | |
1356 model = MetaFormer( | |
1357 depths=[3, 12, 18, 3], | |
1358 dims=[128, 256, 512, 768], | |
1359 token_mixers=[SepConv, SepConv, Attention, Attention], | |
1360 head_fn=MlpHead, | |
1361 **kwargs) | |
1362 model.default_cfg = default_cfgs['caformer_b36'] | |
1363 if pretrained: | |
1364 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) | |
1365 model.load_state_dict(state_dict) | |
1366 return model | |
1367 | |
1368 | |
1369 @register_model | |
1370 def caformer_b36_384(pretrained=False, **kwargs): | |
1371 model = MetaFormer( | |
1372 depths=[3, 12, 18, 3], | |
1373 dims=[128, 256, 512, 768], | |
1374 token_mixers=[SepConv, SepConv, Attention, Attention], | |
1375 head_fn=MlpHead, | |
1376 **kwargs) | |
1377 model.default_cfg = default_cfgs['caformer_b36_384'] | |
1378 if pretrained: | |
1379 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) | |
1380 model.load_state_dict(state_dict) | |
1381 return model | |
1382 | |
1383 | |
1384 @register_model | |
1385 def caformer_b36_in21ft1k(pretrained=False, **kwargs): | |
1386 model = MetaFormer( | |
1387 depths=[3, 12, 18, 3], | |
1388 dims=[128, 256, 512, 768], | |
1389 token_mixers=[SepConv, SepConv, Attention, Attention], | |
1390 head_fn=MlpHead, | |
1391 **kwargs) | |
1392 model.default_cfg = default_cfgs['caformer_b36_in21ft1k'] | |
1393 if pretrained: | |
1394 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) | |
1395 model.load_state_dict(state_dict) | |
1396 return model | |
1397 | |
1398 | |
1399 @register_model | |
1400 def caformer_b36_384_in21ft1k(pretrained=False, **kwargs): | |
1401 model = MetaFormer( | |
1402 depths=[3, 12, 18, 3], | |
1403 dims=[128, 256, 512, 768], | |
1404 token_mixers=[SepConv, SepConv, Attention, Attention], | |
1405 head_fn=MlpHead, | |
1406 **kwargs) | |
1407 model.default_cfg = default_cfgs['caformer_b36_384_in21ft1k'] | |
1408 if pretrained: | |
1409 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) | |
1410 model.load_state_dict(state_dict) | |
1411 return model | |
1412 | |
1413 | |
1414 @register_model | |
1415 def caformer_b36_in21k(pretrained=False, **kwargs): | |
1416 model = MetaFormer( | |
1417 depths=[3, 12, 18, 3], | |
1418 dims=[128, 256, 512, 768], | |
1419 token_mixers=[SepConv, SepConv, Attention, Attention], | |
1420 head_fn=MlpHead, | |
1421 **kwargs) | |
1422 model.default_cfg = default_cfgs['caformer_b36_in21k'] | |
1423 if pretrained: | |
1424 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) | |
1425 model.load_state_dict(state_dict) | |
1426 return model |