view sam3_semantic_segmentation.py @ 2:809f339deae2 draft default tip

planemo upload for repository https://github.com/galaxyecology/tools-ecology/tree/master/tools/Sam3 commit 7b696f5e3039fc7f6f1b8ceb3d5262e230d0ab57
author ecology
date Tue, 10 Mar 2026 10:42:43 +0000
parents 1c90ba574b70
children
line wrap: on
line source

import argparse
import hashlib
import json
import os
from pathlib import Path
from typing import Any, Dict, List

import cv2

import numpy as np

from pycocotools.coco import COCO

import ultralytics
from ultralytics.models.sam import SAM3SemanticPredictor
from ultralytics.models.sam.predict import SAM3VideoSemanticPredictor

# -------- Constants --------
VIDEO_EXTENSIONS = {".mp4", ".avi", ".mov", ".mkv"}
DEFAULT_CONFIDENCE = 0.25
DEFAULT_VID_STRIDE = 5


# -------- Arguments --------
def parse_arguments() -> argparse.Namespace:
    """Parse and validate command-line arguments."""
    parser = argparse.ArgumentParser(
        description="SAM3 to COCO/YOLO exporter with semantic segmentation"
    )
    parser.add_argument(
        "--model", type=str, help="Output directory for annotations and images"
    )
    parser.add_argument(
        "--outdir",
        type=str,
        default="outputs/",
        help="Output directory for annotations and images",
    )
    parser.add_argument(
        "--prompts",
        type=str,
        required=True,
        help="Comma-separated list of class text prompts"
        "(e.g., 'human,elephant')",
    )
    parser.add_argument(
        "--conf",
        type=float,
        default=DEFAULT_CONFIDENCE,
        help="Confidence threshold for predictions",
    )
    parser.add_argument(
        "--vid_stride",
        type=int,
        default=DEFAULT_VID_STRIDE,
        help="Frame stride for video prediction (process every Nth frame)",
    )
    parser.add_argument(
        "--outputs",
        type=str,
        default="",
        help="Comma-separated output formats: 'coco', 'yolo_bbox', 'yolo_seg'",
    )
    parser.add_argument(
        "--name_file",
        type=str,
        default=None,
        help="Specific filename to process (optional)",
    )
    return parser.parse_args()


# -------- Functions --------


def is_video(file_path: str) -> bool:
    """Check if a file is a video based on its extension."""
    return Path(file_path).suffix.lower() in VIDEO_EXTENSIONS


def compute_file_hash(filepath: Path) -> str:
    """Compute SHA256 hash of a file."""
    hasher = hashlib.sha256()
    with open(filepath, "rb") as f:
        # Read file in chunks to handle large files efficiently
        for chunk in iter(lambda: f.read(8192), b""):
            hasher.update(chunk)
    return hasher.hexdigest()


def validate_coco_format(annotation_file: Path) -> bool:
    """Validate COCO JSON format."""
    try:
        COCO(str(annotation_file))
        print(f"✓ COCO file is valid: {annotation_file}")
        return True
    except Exception as e:
        print(f"COCO format error: {e}")
        return False


def create_coco_categories(text_prompts: List[str]) -> List[Dict[str, Any]]:
    """
    Create COCO categories from text prompts.
    Returns: List of category dictionaries
    """
    return [
        {"id": i + 1, "name": label}
        for i, label in enumerate(text_prompts)
    ]


def create_coco_output(
    results: List[Any], text_prompts: List[str], metadata: Dict[str, Any]
) -> Dict[str, Any]:
    """Convert SAM3 results to COCO format."""
    coco_output = {
        "info": metadata,
        "images": [],
        "annotations": [],
        "categories": create_coco_categories(text_prompts),
    }

    annotation_id = 1

    for image_idx, result in enumerate(results):
        if result.masks is None:
            continue

        height, width = result.orig_shape
        image_id = image_idx + 1
        filename = Path(result.path).name

        # Add image information
        coco_output["images"].append(
            {
                "id": image_id,
                "file_name": filename,
                "width": width,
                "height": height,
            }
        )

        # Add annotations for each detected object
        for polygon, bbox, class_id in zip(
            result.masks.xyn, result.boxes.xyxyn, result.boxes.cls
        ):
            # Flatten polygon coordinates
            polygon_flat = polygon.flatten().tolist()

            # Extract bounding box coordinates (normalized)
            x1, y1, x2, y2 = bbox[:4].tolist()

            # Calculate area using contour
            area = float(cv2.contourArea(polygon.astype(np.float32)))

            coco_output["annotations"].append(
                {
                    "id": annotation_id,
                    "image_id": image_id,
                    "category_id": int(class_id) + 1,
                    "segmentation": [polygon_flat],
                    "area": area,
                    "bbox": [x1, y1, x2, y2],
                    "iscrowd": 0,
                }
            )
            annotation_id += 1

    return coco_output


