diff points2label.py @ 3:de611b3b5ae8 draft default tip

planemo upload for repository https://github.com/BMCV/galaxy-image-analysis/tree/master/tools/points2labelimage/ commit 6fc9ab8db9ef72ac7ded30d7373768feeae9390d
author imgteam
date Fri, 27 Sep 2024 17:41:21 +0000
parents 30ca5d5d03ec
children
line wrap: on
line diff
--- a/points2label.py	Mon Nov 13 22:11:46 2023 +0000
+++ b/points2label.py	Fri Sep 27 17:41:21 2024 +0000
@@ -1,47 +1,125 @@
 import argparse
-import sys
+import os
 import warnings
 
+import giatools.pandas
 import numpy as np
 import pandas as pd
+import scipy.ndimage as ndi
 import skimage.io
+import skimage.segmentation
 
 
-def points2label(labels, shape, output_file=None, has_header=False, is_TSV=True):
-    labelimg = np.zeros([shape[0], shape[1]], dtype=np.int32)
+def rasterize(point_file, out_file, shape, has_header=False, swap_xy=False, bg_value=0, fg_value=None):
 
-    if is_TSV:
+    img = np.full(shape, dtype=np.uint16, fill_value=bg_value)
+    if os.path.exists(point_file) and os.path.getsize(point_file) > 0:
+
+        # Read the tabular file with information from the header
         if has_header:
-            df = pd.read_csv(labels, sep='\t', skiprows=1, header=None)
-        else:
-            df = pd.read_csv(labels, sep='\t', header=None)
-    else:
-        if has_header:
-            df = pd.read_csv(labels, skiprows=1, header=None)
+            df = pd.read_csv(point_file, delimiter='\t')
+
+            pos_x_column = giatools.pandas.find_column(df, ['pos_x', 'POS_X'])
+            pos_y_column = giatools.pandas.find_column(df, ['pos_y', 'POS_Y'])
+            pos_x_list = df[pos_x_column].round().astype(int)
+            pos_y_list = df[pos_y_column].round().astype(int)
+            assert len(pos_x_list) == len(pos_y_list)
+
+            try:
+                radius_column = giatools.pandas.find_column(df, ['radius', 'RADIUS'])
+                radius_list = df[radius_column]
+                assert len(pos_x_list) == len(radius_list)
+            except KeyError:
+                radius_list = [0] * len(pos_x_list)
+
+            try:
+                label_column = giatools.pandas.find_column(df, ['label', 'LABEL'])
+                label_list = df[label_column]
+                assert len(pos_x_list) == len(label_list)
+            except KeyError:
+                label_list = list(range(1, len(pos_x_list) + 1))
+
+        # Read the tabular file without header
         else:
-            df = pd.read_csv(labels, header=None)
+            df = pd.read_csv(point_file, header=None, delimiter='\t')
+            pos_x_list = df[0].round().astype(int)
+            pos_y_list = df[1].round().astype(int)
+            assert len(pos_x_list) == len(pos_y_list)
+            radius_list = [0] * len(pos_x_list)
+            label_list = list(range(1, len(pos_x_list) + 1))
+
+        # Optionally swap the coordinates
+        if swap_xy:
+            pos_x_list, pos_y_list = pos_y_list, pos_x_list
 
-    for i in range(0, len(df)):
-        a_row = df.iloc[i]
-        labelimg[a_row[0], a_row[1]] = i + 1
+        # Perform the rasterization
+        for y, x, radius, label in zip(pos_y_list, pos_x_list, radius_list, label_list):
+            if fg_value is not None:
+                label = fg_value
+
+            if y < 0 or x < 0 or y >= shape[0] or x >= shape[1]:
+                raise IndexError(f'The point x={x}, y={y} exceeds the bounds of the image (width: {shape[1]}, height: {shape[0]})')
+
+            # Rasterize circle and distribute overlapping image area
+            if radius > 0:
+                mask = np.ones(shape, dtype=bool)
+                mask[y, x] = False
+                mask = (ndi.distance_transform_edt(mask) <= radius)
 
-    if output_file is not None:
-        with warnings.catch_warnings():
-            warnings.simplefilter("ignore")
-            skimage.io.imsave(output_file, labelimg, plugin='tifffile')
+                # Compute the overlap (pretend there is none if the rasterization is binary)
+                if fg_value is None:
+                    overlap = np.logical_and(img > 0, mask)
+                else:
+                    overlap = np.zeros(shape, dtype=bool)
+
+                # Rasterize the part of the circle which is disjoint from other foreground.
+                #
+                # In the current implementation, the result depends on the order of the rasterized circles if somewhere
+                # more than two circles overlap. This is probably negligable for most applications. To achieve results
+                # that are invariant to the order, first all circles would need to be rasterized independently, and
+                # then blended together. This, however, would either strongly increase the memory consumption, or
+                # require a more complex implementation which exploits the sparsity of the rasterized masks.
+                #
+                disjoint_mask = np.logical_xor(mask, overlap)
+                if disjoint_mask.any():
+                    img[disjoint_mask] = label
+
+                    # Distribute the remaining part of the circle
+                    if overlap.any():
+                        dist = ndi.distance_transform_edt(overlap)
+                        foreground = (img > 0)
+                        img[overlap] = 0
+                        img = skimage.segmentation.watershed(dist, img, mask=foreground)
+
+            # Rasterize point (there is no overlapping area to be distributed)
+            else:
+                img[y, x] = label
+
     else:
-        return labelimg
+        raise Exception("{} is empty or does not exist.".format(point_file))  # appropriate built-in error?
+
+    with warnings.catch_warnings():
+        warnings.simplefilter("ignore")
+        skimage.io.imsave(out_file, img, plugin='tifffile')  # otherwise we get problems with the .dat extension
 
 
 if __name__ == "__main__":
     parser = argparse.ArgumentParser()
-    parser.add_argument('label_file', type=argparse.FileType('r'), default=sys.stdin, help='label file')
-    parser.add_argument('out_file', type=argparse.FileType('w'), default=sys.stdin, help='out file')
-    parser.add_argument('org_file', type=argparse.FileType('r'), default=sys.stdin, help='input original file')
-    parser.add_argument('--has_header', dest='has_header', type=bool, default=False, help='label file has header')
-    parser.add_argument('--is_tsv', dest='is_tsv', type=bool, default=True, help='label file is TSV')
+    parser.add_argument('point_file', type=argparse.FileType('r'), help='point file')
+    parser.add_argument('out_file', type=str, help='out file (TIFF)')
+    parser.add_argument('shapex', type=int, help='shapex')
+    parser.add_argument('shapey', type=int, help='shapey')
+    parser.add_argument('--has_header', dest='has_header', default=False, help='set True if point file has header')
+    parser.add_argument('--swap_xy', dest='swap_xy', default=False, help='Swap X and Y coordinates')
+    parser.add_argument('--binary', dest='binary', default=False, help='Produce binary image')
+
     args = parser.parse_args()
 
-    original_shape = skimage.io.imread(args.org_file.name, plugin='tifffile').shape
-
-    points2label(args.label_file.name, original_shape, args.out_file.name, args.has_header, args.is_tsv)
+    rasterize(
+        args.point_file.name,
+        args.out_file,
+        (args.shapey, args.shapex),
+        has_header=args.has_header,
+        swap_xy=args.swap_xy,
+        fg_value=0xffff if args.binary else None,
+    )