Mercurial > repos > bgruening > cellpose
comparison cp_segmentation.py @ 0:1e7334a51725 draft
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/cellpose commit 06dd9637975e3b9d6d27a3d5a773c85e9a52baf2
| author | bgruening |
|---|---|
| date | Thu, 29 Feb 2024 22:07:26 +0000 |
| parents | |
| children | 32153c43126c |
comparison
equal
deleted
inserted
replaced
| -1:000000000000 | 0:1e7334a51725 |
|---|---|
| 1 import argparse | |
| 2 import json | |
| 3 import os | |
| 4 import warnings | |
| 5 | |
| 6 import matplotlib.pyplot as plt | |
| 7 import numpy as np | |
| 8 import skimage.io | |
| 9 from cellpose import models, plot, transforms | |
| 10 | |
| 11 | |
| 12 def main(inputs, img_path, img_format, output_dir): | |
| 13 """ | |
| 14 Parameter | |
| 15 --------- | |
| 16 inputs : str | |
| 17 File path to galaxy tool parameter | |
| 18 img_path : str | |
| 19 File path for the input image | |
| 20 img_format : str | |
| 21 One of the ['ome.tiff', 'tiff', 'png', 'jpg'] | |
| 22 output_dir : str | |
| 23 Folder to save the outputs. | |
| 24 """ | |
| 25 warnings.simplefilter('ignore') | |
| 26 | |
| 27 with open(inputs, 'r') as param_handler: | |
| 28 params = json.load(param_handler) | |
| 29 | |
| 30 gpu = params['use_gpu'] | |
| 31 model_type = params['model_type'] | |
| 32 chan = params['chan'] | |
| 33 chan2 = params['chan2'] | |
| 34 chan_first = params['chan_first'] | |
| 35 if chan is None: | |
| 36 channels = None | |
| 37 else: | |
| 38 channels = [int(chan), int(chan2) if chan2 is not None else None] | |
| 39 | |
| 40 options = params['options'] | |
| 41 | |
| 42 img = skimage.io.imread(img_path) | |
| 43 | |
| 44 print(f"Image shape: {img.shape}") | |
| 45 # transpose to Ly x Lx x nchann and reshape based on channels | |
| 46 if img_format.endswith('tiff'): | |
| 47 img = np.transpose(img, (1, 2, 0)) | |
| 48 img = transforms.reshape(img, channels=channels, chan_first=chan_first) | |
| 49 | |
| 50 print(f"Image shape: {img.shape}") | |
| 51 model = models.Cellpose(gpu=gpu, model_type=model_type) | |
| 52 masks, flows, styles, diams = model.eval(img, channels=channels, **options) | |
| 53 | |
| 54 # save masks to tiff | |
| 55 with warnings.catch_warnings(): | |
| 56 warnings.simplefilter("ignore") | |
| 57 skimage.io.imsave(os.path.join(output_dir, 'cp_masks.tif'), | |
| 58 masks.astype(np.uint16)) | |
| 59 | |
| 60 # make segmentation show # | |
| 61 if params['show_segmentation']: | |
| 62 img = skimage.io.imread(img_path) | |
| 63 # uniform image | |
| 64 if img_format.endswith('tiff'): | |
| 65 img = np.transpose(img, (1, 2, 0)) | |
| 66 img = transforms.reshape(img, channels=channels, chan_first=chan_first) | |
| 67 | |
| 68 maski = masks | |
| 69 flowi = flows[0] | |
| 70 fig = plt.figure(figsize=(12, 3)) | |
| 71 # can save images (set save_dir=None if not) | |
| 72 plot.show_segmentation(fig, img, maski, flowi, channels=channels) | |
| 73 fig.savefig(os.path.join(output_dir, 'segm_show.png'), dpi=300) | |
| 74 plt.close(fig) | |
| 75 | |
| 76 | |
| 77 if __name__ == '__main__': | |
| 78 aparser = argparse.ArgumentParser() | |
| 79 aparser.add_argument("-i", "--inputs", dest="inputs", required=True) | |
| 80 aparser.add_argument("-p", "--img_path", dest="img_path") | |
| 81 aparser.add_argument("-f", "--img_format", dest="img_format") | |
| 82 aparser.add_argument("-O", "--output_dir", dest="output_dir") | |
| 83 args = aparser.parse_args() | |
| 84 | |
| 85 main(args.inputs, args.img_path, args.img_format, args.output_dir) |
