Mercurial > repos > perssond > unmicst
comparison UnMicst.py @ 0:6bec4fef6b2e draft
"planemo upload for repository https://github.com/ohsu-comp-bio/unmicst commit 73e4cae15f2d7cdc86719e77470eb00af4b6ebb7-dirty"
author | perssond |
---|---|
date | Fri, 12 Mar 2021 00:17:29 +0000 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:6bec4fef6b2e |
---|---|
1 import numpy as np | |
2 from scipy import misc | |
3 import tensorflow.compat.v1 as tf | |
4 import shutil | |
5 import scipy.io as sio | |
6 import os, fnmatch, glob | |
7 import skimage.exposure as sk | |
8 import skimage.io | |
9 import argparse | |
10 import czifile | |
11 from nd2reader import ND2Reader | |
12 import tifffile | |
13 import sys | |
14 tf.disable_v2_behavior() | |
15 #sys.path.insert(0, 'C:\\Users\\Public\\Documents\\ImageScience') | |
16 | |
17 from toolbox.imtools import * | |
18 from toolbox.ftools import * | |
19 from toolbox.PartitionOfImage import PI2D | |
20 from toolbox import GPUselect | |
21 | |
22 def concat3(lst): | |
23 return tf.concat(lst, 3) | |
24 | |
25 | |
26 class UNet2D: | |
27 hp = None # hyper-parameters | |
28 nn = None # network | |
29 tfTraining = None # if training or not (to handle batch norm) | |
30 tfData = None # data placeholder | |
31 Session = None | |
32 DatasetMean = 0 | |
33 DatasetStDev = 0 | |
34 | |
35 def setupWithHP(hp): | |
36 UNet2D.setup(hp['imSize'], | |
37 hp['nChannels'], | |
38 hp['nClasses'], | |
39 hp['nOut0'], | |
40 hp['featMapsFact'], | |
41 hp['downSampFact'], | |
42 hp['ks'], | |
43 hp['nExtraConvs'], | |
44 hp['stdDev0'], | |
45 hp['nLayers'], | |
46 hp['batchSize']) | |
47 | |
48 def setup(imSize, nChannels, nClasses, nOut0, featMapsFact, downSampFact, kernelSize, nExtraConvs, stdDev0, | |
49 nDownSampLayers, batchSize): | |
50 UNet2D.hp = {'imSize': imSize, | |
51 'nClasses': nClasses, | |
52 'nChannels': nChannels, | |
53 'nExtraConvs': nExtraConvs, | |
54 'nLayers': nDownSampLayers, | |
55 'featMapsFact': featMapsFact, | |
56 'downSampFact': downSampFact, | |
57 'ks': kernelSize, | |
58 'nOut0': nOut0, | |
59 'stdDev0': stdDev0, | |
60 'batchSize': batchSize} | |
61 | |
62 nOutX = [UNet2D.hp['nChannels'], UNet2D.hp['nOut0']] | |
63 dsfX = [] | |
64 for i in range(UNet2D.hp['nLayers']): | |
65 nOutX.append(nOutX[-1] * UNet2D.hp['featMapsFact']) | |
66 dsfX.append(UNet2D.hp['downSampFact']) | |
67 | |
68 # -------------------------------------------------- | |
69 # downsampling layer | |
70 # -------------------------------------------------- | |
71 | |
72 with tf.name_scope('placeholders'): | |
73 UNet2D.tfTraining = tf.placeholder(tf.bool, name='training') | |
74 UNet2D.tfData = tf.placeholder("float", shape=[None, UNet2D.hp['imSize'], UNet2D.hp['imSize'], | |
75 UNet2D.hp['nChannels']], name='data') | |
76 | |
77 def down_samp_layer(data, index): | |
78 with tf.name_scope('ld%d' % index): | |
79 ldXWeights1 = tf.Variable( | |
80 tf.truncated_normal([UNet2D.hp['ks'], UNet2D.hp['ks'], nOutX[index], nOutX[index + 1]], | |
81 stddev=stdDev0), name='kernel1') | |
82 ldXWeightsExtra = [] | |
83 for i in range(nExtraConvs): | |
84 ldXWeightsExtra.append(tf.Variable( | |
85 tf.truncated_normal([UNet2D.hp['ks'], UNet2D.hp['ks'], nOutX[index + 1], nOutX[index + 1]], | |
86 stddev=stdDev0), name='kernelExtra%d' % i)) | |
87 | |
88 c00 = tf.nn.conv2d(data, ldXWeights1, strides=[1, 1, 1, 1], padding='SAME') | |
89 for i in range(nExtraConvs): | |
90 c00 = tf.nn.conv2d(tf.nn.relu(c00), ldXWeightsExtra[i], strides=[1, 1, 1, 1], padding='SAME') | |
91 | |
92 ldXWeightsShortcut = tf.Variable( | |
93 tf.truncated_normal([1, 1, nOutX[index], nOutX[index + 1]], stddev=stdDev0), name='shortcutWeights') | |
94 shortcut = tf.nn.conv2d(data, ldXWeightsShortcut, strides=[1, 1, 1, 1], padding='SAME') | |
95 | |
96 bn = tf.layers.batch_normalization(tf.nn.relu(c00 + shortcut), training=UNet2D.tfTraining) | |
97 | |
98 return tf.nn.max_pool(bn, ksize=[1, dsfX[index], dsfX[index], 1], | |
99 strides=[1, dsfX[index], dsfX[index], 1], padding='SAME', name='maxpool') | |
100 | |
101 # -------------------------------------------------- | |
102 # bottom layer | |
103 # -------------------------------------------------- | |
104 | |
105 with tf.name_scope('lb'): | |
106 lbWeights1 = tf.Variable(tf.truncated_normal( | |
107 [UNet2D.hp['ks'], UNet2D.hp['ks'], nOutX[UNet2D.hp['nLayers']], nOutX[UNet2D.hp['nLayers'] + 1]], | |
108 stddev=stdDev0), name='kernel1') | |
109 | |
110 def lb(hidden): | |
111 return tf.nn.relu(tf.nn.conv2d(hidden, lbWeights1, strides=[1, 1, 1, 1], padding='SAME'), name='conv') | |
112 | |
113 # -------------------------------------------------- | |
114 # downsampling | |
115 # -------------------------------------------------- | |
116 | |
117 with tf.name_scope('downsampling'): | |
118 dsX = [] | |
119 dsX.append(UNet2D.tfData) | |
120 | |
121 for i in range(UNet2D.hp['nLayers']): | |
122 dsX.append(down_samp_layer(dsX[i], i)) | |
123 | |
124 b = lb(dsX[UNet2D.hp['nLayers']]) | |
125 | |
126 # -------------------------------------------------- | |
127 # upsampling layer | |
128 # -------------------------------------------------- | |
129 | |
130 def up_samp_layer(data, index): | |
131 with tf.name_scope('lu%d' % index): | |
132 luXWeights1 = tf.Variable( | |
133 tf.truncated_normal([UNet2D.hp['ks'], UNet2D.hp['ks'], nOutX[index + 1], nOutX[index + 2]], | |
134 stddev=stdDev0), name='kernel1') | |
135 luXWeights2 = tf.Variable(tf.truncated_normal( | |
136 [UNet2D.hp['ks'], UNet2D.hp['ks'], nOutX[index] + nOutX[index + 1], nOutX[index + 1]], | |
137 stddev=stdDev0), name='kernel2') | |
138 luXWeightsExtra = [] | |
139 for i in range(nExtraConvs): | |
140 luXWeightsExtra.append(tf.Variable( | |
141 tf.truncated_normal([UNet2D.hp['ks'], UNet2D.hp['ks'], nOutX[index + 1], nOutX[index + 1]], | |
142 stddev=stdDev0), name='kernel2Extra%d' % i)) | |
143 | |
144 outSize = UNet2D.hp['imSize'] | |
145 for i in range(index): | |
146 outSize /= dsfX[i] | |
147 outSize = int(outSize) | |
148 | |
149 outputShape = [UNet2D.hp['batchSize'], outSize, outSize, nOutX[index + 1]] | |
150 us = tf.nn.relu( | |
151 tf.nn.conv2d_transpose(data, luXWeights1, outputShape, strides=[1, dsfX[index], dsfX[index], 1], | |
152 padding='SAME'), name='conv1') | |
153 cc = concat3([dsX[index], us]) | |
154 cv = tf.nn.relu(tf.nn.conv2d(cc, luXWeights2, strides=[1, 1, 1, 1], padding='SAME'), name='conv2') | |
155 for i in range(nExtraConvs): | |
156 cv = tf.nn.relu(tf.nn.conv2d(cv, luXWeightsExtra[i], strides=[1, 1, 1, 1], padding='SAME'), | |
157 name='conv2Extra%d' % i) | |
158 return cv | |
159 | |
160 # -------------------------------------------------- | |
161 # final (top) layer | |
162 # -------------------------------------------------- | |
163 | |
164 with tf.name_scope('lt'): | |
165 ltWeights1 = tf.Variable(tf.truncated_normal([1, 1, nOutX[1], nClasses], stddev=stdDev0), name='kernel') | |
166 | |
167 def lt(hidden): | |
168 return tf.nn.conv2d(hidden, ltWeights1, strides=[1, 1, 1, 1], padding='SAME', name='conv') | |
169 | |
170 # -------------------------------------------------- | |
171 # upsampling | |
172 # -------------------------------------------------- | |
173 | |
174 with tf.name_scope('upsampling'): | |
175 usX = [] | |
176 usX.append(b) | |
177 | |
178 for i in range(UNet2D.hp['nLayers']): | |
179 usX.append(up_samp_layer(usX[i], UNet2D.hp['nLayers'] - 1 - i)) | |
180 | |
181 t = lt(usX[UNet2D.hp['nLayers']]) | |
182 | |
183 sm = tf.nn.softmax(t, -1) | |
184 UNet2D.nn = sm | |
185 | |
186 def train(imPath, logPath, modelPath, pmPath, nTrain, nValid, nTest, restoreVariables, nSteps, gpuIndex, | |
187 testPMIndex): | |
188 os.environ['CUDA_VISIBLE_DEVICES'] = '%d' % gpuIndex | |
189 | |
190 outLogPath = logPath | |
191 trainWriterPath = pathjoin(logPath, 'Train') | |
192 validWriterPath = pathjoin(logPath, 'Valid') | |
193 outModelPath = pathjoin(modelPath, 'model.ckpt') | |
194 outPMPath = pmPath | |
195 | |
196 batchSize = UNet2D.hp['batchSize'] | |
197 imSize = UNet2D.hp['imSize'] | |
198 nChannels = UNet2D.hp['nChannels'] | |
199 nClasses = UNet2D.hp['nClasses'] | |
200 | |
201 # -------------------------------------------------- | |
202 # data | |
203 # -------------------------------------------------- | |
204 | |
205 Train = np.zeros((nTrain, imSize, imSize, nChannels)) | |
206 Valid = np.zeros((nValid, imSize, imSize, nChannels)) | |
207 Test = np.zeros((nTest, imSize, imSize, nChannels)) | |
208 LTrain = np.zeros((nTrain, imSize, imSize, nClasses)) | |
209 LValid = np.zeros((nValid, imSize, imSize, nClasses)) | |
210 LTest = np.zeros((nTest, imSize, imSize, nClasses)) | |
211 | |
212 print('loading data, computing mean / st dev') | |
213 if not os.path.exists(modelPath): | |
214 os.makedirs(modelPath) | |
215 if restoreVariables: | |
216 datasetMean = loadData(pathjoin(modelPath, 'datasetMean.data')) | |
217 datasetStDev = loadData(pathjoin(modelPath, 'datasetStDev.data')) | |
218 else: | |
219 datasetMean = 0 | |
220 datasetStDev = 0 | |
221 for iSample in range(nTrain + nValid + nTest): | |
222 I = im2double(tifread('%s/I%05d_Img.tif' % (imPath, iSample))) | |
223 datasetMean += np.mean(I) | |
224 datasetStDev += np.std(I) | |
225 datasetMean /= (nTrain + nValid + nTest) | |
226 datasetStDev /= (nTrain + nValid + nTest) | |
227 saveData(datasetMean, pathjoin(modelPath, 'datasetMean.data')) | |
228 saveData(datasetStDev, pathjoin(modelPath, 'datasetStDev.data')) | |
229 | |
230 perm = np.arange(nTrain + nValid + nTest) | |
231 np.random.shuffle(perm) | |
232 | |
233 for iSample in range(0, nTrain): | |
234 path = '%s/I%05d_Img.tif' % (imPath, perm[iSample]) | |
235 im = im2double(tifread(path)) | |
236 Train[iSample, :, :, 0] = (im - datasetMean) / datasetStDev | |
237 path = '%s/I%05d_Ant.tif' % (imPath, perm[iSample]) | |
238 im = tifread(path) | |
239 for i in range(nClasses): | |
240 LTrain[iSample, :, :, i] = (im == i + 1) | |
241 | |
242 for iSample in range(0, nValid): | |
243 path = '%s/I%05d_Img.tif' % (imPath, perm[nTrain + iSample]) | |
244 im = im2double(tifread(path)) | |
245 Valid[iSample, :, :, 0] = (im - datasetMean) / datasetStDev | |
246 path = '%s/I%05d_Ant.tif' % (imPath, perm[nTrain + iSample]) | |
247 im = tifread(path) | |
248 for i in range(nClasses): | |
249 LValid[iSample, :, :, i] = (im == i + 1) | |
250 | |
251 for iSample in range(0, nTest): | |
252 path = '%s/I%05d_Img.tif' % (imPath, perm[nTrain + nValid + iSample]) | |
253 im = im2double(tifread(path)) | |
254 Test[iSample, :, :, 0] = (im - datasetMean) / datasetStDev | |
255 path = '%s/I%05d_Ant.tif' % (imPath, perm[nTrain + nValid + iSample]) | |
256 im = tifread(path) | |
257 for i in range(nClasses): | |
258 LTest[iSample, :, :, i] = (im == i + 1) | |
259 | |
260 # -------------------------------------------------- | |
261 # optimization | |
262 # -------------------------------------------------- | |
263 | |
264 tfLabels = tf.placeholder("float", shape=[None, imSize, imSize, nClasses], name='labels') | |
265 | |
266 globalStep = tf.Variable(0, trainable=False) | |
267 learningRate0 = 0.01 | |
268 decaySteps = 1000 | |
269 decayRate = 0.95 | |
270 learningRate = tf.train.exponential_decay(learningRate0, globalStep, decaySteps, decayRate, staircase=True) | |
271 | |
272 with tf.name_scope('optim'): | |
273 loss = tf.reduce_mean(-tf.reduce_sum(tf.multiply(tfLabels, tf.log(UNet2D.nn)), 3)) | |
274 updateOps = tf.get_collection(tf.GraphKeys.UPDATE_OPS) | |
275 # optimizer = tf.train.MomentumOptimizer(1e-3,0.9) | |
276 optimizer = tf.train.MomentumOptimizer(learningRate, 0.9) | |
277 # optimizer = tf.train.GradientDescentOptimizer(learningRate) | |
278 with tf.control_dependencies(updateOps): | |
279 optOp = optimizer.minimize(loss, global_step=globalStep) | |
280 | |
281 with tf.name_scope('eval'): | |
282 error = [] | |
283 for iClass in range(nClasses): | |
284 labels0 = tf.reshape(tf.to_int32(tf.slice(tfLabels, [0, 0, 0, iClass], [-1, -1, -1, 1])), | |
285 [batchSize, imSize, imSize]) | |
286 predict0 = tf.reshape(tf.to_int32(tf.equal(tf.argmax(UNet2D.nn, 3), iClass)), | |
287 [batchSize, imSize, imSize]) | |
288 correct = tf.multiply(labels0, predict0) | |
289 nCorrect0 = tf.reduce_sum(correct) | |
290 nLabels0 = tf.reduce_sum(labels0) | |
291 error.append(1 - tf.to_float(nCorrect0) / tf.to_float(nLabels0)) | |
292 errors = tf.tuple(error) | |
293 | |
294 # -------------------------------------------------- | |
295 # inspection | |
296 # -------------------------------------------------- | |
297 | |
298 with tf.name_scope('scalars'): | |
299 tf.summary.scalar('avg_cross_entropy', loss) | |
300 for iClass in range(nClasses): | |
301 tf.summary.scalar('avg_pixel_error_%d' % iClass, error[iClass]) | |
302 tf.summary.scalar('learning_rate', learningRate) | |
303 with tf.name_scope('images'): | |
304 split0 = tf.slice(UNet2D.nn, [0, 0, 0, 0], [-1, -1, -1, 1]) | |
305 split1 = tf.slice(UNet2D.nn, [0, 0, 0, 1], [-1, -1, -1, 1]) | |
306 if nClasses > 2: | |
307 split2 = tf.slice(UNet2D.nn, [0, 0, 0, 2], [-1, -1, -1, 1]) | |
308 tf.summary.image('pm0', split0) | |
309 tf.summary.image('pm1', split1) | |
310 if nClasses > 2: | |
311 tf.summary.image('pm2', split2) | |
312 merged = tf.summary.merge_all() | |
313 | |
314 # -------------------------------------------------- | |
315 # session | |
316 # -------------------------------------------------- | |
317 | |
318 saver = tf.train.Saver() | |
319 sess = tf.Session(config=tf.ConfigProto( | |
320 allow_soft_placement=True)) # config parameter needed to save variables when using GPU | |
321 | |
322 if os.path.exists(outLogPath): | |
323 shutil.rmtree(outLogPath) | |
324 trainWriter = tf.summary.FileWriter(trainWriterPath, sess.graph) | |
325 validWriter = tf.summary.FileWriter(validWriterPath, sess.graph) | |
326 | |
327 if restoreVariables: | |
328 saver.restore(sess, outModelPath) | |
329 print("Model restored.") | |
330 else: | |
331 sess.run(tf.global_variables_initializer()) | |
332 | |
333 # -------------------------------------------------- | |
334 # train | |
335 # -------------------------------------------------- | |
336 | |
337 batchData = np.zeros((batchSize, imSize, imSize, nChannels)) | |
338 batchLabels = np.zeros((batchSize, imSize, imSize, nClasses)) | |
339 for i in range(nSteps): | |
340 # train | |
341 | |
342 perm = np.arange(nTrain) | |
343 np.random.shuffle(perm) | |
344 | |
345 for j in range(batchSize): | |
346 batchData[j, :, :, :] = Train[perm[j], :, :, :] | |
347 batchLabels[j, :, :, :] = LTrain[perm[j], :, :, :] | |
348 | |
349 summary, _ = sess.run([merged, optOp], | |
350 feed_dict={UNet2D.tfData: batchData, tfLabels: batchLabels, UNet2D.tfTraining: 1}) | |
351 trainWriter.add_summary(summary, i) | |
352 | |
353 # validation | |
354 | |
355 perm = np.arange(nValid) | |
356 np.random.shuffle(perm) | |
357 | |
358 for j in range(batchSize): | |
359 batchData[j, :, :, :] = Valid[perm[j], :, :, :] | |
360 batchLabels[j, :, :, :] = LValid[perm[j], :, :, :] | |
361 | |
362 summary, es = sess.run([merged, errors], | |
363 feed_dict={UNet2D.tfData: batchData, tfLabels: batchLabels, UNet2D.tfTraining: 0}) | |
364 validWriter.add_summary(summary, i) | |
365 | |
366 e = np.mean(es) | |
367 print('step %05d, e: %f' % (i, e)) | |
368 | |
369 if i == 0: | |
370 if restoreVariables: | |
371 lowestError = e | |
372 else: | |
373 lowestError = np.inf | |
374 | |
375 if np.mod(i, 100) == 0 and e < lowestError: | |
376 lowestError = e | |
377 print("Model saved in file: %s" % saver.save(sess, outModelPath)) | |
378 | |
379 # -------------------------------------------------- | |
380 # test | |
381 # -------------------------------------------------- | |
382 | |
383 if not os.path.exists(outPMPath): | |
384 os.makedirs(outPMPath) | |
385 | |
386 for i in range(nTest): | |
387 j = np.mod(i, batchSize) | |
388 | |
389 batchData[j, :, :, :] = Test[i, :, :, :] | |
390 batchLabels[j, :, :, :] = LTest[i, :, :, :] | |
391 | |
392 if j == batchSize - 1 or i == nTest - 1: | |
393 | |
394 output = sess.run(UNet2D.nn, | |
395 feed_dict={UNet2D.tfData: batchData, tfLabels: batchLabels, UNet2D.tfTraining: 0}) | |
396 | |
397 for k in range(j + 1): | |
398 pm = output[k, :, :, testPMIndex] | |
399 gt = batchLabels[k, :, :, testPMIndex] | |
400 im = np.sqrt(normalize(batchData[k, :, :, 0])) | |
401 imwrite(np.uint8(255 * np.concatenate((im, np.concatenate((pm, gt), axis=1)), axis=1)), | |
402 '%s/I%05d.png' % (outPMPath, i - j + k + 1)) | |
403 | |
404 # -------------------------------------------------- | |
405 # save hyper-parameters, clean-up | |
406 # -------------------------------------------------- | |
407 | |
408 saveData(UNet2D.hp, pathjoin(modelPath, 'hp.data')) | |
409 | |
410 trainWriter.close() | |
411 validWriter.close() | |
412 sess.close() | |
413 | |
414 def deploy(imPath, nImages, modelPath, pmPath, gpuIndex, pmIndex): | |
415 os.environ['CUDA_VISIBLE_DEVICES'] = '%d' % gpuIndex | |
416 | |
417 variablesPath = pathjoin(modelPath, 'model.ckpt') | |
418 outPMPath = pmPath | |
419 | |
420 hp = loadData(pathjoin(modelPath, 'hp.data')) | |
421 UNet2D.setupWithHP(hp) | |
422 | |
423 batchSize = UNet2D.hp['batchSize'] | |
424 imSize = UNet2D.hp['imSize'] | |
425 nChannels = UNet2D.hp['nChannels'] | |
426 nClasses = UNet2D.hp['nClasses'] | |
427 | |
428 # -------------------------------------------------- | |
429 # data | |
430 # -------------------------------------------------- | |
431 | |
432 Data = np.zeros((nImages, imSize, imSize, nChannels)) | |
433 | |
434 datasetMean = loadData(pathjoin(modelPath, 'datasetMean.data')) | |
435 datasetStDev = loadData(pathjoin(modelPath, 'datasetStDev.data')) | |
436 | |
437 for iSample in range(0, nImages): | |
438 path = '%s/I%05d_Img.tif' % (imPath, iSample) | |
439 im = im2double(tifread(path)) | |
440 Data[iSample, :, :, 0] = (im - datasetMean) / datasetStDev | |
441 | |
442 # -------------------------------------------------- | |
443 # session | |
444 # -------------------------------------------------- | |
445 | |
446 saver = tf.train.Saver() | |
447 sess = tf.Session(config=tf.ConfigProto( | |
448 allow_soft_placement=True)) # config parameter needed to save variables when using GPU | |
449 | |
450 saver.restore(sess, variablesPath) | |
451 print("Model restored.") | |
452 | |
453 # -------------------------------------------------- | |
454 # deploy | |
455 # -------------------------------------------------- | |
456 | |
457 batchData = np.zeros((batchSize, imSize, imSize, nChannels)) | |
458 | |
459 if not os.path.exists(outPMPath): | |
460 os.makedirs(outPMPath) | |
461 | |
462 for i in range(nImages): | |
463 print(i, nImages) | |
464 | |
465 j = np.mod(i, batchSize) | |
466 | |
467 batchData[j, :, :, :] = Data[i, :, :, :] | |
468 | |
469 if j == batchSize - 1 or i == nImages - 1: | |
470 | |
471 output = sess.run(UNet2D.nn, feed_dict={UNet2D.tfData: batchData, UNet2D.tfTraining: 0}) | |
472 | |
473 for k in range(j + 1): | |
474 pm = output[k, :, :, pmIndex] | |
475 im = np.sqrt(normalize(batchData[k, :, :, 0])) | |
476 # imwrite(np.uint8(255*np.concatenate((im,pm),axis=1)),'%s/I%05d.png' % (outPMPath,i-j+k+1)) | |
477 imwrite(np.uint8(255 * im), '%s/I%05d_Im.png' % (outPMPath, i - j + k + 1)) | |
478 imwrite(np.uint8(255 * pm), '%s/I%05d_PM.png' % (outPMPath, i - j + k + 1)) | |
479 | |
480 # -------------------------------------------------- | |
481 # clean-up | |
482 # -------------------------------------------------- | |
483 | |
484 sess.close() | |
485 | |
486 def singleImageInferenceSetup(modelPath, gpuIndex,mean,std): | |
487 variablesPath = pathjoin(modelPath, 'model.ckpt') | |
488 | |
489 hp = loadData(pathjoin(modelPath, 'hp.data')) | |
490 UNet2D.setupWithHP(hp) | |
491 if mean ==-1: | |
492 UNet2D.DatasetMean = loadData(pathjoin(modelPath, 'datasetMean.data')) | |
493 else: | |
494 UNet2D.DatasetMean = mean | |
495 | |
496 if std == -1: | |
497 UNet2D.DatasetStDev = loadData(pathjoin(modelPath, 'datasetStDev.data')) | |
498 else: | |
499 UNet2D.DatasetStDev = std | |
500 print(UNet2D.DatasetMean) | |
501 print(UNet2D.DatasetStDev) | |
502 | |
503 # -------------------------------------------------- | |
504 # session | |
505 # -------------------------------------------------- | |
506 | |
507 saver = tf.train.Saver() | |
508 UNet2D.Session = tf.Session(config=tf.ConfigProto()) | |
509 # allow_soft_placement=True)) # config parameter needed to save variables when using GPU | |
510 | |
511 saver.restore(UNet2D.Session, variablesPath) | |
512 print("Model restored.") | |
513 | |
514 def singleImageInferenceCleanup(): | |
515 UNet2D.Session.close() | |
516 | |
517 def singleImageInference(image, mode, pmIndex): | |
518 print('Inference...') | |
519 | |
520 batchSize = UNet2D.hp['batchSize'] | |
521 imSize = UNet2D.hp['imSize'] | |
522 nChannels = UNet2D.hp['nChannels'] | |
523 | |
524 PI2D.setup(image, imSize, int(imSize / 8), mode) | |
525 PI2D.createOutput(nChannels) | |
526 | |
527 batchData = np.zeros((batchSize, imSize, imSize, nChannels)) | |
528 for i in range(PI2D.NumPatches): | |
529 j = np.mod(i, batchSize) | |
530 batchData[j, :, :, 0] = (PI2D.getPatch(i) - UNet2D.DatasetMean) / UNet2D.DatasetStDev | |
531 if j == batchSize - 1 or i == PI2D.NumPatches - 1: | |
532 output = UNet2D.Session.run(UNet2D.nn, feed_dict={UNet2D.tfData: batchData, UNet2D.tfTraining: 0}) | |
533 for k in range(j + 1): | |
534 pm = output[k, :, :, pmIndex] | |
535 PI2D.patchOutput(i - j + k, pm) | |
536 # PI2D.patchOutput(i-j+k,normalize(imgradmag(PI2D.getPatch(i-j+k),1))) | |
537 | |
538 return PI2D.getValidOutput() | |
539 | |
540 | |
541 if __name__ == '__main__': | |
542 parser = argparse.ArgumentParser() | |
543 parser.add_argument("imagePath", help="path to the .tif file") | |
544 parser.add_argument("--model", help="type of model. For example, nuclei vs cytoplasm",default = 'nucleiDAPI') | |
545 parser.add_argument("--outputPath", help="output path of probability map") | |
546 parser.add_argument("--channel", help="channel to perform inference on", type=int, default=0) | |
547 parser.add_argument("--classOrder", help="background, contours, foreground", type = int, nargs = '+', default=-1) | |
548 parser.add_argument("--mean", help="mean intensity of input image. Use -1 to use model", type=float, default=-1) | |
549 parser.add_argument("--std", help="mean standard deviation of input image. Use -1 to use model", type=float, default=-1) | |
550 parser.add_argument("--scalingFactor", help="factor by which to increase/decrease image size by", type=float, | |
551 default=1) | |
552 parser.add_argument("--stackOutput", help="save probability maps as separate files", action='store_true') | |
553 parser.add_argument("--GPU", help="explicitly select GPU", type=int, default = -1) | |
554 parser.add_argument("--outlier", | |
555 help="map percentile intensity to max when rescaling intensity values. Max intensity as default", | |
556 type=float, default=-1) | |
557 args = parser.parse_args() | |
558 | |
559 logPath = '' | |
560 scriptPath = os.path.dirname(os.path.realpath(__file__)) | |
561 modelPath = os.path.join(scriptPath, 'models', args.model) | |
562 # modelPath = os.path.join(scriptPath, 'models/cytoplasmINcell') | |
563 # modelPath = os.path.join(scriptPath, 'cytoplasmZeissNikon') | |
564 pmPath = '' | |
565 | |
566 if os.system('nvidia-smi') == 0: | |
567 if args.GPU == -1: | |
568 print("automatically choosing GPU") | |
569 GPU = GPUselect.pick_gpu_lowest_memory() | |
570 else: | |
571 GPU = args.GPU | |
572 print('Using GPU ' + str(GPU)) | |
573 | |
574 else: | |
575 if sys.platform == 'win32': # only 1 gpu on windows | |
576 if args.GPU==-1: | |
577 GPU = 0 | |
578 else: | |
579 GPU = args.GPU | |
580 print('Using GPU ' + str(GPU)) | |
581 else: | |
582 GPU=0 | |
583 print('Using CPU') | |
584 os.environ['CUDA_VISIBLE_DEVICES'] = '%d' % GPU | |
585 UNet2D.singleImageInferenceSetup(modelPath, GPU,args.mean,args.std) | |
586 nClass = UNet2D.hp['nClasses'] | |
587 imagePath = args.imagePath | |
588 dapiChannel = args.channel | |
589 dsFactor = args.scalingFactor | |
590 parentFolder = os.path.dirname(os.path.dirname(imagePath)) | |
591 fileName = os.path.basename(imagePath) | |
592 fileNamePrefix = fileName.split(os.extsep, 1) | |
593 print(fileName) | |
594 fileType = fileNamePrefix[1] | |
595 | |
596 if fileType=='ome.tif' or fileType == 'btf' : | |
597 I = skio.imread(imagePath, img_num=dapiChannel,plugin='tifffile') | |
598 elif fileType == 'tif' : | |
599 I = tifffile.imread(imagePath, key=dapiChannel) | |
600 elif fileType == 'czi': | |
601 with czifile.CziFile(imagePath) as czi: | |
602 image = czi.asarray() | |
603 I = image[0, 0, dapiChannel, 0, 0, :, :, 0] | |
604 elif fileType == 'nd2': | |
605 with ND2Reader(imagePath) as fullStack: | |
606 I = fullStack[dapiChannel] | |
607 | |
608 if args.classOrder == -1: | |
609 args.classOrder = range(nClass) | |
610 | |
611 rawI = I | |
612 print(type(I)) | |
613 hsize = int((float(I.shape[0]) * float(dsFactor))) | |
614 vsize = int((float(I.shape[1]) * float(dsFactor))) | |
615 I = resize(I, (hsize, vsize)) | |
616 if args.outlier == -1: | |
617 maxLimit = np.max(I) | |
618 else: | |
619 maxLimit = np.percentile(I, args.outlier) | |
620 I = im2double(sk.rescale_intensity(I, in_range=(np.min(I), maxLimit), out_range=(0, 0.983))) | |
621 rawI = im2double(rawI) / np.max(im2double(rawI)) | |
622 if not args.outputPath: | |
623 args.outputPath = parentFolder + '//probability_maps' | |
624 | |
625 if not os.path.exists(args.outputPath): | |
626 os.makedirs(args.outputPath) | |
627 | |
628 append_kwargs = { | |
629 'bigtiff': True, | |
630 'metadata': None, | |
631 'append': True, | |
632 } | |
633 save_kwargs = { | |
634 'bigtiff': True, | |
635 'metadata': None, | |
636 'append': False, | |
637 } | |
638 if args.stackOutput: | |
639 slice=0 | |
640 for iClass in args.classOrder[::-1]: | |
641 PM = np.uint8(255*UNet2D.singleImageInference(I, 'accumulate', iClass)) # backwards in order to align with ilastik... | |
642 PM = resize(PM, (rawI.shape[0], rawI.shape[1])) | |
643 if slice==0: | |
644 skimage.io.imsave(args.outputPath + '//' + fileNamePrefix[0] + '_Probabilities_' + str(dapiChannel) + '.tif', np.uint8(255 * PM),**save_kwargs) | |
645 else: | |
646 skimage.io.imsave(args.outputPath + '//' + fileNamePrefix[0] + '_Probabilities_' + str(dapiChannel) + '.tif',np.uint8(255 * PM),**append_kwargs) | |
647 if slice==1: | |
648 save_kwargs['append'] = False | |
649 skimage.io.imsave(args.outputPath + '//' + fileNamePrefix[0] + '_Preview_' + str(dapiChannel) + '.tif', np.uint8(255 * PM), **save_kwargs) | |
650 skimage.io.imsave(args.outputPath + '//' + fileNamePrefix[0] + '_Preview_' + str(dapiChannel) + '.tif', np.uint8(255 * rawI), **append_kwargs) | |
651 slice = slice + 1 | |
652 | |
653 else: | |
654 contours = np.uint8(255*UNet2D.singleImageInference(I, 'accumulate', args.classOrder[1])) | |
655 hsize = int((float(I.shape[0]) * float(1 / dsFactor))) | |
656 vsize = int((float(I.shape[1]) * float(1 / dsFactor))) | |
657 contours = resize(contours, (rawI.shape[0], rawI.shape[1])) | |
658 skimage.io.imsave(args.outputPath + '//' + fileNamePrefix[0] + '_ContoursPM_' + str(dapiChannel) + '.tif',np.uint8(255 * contours),**save_kwargs) | |
659 skimage.io.imsave(args.outputPath + '//' + fileNamePrefix[0] + '_ContoursPM_' + str(dapiChannel) + '.tif',np.uint8(255 * rawI), **append_kwargs) | |
660 del contours | |
661 nuclei = np.uint8(255*UNet2D.singleImageInference(I, 'accumulate', args.classOrder[2])) | |
662 nuclei = resize(nuclei, (rawI.shape[0], rawI.shape[1])) | |
663 skimage.io.imsave(args.outputPath + '//' + fileNamePrefix[0] + '_NucleiPM_' + str(dapiChannel) + '.tif',np.uint8(255 * nuclei), **save_kwargs) | |
664 del nuclei | |
665 UNet2D.singleImageInferenceCleanup() | |
666 | |
667 #aligned output files to reflect ilastik | |
668 #outputting all classes as single file | |
669 #handles multiple formats including tif, ome.tif, nd2, czi | |
670 #selectable models (human nuclei, mouse nuclei, cytoplasm) | |
671 | |
672 #added legacy function to save output files | |
673 #append save function to reduce memory footprint | |
674 #added --classOrder parameter to specify which class is background, contours, and nuclei respectively |