diff colorize_labels.py @ 4:5907be5a8d7c draft default tip

planemo upload for repository https://github.com/BMCV/galaxy-image-analysis/tree/master/tools/colorize_labels commit 4344f548f365dba87c20188d6e3c2df8630d2313
author imgteam
date Tue, 24 Sep 2024 17:30:20 +0000
parents 2d1de6e7b113
children
line wrap: on
line diff
--- a/colorize_labels.py	Wed Apr 24 08:12:15 2024 +0000
+++ b/colorize_labels.py	Tue Sep 24 17:30:20 2024 +0000
@@ -1,12 +1,11 @@
 import argparse
 
 import giatools.io
-import matplotlib.colors as mpl
+import matplotlib.pyplot as plt
 import networkx as nx
 import numpy as np
 import scipy.ndimage as ndi
 import skimage.io
-import skimage.morphology as morph
 import skimage.util
 
 
@@ -22,7 +21,6 @@
 
 def build_label_adjacency_graph(im, radius, bg_label):
     G = nx.Graph()
-    selem = morph.disk(radius)
     for label in np.unique(im):
 
         if label == bg_label:
@@ -31,7 +29,7 @@
         G.add_node(label)
 
         cc = (im == label)
-        neighborhood = ndi.binary_dilation(cc, selem)
+        neighborhood = (ndi.distance_transform_edt(~cc) <= radius)
         adjacent_labels = np.unique(im[neighborhood])
 
         for adjacent_label in adjacent_labels:
@@ -45,6 +43,18 @@
     return G
 
 
+def get_n_unique_mpl_colors(n, colormap='jet', cyclic=False):
+    """
+    Yields `n` unique colors from the given `colormap`.
+
+    Set `cyclic` to `True` if the `colormap` is cyclic.
+    """
+    cmap = plt.get_cmap(colormap)
+    m = n if cyclic else n - 1
+    for i in range(n):
+        yield np.multiply(255, cmap(i / m))
+
+
 if __name__ == '__main__':
 
     parser = argparse.ArgumentParser()
@@ -68,7 +78,7 @@
     unique_colors = frozenset(graph_coloring.values())
 
     # Assign colors to nodes based on the greedy coloring
-    graph_color_to_mpl_color = dict(zip(unique_colors, mpl.TABLEAU_COLORS.values()))
+    graph_color_to_mpl_color = dict(zip(unique_colors, get_n_unique_mpl_colors(len(unique_colors))))
     node_colors = [graph_color_to_mpl_color[graph_coloring[n]] for n in G.nodes()]
 
     # Render result
@@ -77,7 +87,6 @@
     for label, label_color in zip(G.nodes(), node_colors):
 
         cc = (im == label)
-        label_color = color_hex_to_rgb_tuple(label_color)
         for ch in range(3):
             result[:, :, ch][cc] = label_color[ch]