diff auto_threshold.py @ 9:50fa6150e340 draft default tip

planemo upload for repository https://github.com/BMCV/galaxy-image-analysis/tree/master/tools/2d_auto_threshold/ commit 01343602708de3cc7fa4986af9000adc36dd0651
author imgteam
date Sat, 07 Jun 2025 18:38:31 +0000
parents 699a5e9146b3
children
line wrap: on
line diff
--- a/auto_threshold.py	Wed Apr 24 08:11:33 2024 +0000
+++ b/auto_threshold.py	Sat Jun 07 18:38:31 2025 +0000
@@ -7,49 +7,111 @@
 
 import argparse
 
-import giatools.io
 import numpy as np
 import skimage.filters
 import skimage.util
-import tifffile
+from giatools.image import Image
+
+
+class DefaultThresholdingMethod:
+
+    def __init__(self, thres, accept: list[str] | None = None, **kwargs):
+        self.thres = thres
+        self.accept = accept if accept else []
+        self.kwargs = kwargs
+
+    def __call__(self, image, *args, offset=0, **kwargs):
+        accepted_kwargs = self.kwargs.copy()
+        for key, val in kwargs.items():
+            if key in self.accept:
+                accepted_kwargs[key] = val
+        thres = self.thres(image, *args, **accepted_kwargs)
+        return image > thres + offset
+
+
+class ManualThresholding:
+
+    def __call__(self, image, thres1: float, thres2: float | None, **kwargs):
+        if thres2 is None:
+            return image > thres1
+        else:
+            thres1, thres2 = sorted((thres1, thres2))
+            return skimage.filters.apply_hysteresis_threshold(image, thres1, thres2)
 
 
 th_methods = {
-    'manual': lambda thres, **kwargs: thres,
+    'manual': ManualThresholding(),
 
-    'otsu': lambda img_raw, **kwargs: skimage.filters.threshold_otsu(img_raw),
-    'li': lambda img_raw, **kwargs: skimage.filters.threshold_li(img_raw),
-    'yen': lambda img_raw, **kwargs: skimage.filters.threshold_yen(img_raw),
-    'isodata': lambda img_raw, **kwargs: skimage.filters.threshold_isodata(img_raw),
+    'otsu': DefaultThresholdingMethod(skimage.filters.threshold_otsu),
+    'li': DefaultThresholdingMethod(skimage.filters.threshold_li),
+    'yen': DefaultThresholdingMethod(skimage.filters.threshold_yen),
+    'isodata': DefaultThresholdingMethod(skimage.filters.threshold_isodata),
 
-    'loc_gaussian': lambda img_raw, bz, **kwargs: skimage.filters.threshold_local(img_raw, bz, method='gaussian'),
-    'loc_median': lambda img_raw, bz, **kwargs: skimage.filters.threshold_local(img_raw, bz, method='median'),
-    'loc_mean': lambda img_raw, bz, **kwargs: skimage.filters.threshold_local(img_raw, bz, method='mean')
+    'loc_gaussian': DefaultThresholdingMethod(skimage.filters.threshold_local, accept=['block_size'], method='gaussian'),
+    'loc_median': DefaultThresholdingMethod(skimage.filters.threshold_local, accept=['block_size'], method='median'),
+    'loc_mean': DefaultThresholdingMethod(skimage.filters.threshold_local, accept=['block_size'], method='mean'),
 }
 
 
-def do_thresholding(in_fn, out_fn, th_method, block_size, offset, threshold, invert_output=False):
-    img = giatools.io.imread(in_fn)
-    img = np.squeeze(img)
-    assert img.ndim == 2
+def do_thresholding(
+    input_filepath: str,
+    output_filepath: str,
+    th_method: str,
+    block_size: int,
+    offset: float,
+    threshold1: float,
+    threshold2: float | None,
+    invert_output: bool,
+):
+    assert th_method in th_methods, f'Unknown method "{th_method}"'
+
+    # Load image
+    img_in = Image.read(input_filepath)
 
-    th = offset + th_methods[th_method](img_raw=img, bz=block_size, thres=threshold)
-    res = img > th
+    # Perform thresholding
+    result = th_methods[th_method](
+        image=img_in.data,
+        block_size=block_size,
+        offset=offset,
+        thres1=threshold1,
+        thres2=threshold2,
+    )
     if invert_output:
-        res = np.logical_not(res)
+        result = np.logical_not(result)
+
+    # Convert to canonical representation for binary images
+    result = (result * 255).astype(np.uint8)
 
-    tifffile.imwrite(out_fn, skimage.util.img_as_ubyte(res))
+    # Write result
+    Image(
+        data=skimage.util.img_as_ubyte(result),
+        axes=img_in.axes,
+    ).normalize_axes_like(
+        img_in.original_axes,
+    ).write(
+        output_filepath,
+    )
 
 
 if __name__ == "__main__":
     parser = argparse.ArgumentParser(description='Automatic image thresholding')
-    parser.add_argument('im_in', help='Path to the input image')
-    parser.add_argument('im_out', help='Path to the output image (uint8)')
+    parser.add_argument('input', type=str, help='Path to the input image')
+    parser.add_argument('output', type=str, help='Path to the output image (uint8)')
     parser.add_argument('th_method', choices=th_methods.keys(), help='Thresholding method')
-    parser.add_argument('block_size', type=int, default=5, help='Odd size of pixel neighborhood for calculating the threshold')
-    parser.add_argument('offset', type=float, default=0, help='Offset of automatically determined threshold value')
-    parser.add_argument('threshold', type=float, default=0, help='Manual threshold value')
+    parser.add_argument('block_size', type=int, help='Odd size of pixel neighborhood for calculating the threshold')
+    parser.add_argument('offset', type=float, help='Offset of automatically determined threshold value')
+    parser.add_argument('threshold1', type=float, help='Manual threshold value')
+    parser.add_argument('--threshold2', type=float, help='Second manual threshold value (for hysteresis thresholding)')
     parser.add_argument('--invert_output', default=False, action='store_true', help='Values below/above the threshold are labeled with 0/255 by default, and with 255/0 if this argument is used')
     args = parser.parse_args()
 
-    do_thresholding(args.im_in, args.im_out, args.th_method, args.block_size, args.offset, args.threshold, args.invert_output)
+    do_thresholding(
+        args.input,
+        args.output,
+        args.th_method,
+        args.block_size,
+        args.offset,
+        args.threshold1,
+        args.threshold2,
+        args.invert_output,
+    )