diff points2label.py @ 6:22bb32eae6a1 draft default tip

planemo upload for repository https://github.com/BMCV/galaxy-image-analysis/tree/master/tools/points2labelimage/ commit edac062b00490276ef00d094e0594abdb3a3f23c
author imgteam
date Thu, 06 Nov 2025 09:59:34 +0000
parents 4a49f74a3c14
children
line wrap: on
line diff
--- a/points2label.py	Mon May 12 14:01:26 2025 +0000
+++ b/points2label.py	Thu Nov 06 09:59:34 2025 +0000
@@ -1,12 +1,11 @@
 import argparse
 import json
-import os
 import warnings
 from typing import (
+    Any,
     Dict,
-    List,
+    Optional,
     Tuple,
-    Union,
 )
 
 import giatools.pandas
@@ -14,219 +13,276 @@
 import numpy.typing as npt
 import pandas as pd
 import scipy.ndimage as ndi
+import skimage.draw
 import skimage.io
 import skimage.segmentation
 
 
-def is_rectangular(points: Union[List[Tuple[float, float]], npt.NDArray]) -> bool:
-    points = np.asarray(points)
-
-    # Rectangle must have 5 points, where first and last are identical
-    if len(points) != 5 or not (points[0] == points[-1]).all():
-        return False
-
-    # Check that all edges align with the axes
-    edges = points[1:] - points[:-1]
-    if any((edge == 0).sum() != 1 for edge in edges):
-        return False
-
-    # All checks have passed, the geometry is rectangular
-    return True
+def get_list_depth(nested_list: Any) -> int:
+    if isinstance(nested_list, list):
+        if len(nested_list) > 0:
+            return 1 + max(map(get_list_depth, nested_list))
+        else:
+            return 1
+    else:
+        return 0
 
 
-def geojson_to_tabular(geojson: Dict):
-    rows = []
-    labels = []
-    for feature in geojson['features']:
-        assert feature['geometry']['type'].lower() == 'polygon', (
-            f'Unsupported geometry type: "{feature["geometry"]["type"]}"'
-        )
-        coords = feature['geometry']['coordinates'][0]
+class AutoLabel:
+    """
+    Creates a sequence of unique labels (non-negative values).
+    """
 
-        # Properties and name (label) are optional
-        try:
-            label = feature['properties']['name']
-        except KeyError:
-            label = max(labels, default=0) + 1
-        labels.append(label)
+    def __init__(self, reserved_labels):
+        self.reserved_labels = reserved_labels
+        self.next_autolabel = 0
 
-        # Read geometry
-        xs = [pt[0] for pt in coords]
-        ys = [pt[1] for pt in coords]
+    def next(self):
+        """
+        Retrieve the next auto-label (post-increment).
+        """
+        # Fast-forward `next_autolabel` to the first free label
+        while self.next_autolabel in self.reserved_labels:
+            self.next_autolabel += 1
 
-        x = min(xs)
-        y = min(ys)
+        # Return the free label, then advance `next_autolabel`
+        try:
+            return self.next_autolabel
+        finally:
+            self.next_autolabel += 1
 
-        width = max(xs) + 1 - x
-        height = max(ys) + 1 - y
-
-        # Validate geometry (must be rectangular)
-        assert is_rectangular(list(zip(xs, ys)))
 
-        # Append the rectangle
-        rows.append({
-            'pos_x': x,
-            'pos_y': y,
-            'width': width,
-            'height': height,
-            'label': label,
-        })
-    df = pd.DataFrame(rows)
-    point_file = './point_file.tabular'
-    df.to_csv(point_file, sep='\t', index=False)
-    return point_file
+def get_feature_label(feature: Dict) -> Optional[int]:
+    """
+    Get the label of a GeoJSON feature, or `None` if there is no proper label.
+    """
+    label = feature.get('properties', {}).get('name', None)
+    if label is None:
+        return None
+
+    # If the `label` is given as a string, try to parse as integer
+    if isinstance(label, str):
+        try:
+            label = int(label)
+        except ValueError:
+            pass
+
+    # Finally, if `label` is an integer, only use it if it is non-negative
+    if isinstance(label, int) and label >= 0:
+        return label
+    else:
+        return None
 
 
-def rasterize(point_file, out_file, shape, has_header=False, swap_xy=False, bg_value=0, fg_value=None):
+def rasterize(
+    geojson: Dict,
+    shape: Tuple[int, int],
+    bg_value: int = 0,
+    fg_value: Optional[int] = None,
+) -> npt.NDArray:
+    """
+    Rasterize GeoJSON into a pixel image, that is returned as a NumPy array.
+    """
 
