Mercurial > repos > perssond > unmicst
comparison toolbox/PartitionOfImage.py @ 1:74fe58ff55a5 draft default tip
planemo upload for repository https://github.com/HMS-IDAC/UnMicst commit e14f76a8803cab0013c6dbe809bc81d7667f2ab9
| author | goeckslab |
|---|---|
| date | Wed, 07 Sep 2022 23:10:14 +0000 |
| parents | 6bec4fef6b2e |
| children |
comparison
equal
deleted
inserted
replaced
| 0:6bec4fef6b2e | 1:74fe58ff55a5 |
|---|---|
| 1 import numpy as np | |
| 2 from toolbox.imtools import * | |
| 3 # from toolbox.ftools import * | |
| 4 # import sys | |
| 5 | |
| 6 class PI2D: | |
| 7 Image = None | |
| 8 PaddedImage = None | |
| 9 PatchSize = 128 | |
| 10 Margin = 14 | |
| 11 SubPatchSize = 100 | |
| 12 PC = None # patch coordinates | |
| 13 NumPatches = 0 | |
| 14 Output = None | |
| 15 Count = None | |
| 16 NR = None | |
| 17 NC = None | |
| 18 NRPI = None | |
| 19 NCPI = None | |
| 20 Mode = None | |
| 21 W = None | |
| 22 | |
| 23 def setup(image,patchSize,margin,mode): | |
| 24 PI2D.Image = image | |
| 25 PI2D.PatchSize = patchSize | |
| 26 PI2D.Margin = margin | |
| 27 subPatchSize = patchSize-2*margin | |
| 28 PI2D.SubPatchSize = subPatchSize | |
| 29 | |
| 30 W = np.ones((patchSize,patchSize)) | |
| 31 W[[0,-1],:] = 0 | |
| 32 W[:,[0,-1]] = 0 | |
| 33 for i in range(1,2*margin): | |
| 34 v = i/(2*margin) | |
| 35 W[i,i:-i] = v | |
| 36 W[-i-1,i:-i] = v | |
| 37 W[i:-i,i] = v | |
| 38 W[i:-i,-i-1] = v | |
| 39 PI2D.W = W | |
| 40 | |
| 41 if len(image.shape) == 2: | |
| 42 nr,nc = image.shape | |
| 43 elif len(image.shape) == 3: # multi-channel image | |
| 44 nz,nr,nc = image.shape | |
| 45 | |
| 46 PI2D.NR = nr | |
| 47 PI2D.NC = nc | |
| 48 | |
| 49 npr = int(np.ceil(nr/subPatchSize)) # number of patch rows | |
| 50 npc = int(np.ceil(nc/subPatchSize)) # number of patch cols | |
| 51 | |
| 52 nrpi = npr*subPatchSize+2*margin # number of rows in padded image | |
| 53 ncpi = npc*subPatchSize+2*margin # number of cols in padded image | |
| 54 | |
| 55 PI2D.NRPI = nrpi | |
| 56 PI2D.NCPI = ncpi | |
| 57 | |
| 58 if len(image.shape) == 2: | |
| 59 PI2D.PaddedImage = np.zeros((nrpi,ncpi)) | |
| 60 PI2D.PaddedImage[margin:margin+nr,margin:margin+nc] = image | |
| 61 elif len(image.shape) == 3: | |
| 62 PI2D.PaddedImage = np.zeros((nz,nrpi,ncpi)) | |
| 63 PI2D.PaddedImage[:,margin:margin+nr,margin:margin+nc] = image | |
| 64 | |
| 65 PI2D.PC = [] # patch coordinates [r0,r1,c0,c1] | |
| 66 for i in range(npr): | |
| 67 r0 = i*subPatchSize | |
| 68 r1 = r0+patchSize | |
| 69 for j in range(npc): | |
| 70 c0 = j*subPatchSize | |
| 71 c1 = c0+patchSize | |
| 72 PI2D.PC.append([r0,r1,c0,c1]) | |
| 73 | |
| 74 PI2D.NumPatches = len(PI2D.PC) | |
| 75 PI2D.Mode = mode # 'replace' or 'accumulate' | |
| 76 | |
| 77 def getPatch(i): | |
| 78 r0,r1,c0,c1 = PI2D.PC[i] | |
| 79 if len(PI2D.PaddedImage.shape) == 2: | |
| 80 return PI2D.PaddedImage[r0:r1,c0:c1] | |
| 81 if len(PI2D.PaddedImage.shape) == 3: | |
| 82 return PI2D.PaddedImage[:,r0:r1,c0:c1] | |
| 83 | |
| 84 def createOutput(nChannels): | |
| 85 if nChannels == 1: | |
| 86 PI2D.Output = np.zeros((PI2D.NRPI,PI2D.NCPI),np.float16) | |
| 87 else: | |
| 88 PI2D.Output = np.zeros((nChannels,PI2D.NRPI,PI2D.NCPI),np.float16) | |
| 89 if PI2D.Mode == 'accumulate': | |
| 90 PI2D.Count = np.zeros((PI2D.NRPI,PI2D.NCPI),np.float16) | |
| 91 | |
| 92 def patchOutput(i,P): | |
| 93 r0,r1,c0,c1 = PI2D.PC[i] | |
| 94 if PI2D.Mode == 'accumulate': | |
| 95 PI2D.Count[r0:r1,c0:c1] += PI2D.W | |
| 96 if len(P.shape) == 2: | |
| 97 if PI2D.Mode == 'accumulate': | |
| 98 PI2D.Output[r0:r1,c0:c1] += np.multiply(P,PI2D.W) | |
| 99 elif PI2D.Mode == 'replace': | |
| 100 PI2D.Output[r0:r1,c0:c1] = P | |
| 101 elif len(P.shape) == 3: | |
| 102 if PI2D.Mode == 'accumulate': | |
| 103 for i in range(P.shape[0]): | |
| 104 PI2D.Output[i,r0:r1,c0:c1] += np.multiply(P[i,:,:],PI2D.W) | |
| 105 elif PI2D.Mode == 'replace': | |
| 106 PI2D.Output[:,r0:r1,c0:c1] = P | |
| 107 | |
| 108 def getValidOutput(): | |
| 109 margin = PI2D.Margin | |
| 110 nr, nc = PI2D.NR, PI2D.NC | |
| 111 if PI2D.Mode == 'accumulate': | |
| 112 C = PI2D.Count[margin:margin+nr,margin:margin+nc] | |
| 113 if len(PI2D.Output.shape) == 2: | |
| 114 if PI2D.Mode == 'accumulate': | |
| 115 return np.divide(PI2D.Output[margin:margin+nr,margin:margin+nc],C) | |
| 116 if PI2D.Mode == 'replace': | |
| 117 return PI2D.Output[margin:margin+nr,margin:margin+nc] | |
| 118 if len(PI2D.Output.shape) == 3: | |
| 119 if PI2D.Mode == 'accumulate': | |
| 120 for i in range(PI2D.Output.shape[0]): | |
| 121 PI2D.Output[i,margin:margin+nr,margin:margin+nc] = np.divide(PI2D.Output[i,margin:margin+nr,margin:margin+nc],C) | |
| 122 return PI2D.Output[:,margin:margin+nr,margin:margin+nc] | |
| 123 | |
| 124 | |
| 125 def demo(): | |
| 126 I = np.random.rand(128,128) | |
| 127 # PI2D.setup(I,128,14) | |
| 128 PI2D.setup(I,64,4,'replace') | |
| 129 | |
| 130 nChannels = 2 | |
| 131 PI2D.createOutput(nChannels) | |
| 132 | |
| 133 for i in range(PI2D.NumPatches): | |
| 134 P = PI2D.getPatch(i) | |
| 135 Q = np.zeros((nChannels,P.shape[0],P.shape[1])) | |
| 136 for j in range(nChannels): | |
| 137 Q[j,:,:] = P | |
| 138 PI2D.patchOutput(i,Q) | |
| 139 | |
| 140 J = PI2D.getValidOutput() | |
| 141 J = J[0,:,:] | |
| 142 | |
| 143 D = np.abs(I-J) | |
| 144 print(np.max(D)) | |
| 145 | |
| 146 K = cat(1,cat(1,I,J),D) | |
| 147 imshow(K) | |
| 148 | |
| 149 | |
| 150 class PI3D: | |
| 151 Image = None | |
| 152 PaddedImage = None | |
| 153 PatchSize = 128 | |
| 154 Margin = 14 | |
| 155 SubPatchSize = 100 | |
| 156 PC = None # patch coordinates | |
| 157 NumPatches = 0 | |
| 158 Output = None | |
| 159 Count = None | |
| 160 NR = None # rows | |
| 161 NC = None # cols | |
| 162 NZ = None # planes | |
| 163 NRPI = None | |
| 164 NCPI = None | |
| 165 NZPI = None | |
| 166 Mode = None | |
| 167 W = None | |
| 168 | |
| 169 def setup(image,patchSize,margin,mode): | |
| 170 PI3D.Image = image | |
| 171 PI3D.PatchSize = patchSize | |
| 172 PI3D.Margin = margin | |
| 173 subPatchSize = patchSize-2*margin | |
| 174 PI3D.SubPatchSize = subPatchSize | |
| 175 | |
| 176 W = np.ones((patchSize,patchSize,patchSize)) | |
| 177 W[[0,-1],:,:] = 0 | |
| 178 W[:,[0,-1],:] = 0 | |
| 179 W[:,:,[0,-1]] = 0 | |
| 180 for i in range(1,2*margin): | |
| 181 v = i/(2*margin) | |
| 182 W[[i,-i-1],i:-i,i:-i] = v | |
| 183 W[i:-i,[i,-i-1],i:-i] = v | |
| 184 W[i:-i,i:-i,[i,-i-1]] = v | |
| 185 | |
| 186 PI3D.W = W | |
| 187 | |
| 188 if len(image.shape) == 3: | |
| 189 nz,nr,nc = image.shape | |
| 190 elif len(image.shape) == 4: # multi-channel image | |
| 191 nz,nw,nr,nc = image.shape | |
| 192 | |
| 193 PI3D.NR = nr | |
| 194 PI3D.NC = nc | |
| 195 PI3D.NZ = nz | |
| 196 | |
| 197 npr = int(np.ceil(nr/subPatchSize)) # number of patch rows | |
| 198 npc = int(np.ceil(nc/subPatchSize)) # number of patch cols | |
| 199 npz = int(np.ceil(nz/subPatchSize)) # number of patch planes | |
| 200 | |
| 201 nrpi = npr*subPatchSize+2*margin # number of rows in padded image | |
| 202 ncpi = npc*subPatchSize+2*margin # number of cols in padded image | |
| 203 nzpi = npz*subPatchSize+2*margin # number of plns in padded image | |
| 204 | |
| 205 PI3D.NRPI = nrpi | |
| 206 PI3D.NCPI = ncpi | |
| 207 PI3D.NZPI = nzpi | |
| 208 | |
| 209 if len(image.shape) == 3: | |
| 210 PI3D.PaddedImage = np.zeros((nzpi,nrpi,ncpi)) | |
| 211 PI3D.PaddedImage[margin:margin+nz,margin:margin+nr,margin:margin+nc] = image | |
| 212 elif len(image.shape) == 4: | |
| 213 PI3D.PaddedImage = np.zeros((nzpi,nw,nrpi,ncpi)) | |
| 214 PI3D.PaddedImage[margin:margin+nz,:,margin:margin+nr,margin:margin+nc] = image | |
| 215 | |
| 216 PI3D.PC = [] # patch coordinates [z0,z1,r0,r1,c0,c1] | |
| 217 for iZ in range(npz): | |
| 218 z0 = iZ*subPatchSize | |
| 219 z1 = z0+patchSize | |
| 220 for i in range(npr): | |
| 221 r0 = i*subPatchSize | |
| 222 r1 = r0+patchSize | |
| 223 for j in range(npc): | |
| 224 c0 = j*subPatchSize | |
| 225 c1 = c0+patchSize | |
| 226 PI3D.PC.append([z0,z1,r0,r1,c0,c1]) | |
| 227 | |
| 228 PI3D.NumPatches = len(PI3D.PC) | |
| 229 PI3D.Mode = mode # 'replace' or 'accumulate' | |
| 230 | |
| 231 def getPatch(i): | |
| 232 z0,z1,r0,r1,c0,c1 = PI3D.PC[i] | |
| 233 if len(PI3D.PaddedImage.shape) == 3: | |
| 234 return PI3D.PaddedImage[z0:z1,r0:r1,c0:c1] | |
| 235 if len(PI3D.PaddedImage.shape) == 4: | |
| 236 return PI3D.PaddedImage[z0:z1,:,r0:r1,c0:c1] | |
| 237 | |
| 238 def createOutput(nChannels): | |
| 239 if nChannels == 1: | |
| 240 PI3D.Output = np.zeros((PI3D.NZPI,PI3D.NRPI,PI3D.NCPI)) | |
| 241 else: | |
| 242 PI3D.Output = np.zeros((PI3D.NZPI,nChannels,PI3D.NRPI,PI3D.NCPI)) | |
| 243 if PI3D.Mode == 'accumulate': | |
| 244 PI3D.Count = np.zeros((PI3D.NZPI,PI3D.NRPI,PI3D.NCPI)) | |
| 245 | |
| 246 def patchOutput(i,P): | |
| 247 z0,z1,r0,r1,c0,c1 = PI3D.PC[i] | |
| 248 if PI3D.Mode == 'accumulate': | |
| 249 PI3D.Count[z0:z1,r0:r1,c0:c1] += PI3D.W | |
| 250 if len(P.shape) == 3: | |
| 251 if PI3D.Mode == 'accumulate': | |
| 252 PI3D.Output[z0:z1,r0:r1,c0:c1] += np.multiply(P,PI3D.W) | |
| 253 elif PI3D.Mode == 'replace': | |
| 254 PI3D.Output[z0:z1,r0:r1,c0:c1] = P | |
| 255 elif len(P.shape) == 4: | |
| 256 if PI3D.Mode == 'accumulate': | |
| 257 for i in range(P.shape[1]): | |
| 258 PI3D.Output[z0:z1,i,r0:r1,c0:c1] += np.multiply(P[:,i,:,:],PI3D.W) | |
| 259 elif PI3D.Mode == 'replace': | |
| 260 PI3D.Output[z0:z1,:,r0:r1,c0:c1] = P | |
| 261 | |
| 262 def getValidOutput(): | |
| 263 margin = PI3D.Margin | |
| 264 nz, nr, nc = PI3D.NZ, PI3D.NR, PI3D.NC | |
| 265 if PI3D.Mode == 'accumulate': | |
| 266 C = PI3D.Count[margin:margin+nz,margin:margin+nr,margin:margin+nc] | |
| 267 if len(PI3D.Output.shape) == 3: | |
| 268 if PI3D.Mode == 'accumulate': | |
| 269 return np.divide(PI3D.Output[margin:margin+nz,margin:margin+nr,margin:margin+nc],C) | |
| 270 if PI3D.Mode == 'replace': | |
| 271 return PI3D.Output[margin:margin+nz,margin:margin+nr,margin:margin+nc] | |
| 272 if len(PI3D.Output.shape) == 4: | |
| 273 if PI3D.Mode == 'accumulate': | |
| 274 for i in range(PI3D.Output.shape[1]): | |
| 275 PI3D.Output[margin:margin+nz,i,margin:margin+nr,margin:margin+nc] = np.divide(PI3D.Output[margin:margin+nz,i,margin:margin+nr,margin:margin+nc],C) | |
| 276 return PI3D.Output[margin:margin+nz,:,margin:margin+nr,margin:margin+nc] | |
| 277 | |
| 278 | |
| 279 def demo(): | |
| 280 I = np.random.rand(128,128,128) | |
| 281 PI3D.setup(I,64,4,'accumulate') | |
| 282 | |
| 283 nChannels = 2 | |
| 284 PI3D.createOutput(nChannels) | |
| 285 | |
| 286 for i in range(PI3D.NumPatches): | |
| 287 P = PI3D.getPatch(i) | |
| 288 Q = np.zeros((P.shape[0],nChannels,P.shape[1],P.shape[2])) | |
| 289 for j in range(nChannels): | |
| 290 Q[:,j,:,:] = P | |
| 291 PI3D.patchOutput(i,Q) | |
| 292 | |
| 293 J = PI3D.getValidOutput() | |
| 294 J = J[:,0,:,:] | |
| 295 | |
| 296 D = np.abs(I-J) | |
| 297 print(np.max(D)) | |
| 298 | |
| 299 pI = I[64,:,:] | |
| 300 pJ = J[64,:,:] | |
| 301 pD = D[64,:,:] | |
| 302 | |
| 303 K = cat(1,cat(1,pI,pJ),pD) | |
| 304 imshow(K) | |
| 305 |
