Mercurial > repos > goeckslab > image_learner
comparison MetaFormer/metaformer_models.py @ 11:c5150cceab47 draft default tip
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
| author | goeckslab | 
|---|---|
| date | Sat, 18 Oct 2025 03:17:09 +0000 | 
| parents | |
| children | 
   comparison
  equal
  deleted
  inserted
  replaced
| 10:b0d893d04d4c | 11:c5150cceab47 | 
|---|---|
| 1 """ | |
| 2 MetaFormer baselines including IdentityFormer, RandFormer, PoolFormerV2, | |
| 3 ConvFormer and CAFormer. | |
| 4 Standalone implementation for Galaxy Image Learner tool (no timm dependency). | |
| 5 """ | |
| 6 import logging | |
| 7 from functools import partial | |
| 8 | |
| 9 import torch | |
| 10 import torch.nn as nn | |
| 11 import torch.nn.functional as F | |
| 12 from torch.nn.init import trunc_normal_ # use torch's built-in truncated normal | |
| 13 | |
| 14 logger = logging.getLogger(__name__) | |
| 15 | |
| 16 | |
| 17 def to_2tuple(v): | |
| 18 if isinstance(v, (list, tuple)): | |
| 19 return tuple(v) | |
| 20 return (v, v) | |
| 21 | |
| 22 | |
| 23 class DropPath(nn.Module): | |
| 24 def __init__(self, drop_prob: float = 0.0): | |
| 25 super().__init__() | |
| 26 self.drop_prob = float(drop_prob) | |
| 27 | |
| 28 def forward(self, x): | |
| 29 if self.drop_prob == 0.0 or not self.training: | |
| 30 return x | |
| 31 keep_prob = 1.0 - self.drop_prob | |
| 32 shape = (x.shape[0],) + (1,) * (x.ndim - 1) | |
| 33 random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) | |
| 34 random_tensor.floor_() | |
| 35 return x.div(keep_prob) * random_tensor | |
| 36 | |
| 37 | |
| 38 # ImageNet normalization constants | |
| 39 IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) | |
| 40 IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) | |
| 41 | |
| 42 | |
| 43 def register_model(fn): | |
| 44 # no-op decorator to mirror timm API without dependency | |
| 45 return fn | |
| 46 | |
| 47 | |
| 48 def _cfg(url: str = '', **kwargs): | |
| 49 return { | |
| 50 'url': url, | |
| 51 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, | |
| 52 'crop_pct': 1.0, 'interpolation': 'bicubic', | |
| 53 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'classifier': 'head', | |
| 54 **kwargs | |
| 55 } | |
| 56 | |
| 57 | |
| 58 default_cfgs = { | |
| 59 'identityformer_s12': _cfg( | |
| 60 url='https://huggingface.co/sail/dl/resolve/main/identityformer/identityformer_s12.pth'), | |
| 61 'identityformer_s24': _cfg( | |
| 62 url='https://huggingface.co/sail/dl/resolve/main/identityformer/identityformer_s24.pth'), | |
| 63 'identityformer_s36': _cfg( | |
| 64 url='https://huggingface.co/sail/dl/resolve/main/identityformer/identityformer_s36.pth'), | |
| 65 'identityformer_m36': _cfg( | |
| 66 url='https://huggingface.co/sail/dl/resolve/main/identityformer/identityformer_m36.pth'), | |
| 67 'identityformer_m48': _cfg( | |
| 68 url='https://huggingface.co/sail/dl/resolve/main/identityformer/identityformer_m48.pth'), | |
| 69 | |
| 70 'randformer_s12': _cfg( | |
| 71 url='https://huggingface.co/sail/dl/resolve/main/randformer/randformer_s12.pth'), | |
| 72 'randformer_s24': _cfg( | |
| 73 url='https://huggingface.co/sail/dl/resolve/main/randformer/randformer_s24.pth'), | |
| 74 'randformer_s36': _cfg( | |
| 75 url='https://huggingface.co/sail/dl/resolve/main/randformer/randformer_s36.pth'), | |
| 76 'randformer_m36': _cfg( | |
| 77 url='https://huggingface.co/sail/dl/resolve/main/randformer/randformer_m36.pth'), | |
| 78 'randformer_m48': _cfg( | |
| 79 url='https://huggingface.co/sail/dl/resolve/main/randformer/randformer_m48.pth'), | |
| 80 | |
| 81 'poolformerv2_s12': _cfg( | |
| 82 url='https://huggingface.co/sail/dl/resolve/main/poolformerv2/poolformerv2_s12.pth'), | |
| 83 'poolformerv2_s24': _cfg( | |
| 84 url='https://huggingface.co/sail/dl/resolve/main/poolformerv2/poolformerv2_s24.pth'), | |
| 85 'poolformerv2_s36': _cfg( | |
| 86 url='https://huggingface.co/sail/dl/resolve/main/poolformerv2/poolformerv2_s36.pth'), | |
| 87 'poolformerv2_m36': _cfg( | |
| 88 url='https://huggingface.co/sail/dl/resolve/main/poolformerv2/poolformerv2_m36.pth'), | |
| 89 'poolformerv2_m48': _cfg( | |
| 90 url='https://huggingface.co/sail/dl/resolve/main/poolformerv2/poolformerv2_m48.pth'), | |
| 91 | |
| 92 'convformer_s18': _cfg( | |
| 93 url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s18.pth'), | |
| 94 'convformer_s18_384': _cfg( | |
| 95 url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s18_384.pth', | |
| 96 input_size=(3, 384, 384)), | |
| 97 'convformer_s18_in21ft1k': _cfg( | |
| 98 url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s18_in21ft1k.pth'), | |
| 99 'convformer_s18_384_in21ft1k': _cfg( | |
| 100 url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s18_384_in21ft1k.pth', | |
| 101 input_size=(3, 384, 384)), | |
| 102 'convformer_s18_in21k': _cfg( | |
| 103 url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s18_in21k.pth', | |
| 104 num_classes=21841), | |
| 105 | |
| 106 'convformer_s36': _cfg( | |
| 107 url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s36.pth'), | |
| 108 'convformer_s36_384': _cfg( | |
| 109 url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s36_384.pth', | |
| 110 input_size=(3, 384, 384)), | |
| 111 'convformer_s36_in21ft1k': _cfg( | |
| 112 url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s36_in21ft1k.pth'), | |
| 113 'convformer_s36_384_in21ft1k': _cfg( | |
| 114 url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s36_384_in21ft1k.pth', | |
| 115 input_size=(3, 384, 384)), | |
| 116 'convformer_s36_in21k': _cfg( | |
| 117 url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s36_in21k.pth', | |
| 118 num_classes=21841), | |
| 119 | |
| 120 'convformer_m36': _cfg( | |
| 121 url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_m36.pth'), | |
| 122 'convformer_m36_384': _cfg( | |
| 123 url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_m36_384.pth', | |
| 124 input_size=(3, 384, 384)), | |
| 125 'convformer_m36_in21ft1k': _cfg( | |
| 126 url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_m36_in21ft1k.pth'), | |
| 127 'convformer_m36_384_in21ft1k': _cfg( | |
| 128 url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_m36_384_in21ft1k.pth', | |
| 129 input_size=(3, 384, 384)), | |
| 130 'convformer_m36_in21k': _cfg( | |
| 131 url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_m36_in21k.pth', | |
| 132 num_classes=21841), | |
| 133 | |
| 134 'convformer_b36': _cfg( | |
| 135 url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_b36.pth'), | |
| 136 'convformer_b36_384': _cfg( | |
| 137 url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_b36_384.pth', | |
| 138 input_size=(3, 384, 384)), | |
| 139 'convformer_b36_in21ft1k': _cfg( | |
| 140 url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_b36_in21ft1k.pth'), | |
| 141 'convformer_b36_384_in21ft1k': _cfg( | |
| 142 url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_b36_384_in21ft1k.pth', | |
| 143 input_size=(3, 384, 384)), | |
| 144 'convformer_b36_in21k': _cfg( | |
| 145 url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_b36_in21k.pth', | |
| 146 num_classes=21841), | |
| 147 | |
| 148 'caformer_s18': _cfg( | |
| 149 url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s18.pth'), | |
| 150 'caformer_s18_384': _cfg( | |
| 151 url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s18_384.pth', | |
| 152 input_size=(3, 384, 384)), | |
| 153 'caformer_s18_in21ft1k': _cfg( | |
| 154 url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s18_in21ft1k.pth'), | |
| 155 'caformer_s18_384_in21ft1k': _cfg( | |
| 156 url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s18_384_in21ft1k.pth', | |
| 157 input_size=(3, 384, 384)), | |
| 158 'caformer_s18_in21k': _cfg( | |
| 159 url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s18_in21k.pth', | |
| 160 num_classes=21841), | |
| 161 | |
| 162 'caformer_s36': _cfg( | |
| 163 url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s36.pth'), | |
| 164 'caformer_s36_384': _cfg( | |
| 165 url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s36_384.pth', | |
| 166 input_size=(3, 384, 384)), | |
| 167 'caformer_s36_in21ft1k': _cfg( | |
| 168 url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s36_in21ft1k.pth'), | |
| 169 'caformer_s36_384_in21ft1k': _cfg( | |
| 170 url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s36_384_in21ft1k.pth', | |
| 171 input_size=(3, 384, 384)), | |
| 172 'caformer_s36_in21k': _cfg( | |
| 173 url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s36_in21k.pth', | |
| 174 num_classes=21841), | |
| 175 | |
| 176 'caformer_m36': _cfg( | |
| 177 url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_m36.pth'), | |
| 178 'caformer_m36_384': _cfg( | |
| 179 url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_m36_384.pth', | |
| 180 input_size=(3, 384, 384)), | |
| 181 'caformer_m36_in21ft1k': _cfg( | |
| 182 url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_m36_in21ft1k.pth'), | |
| 183 'caformer_m36_384_in21ft1k': _cfg( | |
| 184 url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_m36_384_in21ft1k.pth', | |
| 185 input_size=(3, 384, 384)), | |
| 186 'caformer_m36_in21k': _cfg( | |
| 187 url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_m36_in21k.pth', | |
| 188 num_classes=21841), | |
| 189 | |
| 190 'caformer_b36': _cfg( | |
| 191 url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_b36.pth'), | |
| 192 'caformer_b36_384': _cfg( | |
| 193 url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_b36_384.pth', | |
| 194 input_size=(3, 384, 384)), | |
| 195 'caformer_b36_in21ft1k': _cfg( | |
| 196 url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_b36_in21ft1k.pth'), | |
| 197 'caformer_b36_384_in21ft1k': _cfg( | |
| 198 url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_b36_384_in21ft1k.pth', | |
| 199 input_size=(3, 384, 384)), | |
| 200 'caformer_b36_in21k': _cfg( | |
| 201 url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_b36_in21k.pth', | |
| 202 num_classes=21841), | |
| 203 } | |
| 204 | |
| 205 | |
| 206 class Downsampling(nn.Module): | |
| 207 """Downsampling implemented by a layer of convolution.""" | |
| 208 def __init__(self, in_channels, out_channels, | |
| 209 kernel_size, stride=1, padding=0, | |
| 210 pre_norm=None, post_norm=None, pre_permute=False): | |
| 211 super().__init__() | |
| 212 self.pre_norm = pre_norm(in_channels) if pre_norm else nn.Identity() | |
| 213 self.pre_permute = pre_permute | |
| 214 self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, | |
| 215 stride=stride, padding=padding) | |
| 216 self.post_norm = post_norm(out_channels) if post_norm else nn.Identity() | |
| 217 | |
| 218 def forward(self, x): | |
| 219 x = self.pre_norm(x) | |
| 220 if self.pre_permute: | |
| 221 x = x.permute(0, 3, 1, 2) | |
| 222 x = self.conv(x) | |
| 223 x = x.permute(0, 2, 3, 1) | |
| 224 x = self.post_norm(x) | |
| 225 return x | |
| 226 | |
| 227 | |
| 228 class Scale(nn.Module): | |
| 229 """Scale vector by element multiplications.""" | |
| 230 def __init__(self, dim, init_value=1.0, trainable=True): | |
| 231 super().__init__() | |
| 232 self.scale = nn.Parameter(init_value * torch.ones(dim), requires_grad=trainable) | |
| 233 | |
| 234 def forward(self, x): | |
| 235 return x * self.scale | |
| 236 | |
| 237 | |
| 238 class SquaredReLU(nn.Module): | |
| 239 """Squared ReLU: https://arxiv.org/abs/2109.08668""" | |
| 240 def __init__(self, inplace=False): | |
| 241 super().__init__() | |
| 242 self.relu = nn.ReLU(inplace=inplace) | |
| 243 | |
| 244 def forward(self, x): | |
| 245 return torch.square(self.relu(x)) | |
| 246 | |
| 247 | |
| 248 class StarReLU(nn.Module): | |
| 249 """StarReLU: s * relu(x) ** 2 + b""" | |
| 250 def __init__(self, scale_value=1.0, bias_value=0.0, | |
| 251 scale_learnable=True, bias_learnable=True, | |
| 252 mode=None, inplace=False): | |
| 253 super().__init__() | |
| 254 self.inplace = inplace | |
| 255 self.relu = nn.ReLU(inplace=inplace) | |
| 256 self.scale = nn.Parameter(scale_value * torch.ones(1), | |
| 257 requires_grad=scale_learnable) | |
| 258 self.bias = nn.Parameter(bias_value * torch.ones(1), | |
| 259 requires_grad=bias_learnable) | |
| 260 | |
| 261 def forward(self, x): | |
| 262 return self.scale * self.relu(x) ** 2 + self.bias | |
| 263 | |
| 264 | |
| 265 class Attention(nn.Module): | |
| 266 """Vanilla self-attention from Transformer.""" | |
| 267 def __init__(self, dim, head_dim=32, num_heads=None, qkv_bias=False, | |
| 268 attn_drop=0., proj_drop=0., proj_bias=False, **kwargs): | |
| 269 super().__init__() | |
| 270 | |
| 271 self.head_dim = head_dim | |
| 272 self.scale = head_dim ** -0.5 | |
| 273 | |
| 274 self.num_heads = num_heads if num_heads else dim // head_dim | |
| 275 if self.num_heads == 0: | |
| 276 self.num_heads = 1 | |
| 277 | |
| 278 self.attention_dim = self.num_heads * self.head_dim | |
| 279 self.qkv = nn.Linear(dim, self.attention_dim * 3, bias=qkv_bias) | |
| 280 self.attn_drop = nn.Dropout(attn_drop) | |
| 281 self.proj = nn.Linear(self.attention_dim, dim, bias=proj_bias) | |
| 282 self.proj_drop = nn.Dropout(proj_drop) | |
| 283 | |
| 284 def forward(self, x): | |
| 285 B, H, W, C = x.shape | |
| 286 N = H * W | |
| 287 qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) | |
| 288 q, k, v = qkv.unbind(0) | |
| 289 attn = (q @ k.transpose(-2, -1)) * self.scale | |
| 290 attn = attn.softmax(dim=-1) | |
| 291 attn = self.attn_drop(attn) | |
| 292 x = (attn @ v).transpose(1, 2).reshape(B, H, W, self.attention_dim) | |
| 293 x = self.proj(x) | |
| 294 x = self.proj_drop(x) | |
| 295 return x | |
| 296 | |
| 297 | |
| 298 class RandomMixing(nn.Module): | |
| 299 def __init__(self, num_tokens=196, **kwargs): | |
| 300 super().__init__() | |
| 301 self.num_tokens = num_tokens | |
| 302 base_matrix = torch.softmax(torch.rand(num_tokens, num_tokens), dim=-1) | |
| 303 self.register_buffer("random_matrix", base_matrix, persistent=True) | |
| 304 | |
| 305 def forward(self, x): | |
| 306 B, H, W, C = x.shape | |
| 307 actual_tokens = H * W | |
| 308 | |
| 309 if actual_tokens == self.random_matrix.shape[0]: | |
| 310 mixing = self.random_matrix | |
| 311 else: | |
| 312 base = self.random_matrix | |
| 313 if base.device != x.device: | |
| 314 base = base.to(x.device) | |
| 315 resized = F.interpolate( | |
| 316 base.unsqueeze(0).unsqueeze(0), | |
| 317 size=(actual_tokens, actual_tokens), | |
| 318 mode='bilinear', | |
| 319 align_corners=False, | |
| 320 ).squeeze(0).squeeze(0) | |
| 321 mixing = torch.softmax(resized, dim=-1) | |
| 322 | |
| 323 x = x.reshape(B, actual_tokens, C) | |
| 324 x = torch.einsum('mn, bnc -> bmc', mixing, x) | |
| 325 x = x.reshape(B, H, W, C) | |
| 326 return x | |
| 327 | |
| 328 | |
| 329 class LayerNormGeneral(nn.Module): | |
| 330 """General LayerNorm for different situations.""" | |
| 331 def __init__(self, affine_shape=None, normalized_dim=(-1,), scale=True, | |
| 332 bias=True, eps=1e-5): | |
| 333 super().__init__() | |
| 334 self.normalized_dim = normalized_dim | |
| 335 self.use_scale = scale | |
| 336 self.use_bias = bias | |
| 337 self.weight = nn.Parameter(torch.ones(affine_shape)) if scale else None | |
| 338 self.bias = nn.Parameter(torch.zeros(affine_shape)) if bias else None | |
| 339 self.eps = eps | |
| 340 | |
| 341 def forward(self, x): | |
| 342 c = x - x.mean(self.normalized_dim, keepdim=True) | |
| 343 s = c.pow(2).mean(self.normalized_dim, keepdim=True) | |
| 344 x = c / torch.sqrt(s + self.eps) | |
| 345 if self.use_scale: | |
| 346 x = x * self.weight | |
| 347 if self.use_bias: | |
| 348 x = x + self.bias | |
| 349 return x | |
| 350 | |
| 351 | |
| 352 class LayerNormWithoutBias(nn.Module): | |
| 353 """Equal to partial(LayerNormGeneral, bias=False) but faster.""" | |
| 354 def __init__(self, normalized_shape, eps=1e-5, **kwargs): | |
| 355 super().__init__() | |
| 356 self.eps = eps | |
| 357 self.bias = None | |
| 358 if isinstance(normalized_shape, int): | |
| 359 normalized_shape = (normalized_shape,) | |
| 360 self.weight = nn.Parameter(torch.ones(normalized_shape)) | |
| 361 self.normalized_shape = normalized_shape | |
| 362 | |
| 363 def forward(self, x): | |
| 364 return F.layer_norm(x, self.normalized_shape, weight=self.weight, bias=self.bias, eps=self.eps) | |
| 365 | |
| 366 | |
| 367 class SepConv(nn.Module): | |
| 368 """Inverted separable convolution from MobileNetV2.""" | |
| 369 def __init__(self, dim, expansion_ratio=2, | |
| 370 act1_layer=StarReLU, act2_layer=nn.Identity, | |
| 371 bias=False, kernel_size=7, padding=3, | |
| 372 **kwargs): | |
| 373 super().__init__() | |
| 374 med_channels = int(expansion_ratio * dim) | |
| 375 self.pwconv1 = nn.Linear(dim, med_channels, bias=bias) | |
| 376 self.act1 = act1_layer() | |
| 377 self.dwconv = nn.Conv2d( | |
| 378 med_channels, med_channels, kernel_size=kernel_size, | |
| 379 padding=padding, groups=med_channels, bias=bias) | |
| 380 self.act2 = act2_layer() | |
| 381 self.pwconv2 = nn.Linear(med_channels, dim, bias=bias) | |
| 382 | |
| 383 def forward(self, x): | |
| 384 x = self.pwconv1(x) | |
| 385 x = self.act1(x) | |
| 386 x = x.permute(0, 3, 1, 2) | |
| 387 x = self.dwconv(x) | |
| 388 x = x.permute(0, 2, 3, 1) | |
| 389 x = self.act2(x) | |
| 390 x = self.pwconv2(x) | |
| 391 return x | |
| 392 | |
| 393 | |
| 394 class Pooling(nn.Module): | |
| 395 """Pooling for PoolFormer.""" | |
| 396 def __init__(self, pool_size=3, **kwargs): | |
| 397 super().__init__() | |
| 398 self.pool = nn.AvgPool2d(pool_size, stride=1, padding=pool_size // 2, count_include_pad=False) | |
| 399 | |
| 400 def forward(self, x): | |
| 401 y = x.permute(0, 3, 1, 2) | |
| 402 y = self.pool(y) | |
| 403 y = y.permute(0, 2, 3, 1) | |
| 404 return y - x | |
| 405 | |
| 406 | |
| 407 class Mlp(nn.Module): | |
| 408 """ MLP used in MetaFormer models.""" | |
| 409 def __init__(self, dim, mlp_ratio=4, out_features=None, act_layer=StarReLU, drop=0., bias=False, **kwargs): | |
| 410 super().__init__() | |
| 411 in_features = dim | |
| 412 out_features = out_features or in_features | |
| 413 hidden_features = int(mlp_ratio * in_features) | |
| 414 drop_probs = to_2tuple(drop) | |
| 415 | |
| 416 self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) | |
| 417 self.act = act_layer() | |
| 418 self.drop1 = nn.Dropout(drop_probs[0]) | |
| 419 self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) | |
| 420 self.drop2 = nn.Dropout(drop_probs[1]) | |
| 421 | |
| 422 def forward(self, x): | |
| 423 x = self.fc1(x) | |
| 424 x = self.act(x) | |
| 425 x = self.drop1(x) | |
| 426 x = self.fc2(x) | |
| 427 x = self.drop2(x) | |
| 428 return x | |
| 429 | |
| 430 | |
| 431 class MlpHead(nn.Module): | |
| 432 def __init__(self, dim, num_classes=1000, mlp_ratio=4, act_layer=SquaredReLU, | |
| 433 norm_layer=nn.LayerNorm, head_dropout=0., bias=True): | |
| 434 super().__init__() | |
| 435 hidden_features = int(mlp_ratio * dim) | |
| 436 self.fc1 = nn.Linear(dim, hidden_features, bias=bias) | |
| 437 self.act = act_layer() | |
| 438 self.norm = norm_layer(hidden_features) | |
| 439 self.fc2 = nn.Linear(hidden_features, num_classes, bias=bias) | |
| 440 self.head_dropout = nn.Dropout(head_dropout) | |
| 441 | |
| 442 def forward(self, x): | |
| 443 x = self.fc1(x) | |
| 444 x = self.act(x) | |
| 445 x = self.norm(x) | |
| 446 x = self.head_dropout(x) | |
| 447 x = self.fc2(x) | |
| 448 return x | |
| 449 | |
| 450 | |
| 451 class MetaFormerBlock(nn.Module): | |
| 452 def __init__(self, dim, | |
| 453 token_mixer=nn.Identity, mlp=Mlp, | |
| 454 norm_layer=nn.LayerNorm, | |
| 455 drop=0., drop_path=0., | |
| 456 layer_scale_init_value=None, res_scale_init_value=None): | |
| 457 super().__init__() | |
| 458 self.norm1 = norm_layer(dim) | |
| 459 self.token_mixer = token_mixer(dim=dim, drop=drop) | |
| 460 self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() | |
| 461 self.layer_scale1 = Scale(dim=dim, init_value=layer_scale_init_value) if layer_scale_init_value else nn.Identity() | |
| 462 self.res_scale1 = Scale(dim=dim, init_value=res_scale_init_value) if res_scale_init_value else nn.Identity() | |
| 463 | |
| 464 self.norm2 = norm_layer(dim) | |
| 465 self.mlp = mlp(dim=dim, drop=drop) | |
| 466 self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() | |
| 467 self.layer_scale2 = Scale(dim=dim, init_value=layer_scale_init_value) if layer_scale_init_value else nn.Identity() | |
| 468 self.res_scale2 = Scale(dim=dim, init_value=res_scale_init_value) if res_scale_init_value else nn.Identity() | |
| 469 | |
| 470 def forward(self, x): | |
| 471 x = self.res_scale1(x) + self.layer_scale1(self.drop_path1(self.token_mixer(self.norm1(x)))) | |
| 472 x = self.res_scale2(x) + self.layer_scale2(self.drop_path2(self.mlp(self.norm2(x)))) | |
| 473 return x | |
| 474 | |
| 475 | |
| 476 DOWNSAMPLE_LAYERS_FOUR_STAGES = [partial(Downsampling, | |
| 477 kernel_size=7, stride=4, padding=2, | |
| 478 post_norm=partial(LayerNormGeneral, bias=False, eps=1e-6) | |
| 479 )] + \ | |
| 480 [partial(Downsampling, | |
| 481 kernel_size=3, stride=2, padding=1, | |
| 482 pre_norm=partial(LayerNormGeneral, bias=False, eps=1e-6), pre_permute=True | |
| 483 )] * 3 | |
| 484 | |
| 485 | |
| 486 class MetaFormer(nn.Module): | |
| 487 def __init__(self, in_chans=3, num_classes=1000, | |
| 488 depths=[2, 2, 6, 2], | |
| 489 dims=[64, 128, 320, 512], | |
| 490 downsample_layers=DOWNSAMPLE_LAYERS_FOUR_STAGES, | |
| 491 token_mixers=nn.Identity, | |
| 492 mlps=Mlp, | |
| 493 norm_layers=partial(LayerNormWithoutBias, eps=1e-6), | |
| 494 drop_path_rate=0., | |
| 495 head_dropout=0.0, | |
| 496 layer_scale_init_values=None, | |
| 497 res_scale_init_values=[None, None, 1.0, 1.0], | |
| 498 output_norm=partial(nn.LayerNorm, eps=1e-6), | |
| 499 head_fn=nn.Linear, | |
| 500 **kwargs): | |
| 501 super().__init__() | |
| 502 self.num_classes = num_classes | |
| 503 | |
| 504 if not isinstance(depths, (list, tuple)): | |
| 505 depths = [depths] | |
| 506 if not isinstance(dims, (list, tuple)): | |
| 507 dims = [dims] | |
| 508 | |
| 509 num_stage = len(depths) | |
| 510 self.num_stage = num_stage | |
| 511 | |
| 512 if not isinstance(downsample_layers, (list, tuple)): | |
| 513 downsample_layers = [downsample_layers] * num_stage | |
| 514 down_dims = [in_chans] + dims | |
| 515 self.downsample_layers = nn.ModuleList( | |
| 516 [downsample_layers[i](down_dims[i], down_dims[i + 1]) for i in range(num_stage)] | |
| 517 ) | |
| 518 | |
| 519 if not isinstance(token_mixers, (list, tuple)): | |
| 520 token_mixers = [token_mixers] * num_stage | |
| 521 if not isinstance(mlps, (list, tuple)): | |
| 522 mlps = [mlps] * num_stage | |
| 523 if not isinstance(norm_layers, (list, tuple)): | |
| 524 norm_layers = [norm_layers] * num_stage | |
| 525 | |
| 526 dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] | |
| 527 | |
| 528 if not isinstance(layer_scale_init_values, (list, tuple)): | |
| 529 layer_scale_init_values = [layer_scale_init_values] * num_stage | |
| 530 if not isinstance(res_scale_init_values, (list, tuple)): | |
| 531 res_scale_init_values = [res_scale_init_values] * num_stage | |
| 532 | |
| 533 self.stages = nn.ModuleList() | |
| 534 cur = 0 | |
| 535 for i in range(num_stage): | |
| 536 stage = nn.Sequential( | |
| 537 *[MetaFormerBlock(dim=dims[i], | |
| 538 token_mixer=token_mixers[i], | |
| 539 mlp=mlps[i], | |
| 540 norm_layer=norm_layers[i], | |
| 541 drop_path=dp_rates[cur + j], | |
| 542 layer_scale_init_value=layer_scale_init_values[i], | |
| 543 res_scale_init_value=res_scale_init_values[i]) | |
| 544 for j in range(depths[i])] | |
| 545 ) | |
| 546 self.stages.append(stage) | |
| 547 cur += depths[i] | |
| 548 | |
| 549 self.norm = output_norm(dims[-1]) | |
| 550 self.head = head_fn(dims[-1], num_classes) if head_dropout <= 0.0 else head_fn(dims[-1], num_classes, head_dropout=head_dropout) | |
| 551 self.apply(self._init_weights) | |
| 552 | |
| 553 def _init_weights(self, m): | |
| 554 if isinstance(m, (nn.Conv2d, nn.Linear)): | |
| 555 trunc_normal_(m.weight, std=.02) | |
| 556 if m.bias is not None: | |
| 557 nn.init.constant_(m.bias, 0) | |
| 558 | |
| 559 @torch.jit.ignore | |
| 560 def no_weight_decay(self): | |
| 561 return {'norm'} | |
| 562 | |
| 563 def forward_features(self, x): | |
| 564 for i in range(self.num_stage): | |
| 565 x = self.downsample_layers[i](x) | |
| 566 x = self.stages[i](x) | |
| 567 return self.norm(x.mean([1, 2])) | |
| 568 | |
| 569 def forward(self, x): | |
| 570 x = self.forward_features(x) | |
| 571 x = self.head(x) | |
| 572 return x | |
| 573 | |
| 574 | |
| 575 # ---- Model factory functions (subset, extend as needed) ---- | |
| 576 | |
| 577 @register_model | |
| 578 def identityformer_s12(pretrained=False, **kwargs): | |
| 579 model = MetaFormer( | |
| 580 depths=[2, 2, 6, 2], | |
| 581 dims=[64, 128, 320, 512], | |
| 582 token_mixers=nn.Identity, | |
| 583 norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False), | |
| 584 **kwargs) | |
| 585 model.default_cfg = default_cfgs['identityformer_s12'] | |
| 586 if pretrained: | |
| 587 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) | |
| 588 model.load_state_dict(state_dict) | |
| 589 return model | |
| 590 | |
| 591 | |
| 592 @register_model | |
| 593 def identityformer_s24(pretrained=False, **kwargs): | |
| 594 model = MetaFormer( | |
| 595 depths=[4, 4, 12, 4], | |
| 596 dims=[64, 128, 320, 512], | |
| 597 token_mixers=nn.Identity, | |
| 598 norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False), | |
| 599 **kwargs) | |
| 600 model.default_cfg = default_cfgs['identityformer_s24'] | |
| 601 if pretrained: | |
| 602 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) | |
| 603 model.load_state_dict(state_dict) | |
| 604 return model | |
| 605 | |
| 606 | |
| 607 @register_model | |
| 608 def identityformer_s36(pretrained=False, **kwargs): | |
| 609 model = MetaFormer( | |
| 610 depths=[6, 6, 18, 6], | |
| 611 dims=[64, 128, 320, 512], | |
| 612 token_mixers=nn.Identity, | |
| 613 norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False), | |
| 614 **kwargs) | |
| 615 model.default_cfg = default_cfgs['identityformer_s36'] | |
| 616 if pretrained: | |
| 617 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) | |
| 618 model.load_state_dict(state_dict) | |
| 619 return model | |
| 620 | |
| 621 | |
| 622 @register_model | |
| 623 def identityformer_m36(pretrained=False, **kwargs): | |
| 624 model = MetaFormer( | |
| 625 depths=[6, 6, 18, 6], | |
| 626 dims=[96, 192, 384, 768], | |
| 627 token_mixers=nn.Identity, | |
| 628 norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False), | |
| 629 **kwargs) | |
| 630 model.default_cfg = default_cfgs['identityformer_m36'] | |
| 631 if pretrained: | |
| 632 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) | |
| 633 model.load_state_dict(state_dict) | |
| 634 return model | |
| 635 | |
| 636 | |
| 637 @register_model | |
| 638 def identityformer_m48(pretrained=False, **kwargs): | |
| 639 model = MetaFormer( | |
| 640 depths=[8, 8, 24, 8], | |
| 641 dims=[96, 192, 384, 768], | |
| 642 token_mixers=nn.Identity, | |
| 643 norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False), | |
| 644 **kwargs) | |
| 645 model.default_cfg = default_cfgs['identityformer_m48'] | |
| 646 if pretrained: | |
| 647 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) | |
| 648 model.load_state_dict(state_dict) | |
| 649 return model | |
| 650 | |
| 651 | |
| 652 @register_model | |
| 653 def randformer_s12(pretrained=False, **kwargs): | |
| 654 model = MetaFormer( | |
| 655 depths=[2, 2, 6, 2], | |
| 656 dims=[64, 128, 320, 512], | |
| 657 token_mixers=[nn.Identity, nn.Identity, RandomMixing, partial(RandomMixing, num_tokens=49)], | |
| 658 norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False), | |
| 659 **kwargs) | |
| 660 model.default_cfg = default_cfgs['randformer_s12'] | |
| 661 if pretrained: | |
| 662 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) | |
| 663 model.load_state_dict(state_dict) | |
| 664 return model | |
| 665 | |
| 666 | |
| 667 @register_model | |
| 668 def randformer_s24(pretrained=False, **kwargs): | |
| 669 model = MetaFormer( | |
| 670 depths=[4, 4, 12, 4], | |
| 671 dims=[64, 128, 320, 512], | |
| 672 token_mixers=[nn.Identity, nn.Identity, RandomMixing, partial(RandomMixing, num_tokens=49)], | |
| 673 norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False), | |
| 674 **kwargs) | |
| 675 model.default_cfg = default_cfgs['randformer_s24'] | |
| 676 if pretrained: | |
| 677 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) | |
| 678 model.load_state_dict(state_dict) | |
| 679 return model | |
| 680 | |
| 681 | |
| 682 @register_model | |
| 683 def randformer_s36(pretrained=False, **kwargs): | |
| 684 model = MetaFormer( | |
| 685 depths=[6, 6, 18, 6], | |
| 686 dims=[64, 128, 320, 512], | |
| 687 token_mixers=[nn.Identity, nn.Identity, RandomMixing, partial(RandomMixing, num_tokens=49)], | |
| 688 norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False), | |
| 689 **kwargs) | |
| 690 model.default_cfg = default_cfgs['randformer_s36'] | |
| 691 if pretrained: | |
| 692 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) | |
| 693 model.load_state_dict(state_dict) | |
| 694 return model | |
| 695 | |
| 696 | |
| 697 @register_model | |
| 698 def randformer_m36(pretrained=False, **kwargs): | |
| 699 model = MetaFormer( | |
| 700 depths=[6, 6, 18, 6], | |
| 701 dims=[96, 192, 384, 768], | |
| 702 token_mixers=[nn.Identity, nn.Identity, RandomMixing, partial(RandomMixing, num_tokens=49)], | |
| 703 norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False), | |
| 704 **kwargs) | |
| 705 model.default_cfg = default_cfgs['randformer_m36'] | |
| 706 if pretrained: | |
| 707 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) | |
| 708 model.load_state_dict(state_dict) | |
| 709 return model | |
| 710 | |
| 711 | |
| 712 @register_model | |
| 713 def randformer_m48(pretrained=False, **kwargs): | |
| 714 model = MetaFormer( | |
| 715 depths=[8, 8, 24, 8], | |
| 716 dims=[96, 192, 384, 768], | |
| 717 token_mixers=[nn.Identity, nn.Identity, RandomMixing, partial(RandomMixing, num_tokens=49)], | |
| 718 norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False), | |
| 719 **kwargs) | |
| 720 model.default_cfg = default_cfgs['randformer_m48'] | |
| 721 if pretrained: | |
| 722 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) | |
| 723 model.load_state_dict(state_dict) | |
| 724 return model | |
| 725 | |
| 726 | |
| 727 @register_model | |
| 728 def poolformerv2_s12(pretrained=False, **kwargs): | |
| 729 model = MetaFormer( | |
| 730 depths=[2, 2, 6, 2], | |
| 731 dims=[64, 128, 320, 512], | |
| 732 token_mixers=Pooling, | |
| 733 norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False), | |
| 734 **kwargs) | |
| 735 model.default_cfg = default_cfgs['poolformerv2_s12'] | |
| 736 if pretrained: | |
| 737 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) | |
| 738 model.load_state_dict(state_dict) | |
| 739 return model | |
| 740 | |
| 741 | |
| 742 @register_model | |
| 743 def poolformerv2_s24(pretrained=False, **kwargs): | |
| 744 model = MetaFormer( | |
| 745 depths=[4, 4, 12, 4], | |
| 746 dims=[64, 128, 320, 512], | |
| 747 token_mixers=Pooling, | |
| 748 norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False), | |
| 749 **kwargs) | |
| 750 model.default_cfg = default_cfgs['poolformerv2_s24'] | |
| 751 if pretrained: | |
| 752 try: | |
| 753 logger.info("Loading pretrained weights for poolformerv2_s24 from: %s", model.default_cfg['url']) | |
| 754 | |
| 755 # Add timeout to prevent hanging in CI environments | |
| 756 import socket | |
| 757 original_timeout = socket.getdefaulttimeout() | |
| 758 socket.setdefaulttimeout(60) # 60 second timeout | |
| 759 try: | |
| 760 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) | |
| 761 model.load_state_dict(state_dict) | |
| 762 print("✓ Successfully loaded pretrained weights for poolformerv2_s24") | |
| 763 finally: | |
| 764 socket.setdefaulttimeout(original_timeout) | |
| 765 except Exception as e: | |
| 766 logger.warning("Failed to load pretrained weights for poolformerv2_s24: %s", e) | |
| 767 logger.info("Continuing with randomly initialized weights...") | |
| 768 return model | |
| 769 | |
| 770 | |
| 771 @register_model | |
| 772 def poolformerv2_s36(pretrained=False, **kwargs): | |
| 773 model = MetaFormer( | |
| 774 depths=[6, 6, 18, 6], | |
| 775 dims=[64, 128, 320, 512], | |
| 776 token_mixers=Pooling, | |
| 777 norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False), | |
| 778 **kwargs) | |
| 779 model.default_cfg = default_cfgs['poolformerv2_s36'] | |
| 780 if pretrained: | |
| 781 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) | |
| 782 model.load_state_dict(state_dict) | |
| 783 return model | |
| 784 | |
| 785 | |
| 786 @register_model | |
| 787 def poolformerv2_m36(pretrained=False, **kwargs): | |
| 788 model = MetaFormer( | |
| 789 depths=[6, 6, 18, 6], | |
| 790 dims=[96, 192, 384, 768], | |
| 791 token_mixers=Pooling, | |
| 792 norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False), | |
| 793 **kwargs) | |
| 794 model.default_cfg = default_cfgs['poolformerv2_m36'] | |
| 795 if pretrained: | |
| 796 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) | |
| 797 model.load_state_dict(state_dict) | |
| 798 return model | |
| 799 | |
| 800 | |
| 801 @register_model | |
| 802 def poolformerv2_m48(pretrained=False, **kwargs): | |
| 803 model = MetaFormer( | |
| 804 depths=[8, 8, 24, 8], | |
| 805 dims=[96, 192, 384, 768], | |
| 806 token_mixers=Pooling, | |
| 807 norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False), | |
| 808 **kwargs) | |
| 809 model.default_cfg = default_cfgs['poolformerv2_m48'] | |
| 810 if pretrained: | |
| 811 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) | |
| 812 model.load_state_dict(state_dict) | |
| 813 return model | |
| 814 | |
| 815 | |
| 816 @register_model | |
| 817 def convformer_s18(pretrained=False, **kwargs): | |
| 818 model = MetaFormer( | |
| 819 depths=[3, 3, 9, 3], | |
| 820 dims=[64, 128, 320, 512], | |
| 821 token_mixers=SepConv, | |
| 822 head_fn=MlpHead, | |
| 823 **kwargs) | |
| 824 model.default_cfg = default_cfgs['convformer_s18'] | |
| 825 if pretrained: | |
| 826 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) | |
| 827 model.load_state_dict(state_dict) | |
| 828 return model | |
| 829 | |
| 830 | |
| 831 @register_model | |
| 832 def convformer_s18_384(pretrained=False, **kwargs): | |
| 833 model = MetaFormer( | |
| 834 depths=[3, 3, 9, 3], | |
| 835 dims=[64, 128, 320, 512], | |
| 836 token_mixers=SepConv, | |
| 837 head_fn=MlpHead, | |
| 838 **kwargs) | |
| 839 model.default_cfg = default_cfgs['convformer_s18_384'] | |
| 840 if pretrained: | |
| 841 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) | |
| 842 model.load_state_dict(state_dict) | |
| 843 return model | |
| 844 | |
| 845 | |
| 846 @register_model | |
| 847 def convformer_s18_in21ft1k(pretrained=False, **kwargs): | |
| 848 model = MetaFormer( | |
| 849 depths=[3, 3, 9, 3], | |
| 850 dims=[64, 128, 320, 512], | |
| 851 token_mixers=SepConv, | |
| 852 head_fn=MlpHead, | |
| 853 **kwargs) | |
| 854 model.default_cfg = default_cfgs['convformer_s18_in21ft1k'] | |
| 855 if pretrained: | |
| 856 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) | |
| 857 model.load_state_dict(state_dict) | |
| 858 return model | |
| 859 | |
| 860 | |
| 861 @register_model | |
| 862 def convformer_s18_384_in21ft1k(pretrained=False, **kwargs): | |
| 863 model = MetaFormer( | |
| 864 depths=[3, 3, 9, 3], | |
| 865 dims=[64, 128, 320, 512], | |
| 866 token_mixers=SepConv, | |
| 867 head_fn=MlpHead, | |
| 868 **kwargs) | |
| 869 model.default_cfg = default_cfgs['convformer_s18_384_in21ft1k'] | |
| 870 if pretrained: | |
| 871 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) | |
| 872 model.load_state_dict(state_dict) | |
| 873 return model | |
| 874 | |
| 875 | |
| 876 @register_model | |
| 877 def convformer_s18_in21k(pretrained=False, **kwargs): | |
| 878 model = MetaFormer( | |
| 879 depths=[3, 3, 9, 3], | |
| 880 dims=[64, 128, 320, 512], | |
| 881 token_mixers=SepConv, | |
| 882 head_fn=MlpHead, | |
| 883 **kwargs) | |
| 884 model.default_cfg = default_cfgs['convformer_s18_in21k'] | |
| 885 if pretrained: | |
| 886 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) | |
| 887 model.load_state_dict(state_dict) | |
| 888 return model | |
| 889 | |
| 890 | |
| 891 @register_model | |
| 892 def convformer_s36(pretrained=False, **kwargs): | |
| 893 model = MetaFormer( | |
| 894 depths=[3, 12, 18, 3], | |
| 895 dims=[64, 128, 320, 512], | |
| 896 token_mixers=SepConv, | |
| 897 head_fn=MlpHead, | |
| 898 **kwargs) | |
| 899 model.default_cfg = default_cfgs['convformer_s36'] | |
| 900 if pretrained: | |
| 901 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) | |
| 902 model.load_state_dict(state_dict) | |
| 903 return model | |
| 904 | |
| 905 | |
| 906 @register_model | |
| 907 def convformer_s36_384(pretrained=False, **kwargs): | |
| 908 model = MetaFormer( | |
| 909 depths=[3, 12, 18, 3], | |
| 910 dims=[64, 128, 320, 512], | |
| 911 token_mixers=SepConv, | |
| 912 head_fn=MlpHead, | |
| 913 **kwargs) | |
| 914 model.default_cfg = default_cfgs['convformer_s36_384'] | |
| 915 if pretrained: | |
| 916 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) | |
| 917 model.load_state_dict(state_dict) | |
| 918 return model | |
| 919 | |
| 920 | |
| 921 @register_model | |
| 922 def convformer_s36_in21ft1k(pretrained=False, **kwargs): | |
| 923 model = MetaFormer( | |
| 924 depths=[3, 12, 18, 3], | |
| 925 dims=[64, 128, 320, 512], | |
| 926 token_mixers=SepConv, | |
| 927 head_fn=MlpHead, | |
| 928 **kwargs) | |
| 929 model.default_cfg = default_cfgs['convformer_s36_in21ft1k'] | |
| 930 if pretrained: | |
| 931 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) | |
| 932 model.load_state_dict(state_dict) | |
| 933 return model | |
| 934 | |
| 935 | |
| 936 @register_model | |
| 937 def convformer_s36_384_in21ft1k(pretrained=False, **kwargs): | |
| 938 model = MetaFormer( | |
| 939 depths=[3, 12, 18, 3], | |
| 940 dims=[64, 128, 320, 512], | |
| 941 token_mixers=SepConv, | |
| 942 head_fn=MlpHead, | |
| 943 **kwargs) | |
| 944 model.default_cfg = default_cfgs['convformer_s36_384_in21ft1k'] | |
| 945 if pretrained: | |
| 946 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) | |
| 947 model.load_state_dict(state_dict) | |
| 948 return model | |
| 949 | |
| 950 | |
| 951 @register_model | |
| 952 def convformer_s36_in21k(pretrained=False, **kwargs): | |
| 953 model = MetaFormer( | |
| 954 depths=[3, 12, 18, 3], | |
| 955 dims=[64, 128, 320, 512], | |
| 956 token_mixers=SepConv, | |
| 957 head_fn=MlpHead, | |
| 958 **kwargs) | |
| 959 model.default_cfg = default_cfgs['convformer_s36_in21k'] | |
| 960 if pretrained: | |
| 961 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) | |
| 962 model.load_state_dict(state_dict) | |
| 963 return model | |
| 964 | |
| 965 | |
| 966 @register_model | |
| 967 def convformer_m36(pretrained=False, **kwargs): | |
| 968 model = MetaFormer( | |
| 969 depths=[3, 12, 18, 3], | |
| 970 dims=[96, 192, 384, 576], | |
| 971 token_mixers=SepConv, | |
| 972 head_fn=MlpHead, | |
| 973 **kwargs) | |
| 974 model.default_cfg = default_cfgs['convformer_m36'] | |
| 975 if pretrained: | |
| 976 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) | |
| 977 model.load_state_dict(state_dict) | |
| 978 return model | |
| 979 | |
| 980 | |
| 981 @register_model | |
| 982 def convformer_m36_384(pretrained=False, **kwargs): | |
| 983 model = MetaFormer( | |
| 984 depths=[3, 12, 18, 3], | |
| 985 dims=[96, 192, 384, 576], | |
| 986 token_mixers=SepConv, | |
| 987 head_fn=MlpHead, | |
| 988 **kwargs) | |
| 989 model.default_cfg = default_cfgs['convformer_m36_384'] | |
| 990 if pretrained: | |
| 991 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) | |
| 992 model.load_state_dict(state_dict) | |
| 993 return model | |
| 994 | |
| 995 | |
| 996 @register_model | |
| 997 def convformer_m36_in21ft1k(pretrained=False, **kwargs): | |
| 998 model = MetaFormer( | |
| 999 depths=[3, 12, 18, 3], | |
| 1000 dims=[96, 192, 384, 576], | |
| 1001 token_mixers=SepConv, | |
| 1002 head_fn=MlpHead, | |
| 1003 **kwargs) | |
| 1004 model.default_cfg = default_cfgs['convformer_m36_in21ft1k'] | |
| 1005 if pretrained: | |
| 1006 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) | |
| 1007 model.load_state_dict(state_dict) | |
| 1008 return model | |
| 1009 | |
| 1010 | |
| 1011 @register_model | |
| 1012 def convformer_m36_384_in21ft1k(pretrained=False, **kwargs): | |
| 1013 model = MetaFormer( | |
| 1014 depths=[3, 12, 18, 3], | |
| 1015 dims=[96, 192, 384, 576], | |
| 1016 token_mixers=SepConv, | |
| 1017 head_fn=MlpHead, | |
| 1018 **kwargs) | |
| 1019 model.default_cfg = default_cfgs['convformer_m36_384_in21ft1k'] | |
| 1020 if pretrained: | |
| 1021 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) | |
| 1022 model.load_state_dict(state_dict) | |
| 1023 return model | |
| 1024 | |
| 1025 | |
| 1026 @register_model | |
| 1027 def convformer_m36_in21k(pretrained=False, **kwargs): | |
| 1028 model = MetaFormer( | |
| 1029 depths=[3, 12, 18, 3], | |
| 1030 dims=[96, 192, 384, 576], | |
| 1031 token_mixers=SepConv, | |
| 1032 head_fn=MlpHead, | |
| 1033 **kwargs) | |
| 1034 model.default_cfg = default_cfgs['convformer_m36_in21k'] | |
| 1035 if pretrained: | |
| 1036 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) | |
| 1037 model.load_state_dict(state_dict) | |
| 1038 return model | |
| 1039 | |
| 1040 | |
| 1041 @register_model | |
| 1042 def convformer_b36(pretrained=False, **kwargs): | |
| 1043 model = MetaFormer( | |
| 1044 depths=[3, 12, 18, 3], | |
| 1045 dims=[128, 256, 512, 768], | |
| 1046 token_mixers=SepConv, | |
| 1047 head_fn=MlpHead, | |
| 1048 **kwargs) | |
| 1049 model.default_cfg = default_cfgs['convformer_b36'] | |
| 1050 if pretrained: | |
| 1051 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) | |
| 1052 model.load_state_dict(state_dict) | |
| 1053 return model | |
| 1054 | |
| 1055 | |
| 1056 @register_model | |
| 1057 def convformer_b36_384(pretrained=False, **kwargs): | |
| 1058 model = MetaFormer( | |
| 1059 depths=[3, 12, 18, 3], | |
| 1060 dims=[128, 256, 512, 768], | |
| 1061 token_mixers=SepConv, | |
| 1062 head_fn=MlpHead, | |
| 1063 **kwargs) | |
| 1064 model.default_cfg = default_cfgs['convformer_b36_384'] | |
| 1065 if pretrained: | |
| 1066 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) | |
| 1067 model.load_state_dict(state_dict) | |
| 1068 return model | |
| 1069 | |
| 1070 | |
| 1071 @register_model | |
| 1072 def convformer_b36_in21ft1k(pretrained=False, **kwargs): | |
| 1073 model = MetaFormer( | |
| 1074 depths=[3, 12, 18, 3], | |
| 1075 dims=[128, 256, 512, 768], | |
| 1076 token_mixers=SepConv, | |
| 1077 head_fn=MlpHead, | |
| 1078 **kwargs) | |
| 1079 model.default_cfg = default_cfgs['convformer_b36_in21ft1k'] | |
| 1080 if pretrained: | |
| 1081 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) | |
| 1082 model.load_state_dict(state_dict) | |
| 1083 return model | |
| 1084 | |
| 1085 | |
| 1086 @register_model | |
| 1087 def convformer_b36_384_in21ft1k(pretrained=False, **kwargs): | |
| 1088 model = MetaFormer( | |
| 1089 depths=[3, 12, 18, 3], | |
| 1090 dims=[128, 256, 512, 768], | |
| 1091 token_mixers=SepConv, | |
| 1092 head_fn=MlpHead, | |
| 1093 **kwargs) | |
| 1094 model.default_cfg = default_cfgs['convformer_b36_384_in21ft1k'] | |
| 1095 if pretrained: | |
| 1096 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) | |
| 1097 model.load_state_dict(state_dict) | |
| 1098 return model | |
| 1099 | |
| 1100 | |
| 1101 @register_model | |
| 1102 def convformer_b36_in21k(pretrained=False, **kwargs): | |
| 1103 model = MetaFormer( | |
| 1104 depths=[3, 12, 18, 3], | |
| 1105 dims=[128, 256, 512, 768], | |
| 1106 token_mixers=SepConv, | |
| 1107 head_fn=MlpHead, | |
| 1108 **kwargs) | |
| 1109 model.default_cfg = default_cfgs['convformer_b36_in21k'] | |
| 1110 if pretrained: | |
| 1111 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) | |
| 1112 model.load_state_dict(state_dict) | |
| 1113 return model | |
| 1114 | |
| 1115 | |
| 1116 @register_model | |
| 1117 def caformer_s18(pretrained=False, **kwargs): | |
| 1118 model = MetaFormer( | |
| 1119 depths=[3, 3, 9, 3], | |
| 1120 dims=[64, 128, 320, 512], | |
| 1121 token_mixers=[SepConv, SepConv, Attention, Attention], | |
| 1122 head_fn=MlpHead, | |
| 1123 **kwargs) | |
| 1124 model.default_cfg = default_cfgs['caformer_s18'] | |
| 1125 if pretrained: | |
| 1126 try: | |
| 1127 print(f"Loading pretrained weights for caformer_s18 from: {model.default_cfg['url']}") | |
| 1128 # Add timeout to prevent hanging in CI environments | |
| 1129 import socket | |
| 1130 original_timeout = socket.getdefaulttimeout() | |
| 1131 socket.setdefaulttimeout(60) # 60 second timeout | |
| 1132 try: | |
| 1133 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) | |
| 1134 model.load_state_dict(state_dict) | |
| 1135 print("✓ Successfully loaded pretrained weights for caformer_s18") | |
| 1136 finally: | |
| 1137 socket.setdefaulttimeout(original_timeout) | |
| 1138 except Exception as e: | |
| 1139 print(f"⚠ Warning: Failed to load pretrained weights for caformer_s18: {e}") | |
| 1140 print("Continuing with randomly initialized weights...") | |
| 1141 return model | |
| 1142 | |
| 1143 | |
| 1144 @register_model | |
| 1145 def caformer_s18_384(pretrained=False, **kwargs): | |
| 1146 model = MetaFormer( | |
| 1147 depths=[3, 3, 9, 3], | |
| 1148 dims=[64, 128, 320, 512], | |
| 1149 token_mixers=[SepConv, SepConv, Attention, Attention], | |
| 1150 head_fn=MlpHead, | |
| 1151 **kwargs) | |
| 1152 model.default_cfg = default_cfgs['caformer_s18_384'] | |
| 1153 if pretrained: | |
| 1154 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) | |
| 1155 model.load_state_dict(state_dict) | |
| 1156 return model | |
| 1157 | |
| 1158 | |
| 1159 @register_model | |
| 1160 def caformer_s18_in21ft1k(pretrained=False, **kwargs): | |
| 1161 model = MetaFormer( | |
| 1162 depths=[3, 3, 9, 3], | |
| 1163 dims=[64, 128, 320, 512], | |
| 1164 token_mixers=[SepConv, SepConv, Attention, Attention], | |
| 1165 head_fn=MlpHead, | |
| 1166 **kwargs) | |
| 1167 model.default_cfg = default_cfgs['caformer_s18_in21ft1k'] | |
| 1168 if pretrained: | |
| 1169 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) | |
| 1170 model.load_state_dict(state_dict) | |
| 1171 return model | |
| 1172 | |
| 1173 | |
| 1174 @register_model | |
| 1175 def caformer_s18_384_in21ft1k(pretrained=False, **kwargs): | |
| 1176 model = MetaFormer( | |
| 1177 depths=[3, 3, 9, 3], | |
| 1178 dims=[64, 128, 320, 512], | |
| 1179 token_mixers=[SepConv, SepConv, Attention, Attention], | |
| 1180 head_fn=MlpHead, | |
| 1181 **kwargs) | |
| 1182 model.default_cfg = default_cfgs['caformer_s18_384_in21ft1k'] | |
| 1183 if pretrained: | |
| 1184 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) | |
| 1185 model.load_state_dict(state_dict) | |
| 1186 return model | |
| 1187 | |
| 1188 | |
| 1189 @register_model | |
| 1190 def caformer_s18_in21k(pretrained=False, **kwargs): | |
| 1191 model = MetaFormer( | |
| 1192 depths=[3, 3, 9, 3], | |
| 1193 dims=[64, 128, 320, 512], | |
| 1194 token_mixers=[SepConv, SepConv, Attention, Attention], | |
| 1195 head_fn=MlpHead, | |
| 1196 **kwargs) | |
| 1197 model.default_cfg = default_cfgs['caformer_s18_in21k'] | |
| 1198 if pretrained: | |
| 1199 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) | |
| 1200 model.load_state_dict(state_dict) | |
| 1201 return model | |
| 1202 | |
| 1203 | |
| 1204 @register_model | |
| 1205 def caformer_s36(pretrained=False, **kwargs): | |
| 1206 model = MetaFormer( | |
| 1207 depths=[3, 12, 18, 3], | |
| 1208 dims=[64, 128, 320, 512], | |
| 1209 token_mixers=[SepConv, SepConv, Attention, Attention], | |
| 1210 head_fn=MlpHead, | |
| 1211 **kwargs) | |
| 1212 model.default_cfg = default_cfgs['caformer_s36'] | |
| 1213 if pretrained: | |
| 1214 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) | |
| 1215 model.load_state_dict(state_dict) | |
| 1216 return model | |
| 1217 | |
| 1218 | |
| 1219 @register_model | |
| 1220 def caformer_s36_384(pretrained=False, **kwargs): | |
| 1221 model = MetaFormer( | |
| 1222 depths=[3, 12, 18, 3], | |
| 1223 dims=[64, 128, 320, 512], | |
| 1224 token_mixers=[SepConv, SepConv, Attention, Attention], | |
| 1225 head_fn=MlpHead, | |
| 1226 **kwargs) | |
| 1227 model.default_cfg = default_cfgs['caformer_s36_384'] | |
| 1228 if pretrained: | |
| 1229 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) | |
| 1230 model.load_state_dict(state_dict) | |
| 1231 return model | |
| 1232 | |
| 1233 | |
| 1234 @register_model | |
| 1235 def caformer_s36_in21ft1k(pretrained=False, **kwargs): | |
| 1236 model = MetaFormer( | |
| 1237 depths=[3, 12, 18, 3], | |
| 1238 dims=[64, 128, 320, 512], | |
| 1239 token_mixers=[SepConv, SepConv, Attention, Attention], | |
| 1240 head_fn=MlpHead, | |
| 1241 **kwargs) | |
| 1242 model.default_cfg = default_cfgs['caformer_s36_in21ft1k'] | |
| 1243 if pretrained: | |
| 1244 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) | |
| 1245 model.load_state_dict(state_dict) | |
| 1246 return model | |
| 1247 | |
| 1248 | |
| 1249 @register_model | |
| 1250 def caformer_s36_384_in21ft1k(pretrained=False, **kwargs): | |
| 1251 model = MetaFormer( | |
| 1252 depths=[3, 12, 18, 3], | |
| 1253 dims=[64, 128, 320, 512], | |
| 1254 token_mixers=[SepConv, SepConv, Attention, Attention], | |
| 1255 head_fn=MlpHead, | |
| 1256 **kwargs) | |
| 1257 model.default_cfg = default_cfgs['caformer_s36_384_in21ft1k'] | |
| 1258 if pretrained: | |
| 1259 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) | |
| 1260 model.load_state_dict(state_dict) | |
| 1261 return model | |
| 1262 | |
| 1263 | |
| 1264 @register_model | |
| 1265 def caformer_s36_in21k(pretrained=False, **kwargs): | |
| 1266 model = MetaFormer( | |
| 1267 depths=[3, 12, 18, 3], | |
| 1268 dims=[64, 128, 320, 512], | |
| 1269 token_mixers=[SepConv, SepConv, Attention, Attention], | |
| 1270 head_fn=MlpHead, | |
| 1271 **kwargs) | |
| 1272 model.default_cfg = default_cfgs['caformer_s36_in21k'] | |
| 1273 if pretrained: | |
| 1274 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) | |
| 1275 model.load_state_dict(state_dict) | |
| 1276 return model | |
| 1277 | |
| 1278 | |
| 1279 @register_model | |
| 1280 def caformer_m36(pretrained=False, **kwargs): | |
| 1281 model = MetaFormer( | |
| 1282 depths=[3, 12, 18, 3], | |
| 1283 dims=[96, 192, 384, 576], | |
| 1284 token_mixers=[SepConv, SepConv, Attention, Attention], | |
| 1285 head_fn=MlpHead, | |
| 1286 **kwargs) | |
| 1287 model.default_cfg = default_cfgs['caformer_m36'] | |
| 1288 if pretrained: | |
| 1289 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) | |
| 1290 model.load_state_dict(state_dict) | |
| 1291 return model | |
| 1292 | |
| 1293 | |
| 1294 @register_model | |
| 1295 def caformer_m36_384(pretrained=False, **kwargs): | |
| 1296 model = MetaFormer( | |
| 1297 depths=[3, 12, 18, 3], | |
| 1298 dims=[96, 192, 384, 576], | |
| 1299 token_mixers=[SepConv, SepConv, Attention, Attention], | |
| 1300 head_fn=MlpHead, | |
| 1301 **kwargs) | |
| 1302 model.default_cfg = default_cfgs['caformer_m36_384'] | |
| 1303 if pretrained: | |
| 1304 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) | |
| 1305 model.load_state_dict(state_dict) | |
| 1306 return model | |
| 1307 | |
| 1308 | |
| 1309 @register_model | |
| 1310 def caformer_m36_in21ft1k(pretrained=False, **kwargs): | |
| 1311 model = MetaFormer( | |
| 1312 depths=[3, 12, 18, 3], | |
| 1313 dims=[96, 192, 384, 576], | |
| 1314 token_mixers=[SepConv, SepConv, Attention, Attention], | |
| 1315 head_fn=MlpHead, | |
| 1316 **kwargs) | |
| 1317 model.default_cfg = default_cfgs['caformer_m36_in21ft1k'] | |
| 1318 if pretrained: | |
| 1319 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) | |
| 1320 model.load_state_dict(state_dict) | |
| 1321 return model | |
| 1322 | |
| 1323 | |
| 1324 @register_model | |
| 1325 def caformer_m36_384_in21ft1k(pretrained=False, **kwargs): | |
| 1326 model = MetaFormer( | |
| 1327 depths=[3, 12, 18, 3], | |
| 1328 dims=[96, 192, 384, 576], | |
| 1329 token_mixers=[SepConv, SepConv, Attention, Attention], | |
| 1330 head_fn=MlpHead, | |
| 1331 **kwargs) | |
| 1332 model.default_cfg = default_cfgs['caformer_m36_384_in21ft1k'] | |
| 1333 if pretrained: | |
| 1334 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) | |
| 1335 model.load_state_dict(state_dict) | |
| 1336 return model | |
| 1337 | |
| 1338 | |
| 1339 @register_model | |
| 1340 def caformer_m36_in21k(pretrained=False, **kwargs): | |
| 1341 model = MetaFormer( | |
| 1342 depths=[3, 12, 18, 3], | |
| 1343 dims=[96, 192, 384, 576], | |
| 1344 token_mixers=[SepConv, SepConv, Attention, Attention], | |
| 1345 head_fn=MlpHead, | |
| 1346 **kwargs) | |
| 1347 model.default_cfg = default_cfgs['caformer_m36_in21k'] | |
| 1348 if pretrained: | |
| 1349 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) | |
| 1350 model.load_state_dict(state_dict) | |
| 1351 return model | |
| 1352 | |
| 1353 | |
| 1354 @register_model | |
| 1355 def caformer_b36(pretrained=False, **kwargs): | |
| 1356 model = MetaFormer( | |
| 1357 depths=[3, 12, 18, 3], | |
| 1358 dims=[128, 256, 512, 768], | |
| 1359 token_mixers=[SepConv, SepConv, Attention, Attention], | |
| 1360 head_fn=MlpHead, | |
| 1361 **kwargs) | |
| 1362 model.default_cfg = default_cfgs['caformer_b36'] | |
| 1363 if pretrained: | |
| 1364 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) | |
| 1365 model.load_state_dict(state_dict) | |
| 1366 return model | |
| 1367 | |
| 1368 | |
| 1369 @register_model | |
| 1370 def caformer_b36_384(pretrained=False, **kwargs): | |
| 1371 model = MetaFormer( | |
| 1372 depths=[3, 12, 18, 3], | |
| 1373 dims=[128, 256, 512, 768], | |
| 1374 token_mixers=[SepConv, SepConv, Attention, Attention], | |
| 1375 head_fn=MlpHead, | |
| 1376 **kwargs) | |
| 1377 model.default_cfg = default_cfgs['caformer_b36_384'] | |
| 1378 if pretrained: | |
| 1379 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) | |
| 1380 model.load_state_dict(state_dict) | |
| 1381 return model | |
| 1382 | |
| 1383 | |
| 1384 @register_model | |
| 1385 def caformer_b36_in21ft1k(pretrained=False, **kwargs): | |
| 1386 model = MetaFormer( | |
| 1387 depths=[3, 12, 18, 3], | |
| 1388 dims=[128, 256, 512, 768], | |
| 1389 token_mixers=[SepConv, SepConv, Attention, Attention], | |
| 1390 head_fn=MlpHead, | |
| 1391 **kwargs) | |
| 1392 model.default_cfg = default_cfgs['caformer_b36_in21ft1k'] | |
| 1393 if pretrained: | |
| 1394 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) | |
| 1395 model.load_state_dict(state_dict) | |
| 1396 return model | |
| 1397 | |
| 1398 | |
| 1399 @register_model | |
| 1400 def caformer_b36_384_in21ft1k(pretrained=False, **kwargs): | |
| 1401 model = MetaFormer( | |
| 1402 depths=[3, 12, 18, 3], | |
| 1403 dims=[128, 256, 512, 768], | |
| 1404 token_mixers=[SepConv, SepConv, Attention, Attention], | |
| 1405 head_fn=MlpHead, | |
| 1406 **kwargs) | |
| 1407 model.default_cfg = default_cfgs['caformer_b36_384_in21ft1k'] | |
| 1408 if pretrained: | |
| 1409 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) | |
| 1410 model.load_state_dict(state_dict) | |
| 1411 return model | |
| 1412 | |
| 1413 | |
| 1414 @register_model | |
| 1415 def caformer_b36_in21k(pretrained=False, **kwargs): | |
| 1416 model = MetaFormer( | |
| 1417 depths=[3, 12, 18, 3], | |
| 1418 dims=[128, 256, 512, 768], | |
| 1419 token_mixers=[SepConv, SepConv, Attention, Attention], | |
| 1420 head_fn=MlpHead, | |
| 1421 **kwargs) | |
| 1422 model.default_cfg = default_cfgs['caformer_b36_in21k'] | |
| 1423 if pretrained: | |
| 1424 state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) | |
| 1425 model.load_state_dict(state_dict) | |
| 1426 return model | 