-    img = np.full(shape, dtype=np.uint16, fill_value=bg_value)
-    if os.path.exists(point_file) and os.path.getsize(point_file) > 0:
+    # Determine which labels are reserved (not used by auto-label)
+    reserved_labels = [bg_value]
+    if fg_value is None:
+        for feature in geojson['features']:
+            label = get_feature_label(feature)
+            if label is not None:
+                reserved_labels.append(label)
 
-        # Read the tabular file with information from the header
-        if has_header:
-            df = pd.read_csv(point_file, delimiter='\t')
+    # Convert `reserved_labels` into a `set` for faster look-ups
+    reserved_labels = frozenset(reserved_labels)
+
+    # Define routine to retrieve the next auto-label
+    autolabel = AutoLabel(reserved_labels)
 
-            pos_x_column = giatools.pandas.find_column(df, ['pos_x', 'POS_X'])
-            pos_y_column = giatools.pandas.find_column(df, ['pos_y', 'POS_Y'])
-            pos_x_list = df[pos_x_column].round().astype(int)
-            pos_y_list = df[pos_y_column].round().astype(int)
-            assert len(pos_x_list) == len(pos_y_list)
+    # Rasterize the image
+    img = np.full(shape, dtype=np.uint16, fill_value=bg_value)
+    for feature in geojson['features']:
+        geom_type = feature['geometry']['type'].lower()
+        coords = feature['geometry']['coordinates']
+
+        # Rasterize a `mask` separately for each feature
+        if geom_type == 'polygon':
+
+            # Normalization: Let there always be a list of polygons
+            if get_list_depth(coords) == 2:
+                coords = [coords]
 
-            try:
-                radius_column = giatools.pandas.find_column(df, ['radius', 'RADIUS'])
-                radius_list = df[radius_column]
-                assert len(pos_x_list) == len(radius_list)
-            except KeyError:
-                radius_list = [0] * len(pos_x_list)
+            # Rasterize each polygon separately, then join via XOR
+            mask = np.zeros(shape, dtype=bool)
+            for polygon_coords in coords:
+                polygon_mask = skimage.draw.polygon2mask(
+                    shape,
+                    [point[::-1] for point in polygon_coords],
+                )
+                mask = np.logical_xor(mask, polygon_mask)
+
+        elif geom_type == 'point':
+            mask = np.zeros(shape, dtype=bool)
+            mask[coords[1], coords[0]] = True
+            radius = feature.get('properties', {}).get('radius', 0)
+            if radius > 0:
+                mask = (ndi.distance_transform_edt(~mask) <= radius)
+
+        else:
+            raise ValueError(
+                f'Unsupported geometry type: "{feature["geometry"]["type"]}"',
+            )
 
-            try:
-                width_column = giatools.pandas.find_column(df, ['width', 'WIDTH'])
-                height_column = giatools.pandas.find_column(df, ['height', 'HEIGHT'])
-                width_list = df[width_column]
-                height_list = df[height_column]
-                assert len(pos_x_list) == len(width_list)
-                assert len(pos_x_list) == len(height_list)
-            except KeyError:
-                width_list = [0] * len(pos_x_list)
-                height_list = [0] * len(pos_x_list)
+        # Determine the `label` for the current `mask`
+        if fg_value is None:
+            label = get_feature_label(feature)
+            if label is None:
+                label = autolabel.next()
+        else:
+            label = fg_value
+
+        # Blend the current `mask` with the rasterized image
+        img[mask] = label
+
+    # Return the rasterized image
+    return img
+
+
+def convert_tabular_to_geojson(
+    tabular_file: str,
+    has_header: bool,
+) -> dict:
+    """
+    Read a tabular file and convert it to GeoJSON.
+
+    The GeoJSON data is returned as a dictionary.
+    """
+
+    # Read the tabular file with information from the header
+    if has_header:
+        df = pd.read_csv(tabular_file, delimiter='\t')
 
