Mercurial > repos > goeckslab > extract_embeddings
comparison pytorch_embedding.py @ 0:38333676a029 draft
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
| author | goeckslab |
|---|---|
| date | Thu, 19 Jun 2025 23:33:23 +0000 |
| parents | |
| children | 84f96c952c2c |
comparison
equal
deleted
inserted
replaced
| -1:000000000000 | 0:38333676a029 |
|---|---|
| 1 """ | |
| 2 This module provides functionality to extract image embeddings | |
| 3 using a specified | |
| 4 pretrained model from the torchvision library. It includes functions to: | |
| 5 - List image files directly from a ZIP file without extraction. | |
| 6 - Apply model-specific preprocessing and transformations. | |
| 7 - Extract embeddings using various models. | |
| 8 - Save the resulting embeddings into a CSV file. | |
| 9 Modules required: | |
| 10 - argparse: For command-line argument parsing. | |
| 11 - os, csv, zipfile: For file handling (ZIP file reading, CSV writing). | |
| 12 - inspect: For inspecting function signatures and models. | |
| 13 - torch, torchvision: For loading and using pretrained models | |
| 14 to extract embeddings. | |
| 15 - PIL, cv2: For image processing tasks such as resizing, normalization, | |
| 16 and conversion. | |
| 17 """ | |
| 18 | |
| 19 import argparse | |
| 20 import csv | |
| 21 import inspect | |
| 22 import logging | |
| 23 import os | |
| 24 import zipfile | |
| 25 from inspect import signature | |
| 26 | |
| 27 import cv2 | |
| 28 import numpy as np | |
| 29 import torch | |
| 30 import torchvision.models as models | |
| 31 from PIL import Image | |
| 32 from torch.utils.data import DataLoader, Dataset | |
| 33 from torchvision import transforms | |
| 34 | |
| 35 # Configure logging | |
| 36 logging.basicConfig( | |
| 37 filename="/tmp/ludwig_embeddings.log", | |
| 38 filemode="a", | |
| 39 format="%(asctime)s - %(levelname)s - %(message)s", | |
| 40 level=logging.DEBUG, | |
| 41 ) | |
| 42 | |
| 43 # Create a cache directory in the current working directory | |
| 44 cache_dir = os.path.join(os.getcwd(), 'hf_cache') | |
| 45 try: | |
| 46 os.makedirs(cache_dir, exist_ok=True) | |
| 47 logging.info(f"Cache directory created: {cache_dir}, writable: {os.access(cache_dir, os.W_OK)}") | |
| 48 except OSError as e: | |
| 49 logging.error(f"Failed to create cache directory {cache_dir}: {e}") | |
| 50 raise | |
| 51 | |
| 52 # Available models from torchvision | |
| 53 AVAILABLE_MODELS = { | |
| 54 name: getattr(models, name) | |
| 55 for name in dir(models) | |
| 56 if callable( | |
| 57 getattr(models, name) | |
| 58 ) and "weights" in signature(getattr(models, name)).parameters | |
| 59 } | |
| 60 | |
| 61 # Default resize and normalization settings for models | |
| 62 MODEL_DEFAULTS = { | |
| 63 "default": {"resize": (224, 224), "normalize": ( | |
| 64 [0.485, 0.456, 0.406], [0.229, 0.224, 0.225] | |
| 65 )}, | |
| 66 "efficientnet_b1": {"resize": (240, 240)}, | |
| 67 "efficientnet_b2": {"resize": (260, 260)}, | |
| 68 "efficientnet_b3": {"resize": (300, 300)}, | |
| 69 "efficientnet_b4": {"resize": (380, 380)}, | |
| 70 "efficientnet_b5": {"resize": (456, 456)}, | |
| 71 "efficientnet_b6": {"resize": (528, 528)}, | |
| 72 "efficientnet_b7": {"resize": (600, 600)}, | |
| 73 "inception_v3": {"resize": (299, 299)}, | |
| 74 "swin_b": {"resize": (224, 224), "normalize": ( | |
| 75 [0.5, 0.0, 0.5], [0.5, 0.5, 0.5] | |
| 76 )}, | |
| 77 "swin_s": {"resize": (224, 224), "normalize": ( | |
| 78 [0.5, 0.0, 0.5], [0.5, 0.5, 0.5] | |
| 79 )}, | |
| 80 "swin_t": {"resize": (224, 224), "normalize": ( | |
| 81 [0.5, 0.0, 0.5], [0.5, 0.5, 0.5] | |
| 82 )}, | |
| 83 "vit_b_16": {"resize": (224, 224), "normalize": ( | |
| 84 [0.5, 0.5, 0.5], [0.5, 0.5, 0.5] | |
| 85 )}, | |
| 86 "vit_b_32": {"resize": (224, 224), "normalize": ( | |
| 87 [0.5, 0.5, 0.5], [0.5, 0.5, 0.5] | |
| 88 )}, | |
| 89 } | |
| 90 | |
| 91 for model, settings in MODEL_DEFAULTS.items(): | |
| 92 if "normalize" not in settings: | |
| 93 settings["normalize"] = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
| 94 | |
| 95 | |
| 96 # Custom transform classes | |
| 97 class CLAHETransform: | |
| 98 def __init__(self, clip_limit=2.0, tile_grid_size=(8, 8)): | |
| 99 self.clahe = cv2.createCLAHE( | |
| 100 clipLimit=clip_limit, | |
| 101 tileGridSize=tile_grid_size | |
| 102 ) | |
| 103 | |
| 104 def __call__(self, img): | |
| 105 img = np.array(img.convert("L")) | |
| 106 img = self.clahe.apply(img) | |
| 107 return Image.fromarray(img).convert("RGB") | |
| 108 | |
| 109 | |
| 110 class CannyTransform: | |
| 111 def __init__(self, threshold1=100, threshold2=200): | |
| 112 self.threshold1 = threshold1 | |
| 113 self.threshold2 = threshold2 | |
| 114 | |
| 115 def __call__(self, img): | |
| 116 img = np.array(img.convert("L")) | |
| 117 edges = cv2.Canny(img, self.threshold1, self.threshold2) | |
| 118 return Image.fromarray(edges).convert("RGB") | |
| 119 | |
| 120 | |
| 121 class RGBAtoRGBTransform: | |
| 122 def __call__(self, img): | |
| 123 if img.mode == "RGBA": | |
| 124 background = Image.new("RGBA", img.size, (255, 255, 255, 255)) | |
| 125 img = Image.alpha_composite(background, img).convert("RGB") | |
| 126 else: | |
| 127 img = img.convert("RGB") | |
| 128 return img | |
| 129 | |
| 130 | |
| 131 def get_image_files_from_zip(zip_file): | |
| 132 """Returns a list of image file names in the ZIP file.""" | |
| 133 try: | |
| 134 with zipfile.ZipFile(zip_file, "r") as zip_ref: | |
| 135 file_list = [ | |
| 136 f for f in zip_ref.namelist() if f.lower().endswith( | |
| 137 (".png", ".jpg", ".jpeg", ".bmp", ".gif") | |
| 138 ) | |
| 139 ] | |
| 140 return file_list | |
| 141 except zipfile.BadZipFile as exc: | |
| 142 raise RuntimeError("Invalid ZIP file.") from exc | |
| 143 except Exception as exc: | |
| 144 raise RuntimeError("Error reading ZIP file.") from exc | |
| 145 | |
| 146 | |
| 147 def load_model(model_name, device): | |
| 148 """Loads a specified torchvision model and | |
| 149 modifies it for feature extraction.""" | |
| 150 if model_name not in AVAILABLE_MODELS: | |
| 151 raise ValueError( | |
| 152 f"Unsupported model: {model_name}. \ | |
| 153 Available models: {list(AVAILABLE_MODELS.keys())}") | |
| 154 try: | |
| 155 if "weights" in inspect.signature( | |
| 156 AVAILABLE_MODELS[model_name]).parameters: | |
| 157 model = AVAILABLE_MODELS[model_name](weights="DEFAULT").to(device) | |
| 158 else: | |
| 159 model = AVAILABLE_MODELS[model_name]().to(device) | |
| 160 logging.info("Model loaded") | |
| 161 except Exception as e: | |
| 162 logging.error(f"Failed to load model {model_name}: {e}") | |
| 163 raise | |
| 164 | |
| 165 if hasattr(model, "fc"): | |
| 166 model.fc = torch.nn.Identity() | |
| 167 elif hasattr(model, "classifier"): | |
| 168 model.classifier = torch.nn.Identity() | |
| 169 elif hasattr(model, "head"): | |
| 170 model.head = torch.nn.Identity() | |
| 171 | |
| 172 model.eval() | |
| 173 return model | |
| 174 | |
| 175 | |
| 176 def write_csv(output_csv, list_embeddings, ludwig_format=False): | |
| 177 """Writes embeddings to a CSV file, optionally in Ludwig format.""" | |
| 178 with open(output_csv, mode="w", encoding="utf-8", newline="") as csv_file: | |
| 179 csv_writer = csv.writer(csv_file) | |
| 180 if list_embeddings: | |
| 181 if ludwig_format: | |
| 182 header = ["sample_name", "embedding"] | |
| 183 formatted_embeddings = [] | |
| 184 for embedding in list_embeddings: | |
| 185 sample_name = embedding[0] | |
| 186 vector = embedding[1:] | |
| 187 embedding_str = " ".join(map(str, vector)) | |
| 188 formatted_embeddings.append([sample_name, embedding_str]) | |
| 189 csv_writer.writerow(header) | |
| 190 csv_writer.writerows(formatted_embeddings) | |
| 191 logging.info("CSV created in Ludwig format") | |
| 192 else: | |
| 193 header = ["sample_name"] + [f"vector{i + 1}" for i in range( | |
| 194 len(list_embeddings[0]) - 1 | |
| 195 )] | |
| 196 csv_writer.writerow(header) | |
| 197 csv_writer.writerows(list_embeddings) | |
| 198 logging.info("CSV created") | |
| 199 else: | |
| 200 csv_writer.writerow(["sample_name"] if not ludwig_format | |
| 201 else ["sample_name", "embedding"]) | |
| 202 logging.info("No valid images found. Empty CSV created.") | |
| 203 | |
| 204 | |
| 205 def extract_embeddings( | |
| 206 model_name, | |
| 207 apply_normalization, | |
| 208 zip_file, | |
| 209 file_list, | |
| 210 transform_type="rgb"): | |
| 211 """Extracts embeddings from images | |
| 212 using batch processing or sequential fallback.""" | |
| 213 | |
| 214 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| 215 model = load_model(model_name, device) | |
| 216 model_settings = MODEL_DEFAULTS.get(model_name, MODEL_DEFAULTS["default"]) | |
| 217 resize = model_settings["resize"] | |
| 218 normalize = model_settings.get("normalize", ( | |
| 219 [0.485, 0.456, 0.406], [0.229, 0.224, 0.225] | |
| 220 )) | |
| 221 | |
| 222 # Define transform pipeline | |
| 223 if transform_type == "grayscale": | |
| 224 initial_transform = transforms.Grayscale(num_output_channels=3) | |
| 225 elif transform_type == "clahe": | |
| 226 initial_transform = CLAHETransform() | |
| 227 elif transform_type == "edges": | |
| 228 initial_transform = CannyTransform() | |
| 229 elif transform_type == "rgba_to_rgb": | |
| 230 initial_transform = RGBAtoRGBTransform() | |
| 231 else: | |
| 232 initial_transform = transforms.Lambda(lambda x: x.convert("RGB")) | |
| 233 | |
| 234 transform_list = [initial_transform, | |
| 235 transforms.Resize(resize), | |
| 236 transforms.ToTensor()] | |
| 237 if apply_normalization: | |
| 238 transform_list.append(transforms.Normalize(mean=normalize[0], | |
| 239 std=normalize[1])) | |
| 240 transform = transforms.Compose(transform_list) | |
| 241 | |
| 242 class ImageDataset(Dataset): | |
| 243 def __init__(self, zip_file, file_list, transform=None): | |
| 244 self.zip_file = zip_file | |
| 245 self.file_list = file_list | |
| 246 self.transform = transform | |
| 247 | |
| 248 def __len__(self): | |
| 249 return len(self.file_list) | |
| 250 | |
| 251 def __getitem__(self, idx): | |
| 252 with zipfile.ZipFile(self.zip_file, "r") as zip_ref: | |
| 253 with zip_ref.open(self.file_list[idx]) as file: | |
| 254 try: | |
| 255 image = Image.open(file) | |
| 256 if self.transform: | |
| 257 image = self.transform(image) | |
| 258 return image, os.path.basename(self.file_list[idx]) | |
| 259 except Exception as e: | |
| 260 logging.warning( | |
| 261 "Skipping %s: %s", self.file_list[idx], e | |
| 262 ) | |
| 263 return None, os.path.basename(self.file_list[idx]) | |
| 264 | |
| 265 # Custom collate function | |
| 266 def collate_fn(batch): | |
| 267 batch = [item for item in batch if item[0] is not None] | |
| 268 if not batch: | |
| 269 return None, None | |
| 270 images, names = zip(*batch) | |
| 271 return torch.stack(images), names | |
| 272 | |
| 273 list_embeddings = [] | |
| 274 with torch.inference_mode(): | |
| 275 try: | |
| 276 # Try DataLoader with reduced resource usage | |
| 277 dataset = ImageDataset(zip_file, file_list, transform=transform) | |
| 278 dataloader = DataLoader( | |
| 279 dataset, | |
| 280 batch_size=16, # Reduced for lower memory usage | |
| 281 num_workers=1, # Reduced to minimize shared memory | |
| 282 shuffle=False, | |
| 283 pin_memory=True if device == "cuda" else False, | |
| 284 collate_fn=collate_fn, | |
| 285 ) | |
| 286 for images, names in dataloader: | |
| 287 if images is None: | |
| 288 continue | |
| 289 images = images.to(device) | |
| 290 embeddings = model(images).cpu().numpy() | |
| 291 for name, embedding in zip(names, embeddings): | |
| 292 list_embeddings.append([name] + embedding.tolist()) | |
| 293 except RuntimeError as e: | |
| 294 logging.warning( | |
| 295 f"DataLoader failed: {e}. \ | |
| 296 Falling back to sequential processing." | |
| 297 ) | |
| 298 # Fallback to sequential processing | |
| 299 for file in file_list: | |
| 300 with zipfile.ZipFile(zip_file, "r") as zip_ref: | |
| 301 with zip_ref.open(file) as img_file: | |
| 302 try: | |
| 303 image = Image.open(img_file) | |
| 304 image = transform(image) | |
| 305 input_tensor = image.unsqueeze(0).to(device) | |
| 306 embedding = model( | |
| 307 input_tensor | |
| 308 ).squeeze().cpu().numpy() | |
| 309 list_embeddings.append( | |
| 310 [os.path.basename(file)] + embedding.tolist() | |
| 311 ) | |
| 312 except Exception as e: | |
| 313 logging.warning("Skipping %s: %s", file, e) | |
| 314 | |
| 315 return list_embeddings | |
| 316 | |
| 317 | |
| 318 def main(zip_file, output_csv, model_name, apply_normalization=False, | |
| 319 transform_type="rgb", ludwig_format=False): | |
| 320 """Main entry point for processing the zip file and | |
| 321 extracting embeddings.""" | |
| 322 file_list = get_image_files_from_zip(zip_file) | |
| 323 logging.info("Image files listed from ZIP") | |
| 324 | |
| 325 list_embeddings = extract_embeddings( | |
| 326 model_name, | |
| 327 apply_normalization, | |
| 328 zip_file, | |
| 329 file_list, | |
| 330 transform_type | |
| 331 ) | |
| 332 logging.info("Embeddings extracted") | |
| 333 write_csv(output_csv, list_embeddings, ludwig_format) | |
| 334 | |
| 335 | |
| 336 if __name__ == "__main__": | |
| 337 parser = argparse.ArgumentParser(description="Extract image embeddings.") | |
| 338 parser.add_argument( | |
| 339 "--zip_file", | |
| 340 required=True, | |
| 341 help="Path to the ZIP file containing images." | |
| 342 ) | |
| 343 parser.add_argument( | |
| 344 "--model_name", | |
| 345 required=True, | |
| 346 choices=AVAILABLE_MODELS.keys(), | |
| 347 help="Model for embedding extraction." | |
| 348 ) | |
| 349 parser.add_argument( | |
| 350 "--normalize", | |
| 351 action="store_true", | |
| 352 help="Whether to apply normalization." | |
| 353 ) | |
| 354 parser.add_argument( | |
| 355 "--transform_type", | |
| 356 required=True, | |
| 357 help="Image transformation type." | |
| 358 ) | |
| 359 parser.add_argument( | |
| 360 "--output_csv", | |
| 361 required=True, | |
| 362 help="Path to the output CSV file" | |
| 363 ) | |
| 364 parser.add_argument( | |
| 365 "--ludwig_format", | |
| 366 action="store_true", | |
| 367 help="Prepare CSV file in Ludwig input format" | |
| 368 ) | |
| 369 | |
| 370 args = parser.parse_args() | |
| 371 main( | |
| 372 args.zip_file, | |
| 373 args.output_csv, | |
| 374 args.model_name, | |
| 375 args.normalize, | |
| 376 args.transform_type, | |
| 377 args.ludwig_format | |
| 378 ) |