def create_yolo_bbox_annotation(box: np.ndarray, class_id: int) -> str:
    """Create YOLO bounding box annotation line."""
    x1, y1, x2, y2 = box[:4].tolist()

    # Convert to YOLO format (center_x, center_y, width, height)
    x_center = (x1 + x2) / 2
    y_center = (y1 + y2) / 2
    bbox_width = x2 - x1
    bbox_height = y2 - y1

    return (
        f"{class_id} {x_center:.6f} "
        f"{y_center:.6f} {bbox_width:.6f} "
        f"{bbox_height:.6f}"
    )


def create_yolo_seg_annotation(polygon: np.ndarray, class_id: int) -> str:
    """Create YOLO segmentation annotation line."""
    # Flatten polygon coordinates
    flattened = [f"{coord:.6f}" for point in polygon for coord in point]
    return f"{class_id} " + " ".join(flattened)


def create_yolo_output(
    annotation_type: str, results: List[Any], output_dir: Path
) -> None:
    """Export annotations in YOLO format for images."""
    # Create subdirectories for images and labels
    images_dir = output_dir / "images"
    labels_dir = output_dir / "labels"
    images_dir.mkdir(exist_ok=True, parents=True)
    labels_dir.mkdir(exist_ok=True, parents=True)

    for result in results:
        # Generate output filename based on source image
        image_name = Path(result.path).stem

        # Copy image to images directory
        image_src = Path(result.path)
        image_dst = images_dir / f"{image_name}{image_src.suffix}"

        import shutil

        shutil.copy2(image_src, image_dst)

        # Create label file path
        output_path = labels_dir / f"{image_name}.txt"

        lines = []

        # Process each detection
        for i, (box, class_id) in enumerate(
            zip(result.boxes.xyxyn, result.boxes.cls)
        ):
            class_id = int(class_id)

            if annotation_type == "bbox":
                line = create_yolo_bbox_annotation(box, class_id)
            else:  # segmentation
                if result.masks is None or not hasattr(result.masks, "xyn"):
                    continue
                polygon = result.masks.xyn[i]
                line = create_yolo_seg_annotation(polygon, class_id)

            lines.append(line)

        # Write annotations to file
        with open(output_path, "w") as f:
            f.write("\n".join(lines))

    print(f"✓ Created {len(results)} images and labels in {output_dir}")


def create_yolo_video_output(
    annotation_type: str,
    results: List[Any],
    output_dir: Path,
    video_path: str,
    stride: int,
) -> None:
    """Export annotations in YOLO format for videos with frame extraction."""
    # Create subdirectories for images and labels
    images_dir = output_dir / "images"
    labels_dir = output_dir / "labels"
    images_dir.mkdir(exist_ok=True, parents=True)
    labels_dir.mkdir(exist_ok=True, parents=True)

    # Open video file
    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        raise RuntimeError(f"Failed to open video: {video_path}")

    video_name = Path(video_path).stem

    # Initialize counters
    frame_idx = 1 if stride > 1 else 0
    saved_idx = 0

    print(f"Processing video frames (stride={stride})...")

    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break

        # Process only frames according to stride
        if frame_idx % stride == 0:
            frame_base = f"{video_name}_frame_{frame_idx:06d}"

            # Save frame as image
            img_path = images_dir / f"{frame_base}.jpg"
            cv2.imwrite(str(img_path), frame)

            # Create corresponding label file
            label_path = labels_dir / f"{frame_base}.txt"

            # Get prediction result for this frame
            if saved_idx >= len(results):
                print(f"Warning: No result available for frame {frame_idx}")
                break

            result = results[saved_idx]
            lines = []

            # Process detections if available
            if result.boxes is not None:
                for i, (box, class_id) in enumerate(
                    zip(result.boxes.xyxyn, result.boxes.cls)
                ):
                    class_id = int(class_id)

                    if annotation_type == "bbox":
                        line = create_yolo_bbox_annotation(box, class_id)
                    else:  # segmentation
                        if result.masks is None or not hasattr(
                            result.masks, "xyn"
                        ):
                            continue
                        polygon = result.masks.xyn[i]
                        line = create_yolo_seg_annotation(polygon, class_id)

                    lines.append(line)

            # Write label file
            with open(label_path, "w") as f:
                f.write("\n".join(lines))

            saved_idx += 1

            if saved_idx % 10 == 0:
                print(f"  Processed {saved_idx} frames...")

        frame_idx += 1

    cap.release()
    print(f"✓ Created {saved_idx} frames and labels in {output_dir}")


