Mercurial > repos > perssond > coreograph
comparison UNetCoreograph.py @ 0:99308601eaa6 draft
"planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
author | perssond |
---|---|
date | Wed, 19 May 2021 21:34:38 +0000 |
parents | |
children | 57f1260ca94e |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:99308601eaa6 |
---|---|
1 import numpy as np | |
2 from scipy import misc as sm | |
3 import shutil | |
4 import scipy.io as sio | |
5 import os | |
6 import skimage.exposure as sk | |
7 import cv2 | |
8 import argparse | |
9 import pytiff | |
10 import tifffile | |
11 import tensorflow as tf | |
12 from skimage.morphology import * | |
13 from skimage.exposure import rescale_intensity | |
14 from skimage.segmentation import chan_vese, find_boundaries, morphological_chan_vese | |
15 from skimage.measure import regionprops,label, find_contours | |
16 from skimage.transform import resize | |
17 from skimage.filters import gaussian | |
18 from skimage.feature import peak_local_max,blob_log | |
19 from skimage.color import label2rgb | |
20 import skimage.io as skio | |
21 from skimage import img_as_bool | |
22 from skimage.draw import circle_perimeter | |
23 from scipy.ndimage.filters import uniform_filter | |
24 from scipy.ndimage import gaussian_laplace | |
25 from os.path import * | |
26 from os import listdir, makedirs, remove | |
27 | |
28 | |
29 | |
30 import sys | |
31 from typing import Any | |
32 | |
33 #sys.path.insert(0, 'C:\\Users\\Public\\Documents\\ImageScience') | |
34 from toolbox.imtools import * | |
35 from toolbox.ftools import * | |
36 from toolbox.PartitionOfImage import PI2D | |
37 | |
38 | |
39 def concat3(lst): | |
40 return tf.concat(lst,3) | |
41 | |
42 class UNet2D: | |
43 hp = None # hyper-parameters | |
44 nn = None # network | |
45 tfTraining = None # if training or not (to handle batch norm) | |
46 tfData = None # data placeholder | |
47 Session = None | |
48 DatasetMean = 0 | |
49 DatasetStDev = 0 | |
50 | |
51 def setupWithHP(hp): | |
52 UNet2D.setup(hp['imSize'], | |
53 hp['nChannels'], | |
54 hp['nClasses'], | |
55 hp['nOut0'], | |
56 hp['featMapsFact'], | |
57 hp['downSampFact'], | |
58 hp['ks'], | |
59 hp['nExtraConvs'], | |
60 hp['stdDev0'], | |
61 hp['nLayers'], | |
62 hp['batchSize']) | |
63 | |
64 def setup(imSize,nChannels,nClasses,nOut0,featMapsFact,downSampFact,kernelSize,nExtraConvs,stdDev0,nDownSampLayers,batchSize): | |
65 UNet2D.hp = {'imSize':imSize, | |
66 'nClasses':nClasses, | |
67 'nChannels':nChannels, | |
68 'nExtraConvs':nExtraConvs, | |
69 'nLayers':nDownSampLayers, | |
70 'featMapsFact':featMapsFact, | |
71 'downSampFact':downSampFact, | |
72 'ks':kernelSize, | |
73 'nOut0':nOut0, | |
74 'stdDev0':stdDev0, | |
75 'batchSize':batchSize} | |
76 | |
77 nOutX = [UNet2D.hp['nChannels'],UNet2D.hp['nOut0']] | |
78 dsfX = [] | |
79 for i in range(UNet2D.hp['nLayers']): | |
80 nOutX.append(nOutX[-1]*UNet2D.hp['featMapsFact']) | |
81 dsfX.append(UNet2D.hp['downSampFact']) | |
82 | |
83 | |
84 # -------------------------------------------------- | |
85 # downsampling layer | |
86 # -------------------------------------------------- | |
87 | |
88 with tf.name_scope('placeholders'): | |
89 UNet2D.tfTraining = tf.placeholder(tf.bool, name='training') | |
90 UNet2D.tfData = tf.placeholder("float", shape=[None,UNet2D.hp['imSize'],UNet2D.hp['imSize'],UNet2D.hp['nChannels']],name='data') | |
91 | |
92 def down_samp_layer(data,index): | |
93 with tf.name_scope('ld%d' % index): | |
94 ldXWeights1 = tf.Variable(tf.truncated_normal([UNet2D.hp['ks'], UNet2D.hp['ks'], nOutX[index], nOutX[index+1]], stddev=stdDev0),name='kernel1') | |
95 ldXWeightsExtra = [] | |
96 for i in range(nExtraConvs): | |
97 ldXWeightsExtra.append(tf.Variable(tf.truncated_normal([UNet2D.hp['ks'], UNet2D.hp['ks'], nOutX[index+1], nOutX[index+1]], stddev=stdDev0),name='kernelExtra%d' % i)) | |
98 | |
99 c00 = tf.nn.conv2d(data, ldXWeights1, strides=[1, 1, 1, 1], padding='SAME') | |
100 for i in range(nExtraConvs): | |
101 c00 = tf.nn.conv2d(tf.nn.relu(c00), ldXWeightsExtra[i], strides=[1, 1, 1, 1], padding='SAME') | |
102 | |
103 ldXWeightsShortcut = tf.Variable(tf.truncated_normal([1, 1, nOutX[index], nOutX[index+1]], stddev=stdDev0),name='shortcutWeights') | |
104 shortcut = tf.nn.conv2d(data, ldXWeightsShortcut, strides=[1, 1, 1, 1], padding='SAME') | |
105 | |
106 bn = tf.layers.batch_normalization(tf.nn.relu(c00+shortcut), training=UNet2D.tfTraining) | |
107 | |
108 return tf.nn.max_pool(bn, ksize=[1, dsfX[index], dsfX[index], 1], strides=[1, dsfX[index], dsfX[index], 1], padding='SAME',name='maxpool') | |
109 | |
110 # -------------------------------------------------- | |
111 # bottom layer | |
112 # -------------------------------------------------- | |
113 | |
114 with tf.name_scope('lb'): | |
115 lbWeights1 = tf.Variable(tf.truncated_normal([UNet2D.hp['ks'], UNet2D.hp['ks'], nOutX[UNet2D.hp['nLayers']], nOutX[UNet2D.hp['nLayers']+1]], stddev=stdDev0),name='kernel1') | |
116 def lb(hidden): | |
117 return tf.nn.relu(tf.nn.conv2d(hidden, lbWeights1, strides=[1, 1, 1, 1], padding='SAME'),name='conv') | |
118 | |
119 # -------------------------------------------------- | |
120 # downsampling | |
121 # -------------------------------------------------- | |
122 | |
123 with tf.name_scope('downsampling'): | |
124 dsX = [] | |
125 dsX.append(UNet2D.tfData) | |
126 | |
127 for i in range(UNet2D.hp['nLayers']): | |
128 dsX.append(down_samp_layer(dsX[i],i)) | |
129 | |
130 b = lb(dsX[UNet2D.hp['nLayers']]) | |
131 | |
132 # -------------------------------------------------- | |
133 # upsampling layer | |
134 # -------------------------------------------------- | |
135 | |
136 def up_samp_layer(data,index): | |
137 with tf.name_scope('lu%d' % index): | |
138 luXWeights1 = tf.Variable(tf.truncated_normal([UNet2D.hp['ks'], UNet2D.hp['ks'], nOutX[index+1], nOutX[index+2]], stddev=stdDev0),name='kernel1') | |
139 luXWeights2 = tf.Variable(tf.truncated_normal([UNet2D.hp['ks'], UNet2D.hp['ks'], nOutX[index]+nOutX[index+1], nOutX[index+1]], stddev=stdDev0),name='kernel2') | |
140 luXWeightsExtra = [] | |
141 for i in range(nExtraConvs): | |
142 luXWeightsExtra.append(tf.Variable(tf.truncated_normal([UNet2D.hp['ks'], UNet2D.hp['ks'], nOutX[index+1], nOutX[index+1]], stddev=stdDev0),name='kernel2Extra%d' % i)) | |
143 | |
144 outSize = UNet2D.hp['imSize'] | |
145 for i in range(index): | |
146 outSize /= dsfX[i] | |
147 outSize = int(outSize) | |
148 | |
149 outputShape = [UNet2D.hp['batchSize'],outSize,outSize,nOutX[index+1]] | |
150 us = tf.nn.relu(tf.nn.conv2d_transpose(data, luXWeights1, outputShape, strides=[1, dsfX[index], dsfX[index], 1], padding='SAME'),name='conv1') | |
151 cc = concat3([dsX[index],us]) | |
152 cv = tf.nn.relu(tf.nn.conv2d(cc, luXWeights2, strides=[1, 1, 1, 1], padding='SAME'),name='conv2') | |
153 for i in range(nExtraConvs): | |
154 cv = tf.nn.relu(tf.nn.conv2d(cv, luXWeightsExtra[i], strides=[1, 1, 1, 1], padding='SAME'),name='conv2Extra%d' % i) | |
155 return cv | |
156 | |
157 # -------------------------------------------------- | |
158 # final (top) layer | |
159 # -------------------------------------------------- | |
160 | |
161 with tf.name_scope('lt'): | |
162 ltWeights1 = tf.Variable(tf.truncated_normal([1, 1, nOutX[1], nClasses], stddev=stdDev0),name='kernel') | |
163 def lt(hidden): | |
164 return tf.nn.conv2d(hidden, ltWeights1, strides=[1, 1, 1, 1], padding='SAME',name='conv') | |
165 | |
166 | |
167 # -------------------------------------------------- | |
168 # upsampling | |
169 # -------------------------------------------------- | |
170 | |
171 with tf.name_scope('upsampling'): | |
172 usX = [] | |
173 usX.append(b) | |
174 | |
175 for i in range(UNet2D.hp['nLayers']): | |
176 usX.append(up_samp_layer(usX[i],UNet2D.hp['nLayers']-1-i)) | |
177 | |
178 t = lt(usX[UNet2D.hp['nLayers']]) | |
179 | |
180 | |
181 sm = tf.nn.softmax(t,-1) | |
182 UNet2D.nn = sm | |
183 | |
184 | |
185 def train(imPath,logPath,modelPath,pmPath,nTrain,nValid,nTest,restoreVariables,nSteps,gpuIndex,testPMIndex): | |
186 os.environ['CUDA_VISIBLE_DEVICES']= '%d' % gpuIndex | |
187 | |
188 outLogPath = logPath | |
189 trainWriterPath = pathjoin(logPath,'Train') | |
190 validWriterPath = pathjoin(logPath,'Valid') | |
191 outModelPath = pathjoin(modelPath,'model.ckpt') | |
192 outPMPath = pmPath | |
193 | |
194 batchSize = UNet2D.hp['batchSize'] | |
195 imSize = UNet2D.hp['imSize'] | |
196 nChannels = UNet2D.hp['nChannels'] | |
197 nClasses = UNet2D.hp['nClasses'] | |
198 | |
199 # -------------------------------------------------- | |
200 # data | |
201 # -------------------------------------------------- | |
202 | |
203 Train = np.zeros((nTrain,imSize,imSize,nChannels)) | |
204 Valid = np.zeros((nValid,imSize,imSize,nChannels)) | |
205 Test = np.zeros((nTest,imSize,imSize,nChannels)) | |
206 LTrain = np.zeros((nTrain,imSize,imSize,nClasses)) | |
207 LValid = np.zeros((nValid,imSize,imSize,nClasses)) | |
208 LTest = np.zeros((nTest,imSize,imSize,nClasses)) | |
209 | |
210 print('loading data, computing mean / st dev') | |
211 if not os.path.exists(modelPath): | |
212 os.makedirs(modelPath) | |
213 if restoreVariables: | |
214 datasetMean = loadData(pathjoin(modelPath,'datasetMean.data')) | |
215 datasetStDev = loadData(pathjoin(modelPath,'datasetStDev.data')) | |
216 else: | |
217 datasetMean = 0 | |
218 datasetStDev = 0 | |
219 for iSample in range(nTrain+nValid+nTest): | |
220 I = im2double(tifread('%s/I%05d_Img.tif' % (imPath,iSample))) | |
221 datasetMean += np.mean(I) | |
222 datasetStDev += np.std(I) | |
223 datasetMean /= (nTrain+nValid+nTest) | |
224 datasetStDev /= (nTrain+nValid+nTest) | |
225 saveData(datasetMean, pathjoin(modelPath,'datasetMean.data')) | |
226 saveData(datasetStDev, pathjoin(modelPath,'datasetStDev.data')) | |
227 | |
228 perm = np.arange(nTrain+nValid+nTest) | |
229 np.random.shuffle(perm) | |
230 | |
231 for iSample in range(0, nTrain): | |
232 path = '%s/I%05d_Img.tif' % (imPath,perm[iSample]) | |
233 im = im2double(tifread(path)) | |
234 Train[iSample,:,:,0] = (im-datasetMean)/datasetStDev | |
235 path = '%s/I%05d_Ant.tif' % (imPath,perm[iSample]) | |
236 im = tifread(path) | |
237 for i in range(nClasses): | |
238 LTrain[iSample,:,:,i] = (im == i+1) | |
239 | |
240 for iSample in range(0, nValid): | |
241 path = '%s/I%05d_Img.tif' % (imPath,perm[nTrain+iSample]) | |
242 im = im2double(tifread(path)) | |
243 Valid[iSample,:,:,0] = (im-datasetMean)/datasetStDev | |
244 path = '%s/I%05d_Ant.tif' % (imPath,perm[nTrain+iSample]) | |
245 im = tifread(path) | |
246 for i in range(nClasses): | |
247 LValid[iSample,:,:,i] = (im == i+1) | |
248 | |
249 for iSample in range(0, nTest): | |
250 path = '%s/I%05d_Img.tif' % (imPath,perm[nTrain+nValid+iSample]) | |
251 im = im2double(tifread(path)) | |
252 Test[iSample,:,:,0] = (im-datasetMean)/datasetStDev | |
253 path = '%s/I%05d_Ant.tif' % (imPath,perm[nTrain+nValid+iSample]) | |
254 im = tifread(path) | |
255 for i in range(nClasses): | |
256 LTest[iSample,:,:,i] = (im == i+1) | |
257 | |
258 # -------------------------------------------------- | |
259 # optimization | |
260 # -------------------------------------------------- | |
261 | |
262 tfLabels = tf.placeholder("float", shape=[None,imSize,imSize,nClasses],name='labels') | |
263 | |
264 globalStep = tf.Variable(0,trainable=False) | |
265 learningRate0 = 0.01 | |
266 decaySteps = 1000 | |
267 decayRate = 0.95 | |
268 learningRate = tf.train.exponential_decay(learningRate0,globalStep,decaySteps,decayRate,staircase=True) | |
269 | |
270 with tf.name_scope('optim'): | |
271 loss = tf.reduce_mean(-tf.reduce_sum(tf.multiply(tfLabels,tf.log(UNet2D.nn)),3)) | |
272 updateOps = tf.get_collection(tf.GraphKeys.UPDATE_OPS) | |
273 # optimizer = tf.train.MomentumOptimizer(1e-3,0.9) | |
274 optimizer = tf.train.MomentumOptimizer(learningRate,0.9) | |
275 # optimizer = tf.train.GradientDescentOptimizer(learningRate) | |
276 with tf.control_dependencies(updateOps): | |
277 optOp = optimizer.minimize(loss,global_step=globalStep) | |
278 | |
279 with tf.name_scope('eval'): | |
280 error = [] | |
281 for iClass in range(nClasses): | |
282 labels0 = tf.reshape(tf.to_int32(tf.slice(tfLabels,[0,0,0,iClass],[-1,-1,-1,1])),[batchSize,imSize,imSize]) | |
283 predict0 = tf.reshape(tf.to_int32(tf.equal(tf.argmax(UNet2D.nn,3),iClass)),[batchSize,imSize,imSize]) | |
284 correct = tf.multiply(labels0,predict0) | |
285 nCorrect0 = tf.reduce_sum(correct) | |
286 nLabels0 = tf.reduce_sum(labels0) | |
287 error.append(1-tf.to_float(nCorrect0)/tf.to_float(nLabels0)) | |
288 errors = tf.tuple(error) | |
289 | |
290 # -------------------------------------------------- | |
291 # inspection | |
292 # -------------------------------------------------- | |
293 | |
294 with tf.name_scope('scalars'): | |
295 tf.summary.scalar('avg_cross_entropy', loss) | |
296 for iClass in range(nClasses): | |
297 tf.summary.scalar('avg_pixel_error_%d' % iClass, error[iClass]) | |
298 tf.summary.scalar('learning_rate', learningRate) | |
299 with tf.name_scope('images'): | |
300 split0 = tf.slice(UNet2D.nn,[0,0,0,0],[-1,-1,-1,1]) | |
301 split1 = tf.slice(UNet2D.nn,[0,0,0,1],[-1,-1,-1,1]) | |
302 if nClasses > 2: | |
303 split2 = tf.slice(UNet2D.nn,[0,0,0,2],[-1,-1,-1,1]) | |
304 tf.summary.image('pm0',split0) | |
305 tf.summary.image('pm1',split1) | |
306 if nClasses > 2: | |
307 tf.summary.image('pm2',split2) | |
308 merged = tf.summary.merge_all() | |
309 | |
310 | |
311 # -------------------------------------------------- | |
312 # session | |
313 # -------------------------------------------------- | |
314 | |
315 saver = tf.train.Saver() | |
316 sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) # config parameter needed to save variables when using GPU | |
317 | |
318 if os.path.exists(outLogPath): | |
319 shutil.rmtree(outLogPath) | |
320 trainWriter = tf.summary.FileWriter(trainWriterPath, sess.graph) | |
321 validWriter = tf.summary.FileWriter(validWriterPath, sess.graph) | |
322 | |
323 if restoreVariables: | |
324 saver.restore(sess, outModelPath) | |
325 print("Model restored.") | |
326 else: | |
327 sess.run(tf.global_variables_initializer()) | |
328 | |
329 # -------------------------------------------------- | |
330 # train | |
331 # -------------------------------------------------- | |
332 | |
333 batchData = np.zeros((batchSize,imSize,imSize,nChannels)) | |
334 batchLabels = np.zeros((batchSize,imSize,imSize,nClasses)) | |
335 for i in range(nSteps): | |
336 # train | |
337 | |
338 perm = np.arange(nTrain) | |
339 np.random.shuffle(perm) | |
340 | |
341 for j in range(batchSize): | |
342 batchData[j,:,:,:] = Train[perm[j],:,:,:] | |
343 batchLabels[j,:,:,:] = LTrain[perm[j],:,:,:] | |
344 | |
345 summary,_ = sess.run([merged,optOp],feed_dict={UNet2D.tfData: batchData, tfLabels: batchLabels, UNet2D.tfTraining: 1}) | |
346 trainWriter.add_summary(summary, i) | |
347 | |
348 # validation | |
349 | |
350 perm = np.arange(nValid) | |
351 np.random.shuffle(perm) | |
352 | |
353 for j in range(batchSize): | |
354 batchData[j,:,:,:] = Valid[perm[j],:,:,:] | |
355 batchLabels[j,:,:,:] = LValid[perm[j],:,:,:] | |
356 | |
357 summary, es = sess.run([merged, errors],feed_dict={UNet2D.tfData: batchData, tfLabels: batchLabels, UNet2D.tfTraining: 0}) | |
358 validWriter.add_summary(summary, i) | |
359 | |
360 e = np.mean(es) | |
361 print('step %05d, e: %f' % (i,e)) | |
362 | |
363 if i == 0: | |
364 if restoreVariables: | |
365 lowestError = e | |
366 else: | |
367 lowestError = np.inf | |
368 | |
369 if np.mod(i,100) == 0 and e < lowestError: | |
370 lowestError = e | |
371 print("Model saved in file: %s" % saver.save(sess, outModelPath)) | |
372 | |
373 | |
374 # -------------------------------------------------- | |
375 # test | |
376 # -------------------------------------------------- | |
377 | |
378 if not os.path.exists(outPMPath): | |
379 os.makedirs(outPMPath) | |
380 | |
381 for i in range(nTest): | |
382 j = np.mod(i,batchSize) | |
383 | |
384 batchData[j,:,:,:] = Test[i,:,:,:] | |
385 batchLabels[j,:,:,:] = LTest[i,:,:,:] | |
386 | |
387 if j == batchSize-1 or i == nTest-1: | |
388 | |
389 output = sess.run(UNet2D.nn,feed_dict={UNet2D.tfData: batchData, tfLabels: batchLabels, UNet2D.tfTraining: 0}) | |
390 | |
391 for k in range(j+1): | |
392 pm = output[k,:,:,testPMIndex] | |
393 gt = batchLabels[k,:,:,testPMIndex] | |
394 im = np.sqrt(normalize(batchData[k,:,:,0])) | |
395 imwrite(np.uint8(255*np.concatenate((im,np.concatenate((pm,gt),axis=1)),axis=1)),'%s/I%05d.png' % (outPMPath,i-j+k+1)) | |
396 | |
397 | |
398 # -------------------------------------------------- | |
399 # save hyper-parameters, clean-up | |
400 # -------------------------------------------------- | |
401 | |
402 saveData(UNet2D.hp,pathjoin(modelPath,'hp.data')) | |
403 | |
404 trainWriter.close() | |
405 validWriter.close() | |
406 sess.close() | |
407 | |
408 def deploy(imPath,nImages,modelPath,pmPath,gpuIndex,pmIndex): | |
409 os.environ['CUDA_VISIBLE_DEVICES']= '%d' % gpuIndex | |
410 variablesPath = pathjoin(modelPath,'model.ckpt') | |
411 outPMPath = pmPath | |
412 | |
413 hp = loadData(pathjoin(modelPath,'hp.data')) | |
414 UNet2D.setupWithHP(hp) | |
415 | |
416 batchSize = UNet2D.hp['batchSize'] | |
417 imSize = UNet2D.hp['imSize'] | |
418 nChannels = UNet2D.hp['nChannels'] | |
419 nClasses = UNet2D.hp['nClasses'] | |
420 | |
421 # -------------------------------------------------- | |
422 # data | |
423 # -------------------------------------------------- | |
424 | |
425 Data = np.zeros((nImages,imSize,imSize,nChannels)) | |
426 | |
427 datasetMean = loadData(pathjoin(modelPath,'datasetMean.data')) | |
428 datasetStDev = loadData(pathjoin(modelPath,'datasetStDev.data')) | |
429 | |
430 for iSample in range(0, nImages): | |
431 path = '%s/I%05d_Img.tif' % (imPath,iSample) | |
432 im = im2double(tifread(path)) | |
433 Data[iSample,:,:,0] = (im-datasetMean)/datasetStDev | |
434 | |
435 # -------------------------------------------------- | |
436 # session | |
437 # -------------------------------------------------- | |
438 | |
439 saver = tf.train.Saver() | |
440 sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) # config parameter needed to save variables when using GPU | |
441 | |
442 saver.restore(sess, variablesPath) | |
443 print("Model restored.") | |
444 | |
445 # -------------------------------------------------- | |
446 # deploy | |
447 # -------------------------------------------------- | |
448 | |
449 batchData = np.zeros((batchSize,imSize,imSize,nChannels)) | |
450 | |
451 if not os.path.exists(outPMPath): | |
452 os.makedirs(outPMPath) | |
453 | |
454 for i in range(nImages): | |
455 print(i,nImages) | |
456 | |
457 j = np.mod(i,batchSize) | |
458 | |
459 batchData[j,:,:,:] = Data[i,:,:,:] | |
460 | |
461 if j == batchSize-1 or i == nImages-1: | |
462 | |
463 output = sess.run(UNet2D.nn,feed_dict={UNet2D.tfData: batchData, UNet2D.tfTraining: 0}) | |
464 | |
465 for k in range(j+1): | |
466 pm = output[k,:,:,pmIndex] | |
467 im = np.sqrt(normalize(batchData[k,:,:,0])) | |
468 # imwrite(np.uint8(255*np.concatenate((im,pm),axis=1)),'%s/I%05d.png' % (outPMPath,i-j+k+1)) | |
469 imwrite(np.uint8(255*im),'%s/I%05d_Im.png' % (outPMPath,i-j+k+1)) | |
470 imwrite(np.uint8(255*pm),'%s/I%05d_PM.png' % (outPMPath,i-j+k+1)) | |
471 | |
472 | |
473 # -------------------------------------------------- | |
474 # clean-up | |
475 # -------------------------------------------------- | |
476 | |
477 sess.close() | |
478 | |
479 def singleImageInferenceSetup(modelPath,gpuIndex): | |
480 os.environ['CUDA_VISIBLE_DEVICES']= '%d' % gpuIndex | |
481 variablesPath = pathjoin(modelPath,'model.ckpt') | |
482 hp = loadData(pathjoin(modelPath,'hp.data')) | |
483 UNet2D.setupWithHP(hp) | |
484 | |
485 UNet2D.DatasetMean =loadData(pathjoin(modelPath,'datasetMean.data')) | |
486 UNet2D.DatasetStDev = loadData(pathjoin(modelPath,'datasetStDev.data')) | |
487 print(UNet2D.DatasetMean) | |
488 print(UNet2D.DatasetStDev) | |
489 | |
490 # -------------------------------------------------- | |
491 # session | |
492 # -------------------------------------------------- | |
493 | |
494 saver = tf.train.Saver() | |
495 UNet2D.Session = tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) # config parameter needed to save variables when using GPU | |
496 #UNet2D.Session = tf.Session(config=tf.ConfigProto(device_count={'GPU': 0})) | |
497 saver.restore(UNet2D.Session, variablesPath) | |
498 print("Model restored.") | |
499 | |
500 def singleImageInferenceCleanup(): | |
501 UNet2D.Session.close() | |
502 | |
503 def singleImageInference(image,mode,pmIndex): | |
504 print('Inference...') | |
505 | |
506 batchSize = UNet2D.hp['batchSize'] | |
507 imSize = UNet2D.hp['imSize'] | |
508 nChannels = UNet2D.hp['nChannels'] | |
509 | |
510 PI2D.setup(image,imSize,int(imSize/8),mode) | |
511 PI2D.createOutput(nChannels) | |
512 | |
513 batchData = np.zeros((batchSize,imSize,imSize,nChannels)) | |
514 for i in range(PI2D.NumPatches): | |
515 j = np.mod(i,batchSize) | |
516 batchData[j,:,:,0] = (PI2D.getPatch(i)-UNet2D.DatasetMean)/UNet2D.DatasetStDev | |
517 if j == batchSize-1 or i == PI2D.NumPatches-1: | |
518 output = UNet2D.Session.run(UNet2D.nn,feed_dict={UNet2D.tfData: batchData, UNet2D.tfTraining: 0}) | |
519 for k in range(j+1): | |
520 pm = output[k,:,:,pmIndex] | |
521 PI2D.patchOutput(i-j+k,pm) | |
522 # PI2D.patchOutput(i-j+k,normalize(imgradmag(PI2D.getPatch(i-j+k),1))) | |
523 | |
524 return PI2D.getValidOutput() | |
525 | |
526 | |
527 def identifyNumChan(path): | |
528 tiff = tifffile.TiffFile(path) | |
529 shape = tiff.pages[0].shape | |
530 numChan=None | |
531 for i, page in enumerate(tiff.pages): | |
532 if page.shape != shape: | |
533 numChan = i | |
534 return numChan | |
535 break | |
536 # else: | |
537 # raise Exception("Did not find any pyramid subresolutions") | |
538 | |
539 if not numChan: | |
540 numChan = len(tiff.pages) | |
541 return numChan | |
542 | |
543 def getProbMaps(I,dsFactor,modelPath): | |
544 hsize = int((float(I.shape[0]) * float(0.5))) | |
545 vsize = int((float(I.shape[1]) * float(0.5))) | |
546 imagesub = cv2.resize(I,(vsize,hsize),cv2.INTER_NEAREST) | |
547 | |
548 UNet2D.singleImageInferenceSetup(modelPath, 1) | |
549 | |
550 for iSize in range(dsFactor): | |
551 hsize = int((float(I.shape[0]) * float(0.5))) | |
552 vsize = int((float(I.shape[1]) * float(0.5))) | |
553 I = cv2.resize(I,(vsize,hsize),cv2.INTER_NEAREST) | |
554 I = im2double(I) | |
555 I = im2double(sk.rescale_intensity(I, in_range=(np.min(I), np.max(I)), out_range=(0, 0.983))) | |
556 probMaps = UNet2D.singleImageInference(I,'accumulate',1) | |
557 UNet2D.singleImageInferenceCleanup() | |
558 return probMaps | |
559 | |
560 def coreSegmenterOutput(I,probMap,initialmask,preBlur,findCenter): | |
561 hsize = int((float(I.shape[0]) * float(0.1))) | |
562 vsize = int((float(I.shape[1]) * float(0.1))) | |
563 nucGF = cv2.resize(I,(vsize,hsize),cv2.INTER_CUBIC) | |
564 # Irs = cv2.resize(I,(vsize,hsize),cv2.INTER_CUBIC) | |
565 # I=I.astype(np.float) | |
566 # r,c = I.shape | |
567 # I+=np.random.rand(r,c)*1e-6 | |
568 # c1 = uniform_filter(I, 3, mode='reflect') | |
569 # c2 = uniform_filter(I*I, 3, mode='reflect') | |
570 # nucGF = np.sqrt(c2 - c1*c1)*np.sqrt(9./8) | |
571 # nucGF[np.isnan(nucGF)]=0 | |
572 #active contours | |
573 hsize = int(float(nucGF.shape[0])) | |
574 vsize = int(float(nucGF.shape[1])) | |
575 initialmask = cv2.resize(initialmask,(vsize,hsize),cv2.INTER_NEAREST) | |
576 initialmask = dilation(initialmask,disk(15)) >0 | |
577 | |
578 # init=np.argwhere(eroded>0) | |
579 nucGF = gaussian(nucGF,0.7) | |
580 nucGF=nucGF/np.amax(nucGF) | |
581 | |
582 | |
583 # initialmask = nucGF>0 | |
584 nuclearMask = morphological_chan_vese(nucGF, 100, init_level_set=initialmask, smoothing=10,lambda1=1.001, lambda2=1) | |
585 | |
586 # nuclearMask = chan_vese(nucGF, mu=1.5, lambda1=6, lambda2=1, tol=0.0005, max_iter=2000, dt=15, init_level_set=initialmask, extended_output=True) | |
587 # nuclearMask = nuclearMask[0] | |
588 | |
589 | |
590 TMAmask = nuclearMask | |
591 # nMaskDist =distance_transform_edt(nuclearMask) | |
592 # fgm = peak_local_max(h_maxima(nMaskDist, 2*preBlur),indices =False) | |
593 # markers= np.logical_or(erosion(1-nuclearMask,disk(3)),fgm) | |
594 # TMAmask=watershed(-nMaskDist,label(markers),watershed_line=True) | |
595 # TMAmask = nuclearMask*(TMAmask>0) | |
596 TMAmask = remove_small_objects(TMAmask>0,round(TMAmask.shape[0])*round(TMAmask.shape[1])*0.005) | |
597 TMAlabel = label(TMAmask) | |
598 # find object closest to center | |
599 if findCenter==True: | |
600 | |
601 stats= regionprops(TMAlabel) | |
602 counter=1 | |
603 minDistance =-1 | |
604 index =[] | |
605 for props in stats: | |
606 centroid = props.centroid | |
607 distanceFromCenter = np.sqrt((centroid[0]-nucGF.shape[0]/2)**2+(centroid[1]-nucGF.shape[1]/2)**2) | |
608 # if distanceFromCenter<0.6/2*np.sqrt(TMAlabel.shape[0]*TMAlabel.shape[1]): | |
609 if distanceFromCenter<minDistance or minDistance==-1 : | |
610 minDistance =distanceFromCenter | |
611 index = counter | |
612 counter=counter+1 | |
613 # dist = 0.6/2*np.sqrt(TMAlabel.shape[0]*TMAlabel.shape[1]) | |
614 TMAmask = morphology.binary_closing(TMAlabel==index,disk(3)) | |
615 | |
616 return TMAmask | |
617 | |
618 def overlayOutline(outline,img): | |
619 img2 = img.copy() | |
620 stacked_img = np.stack((img2,)*3, axis=-1) | |
621 stacked_img[outline > 0] = [1, 0, 0] | |
622 imshowpair(img2,stacked_img) | |
623 | |
624 def imshowpair(A,B): | |
625 plt.imshow(A,cmap='Purples') | |
626 plt.imshow(B,cmap='Greens',alpha=0.5) | |
627 plt.show() | |
628 | |
629 | |
630 if __name__ == '__main__': | |
631 parser=argparse.ArgumentParser() | |
632 parser.add_argument("--imagePath") | |
633 parser.add_argument("--outputPath") | |
634 parser.add_argument("--maskPath") | |
635 parser.add_argument("--downsampleFactor",type = int, default = 5) | |
636 parser.add_argument("--channel",type = int, default = 0) | |
637 parser.add_argument("--buffer",type = float, default = 2) | |
638 parser.add_argument("--outputChan", type=int, nargs = '+', default=[-1]) | |
639 parser.add_argument("--sensitivity",type = float, default=0.3) | |
640 parser.add_argument("--useGrid",action='store_true') | |
641 parser.add_argument("--cluster",action='store_true') | |
642 args = parser.parse_args() | |
643 | |
644 outputPath = args.outputPath | |
645 imagePath = args.imagePath | |
646 sensitivity = args.sensitivity | |
647 #scriptPath = os.path.dirname(os.path.realpath(__file__)) | |
648 #modelPath = os.path.join(scriptPath, 'TFModel - 3class 16 kernels 5ks 2 layers') | |
649 #modelPath = 'D:\\LSP\\Coreograph\\model-4layersMaskAug20' | |
650 scriptPath = os.path.dirname(os.path.realpath(__file__)) | |
651 modelPath = os.path.join(scriptPath, 'model') | |
652 # outputPath = 'D:\\LSP\\cycif\\testsets\\exemplar-002\\dearrayPython' ############ | |
653 maskOutputPath = os.path.join(outputPath, 'masks') | |
654 # imagePath = 'D:\\LSP\\cycif\\testsets\\exemplar-002\\registration\\exemplar-002.ome.tif'########### | |
655 # imagePath = 'Y:\\sorger\\data\\RareCyte\\Connor\\TMAs\\CAJ_TMA11_13\\original_data\\TMA11\\registration\\TMA11.ome.tif' | |
656 # imagePath = 'Y:\\sorger\\data\\RareCyte\\Connor\\TMAs\\Z124_TMA20_22\\TMA22\\registration\\TMA22.ome.tif' | |
657 # classProbsPath = 'D:\\unetcoreograph.tif' | |
658 # imagePath = 'Y:\\sorger\\data\\RareCyte\\Connor\\Z155_PTCL\\TMA_552\\registration\\TMA_552.ome.tif' | |
659 # classProbsPath = 'Y:\\sorger\\data\\RareCyte\\Connor\\Z155_PTCL\\TMA_552\\probMapCore\\TMA_552_CorePM_1.tif' | |
660 # imagePath = 'Y:\\sorger\\data\\RareCyte\\Zoltan\\Z112_TMA17_19\\190403_ashlar\\TMA17_1092.ome.tif' | |
661 # classProbsPath = 'Z:\\IDAC\\Clarence\\LSP\\CyCIF\\TMA\\probMapCore\\1new_CorePM_1.tif' | |
662 # imagePath = 'Y:\\sorger\\data\\RareCyte\\ANNIINA\\Julia\\2018\\TMA6\\julia_tma6.ome.tif' | |
663 # classProbsPath = 'Z:\\IDAC\\Clarence\\LSP\\CyCIF\\TMA\\probMapCore\\3new_CorePM_1.tif' | |
664 | |
665 | |
666 # if not os.path.exists(outputPath): | |
667 # os.makedirs(outputPath) | |
668 # else: | |
669 # shutil.rmtree(outputPath) | |
670 if not os.path.exists(maskOutputPath): | |
671 os.makedirs(maskOutputPath) | |
672 | |
673 | |
674 channel = args.channel | |
675 dsFactor = 1/(2**args.downsampleFactor) | |
676 # I = tifffile.imread(imagePath, key=channel) | |
677 I = skio.imread(imagePath, img_num=channel) | |
678 | |
679 imagesub = resize(I,(int((float(I.shape[0]) * dsFactor)),int((float(I.shape[1]) * dsFactor)))) | |
680 numChan = identifyNumChan(imagePath) | |
681 | |
682 outputChan = args.outputChan | |
683 if len(outputChan)==1: | |
684 if outputChan[0]==-1: | |
685 outputChan = [0, numChan-1] | |
686 else: | |
687 outputChan.append(outputChan[0]) | |
688 | |
689 classProbs = getProbMaps(I,args.downsampleFactor,modelPath) | |
690 # classProbs = tifffile.imread(classProbsPath,key=0) | |
691 preMask = gaussian(np.uint8(classProbs*255),1)>0.8 | |
692 | |
693 P = regionprops(label(preMask),cache=False) | |
694 area = [ele.area for ele in P] | |
695 print(str(len(P)) + ' cores detected!') | |
696 if len(P) <3: | |
697 medArea = np.median(area) | |
698 maxArea = np.percentile(area,99) | |
699 else: | |
700 count=0 | |
701 labelpreMask = np.zeros(preMask.shape,dtype=np.uint32) | |
702 for props in P: | |
703 count += 1 | |
704 yi = props.coords[:, 0] | |
705 xi = props.coords[:, 1] | |
706 labelpreMask[yi, xi] = count | |
707 P=regionprops(labelpreMask) | |
708 area = [ele.area for ele in P] | |
709 medArea = np.median(area) | |
710 maxArea = np.percentile(area,99) | |
711 preMask = remove_small_objects(preMask,0.2*medArea) | |
712 coreRad = round(np.sqrt(medArea/np.pi)) | |
713 estCoreDiam = round(np.sqrt(maxArea/np.pi)*1.2*args.buffer) | |
714 | |
715 #preprocessing | |
716 fgFiltered = blob_log(preMask,coreRad*0.6,threshold=sensitivity) | |
717 Imax = np.zeros(preMask.shape,dtype=np.uint8) | |
718 for iSpot in range(fgFiltered.shape[0]): | |
719 yi = np.uint32(round(fgFiltered[iSpot, 0])) | |
720 xi = np.uint32(round(fgFiltered[iSpot, 1])) | |
721 Imax[yi, xi] = 1 | |
722 Imax = Imax*preMask | |
723 Idist = distance_transform_edt(1-Imax) | |
724 markers = label(Imax) | |
725 coreLabel = watershed(Idist,markers,watershed_line=True,mask = preMask) | |
726 P = regionprops(coreLabel) | |
727 centroids = np.array([ele.centroid for ele in P])/dsFactor | |
728 numCores = len(centroids) | |
729 estCoreDiamX = np.ones(numCores)*estCoreDiam/dsFactor | |
730 estCoreDiamY = np.ones(numCores)*estCoreDiam/dsFactor | |
731 | |
732 if numCores ==0 & args.cluster: | |
733 print('No cores detected. Try adjusting the downsample factor') | |
734 sys.exit(255) | |
735 | |
736 singleMaskTMA = np.zeros(imagesub.shape) | |
737 maskTMA = np.zeros(imagesub.shape) | |
738 bbox = [None] * numCores | |
739 | |
740 | |
741 x=np.zeros(numCores) | |
742 xLim=np.zeros(numCores) | |
743 y=np.zeros(numCores) | |
744 yLim=np.zeros(numCores) | |
745 | |
746 # segmenting each core | |
747 ####################### | |
748 for iCore in range(numCores): | |
749 x[iCore] = centroids[iCore,1] - estCoreDiamX[iCore]/2 | |
750 xLim[iCore] = x[iCore]+estCoreDiamX[iCore] | |
751 if xLim[iCore] > I.shape[1]: | |
752 xLim[iCore] = I.shape[1] | |
753 if x[iCore]<1: | |
754 x[iCore]=1 | |
755 | |
756 y[iCore] = centroids[iCore,0] - estCoreDiamY[iCore]/2 | |
757 yLim[iCore] = y[iCore] + estCoreDiamY[iCore] | |
758 if yLim[iCore] > I.shape[0]: | |
759 yLim[iCore] = I.shape[0] | |
760 if y[iCore]<1: | |
761 y[iCore]=1 | |
762 | |
763 bbox[iCore] = [round(x[iCore]), round(y[iCore]), round(xLim[iCore]), round(yLim[iCore])] | |
764 | |
765 for iChan in range(outputChan[0],outputChan[1]+1): | |
766 with pytiff.Tiff(imagePath, "r", encoding='utf-8') as handle: | |
767 handle.set_page(iChan) | |
768 coreStack= handle[np.uint32(bbox[iCore][1]):np.uint32(bbox[iCore][3]-1), np.uint32(bbox[iCore][0]):np.uint32(bbox[iCore][2]-1)] | |
769 skio.imsave(outputPath + os.path.sep + str(iCore+1) + '.tif',coreStack,append=True) | |
770 | |
771 with pytiff.Tiff(imagePath, "r", encoding='utf-8') as handle: | |
772 handle.set_page(args.channel) | |
773 coreSlice= handle[np.uint32(bbox[iCore][1]):np.uint32(bbox[iCore][3]-1), np.uint32(bbox[iCore][0]):np.uint32(bbox[iCore][2]-1)] | |
774 | |
775 core = (coreLabel ==(iCore+1)) | |
776 initialmask = core[np.uint32(y[iCore]*dsFactor):np.uint32(yLim[iCore]*dsFactor),np.uint32(x[iCore]*dsFactor):np.uint32(xLim[iCore]*dsFactor)] | |
777 initialmask = resize(initialmask,size(coreSlice),cv2.INTER_NEAREST) | |
778 | |
779 singleProbMap = classProbs[np.uint32(y[iCore]*dsFactor):np.uint32(yLim[iCore]*dsFactor),np.uint32(x[iCore]*dsFactor):np.uint32(xLim[iCore]*dsFactor)] | |
780 singleProbMap = resize(np.uint8(255*singleProbMap),size(coreSlice),cv2.INTER_NEAREST) | |
781 TMAmask = coreSegmenterOutput(coreSlice,singleProbMap,initialmask,coreRad/20,False) | |
782 if np.sum(TMAmask)==0: | |
783 TMAmask = np.ones(TMAmask.shape) | |
784 vsize = int(float(coreSlice.shape[0])) | |
785 hsize = int(float(coreSlice.shape[1])) | |
786 masksub = resize(resize(TMAmask,(vsize,hsize),cv2.INTER_NEAREST),(int((float(coreSlice.shape[0])*dsFactor)),int((float(coreSlice.shape[1])*dsFactor))),cv2.INTER_NEAREST) | |
787 singleMaskTMA[int(y[iCore]*dsFactor):int(y[iCore]*dsFactor)+masksub.shape[0],int(x[iCore]*dsFactor):int(x[iCore]*dsFactor)+masksub.shape[1]]=masksub | |
788 maskTMA = maskTMA + resize(singleMaskTMA,maskTMA.shape,cv2.INTER_NEAREST) | |
789 cv2.putText(imagesub, str(iCore+1), (int(P[iCore].centroid[1]),int(P[iCore].centroid[0])), 0, 0.5, (np.amax(imagesub), np.amax(imagesub), np.amax(imagesub)), 1, cv2.LINE_AA) | |
790 | |
791 skio.imsave(maskOutputPath + os.path.sep + str(iCore+1) + '_mask.tif',np.uint8(TMAmask)) | |
792 print('Segmented core ' + str(iCore+1)) | |
793 | |
794 boundaries = find_boundaries(maskTMA) | |
795 imagesub = imagesub/np.percentile(imagesub,99.9) | |
796 imagesub[boundaries==1] = 1 | |
797 skio.imsave(outputPath + os.path.sep + 'TMA_MAP.tif' ,np.uint8(imagesub*255)) | |
798 print('Segmented all cores!') | |
799 | |
800 | |
801 #restore GPU to 0 | |
802 #image load using tifffile |