Mercurial > repos > goeckslab > image_learner
annotate MetaFormer/metaformer_stacked_cnn.py @ 11:c5150cceab47 draft default tip
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
| author | goeckslab | 
|---|---|
| date | Sat, 18 Oct 2025 03:17:09 +0000 | 
| parents | |
| children | 
| rev | line source | 
|---|---|
| 11 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 1 import logging | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 2 import os | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 3 import sys | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 4 from typing import Dict, List, Optional | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 5 | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 6 import torch | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 7 import torch.nn as nn | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 8 | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 9 sys.path.insert(0, os.path.dirname(__file__)) | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 10 | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 11 logging.basicConfig( | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 12 level=logging.INFO, | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 13 format="%(asctime)s %(levelname)s %(name)s: %(message)s", | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 14 ) | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 15 logger = logging.getLogger(__name__) | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 16 | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 17 SUPPORTED_PREFIXES = ( | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 18 'identityformer_', | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 19 'randformer_', | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 20 'poolformerv2_', | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 21 'convformer_', | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 22 'caformer_', | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 23 ) | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 24 | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 25 try: | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 26 from metaformer_models import default_cfgs as META_DEFAULT_CFGS | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 27 META_MODELS_AVAILABLE = True | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 28 logger.info("MetaFormer models imported successfully") | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 29 except Exception as e: | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 30 META_MODELS_AVAILABLE = False | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 31 logger.warning(f"MetaFormer models not available: {e}") | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 32 | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 33 | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 34 def _resolve_metaformer_ctor(model_name: str): | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 35 # Prefer getattr to avoid importing every factory explicitly | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 36 try: | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 37 # Import the module itself for dynamic access | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 38 import metaformer_models | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 39 _factories = metaformer_models.__dict__ | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 40 if model_name in _factories and callable(_factories[model_name]): | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 41 return _factories[model_name] | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 42 except Exception: | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 43 pass | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 44 return None | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 45 | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 46 | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 47 class MetaFormerStackedCNN(nn.Module): | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 48 def __init__( | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 49 self, | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 50 height: int = 224, | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 51 width: int = 224, | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 52 num_channels: int = 3, | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 53 output_size: int = 128, | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 54 custom_model: str = "identityformer_s12", | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 55 use_pretrained: bool = True, | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 56 trainable: bool = True, | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 57 conv_layers: Optional[List[Dict]] = None, | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 58 num_conv_layers: Optional[int] = None, | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 59 conv_activation: str = "relu", | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 60 conv_dropout: float = 0.0, | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 61 conv_norm: Optional[str] = None, | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 62 conv_use_bias: bool = True, | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 63 fc_layers: Optional[List[Dict]] = None, | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 64 num_fc_layers: int = 1, | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 65 fc_activation: str = "relu", | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 66 fc_dropout: float = 0.0, | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 67 fc_norm: Optional[str] = None, | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 68 fc_use_bias: bool = True, | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 69 **kwargs, | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 70 ): | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 71 super().__init__() | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 72 logger.info("MetaFormerStackedCNN encoder instantiated") | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 73 logger.info(f"Using MetaFormer model: {custom_model}") | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 74 | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 75 try: | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 76 height = int(height) | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 77 width = int(width) | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 78 num_channels = int(num_channels) | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 79 except (TypeError, ValueError) as exc: | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 80 raise ValueError("MetaFormerStackedCNN requires integer height, width, and num_channels.") from exc | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 81 | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 82 if height <= 0 or width <= 0: | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 83 raise ValueError(f"MetaFormerStackedCNN received non-positive dimensions: {height}x{width}.") | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 84 if num_channels <= 0: | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 85 raise ValueError(f"MetaFormerStackedCNN requires num_channels > 0, received {num_channels}.") | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 86 | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 87 self.height = height | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 88 self.width = width | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 89 self.num_channels = num_channels | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 90 self.output_size = output_size | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 91 self.custom_model = custom_model | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 92 self.use_pretrained = use_pretrained | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 93 self.trainable = trainable | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 94 | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 95 cfg = META_DEFAULT_CFGS.get(custom_model, {}) | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 96 input_size = cfg.get('input_size', (3, 224, 224)) | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 97 if isinstance(input_size, (list, tuple)) and len(input_size) == 3: | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 98 expected_channels, expected_height, expected_width = input_size | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 99 else: | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 100 expected_channels, expected_height, expected_width = 3, 224, 224 | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 101 | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 102 self.expected_channels = expected_channels | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 103 self.expected_height = expected_height | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 104 self.expected_width = expected_width | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 105 | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 106 logger.info(f"Initializing MetaFormerStackedCNN with model: {custom_model}") | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 107 logger.info( | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 108 "Input: %sx%sx%s -> Output: %s (expected backbone size: %sx%s)", | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 109 num_channels, | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 110 height, | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 111 width, | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 112 output_size, | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 113 self.expected_height, | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 114 self.expected_width, | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 115 ) | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 116 | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 117 self.channel_adapter: Optional[nn.Conv2d] = None | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 118 if num_channels != self.expected_channels: | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 119 self.channel_adapter = nn.Conv2d( | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 120 num_channels, self.expected_channels, kernel_size=1, stride=1, padding=0 | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 121 ) | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 122 logger.info( | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 123 "Added channel adapter: %s -> %s channels", | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 124 num_channels, | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 125 self.expected_channels, | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 126 ) | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 127 | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 128 self.size_adapter: Optional[nn.Module] = None | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 129 if height != self.expected_height or width != self.expected_width: | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 130 self.size_adapter = nn.AdaptiveAvgPool2d((height, width)) | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 131 logger.info( | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 132 "Configured size adapter to requested input: %sx%s", | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 133 height, | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 134 width, | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 135 ) | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 136 self.backbone_adapter: Optional[nn.Module] = None | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 137 | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 138 self.backbone = self._load_metaformer_backbone() | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 139 self.feature_dim = self._get_feature_dim() | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 140 | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 141 self.fc_layers = self._create_fc_layers( | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 142 input_dim=self.feature_dim, | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 143 output_dim=output_size, | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 144 num_layers=num_fc_layers, | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 145 activation=fc_activation, | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 146 dropout=fc_dropout, | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 147 norm=fc_norm, | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 148 use_bias=fc_use_bias, | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 149 fc_layers_config=fc_layers, | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 150 ) | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 151 | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 152 if not trainable: | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 153 for param in self.backbone.parameters(): | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 154 param.requires_grad = False | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 155 logger.info("MetaFormer backbone frozen (trainable=False)") | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 156 | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 157 logger.info("MetaFormerStackedCNN initialized successfully") | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 158 | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 159 def _load_metaformer_backbone(self): | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 160 if not META_MODELS_AVAILABLE: | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 161 raise ImportError("MetaFormer models are not available") | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 162 | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 163 ctor = _resolve_metaformer_ctor(self.custom_model) | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 164 if ctor is None: | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 165 raise ValueError(f"Unknown MetaFormer model: {self.custom_model}") | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 166 | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 167 cfg = META_DEFAULT_CFGS.get(self.custom_model, {}) | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 168 weights_url = cfg.get('url') | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 169 # track loading | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 170 self._pretrained_loaded = False | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 171 self._loaded_weights_url: Optional[str] = None | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 172 if self.use_pretrained and weights_url: | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 173 print(f"LOADING MetaFormer pretrained weights from: {weights_url}") | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 174 logger.info(f"Loading pretrained weights from: {weights_url}") | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 175 # Ensure we log whenever the factories call torch.hub.load_state_dict_from_url | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 176 orig_loader = getattr(torch.hub, 'load_state_dict_from_url', None) | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 177 | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 178 def _wrapped_loader(url, *args, **kwargs): | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 179 print(f"DOWNLOADING weights from: {url}") | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 180 logger.info(f"DOWNLOADING weights from: {url}") | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 181 self._pretrained_loaded = True | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 182 self._loaded_weights_url = url | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 183 result = orig_loader(url, *args, **kwargs) | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 184 print(f"WEIGHTS DOWNLOADED successfully from: {url}") | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 185 return result | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 186 try: | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 187 if self.use_pretrained and orig_loader is not None: | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 188 torch.hub.load_state_dict_from_url = _wrapped_loader # type: ignore[attr-defined] | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 189 print(f"CREATING MetaFormer model: {self.custom_model} (pretrained={self.use_pretrained})") | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 190 try: | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 191 model = ctor(pretrained=self.use_pretrained, num_classes=1000) | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 192 print(f"MetaFormer model CREATED: {self.custom_model}") | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 193 except Exception as model_error: | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 194 if self.use_pretrained: | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 195 print(f"âš Warning: Failed to load {self.custom_model} with pretrained weights: {model_error}") | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 196 print("Attempting to load without pretrained weights as fallback...") | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 197 logger.warning(f"Failed to load {self.custom_model} with pretrained weights: {model_error}") | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 198 model = ctor(pretrained=False, num_classes=1000) | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 199 print(f"✓ Successfully loaded {self.custom_model} without pretrained weights") | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 200 self.use_pretrained = False # Update state to reflect actual loading | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 201 else: | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 202 raise model_error | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 203 finally: | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 204 if orig_loader is not None: | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 205 torch.hub.load_state_dict_from_url = orig_loader # type: ignore[attr-defined] | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 206 self._metaformer_weights_url = weights_url | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 207 if self.use_pretrained: | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 208 if self._pretrained_loaded: | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 209 print(f"MetaFormer: pretrained weights loaded from {self._loaded_weights_url}") | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 210 logger.info(f"MetaFormer: pretrained weights loaded from {self._loaded_weights_url}") | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 211 else: | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 212 # Warn but don't fail - weights may have failed to load but model creation succeeded | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 213 print("âš Warning: MetaFormer pretrained weights were requested but not confirmed as loaded") | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 214 logger.warning("MetaFormer: pretrained weights were requested but not confirmed as loaded") | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 215 else: | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 216 print(f"MetaFormer: using randomly initialized weights for {self.custom_model}") | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 217 logger.info(f"MetaFormer: using randomly initialized weights for {self.custom_model}") | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 218 logger.info(f"Loaded MetaFormer backbone: {self.custom_model} (pretrained={self.use_pretrained})") | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 219 return model | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 220 | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 221 def _get_feature_dim(self): | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 222 with torch.no_grad(): | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 223 dummy_input = torch.randn(1, 3, 224, 224) | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 224 features = self.backbone.forward_features(dummy_input) | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 225 feature_dim = features.shape[-1] | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 226 logger.info(f"MetaFormer feature dimension: {feature_dim}") | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 227 return feature_dim | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 228 | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 229 def _create_fc_layers(self, input_dim, output_dim, num_layers, activation, dropout, norm, use_bias, fc_layers_config): | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 230 layers = [] | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 231 if fc_layers_config: | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 232 current_dim = input_dim | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 233 for i, layer_config in enumerate(fc_layers_config): | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 234 layer_output_dim = layer_config.get('output_size', output_dim if i == len(fc_layers_config) - 1 else current_dim) | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 235 layers.append(nn.Linear(current_dim, layer_output_dim, bias=use_bias)) | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 236 if i < len(fc_layers_config) - 1: | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 237 if activation == "relu": | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 238 layers.append(nn.ReLU()) | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 239 elif activation == "tanh": | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 240 layers.append(nn.Tanh()) | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 241 elif activation == "sigmoid": | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 242 layers.append(nn.Sigmoid()) | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 243 elif activation == "leaky_relu": | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 244 layers.append(nn.LeakyReLU()) | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 245 if dropout > 0: | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 246 layers.append(nn.Dropout(dropout)) | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 247 if norm == "batch": | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 248 layers.append(nn.BatchNorm1d(layer_output_dim)) | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 249 elif norm == "layer": | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 250 layers.append(nn.LayerNorm(layer_output_dim)) | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 251 current_dim = layer_output_dim | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 252 else: | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 253 if num_layers == 1: | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 254 layers.append(nn.Linear(input_dim, output_dim, bias=use_bias)) | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 255 else: | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 256 intermediate_dims = [input_dim] | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 257 for i in range(num_layers - 1): | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 258 intermediate_dim = int(input_dim * (0.5 ** (i + 1))) | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 259 intermediate_dim = max(intermediate_dim, output_dim) | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 260 intermediate_dims.append(intermediate_dim) | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 261 intermediate_dims.append(output_dim) | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 262 for i in range(num_layers): | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 263 layers.append(nn.Linear(intermediate_dims[i], intermediate_dims[i + 1], bias=use_bias)) | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 264 if i < num_layers - 1: | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 265 if activation == "relu": | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 266 layers.append(nn.ReLU()) | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 267 elif activation == "tanh": | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 268 layers.append(nn.Tanh()) | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 269 elif activation == "sigmoid": | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 270 layers.append(nn.Sigmoid()) | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 271 elif activation == "leaky_relu": | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 272 layers.append(nn.LeakyReLU()) | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 273 if dropout > 0: | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 274 layers.append(nn.Dropout(dropout)) | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 275 if norm == "batch": | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 276 layers.append(nn.BatchNorm1d(intermediate_dims[i + 1])) | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 277 elif norm == "layer": | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 278 layers.append(nn.LayerNorm(intermediate_dims[i + 1])) | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 279 return nn.Sequential(*layers) | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 280 | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 281 def forward(self, x): | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 282 if x.shape[1] != self.expected_channels: | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 283 if ( | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 284 self.channel_adapter is None | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 285 or self.channel_adapter.in_channels != x.shape[1] | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 286 or self.channel_adapter.out_channels != self.expected_channels | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 287 ): | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 288 self.channel_adapter = nn.Conv2d( | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 289 x.shape[1], | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 290 self.expected_channels, | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 291 kernel_size=1, | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 292 stride=1, | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 293 padding=0, | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 294 ).to(x.device) | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 295 logger.info( | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 296 "Created dynamic channel adapter: %s -> %s channels", | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 297 x.shape[1], | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 298 self.expected_channels, | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 299 ) | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 300 x = self.channel_adapter(x) | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 301 | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 302 target_height, target_width = self.height, self.width | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 303 if x.shape[2] != target_height or x.shape[3] != target_width: | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 304 if ( | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 305 self.size_adapter is None | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 306 or getattr(self.size_adapter, "output_size", None) | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 307 != (target_height, target_width) | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 308 ): | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 309 self.size_adapter = nn.AdaptiveAvgPool2d( | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 310 (target_height, target_width) | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 311 ).to(x.device) | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 312 logger.info( | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 313 "Created size adapter: %sx%s -> %sx%s", | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 314 x.shape[2], | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 315 x.shape[3], | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 316 target_height, | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 317 target_width, | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 318 ) | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 319 x = self.size_adapter(x) | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 320 | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 321 if target_height != self.expected_height or target_width != self.expected_width: | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 322 if ( | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 323 self.backbone_adapter is None | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 324 or getattr(self.backbone_adapter, "output_size", None) | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 325 != (self.expected_height, self.expected_width) | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 326 ): | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 327 self.backbone_adapter = nn.AdaptiveAvgPool2d( | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 328 (self.expected_height, self.expected_width) | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 329 ).to(x.device) | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 330 logger.info( | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 331 "Aligning to MetaFormer backbone size: %sx%s", | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 332 self.expected_height, | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 333 self.expected_width, | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 334 ) | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 335 x = self.backbone_adapter(x) | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 336 | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 337 features = self.backbone.forward_features(x) | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 338 output = self.fc_layers(features) | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 339 return {'encoder_output': output} | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 340 | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 341 @property | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 342 def output_shape(self): | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 343 return [self.output_size] | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 344 | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 345 | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 346 def create_metaformer_stacked_cnn(model_name: str, **kwargs) -> MetaFormerStackedCNN: | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 347 encoder = MetaFormerStackedCNN(custom_model=model_name, **kwargs) | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 348 return encoder | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 349 | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 350 | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 351 def patch_ludwig_stacked_cnn(): | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 352 # Only patch Ludwig if MetaFormer models are available in this runtime | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 353 if not META_MODELS_AVAILABLE: | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 354 logger.warning("MetaFormer models unavailable; skipping Ludwig patch for stacked_cnn.") | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 355 return False | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 356 return patch_ludwig_direct() | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 357 | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 358 | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 359 def _is_supported_metaformer(custom_model: Optional[str]) -> bool: | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 360 return bool(custom_model) and custom_model.startswith(SUPPORTED_PREFIXES) | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 361 | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 362 | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 363 def patch_ludwig_direct(): | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 364 try: | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 365 from ludwig.encoders.image.base import Stacked2DCNN | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 366 original_stacked_cnn_init = Stacked2DCNN.__init__ | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 367 | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 368 def patched_stacked_cnn_init(self, *args, **kwargs): | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 369 custom_model = kwargs.pop("custom_model", None) | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 370 if custom_model is None: | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 371 custom_model = getattr(patch_ludwig_direct, '_metaformer_model', None) | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 372 | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 373 try: | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 374 if META_MODELS_AVAILABLE and _is_supported_metaformer(custom_model): | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 375 print(f"DETECTED MetaFormer model: {custom_model}") | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 376 print("MetaFormer encoder is being loaded and used.") | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 377 # Initialize base class to keep Ludwig internals intact | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 378 original_stacked_cnn_init(self, *args, **kwargs) | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 379 # Create our MetaFormer encoder and graft behavior | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 380 mf_encoder = create_metaformer_stacked_cnn(custom_model, **kwargs) | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 381 # ensure base attributes won't be used accidentally | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 382 for attr in ("conv_layers", "fc_layers", "combiner", "output_shape", "reduce_output"): | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 383 if hasattr(self, attr): | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 384 try: | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 385 setattr(self, attr, getattr(mf_encoder, attr, None)) | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 386 except Exception: | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 387 pass | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 388 self.forward = mf_encoder.forward | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 389 if hasattr(mf_encoder, 'backbone'): | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 390 self.backbone = mf_encoder.backbone | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 391 if hasattr(mf_encoder, 'fc_layers'): | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 392 self.fc_layers = mf_encoder.fc_layers | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 393 if hasattr(mf_encoder, 'custom_model'): | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 394 self.custom_model = mf_encoder.custom_model | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 395 # explicit confirmation logs | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 396 try: | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 397 url_info = getattr(mf_encoder, '_loaded_weights_url', None) | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 398 loaded_flag = getattr(mf_encoder, '_pretrained_loaded', False) | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 399 if loaded_flag and url_info: | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 400 print(f"CONFIRMED: MetaFormer '{custom_model}' using pretrained weights from: {url_info}") | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 401 logger.info(f"CONFIRMED: MetaFormer '{custom_model}' using pretrained weights from: {url_info}") | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 402 else: | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 403 print(f"CONFIRMED: MetaFormer '{custom_model}' using randomly initialized weights (no pretrained)") | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 404 logger.info(f"CONFIRMED: MetaFormer '{custom_model}' using randomly initialized weights") | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 405 except Exception: | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 406 pass | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 407 else: | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 408 original_stacked_cnn_init(self, *args, **kwargs) | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 409 finally: | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 410 if hasattr(patch_ludwig_direct, '_metaformer_model'): | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 411 patch_ludwig_direct._metaformer_model = None | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 412 | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 413 Stacked2DCNN.__init__ = patched_stacked_cnn_init | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 414 return True | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 415 except Exception as e: | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 416 logger.error(f"Failed to apply MetaFormer direct patch: {e}") | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 417 return False | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 418 | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 419 | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 420 def set_current_metaformer_model(model_name: str): | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 421 """Store the current MetaFormer model name for the patch to use.""" | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 422 setattr(patch_ludwig_direct, '_metaformer_model', model_name) | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 423 | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 424 | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 425 def clear_current_metaformer_model(): | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 426 """Remove any cached MetaFormer model hint.""" | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 427 if hasattr(patch_ludwig_direct, '_metaformer_model'): | 
| 
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
 goeckslab parents: diff
changeset | 428 delattr(patch_ludwig_direct, '_metaformer_model') | 
