Mercurial > repos > goeckslab > image_learner
diff MetaFormer/metaformer_stacked_cnn.py @ 11:c5150cceab47 draft default tip
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
author | goeckslab |
---|---|
date | Sat, 18 Oct 2025 03:17:09 +0000 |
parents | |
children |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/MetaFormer/metaformer_stacked_cnn.py Sat Oct 18 03:17:09 2025 +0000 @@ -0,0 +1,428 @@ +import logging +import os +import sys +from typing import Dict, List, Optional + +import torch +import torch.nn as nn + +sys.path.insert(0, os.path.dirname(__file__)) + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s %(levelname)s %(name)s: %(message)s", +) +logger = logging.getLogger(__name__) + +SUPPORTED_PREFIXES = ( + 'identityformer_', + 'randformer_', + 'poolformerv2_', + 'convformer_', + 'caformer_', +) + +try: + from metaformer_models import default_cfgs as META_DEFAULT_CFGS + META_MODELS_AVAILABLE = True + logger.info("MetaFormer models imported successfully") +except Exception as e: + META_MODELS_AVAILABLE = False + logger.warning(f"MetaFormer models not available: {e}") + + +def _resolve_metaformer_ctor(model_name: str): + # Prefer getattr to avoid importing every factory explicitly + try: + # Import the module itself for dynamic access + import metaformer_models + _factories = metaformer_models.__dict__ + if model_name in _factories and callable(_factories[model_name]): + return _factories[model_name] + except Exception: + pass + return None + + +class MetaFormerStackedCNN(nn.Module): + def __init__( + self, + height: int = 224, + width: int = 224, + num_channels: int = 3, + output_size: int = 128, + custom_model: str = "identityformer_s12", + use_pretrained: bool = True, + trainable: bool = True, + conv_layers: Optional[List[Dict]] = None, + num_conv_layers: Optional[int] = None, + conv_activation: str = "relu", + conv_dropout: float = 0.0, + conv_norm: Optional[str] = None, + conv_use_bias: bool = True, + fc_layers: Optional[List[Dict]] = None, + num_fc_layers: int = 1, + fc_activation: str = "relu", + fc_dropout: float = 0.0, + fc_norm: Optional[str] = None, + fc_use_bias: bool = True, + **kwargs, + ): + super().__init__() + logger.info("MetaFormerStackedCNN encoder instantiated") + logger.info(f"Using MetaFormer model: {custom_model}") + + try: + height = int(height) + width = int(width) + num_channels = int(num_channels) + except (TypeError, ValueError) as exc: + raise ValueError("MetaFormerStackedCNN requires integer height, width, and num_channels.") from exc + + if height <= 0 or width <= 0: + raise ValueError(f"MetaFormerStackedCNN received non-positive dimensions: {height}x{width}.") + if num_channels <= 0: + raise ValueError(f"MetaFormerStackedCNN requires num_channels > 0, received {num_channels}.") + + self.height = height + self.width = width + self.num_channels = num_channels + self.output_size = output_size + self.custom_model = custom_model + self.use_pretrained = use_pretrained + self.trainable = trainable + + cfg = META_DEFAULT_CFGS.get(custom_model, {}) + input_size = cfg.get('input_size', (3, 224, 224)) + if isinstance(input_size, (list, tuple)) and len(input_size) == 3: + expected_channels, expected_height, expected_width = input_size + else: + expected_channels, expected_height, expected_width = 3, 224, 224 + + self.expected_channels = expected_channels + self.expected_height = expected_height + self.expected_width = expected_width + + logger.info(f"Initializing MetaFormerStackedCNN with model: {custom_model}") + logger.info( + "Input: %sx%sx%s -> Output: %s (expected backbone size: %sx%s)", + num_channels, + height, + width, + output_size, + self.expected_height, + self.expected_width, + ) + + self.channel_adapter: Optional[nn.Conv2d] = None + if num_channels != self.expected_channels: + self.channel_adapter = nn.Conv2d( + num_channels, self.expected_channels, kernel_size=1, stride=1, padding=0 + ) + logger.info( + "Added channel adapter: %s -> %s channels", + num_channels, + self.expected_channels, + ) + + self.size_adapter: Optional[nn.Module] = None + if height != self.expected_height or width != self.expected_width: + self.size_adapter = nn.AdaptiveAvgPool2d((height, width)) + logger.info( + "Configured size adapter to requested input: %sx%s", + height, + width, + ) + self.backbone_adapter: Optional[nn.Module] = None + + self.backbone = self._load_metaformer_backbone() + self.feature_dim = self._get_feature_dim() + + self.fc_layers = self._create_fc_layers( + input_dim=self.feature_dim, + output_dim=output_size, + num_layers=num_fc_layers, + activation=fc_activation, + dropout=fc_dropout, + norm=fc_norm, + use_bias=fc_use_bias, + fc_layers_config=fc_layers, + ) + + if not trainable: + for param in self.backbone.parameters(): + param.requires_grad = False + logger.info("MetaFormer backbone frozen (trainable=False)") + + logger.info("MetaFormerStackedCNN initialized successfully") + + def _load_metaformer_backbone(self): + if not META_MODELS_AVAILABLE: + raise ImportError("MetaFormer models are not available") + + ctor = _resolve_metaformer_ctor(self.custom_model) + if ctor is None: + raise ValueError(f"Unknown MetaFormer model: {self.custom_model}") + + cfg = META_DEFAULT_CFGS.get(self.custom_model, {}) + weights_url = cfg.get('url') + # track loading + self._pretrained_loaded = False + self._loaded_weights_url: Optional[str] = None + if self.use_pretrained and weights_url: + print(f"LOADING MetaFormer pretrained weights from: {weights_url}") + logger.info(f"Loading pretrained weights from: {weights_url}") + # Ensure we log whenever the factories call torch.hub.load_state_dict_from_url + orig_loader = getattr(torch.hub, 'load_state_dict_from_url', None) + + def _wrapped_loader(url, *args, **kwargs): + print(f"DOWNLOADING weights from: {url}") + logger.info(f"DOWNLOADING weights from: {url}") + self._pretrained_loaded = True + self._loaded_weights_url = url + result = orig_loader(url, *args, **kwargs) + print(f"WEIGHTS DOWNLOADED successfully from: {url}") + return result + try: + if self.use_pretrained and orig_loader is not None: + torch.hub.load_state_dict_from_url = _wrapped_loader # type: ignore[attr-defined] + print(f"CREATING MetaFormer model: {self.custom_model} (pretrained={self.use_pretrained})") + try: + model = ctor(pretrained=self.use_pretrained, num_classes=1000) + print(f"MetaFormer model CREATED: {self.custom_model}") + except Exception as model_error: + if self.use_pretrained: + print(f"⚠ Warning: Failed to load {self.custom_model} with pretrained weights: {model_error}") + print("Attempting to load without pretrained weights as fallback...") + logger.warning(f"Failed to load {self.custom_model} with pretrained weights: {model_error}") + model = ctor(pretrained=False, num_classes=1000) + print(f"✓ Successfully loaded {self.custom_model} without pretrained weights") + self.use_pretrained = False # Update state to reflect actual loading + else: + raise model_error + finally: + if orig_loader is not None: + torch.hub.load_state_dict_from_url = orig_loader # type: ignore[attr-defined] + self._metaformer_weights_url = weights_url + if self.use_pretrained: + if self._pretrained_loaded: + print(f"MetaFormer: pretrained weights loaded from {self._loaded_weights_url}") + logger.info(f"MetaFormer: pretrained weights loaded from {self._loaded_weights_url}") + else: + # Warn but don't fail - weights may have failed to load but model creation succeeded + print("⚠ Warning: MetaFormer pretrained weights were requested but not confirmed as loaded") + logger.warning("MetaFormer: pretrained weights were requested but not confirmed as loaded") + else: + print(f"MetaFormer: using randomly initialized weights for {self.custom_model}") + logger.info(f"MetaFormer: using randomly initialized weights for {self.custom_model}") + logger.info(f"Loaded MetaFormer backbone: {self.custom_model} (pretrained={self.use_pretrained})") + return model + + def _get_feature_dim(self): + with torch.no_grad(): + dummy_input = torch.randn(1, 3, 224, 224) + features = self.backbone.forward_features(dummy_input) + feature_dim = features.shape[-1] + logger.info(f"MetaFormer feature dimension: {feature_dim}") + return feature_dim + + def _create_fc_layers(self, input_dim, output_dim, num_layers, activation, dropout, norm, use_bias, fc_layers_config): + layers = [] + if fc_layers_config: + current_dim = input_dim + for i, layer_config in enumerate(fc_layers_config): + layer_output_dim = layer_config.get('output_size', output_dim if i == len(fc_layers_config) - 1 else current_dim) + layers.append(nn.Linear(current_dim, layer_output_dim, bias=use_bias)) + if i < len(fc_layers_config) - 1: + if activation == "relu": + layers.append(nn.ReLU()) + elif activation == "tanh": + layers.append(nn.Tanh()) + elif activation == "sigmoid": + layers.append(nn.Sigmoid()) + elif activation == "leaky_relu": + layers.append(nn.LeakyReLU()) + if dropout > 0: + layers.append(nn.Dropout(dropout)) + if norm == "batch": + layers.append(nn.BatchNorm1d(layer_output_dim)) + elif norm == "layer": + layers.append(nn.LayerNorm(layer_output_dim)) + current_dim = layer_output_dim + else: + if num_layers == 1: + layers.append(nn.Linear(input_dim, output_dim, bias=use_bias)) + else: + intermediate_dims = [input_dim] + for i in range(num_layers - 1): + intermediate_dim = int(input_dim * (0.5 ** (i + 1))) + intermediate_dim = max(intermediate_dim, output_dim) + intermediate_dims.append(intermediate_dim) + intermediate_dims.append(output_dim) + for i in range(num_layers): + layers.append(nn.Linear(intermediate_dims[i], intermediate_dims[i + 1], bias=use_bias)) + if i < num_layers - 1: + if activation == "relu": + layers.append(nn.ReLU()) + elif activation == "tanh": + layers.append(nn.Tanh()) + elif activation == "sigmoid": + layers.append(nn.Sigmoid()) + elif activation == "leaky_relu": + layers.append(nn.LeakyReLU()) + if dropout > 0: + layers.append(nn.Dropout(dropout)) + if norm == "batch": + layers.append(nn.BatchNorm1d(intermediate_dims[i + 1])) + elif norm == "layer": + layers.append(nn.LayerNorm(intermediate_dims[i + 1])) + return nn.Sequential(*layers) + + def forward(self, x): + if x.shape[1] != self.expected_channels: + if ( + self.channel_adapter is None + or self.channel_adapter.in_channels != x.shape[1] + or self.channel_adapter.out_channels != self.expected_channels + ): + self.channel_adapter = nn.Conv2d( + x.shape[1], + self.expected_channels, + kernel_size=1, + stride=1, + padding=0, + ).to(x.device) + logger.info( + "Created dynamic channel adapter: %s -> %s channels", + x.shape[1], + self.expected_channels, + ) + x = self.channel_adapter(x) + + target_height, target_width = self.height, self.width + if x.shape[2] != target_height or x.shape[3] != target_width: + if ( + self.size_adapter is None + or getattr(self.size_adapter, "output_size", None) + != (target_height, target_width) + ): + self.size_adapter = nn.AdaptiveAvgPool2d( + (target_height, target_width) + ).to(x.device) + logger.info( + "Created size adapter: %sx%s -> %sx%s", + x.shape[2], + x.shape[3], + target_height, + target_width, + ) + x = self.size_adapter(x) + + if target_height != self.expected_height or target_width != self.expected_width: + if ( + self.backbone_adapter is None + or getattr(self.backbone_adapter, "output_size", None) + != (self.expected_height, self.expected_width) + ): + self.backbone_adapter = nn.AdaptiveAvgPool2d( + (self.expected_height, self.expected_width) + ).to(x.device) + logger.info( + "Aligning to MetaFormer backbone size: %sx%s", + self.expected_height, + self.expected_width, + ) + x = self.backbone_adapter(x) + + features = self.backbone.forward_features(x) + output = self.fc_layers(features) + return {'encoder_output': output} + + @property + def output_shape(self): + return [self.output_size] + + +def create_metaformer_stacked_cnn(model_name: str, **kwargs) -> MetaFormerStackedCNN: + encoder = MetaFormerStackedCNN(custom_model=model_name, **kwargs) + return encoder + + +def patch_ludwig_stacked_cnn(): + # Only patch Ludwig if MetaFormer models are available in this runtime + if not META_MODELS_AVAILABLE: + logger.warning("MetaFormer models unavailable; skipping Ludwig patch for stacked_cnn.") + return False + return patch_ludwig_direct() + + +def _is_supported_metaformer(custom_model: Optional[str]) -> bool: + return bool(custom_model) and custom_model.startswith(SUPPORTED_PREFIXES) + + +def patch_ludwig_direct(): + try: + from ludwig.encoders.image.base import Stacked2DCNN + original_stacked_cnn_init = Stacked2DCNN.__init__ + + def patched_stacked_cnn_init(self, *args, **kwargs): + custom_model = kwargs.pop("custom_model", None) + if custom_model is None: + custom_model = getattr(patch_ludwig_direct, '_metaformer_model', None) + + try: + if META_MODELS_AVAILABLE and _is_supported_metaformer(custom_model): + print(f"DETECTED MetaFormer model: {custom_model}") + print("MetaFormer encoder is being loaded and used.") + # Initialize base class to keep Ludwig internals intact + original_stacked_cnn_init(self, *args, **kwargs) + # Create our MetaFormer encoder and graft behavior + mf_encoder = create_metaformer_stacked_cnn(custom_model, **kwargs) + # ensure base attributes won't be used accidentally + for attr in ("conv_layers", "fc_layers", "combiner", "output_shape", "reduce_output"): + if hasattr(self, attr): + try: + setattr(self, attr, getattr(mf_encoder, attr, None)) + except Exception: + pass + self.forward = mf_encoder.forward + if hasattr(mf_encoder, 'backbone'): + self.backbone = mf_encoder.backbone + if hasattr(mf_encoder, 'fc_layers'): + self.fc_layers = mf_encoder.fc_layers + if hasattr(mf_encoder, 'custom_model'): + self.custom_model = mf_encoder.custom_model + # explicit confirmation logs + try: + url_info = getattr(mf_encoder, '_loaded_weights_url', None) + loaded_flag = getattr(mf_encoder, '_pretrained_loaded', False) + if loaded_flag and url_info: + print(f"CONFIRMED: MetaFormer '{custom_model}' using pretrained weights from: {url_info}") + logger.info(f"CONFIRMED: MetaFormer '{custom_model}' using pretrained weights from: {url_info}") + else: + print(f"CONFIRMED: MetaFormer '{custom_model}' using randomly initialized weights (no pretrained)") + logger.info(f"CONFIRMED: MetaFormer '{custom_model}' using randomly initialized weights") + except Exception: + pass + else: + original_stacked_cnn_init(self, *args, **kwargs) + finally: + if hasattr(patch_ludwig_direct, '_metaformer_model'): + patch_ludwig_direct._metaformer_model = None + + Stacked2DCNN.__init__ = patched_stacked_cnn_init + return True + except Exception as e: + logger.error(f"Failed to apply MetaFormer direct patch: {e}") + return False + + +def set_current_metaformer_model(model_name: str): + """Store the current MetaFormer model name for the patch to use.""" + setattr(patch_ludwig_direct, '_metaformer_model', model_name) + + +def clear_current_metaformer_model(): + """Remove any cached MetaFormer model hint.""" + if hasattr(patch_ludwig_direct, '_metaformer_model'): + delattr(patch_ludwig_direct, '_metaformer_model')