Mercurial > repos > goeckslab > image_learner
view MetaFormer/metaformer_models.py @ 11:c5150cceab47 draft default tip
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
author | goeckslab |
---|---|
date | Sat, 18 Oct 2025 03:17:09 +0000 |
parents | |
children |
line wrap: on
line source
""" MetaFormer baselines including IdentityFormer, RandFormer, PoolFormerV2, ConvFormer and CAFormer. Standalone implementation for Galaxy Image Learner tool (no timm dependency). """ import logging from functools import partial import torch import torch.nn as nn import torch.nn.functional as F from torch.nn.init import trunc_normal_ # use torch's built-in truncated normal logger = logging.getLogger(__name__) def to_2tuple(v): if isinstance(v, (list, tuple)): return tuple(v) return (v, v) class DropPath(nn.Module): def __init__(self, drop_prob: float = 0.0): super().__init__() self.drop_prob = float(drop_prob) def forward(self, x): if self.drop_prob == 0.0 or not self.training: return x keep_prob = 1.0 - self.drop_prob shape = (x.shape[0],) + (1,) * (x.ndim - 1) random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) random_tensor.floor_() return x.div(keep_prob) * random_tensor # ImageNet normalization constants IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) def register_model(fn): # no-op decorator to mirror timm API without dependency return fn def _cfg(url: str = '', **kwargs): return { 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, 'crop_pct': 1.0, 'interpolation': 'bicubic', 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'classifier': 'head', **kwargs } default_cfgs = { 'identityformer_s12': _cfg( url='https://huggingface.co/sail/dl/resolve/main/identityformer/identityformer_s12.pth'), 'identityformer_s24': _cfg( url='https://huggingface.co/sail/dl/resolve/main/identityformer/identityformer_s24.pth'), 'identityformer_s36': _cfg( url='https://huggingface.co/sail/dl/resolve/main/identityformer/identityformer_s36.pth'), 'identityformer_m36': _cfg( url='https://huggingface.co/sail/dl/resolve/main/identityformer/identityformer_m36.pth'), 'identityformer_m48': _cfg( url='https://huggingface.co/sail/dl/resolve/main/identityformer/identityformer_m48.pth'), 'randformer_s12': _cfg( url='https://huggingface.co/sail/dl/resolve/main/randformer/randformer_s12.pth'), 'randformer_s24': _cfg( url='https://huggingface.co/sail/dl/resolve/main/randformer/randformer_s24.pth'), 'randformer_s36': _cfg( url='https://huggingface.co/sail/dl/resolve/main/randformer/randformer_s36.pth'), 'randformer_m36': _cfg( url='https://huggingface.co/sail/dl/resolve/main/randformer/randformer_m36.pth'), 'randformer_m48': _cfg( url='https://huggingface.co/sail/dl/resolve/main/randformer/randformer_m48.pth'), 'poolformerv2_s12': _cfg( url='https://huggingface.co/sail/dl/resolve/main/poolformerv2/poolformerv2_s12.pth'), 'poolformerv2_s24': _cfg( url='https://huggingface.co/sail/dl/resolve/main/poolformerv2/poolformerv2_s24.pth'), 'poolformerv2_s36': _cfg( url='https://huggingface.co/sail/dl/resolve/main/poolformerv2/poolformerv2_s36.pth'), 'poolformerv2_m36': _cfg( url='https://huggingface.co/sail/dl/resolve/main/poolformerv2/poolformerv2_m36.pth'), 'poolformerv2_m48': _cfg( url='https://huggingface.co/sail/dl/resolve/main/poolformerv2/poolformerv2_m48.pth'), 'convformer_s18': _cfg( url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s18.pth'), 'convformer_s18_384': _cfg( url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s18_384.pth', input_size=(3, 384, 384)), 'convformer_s18_in21ft1k': _cfg( url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s18_in21ft1k.pth'), 'convformer_s18_384_in21ft1k': _cfg( url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s18_384_in21ft1k.pth', input_size=(3, 384, 384)), 'convformer_s18_in21k': _cfg( url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s18_in21k.pth', num_classes=21841), 'convformer_s36': _cfg( url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s36.pth'), 'convformer_s36_384': _cfg( url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s36_384.pth', input_size=(3, 384, 384)), 'convformer_s36_in21ft1k': _cfg( url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s36_in21ft1k.pth'), 'convformer_s36_384_in21ft1k': _cfg( url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s36_384_in21ft1k.pth', input_size=(3, 384, 384)), 'convformer_s36_in21k': _cfg( url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s36_in21k.pth', num_classes=21841), 'convformer_m36': _cfg( url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_m36.pth'), 'convformer_m36_384': _cfg( url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_m36_384.pth', input_size=(3, 384, 384)), 'convformer_m36_in21ft1k': _cfg( url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_m36_in21ft1k.pth'), 'convformer_m36_384_in21ft1k': _cfg( url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_m36_384_in21ft1k.pth', input_size=(3, 384, 384)), 'convformer_m36_in21k': _cfg( url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_m36_in21k.pth', num_classes=21841), 'convformer_b36': _cfg( url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_b36.pth'), 'convformer_b36_384': _cfg( url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_b36_384.pth', input_size=(3, 384, 384)), 'convformer_b36_in21ft1k': _cfg( url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_b36_in21ft1k.pth'), 'convformer_b36_384_in21ft1k': _cfg( url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_b36_384_in21ft1k.pth', input_size=(3, 384, 384)), 'convformer_b36_in21k': _cfg( url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_b36_in21k.pth', num_classes=21841), 'caformer_s18': _cfg( url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s18.pth'), 'caformer_s18_384': _cfg( url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s18_384.pth', input_size=(3, 384, 384)), 'caformer_s18_in21ft1k': _cfg( url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s18_in21ft1k.pth'), 'caformer_s18_384_in21ft1k': _cfg( url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s18_384_in21ft1k.pth', input_size=(3, 384, 384)), 'caformer_s18_in21k': _cfg( url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s18_in21k.pth', num_classes=21841), 'caformer_s36': _cfg( url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s36.pth'), 'caformer_s36_384': _cfg( url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s36_384.pth', input_size=(3, 384, 384)), 'caformer_s36_in21ft1k': _cfg( url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s36_in21ft1k.pth'), 'caformer_s36_384_in21ft1k': _cfg( url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s36_384_in21ft1k.pth', input_size=(3, 384, 384)), 'caformer_s36_in21k': _cfg( url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s36_in21k.pth', num_classes=21841), 'caformer_m36': _cfg( url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_m36.pth'), 'caformer_m36_384': _cfg( url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_m36_384.pth', input_size=(3, 384, 384)), 'caformer_m36_in21ft1k': _cfg( url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_m36_in21ft1k.pth'), 'caformer_m36_384_in21ft1k': _cfg( url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_m36_384_in21ft1k.pth', input_size=(3, 384, 384)), 'caformer_m36_in21k': _cfg( url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_m36_in21k.pth', num_classes=21841), 'caformer_b36': _cfg( url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_b36.pth'), 'caformer_b36_384': _cfg( url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_b36_384.pth', input_size=(3, 384, 384)), 'caformer_b36_in21ft1k': _cfg( url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_b36_in21ft1k.pth'), 'caformer_b36_384_in21ft1k': _cfg( url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_b36_384_in21ft1k.pth', input_size=(3, 384, 384)), 'caformer_b36_in21k': _cfg( url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_b36_in21k.pth', num_classes=21841), } class Downsampling(nn.Module): """Downsampling implemented by a layer of convolution.""" def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, pre_norm=None, post_norm=None, pre_permute=False): super().__init__() self.pre_norm = pre_norm(in_channels) if pre_norm else nn.Identity() self.pre_permute = pre_permute self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding) self.post_norm = post_norm(out_channels) if post_norm else nn.Identity() def forward(self, x): x = self.pre_norm(x) if self.pre_permute: x = x.permute(0, 3, 1, 2) x = self.conv(x) x = x.permute(0, 2, 3, 1) x = self.post_norm(x) return x class Scale(nn.Module): """Scale vector by element multiplications.""" def __init__(self, dim, init_value=1.0, trainable=True): super().__init__() self.scale = nn.Parameter(init_value * torch.ones(dim), requires_grad=trainable) def forward(self, x): return x * self.scale class SquaredReLU(nn.Module): """Squared ReLU: https://arxiv.org/abs/2109.08668""" def __init__(self, inplace=False): super().__init__() self.relu = nn.ReLU(inplace=inplace) def forward(self, x): return torch.square(self.relu(x)) class StarReLU(nn.Module): """StarReLU: s * relu(x) ** 2 + b""" def __init__(self, scale_value=1.0, bias_value=0.0, scale_learnable=True, bias_learnable=True, mode=None, inplace=False): super().__init__() self.inplace = inplace self.relu = nn.ReLU(inplace=inplace) self.scale = nn.Parameter(scale_value * torch.ones(1), requires_grad=scale_learnable) self.bias = nn.Parameter(bias_value * torch.ones(1), requires_grad=bias_learnable) def forward(self, x): return self.scale * self.relu(x) ** 2 + self.bias class Attention(nn.Module): """Vanilla self-attention from Transformer.""" def __init__(self, dim, head_dim=32, num_heads=None, qkv_bias=False, attn_drop=0., proj_drop=0., proj_bias=False, **kwargs): super().__init__() self.head_dim = head_dim self.scale = head_dim ** -0.5 self.num_heads = num_heads if num_heads else dim // head_dim if self.num_heads == 0: self.num_heads = 1 self.attention_dim = self.num_heads * self.head_dim self.qkv = nn.Linear(dim, self.attention_dim * 3, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(self.attention_dim, dim, bias=proj_bias) self.proj_drop = nn.Dropout(proj_drop) def forward(self, x): B, H, W, C = x.shape N = H * W qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) q, k, v = qkv.unbind(0) attn = (q @ k.transpose(-2, -1)) * self.scale attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) x = (attn @ v).transpose(1, 2).reshape(B, H, W, self.attention_dim) x = self.proj(x) x = self.proj_drop(x) return x class RandomMixing(nn.Module): def __init__(self, num_tokens=196, **kwargs): super().__init__() self.num_tokens = num_tokens base_matrix = torch.softmax(torch.rand(num_tokens, num_tokens), dim=-1) self.register_buffer("random_matrix", base_matrix, persistent=True) def forward(self, x): B, H, W, C = x.shape actual_tokens = H * W if actual_tokens == self.random_matrix.shape[0]: mixing = self.random_matrix else: base = self.random_matrix if base.device != x.device: base = base.to(x.device) resized = F.interpolate( base.unsqueeze(0).unsqueeze(0), size=(actual_tokens, actual_tokens), mode='bilinear', align_corners=False, ).squeeze(0).squeeze(0) mixing = torch.softmax(resized, dim=-1) x = x.reshape(B, actual_tokens, C) x = torch.einsum('mn, bnc -> bmc', mixing, x) x = x.reshape(B, H, W, C) return x class LayerNormGeneral(nn.Module): """General LayerNorm for different situations.""" def __init__(self, affine_shape=None, normalized_dim=(-1,), scale=True, bias=True, eps=1e-5): super().__init__() self.normalized_dim = normalized_dim self.use_scale = scale self.use_bias = bias self.weight = nn.Parameter(torch.ones(affine_shape)) if scale else None self.bias = nn.Parameter(torch.zeros(affine_shape)) if bias else None self.eps = eps def forward(self, x): c = x - x.mean(self.normalized_dim, keepdim=True) s = c.pow(2).mean(self.normalized_dim, keepdim=True) x = c / torch.sqrt(s + self.eps) if self.use_scale: x = x * self.weight if self.use_bias: x = x + self.bias return x class LayerNormWithoutBias(nn.Module): """Equal to partial(LayerNormGeneral, bias=False) but faster.""" def __init__(self, normalized_shape, eps=1e-5, **kwargs): super().__init__() self.eps = eps self.bias = None if isinstance(normalized_shape, int): normalized_shape = (normalized_shape,) self.weight = nn.Parameter(torch.ones(normalized_shape)) self.normalized_shape = normalized_shape def forward(self, x): return F.layer_norm(x, self.normalized_shape, weight=self.weight, bias=self.bias, eps=self.eps) class SepConv(nn.Module): """Inverted separable convolution from MobileNetV2.""" def __init__(self, dim, expansion_ratio=2, act1_layer=StarReLU, act2_layer=nn.Identity, bias=False, kernel_size=7, padding=3, **kwargs): super().__init__() med_channels = int(expansion_ratio * dim) self.pwconv1 = nn.Linear(dim, med_channels, bias=bias) self.act1 = act1_layer() self.dwconv = nn.Conv2d( med_channels, med_channels, kernel_size=kernel_size, padding=padding, groups=med_channels, bias=bias) self.act2 = act2_layer() self.pwconv2 = nn.Linear(med_channels, dim, bias=bias) def forward(self, x): x = self.pwconv1(x) x = self.act1(x) x = x.permute(0, 3, 1, 2) x = self.dwconv(x) x = x.permute(0, 2, 3, 1) x = self.act2(x) x = self.pwconv2(x) return x class Pooling(nn.Module): """Pooling for PoolFormer.""" def __init__(self, pool_size=3, **kwargs): super().__init__() self.pool = nn.AvgPool2d(pool_size, stride=1, padding=pool_size // 2, count_include_pad=False) def forward(self, x): y = x.permute(0, 3, 1, 2) y = self.pool(y) y = y.permute(0, 2, 3, 1) return y - x class Mlp(nn.Module): """ MLP used in MetaFormer models.""" def __init__(self, dim, mlp_ratio=4, out_features=None, act_layer=StarReLU, drop=0., bias=False, **kwargs): super().__init__() in_features = dim out_features = out_features or in_features hidden_features = int(mlp_ratio * in_features) drop_probs = to_2tuple(drop) self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) self.act = act_layer() self.drop1 = nn.Dropout(drop_probs[0]) self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) self.drop2 = nn.Dropout(drop_probs[1]) def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.drop1(x) x = self.fc2(x) x = self.drop2(x) return x class MlpHead(nn.Module): def __init__(self, dim, num_classes=1000, mlp_ratio=4, act_layer=SquaredReLU, norm_layer=nn.LayerNorm, head_dropout=0., bias=True): super().__init__() hidden_features = int(mlp_ratio * dim) self.fc1 = nn.Linear(dim, hidden_features, bias=bias) self.act = act_layer() self.norm = norm_layer(hidden_features) self.fc2 = nn.Linear(hidden_features, num_classes, bias=bias) self.head_dropout = nn.Dropout(head_dropout) def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.norm(x) x = self.head_dropout(x) x = self.fc2(x) return x class MetaFormerBlock(nn.Module): def __init__(self, dim, token_mixer=nn.Identity, mlp=Mlp, norm_layer=nn.LayerNorm, drop=0., drop_path=0., layer_scale_init_value=None, res_scale_init_value=None): super().__init__() self.norm1 = norm_layer(dim) self.token_mixer = token_mixer(dim=dim, drop=drop) self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.layer_scale1 = Scale(dim=dim, init_value=layer_scale_init_value) if layer_scale_init_value else nn.Identity() self.res_scale1 = Scale(dim=dim, init_value=res_scale_init_value) if res_scale_init_value else nn.Identity() self.norm2 = norm_layer(dim) self.mlp = mlp(dim=dim, drop=drop) self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.layer_scale2 = Scale(dim=dim, init_value=layer_scale_init_value) if layer_scale_init_value else nn.Identity() self.res_scale2 = Scale(dim=dim, init_value=res_scale_init_value) if res_scale_init_value else nn.Identity() def forward(self, x): x = self.res_scale1(x) + self.layer_scale1(self.drop_path1(self.token_mixer(self.norm1(x)))) x = self.res_scale2(x) + self.layer_scale2(self.drop_path2(self.mlp(self.norm2(x)))) return x DOWNSAMPLE_LAYERS_FOUR_STAGES = [partial(Downsampling, kernel_size=7, stride=4, padding=2, post_norm=partial(LayerNormGeneral, bias=False, eps=1e-6) )] + \ [partial(Downsampling, kernel_size=3, stride=2, padding=1, pre_norm=partial(LayerNormGeneral, bias=False, eps=1e-6), pre_permute=True )] * 3 class MetaFormer(nn.Module): def __init__(self, in_chans=3, num_classes=1000, depths=[2, 2, 6, 2], dims=[64, 128, 320, 512], downsample_layers=DOWNSAMPLE_LAYERS_FOUR_STAGES, token_mixers=nn.Identity, mlps=Mlp, norm_layers=partial(LayerNormWithoutBias, eps=1e-6), drop_path_rate=0., head_dropout=0.0, layer_scale_init_values=None, res_scale_init_values=[None, None, 1.0, 1.0], output_norm=partial(nn.LayerNorm, eps=1e-6), head_fn=nn.Linear, **kwargs): super().__init__() self.num_classes = num_classes if not isinstance(depths, (list, tuple)): depths = [depths] if not isinstance(dims, (list, tuple)): dims = [dims] num_stage = len(depths) self.num_stage = num_stage if not isinstance(downsample_layers, (list, tuple)): downsample_layers = [downsample_layers] * num_stage down_dims = [in_chans] + dims self.downsample_layers = nn.ModuleList( [downsample_layers[i](down_dims[i], down_dims[i + 1]) for i in range(num_stage)] ) if not isinstance(token_mixers, (list, tuple)): token_mixers = [token_mixers] * num_stage if not isinstance(mlps, (list, tuple)): mlps = [mlps] * num_stage if not isinstance(norm_layers, (list, tuple)): norm_layers = [norm_layers] * num_stage dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] if not isinstance(layer_scale_init_values, (list, tuple)): layer_scale_init_values = [layer_scale_init_values] * num_stage if not isinstance(res_scale_init_values, (list, tuple)): res_scale_init_values = [res_scale_init_values] * num_stage self.stages = nn.ModuleList() cur = 0 for i in range(num_stage): stage = nn.Sequential( *[MetaFormerBlock(dim=dims[i], token_mixer=token_mixers[i], mlp=mlps[i], norm_layer=norm_layers[i], drop_path=dp_rates[cur + j], layer_scale_init_value=layer_scale_init_values[i], res_scale_init_value=res_scale_init_values[i]) for j in range(depths[i])] ) self.stages.append(stage) cur += depths[i] self.norm = output_norm(dims[-1]) self.head = head_fn(dims[-1], num_classes) if head_dropout <= 0.0 else head_fn(dims[-1], num_classes, head_dropout=head_dropout) self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, (nn.Conv2d, nn.Linear)): trunc_normal_(m.weight, std=.02) if m.bias is not None: nn.init.constant_(m.bias, 0) @torch.jit.ignore def no_weight_decay(self): return {'norm'} def forward_features(self, x): for i in range(self.num_stage): x = self.downsample_layers[i](x) x = self.stages[i](x) return self.norm(x.mean([1, 2])) def forward(self, x): x = self.forward_features(x) x = self.head(x) return x # ---- Model factory functions (subset, extend as needed) ---- @register_model def identityformer_s12(pretrained=False, **kwargs): model = MetaFormer( depths=[2, 2, 6, 2], dims=[64, 128, 320, 512], token_mixers=nn.Identity, norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False), **kwargs) model.default_cfg = default_cfgs['identityformer_s12'] if pretrained: state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) model.load_state_dict(state_dict) return model @register_model def identityformer_s24(pretrained=False, **kwargs): model = MetaFormer( depths=[4, 4, 12, 4], dims=[64, 128, 320, 512], token_mixers=nn.Identity, norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False), **kwargs) model.default_cfg = default_cfgs['identityformer_s24'] if pretrained: state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) model.load_state_dict(state_dict) return model @register_model def identityformer_s36(pretrained=False, **kwargs): model = MetaFormer( depths=[6, 6, 18, 6], dims=[64, 128, 320, 512], token_mixers=nn.Identity, norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False), **kwargs) model.default_cfg = default_cfgs['identityformer_s36'] if pretrained: state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) model.load_state_dict(state_dict) return model @register_model def identityformer_m36(pretrained=False, **kwargs): model = MetaFormer( depths=[6, 6, 18, 6], dims=[96, 192, 384, 768], token_mixers=nn.Identity, norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False), **kwargs) model.default_cfg = default_cfgs['identityformer_m36'] if pretrained: state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) model.load_state_dict(state_dict) return model @register_model def identityformer_m48(pretrained=False, **kwargs): model = MetaFormer( depths=[8, 8, 24, 8], dims=[96, 192, 384, 768], token_mixers=nn.Identity, norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False), **kwargs) model.default_cfg = default_cfgs['identityformer_m48'] if pretrained: state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) model.load_state_dict(state_dict) return model @register_model def randformer_s12(pretrained=False, **kwargs): model = MetaFormer( depths=[2, 2, 6, 2], dims=[64, 128, 320, 512], token_mixers=[nn.Identity, nn.Identity, RandomMixing, partial(RandomMixing, num_tokens=49)], norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False), **kwargs) model.default_cfg = default_cfgs['randformer_s12'] if pretrained: state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) model.load_state_dict(state_dict) return model @register_model def randformer_s24(pretrained=False, **kwargs): model = MetaFormer( depths=[4, 4, 12, 4], dims=[64, 128, 320, 512], token_mixers=[nn.Identity, nn.Identity, RandomMixing, partial(RandomMixing, num_tokens=49)], norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False), **kwargs) model.default_cfg = default_cfgs['randformer_s24'] if pretrained: state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) model.load_state_dict(state_dict) return model @register_model def randformer_s36(pretrained=False, **kwargs): model = MetaFormer( depths=[6, 6, 18, 6], dims=[64, 128, 320, 512], token_mixers=[nn.Identity, nn.Identity, RandomMixing, partial(RandomMixing, num_tokens=49)], norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False), **kwargs) model.default_cfg = default_cfgs['randformer_s36'] if pretrained: state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) model.load_state_dict(state_dict) return model @register_model def randformer_m36(pretrained=False, **kwargs): model = MetaFormer( depths=[6, 6, 18, 6], dims=[96, 192, 384, 768], token_mixers=[nn.Identity, nn.Identity, RandomMixing, partial(RandomMixing, num_tokens=49)], norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False), **kwargs) model.default_cfg = default_cfgs['randformer_m36'] if pretrained: state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) model.load_state_dict(state_dict) return model @register_model def randformer_m48(pretrained=False, **kwargs): model = MetaFormer( depths=[8, 8, 24, 8], dims=[96, 192, 384, 768], token_mixers=[nn.Identity, nn.Identity, RandomMixing, partial(RandomMixing, num_tokens=49)], norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False), **kwargs) model.default_cfg = default_cfgs['randformer_m48'] if pretrained: state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) model.load_state_dict(state_dict) return model @register_model def poolformerv2_s12(pretrained=False, **kwargs): model = MetaFormer( depths=[2, 2, 6, 2], dims=[64, 128, 320, 512], token_mixers=Pooling, norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False), **kwargs) model.default_cfg = default_cfgs['poolformerv2_s12'] if pretrained: state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) model.load_state_dict(state_dict) return model @register_model def poolformerv2_s24(pretrained=False, **kwargs): model = MetaFormer( depths=[4, 4, 12, 4], dims=[64, 128, 320, 512], token_mixers=Pooling, norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False), **kwargs) model.default_cfg = default_cfgs['poolformerv2_s24'] if pretrained: try: logger.info("Loading pretrained weights for poolformerv2_s24 from: %s", model.default_cfg['url']) # Add timeout to prevent hanging in CI environments import socket original_timeout = socket.getdefaulttimeout() socket.setdefaulttimeout(60) # 60 second timeout try: state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) model.load_state_dict(state_dict) print("✓ Successfully loaded pretrained weights for poolformerv2_s24") finally: socket.setdefaulttimeout(original_timeout) except Exception as e: logger.warning("Failed to load pretrained weights for poolformerv2_s24: %s", e) logger.info("Continuing with randomly initialized weights...") return model @register_model def poolformerv2_s36(pretrained=False, **kwargs): model = MetaFormer( depths=[6, 6, 18, 6], dims=[64, 128, 320, 512], token_mixers=Pooling, norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False), **kwargs) model.default_cfg = default_cfgs['poolformerv2_s36'] if pretrained: state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) model.load_state_dict(state_dict) return model @register_model def poolformerv2_m36(pretrained=False, **kwargs): model = MetaFormer( depths=[6, 6, 18, 6], dims=[96, 192, 384, 768], token_mixers=Pooling, norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False), **kwargs) model.default_cfg = default_cfgs['poolformerv2_m36'] if pretrained: state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) model.load_state_dict(state_dict) return model @register_model def poolformerv2_m48(pretrained=False, **kwargs): model = MetaFormer( depths=[8, 8, 24, 8], dims=[96, 192, 384, 768], token_mixers=Pooling, norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False), **kwargs) model.default_cfg = default_cfgs['poolformerv2_m48'] if pretrained: state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) model.load_state_dict(state_dict) return model @register_model def convformer_s18(pretrained=False, **kwargs): model = MetaFormer( depths=[3, 3, 9, 3], dims=[64, 128, 320, 512], token_mixers=SepConv, head_fn=MlpHead, **kwargs) model.default_cfg = default_cfgs['convformer_s18'] if pretrained: state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) model.load_state_dict(state_dict) return model @register_model def convformer_s18_384(pretrained=False, **kwargs): model = MetaFormer( depths=[3, 3, 9, 3], dims=[64, 128, 320, 512], token_mixers=SepConv, head_fn=MlpHead, **kwargs) model.default_cfg = default_cfgs['convformer_s18_384'] if pretrained: state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) model.load_state_dict(state_dict) return model @register_model def convformer_s18_in21ft1k(pretrained=False, **kwargs): model = MetaFormer( depths=[3, 3, 9, 3], dims=[64, 128, 320, 512], token_mixers=SepConv, head_fn=MlpHead, **kwargs) model.default_cfg = default_cfgs['convformer_s18_in21ft1k'] if pretrained: state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) model.load_state_dict(state_dict) return model @register_model def convformer_s18_384_in21ft1k(pretrained=False, **kwargs): model = MetaFormer( depths=[3, 3, 9, 3], dims=[64, 128, 320, 512], token_mixers=SepConv, head_fn=MlpHead, **kwargs) model.default_cfg = default_cfgs['convformer_s18_384_in21ft1k'] if pretrained: state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) model.load_state_dict(state_dict) return model @register_model def convformer_s18_in21k(pretrained=False, **kwargs): model = MetaFormer( depths=[3, 3, 9, 3], dims=[64, 128, 320, 512], token_mixers=SepConv, head_fn=MlpHead, **kwargs) model.default_cfg = default_cfgs['convformer_s18_in21k'] if pretrained: state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) model.load_state_dict(state_dict) return model @register_model def convformer_s36(pretrained=False, **kwargs): model = MetaFormer( depths=[3, 12, 18, 3], dims=[64, 128, 320, 512], token_mixers=SepConv, head_fn=MlpHead, **kwargs) model.default_cfg = default_cfgs['convformer_s36'] if pretrained: state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) model.load_state_dict(state_dict) return model @register_model def convformer_s36_384(pretrained=False, **kwargs): model = MetaFormer( depths=[3, 12, 18, 3], dims=[64, 128, 320, 512], token_mixers=SepConv, head_fn=MlpHead, **kwargs) model.default_cfg = default_cfgs['convformer_s36_384'] if pretrained: state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) model.load_state_dict(state_dict) return model @register_model def convformer_s36_in21ft1k(pretrained=False, **kwargs): model = MetaFormer( depths=[3, 12, 18, 3], dims=[64, 128, 320, 512], token_mixers=SepConv, head_fn=MlpHead, **kwargs) model.default_cfg = default_cfgs['convformer_s36_in21ft1k'] if pretrained: state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) model.load_state_dict(state_dict) return model @register_model def convformer_s36_384_in21ft1k(pretrained=False, **kwargs): model = MetaFormer( depths=[3, 12, 18, 3], dims=[64, 128, 320, 512], token_mixers=SepConv, head_fn=MlpHead, **kwargs) model.default_cfg = default_cfgs['convformer_s36_384_in21ft1k'] if pretrained: state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) model.load_state_dict(state_dict) return model @register_model def convformer_s36_in21k(pretrained=False, **kwargs): model = MetaFormer( depths=[3, 12, 18, 3], dims=[64, 128, 320, 512], token_mixers=SepConv, head_fn=MlpHead, **kwargs) model.default_cfg = default_cfgs['convformer_s36_in21k'] if pretrained: state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) model.load_state_dict(state_dict) return model @register_model def convformer_m36(pretrained=False, **kwargs): model = MetaFormer( depths=[3, 12, 18, 3], dims=[96, 192, 384, 576], token_mixers=SepConv, head_fn=MlpHead, **kwargs) model.default_cfg = default_cfgs['convformer_m36'] if pretrained: state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) model.load_state_dict(state_dict) return model @register_model def convformer_m36_384(pretrained=False, **kwargs): model = MetaFormer( depths=[3, 12, 18, 3], dims=[96, 192, 384, 576], token_mixers=SepConv, head_fn=MlpHead, **kwargs) model.default_cfg = default_cfgs['convformer_m36_384'] if pretrained: state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) model.load_state_dict(state_dict) return model @register_model def convformer_m36_in21ft1k(pretrained=False, **kwargs): model = MetaFormer( depths=[3, 12, 18, 3], dims=[96, 192, 384, 576], token_mixers=SepConv, head_fn=MlpHead, **kwargs) model.default_cfg = default_cfgs['convformer_m36_in21ft1k'] if pretrained: state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) model.load_state_dict(state_dict) return model @register_model def convformer_m36_384_in21ft1k(pretrained=False, **kwargs): model = MetaFormer( depths=[3, 12, 18, 3], dims=[96, 192, 384, 576], token_mixers=SepConv, head_fn=MlpHead, **kwargs) model.default_cfg = default_cfgs['convformer_m36_384_in21ft1k'] if pretrained: state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) model.load_state_dict(state_dict) return model @register_model def convformer_m36_in21k(pretrained=False, **kwargs): model = MetaFormer( depths=[3, 12, 18, 3], dims=[96, 192, 384, 576], token_mixers=SepConv, head_fn=MlpHead, **kwargs) model.default_cfg = default_cfgs['convformer_m36_in21k'] if pretrained: state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) model.load_state_dict(state_dict) return model @register_model def convformer_b36(pretrained=False, **kwargs): model = MetaFormer( depths=[3, 12, 18, 3], dims=[128, 256, 512, 768], token_mixers=SepConv, head_fn=MlpHead, **kwargs) model.default_cfg = default_cfgs['convformer_b36'] if pretrained: state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) model.load_state_dict(state_dict) return model @register_model def convformer_b36_384(pretrained=False, **kwargs): model = MetaFormer( depths=[3, 12, 18, 3], dims=[128, 256, 512, 768], token_mixers=SepConv, head_fn=MlpHead, **kwargs) model.default_cfg = default_cfgs['convformer_b36_384'] if pretrained: state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) model.load_state_dict(state_dict) return model @register_model def convformer_b36_in21ft1k(pretrained=False, **kwargs): model = MetaFormer( depths=[3, 12, 18, 3], dims=[128, 256, 512, 768], token_mixers=SepConv, head_fn=MlpHead, **kwargs) model.default_cfg = default_cfgs['convformer_b36_in21ft1k'] if pretrained: state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) model.load_state_dict(state_dict) return model @register_model def convformer_b36_384_in21ft1k(pretrained=False, **kwargs): model = MetaFormer( depths=[3, 12, 18, 3], dims=[128, 256, 512, 768], token_mixers=SepConv, head_fn=MlpHead, **kwargs) model.default_cfg = default_cfgs['convformer_b36_384_in21ft1k'] if pretrained: state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) model.load_state_dict(state_dict) return model @register_model def convformer_b36_in21k(pretrained=False, **kwargs): model = MetaFormer( depths=[3, 12, 18, 3], dims=[128, 256, 512, 768], token_mixers=SepConv, head_fn=MlpHead, **kwargs) model.default_cfg = default_cfgs['convformer_b36_in21k'] if pretrained: state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) model.load_state_dict(state_dict) return model @register_model def caformer_s18(pretrained=False, **kwargs): model = MetaFormer( depths=[3, 3, 9, 3], dims=[64, 128, 320, 512], token_mixers=[SepConv, SepConv, Attention, Attention], head_fn=MlpHead, **kwargs) model.default_cfg = default_cfgs['caformer_s18'] if pretrained: try: print(f"Loading pretrained weights for caformer_s18 from: {model.default_cfg['url']}") # Add timeout to prevent hanging in CI environments import socket original_timeout = socket.getdefaulttimeout() socket.setdefaulttimeout(60) # 60 second timeout try: state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) model.load_state_dict(state_dict) print("✓ Successfully loaded pretrained weights for caformer_s18") finally: socket.setdefaulttimeout(original_timeout) except Exception as e: print(f"⚠ Warning: Failed to load pretrained weights for caformer_s18: {e}") print("Continuing with randomly initialized weights...") return model @register_model def caformer_s18_384(pretrained=False, **kwargs): model = MetaFormer( depths=[3, 3, 9, 3], dims=[64, 128, 320, 512], token_mixers=[SepConv, SepConv, Attention, Attention], head_fn=MlpHead, **kwargs) model.default_cfg = default_cfgs['caformer_s18_384'] if pretrained: state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) model.load_state_dict(state_dict) return model @register_model def caformer_s18_in21ft1k(pretrained=False, **kwargs): model = MetaFormer( depths=[3, 3, 9, 3], dims=[64, 128, 320, 512], token_mixers=[SepConv, SepConv, Attention, Attention], head_fn=MlpHead, **kwargs) model.default_cfg = default_cfgs['caformer_s18_in21ft1k'] if pretrained: state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) model.load_state_dict(state_dict) return model @register_model def caformer_s18_384_in21ft1k(pretrained=False, **kwargs): model = MetaFormer( depths=[3, 3, 9, 3], dims=[64, 128, 320, 512], token_mixers=[SepConv, SepConv, Attention, Attention], head_fn=MlpHead, **kwargs) model.default_cfg = default_cfgs['caformer_s18_384_in21ft1k'] if pretrained: state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) model.load_state_dict(state_dict) return model @register_model def caformer_s18_in21k(pretrained=False, **kwargs): model = MetaFormer( depths=[3, 3, 9, 3], dims=[64, 128, 320, 512], token_mixers=[SepConv, SepConv, Attention, Attention], head_fn=MlpHead, **kwargs) model.default_cfg = default_cfgs['caformer_s18_in21k'] if pretrained: state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) model.load_state_dict(state_dict) return model @register_model def caformer_s36(pretrained=False, **kwargs): model = MetaFormer( depths=[3, 12, 18, 3], dims=[64, 128, 320, 512], token_mixers=[SepConv, SepConv, Attention, Attention], head_fn=MlpHead, **kwargs) model.default_cfg = default_cfgs['caformer_s36'] if pretrained: state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) model.load_state_dict(state_dict) return model @register_model def caformer_s36_384(pretrained=False, **kwargs): model = MetaFormer( depths=[3, 12, 18, 3], dims=[64, 128, 320, 512], token_mixers=[SepConv, SepConv, Attention, Attention], head_fn=MlpHead, **kwargs) model.default_cfg = default_cfgs['caformer_s36_384'] if pretrained: state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) model.load_state_dict(state_dict) return model @register_model def caformer_s36_in21ft1k(pretrained=False, **kwargs): model = MetaFormer( depths=[3, 12, 18, 3], dims=[64, 128, 320, 512], token_mixers=[SepConv, SepConv, Attention, Attention], head_fn=MlpHead, **kwargs) model.default_cfg = default_cfgs['caformer_s36_in21ft1k'] if pretrained: state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) model.load_state_dict(state_dict) return model @register_model def caformer_s36_384_in21ft1k(pretrained=False, **kwargs): model = MetaFormer( depths=[3, 12, 18, 3], dims=[64, 128, 320, 512], token_mixers=[SepConv, SepConv, Attention, Attention], head_fn=MlpHead, **kwargs) model.default_cfg = default_cfgs['caformer_s36_384_in21ft1k'] if pretrained: state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) model.load_state_dict(state_dict) return model @register_model def caformer_s36_in21k(pretrained=False, **kwargs): model = MetaFormer( depths=[3, 12, 18, 3], dims=[64, 128, 320, 512], token_mixers=[SepConv, SepConv, Attention, Attention], head_fn=MlpHead, **kwargs) model.default_cfg = default_cfgs['caformer_s36_in21k'] if pretrained: state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) model.load_state_dict(state_dict) return model @register_model def caformer_m36(pretrained=False, **kwargs): model = MetaFormer( depths=[3, 12, 18, 3], dims=[96, 192, 384, 576], token_mixers=[SepConv, SepConv, Attention, Attention], head_fn=MlpHead, **kwargs) model.default_cfg = default_cfgs['caformer_m36'] if pretrained: state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) model.load_state_dict(state_dict) return model @register_model def caformer_m36_384(pretrained=False, **kwargs): model = MetaFormer( depths=[3, 12, 18, 3], dims=[96, 192, 384, 576], token_mixers=[SepConv, SepConv, Attention, Attention], head_fn=MlpHead, **kwargs) model.default_cfg = default_cfgs['caformer_m36_384'] if pretrained: state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) model.load_state_dict(state_dict) return model @register_model def caformer_m36_in21ft1k(pretrained=False, **kwargs): model = MetaFormer( depths=[3, 12, 18, 3], dims=[96, 192, 384, 576], token_mixers=[SepConv, SepConv, Attention, Attention], head_fn=MlpHead, **kwargs) model.default_cfg = default_cfgs['caformer_m36_in21ft1k'] if pretrained: state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) model.load_state_dict(state_dict) return model @register_model def caformer_m36_384_in21ft1k(pretrained=False, **kwargs): model = MetaFormer( depths=[3, 12, 18, 3], dims=[96, 192, 384, 576], token_mixers=[SepConv, SepConv, Attention, Attention], head_fn=MlpHead, **kwargs) model.default_cfg = default_cfgs['caformer_m36_384_in21ft1k'] if pretrained: state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) model.load_state_dict(state_dict) return model @register_model def caformer_m36_in21k(pretrained=False, **kwargs): model = MetaFormer( depths=[3, 12, 18, 3], dims=[96, 192, 384, 576], token_mixers=[SepConv, SepConv, Attention, Attention], head_fn=MlpHead, **kwargs) model.default_cfg = default_cfgs['caformer_m36_in21k'] if pretrained: state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) model.load_state_dict(state_dict) return model @register_model def caformer_b36(pretrained=False, **kwargs): model = MetaFormer( depths=[3, 12, 18, 3], dims=[128, 256, 512, 768], token_mixers=[SepConv, SepConv, Attention, Attention], head_fn=MlpHead, **kwargs) model.default_cfg = default_cfgs['caformer_b36'] if pretrained: state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) model.load_state_dict(state_dict) return model @register_model def caformer_b36_384(pretrained=False, **kwargs): model = MetaFormer( depths=[3, 12, 18, 3], dims=[128, 256, 512, 768], token_mixers=[SepConv, SepConv, Attention, Attention], head_fn=MlpHead, **kwargs) model.default_cfg = default_cfgs['caformer_b36_384'] if pretrained: state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) model.load_state_dict(state_dict) return model @register_model def caformer_b36_in21ft1k(pretrained=False, **kwargs): model = MetaFormer( depths=[3, 12, 18, 3], dims=[128, 256, 512, 768], token_mixers=[SepConv, SepConv, Attention, Attention], head_fn=MlpHead, **kwargs) model.default_cfg = default_cfgs['caformer_b36_in21ft1k'] if pretrained: state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) model.load_state_dict(state_dict) return model @register_model def caformer_b36_384_in21ft1k(pretrained=False, **kwargs): model = MetaFormer( depths=[3, 12, 18, 3], dims=[128, 256, 512, 768], token_mixers=[SepConv, SepConv, Attention, Attention], head_fn=MlpHead, **kwargs) model.default_cfg = default_cfgs['caformer_b36_384_in21ft1k'] if pretrained: state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) model.load_state_dict(state_dict) return model @register_model def caformer_b36_in21k(pretrained=False, **kwargs): model = MetaFormer( depths=[3, 12, 18, 3], dims=[128, 256, 512, 768], token_mixers=[SepConv, SepConv, Attention, Attention], head_fn=MlpHead, **kwargs) model.default_cfg = default_cfgs['caformer_b36_in21k'] if pretrained: state_dict = torch.hub.load_state_dict_from_url(url=model.default_cfg['url'], map_location="cpu", check_hash=True) model.load_state_dict(state_dict) return model