Mercurial > repos > perssond > unmicst
comparison batchUNet2DtCycif.py @ 1:74fe58ff55a5 draft default tip
planemo upload for repository https://github.com/HMS-IDAC/UnMicst commit e14f76a8803cab0013c6dbe809bc81d7667f2ab9
author | goeckslab |
---|---|
date | Wed, 07 Sep 2022 23:10:14 +0000 |
parents | 6bec4fef6b2e |
children |
comparison
equal
deleted
inserted
replaced
0:6bec4fef6b2e | 1:74fe58ff55a5 |
---|---|
1 import numpy as np | |
2 from scipy import misc | |
3 import tensorflow as tf | |
4 import shutil | |
5 import scipy.io as sio | |
6 import os,fnmatch,glob | |
7 import skimage.exposure as sk | |
8 | |
9 import sys | |
10 sys.path.insert(0, 'C:\\Users\\Clarence\\Documents\\UNet code\\ImageScience') | |
11 from toolbox.imtools import * | |
12 from toolbox.ftools import * | |
13 from toolbox.PartitionOfImage import PI2D | |
14 | |
15 | |
16 def concat3(lst): | |
17 return tf.concat(lst,3) | |
18 | |
19 class UNet2D: | |
20 hp = None # hyper-parameters | |
21 nn = None # network | |
22 tfTraining = None # if training or not (to handle batch norm) | |
23 tfData = None # data placeholder | |
24 Session = None | |
25 DatasetMean = 0 | |
26 DatasetStDev = 0 | |
27 | |
28 def setupWithHP(hp): | |
29 UNet2D.setup(hp['imSize'], | |
30 hp['nChannels'], | |
31 hp['nClasses'], | |
32 hp['nOut0'], | |
33 hp['featMapsFact'], | |
34 hp['downSampFact'], | |
35 hp['ks'], | |
36 hp['nExtraConvs'], | |
37 hp['stdDev0'], | |
38 hp['nLayers'], | |
39 hp['batchSize']) | |
40 | |
41 def setup(imSize,nChannels,nClasses,nOut0,featMapsFact,downSampFact,kernelSize,nExtraConvs,stdDev0,nDownSampLayers,batchSize): | |
42 UNet2D.hp = {'imSize':imSize, | |
43 'nClasses':nClasses, | |
44 'nChannels':nChannels, | |
45 'nExtraConvs':nExtraConvs, | |
46 'nLayers':nDownSampLayers, | |
47 'featMapsFact':featMapsFact, | |
48 'downSampFact':downSampFact, | |
49 'ks':kernelSize, | |
50 'nOut0':nOut0, | |
51 'stdDev0':stdDev0, | |
52 'batchSize':batchSize} | |
53 | |
54 nOutX = [UNet2D.hp['nChannels'],UNet2D.hp['nOut0']] | |
55 dsfX = [] | |
56 for i in range(UNet2D.hp['nLayers']): | |
57 nOutX.append(nOutX[-1]*UNet2D.hp['featMapsFact']) | |
58 dsfX.append(UNet2D.hp['downSampFact']) | |
59 | |
60 | |
61 # -------------------------------------------------- | |
62 # downsampling layer | |
63 # -------------------------------------------------- | |
64 | |
65 with tf.name_scope('placeholders'): | |
66 UNet2D.tfTraining = tf.placeholder(tf.bool, name='training') | |
67 UNet2D.tfData = tf.placeholder("float", shape=[None,UNet2D.hp['imSize'],UNet2D.hp['imSize'],UNet2D.hp['nChannels']],name='data') | |
68 | |
69 def down_samp_layer(data,index): | |
70 with tf.name_scope('ld%d' % index): | |
71 ldXWeights1 = tf.Variable(tf.truncated_normal([UNet2D.hp['ks'], UNet2D.hp['ks'], nOutX[index], nOutX[index+1]], stddev=stdDev0),name='kernel1') | |
72 ldXWeightsExtra = [] | |
73 for i in range(nExtraConvs): | |
74 ldXWeightsExtra.append(tf.Variable(tf.truncated_normal([UNet2D.hp['ks'], UNet2D.hp['ks'], nOutX[index+1], nOutX[index+1]], stddev=stdDev0),name='kernelExtra%d' % i)) | |
75 | |
76 c00 = tf.nn.conv2d(data, ldXWeights1, strides=[1, 1, 1, 1], padding='SAME') | |
77 for i in range(nExtraConvs): | |
78 c00 = tf.nn.conv2d(tf.nn.relu(c00), ldXWeightsExtra[i], strides=[1, 1, 1, 1], padding='SAME') | |
79 | |
80 ldXWeightsShortcut = tf.Variable(tf.truncated_normal([1, 1, nOutX[index], nOutX[index+1]], stddev=stdDev0),name='shortcutWeights') | |
81 shortcut = tf.nn.conv2d(data, ldXWeightsShortcut, strides=[1, 1, 1, 1], padding='SAME') | |
82 | |
83 bn = tf.layers.batch_normalization(tf.nn.relu(c00+shortcut), training=UNet2D.tfTraining) | |
84 | |
85 return tf.nn.max_pool(bn, ksize=[1, dsfX[index], dsfX[index], 1], strides=[1, dsfX[index], dsfX[index], 1], padding='SAME',name='maxpool') | |
86 | |
87 # -------------------------------------------------- | |
88 # bottom layer | |
89 # -------------------------------------------------- | |
90 | |
91 with tf.name_scope('lb'): | |
92 lbWeights1 = tf.Variable(tf.truncated_normal([UNet2D.hp['ks'], UNet2D.hp['ks'], nOutX[UNet2D.hp['nLayers']], nOutX[UNet2D.hp['nLayers']+1]], stddev=stdDev0),name='kernel1') | |
93 def lb(hidden): | |
94 return tf.nn.relu(tf.nn.conv2d(hidden, lbWeights1, strides=[1, 1, 1, 1], padding='SAME'),name='conv') | |
95 | |
96 # -------------------------------------------------- | |
97 # downsampling | |
98 # -------------------------------------------------- | |
99 | |
100 with tf.name_scope('downsampling'): | |
101 dsX = [] | |
102 dsX.append(UNet2D.tfData) | |
103 | |
104 for i in range(UNet2D.hp['nLayers']): | |
105 dsX.append(down_samp_layer(dsX[i],i)) | |
106 | |
107 b = lb(dsX[UNet2D.hp['nLayers']]) | |
108 | |
109 # -------------------------------------------------- | |
110 # upsampling layer | |
111 # -------------------------------------------------- | |
112 | |
113 def up_samp_layer(data,index): | |
114 with tf.name_scope('lu%d' % index): | |
115 luXWeights1 = tf.Variable(tf.truncated_normal([UNet2D.hp['ks'], UNet2D.hp['ks'], nOutX[index+1], nOutX[index+2]], stddev=stdDev0),name='kernel1') | |
116 luXWeights2 = tf.Variable(tf.truncated_normal([UNet2D.hp['ks'], UNet2D.hp['ks'], nOutX[index]+nOutX[index+1], nOutX[index+1]], stddev=stdDev0),name='kernel2') | |
117 luXWeightsExtra = [] | |
118 for i in range(nExtraConvs): | |
119 luXWeightsExtra.append(tf.Variable(tf.truncated_normal([UNet2D.hp['ks'], UNet2D.hp['ks'], nOutX[index+1], nOutX[index+1]], stddev=stdDev0),name='kernel2Extra%d' % i)) | |
120 | |
121 outSize = UNet2D.hp['imSize'] | |
122 for i in range(index): | |
123 outSize /= dsfX[i] | |
124 outSize = int(outSize) | |
125 | |
126 outputShape = [UNet2D.hp['batchSize'],outSize,outSize,nOutX[index+1]] | |
127 us = tf.nn.relu(tf.nn.conv2d_transpose(data, luXWeights1, outputShape, strides=[1, dsfX[index], dsfX[index], 1], padding='SAME'),name='conv1') | |
128 cc = concat3([dsX[index],us]) | |
129 cv = tf.nn.relu(tf.nn.conv2d(cc, luXWeights2, strides=[1, 1, 1, 1], padding='SAME'),name='conv2') | |
130 for i in range(nExtraConvs): | |
131 cv = tf.nn.relu(tf.nn.conv2d(cv, luXWeightsExtra[i], strides=[1, 1, 1, 1], padding='SAME'),name='conv2Extra%d' % i) | |
132 return cv | |
133 | |
134 # -------------------------------------------------- | |
135 # final (top) layer | |
136 # -------------------------------------------------- | |
137 | |
138 with tf.name_scope('lt'): | |
139 ltWeights1 = tf.Variable(tf.truncated_normal([1, 1, nOutX[1], nClasses], stddev=stdDev0),name='kernel') | |
140 def lt(hidden): | |
141 return tf.nn.conv2d(hidden, ltWeights1, strides=[1, 1, 1, 1], padding='SAME',name='conv') | |
142 | |
143 | |
144 # -------------------------------------------------- | |
145 # upsampling | |
146 # -------------------------------------------------- | |
147 | |
148 with tf.name_scope('upsampling'): | |
149 usX = [] | |
150 usX.append(b) | |
151 | |
152 for i in range(UNet2D.hp['nLayers']): | |
153 usX.append(up_samp_layer(usX[i],UNet2D.hp['nLayers']-1-i)) | |
154 | |
155 t = lt(usX[UNet2D.hp['nLayers']]) | |
156 | |
157 | |
158 sm = tf.nn.softmax(t,-1) | |
159 UNet2D.nn = sm | |
160 | |
161 | |
162 def train(imPath,logPath,modelPath,pmPath,nTrain,nValid,nTest,restoreVariables,nSteps,gpuIndex,testPMIndex): | |
163 os.environ['CUDA_VISIBLE_DEVICES']= '%d' % gpuIndex | |
164 | |
165 outLogPath = logPath | |
166 trainWriterPath = pathjoin(logPath,'Train') | |
167 validWriterPath = pathjoin(logPath,'Valid') | |
168 outModelPath = pathjoin(modelPath,'model.ckpt') | |
169 outPMPath = pmPath | |
170 | |
171 batchSize = UNet2D.hp['batchSize'] | |
172 imSize = UNet2D.hp['imSize'] | |
173 nChannels = UNet2D.hp['nChannels'] | |
174 nClasses = UNet2D.hp['nClasses'] | |
175 | |
176 # -------------------------------------------------- | |
177 # data | |
178 # -------------------------------------------------- | |
179 | |
180 Train = np.zeros((nTrain,imSize,imSize,nChannels)) | |
181 Valid = np.zeros((nValid,imSize,imSize,nChannels)) | |
182 Test = np.zeros((nTest,imSize,imSize,nChannels)) | |
183 LTrain = np.zeros((nTrain,imSize,imSize,nClasses)) | |
184 LValid = np.zeros((nValid,imSize,imSize,nClasses)) | |
185 LTest = np.zeros((nTest,imSize,imSize,nClasses)) | |
186 | |
187 print('loading data, computing mean / st dev') | |
188 if not os.path.exists(modelPath): | |
189 os.makedirs(modelPath) | |
190 if restoreVariables: | |
191 datasetMean = loadData(pathjoin(modelPath,'datasetMean.data')) | |
192 datasetStDev = loadData(pathjoin(modelPath,'datasetStDev.data')) | |
193 else: | |
194 datasetMean = 0 | |
195 datasetStDev = 0 | |
196 for iSample in range(nTrain+nValid+nTest): | |
197 I = im2double(tifread('%s/I%05d_Img.tif' % (imPath,iSample))) | |
198 datasetMean += np.mean(I) | |
199 datasetStDev += np.std(I) | |
200 datasetMean /= (nTrain+nValid+nTest) | |
201 datasetStDev /= (nTrain+nValid+nTest) | |
202 saveData(datasetMean, pathjoin(modelPath,'datasetMean.data')) | |
203 saveData(datasetStDev, pathjoin(modelPath,'datasetStDev.data')) | |
204 | |
205 perm = np.arange(nTrain+nValid+nTest) | |
206 np.random.shuffle(perm) | |
207 | |
208 for iSample in range(0, nTrain): | |
209 path = '%s/I%05d_Img.tif' % (imPath,perm[iSample]) | |
210 im = im2double(tifread(path)) | |
211 Train[iSample,:,:,0] = (im-datasetMean)/datasetStDev | |
212 path = '%s/I%05d_Ant.tif' % (imPath,perm[iSample]) | |
213 im = tifread(path) | |
214 for i in range(nClasses): | |
215 LTrain[iSample,:,:,i] = (im == i+1) | |
216 | |
217 for iSample in range(0, nValid): | |
218 path = '%s/I%05d_Img.tif' % (imPath,perm[nTrain+iSample]) | |
219 im = im2double(tifread(path)) | |
220 Valid[iSample,:,:,0] = (im-datasetMean)/datasetStDev | |
221 path = '%s/I%05d_Ant.tif' % (imPath,perm[nTrain+iSample]) | |
222 im = tifread(path) | |
223 for i in range(nClasses): | |
224 LValid[iSample,:,:,i] = (im == i+1) | |
225 | |
226 for iSample in range(0, nTest): | |
227 path = '%s/I%05d_Img.tif' % (imPath,perm[nTrain+nValid+iSample]) | |
228 im = im2double(tifread(path)) | |
229 Test[iSample,:,:,0] = (im-datasetMean)/datasetStDev | |
230 path = '%s/I%05d_Ant.tif' % (imPath,perm[nTrain+nValid+iSample]) | |
231 im = tifread(path) | |
232 for i in range(nClasses): | |
233 LTest[iSample,:,:,i] = (im == i+1) | |
234 | |
235 # -------------------------------------------------- | |
236 # optimization | |
237 # -------------------------------------------------- | |
238 | |
239 tfLabels = tf.placeholder("float", shape=[None,imSize,imSize,nClasses],name='labels') | |
240 | |
241 globalStep = tf.Variable(0,trainable=False) | |
242 learningRate0 = 0.01 | |
243 decaySteps = 1000 | |
244 decayRate = 0.95 | |
245 learningRate = tf.train.exponential_decay(learningRate0,globalStep,decaySteps,decayRate,staircase=True) | |
246 | |
247 with tf.name_scope('optim'): | |
248 loss = tf.reduce_mean(-tf.reduce_sum(tf.multiply(tfLabels,tf.log(UNet2D.nn)),3)) | |
249 updateOps = tf.get_collection(tf.GraphKeys.UPDATE_OPS) | |
250 # optimizer = tf.train.MomentumOptimizer(1e-3,0.9) | |
251 optimizer = tf.train.MomentumOptimizer(learningRate,0.9) | |
252 # optimizer = tf.train.GradientDescentOptimizer(learningRate) | |
253 with tf.control_dependencies(updateOps): | |
254 optOp = optimizer.minimize(loss,global_step=globalStep) | |
255 | |
256 with tf.name_scope('eval'): | |
257 error = [] | |
258 for iClass in range(nClasses): | |
259 labels0 = tf.reshape(tf.to_int32(tf.slice(tfLabels,[0,0,0,iClass],[-1,-1,-1,1])),[batchSize,imSize,imSize]) | |
260 predict0 = tf.reshape(tf.to_int32(tf.equal(tf.argmax(UNet2D.nn,3),iClass)),[batchSize,imSize,imSize]) | |
261 correct = tf.multiply(labels0,predict0) | |
262 nCorrect0 = tf.reduce_sum(correct) | |
263 nLabels0 = tf.reduce_sum(labels0) | |
264 error.append(1-tf.to_float(nCorrect0)/tf.to_float(nLabels0)) | |
265 errors = tf.tuple(error) | |
266 | |
267 # -------------------------------------------------- | |
268 # inspection | |
269 # -------------------------------------------------- | |
270 | |
271 with tf.name_scope('scalars'): | |
272 tf.summary.scalar('avg_cross_entropy', loss) | |
273 for iClass in range(nClasses): | |
274 tf.summary.scalar('avg_pixel_error_%d' % iClass, error[iClass]) | |
275 tf.summary.scalar('learning_rate', learningRate) | |
276 with tf.name_scope('images'): | |
277 split0 = tf.slice(UNet2D.nn,[0,0,0,0],[-1,-1,-1,1]) | |
278 split1 = tf.slice(UNet2D.nn,[0,0,0,1],[-1,-1,-1,1]) | |
279 if nClasses > 2: | |
280 split2 = tf.slice(UNet2D.nn,[0,0,0,2],[-1,-1,-1,1]) | |
281 tf.summary.image('pm0',split0) | |
282 tf.summary.image('pm1',split1) | |
283 if nClasses > 2: | |
284 tf.summary.image('pm2',split2) | |
285 merged = tf.summary.merge_all() | |
286 | |
287 | |
288 # -------------------------------------------------- | |
289 # session | |
290 # -------------------------------------------------- | |
291 | |
292 saver = tf.train.Saver() | |
293 sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) # config parameter needed to save variables when using GPU | |
294 | |
295 if os.path.exists(outLogPath): | |
296 shutil.rmtree(outLogPath) | |
297 trainWriter = tf.summary.FileWriter(trainWriterPath, sess.graph) | |
298 validWriter = tf.summary.FileWriter(validWriterPath, sess.graph) | |
299 | |
300 if restoreVariables: | |
301 saver.restore(sess, outModelPath) | |
302 print("Model restored.") | |
303 else: | |
304 sess.run(tf.global_variables_initializer()) | |
305 | |
306 # -------------------------------------------------- | |
307 # train | |
308 # -------------------------------------------------- | |
309 | |
310 batchData = np.zeros((batchSize,imSize,imSize,nChannels)) | |
311 batchLabels = np.zeros((batchSize,imSize,imSize,nClasses)) | |
312 for i in range(nSteps): | |
313 # train | |
314 | |
315 perm = np.arange(nTrain) | |
316 np.random.shuffle(perm) | |
317 | |
318 for j in range(batchSize): | |
319 batchData[j,:,:,:] = Train[perm[j],:,:,:] | |
320 batchLabels[j,:,:,:] = LTrain[perm[j],:,:,:] | |
321 | |
322 summary,_ = sess.run([merged,optOp],feed_dict={UNet2D.tfData: batchData, tfLabels: batchLabels, UNet2D.tfTraining: 1}) | |
323 trainWriter.add_summary(summary, i) | |
324 | |
325 # validation | |
326 | |
327 perm = np.arange(nValid) | |
328 np.random.shuffle(perm) | |
329 | |
330 for j in range(batchSize): | |
331 batchData[j,:,:,:] = Valid[perm[j],:,:,:] | |
332 batchLabels[j,:,:,:] = LValid[perm[j],:,:,:] | |
333 | |
334 summary, es = sess.run([merged, errors],feed_dict={UNet2D.tfData: batchData, tfLabels: batchLabels, UNet2D.tfTraining: 0}) | |
335 validWriter.add_summary(summary, i) | |
336 | |
337 e = np.mean(es) | |
338 print('step %05d, e: %f' % (i,e)) | |
339 | |
340 if i == 0: | |
341 if restoreVariables: | |
342 lowestError = e | |
343 else: | |
344 lowestError = np.inf | |
345 | |
346 if np.mod(i,100) == 0 and e < lowestError: | |
347 lowestError = e | |
348 print("Model saved in file: %s" % saver.save(sess, outModelPath)) | |
349 | |
350 | |
351 # -------------------------------------------------- | |
352 # test | |
353 # -------------------------------------------------- | |
354 | |
355 if not os.path.exists(outPMPath): | |
356 os.makedirs(outPMPath) | |
357 | |
358 for i in range(nTest): | |
359 j = np.mod(i,batchSize) | |
360 | |
361 batchData[j,:,:,:] = Test[i,:,:,:] | |
362 batchLabels[j,:,:,:] = LTest[i,:,:,:] | |
363 | |
364 if j == batchSize-1 or i == nTest-1: | |
365 | |
366 output = sess.run(UNet2D.nn,feed_dict={UNet2D.tfData: batchData, tfLabels: batchLabels, UNet2D.tfTraining: 0}) | |
367 | |
368 for k in range(j+1): | |
369 pm = output[k,:,:,testPMIndex] | |
370 gt = batchLabels[k,:,:,testPMIndex] | |
371 im = np.sqrt(normalize(batchData[k,:,:,0])) | |
372 imwrite(np.uint8(255*np.concatenate((im,np.concatenate((pm,gt),axis=1)),axis=1)),'%s/I%05d.png' % (outPMPath,i-j+k+1)) | |
373 | |
374 | |
375 # -------------------------------------------------- | |
376 # save hyper-parameters, clean-up | |
377 # -------------------------------------------------- | |
378 | |
379 saveData(UNet2D.hp,pathjoin(modelPath,'hp.data')) | |
380 | |
381 trainWriter.close() | |
382 validWriter.close() | |
383 sess.close() | |
384 | |
385 def deploy(imPath,nImages,modelPath,pmPath,gpuIndex,pmIndex): | |
386 os.environ['CUDA_VISIBLE_DEVICES']= '%d' % gpuIndex | |
387 | |
388 variablesPath = pathjoin(modelPath,'model.ckpt') | |
389 outPMPath = pmPath | |
390 | |
391 hp = loadData(pathjoin(modelPath,'hp.data')) | |
392 UNet2D.setupWithHP(hp) | |
393 | |
394 batchSize = UNet2D.hp['batchSize'] | |
395 imSize = UNet2D.hp['imSize'] | |
396 nChannels = UNet2D.hp['nChannels'] | |
397 nClasses = UNet2D.hp['nClasses'] | |
398 | |
399 # -------------------------------------------------- | |
400 # data | |
401 # -------------------------------------------------- | |
402 | |
403 Data = np.zeros((nImages,imSize,imSize,nChannels)) | |
404 | |
405 datasetMean = loadData(pathjoin(modelPath,'datasetMean.data')) | |
406 datasetStDev = loadData(pathjoin(modelPath,'datasetStDev.data')) | |
407 | |
408 for iSample in range(0, nImages): | |
409 path = '%s/I%05d_Img.tif' % (imPath,iSample) | |
410 im = im2double(tifread(path)) | |
411 Data[iSample,:,:,0] = (im-datasetMean)/datasetStDev | |
412 | |
413 # -------------------------------------------------- | |
414 # session | |
415 # -------------------------------------------------- | |
416 | |
417 saver = tf.train.Saver() | |
418 sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) # config parameter needed to save variables when using GPU | |
419 | |
420 saver.restore(sess, variablesPath) | |
421 print("Model restored.") | |
422 | |
423 # -------------------------------------------------- | |
424 # deploy | |
425 # -------------------------------------------------- | |
426 | |
427 batchData = np.zeros((batchSize,imSize,imSize,nChannels)) | |
428 | |
429 if not os.path.exists(outPMPath): | |
430 os.makedirs(outPMPath) | |
431 | |
432 for i in range(nImages): | |
433 print(i,nImages) | |
434 | |
435 j = np.mod(i,batchSize) | |
436 | |
437 batchData[j,:,:,:] = Data[i,:,:,:] | |
438 | |
439 if j == batchSize-1 or i == nImages-1: | |
440 | |
441 output = sess.run(UNet2D.nn,feed_dict={UNet2D.tfData: batchData, UNet2D.tfTraining: 0}) | |
442 | |
443 for k in range(j+1): | |
444 pm = output[k,:,:,pmIndex] | |
445 im = np.sqrt(normalize(batchData[k,:,:,0])) | |
446 # imwrite(np.uint8(255*np.concatenate((im,pm),axis=1)),'%s/I%05d.png' % (outPMPath,i-j+k+1)) | |
447 imwrite(np.uint8(255*im),'%s/I%05d_Im.png' % (outPMPath,i-j+k+1)) | |
448 imwrite(np.uint8(255*pm),'%s/I%05d_PM.png' % (outPMPath,i-j+k+1)) | |
449 | |
450 | |
451 # -------------------------------------------------- | |
452 # clean-up | |
453 # -------------------------------------------------- | |
454 | |
455 sess.close() | |
456 | |
457 def singleImageInferenceSetup(modelPath,gpuIndex): | |
458 #os.environ['CUDA_VISIBLE_DEVICES']= '%d' % gpuIndex | |
459 | |
460 variablesPath = pathjoin(modelPath,'model.ckpt') | |
461 | |
462 hp = loadData(pathjoin(modelPath,'hp.data')) | |
463 UNet2D.setupWithHP(hp) | |
464 | |
465 UNet2D.DatasetMean = loadData(pathjoin(modelPath,'datasetMean.data')) | |
466 UNet2D.DatasetStDev = loadData(pathjoin(modelPath,'datasetStDev.data')) | |
467 print(UNet2D.DatasetMean) | |
468 print(UNet2D.DatasetStDev) | |
469 | |
470 # -------------------------------------------------- | |
471 # session | |
472 # -------------------------------------------------- | |
473 | |
474 saver = tf.train.Saver() | |
475 UNet2D.Session = tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) # config parameter needed to save variables when using GPU | |
476 | |
477 saver.restore(UNet2D.Session, variablesPath) | |
478 print("Model restored.") | |
479 | |
480 def singleImageInferenceCleanup(): | |
481 UNet2D.Session.close() | |
482 | |
483 def singleImageInference(image,mode,pmIndex): | |
484 print('Inference...') | |
485 | |
486 batchSize = UNet2D.hp['batchSize'] | |
487 imSize = UNet2D.hp['imSize'] | |
488 nChannels = UNet2D.hp['nChannels'] | |
489 | |
490 PI2D.setup(image,imSize,int(imSize/8),mode) | |
491 PI2D.createOutput(nChannels) | |
492 | |
493 batchData = np.zeros((batchSize,imSize,imSize,nChannels)) | |
494 for i in range(PI2D.NumPatches): | |
495 j = np.mod(i,batchSize) | |
496 batchData[j,:,:,0] = (PI2D.getPatch(i)-UNet2D.DatasetMean)/UNet2D.DatasetStDev | |
497 if j == batchSize-1 or i == PI2D.NumPatches-1: | |
498 output = UNet2D.Session.run(UNet2D.nn,feed_dict={UNet2D.tfData: batchData, UNet2D.tfTraining: 0}) | |
499 for k in range(j+1): | |
500 pm = output[k,:,:,pmIndex] | |
501 PI2D.patchOutput(i-j+k,pm) | |
502 # PI2D.patchOutput(i-j+k,normalize(imgradmag(PI2D.getPatch(i-j+k),1))) | |
503 | |
504 return PI2D.getValidOutput() | |
505 | |
506 | |
507 if __name__ == '__main__': | |
508 logPath = 'C://Users//Clarence//Documents//UNet code//TFLogs' | |
509 modelPath = 'D:\\LSP\\UNet\\tonsil20x1bin1chan\\TFModel - 3class 16 kernels 5ks 2 layers' | |
510 pmPath = 'C://Users//Clarence//Documents//UNet code//TFProbMaps' | |
511 | |
512 | |
513 | |
514 UNet2D.singleImageInferenceSetup(modelPath, 0) | |
515 imagePath = 'D:\\LSP\\cycif\\testsets' | |
516 sampleList = glob.glob(imagePath + '//exemplar-001*') | |
517 dapiChannel = 0 | |
518 dsFactor = 1 | |
519 for iSample in sampleList: | |
520 fileList = glob.glob(iSample + '//registration//*.tif') | |
521 print(fileList) | |
522 for iFile in fileList: | |
523 fileName = os.path.basename(iFile) | |
524 fileNamePrefix = fileName.split(os.extsep, 1) | |
525 I = tifffile.imread(iFile, key=dapiChannel) | |
526 rawI = I | |
527 hsize = int((float(I.shape[0])*float(dsFactor))) | |
528 vsize = int((float(I.shape[1])*float(dsFactor))) | |
529 I = resize(I,(hsize,vsize)) | |
530 I = im2double(sk.rescale_intensity(I, in_range=(np.min(I), np.max(I)), out_range=(0, 0.983))) | |
531 rawI = im2double(rawI)/np.max(im2double(rawI)) | |
532 outputPath = iSample + '//prob_maps' | |
533 if not os.path.exists(outputPath): | |
534 os.makedirs(outputPath) | |
535 K = np.zeros((2,rawI.shape[0],rawI.shape[1])) | |
536 contours = UNet2D.singleImageInference(I,'accumulate',1) | |
537 hsize = int((float(I.shape[0]) * float(1/dsFactor))) | |
538 vsize = int((float(I.shape[1]) * float(1/dsFactor))) | |
539 contours = resize(contours, (rawI.shape[0], rawI.shape[1])) | |
540 K[1,:,:] = rawI | |
541 K[0,:,:] = contours | |
542 tifwrite(np.uint8(255 * K), | |
543 outputPath + '//' + fileNamePrefix[0] + '_ContoursPM_' + str(dapiChannel + 1) + '.tif') | |
544 del K | |
545 K = np.zeros((1, rawI.shape[0], rawI.shape[1])) | |
546 nuclei = UNet2D.singleImageInference(I,'accumulate',2) | |
547 nuclei = resize(nuclei, (rawI.shape[0], rawI.shape[1])) | |
548 K[0, :, :] = nuclei | |
549 tifwrite(np.uint8(255 * K), | |
550 outputPath + '//' + fileNamePrefix[0] + '_NucleiPM_' + str(dapiChannel + 1) + '.tif') | |
551 del K | |
552 UNet2D.singleImageInferenceCleanup() | |
553 |