Mercurial > repos > goeckslab > image_learner
changeset 19:c460abae83eb draft default tip
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
| author | goeckslab |
|---|---|
| date | Thu, 18 Dec 2025 16:59:58 +0000 |
| parents | bbf30253c99f |
| children | |
| files | html_structure.py image_learner.xml ludwig_backend.py |
| diffstat | 3 files changed, 402 insertions(+), 2 deletions(-) [+] |
line wrap: on
line diff
--- a/html_structure.py Sun Dec 14 03:27:12 2025 +0000 +++ b/html_structure.py Thu Dec 18 16:59:58 2025 +0000 @@ -231,7 +231,7 @@ <html> <head> <meta charset="UTF-8"> - <title>Galaxy-Ludwig Report</title> + <title>Image Learner Report</title> <style> body { font-family: Arial, sans-serif; @@ -719,6 +719,16 @@ ' <li><strong>Detailed Analysis:</strong> Use <strong>Confusion Matrix stats</strong> ' 'for class-wise performance in classification.</li>' ' </ul>' + ' <h3>11) Grad-CAM Heatmaps (When Available)</h3>' + ' <p><strong>Paper:</strong> <em>Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization</em> ' + '(Selvaraju, Cogswell, Das, Vedantam, Parikh, Batra; ICCV 2017).</p>' + ' <p><strong>What it shows:</strong> A heatmap highlighting image regions that most influenced the model’s prediction ' + 'for a small subset of evaluation samples (we prefer the test split when available).</p>' + ' <p><strong>How it is computed (high level):</strong> We use the encoder’s preprocessing (resize + normalization), ' + 'take activations from the last convolution layer, weight them by globally-averaged gradients of the target logits, apply ReLU, ' + 'upsample to input resolution, and overlay on the input image.</p>' + ' <p><strong>Availability:</strong> Only supported for convolutional encoders. Models without convolution layers may not ' + 'produce Grad-CAM outputs.</p>' ' </div>' ' </div>' '</div>'
--- a/image_learner.xml Sun Dec 14 03:27:12 2025 +0000 +++ b/image_learner.xml Thu Dec 18 16:59:58 2025 +0000 @@ -1,4 +1,4 @@ -<tool id="image_learner" name="Image Learner" version="0.1.4.1" profile="22.01"> +<tool id="image_learner" name="Image Learner" version="0.1.5" profile="22.01"> <description>trains and evaluates an image classification/regression model</description> <requirements> <container type="docker">quay.io/goeckslab/galaxy-ludwig-gpu:0.10.1</container> @@ -389,6 +389,8 @@ <discover_datasets pattern="(?P<designation>.+)\.json" format="json" directory="experiment_run" /> <discover_datasets pattern="(?P<designation>.+)\.png" format="png" directory="experiment_run/visualizations/train" /> <discover_datasets pattern="(?P<designation>.+)\.png" format="png" directory="experiment_run/visualizations/test" /> + <discover_datasets pattern="(?P<designation>.+)\.png" format="png" directory="experiment_run/feature_importance_examples" /> + <discover_datasets pattern="(?P<designation>feature_importance_examples\.zip)" format="zip" directory="experiment_run" /> </collection> </outputs> <tests>
--- a/ludwig_backend.py Sun Dec 14 03:27:12 2025 +0000 +++ b/ludwig_backend.py Thu Dec 18 16:59:58 2025 +0000 @@ -2,6 +2,7 @@ import json import logging import os +import zipfile from pathlib import Path from typing import Any, Dict, List, Optional, Protocol, Tuple @@ -712,6 +713,371 @@ except Exception as e: logger.error(f"Error converting Parquet to CSV: {e}") + def _get_latest_experiment_dir(self, output_dir: Path) -> Optional[Path]: + """Return the most recent experiment_run* directory, if present.""" + output_dir = Path(output_dir) + exp_dirs = sorted( + output_dir.glob("experiment_run*"), + key=lambda p: p.stat().st_mtime, + ) + return exp_dirs[-1] if exp_dirs else None + + def _extract_preprocessing_config( + self, exp_dir: Path, config: Dict[str, Any] + ) -> Tuple[Optional[Dict[str, Any]], Optional[Path]]: + """Parse Ludwig preprocessing settings from train_set_metadata or description.""" + image_meta: Dict[str, Any] = {} + meta_path = exp_dir / "model" / TRAIN_SET_METADATA_FILE_NAME + if meta_path.exists(): + try: + with meta_path.open("r", encoding="utf-8") as f: + meta_json = json.load(f) + image_list = meta_json.get("input_features") or [] + if image_list: + image_meta = image_list[0] or {} + except Exception as exc: + logger.warning("Unable to read train_set_metadata: %s", exc) + + # Fallback to description config for preprocessing hints + desc_cfg: Dict[str, Any] = {} + desc_path = exp_dir / DESCRIPTION_FILE_NAME + if desc_path.exists(): + try: + with desc_path.open("r", encoding="utf-8") as f: + desc_json = json.load(f) + desc_cfg = desc_json.get("config", {}) if isinstance(desc_json, dict) else {} + except Exception as exc: + logger.warning("Unable to read description.json for preprocessing: %s", exc) + + preprocessing = {} + if isinstance(image_meta, dict): + preprocessing = image_meta.get("preprocessing") or {} + if not preprocessing and desc_cfg: + try: + preprocessing = ( + desc_cfg.get("input_features", [{}])[0].get("preprocessing") or {} + ) + except Exception: + preprocessing = {} + + # If height/width are missing but max inferred dimensions exist, use them as fallback + if isinstance(preprocessing, dict): + if not preprocessing.get("height") and preprocessing.get("infer_image_max_height"): + preprocessing["height"] = preprocessing.get("infer_image_max_height") + if not preprocessing.get("width") and preprocessing.get("infer_image_max_width"): + preprocessing["width"] = preprocessing.get("infer_image_max_width") + + # Keep label path for downstream sampling + label_path = None + try: + label_path_cfg = config.get("label_column_data_path") + if label_path_cfg: + label_path = Path(label_path_cfg) + except Exception: + label_path = None + + return preprocessing if isinstance(preprocessing, dict) else {}, label_path + + def _find_last_conv_layer(self, encoder_obj: Any) -> Optional[Any]: + """Identify the last Conv2d layer within the encoder.""" + try: + import torch.nn as nn + except Exception: + return None + + target_model = encoder_obj + if hasattr(encoder_obj, "model"): + target_model = encoder_obj.model + + try: + modules = list(target_model.named_modules()) + except Exception: + return None + + for _, module in reversed(modules): + if isinstance(module, nn.Conv2d): + return module + return None + + def _generate_gradcam_heatmaps( + self, + exp_dir: Path, + config: Dict[str, Any], + output_type: Optional[str], + ) -> Dict[str, Any]: + """Compute Grad-CAM overlays for convolutional encoders, when possible.""" + result: Dict[str, Any] = { + "status": "skipped", + "reason": "", + "preview_paths": [], + "zip_path": None, + "dir_path": None, + } + + try: + import numpy as np + import torch + import torch.nn.functional as F + from matplotlib import cm + from PIL import Image + from ludwig.api import LudwigModel + except Exception as exc: + result["reason"] = f"Missing dependency for Grad-CAM: {exc}" + return result + + exp_dir = Path(exp_dir) + model_dir = exp_dir / "model" + if not model_dir.exists(): + result["reason"] = "Model directory not found; skipping Grad-CAM." + return result + + preprocessing, label_path = self._extract_preprocessing_config(exp_dir, config) + height = preprocessing.get("height") + width = preprocessing.get("width") + if not height or not width: + result["reason"] = "Image resize/height not found in Ludwig preprocessing." + return result + + label_csv = label_path if label_path and label_path.exists() else None + if not label_csv: + result["reason"] = "Prepared label CSV not available for Grad-CAM sampling." + return result + + try: + df_all = pd.read_csv(label_csv) + except Exception as exc: + result["reason"] = f"Could not read prepared CSV: {exc}" + return result + + if IMAGE_PATH_COLUMN_NAME not in df_all.columns: + result["reason"] = "Image column missing from prepared CSV; cannot build Grad-CAM inputs." + return result + + # Prefer test split; otherwise fall back to the full dataset + df_candidates = df_all + if SPLIT_COLUMN_NAME in df_all.columns: + try: + df_candidates = df_all[df_all[SPLIT_COLUMN_NAME] == 2] + if df_candidates.empty: + df_candidates = df_all + except Exception: + df_candidates = df_all + + # Cap the number of samples + df_candidates = df_candidates.head(12) + if df_candidates.empty: + result["reason"] = "No samples available for Grad-CAM generation." + return result + + try: + ludwig_model = LudwigModel.load(str(model_dir)) + except Exception as exc: + result["reason"] = f"Unable to load LudwigModel for Grad-CAM: {exc}" + return result + + base_model = getattr(ludwig_model, "model", None) + if base_model is None: + result["reason"] = "Ludwig model missing underlying torch model." + return result + + image_feature_name = None + image_feature = None + try: + for name, feat in getattr(base_model, "input_features", {}).items(): + if hasattr(feat, "encoder_obj"): + image_feature_name = name + image_feature = feat + break + except Exception: + image_feature_name = None + + if not image_feature or not image_feature_name: + result["reason"] = "Image input feature not found; skipping Grad-CAM." + return result + + target_layer = self._find_last_conv_layer(getattr(image_feature, "encoder_obj", None)) + if target_layer is None: + result["reason"] = "No convolutional layer detected in the encoder (heatmaps unsupported)." + return result + + standardize = preprocessing.get("standardize_image") + mean = preprocessing.get("mean") or preprocessing.get("img_mean") + std = preprocessing.get("std") or preprocessing.get("img_std") + encoder_obj = getattr(image_feature, "encoder_obj", None) + if hasattr(encoder_obj, "normalize_mean") and encoder_obj.normalize_mean: + mean = encoder_obj.normalize_mean + if hasattr(encoder_obj, "normalize_std") and encoder_obj.normalize_std: + std = encoder_obj.normalize_std + if isinstance(standardize, str) and standardize.lower() == "imagenet1k": + mean = [0.485, 0.456, 0.406] + std = [0.229, 0.224, 0.225] + if mean is None or std is None: + result["reason"] = "Normalization parameters (mean/std) not found in the saved encoder; skipping heatmaps to avoid mismatch." + return result + + output_feature_name = LABEL_COLUMN_NAME + try: + if getattr(base_model, "output_features", None): + output_feature_name = next(iter(base_model.output_features.keys())) + except Exception: + output_feature_name = LABEL_COLUMN_NAME + + device = torch.device("cpu") + try: + base_model.to(device) + base_model.eval() + except Exception: + logger.debug("Could not move model to CPU for Grad-CAM; continuing on default device.") + + heatmap_dir = exp_dir / "feature_importance_examples" + heatmap_dir.mkdir(parents=True, exist_ok=True) + + def _load_tensor(image_path: Path) -> Tuple[Optional[torch.Tensor], Optional[Image.Image]]: + try: + img = Image.open(image_path).convert("RGB") + except Exception: + return None, None + resized = img.resize((int(width), int(height))) + arr = np.asarray(resized).astype("float32") / 255.0 + arr = np.transpose(arr, (2, 0, 1)) + tensor = torch.from_numpy(arr) + try: + mean_tensor = torch.tensor(mean).view(-1, 1, 1) + std_tensor = torch.tensor(std).view(-1, 1, 1) + tensor = (tensor - mean_tensor) / std_tensor + except Exception: + return None, None + return tensor.unsqueeze(0).to(device), resized + + generated: List[Path] = [] + pairs: List[Tuple[Path, Path]] = [] + image_root = label_csv.parent + + for _, row in df_candidates.iterrows(): + raw_path = row.get(IMAGE_PATH_COLUMN_NAME) + if not isinstance(raw_path, str): + continue + abs_path = (image_root / raw_path).resolve() + if not abs_path.exists(): + continue + + tensor, resized_img = _load_tensor(abs_path) + if tensor is None or resized_img is None: + continue + + activations: List[torch.Tensor] = [] + gradients: List[torch.Tensor] = [] + + def _fwd_hook(_module, _inp, output): + activations.append(output) + + def _bwd_hook(_module, _grad_in, grad_out): + if grad_out and isinstance(grad_out[0], torch.Tensor): + gradients.append(grad_out[0]) + + handle_fwd = target_layer.register_forward_hook(_fwd_hook) + try: + handle_bwd = target_layer.register_full_backward_hook(_bwd_hook) + except Exception: + handle_bwd = target_layer.register_backward_hook(_bwd_hook) + + try: + base_model.zero_grad(set_to_none=True) + with torch.enable_grad(): + outputs = base_model({image_feature_name: tensor}) + + logits = None + if isinstance(outputs, dict): + feature_out = outputs.get(output_feature_name) + if isinstance(feature_out, dict): + logits = feature_out.get("logits") or feature_out.get("logit") + elif isinstance(feature_out, torch.Tensor): + logits = feature_out + + # Ludwig 0.10+ uses namespaced keys: "<feature>::logits" + if logits is None: + ns_key = f"{output_feature_name}::logits" + if isinstance(outputs.get(ns_key), torch.Tensor): + logits = outputs[ns_key] + + # Fallback: a top-level logits tensor + if logits is None and isinstance(outputs.get("logits"), torch.Tensor): + logits = outputs.get("logits") + + if logits is None: + raise ValueError("Could not locate logits for Grad-CAM.") + + if logits.dim() == 1: + target_logit = logits.unsqueeze(0) + else: + target_class = 0 + if output_type != "regression" and logits.shape[-1] > 1: + target_class = int(torch.argmax(logits, dim=-1).item()) + target_logit = logits[:, target_class] + + target_logit.sum().backward() + + if not activations or not gradients: + raise ValueError("Missing activations or gradients for Grad-CAM.") + + act = activations[-1] + grad = gradients[-1] + weights = grad.mean(dim=(2, 3), keepdim=True) + cam = (weights * act).sum(dim=1) + cam = torch.relu(cam) + cam = cam.squeeze(0) + cam = F.interpolate(cam.unsqueeze(0).unsqueeze(0), size=(int(height), int(width)), mode="bilinear", align_corners=False) + cam = cam.squeeze().detach().cpu().numpy() + if cam.max() > 0: + cam = cam / cam.max() + heatmap_rgba = np.uint8(cm.get_cmap("jet")(cam) * 255) + heatmap_img = Image.fromarray(heatmap_rgba).convert("RGBA").resize(resized_img.size) + overlay = Image.blend(resized_img.convert("RGBA"), heatmap_img, alpha=0.45) + + stem = Path(raw_path).stem + out_path = heatmap_dir / f"{stem}_gradcam.png" + overlay.save(out_path) + orig_path = heatmap_dir / f"{stem}_original.png" + try: + resized_img.save(orig_path) + except Exception: + orig_path = None + + generated.append(out_path) + if orig_path: + pairs.append((orig_path, out_path)) + except Exception as exc: + logger.warning("Grad-CAM failed for %s: %s", raw_path, exc) + finally: + try: + handle_fwd.remove() + handle_bwd.remove() + except Exception: + pass + + if not generated: + result["reason"] = "No heatmaps were generated (model may be non-convolutional or preprocessing missing)." + return result + + zip_path = exp_dir / "feature_importance_examples.zip" + try: + with zipfile.ZipFile(zip_path, "w", compression=zipfile.ZIP_DEFLATED) as zf: + for png in generated: + zf.write(png, png.name) + except Exception as exc: + logger.warning("Failed to create Grad-CAM zip: %s", exc) + + result.update( + { + "status": "generated", + "preview_paths": generated[:6], + "pairs": pairs[:6], + "zip_path": zip_path if zip_path.exists() else None, + "dir_path": heatmap_dir, + } + ) + return result + @staticmethod def _extract_metric_series(stats: Dict[str, Any], split: str, prefer: Optional[str] = None) -> Tuple[Optional[str], Optional[List[float]]]: """Pull the first numeric metric list we can find for the requested split.""" @@ -1473,6 +1839,8 @@ tab3_content = test_metrics_html + preds_section + gradcam_info = self._generate_gradcam_heatmaps(exp_dir, config, output_type) + if output_type == "regression" and train_stats_path.exists(): try: test_plots = build_regression_test_plots(str(train_stats_path)) @@ -1530,6 +1898,26 @@ except Exception as e: logger.warning(f"Could not generate test diagnostics: {e}") + if gradcam_info.get("status") == "generated": + tab3_content += "<h2 style='text-align: center;'>Grad-CAM Heatmaps</h2>" + for orig_path, heat_path in gradcam_info.get("pairs", [])[:4]: + try: + display_name = Path(str(orig_path)).name + if display_name.endswith("_original.png"): + display_name = display_name[: -len("_original.png")] + b64_orig = encode_image_to_base64(str(orig_path)) + b64_heat = encode_image_to_base64(str(heat_path)) + tab3_content += ( + "<div class='plot' style='margin-bottom:15px;text-align:center;display:flex;gap:12px;justify-content:center;flex-wrap:wrap;'>" + f"<div><div style='font-weight:600;margin-bottom:4px;'>{display_name}</div>" + f"<img src='data:image/png;base64,{b64_orig}' style='max-width:320px;max-height:320px;border:1px solid #ddd;' /></div>" + f"<div><div style='font-weight:600;margin-bottom:4px;'>Grad-CAM</div>" + f"<img src='data:image/png;base64,{b64_heat}' style='max-width:320px;max-height:320px;border:1px solid #ddd;' /></div>" + "</div>" + ) + except Exception as exc: + logger.debug("Could not embed Grad-CAM pair %s / %s: %s", orig_path, heat_path, exc) + # Add static TEST PNGs (with default dedupe/exclusions) tabbed_html = build_tabbed_html(tab1_content, tab2_content, tab3_content) modal_html = get_metrics_help_modal()
