diff segment_text_yolo.py @ 0:28b4dc80d58b draft default tip

planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/image_processing/yolo-utils/doclayoutyolo commit 80167f52fb9bd60b57e4df9d68152876171228d6
author bgruening
date Fri, 13 Jun 2025 14:54:15 +0000
parents
children
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/segment_text_yolo.py	Fri Jun 13 14:54:15 2025 +0000
@@ -0,0 +1,88 @@
+"""
+Segment text using DocLayout Yolo model
+"""
+
+import argparse
+import json
+import os
+
+import cv2
+from doclayout_yolo import YOLOv10
+from geojson import Feature, FeatureCollection
+from shapely.geometry import box, mapping
+
+
+def load_model_and_predict(
+    model_path, input_image_path, input_confidence, image_size, output_image_path
+):
+
+    model = YOLOv10(model=model_path)
+
+    det_res = model.predict(
+        input_image_path, imgsz=int(image_size), conf=float(input_confidence)
+    )
+    annotated_frame = det_res[0].plot(pil=True, line_width=5, font_size=20)
+    cv2.imwrite(output_image_path, annotated_frame)
+    return det_res[0]
+
+
+def extract_bb_crop(results, output_segmentation_coordiates):
+    bounding_boxes = []
+    features = []
+    for bx in results.boxes.xyxy.cpu().numpy():
+        x1, y1, x2, y2 = bx
+        bounding_boxes.append((x1, y1, x2, y2))
+
+    for i, (x1, y1, x2, y2) in enumerate(bounding_boxes):
+        poly = box(x1, y1, x2, y2)
+        feature = Feature(geometry=mapping(poly), properties={"id": i})
+        features.append(feature)
+
+    geojson_obj = FeatureCollection(features)
+
+    with open(output_segmentation_coordiates, "w") as f:
+        json.dump(geojson_obj, f)
+
+
+if __name__ == "__main__":
+    arg_parser = argparse.ArgumentParser()
+    arg_parser.add_argument(
+        "-im", "--yolo_model", required=True, help="Input Yolo model"
+    )
+    arg_parser.add_argument(
+        "-ii", "--input_image", required=True, help="Input image file"
+    )
+    arg_parser.add_argument(
+        "-ie", "--input_image_ext", required=True, help="Input image file extension"
+    )
+    arg_parser.add_argument(
+        "-ic", "--input_confidence", required=True, help="Input confidence"
+    )
+    arg_parser.add_argument(
+        "-is", "--input_image_size", required=True, help="Input image size"
+    )
+    arg_parser.add_argument("-oi", "--output_image", required=True, help="Output image")
+    arg_parser.add_argument(
+        "-ogj", "--output_geojson", required=True, help="Output segmented coordinates"
+    )
+    args = vars(arg_parser.parse_args())
+    model_path = args["yolo_model"]
+    input_image_path = args["input_image"]
+    input_ext = args["input_image_ext"]
+    confidence = args["input_confidence"]
+    image_size = args["input_image_size"]
+    output_image_path = args["output_image"]
+    output_segmentation_coordiates = args["output_geojson"]
+
+    model_link = "yolo_model.pt"
+    input_image = f"input_image.{input_ext}"
+    output_image = f"output_image.{input_ext}"
+
+    os.symlink(model_path, model_link)
+    os.symlink(input_image_path, input_image)
+    os.symlink(output_image_path, output_image)
+
+    segmented_image = load_model_and_predict(
+        model_link, input_image, confidence, image_size, output_image
+    )
+    extract_bb_crop(segmented_image, output_segmentation_coordiates)