Mercurial > repos > imgteam > image_registration_affine
comparison image_registration_affine.py @ 0:e34222a620d4 draft
"planemo upload for repository https://github.com/BMCV/galaxy-image-analysis/tools/image_registration_affine/ commit 79c2fd560fce8ded4d7f7fe97e876871794e2f9d"
author | imgteam |
---|---|
date | Wed, 30 Dec 2020 20:24:35 +0000 |
parents | |
children | fa769715b6b0 |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:e34222a620d4 |
---|---|
1 import skimage.io | |
2 from skimage.transform import ProjectiveTransform | |
3 from skimage.filters import gaussian | |
4 from scipy.ndimage import map_coordinates | |
5 from scipy.optimize import least_squares | |
6 import numpy as np | |
7 import pandas as pd | |
8 import argparse | |
9 | |
10 | |
11 | |
12 def _stackcopy(a, b): | |
13 if a.ndim == 3: | |
14 a[:] = b[:, :, np.newaxis] | |
15 else: | |
16 a[:] = b | |
17 | |
18 | |
19 | |
20 def warp_coords_batch(coord_map, shape, dtype=np.float64, batch_size=1000000): | |
21 rows, cols = shape[0], shape[1] | |
22 coords_shape = [len(shape), rows, cols] | |
23 if len(shape) == 3: | |
24 coords_shape.append(shape[2]) | |
25 coords = np.empty(coords_shape, dtype=dtype) | |
26 | |
27 tf_coords = np.indices((cols, rows), dtype=dtype).reshape(2, -1).T | |
28 | |
29 for i in range(0, (tf_coords.shape[0]//batch_size+1)): | |
30 tf_coords[batch_size*i:batch_size*(i+1)] = coord_map(tf_coords[batch_size*i:batch_size*(i+1)]) | |
31 tf_coords = tf_coords.T.reshape((-1, cols, rows)).swapaxes(1, 2) | |
32 | |
33 _stackcopy(coords[1, ...], tf_coords[0, ...]) | |
34 _stackcopy(coords[0, ...], tf_coords[1, ...]) | |
35 if len(shape) == 3: | |
36 coords[2, ...] = range(shape[2]) | |
37 | |
38 return coords | |
39 | |
40 | |
41 | |
42 def affine_registration(params,moving,fixed): | |
43 tmat = np.eye(3) | |
44 tmat[0,:] = params.take([0,1,2]) | |
45 tmat[1,:] = params.take([3,4,5]) | |
46 | |
47 trans = ProjectiveTransform(matrix=tmat) | |
48 warped_coords = warp_coords_batch(trans, fixed.shape) | |
49 t = map_coordinates(moving, warped_coords, mode='reflect') | |
50 | |
51 eI = (t - fixed)**2 | |
52 return eI.flatten() | |
53 | |
54 | |
55 | |
56 def image_registration(fn_moving, fn_fixed, fn_out, smooth_sigma=1): | |
57 moving = skimage.io.imread(fn_moving,as_gray=True) | |
58 fixed = skimage.io.imread(fn_fixed,as_gray=True) | |
59 | |
60 moving = gaussian(moving, sigma=smooth_sigma) | |
61 fixed = gaussian(fixed, sigma=smooth_sigma) | |
62 | |
63 x = np.array([1, 0, 0, 0, 1, 0],dtype='float64') | |
64 result = least_squares(affine_registration, x, args=(moving,fixed)) | |
65 | |
66 tmat = np.eye(3) | |
67 tmat[0,:] = result.x.take([0,1,2]) | |
68 tmat[1,:] = result.x.take([3,4,5]) | |
69 | |
70 pd.DataFrame(tmat).to_csv(fn_out, header=None, index=False, sep="\t") | |
71 | |
72 | |
73 | |
74 if __name__ == "__main__": | |
75 | |
76 parser = argparse.ArgumentParser(description="Estimate the transformation matrix") | |
77 parser.add_argument("fn_moving", help="Name of the moving image.png") | |
78 parser.add_argument("fn_fixed", help="Name of the fixed image.png") | |
79 parser.add_argument("fn_tmat", help="Name of output file to save the transformation matrix") | |
80 args = parser.parse_args() | |
81 | |
82 image_registration(args.fn_moving, args.fn_fixed, args.fn_tmat) |