Mercurial > repos > imgteam > points2labelimage
diff points2label.py @ 3:de611b3b5ae8 draft default tip
planemo upload for repository https://github.com/BMCV/galaxy-image-analysis/tree/master/tools/points2labelimage/ commit 6fc9ab8db9ef72ac7ded30d7373768feeae9390d
author | imgteam |
---|---|
date | Fri, 27 Sep 2024 17:41:21 +0000 |
parents | 30ca5d5d03ec |
children |
line wrap: on
line diff
--- a/points2label.py Mon Nov 13 22:11:46 2023 +0000 +++ b/points2label.py Fri Sep 27 17:41:21 2024 +0000 @@ -1,47 +1,125 @@ import argparse -import sys +import os import warnings +import giatools.pandas import numpy as np import pandas as pd +import scipy.ndimage as ndi import skimage.io +import skimage.segmentation -def points2label(labels, shape, output_file=None, has_header=False, is_TSV=True): - labelimg = np.zeros([shape[0], shape[1]], dtype=np.int32) +def rasterize(point_file, out_file, shape, has_header=False, swap_xy=False, bg_value=0, fg_value=None): - if is_TSV: + img = np.full(shape, dtype=np.uint16, fill_value=bg_value) + if os.path.exists(point_file) and os.path.getsize(point_file) > 0: + + # Read the tabular file with information from the header if has_header: - df = pd.read_csv(labels, sep='\t', skiprows=1, header=None) - else: - df = pd.read_csv(labels, sep='\t', header=None) - else: - if has_header: - df = pd.read_csv(labels, skiprows=1, header=None) + df = pd.read_csv(point_file, delimiter='\t') + + pos_x_column = giatools.pandas.find_column(df, ['pos_x', 'POS_X']) + pos_y_column = giatools.pandas.find_column(df, ['pos_y', 'POS_Y']) + pos_x_list = df[pos_x_column].round().astype(int) + pos_y_list = df[pos_y_column].round().astype(int) + assert len(pos_x_list) == len(pos_y_list) + + try: + radius_column = giatools.pandas.find_column(df, ['radius', 'RADIUS']) + radius_list = df[radius_column] + assert len(pos_x_list) == len(radius_list) + except KeyError: + radius_list = [0] * len(pos_x_list) + + try: + label_column = giatools.pandas.find_column(df, ['label', 'LABEL']) + label_list = df[label_column] + assert len(pos_x_list) == len(label_list) + except KeyError: + label_list = list(range(1, len(pos_x_list) + 1)) + + # Read the tabular file without header else: - df = pd.read_csv(labels, header=None) + df = pd.read_csv(point_file, header=None, delimiter='\t') + pos_x_list = df[0].round().astype(int) + pos_y_list = df[1].round().astype(int) + assert len(pos_x_list) == len(pos_y_list) + radius_list = [0] * len(pos_x_list) + label_list = list(range(1, len(pos_x_list) + 1)) + + # Optionally swap the coordinates + if swap_xy: + pos_x_list, pos_y_list = pos_y_list, pos_x_list - for i in range(0, len(df)): - a_row = df.iloc[i] - labelimg[a_row[0], a_row[1]] = i + 1 + # Perform the rasterization + for y, x, radius, label in zip(pos_y_list, pos_x_list, radius_list, label_list): + if fg_value is not None: + label = fg_value + + if y < 0 or x < 0 or y >= shape[0] or x >= shape[1]: + raise IndexError(f'The point x={x}, y={y} exceeds the bounds of the image (width: {shape[1]}, height: {shape[0]})') + + # Rasterize circle and distribute overlapping image area + if radius > 0: + mask = np.ones(shape, dtype=bool) + mask[y, x] = False + mask = (ndi.distance_transform_edt(mask) <= radius) - if output_file is not None: - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - skimage.io.imsave(output_file, labelimg, plugin='tifffile') + # Compute the overlap (pretend there is none if the rasterization is binary) + if fg_value is None: + overlap = np.logical_and(img > 0, mask) + else: + overlap = np.zeros(shape, dtype=bool) + + # Rasterize the part of the circle which is disjoint from other foreground. + # + # In the current implementation, the result depends on the order of the rasterized circles if somewhere + # more than two circles overlap. This is probably negligable for most applications. To achieve results + # that are invariant to the order, first all circles would need to be rasterized independently, and + # then blended together. This, however, would either strongly increase the memory consumption, or + # require a more complex implementation which exploits the sparsity of the rasterized masks. + # + disjoint_mask = np.logical_xor(mask, overlap) + if disjoint_mask.any(): + img[disjoint_mask] = label + + # Distribute the remaining part of the circle + if overlap.any(): + dist = ndi.distance_transform_edt(overlap) + foreground = (img > 0) + img[overlap] = 0 + img = skimage.segmentation.watershed(dist, img, mask=foreground) + + # Rasterize point (there is no overlapping area to be distributed) + else: + img[y, x] = label + else: - return labelimg + raise Exception("{} is empty or does not exist.".format(point_file)) # appropriate built-in error? + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + skimage.io.imsave(out_file, img, plugin='tifffile') # otherwise we get problems with the .dat extension if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('label_file', type=argparse.FileType('r'), default=sys.stdin, help='label file') - parser.add_argument('out_file', type=argparse.FileType('w'), default=sys.stdin, help='out file') - parser.add_argument('org_file', type=argparse.FileType('r'), default=sys.stdin, help='input original file') - parser.add_argument('--has_header', dest='has_header', type=bool, default=False, help='label file has header') - parser.add_argument('--is_tsv', dest='is_tsv', type=bool, default=True, help='label file is TSV') + parser.add_argument('point_file', type=argparse.FileType('r'), help='point file') + parser.add_argument('out_file', type=str, help='out file (TIFF)') + parser.add_argument('shapex', type=int, help='shapex') + parser.add_argument('shapey', type=int, help='shapey') + parser.add_argument('--has_header', dest='has_header', default=False, help='set True if point file has header') + parser.add_argument('--swap_xy', dest='swap_xy', default=False, help='Swap X and Y coordinates') + parser.add_argument('--binary', dest='binary', default=False, help='Produce binary image') + args = parser.parse_args() - original_shape = skimage.io.imread(args.org_file.name, plugin='tifffile').shape - - points2label(args.label_file.name, original_shape, args.out_file.name, args.has_header, args.is_tsv) + rasterize( + args.point_file.name, + args.out_file, + (args.shapey, args.shapex), + has_header=args.has_header, + swap_xy=args.swap_xy, + fg_value=0xffff if args.binary else None, + )