diff MetaFormer/metaformer_stacked_cnn.py @ 11:c5150cceab47 draft default tip

planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
author goeckslab
date Sat, 18 Oct 2025 03:17:09 +0000
parents
children
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/MetaFormer/metaformer_stacked_cnn.py	Sat Oct 18 03:17:09 2025 +0000
@@ -0,0 +1,428 @@
+import logging
+import os
+import sys
+from typing import Dict, List, Optional
+
+import torch
+import torch.nn as nn
+
+sys.path.insert(0, os.path.dirname(__file__))
+
+logging.basicConfig(
+    level=logging.INFO,
+    format="%(asctime)s %(levelname)s %(name)s: %(message)s",
+)
+logger = logging.getLogger(__name__)
+
+SUPPORTED_PREFIXES = (
+    'identityformer_',
+    'randformer_',
+    'poolformerv2_',
+    'convformer_',
+    'caformer_',
+)
+
+try:
+    from metaformer_models import default_cfgs as META_DEFAULT_CFGS
+    META_MODELS_AVAILABLE = True
+    logger.info("MetaFormer models imported successfully")
+except Exception as e:
+    META_MODELS_AVAILABLE = False
+    logger.warning(f"MetaFormer models not available: {e}")
+
+
+def _resolve_metaformer_ctor(model_name: str):
+    # Prefer getattr to avoid importing every factory explicitly
+    try:
+        # Import the module itself for dynamic access
+        import metaformer_models
+        _factories = metaformer_models.__dict__
+        if model_name in _factories and callable(_factories[model_name]):
+            return _factories[model_name]
+    except Exception:
+        pass
+    return None
+
+
+class MetaFormerStackedCNN(nn.Module):
+    def __init__(
+        self,
+        height: int = 224,
+        width: int = 224,
+        num_channels: int = 3,
+        output_size: int = 128,
+        custom_model: str = "identityformer_s12",
+        use_pretrained: bool = True,
+        trainable: bool = True,
+        conv_layers: Optional[List[Dict]] = None,
+        num_conv_layers: Optional[int] = None,
+        conv_activation: str = "relu",
+        conv_dropout: float = 0.0,
+        conv_norm: Optional[str] = None,
+        conv_use_bias: bool = True,
+        fc_layers: Optional[List[Dict]] = None,
+        num_fc_layers: int = 1,
+        fc_activation: str = "relu",
+        fc_dropout: float = 0.0,
+        fc_norm: Optional[str] = None,
+        fc_use_bias: bool = True,
+        **kwargs,
+    ):
+        super().__init__()
+        logger.info("MetaFormerStackedCNN encoder instantiated")
+        logger.info(f"Using MetaFormer model: {custom_model}")
+
+        try:
+            height = int(height)
+            width = int(width)
+            num_channels = int(num_channels)
+        except (TypeError, ValueError) as exc:
+            raise ValueError("MetaFormerStackedCNN requires integer height, width, and num_channels.") from exc
+
+        if height <= 0 or width <= 0:
+            raise ValueError(f"MetaFormerStackedCNN received non-positive dimensions: {height}x{width}.")
+        if num_channels <= 0:
+            raise ValueError(f"MetaFormerStackedCNN requires num_channels > 0, received {num_channels}.")
+
+        self.height = height
+        self.width = width
+        self.num_channels = num_channels
+        self.output_size = output_size
+        self.custom_model = custom_model
+        self.use_pretrained = use_pretrained
+        self.trainable = trainable
+
+        cfg = META_DEFAULT_CFGS.get(custom_model, {})
+        input_size = cfg.get('input_size', (3, 224, 224))
+        if isinstance(input_size, (list, tuple)) and len(input_size) == 3:
+            expected_channels, expected_height, expected_width = input_size
+        else:
+            expected_channels, expected_height, expected_width = 3, 224, 224
+
+        self.expected_channels = expected_channels
+        self.expected_height = expected_height
+        self.expected_width = expected_width
+
+        logger.info(f"Initializing MetaFormerStackedCNN with model: {custom_model}")
+        logger.info(
+            "Input: %sx%sx%s -> Output: %s (expected backbone size: %sx%s)",
+            num_channels,
+            height,
+            width,
+            output_size,
+            self.expected_height,
+            self.expected_width,
+        )
+
+        self.channel_adapter: Optional[nn.Conv2d] = None
+        if num_channels != self.expected_channels:
+            self.channel_adapter = nn.Conv2d(
+                num_channels, self.expected_channels, kernel_size=1, stride=1, padding=0
+            )
+            logger.info(
+                "Added channel adapter: %s -> %s channels",
+                num_channels,
+                self.expected_channels,
+            )
+
+        self.size_adapter: Optional[nn.Module] = None
+        if height != self.expected_height or width != self.expected_width:
+            self.size_adapter = nn.AdaptiveAvgPool2d((height, width))
+            logger.info(
+                "Configured size adapter to requested input: %sx%s",
+                height,
+                width,
+            )
+        self.backbone_adapter: Optional[nn.Module] = None
+
+        self.backbone = self._load_metaformer_backbone()
+        self.feature_dim = self._get_feature_dim()
+
+        self.fc_layers = self._create_fc_layers(
+            input_dim=self.feature_dim,
+            output_dim=output_size,
+            num_layers=num_fc_layers,
+            activation=fc_activation,
+            dropout=fc_dropout,
+            norm=fc_norm,
+            use_bias=fc_use_bias,
+            fc_layers_config=fc_layers,
+        )
+
+        if not trainable:
+            for param in self.backbone.parameters():
+                param.requires_grad = False
+            logger.info("MetaFormer backbone frozen (trainable=False)")
+
+        logger.info("MetaFormerStackedCNN initialized successfully")
+
+    def _load_metaformer_backbone(self):
+        if not META_MODELS_AVAILABLE:
+            raise ImportError("MetaFormer models are not available")
+
+        ctor = _resolve_metaformer_ctor(self.custom_model)
+        if ctor is None:
+            raise ValueError(f"Unknown MetaFormer model: {self.custom_model}")
+
+        cfg = META_DEFAULT_CFGS.get(self.custom_model, {})
+        weights_url = cfg.get('url')
+        # track loading
+        self._pretrained_loaded = False
+        self._loaded_weights_url: Optional[str] = None
+        if self.use_pretrained and weights_url:
+            print(f"LOADING MetaFormer pretrained weights from: {weights_url}")
+            logger.info(f"Loading pretrained weights from: {weights_url}")
+        # Ensure we log whenever the factories call torch.hub.load_state_dict_from_url
+        orig_loader = getattr(torch.hub, 'load_state_dict_from_url', None)
+
+        def _wrapped_loader(url, *args, **kwargs):
+            print(f"DOWNLOADING weights from: {url}")
+            logger.info(f"DOWNLOADING weights from: {url}")
+            self._pretrained_loaded = True
+            self._loaded_weights_url = url
+            result = orig_loader(url, *args, **kwargs)
+            print(f"WEIGHTS DOWNLOADED successfully from: {url}")
+            return result
+        try:
+            if self.use_pretrained and orig_loader is not None:
+                torch.hub.load_state_dict_from_url = _wrapped_loader  # type: ignore[attr-defined]
+            print(f"CREATING MetaFormer model: {self.custom_model} (pretrained={self.use_pretrained})")
+            try:
+                model = ctor(pretrained=self.use_pretrained, num_classes=1000)
+                print(f"MetaFormer model CREATED: {self.custom_model}")
+            except Exception as model_error:
+                if self.use_pretrained:
+                    print(f"⚠ Warning: Failed to load {self.custom_model} with pretrained weights: {model_error}")
+                    print("Attempting to load without pretrained weights as fallback...")
+                    logger.warning(f"Failed to load {self.custom_model} with pretrained weights: {model_error}")
+                    model = ctor(pretrained=False, num_classes=1000)
+                    print(f"✓ Successfully loaded {self.custom_model} without pretrained weights")
+                    self.use_pretrained = False  # Update state to reflect actual loading
+                else:
+                    raise model_error
+        finally:
+            if orig_loader is not None:
+                torch.hub.load_state_dict_from_url = orig_loader  # type: ignore[attr-defined]
+        self._metaformer_weights_url = weights_url
+        if self.use_pretrained:
+            if self._pretrained_loaded:
+                print(f"MetaFormer: pretrained weights loaded from {self._loaded_weights_url}")
+                logger.info(f"MetaFormer: pretrained weights loaded from {self._loaded_weights_url}")
+            else:
+                # Warn but don't fail - weights may have failed to load but model creation succeeded
+                print("⚠ Warning: MetaFormer pretrained weights were requested but not confirmed as loaded")
+                logger.warning("MetaFormer: pretrained weights were requested but not confirmed as loaded")
+        else:
+            print(f"MetaFormer: using randomly initialized weights for {self.custom_model}")
+            logger.info(f"MetaFormer: using randomly initialized weights for {self.custom_model}")
+        logger.info(f"Loaded MetaFormer backbone: {self.custom_model} (pretrained={self.use_pretrained})")
+        return model
+
+    def _get_feature_dim(self):
+        with torch.no_grad():
+            dummy_input = torch.randn(1, 3, 224, 224)
+            features = self.backbone.forward_features(dummy_input)
+            feature_dim = features.shape[-1]
+        logger.info(f"MetaFormer feature dimension: {feature_dim}")
+        return feature_dim
+
+    def _create_fc_layers(self, input_dim, output_dim, num_layers, activation, dropout, norm, use_bias, fc_layers_config):
+        layers = []
+        if fc_layers_config:
+            current_dim = input_dim
+            for i, layer_config in enumerate(fc_layers_config):
+                layer_output_dim = layer_config.get('output_size', output_dim if i == len(fc_layers_config) - 1 else current_dim)
+                layers.append(nn.Linear(current_dim, layer_output_dim, bias=use_bias))
+                if i < len(fc_layers_config) - 1:
+                    if activation == "relu":
+                        layers.append(nn.ReLU())
+                    elif activation == "tanh":
+                        layers.append(nn.Tanh())
+                    elif activation == "sigmoid":
+                        layers.append(nn.Sigmoid())
+                    elif activation == "leaky_relu":
+                        layers.append(nn.LeakyReLU())
+                if dropout > 0:
+                    layers.append(nn.Dropout(dropout))
+                if norm == "batch":
+                    layers.append(nn.BatchNorm1d(layer_output_dim))
+                elif norm == "layer":
+                    layers.append(nn.LayerNorm(layer_output_dim))
+                current_dim = layer_output_dim
+        else:
+            if num_layers == 1:
+                layers.append(nn.Linear(input_dim, output_dim, bias=use_bias))
+            else:
+                intermediate_dims = [input_dim]
+                for i in range(num_layers - 1):
+                    intermediate_dim = int(input_dim * (0.5 ** (i + 1)))
+                    intermediate_dim = max(intermediate_dim, output_dim)
+                    intermediate_dims.append(intermediate_dim)
+                intermediate_dims.append(output_dim)
+                for i in range(num_layers):
+                    layers.append(nn.Linear(intermediate_dims[i], intermediate_dims[i + 1], bias=use_bias))
+                    if i < num_layers - 1:
+                        if activation == "relu":
+                            layers.append(nn.ReLU())
+                        elif activation == "tanh":
+                            layers.append(nn.Tanh())
+                        elif activation == "sigmoid":
+                            layers.append(nn.Sigmoid())
+                        elif activation == "leaky_relu":
+                            layers.append(nn.LeakyReLU())
+                    if dropout > 0:
+                        layers.append(nn.Dropout(dropout))
+                    if norm == "batch":
+                        layers.append(nn.BatchNorm1d(intermediate_dims[i + 1]))
+                    elif norm == "layer":
+                        layers.append(nn.LayerNorm(intermediate_dims[i + 1]))
+        return nn.Sequential(*layers)
+
+    def forward(self, x):
+        if x.shape[1] != self.expected_channels:
+            if (
+                self.channel_adapter is None
+                or self.channel_adapter.in_channels != x.shape[1]
+                or self.channel_adapter.out_channels != self.expected_channels
+            ):
+                self.channel_adapter = nn.Conv2d(
+                    x.shape[1],
+                    self.expected_channels,
+                    kernel_size=1,
+                    stride=1,
+                    padding=0,
+                ).to(x.device)
+                logger.info(
+                    "Created dynamic channel adapter: %s -> %s channels",
+                    x.shape[1],
+                    self.expected_channels,
+                )
+            x = self.channel_adapter(x)
+
+        target_height, target_width = self.height, self.width
+        if x.shape[2] != target_height or x.shape[3] != target_width:
+            if (
+                self.size_adapter is None
+                or getattr(self.size_adapter, "output_size", None)
+                != (target_height, target_width)
+            ):
+                self.size_adapter = nn.AdaptiveAvgPool2d(
+                    (target_height, target_width)
+                ).to(x.device)
+                logger.info(
+                    "Created size adapter: %sx%s -> %sx%s",
+                    x.shape[2],
+                    x.shape[3],
+                    target_height,
+                    target_width,
+                )
+            x = self.size_adapter(x)
+
+        if target_height != self.expected_height or target_width != self.expected_width:
+            if (
+                self.backbone_adapter is None
+                or getattr(self.backbone_adapter, "output_size", None)
+                != (self.expected_height, self.expected_width)
+            ):
+                self.backbone_adapter = nn.AdaptiveAvgPool2d(
+                    (self.expected_height, self.expected_width)
+                ).to(x.device)
+                logger.info(
+                    "Aligning to MetaFormer backbone size: %sx%s",
+                    self.expected_height,
+                    self.expected_width,
+                )
+            x = self.backbone_adapter(x)
+
+        features = self.backbone.forward_features(x)
+        output = self.fc_layers(features)
+        return {'encoder_output': output}
+
+    @property
+    def output_shape(self):
+        return [self.output_size]
+
+
+def create_metaformer_stacked_cnn(model_name: str, **kwargs) -> MetaFormerStackedCNN:
+    encoder = MetaFormerStackedCNN(custom_model=model_name, **kwargs)
+    return encoder
+
+
+def patch_ludwig_stacked_cnn():
+    # Only patch Ludwig if MetaFormer models are available in this runtime
+    if not META_MODELS_AVAILABLE:
+        logger.warning("MetaFormer models unavailable; skipping Ludwig patch for stacked_cnn.")
+        return False
+    return patch_ludwig_direct()
+
+
+def _is_supported_metaformer(custom_model: Optional[str]) -> bool:
+    return bool(custom_model) and custom_model.startswith(SUPPORTED_PREFIXES)
+
+
+def patch_ludwig_direct():
+    try:
+        from ludwig.encoders.image.base import Stacked2DCNN
+        original_stacked_cnn_init = Stacked2DCNN.__init__
+
+        def patched_stacked_cnn_init(self, *args, **kwargs):
+            custom_model = kwargs.pop("custom_model", None)
+            if custom_model is None:
+                custom_model = getattr(patch_ludwig_direct, '_metaformer_model', None)
+
+            try:
+                if META_MODELS_AVAILABLE and _is_supported_metaformer(custom_model):
+                    print(f"DETECTED MetaFormer model: {custom_model}")
+                    print("MetaFormer encoder is being loaded and used.")
+                    # Initialize base class to keep Ludwig internals intact
+                    original_stacked_cnn_init(self, *args, **kwargs)
+                    # Create our MetaFormer encoder and graft behavior
+                    mf_encoder = create_metaformer_stacked_cnn(custom_model, **kwargs)
+                    # ensure base attributes won't be used accidentally
+                    for attr in ("conv_layers", "fc_layers", "combiner", "output_shape", "reduce_output"):
+                        if hasattr(self, attr):
+                            try:
+                                setattr(self, attr, getattr(mf_encoder, attr, None))
+                            except Exception:
+                                pass
+                    self.forward = mf_encoder.forward
+                    if hasattr(mf_encoder, 'backbone'):
+                        self.backbone = mf_encoder.backbone
+                    if hasattr(mf_encoder, 'fc_layers'):
+                        self.fc_layers = mf_encoder.fc_layers
+                    if hasattr(mf_encoder, 'custom_model'):
+                        self.custom_model = mf_encoder.custom_model
+                    # explicit confirmation logs
+                    try:
+                        url_info = getattr(mf_encoder, '_loaded_weights_url', None)
+                        loaded_flag = getattr(mf_encoder, '_pretrained_loaded', False)
+                        if loaded_flag and url_info:
+                            print(f"CONFIRMED: MetaFormer '{custom_model}' using pretrained weights from: {url_info}")
+                            logger.info(f"CONFIRMED: MetaFormer '{custom_model}' using pretrained weights from: {url_info}")
+                        else:
+                            print(f"CONFIRMED: MetaFormer '{custom_model}' using randomly initialized weights (no pretrained)")
+                            logger.info(f"CONFIRMED: MetaFormer '{custom_model}' using randomly initialized weights")
+                    except Exception:
+                        pass
+                else:
+                    original_stacked_cnn_init(self, *args, **kwargs)
+            finally:
+                if hasattr(patch_ludwig_direct, '_metaformer_model'):
+                    patch_ludwig_direct._metaformer_model = None
+
+        Stacked2DCNN.__init__ = patched_stacked_cnn_init
+        return True
+    except Exception as e:
+        logger.error(f"Failed to apply MetaFormer direct patch: {e}")
+        return False
+
+
+def set_current_metaformer_model(model_name: str):
+    """Store the current MetaFormer model name for the patch to use."""
+    setattr(patch_ludwig_direct, '_metaformer_model', model_name)
+
+
+def clear_current_metaformer_model():
+    """Remove any cached MetaFormer model hint."""
+    if hasattr(patch_ludwig_direct, '_metaformer_model'):
+        delattr(patch_ludwig_direct, '_metaformer_model')