def create_metadata(
    text_prompts: List[str], conf_threshold: float, model_path: str
) -> Dict[str, Any]:
    """Create metadata dictionary for COCO export."""
    model_path = Path(model_path)

    return {
        "description": "SAM3 semantic segmentation export",
        "model": "sam3.pt",
        "model_sha256": compute_file_hash(
            model_path) if model_path.exists() else "N/A",
        "prompts": text_prompts,
        "confidence_threshold": conf_threshold,
        "ultralytics_version": ultralytics.__version__,
    }


# -------- Main --------


def main():

    # Parse arguments
    args = parse_arguments()

    # Parse text prompts
    text_prompts = [prompt.strip() for prompt in args.prompts.split(",")]
    print(f"\nClass prompts: {text_prompts}")
    print(f"args.name_file: {args.name_file}")

    # Setup paths relative to script location
    folder = Path("data_files").resolve()
    print(f"folder: {folder}")
    os.system("ls -al data_files/")
    # Get all files in the data folder
    file_paths = [str(f) for f in folder.glob("*") if f.is_file()]
    print(f"file_paths: {file_paths}")

    if not file_paths:
        print("Error: No files found in data_files directory")
        return

    # Use folder path as source for Ultralytics
    source_path = str(folder)

    # Setup output directories
    outdir = Path(args.outdir)
    outdir.mkdir(exist_ok=True, parents=True)

    outputs_annotated = outdir / "outputs_annotated"
    outputs_annotated.mkdir(parents=True, exist_ok=True)

    output_formats = [fmt.strip() for fmt in args.outputs.split(",")]
    print(f"Output formats: {output_formats}")

    # Configure predictor overrides
    overrides = {
        "conf": args.conf,
        "show_conf": False,
        "task": "segment",
        "mode": "predict",
        "model": args.model,
        "half": True,  # Use FP16 for faster inference
        "save": True,
        "save_dir": str(outputs_annotated),
    }

    # Determine if input is video or image and initialize appropriate predictor
    if is_video(file_paths[0]):
        print("\nVideo input detected → using SAM3VideoSemanticPredictor")
        overrides["vid_stride"] = args.vid_stride
        predictor = SAM3VideoSemanticPredictor(overrides=overrides)
    else:
        print("\nImage input detected → using SAM3SemanticPredictor")
        predictor = SAM3SemanticPredictor(overrides=overrides)

    # Patch predictor to include custom class names
    original_postprocess = predictor.postprocess

    def patched_postprocess(preds, img, orig_imgs):
        """Patch postprocess to add custom class names to results."""
        results = original_postprocess(preds, img, orig_imgs)
        # Add class names to each result
        for r in results:
            r.names = {i: name for i, name in enumerate(text_prompts)}
        return results

    predictor.postprocess = patched_postprocess

    # Run predictions
    # print(f"\n Running prediction on {source_path}...")
    results = predictor(source=source_path, text=text_prompts, stream=False)

    if not results:
        raise RuntimeError("SAM3 returned no results")

    print(f"✓ Processed {len(results)} result(s)")

    # Create metadata
    metadata = create_metadata(text_prompts, args.conf, args.model)

    # Export in requested formats
    print("\n" + "=" * 60)
    print("EXPORTING RESULTS")
    print("=" * 60)

    if "coco" in output_formats:
        print("\n→ Converting to COCO format...")
        coco_output = create_coco_output(results, text_prompts, metadata)

        annotation_file = outdir / "annotations.json"
        with open(annotation_file, "w") as f:
            json.dump(coco_output, f, indent=4)

        print(f"  Saved: {annotation_file}")
        validate_coco_format(annotation_file)

    if "yolo_bbox" in output_formats:
        print("\n→ Exporting YOLO bbox annotations...")
        yolo_bbox_dir = outdir / "yolo_bbox"
        yolo_bbox_dir.mkdir(parents=True, exist_ok=True)

        if is_video(file_paths[0]):
            create_yolo_video_output(
                "bbox", results, yolo_bbox_dir, file_paths[0], args.vid_stride
            )
        else:
            create_yolo_output("bbox", results, yolo_bbox_dir)
            print(f"  Saved to: {yolo_bbox_dir}")

    if "yolo_seg" in output_formats:
        print("\n→ Exporting YOLO seg annotations...")
        yolo_seg_dir = outdir / "yolo_seg"
        yolo_seg_dir.mkdir(parents=True, exist_ok=True)

        if is_video(file_paths[0]):
            create_yolo_video_output(
                "seg", results, yolo_seg_dir, file_paths[0], args.vid_stride
            )
        else:
            create_yolo_output("seg", results, yolo_seg_dir)
            print(f"  Saved to: {yolo_seg_dir}")

    print("\n" + "=" * 60)
    print("✓ EXPORT COMPLETED SUCCESSFULLY!")
    print("=" * 10)


if __name__ == "__main__":
    main()