Mercurial > repos > perssond > unmicst
comparison batchUnMicst.py @ 0:6bec4fef6b2e draft
"planemo upload for repository https://github.com/ohsu-comp-bio/unmicst commit 73e4cae15f2d7cdc86719e77470eb00af4b6ebb7-dirty"
author | perssond |
---|---|
date | Fri, 12 Mar 2021 00:17:29 +0000 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:6bec4fef6b2e |
---|---|
1 import numpy as np | |
2 from scipy import misc | |
3 import tensorflow as tf | |
4 import shutil | |
5 import scipy.io as sio | |
6 import os, fnmatch, PIL, glob | |
7 import skimage.exposure as sk | |
8 import argparse | |
9 | |
10 import sys | |
11 | |
12 sys.path.insert(0, 'C:\\Users\\Public\\Documents\\ImageScience') | |
13 from toolbox.imtools import * | |
14 from toolbox.ftools import * | |
15 from toolbox.PartitionOfImage import PI2D | |
16 | |
17 | |
18 def concat3(lst): | |
19 return tf.concat(lst, 3) | |
20 | |
21 | |
22 class UNet2D: | |
23 hp = None # hyper-parameters | |
24 nn = None # network | |
25 tfTraining = None # if training or not (to handle batch norm) | |
26 tfData = None # data placeholder | |
27 Session = None | |
28 DatasetMean = 0 | |
29 DatasetStDev = 0 | |
30 | |
31 def setupWithHP(hp): | |
32 UNet2D.setup(hp['imSize'], | |
33 hp['nChannels'], | |
34 hp['nClasses'], | |
35 hp['nOut0'], | |
36 hp['featMapsFact'], | |
37 hp['downSampFact'], | |
38 hp['ks'], | |
39 hp['nExtraConvs'], | |
40 hp['stdDev0'], | |
41 hp['nLayers'], | |
42 hp['batchSize']) | |
43 | |
44 def setup(imSize, nChannels, nClasses, nOut0, featMapsFact, downSampFact, kernelSize, nExtraConvs, stdDev0, | |
45 nDownSampLayers, batchSize): | |
46 UNet2D.hp = {'imSize': imSize, | |
47 'nClasses': nClasses, | |
48 'nChannels': nChannels, | |
49 'nExtraConvs': nExtraConvs, | |
50 'nLayers': nDownSampLayers, | |
51 'featMapsFact': featMapsFact, | |
52 'downSampFact': downSampFact, | |
53 'ks': kernelSize, | |
54 'nOut0': nOut0, | |
55 'stdDev0': stdDev0, | |
56 'batchSize': batchSize} | |
57 | |
58 nOutX = [UNet2D.hp['nChannels'], UNet2D.hp['nOut0']] | |
59 dsfX = [] | |
60 for i in range(UNet2D.hp['nLayers']): | |
61 nOutX.append(nOutX[-1] * UNet2D.hp['featMapsFact']) | |
62 dsfX.append(UNet2D.hp['downSampFact']) | |
63 | |
64 # -------------------------------------------------- | |
65 # downsampling layer | |
66 # -------------------------------------------------- | |
67 | |
68 with tf.name_scope('placeholders'): | |
69 UNet2D.tfTraining = tf.placeholder(tf.bool, name='training') | |
70 UNet2D.tfData = tf.placeholder("float", shape=[None, UNet2D.hp['imSize'], UNet2D.hp['imSize'], | |
71 UNet2D.hp['nChannels']], name='data') | |
72 | |
73 def down_samp_layer(data, index): | |
74 with tf.name_scope('ld%d' % index): | |
75 ldXWeights1 = tf.Variable( | |
76 tf.truncated_normal([UNet2D.hp['ks'], UNet2D.hp['ks'], nOutX[index], nOutX[index + 1]], | |
77 stddev=stdDev0), name='kernel1') | |
78 ldXWeightsExtra = [] | |
79 for i in range(nExtraConvs): | |
80 ldXWeightsExtra.append(tf.Variable( | |
81 tf.truncated_normal([UNet2D.hp['ks'], UNet2D.hp['ks'], nOutX[index + 1], nOutX[index + 1]], | |
82 stddev=stdDev0), name='kernelExtra%d' % i)) | |
83 | |
84 c00 = tf.nn.conv2d(data, ldXWeights1, strides=[1, 1, 1, 1], padding='SAME') | |
85 for i in range(nExtraConvs): | |
86 c00 = tf.nn.conv2d(tf.nn.relu(c00), ldXWeightsExtra[i], strides=[1, 1, 1, 1], padding='SAME') | |
87 | |
88 ldXWeightsShortcut = tf.Variable( | |
89 tf.truncated_normal([1, 1, nOutX[index], nOutX[index + 1]], stddev=stdDev0), name='shortcutWeights') | |
90 shortcut = tf.nn.conv2d(data, ldXWeightsShortcut, strides=[1, 1, 1, 1], padding='SAME') | |
91 | |
92 bn = tf.layers.batch_normalization(tf.nn.relu(c00 + shortcut), training=UNet2D.tfTraining) | |
93 | |
94 return tf.nn.max_pool(bn, ksize=[1, dsfX[index], dsfX[index], 1], | |
95 strides=[1, dsfX[index], dsfX[index], 1], padding='SAME', name='maxpool') | |
96 | |
97 # -------------------------------------------------- | |
98 # bottom layer | |
99 # -------------------------------------------------- | |
100 | |
101 with tf.name_scope('lb'): | |
102 lbWeights1 = tf.Variable(tf.truncated_normal( | |
103 [UNet2D.hp['ks'], UNet2D.hp['ks'], nOutX[UNet2D.hp['nLayers']], nOutX[UNet2D.hp['nLayers'] + 1]], | |
104 stddev=stdDev0), name='kernel1') | |
105 | |
106 def lb(hidden): | |
107 return tf.nn.relu(tf.nn.conv2d(hidden, lbWeights1, strides=[1, 1, 1, 1], padding='SAME'), name='conv') | |
108 | |
109 # -------------------------------------------------- | |
110 # downsampling | |
111 # -------------------------------------------------- | |
112 | |
113 with tf.name_scope('downsampling'): | |
114 dsX = [] | |
115 dsX.append(UNet2D.tfData) | |
116 | |
117 for i in range(UNet2D.hp['nLayers']): | |
118 dsX.append(down_samp_layer(dsX[i], i)) | |
119 | |
120 b = lb(dsX[UNet2D.hp['nLayers']]) | |
121 | |
122 # -------------------------------------------------- | |
123 # upsampling layer | |
124 # -------------------------------------------------- | |
125 | |
126 def up_samp_layer(data, index): | |
127 with tf.name_scope('lu%d' % index): | |
128 luXWeights1 = tf.Variable( | |
129 tf.truncated_normal([UNet2D.hp['ks'], UNet2D.hp['ks'], nOutX[index + 1], nOutX[index + 2]], | |
130 stddev=stdDev0), name='kernel1') | |
131 luXWeights2 = tf.Variable(tf.truncated_normal( | |
132 [UNet2D.hp['ks'], UNet2D.hp['ks'], nOutX[index] + nOutX[index + 1], nOutX[index + 1]], | |
133 stddev=stdDev0), name='kernel2') | |
134 luXWeightsExtra = [] | |
135 for i in range(nExtraConvs): | |
136 luXWeightsExtra.append(tf.Variable( | |
137 tf.truncated_normal([UNet2D.hp['ks'], UNet2D.hp['ks'], nOutX[index + 1], nOutX[index + 1]], | |
138 stddev=stdDev0), name='kernel2Extra%d' % i)) | |
139 | |
140 outSize = UNet2D.hp['imSize'] | |
141 for i in range(index): | |
142 outSize /= dsfX[i] | |
143 outSize = int(outSize) | |
144 | |
145 outputShape = [UNet2D.hp['batchSize'], outSize, outSize, nOutX[index + 1]] | |
146 us = tf.nn.relu( | |
147 tf.nn.conv2d_transpose(data, luXWeights1, outputShape, strides=[1, dsfX[index], dsfX[index], 1], | |
148 padding='SAME'), name='conv1') | |
149 cc = concat3([dsX[index], us]) | |
150 cv = tf.nn.relu(tf.nn.conv2d(cc, luXWeights2, strides=[1, 1, 1, 1], padding='SAME'), name='conv2') | |
151 for i in range(nExtraConvs): | |
152 cv = tf.nn.relu(tf.nn.conv2d(cv, luXWeightsExtra[i], strides=[1, 1, 1, 1], padding='SAME'), | |
153 name='conv2Extra%d' % i) | |
154 return cv | |
155 | |
156 # -------------------------------------------------- | |
157 # final (top) layer | |
158 # -------------------------------------------------- | |
159 | |
160 with tf.name_scope('lt'): | |
161 ltWeights1 = tf.Variable(tf.truncated_normal([1, 1, nOutX[1], nClasses], stddev=stdDev0), name='kernel') | |
162 | |
163 def lt(hidden): | |
164 return tf.nn.conv2d(hidden, ltWeights1, strides=[1, 1, 1, 1], padding='SAME', name='conv') | |
165 | |
166 # -------------------------------------------------- | |
167 # upsampling | |
168 # -------------------------------------------------- | |
169 | |
170 with tf.name_scope('upsampling'): | |
171 usX = [] | |
172 usX.append(b) | |
173 | |
174 for i in range(UNet2D.hp['nLayers']): | |
175 usX.append(up_samp_layer(usX[i], UNet2D.hp['nLayers'] - 1 - i)) | |
176 | |
177 t = lt(usX[UNet2D.hp['nLayers']]) | |
178 | |
179 sm = tf.nn.softmax(t, -1) | |
180 UNet2D.nn = sm | |
181 | |
182 def train(imPath, logPath, modelPath, pmPath, nTrain, nValid, nTest, restoreVariables, nSteps, gpuIndex, | |
183 testPMIndex): | |
184 os.environ['CUDA_VISIBLE_DEVICES'] = '%d' % gpuIndex | |
185 | |
186 outLogPath = logPath | |
187 trainWriterPath = pathjoin(logPath, 'Train') | |
188 validWriterPath = pathjoin(logPath, 'Valid') | |
189 outModelPath = pathjoin(modelPath, 'model.ckpt') | |
190 outPMPath = pmPath | |
191 | |
192 batchSize = UNet2D.hp['batchSize'] | |
193 imSize = UNet2D.hp['imSize'] | |
194 nChannels = UNet2D.hp['nChannels'] | |
195 nClasses = UNet2D.hp['nClasses'] | |
196 | |
197 # -------------------------------------------------- | |
198 # data | |
199 # -------------------------------------------------- | |
200 | |
201 Train = np.zeros((nTrain, imSize, imSize, nChannels)) | |
202 Valid = np.zeros((nValid, imSize, imSize, nChannels)) | |
203 Test = np.zeros((nTest, imSize, imSize, nChannels)) | |
204 LTrain = np.zeros((nTrain, imSize, imSize, nClasses)) | |
205 LValid = np.zeros((nValid, imSize, imSize, nClasses)) | |
206 LTest = np.zeros((nTest, imSize, imSize, nClasses)) | |
207 | |
208 print('loading data, computing mean / st dev') | |
209 if not os.path.exists(modelPath): | |
210 os.makedirs(modelPath) | |
211 if restoreVariables: | |
212 datasetMean = loadData(pathjoin(modelPath, 'datasetMean.data')) | |
213 datasetStDev = loadData(pathjoin(modelPath, 'datasetStDev.data')) | |
214 else: | |
215 datasetMean = 0 | |
216 datasetStDev = 0 | |
217 for iSample in range(nTrain + nValid + nTest): | |
218 I = im2double(tifread('%s/I%05d_Img.tif' % (imPath, iSample))) | |
219 datasetMean += np.mean(I) | |
220 datasetStDev += np.std(I) | |
221 datasetMean /= (nTrain + nValid + nTest) | |
222 datasetStDev /= (nTrain + nValid + nTest) | |
223 saveData(datasetMean, pathjoin(modelPath, 'datasetMean.data')) | |
224 saveData(datasetStDev, pathjoin(modelPath, 'datasetStDev.data')) | |
225 | |
226 perm = np.arange(nTrain + nValid + nTest) | |
227 np.random.shuffle(perm) | |
228 | |
229 for iSample in range(0, nTrain): | |
230 path = '%s/I%05d_Img.tif' % (imPath, perm[iSample]) | |
231 im = im2double(tifread(path)) | |
232 Train[iSample, :, :, 0] = (im - datasetMean) / datasetStDev | |
233 path = '%s/I%05d_Ant.tif' % (imPath, perm[iSample]) | |
234 im = tifread(path) | |
235 for i in range(nClasses): | |
236 LTrain[iSample, :, :, i] = (im == i + 1) | |
237 | |
238 for iSample in range(0, nValid): | |
239 path = '%s/I%05d_Img.tif' % (imPath, perm[nTrain + iSample]) | |
240 im = im2double(tifread(path)) | |
241 Valid[iSample, :, :, 0] = (im - datasetMean) / datasetStDev | |
242 path = '%s/I%05d_Ant.tif' % (imPath, perm[nTrain + iSample]) | |
243 im = tifread(path) | |
244 for i in range(nClasses): | |
245 LValid[iSample, :, :, i] = (im == i + 1) | |
246 | |
247 for iSample in range(0, nTest): | |
248 path = '%s/I%05d_Img.tif' % (imPath, perm[nTrain + nValid + iSample]) | |
249 im = im2double(tifread(path)) | |
250 Test[iSample, :, :, 0] = (im - datasetMean) / datasetStDev | |
251 path = '%s/I%05d_Ant.tif' % (imPath, perm[nTrain + nValid + iSample]) | |
252 im = tifread(path) | |
253 for i in range(nClasses): | |
254 LTest[iSample, :, :, i] = (im == i + 1) | |
255 | |
256 # -------------------------------------------------- | |
257 # optimization | |
258 # -------------------------------------------------- | |
259 | |
260 tfLabels = tf.placeholder("float", shape=[None, imSize, imSize, nClasses], name='labels') | |
261 | |
262 globalStep = tf.Variable(0, trainable=False) | |
263 learningRate0 = 0.01 | |
264 decaySteps = 1000 | |
265 decayRate = 0.95 | |
266 learningRate = tf.train.exponential_decay(learningRate0, globalStep, decaySteps, decayRate, staircase=True) | |
267 | |
268 with tf.name_scope('optim'): | |
269 loss = tf.reduce_mean(-tf.reduce_sum(tf.multiply(tfLabels, tf.log(UNet2D.nn)), 3)) | |
270 updateOps = tf.get_collection(tf.GraphKeys.UPDATE_OPS) | |
271 # optimizer = tf.train.MomentumOptimizer(1e-3,0.9) | |
272 optimizer = tf.train.MomentumOptimizer(learningRate, 0.9) | |
273 # optimizer = tf.train.GradientDescentOptimizer(learningRate) | |
274 with tf.control_dependencies(updateOps): | |
275 optOp = optimizer.minimize(loss, global_step=globalStep) | |
276 | |
277 with tf.name_scope('eval'): | |
278 error = [] | |
279 for iClass in range(nClasses): | |
280 labels0 = tf.reshape(tf.to_int32(tf.slice(tfLabels, [0, 0, 0, iClass], [-1, -1, -1, 1])), | |
281 [batchSize, imSize, imSize]) | |
282 predict0 = tf.reshape(tf.to_int32(tf.equal(tf.argmax(UNet2D.nn, 3), iClass)), | |
283 [batchSize, imSize, imSize]) | |
284 correct = tf.multiply(labels0, predict0) | |
285 nCorrect0 = tf.reduce_sum(correct) | |
286 nLabels0 = tf.reduce_sum(labels0) | |
287 error.append(1 - tf.to_float(nCorrect0) / tf.to_float(nLabels0)) | |
288 errors = tf.tuple(error) | |
289 | |
290 # -------------------------------------------------- | |
291 # inspection | |
292 # -------------------------------------------------- | |
293 | |
294 with tf.name_scope('scalars'): | |
295 tf.summary.scalar('avg_cross_entropy', loss) | |
296 for iClass in range(nClasses): | |
297 tf.summary.scalar('avg_pixel_error_%d' % iClass, error[iClass]) | |
298 tf.summary.scalar('learning_rate', learningRate) | |
299 with tf.name_scope('images'): | |
300 split0 = tf.slice(UNet2D.nn, [0, 0, 0, 0], [-1, -1, -1, 1]) | |
301 split1 = tf.slice(UNet2D.nn, [0, 0, 0, 1], [-1, -1, -1, 1]) | |
302 if nClasses > 2: | |
303 split2 = tf.slice(UNet2D.nn, [0, 0, 0, 2], [-1, -1, -1, 1]) | |
304 tf.summary.image('pm0', split0) | |
305 tf.summary.image('pm1', split1) | |
306 if nClasses > 2: | |
307 tf.summary.image('pm2', split2) | |
308 merged = tf.summary.merge_all() | |
309 | |
310 # -------------------------------------------------- | |
311 # session | |
312 # -------------------------------------------------- | |
313 | |
314 saver = tf.train.Saver() | |
315 sess = tf.Session(config=tf.ConfigProto( | |
316 allow_soft_placement=True)) # config parameter needed to save variables when using GPU | |
317 | |
318 if os.path.exists(outLogPath): | |
319 shutil.rmtree(outLogPath) | |
320 trainWriter = tf.summary.FileWriter(trainWriterPath, sess.graph) | |
321 validWriter = tf.summary.FileWriter(validWriterPath, sess.graph) | |
322 | |
323 if restoreVariables: | |
324 saver.restore(sess, outModelPath) | |
325 print("Model restored.") | |
326 else: | |
327 sess.run(tf.global_variables_initializer()) | |
328 | |
329 # -------------------------------------------------- | |
330 # train | |
331 # -------------------------------------------------- | |
332 | |
333 batchData = np.zeros((batchSize, imSize, imSize, nChannels)) | |
334 batchLabels = np.zeros((batchSize, imSize, imSize, nClasses)) | |
335 for i in range(nSteps): | |
336 # train | |
337 | |
338 perm = np.arange(nTrain) | |
339 np.random.shuffle(perm) | |
340 | |
341 for j in range(batchSize): | |
342 batchData[j, :, :, :] = Train[perm[j], :, :, :] | |
343 batchLabels[j, :, :, :] = LTrain[perm[j], :, :, :] | |
344 | |
345 summary, _ = sess.run([merged, optOp], | |
346 feed_dict={UNet2D.tfData: batchData, tfLabels: batchLabels, UNet2D.tfTraining: 1}) | |
347 trainWriter.add_summary(summary, i) | |
348 | |
349 # validation | |
350 | |
351 perm = np.arange(nValid) | |
352 np.random.shuffle(perm) | |
353 | |
354 for j in range(batchSize): | |
355 batchData[j, :, :, :] = Valid[perm[j], :, :, :] | |
356 batchLabels[j, :, :, :] = LValid[perm[j], :, :, :] | |
357 | |
358 summary, es = sess.run([merged, errors], | |
359 feed_dict={UNet2D.tfData: batchData, tfLabels: batchLabels, UNet2D.tfTraining: 0}) | |
360 validWriter.add_summary(summary, i) | |
361 | |
362 e = np.mean(es) | |
363 print('step %05d, e: %f' % (i, e)) | |
364 | |
365 if i == 0: | |
366 if restoreVariables: | |
367 lowestError = e | |
368 else: | |
369 lowestError = np.inf | |
370 | |
371 if np.mod(i, 100) == 0 and e < lowestError: | |
372 lowestError = e | |
373 print("Model saved in file: %s" % saver.save(sess, outModelPath)) | |
374 | |
375 # -------------------------------------------------- | |
376 # test | |
377 # -------------------------------------------------- | |
378 | |
379 if not os.path.exists(outPMPath): | |
380 os.makedirs(outPMPath) | |
381 | |
382 for i in range(nTest): | |
383 j = np.mod(i, batchSize) | |
384 | |
385 batchData[j, :, :, :] = Test[i, :, :, :] | |
386 batchLabels[j, :, :, :] = LTest[i, :, :, :] | |
387 | |
388 if j == batchSize - 1 or i == nTest - 1: | |
389 | |
390 output = sess.run(UNet2D.nn, | |
391 feed_dict={UNet2D.tfData: batchData, tfLabels: batchLabels, UNet2D.tfTraining: 0}) | |
392 | |
393 for k in range(j + 1): | |
394 pm = output[k, :, :, testPMIndex] | |
395 gt = batchLabels[k, :, :, testPMIndex] | |
396 im = np.sqrt(normalize(batchData[k, :, :, 0])) | |
397 imwrite(np.uint8(255 * np.concatenate((im, np.concatenate((pm, gt), axis=1)), axis=1)), | |
398 '%s/I%05d.png' % (outPMPath, i - j + k + 1)) | |
399 | |
400 # -------------------------------------------------- | |
401 # save hyper-parameters, clean-up | |
402 # -------------------------------------------------- | |
403 | |
404 saveData(UNet2D.hp, pathjoin(modelPath, 'hp.data')) | |
405 | |
406 trainWriter.close() | |
407 validWriter.close() | |
408 sess.close() | |
409 | |
410 def deploy(imPath, nImages, modelPath, pmPath, gpuIndex, pmIndex): | |
411 os.environ['CUDA_VISIBLE_DEVICES'] = '%d' % gpuIndex | |
412 | |
413 variablesPath = pathjoin(modelPath, 'model.ckpt') | |
414 outPMPath = pmPath | |
415 | |
416 hp = loadData(pathjoin(modelPath, 'hp.data')) | |
417 UNet2D.setupWithHP(hp) | |
418 | |
419 batchSize = UNet2D.hp['batchSize'] | |
420 imSize = UNet2D.hp['imSize'] | |
421 nChannels = UNet2D.hp['nChannels'] | |
422 nClasses = UNet2D.hp['nClasses'] | |
423 | |
424 # -------------------------------------------------- | |
425 # data | |
426 # -------------------------------------------------- | |
427 | |
428 Data = np.zeros((nImages, imSize, imSize, nChannels)) | |
429 | |
430 datasetMean = loadData(pathjoin(modelPath, 'datasetMean.data')) | |
431 datasetStDev = loadData(pathjoin(modelPath, 'datasetStDev.data')) | |
432 | |
433 for iSample in range(0, nImages): | |
434 path = '%s/I%05d_Img.tif' % (imPath, iSample) | |
435 im = im2double(tifread(path)) | |
436 Data[iSample, :, :, 0] = (im - datasetMean) / datasetStDev | |
437 | |
438 # -------------------------------------------------- | |
439 # session | |
440 # -------------------------------------------------- | |
441 | |
442 saver = tf.train.Saver() | |
443 sess = tf.Session(config=tf.ConfigProto( | |
444 allow_soft_placement=True)) # config parameter needed to save variables when using GPU | |
445 | |
446 saver.restore(sess, variablesPath) | |
447 print("Model restored.") | |
448 | |
449 # -------------------------------------------------- | |
450 # deploy | |
451 # -------------------------------------------------- | |
452 | |
453 batchData = np.zeros((batchSize, imSize, imSize, nChannels)) | |
454 | |
455 if not os.path.exists(outPMPath): | |
456 os.makedirs(outPMPath) | |
457 | |
458 for i in range(nImages): | |
459 print(i, nImages) | |
460 | |
461 j = np.mod(i, batchSize) | |
462 | |
463 batchData[j, :, :, :] = Data[i, :, :, :] | |
464 | |
465 if j == batchSize - 1 or i == nImages - 1: | |
466 | |
467 output = sess.run(UNet2D.nn, feed_dict={UNet2D.tfData: batchData, UNet2D.tfTraining: 0}) | |
468 | |
469 for k in range(j + 1): | |
470 pm = output[k, :, :, pmIndex] | |
471 im = np.sqrt(normalize(batchData[k, :, :, 0])) | |
472 # imwrite(np.uint8(255*np.concatenate((im,pm),axis=1)),'%s/I%05d.png' % (outPMPath,i-j+k+1)) | |
473 imwrite(np.uint8(255 * im), '%s/I%05d_Im.png' % (outPMPath, i - j + k + 1)) | |
474 imwrite(np.uint8(255 * pm), '%s/I%05d_PM.png' % (outPMPath, i - j + k + 1)) | |
475 | |
476 # -------------------------------------------------- | |
477 # clean-up | |
478 # -------------------------------------------------- | |
479 | |
480 sess.close() | |
481 | |
482 def singleImageInferenceSetup(modelPath, gpuIndex): | |
483 os.environ['CUDA_VISIBLE_DEVICES'] = '%d' % gpuIndex | |
484 | |
485 variablesPath = pathjoin(modelPath, 'model.ckpt') | |
486 | |
487 hp = loadData(pathjoin(modelPath, 'hp.data')) | |
488 UNet2D.setupWithHP(hp) | |
489 | |
490 UNet2D.DatasetMean = loadData(pathjoin(modelPath, 'datasetMean.data')) | |
491 UNet2D.DatasetStDev = loadData(pathjoin(modelPath, 'datasetStDev.data')) | |
492 print(UNet2D.DatasetMean) | |
493 print(UNet2D.DatasetStDev) | |
494 | |
495 # -------------------------------------------------- | |
496 # session | |
497 # -------------------------------------------------- | |
498 | |
499 saver = tf.train.Saver() | |
500 UNet2D.Session = tf.Session(config=tf.ConfigProto( | |
501 allow_soft_placement=True)) # config parameter needed to save variables when using GPU | |
502 | |
503 saver.restore(UNet2D.Session, variablesPath) | |
504 print("Model restored.") | |
505 | |
506 def singleImageInferenceCleanup(): | |
507 UNet2D.Session.close() | |
508 | |
509 def singleImageInference(image, mode, pmIndex): | |
510 print('Inference...') | |
511 | |
512 batchSize = UNet2D.hp['batchSize'] | |
513 imSize = UNet2D.hp['imSize'] | |
514 nChannels = UNet2D.hp['nChannels'] | |
515 | |
516 PI2D.setup(image, imSize, int(imSize / 8), mode) | |
517 PI2D.createOutput(nChannels) | |
518 | |
519 batchData = np.zeros((batchSize, imSize, imSize, nChannels)) | |
520 for i in range(PI2D.NumPatches): | |
521 j = np.mod(i, batchSize) | |
522 batchData[j, :, :, 0] = (PI2D.getPatch(i) - UNet2D.DatasetMean) / UNet2D.DatasetStDev | |
523 if j == batchSize - 1 or i == PI2D.NumPatches - 1: | |
524 output = UNet2D.Session.run(UNet2D.nn, feed_dict={UNet2D.tfData: batchData, UNet2D.tfTraining: 0}) | |
525 for k in range(j + 1): | |
526 pm = output[k, :, :, pmIndex] | |
527 PI2D.patchOutput(i - j + k, pm) | |
528 # PI2D.patchOutput(i-j+k,normalize(imgradmag(PI2D.getPatch(i-j+k),1))) | |
529 | |
530 return PI2D.getValidOutput() | |
531 | |
532 | |
533 if __name__ == '__main__': | |
534 parser = argparse.ArgumentParser() | |
535 parser.add_argument("imagePath", help="path to the .tif file") | |
536 parser.add_argument("--channel", help="channel to perform inference on", type=int, default=0) | |
537 parser.add_argument("--TMA", help="specify if TMA", action="store_true") | |
538 parser.add_argument("--scalingFactor", help="factor by which to increase/decrease image size by", type=float, | |
539 default=1) | |
540 args = parser.parse_args() | |
541 | |
542 logPath = '' | |
543 modelPath = 'D:\\LSP\\UNet\\tonsil20x1bin1chan\\TFModel - 3class 16 kernels 5ks 2 layers' | |
544 pmPath = '' | |
545 | |
546 UNet2D.singleImageInferenceSetup(modelPath, 1) | |
547 imagePath = args.imagePath | |
548 sampleList = glob.glob(imagePath + '/exemplar*') | |
549 dapiChannel = args.channel | |
550 dsFactor = args.scalingFactor | |
551 for iSample in sampleList: | |
552 if args.TMA: | |
553 fileList = [x for x in glob.glob(iSample + '\\dearray\\*.tif') if x != (iSample + '\\dearray\\TMA_MAP.tif')] | |
554 print(iSample) | |
555 else: | |
556 fileList = glob.glob(iSample + '//registration//*ome.tif') | |
557 print(fileList) | |
558 for iFile in fileList: | |
559 fileName = os.path.basename(iFile) | |
560 fileNamePrefix = fileName.split(os.extsep, 1) | |
561 I = tifffile.imread(iFile, key=dapiChannel) | |
562 rawI = I | |
563 hsize = int((float(I.shape[0]) * float(dsFactor))) | |
564 vsize = int((float(I.shape[1]) * float(dsFactor))) | |
565 I = resize(I, (hsize, vsize)) | |
566 I = im2double(sk.rescale_intensity(I, in_range=(np.min(I), np.max(I)), out_range=(0, 0.983))) | |
567 rawI = im2double(rawI) / np.max(im2double(rawI)) | |
568 outputPath = iSample + '//prob_maps' | |
569 if not os.path.exists(outputPath): | |
570 os.makedirs(outputPath) | |
571 K = np.zeros((2, rawI.shape[0], rawI.shape[1])) | |
572 contours = UNet2D.singleImageInference(I, 'accumulate', 1) | |
573 hsize = int((float(I.shape[0]) * float(1 / dsFactor))) | |
574 vsize = int((float(I.shape[1]) * float(1 / dsFactor))) | |
575 contours = resize(contours, (rawI.shape[0], rawI.shape[1])) | |
576 K[1, :, :] = rawI | |
577 K[0, :, :] = contours | |
578 tifwrite(np.uint8(255 * K), | |
579 outputPath + '//' + fileNamePrefix[0] + '_ContoursPM_' + str(dapiChannel + 1) + '.tif') | |
580 del K | |
581 K = np.zeros((1, rawI.shape[0], rawI.shape[1])) | |
582 nuclei = UNet2D.singleImageInference(I, 'accumulate', 2) | |
583 nuclei = resize(nuclei, (rawI.shape[0], rawI.shape[1])) | |
584 K[0, :, :] = nuclei | |
585 tifwrite(np.uint8(255 * K), | |
586 outputPath + '//' + fileNamePrefix[0] + '_NucleiPM_' + str(dapiChannel + 1) + '.tif') | |
587 del K | |
588 UNet2D.singleImageInferenceCleanup() |