-            try:
-                label_column = giatools.pandas.find_column(df, ['label', 'LABEL'])
-                label_list = df[label_column]
-                assert len(pos_x_list) == len(label_list)
-            except KeyError:
-                label_list = list(range(1, len(pos_x_list) + 1))
+        pos_x_column = giatools.pandas.find_column(df, ['pos_x', 'POS_X'])
+        pos_y_column = giatools.pandas.find_column(df, ['pos_y', 'POS_Y'])
+        pos_x_list = df[pos_x_column].round().astype(int)
+        pos_y_list = df[pos_y_column].round().astype(int)
+        assert len(pos_x_list) == len(pos_y_list)
+
+        try:
+            radius_column = giatools.pandas.find_column(df, ['radius', 'RADIUS'])
+            radius_list = df[radius_column]
+            assert len(pos_x_list) == len(radius_list)
+        except KeyError:
+            radius_list = [0] * len(pos_x_list)
 
-        # Read the tabular file without header
-        else:
-            df = pd.read_csv(point_file, header=None, delimiter='\t')
-            pos_x_list = df[0].round().astype(int)
-            pos_y_list = df[1].round().astype(int)
-            assert len(pos_x_list) == len(pos_y_list)
-            radius_list, width_list, height_list = [[0] * len(pos_x_list)] * 3
+        try:
+            width_column = giatools.pandas.find_column(df, ['width', 'WIDTH'])
+            height_column = giatools.pandas.find_column(df, ['height', 'HEIGHT'])
+            width_list = df[width_column]
+            height_list = df[height_column]
+            assert len(pos_x_list) == len(width_list)
+            assert len(pos_x_list) == len(height_list)
+        except KeyError:
+            width_list = [0] * len(pos_x_list)
+            height_list = [0] * len(pos_x_list)
+
+        try:
+            label_column = giatools.pandas.find_column(df, ['label', 'LABEL'])
+            label_list = df[label_column]
+            assert len(pos_x_list) == len(label_list)
+        except KeyError:
             label_list = list(range(1, len(pos_x_list) + 1))
 
-        # Optionally swap the coordinates
-        if swap_xy:
-            pos_x_list, pos_y_list = pos_y_list, pos_x_list
-
-        # Perform the rasterization
-        for y, x, radius, width, height, label in zip(
-            pos_y_list, pos_x_list, radius_list, width_list, height_list, label_list,
-        ):
-            if fg_value is not None:
-                label = fg_value
-
-            if y < 0 or x < 0 or y >= shape[0] or x >= shape[1]:
-                raise IndexError(f'The point x={x}, y={y} exceeds the bounds of the image (width: {shape[1]}, height: {shape[0]})')
+    # Read the tabular file without header
+    else:
+        df = pd.read_csv(tabular_file, header=None, delimiter='\t')
+        pos_x_list = df[0].round().astype(int)
+        pos_y_list = df[1].round().astype(int)
+        assert len(pos_x_list) == len(pos_y_list)
+        radius_list, width_list, height_list = [[0] * len(pos_x_list)] * 3
+        label_list = list(range(1, len(pos_x_list) + 1))
 
-            # Rasterize circle and distribute overlapping image area
-            # Rasterize primitive geometry
-            if radius > 0 or (width > 0 and height > 0):
-
-                # Rasterize circle
-                if radius > 0:
-                    mask = np.ones(shape, dtype=bool)
-                    mask[y, x] = False
-                    mask = (ndi.distance_transform_edt(mask) <= radius)
-                else:
-                    mask = np.zeros(shape, dtype=bool)
-
-                # Rasterize rectangle
-                if width > 0 and height > 0:
-                    mask[
-                        y:min(shape[0], y + width),
-                        x:min(shape[1], x + height)
-                    ] = True
+    # Convert to GeoJSON
+    features = []
+    geojson = {
+        'type': 'FeatureCollection',
+        'features': features,
+    }
+    for y, x, radius, width, height, label in zip(
+        pos_y_list, pos_x_list, radius_list, width_list, height_list, label_list,
+    ):
+        if radius > 0 and width > 0 and height > 0:
+            raise ValueError('Ambiguous shape type (circle or rectangle)')
 
-                # Compute the overlap (pretend there is none if the rasterization is binary)
-                if fg_value is None:
-                    overlap = np.logical_and(img > 0, mask)
-                else:
-                    overlap = np.zeros(shape, dtype=bool)
+        # Create a rectangle
+        if width > 0 and height > 0:
+            geom_type = 'Polygon'
+            coords = [
+                [x, y],
+                [x + width - 1, y],
+                [x + width - 1, y + height - 1],
+                [x, y + height - 1],
+            ]
 
