Mercurial > repos > perssond > unmicst
comparison toolbox/PartitionOfImage.py @ 1:74fe58ff55a5 draft default tip
planemo upload for repository https://github.com/HMS-IDAC/UnMicst commit e14f76a8803cab0013c6dbe809bc81d7667f2ab9
author | goeckslab |
---|---|
date | Wed, 07 Sep 2022 23:10:14 +0000 |
parents | 6bec4fef6b2e |
children |
comparison
equal
deleted
inserted
replaced
0:6bec4fef6b2e | 1:74fe58ff55a5 |
---|---|
1 import numpy as np | |
2 from toolbox.imtools import * | |
3 # from toolbox.ftools import * | |
4 # import sys | |
5 | |
6 class PI2D: | |
7 Image = None | |
8 PaddedImage = None | |
9 PatchSize = 128 | |
10 Margin = 14 | |
11 SubPatchSize = 100 | |
12 PC = None # patch coordinates | |
13 NumPatches = 0 | |
14 Output = None | |
15 Count = None | |
16 NR = None | |
17 NC = None | |
18 NRPI = None | |
19 NCPI = None | |
20 Mode = None | |
21 W = None | |
22 | |
23 def setup(image,patchSize,margin,mode): | |
24 PI2D.Image = image | |
25 PI2D.PatchSize = patchSize | |
26 PI2D.Margin = margin | |
27 subPatchSize = patchSize-2*margin | |
28 PI2D.SubPatchSize = subPatchSize | |
29 | |
30 W = np.ones((patchSize,patchSize)) | |
31 W[[0,-1],:] = 0 | |
32 W[:,[0,-1]] = 0 | |
33 for i in range(1,2*margin): | |
34 v = i/(2*margin) | |
35 W[i,i:-i] = v | |
36 W[-i-1,i:-i] = v | |
37 W[i:-i,i] = v | |
38 W[i:-i,-i-1] = v | |
39 PI2D.W = W | |
40 | |
41 if len(image.shape) == 2: | |
42 nr,nc = image.shape | |
43 elif len(image.shape) == 3: # multi-channel image | |
44 nz,nr,nc = image.shape | |
45 | |
46 PI2D.NR = nr | |
47 PI2D.NC = nc | |
48 | |
49 npr = int(np.ceil(nr/subPatchSize)) # number of patch rows | |
50 npc = int(np.ceil(nc/subPatchSize)) # number of patch cols | |
51 | |
52 nrpi = npr*subPatchSize+2*margin # number of rows in padded image | |
53 ncpi = npc*subPatchSize+2*margin # number of cols in padded image | |
54 | |
55 PI2D.NRPI = nrpi | |
56 PI2D.NCPI = ncpi | |
57 | |
58 if len(image.shape) == 2: | |
59 PI2D.PaddedImage = np.zeros((nrpi,ncpi)) | |
60 PI2D.PaddedImage[margin:margin+nr,margin:margin+nc] = image | |
61 elif len(image.shape) == 3: | |
62 PI2D.PaddedImage = np.zeros((nz,nrpi,ncpi)) | |
63 PI2D.PaddedImage[:,margin:margin+nr,margin:margin+nc] = image | |
64 | |
65 PI2D.PC = [] # patch coordinates [r0,r1,c0,c1] | |
66 for i in range(npr): | |
67 r0 = i*subPatchSize | |
68 r1 = r0+patchSize | |
69 for j in range(npc): | |
70 c0 = j*subPatchSize | |
71 c1 = c0+patchSize | |
72 PI2D.PC.append([r0,r1,c0,c1]) | |
73 | |
74 PI2D.NumPatches = len(PI2D.PC) | |
75 PI2D.Mode = mode # 'replace' or 'accumulate' | |
76 | |
77 def getPatch(i): | |
78 r0,r1,c0,c1 = PI2D.PC[i] | |
79 if len(PI2D.PaddedImage.shape) == 2: | |
80 return PI2D.PaddedImage[r0:r1,c0:c1] | |
81 if len(PI2D.PaddedImage.shape) == 3: | |
82 return PI2D.PaddedImage[:,r0:r1,c0:c1] | |
83 | |
84 def createOutput(nChannels): | |
85 if nChannels == 1: | |
86 PI2D.Output = np.zeros((PI2D.NRPI,PI2D.NCPI),np.float16) | |
87 else: | |
88 PI2D.Output = np.zeros((nChannels,PI2D.NRPI,PI2D.NCPI),np.float16) | |
89 if PI2D.Mode == 'accumulate': | |
90 PI2D.Count = np.zeros((PI2D.NRPI,PI2D.NCPI),np.float16) | |
91 | |
92 def patchOutput(i,P): | |
93 r0,r1,c0,c1 = PI2D.PC[i] | |
94 if PI2D.Mode == 'accumulate': | |
95 PI2D.Count[r0:r1,c0:c1] += PI2D.W | |
96 if len(P.shape) == 2: | |
97 if PI2D.Mode == 'accumulate': | |
98 PI2D.Output[r0:r1,c0:c1] += np.multiply(P,PI2D.W) | |
99 elif PI2D.Mode == 'replace': | |
100 PI2D.Output[r0:r1,c0:c1] = P | |
101 elif len(P.shape) == 3: | |
102 if PI2D.Mode == 'accumulate': | |
103 for i in range(P.shape[0]): | |
104 PI2D.Output[i,r0:r1,c0:c1] += np.multiply(P[i,:,:],PI2D.W) | |
105 elif PI2D.Mode == 'replace': | |
106 PI2D.Output[:,r0:r1,c0:c1] = P | |
107 | |
108 def getValidOutput(): | |
109 margin = PI2D.Margin | |
110 nr, nc = PI2D.NR, PI2D.NC | |
111 if PI2D.Mode == 'accumulate': | |
112 C = PI2D.Count[margin:margin+nr,margin:margin+nc] | |
113 if len(PI2D.Output.shape) == 2: | |
114 if PI2D.Mode == 'accumulate': | |
115 return np.divide(PI2D.Output[margin:margin+nr,margin:margin+nc],C) | |
116 if PI2D.Mode == 'replace': | |
117 return PI2D.Output[margin:margin+nr,margin:margin+nc] | |
118 if len(PI2D.Output.shape) == 3: | |
119 if PI2D.Mode == 'accumulate': | |
120 for i in range(PI2D.Output.shape[0]): | |
121 PI2D.Output[i,margin:margin+nr,margin:margin+nc] = np.divide(PI2D.Output[i,margin:margin+nr,margin:margin+nc],C) | |
122 return PI2D.Output[:,margin:margin+nr,margin:margin+nc] | |
123 | |
124 | |
125 def demo(): | |
126 I = np.random.rand(128,128) | |
127 # PI2D.setup(I,128,14) | |
128 PI2D.setup(I,64,4,'replace') | |
129 | |
130 nChannels = 2 | |
131 PI2D.createOutput(nChannels) | |
132 | |
133 for i in range(PI2D.NumPatches): | |
134 P = PI2D.getPatch(i) | |
135 Q = np.zeros((nChannels,P.shape[0],P.shape[1])) | |
136 for j in range(nChannels): | |
137 Q[j,:,:] = P | |
138 PI2D.patchOutput(i,Q) | |
139 | |
140 J = PI2D.getValidOutput() | |
141 J = J[0,:,:] | |
142 | |
143 D = np.abs(I-J) | |
144 print(np.max(D)) | |
145 | |
146 K = cat(1,cat(1,I,J),D) | |
147 imshow(K) | |
148 | |
149 | |
150 class PI3D: | |
151 Image = None | |
152 PaddedImage = None | |
153 PatchSize = 128 | |
154 Margin = 14 | |
155 SubPatchSize = 100 | |
156 PC = None # patch coordinates | |
157 NumPatches = 0 | |
158 Output = None | |
159 Count = None | |
160 NR = None # rows | |
161 NC = None # cols | |
162 NZ = None # planes | |
163 NRPI = None | |
164 NCPI = None | |
165 NZPI = None | |
166 Mode = None | |
167 W = None | |
168 | |
169 def setup(image,patchSize,margin,mode): | |
170 PI3D.Image = image | |
171 PI3D.PatchSize = patchSize | |
172 PI3D.Margin = margin | |
173 subPatchSize = patchSize-2*margin | |
174 PI3D.SubPatchSize = subPatchSize | |
175 | |
176 W = np.ones((patchSize,patchSize,patchSize)) | |
177 W[[0,-1],:,:] = 0 | |
178 W[:,[0,-1],:] = 0 | |
179 W[:,:,[0,-1]] = 0 | |
180 for i in range(1,2*margin): | |
181 v = i/(2*margin) | |
182 W[[i,-i-1],i:-i,i:-i] = v | |
183 W[i:-i,[i,-i-1],i:-i] = v | |
184 W[i:-i,i:-i,[i,-i-1]] = v | |
185 | |
186 PI3D.W = W | |
187 | |
188 if len(image.shape) == 3: | |
189 nz,nr,nc = image.shape | |
190 elif len(image.shape) == 4: # multi-channel image | |
191 nz,nw,nr,nc = image.shape | |
192 | |
193 PI3D.NR = nr | |
194 PI3D.NC = nc | |
195 PI3D.NZ = nz | |
196 | |
197 npr = int(np.ceil(nr/subPatchSize)) # number of patch rows | |
198 npc = int(np.ceil(nc/subPatchSize)) # number of patch cols | |
199 npz = int(np.ceil(nz/subPatchSize)) # number of patch planes | |
200 | |
201 nrpi = npr*subPatchSize+2*margin # number of rows in padded image | |
202 ncpi = npc*subPatchSize+2*margin # number of cols in padded image | |
203 nzpi = npz*subPatchSize+2*margin # number of plns in padded image | |
204 | |
205 PI3D.NRPI = nrpi | |
206 PI3D.NCPI = ncpi | |
207 PI3D.NZPI = nzpi | |
208 | |
209 if len(image.shape) == 3: | |
210 PI3D.PaddedImage = np.zeros((nzpi,nrpi,ncpi)) | |
211 PI3D.PaddedImage[margin:margin+nz,margin:margin+nr,margin:margin+nc] = image | |
212 elif len(image.shape) == 4: | |
213 PI3D.PaddedImage = np.zeros((nzpi,nw,nrpi,ncpi)) | |
214 PI3D.PaddedImage[margin:margin+nz,:,margin:margin+nr,margin:margin+nc] = image | |
215 | |
216 PI3D.PC = [] # patch coordinates [z0,z1,r0,r1,c0,c1] | |
217 for iZ in range(npz): | |
218 z0 = iZ*subPatchSize | |
219 z1 = z0+patchSize | |
220 for i in range(npr): | |
221 r0 = i*subPatchSize | |
222 r1 = r0+patchSize | |
223 for j in range(npc): | |
224 c0 = j*subPatchSize | |
225 c1 = c0+patchSize | |
226 PI3D.PC.append([z0,z1,r0,r1,c0,c1]) | |
227 | |
228 PI3D.NumPatches = len(PI3D.PC) | |
229 PI3D.Mode = mode # 'replace' or 'accumulate' | |
230 | |
231 def getPatch(i): | |
232 z0,z1,r0,r1,c0,c1 = PI3D.PC[i] | |
233 if len(PI3D.PaddedImage.shape) == 3: | |
234 return PI3D.PaddedImage[z0:z1,r0:r1,c0:c1] | |
235 if len(PI3D.PaddedImage.shape) == 4: | |
236 return PI3D.PaddedImage[z0:z1,:,r0:r1,c0:c1] | |
237 | |
238 def createOutput(nChannels): | |
239 if nChannels == 1: | |
240 PI3D.Output = np.zeros((PI3D.NZPI,PI3D.NRPI,PI3D.NCPI)) | |
241 else: | |
242 PI3D.Output = np.zeros((PI3D.NZPI,nChannels,PI3D.NRPI,PI3D.NCPI)) | |
243 if PI3D.Mode == 'accumulate': | |
244 PI3D.Count = np.zeros((PI3D.NZPI,PI3D.NRPI,PI3D.NCPI)) | |
245 | |
246 def patchOutput(i,P): | |
247 z0,z1,r0,r1,c0,c1 = PI3D.PC[i] | |
248 if PI3D.Mode == 'accumulate': | |
249 PI3D.Count[z0:z1,r0:r1,c0:c1] += PI3D.W | |
250 if len(P.shape) == 3: | |
251 if PI3D.Mode == 'accumulate': | |
252 PI3D.Output[z0:z1,r0:r1,c0:c1] += np.multiply(P,PI3D.W) | |
253 elif PI3D.Mode == 'replace': | |
254 PI3D.Output[z0:z1,r0:r1,c0:c1] = P | |
255 elif len(P.shape) == 4: | |
256 if PI3D.Mode == 'accumulate': | |
257 for i in range(P.shape[1]): | |
258 PI3D.Output[z0:z1,i,r0:r1,c0:c1] += np.multiply(P[:,i,:,:],PI3D.W) | |
259 elif PI3D.Mode == 'replace': | |
260 PI3D.Output[z0:z1,:,r0:r1,c0:c1] = P | |
261 | |
262 def getValidOutput(): | |
263 margin = PI3D.Margin | |
264 nz, nr, nc = PI3D.NZ, PI3D.NR, PI3D.NC | |
265 if PI3D.Mode == 'accumulate': | |
266 C = PI3D.Count[margin:margin+nz,margin:margin+nr,margin:margin+nc] | |
267 if len(PI3D.Output.shape) == 3: | |
268 if PI3D.Mode == 'accumulate': | |
269 return np.divide(PI3D.Output[margin:margin+nz,margin:margin+nr,margin:margin+nc],C) | |
270 if PI3D.Mode == 'replace': | |
271 return PI3D.Output[margin:margin+nz,margin:margin+nr,margin:margin+nc] | |
272 if len(PI3D.Output.shape) == 4: | |
273 if PI3D.Mode == 'accumulate': | |
274 for i in range(PI3D.Output.shape[1]): | |
275 PI3D.Output[margin:margin+nz,i,margin:margin+nr,margin:margin+nc] = np.divide(PI3D.Output[margin:margin+nz,i,margin:margin+nr,margin:margin+nc],C) | |
276 return PI3D.Output[margin:margin+nz,:,margin:margin+nr,margin:margin+nc] | |
277 | |
278 | |
279 def demo(): | |
280 I = np.random.rand(128,128,128) | |
281 PI3D.setup(I,64,4,'accumulate') | |
282 | |
283 nChannels = 2 | |
284 PI3D.createOutput(nChannels) | |
285 | |
286 for i in range(PI3D.NumPatches): | |
287 P = PI3D.getPatch(i) | |
288 Q = np.zeros((P.shape[0],nChannels,P.shape[1],P.shape[2])) | |
289 for j in range(nChannels): | |
290 Q[:,j,:,:] = P | |
291 PI3D.patchOutput(i,Q) | |
292 | |
293 J = PI3D.getValidOutput() | |
294 J = J[:,0,:,:] | |
295 | |
296 D = np.abs(I-J) | |
297 print(np.max(D)) | |
298 | |
299 pI = I[64,:,:] | |
300 pJ = J[64,:,:] | |
301 pD = D[64,:,:] | |
302 | |
303 K = cat(1,cat(1,pI,pJ),pD) | |
304 imshow(K) | |
305 |