-                # Rasterize the part of the circle which is disjoint from other foreground.
-                #
-                # In the current implementation, the result depends on the order of the rasterized circles if somewhere
-                # more than two circles overlap. This is probably negligable for most applications. To achieve results
-                # that are invariant to the order, first all circles would need to be rasterized independently, and
-                # then blended together. This, however, would either strongly increase the memory consumption, or
-                # require a more complex implementation which exploits the sparsity of the rasterized masks.
-                #
-                disjoint_mask = np.logical_xor(mask, overlap)
-                if disjoint_mask.any():
-                    img[disjoint_mask] = label
+        # Create a point or circle
+        else:
+            geom_type = 'Point'
+            coords = [x, y]
 
-                    # Distribute the remaining part of the circle
-                    if overlap.any():
-                        dist = ndi.distance_transform_edt(overlap)
-                        foreground = (img > 0)
-                        img[overlap] = 0
-                        img = skimage.segmentation.watershed(dist, img, mask=foreground)
+        # Create a GeoJSON feature
+        feature = {
+            'type': 'Feature',
+            'geometry': {
+                'type': geom_type,
+                'coordinates': coords,
+            },
+            'properties': {
+                'name': label,
+            },
+        }
+        if radius > 0:
+            feature['properties']['radius'] = radius
+            feature['properties']['subType'] = 'Circle'
+        features.append(feature)
 
-            # Rasterize point (there is no overlapping area to be distributed)
-            else:
-                img[y, x] = label
-
-    else:
-        raise Exception('{} is empty or does not exist.'.format(point_file))  # appropriate built-in error?
-
-    with warnings.catch_warnings():
-        warnings.simplefilter("ignore")
-        skimage.io.imsave(out_file, img, plugin='tifffile')  # otherwise we get problems with the .dat extension
+    # Return the GeoJSON object
+    return geojson
 
 
 if __name__ == '__main__':
     parser = argparse.ArgumentParser()
-    parser.add_argument('in_file', type=argparse.FileType('r'), help='Input point file or GeoJSON file')
-    parser.add_argument('out_file', type=str, help='out file (TIFF)')
-    parser.add_argument('shapex', type=int, help='shapex')
-    parser.add_argument('shapey', type=int, help='shapey')
-    parser.add_argument('--has_header', dest='has_header', default=False, help='set True if point file has header')
+    parser.add_argument('in_ext', type=str, help='Input file format')
+    parser.add_argument('in_file', type=str, help='Input file path (tabular or GeoJSON)')
+    parser.add_argument('out_file', type=str, help='Output file path (TIFF)')
+    parser.add_argument('shapex', type=int, help='Output image width')
+    parser.add_argument('shapey', type=int, help='Output image height')
+    parser.add_argument('--has_header', dest='has_header', default=False, help='Set True if tabular file has a header')
     parser.add_argument('--swap_xy', dest='swap_xy', default=False, help='Swap X and Y coordinates')
     parser.add_argument('--binary', dest='binary', default=False, help='Produce binary image')
-
     args = parser.parse_args()
 
-    point_file = args.in_file.name
-    has_header = args.has_header
+    # Validate command-line arguments
+    assert args.in_ext in ('tabular', 'geojson'), (
+        f'Unexpected input file format: {args.in_ext}'
+    )
 
-    try:
-        with open(args.in_file.name, 'r') as f:
-            content = json.load(f)
-            if isinstance(content, dict) and content.get('type') == 'FeatureCollection' and isinstance(content.get('features'), list):
-                point_file = geojson_to_tabular(content)
-                has_header = True  # header included in the converted file
-            else:
-                raise ValueError('Input is a JSON file but not a valid GeoJSON file')
-    except json.JSONDecodeError:
-        print('Input is not a valid JSON file. Assuming it a tabular file.')
+    # Load the GeoJSON data (if the input file is tabular, convert to GeoJSON)
+    if args.in_ext == 'tabular':
+        geojson = convert_tabular_to_geojson(args.in_file, args.has_header)
+    else:
+        with open(args.in_file) as f:
+            geojson = json.load(f)
 
-    rasterize(
-        point_file,
-        args.out_file,
-        (args.shapey, args.shapex),
-        has_header=has_header,
-        swap_xy=args.swap_xy,
+    # Rasterize the image from GeoJSON
+    shape = (args.shapey, args.shapex)
+    img = rasterize(
+        geojson,
+        shape if not args.swap_xy else shape[::-1],
         fg_value=0xffff if args.binary else None,
     )
+    if args.swap_xy:
+        img = img.T
+
+    # Write the rasterized image as TIFF
+    with warnings.catch_warnings():
+        warnings.simplefilter('ignore')
+        skimage.io.imsave(args.out_file, img, plugin='tifffile')  # otherwise we get problems with the .dat extension