Mercurial > repos > perssond > unmicst
comparison UnMicst.py @ 0:6bec4fef6b2e draft
"planemo upload for repository https://github.com/ohsu-comp-bio/unmicst commit 73e4cae15f2d7cdc86719e77470eb00af4b6ebb7-dirty"
| author | perssond |
|---|---|
| date | Fri, 12 Mar 2021 00:17:29 +0000 |
| parents | |
| children |
comparison
equal
deleted
inserted
replaced
| -1:000000000000 | 0:6bec4fef6b2e |
|---|---|
| 1 import numpy as np | |
| 2 from scipy import misc | |
| 3 import tensorflow.compat.v1 as tf | |
| 4 import shutil | |
| 5 import scipy.io as sio | |
| 6 import os, fnmatch, glob | |
| 7 import skimage.exposure as sk | |
| 8 import skimage.io | |
| 9 import argparse | |
| 10 import czifile | |
| 11 from nd2reader import ND2Reader | |
| 12 import tifffile | |
| 13 import sys | |
| 14 tf.disable_v2_behavior() | |
| 15 #sys.path.insert(0, 'C:\\Users\\Public\\Documents\\ImageScience') | |
| 16 | |
| 17 from toolbox.imtools import * | |
| 18 from toolbox.ftools import * | |
| 19 from toolbox.PartitionOfImage import PI2D | |
| 20 from toolbox import GPUselect | |
| 21 | |
| 22 def concat3(lst): | |
| 23 return tf.concat(lst, 3) | |
| 24 | |
| 25 | |
| 26 class UNet2D: | |
| 27 hp = None # hyper-parameters | |
| 28 nn = None # network | |
| 29 tfTraining = None # if training or not (to handle batch norm) | |
| 30 tfData = None # data placeholder | |
| 31 Session = None | |
| 32 DatasetMean = 0 | |
| 33 DatasetStDev = 0 | |
| 34 | |
| 35 def setupWithHP(hp): | |
| 36 UNet2D.setup(hp['imSize'], | |
| 37 hp['nChannels'], | |
| 38 hp['nClasses'], | |
| 39 hp['nOut0'], | |
| 40 hp['featMapsFact'], | |
| 41 hp['downSampFact'], | |
| 42 hp['ks'], | |
| 43 hp['nExtraConvs'], | |
| 44 hp['stdDev0'], | |
| 45 hp['nLayers'], | |
| 46 hp['batchSize']) | |
| 47 | |
| 48 def setup(imSize, nChannels, nClasses, nOut0, featMapsFact, downSampFact, kernelSize, nExtraConvs, stdDev0, | |
| 49 nDownSampLayers, batchSize): | |
| 50 UNet2D.hp = {'imSize': imSize, | |
| 51 'nClasses': nClasses, | |
| 52 'nChannels': nChannels, | |
| 53 'nExtraConvs': nExtraConvs, | |
| 54 'nLayers': nDownSampLayers, | |
| 55 'featMapsFact': featMapsFact, | |
| 56 'downSampFact': downSampFact, | |
| 57 'ks': kernelSize, | |
| 58 'nOut0': nOut0, | |
| 59 'stdDev0': stdDev0, | |
| 60 'batchSize': batchSize} | |
| 61 | |
| 62 nOutX = [UNet2D.hp['nChannels'], UNet2D.hp['nOut0']] | |
| 63 dsfX = [] | |
| 64 for i in range(UNet2D.hp['nLayers']): | |
| 65 nOutX.append(nOutX[-1] * UNet2D.hp['featMapsFact']) | |
| 66 dsfX.append(UNet2D.hp['downSampFact']) | |
| 67 | |
| 68 # -------------------------------------------------- | |
| 69 # downsampling layer | |
| 70 # -------------------------------------------------- | |
| 71 | |
| 72 with tf.name_scope('placeholders'): | |
| 73 UNet2D.tfTraining = tf.placeholder(tf.bool, name='training') | |
| 74 UNet2D.tfData = tf.placeholder("float", shape=[None, UNet2D.hp['imSize'], UNet2D.hp['imSize'], | |
| 75 UNet2D.hp['nChannels']], name='data') | |
| 76 | |
| 77 def down_samp_layer(data, index): | |
| 78 with tf.name_scope('ld%d' % index): | |
| 79 ldXWeights1 = tf.Variable( | |
| 80 tf.truncated_normal([UNet2D.hp['ks'], UNet2D.hp['ks'], nOutX[index], nOutX[index + 1]], | |
| 81 stddev=stdDev0), name='kernel1') | |
| 82 ldXWeightsExtra = [] | |
| 83 for i in range(nExtraConvs): | |
| 84 ldXWeightsExtra.append(tf.Variable( | |
| 85 tf.truncated_normal([UNet2D.hp['ks'], UNet2D.hp['ks'], nOutX[index + 1], nOutX[index + 1]], | |
| 86 stddev=stdDev0), name='kernelExtra%d' % i)) | |
| 87 | |
| 88 c00 = tf.nn.conv2d(data, ldXWeights1, strides=[1, 1, 1, 1], padding='SAME') | |
| 89 for i in range(nExtraConvs): | |
| 90 c00 = tf.nn.conv2d(tf.nn.relu(c00), ldXWeightsExtra[i], strides=[1, 1, 1, 1], padding='SAME') | |
| 91 | |
| 92 ldXWeightsShortcut = tf.Variable( | |
| 93 tf.truncated_normal([1, 1, nOutX[index], nOutX[index + 1]], stddev=stdDev0), name='shortcutWeights') | |
| 94 shortcut = tf.nn.conv2d(data, ldXWeightsShortcut, strides=[1, 1, 1, 1], padding='SAME') | |
| 95 | |
| 96 bn = tf.layers.batch_normalization(tf.nn.relu(c00 + shortcut), training=UNet2D.tfTraining) | |
| 97 | |
| 98 return tf.nn.max_pool(bn, ksize=[1, dsfX[index], dsfX[index], 1], | |
| 99 strides=[1, dsfX[index], dsfX[index], 1], padding='SAME', name='maxpool') | |
| 100 | |
| 101 # -------------------------------------------------- | |
| 102 # bottom layer | |
| 103 # -------------------------------------------------- | |
| 104 | |
| 105 with tf.name_scope('lb'): | |
| 106 lbWeights1 = tf.Variable(tf.truncated_normal( | |
| 107 [UNet2D.hp['ks'], UNet2D.hp['ks'], nOutX[UNet2D.hp['nLayers']], nOutX[UNet2D.hp['nLayers'] + 1]], | |
| 108 stddev=stdDev0), name='kernel1') | |
| 109 | |
| 110 def lb(hidden): | |
| 111 return tf.nn.relu(tf.nn.conv2d(hidden, lbWeights1, strides=[1, 1, 1, 1], padding='SAME'), name='conv') | |
| 112 | |
| 113 # -------------------------------------------------- | |
| 114 # downsampling | |
| 115 # -------------------------------------------------- | |
| 116 | |
| 117 with tf.name_scope('downsampling'): | |
| 118 dsX = [] | |
| 119 dsX.append(UNet2D.tfData) | |
| 120 | |
| 121 for i in range(UNet2D.hp['nLayers']): | |
| 122 dsX.append(down_samp_layer(dsX[i], i)) | |
| 123 | |
| 124 b = lb(dsX[UNet2D.hp['nLayers']]) | |
| 125 | |
| 126 # -------------------------------------------------- | |
| 127 # upsampling layer | |
| 128 # -------------------------------------------------- | |
| 129 | |
| 130 def up_samp_layer(data, index): | |
| 131 with tf.name_scope('lu%d' % index): | |
| 132 luXWeights1 = tf.Variable( | |
| 133 tf.truncated_normal([UNet2D.hp['ks'], UNet2D.hp['ks'], nOutX[index + 1], nOutX[index + 2]], | |
| 134 stddev=stdDev0), name='kernel1') | |
| 135 luXWeights2 = tf.Variable(tf.truncated_normal( | |
| 136 [UNet2D.hp['ks'], UNet2D.hp['ks'], nOutX[index] + nOutX[index + 1], nOutX[index + 1]], | |
| 137 stddev=stdDev0), name='kernel2') | |
| 138 luXWeightsExtra = [] | |
| 139 for i in range(nExtraConvs): | |
| 140 luXWeightsExtra.append(tf.Variable( | |
| 141 tf.truncated_normal([UNet2D.hp['ks'], UNet2D.hp['ks'], nOutX[index + 1], nOutX[index + 1]], | |
| 142 stddev=stdDev0), name='kernel2Extra%d' % i)) | |
| 143 | |
| 144 outSize = UNet2D.hp['imSize'] | |
| 145 for i in range(index): | |
| 146 outSize /= dsfX[i] | |
| 147 outSize = int(outSize) | |
| 148 | |
| 149 outputShape = [UNet2D.hp['batchSize'], outSize, outSize, nOutX[index + 1]] | |
| 150 us = tf.nn.relu( | |
| 151 tf.nn.conv2d_transpose(data, luXWeights1, outputShape, strides=[1, dsfX[index], dsfX[index], 1], | |
| 152 padding='SAME'), name='conv1') | |
| 153 cc = concat3([dsX[index], us]) | |
| 154 cv = tf.nn.relu(tf.nn.conv2d(cc, luXWeights2, strides=[1, 1, 1, 1], padding='SAME'), name='conv2') | |
| 155 for i in range(nExtraConvs): | |
| 156 cv = tf.nn.relu(tf.nn.conv2d(cv, luXWeightsExtra[i], strides=[1, 1, 1, 1], padding='SAME'), | |
| 157 name='conv2Extra%d' % i) | |
| 158 return cv | |
| 159 | |
| 160 # -------------------------------------------------- | |
| 161 # final (top) layer | |
| 162 # -------------------------------------------------- | |
| 163 | |
| 164 with tf.name_scope('lt'): | |
| 165 ltWeights1 = tf.Variable(tf.truncated_normal([1, 1, nOutX[1], nClasses], stddev=stdDev0), name='kernel') | |
| 166 | |
| 167 def lt(hidden): | |
| 168 return tf.nn.conv2d(hidden, ltWeights1, strides=[1, 1, 1, 1], padding='SAME', name='conv') | |
| 169 | |
| 170 # -------------------------------------------------- | |
| 171 # upsampling | |
| 172 # -------------------------------------------------- | |
| 173 | |
| 174 with tf.name_scope('upsampling'): | |
| 175 usX = [] | |
| 176 usX.append(b) | |
| 177 | |
| 178 for i in range(UNet2D.hp['nLayers']): | |
| 179 usX.append(up_samp_layer(usX[i], UNet2D.hp['nLayers'] - 1 - i)) | |
| 180 | |
| 181 t = lt(usX[UNet2D.hp['nLayers']]) | |
| 182 | |
| 183 sm = tf.nn.softmax(t, -1) | |
| 184 UNet2D.nn = sm | |
| 185 | |
| 186 def train(imPath, logPath, modelPath, pmPath, nTrain, nValid, nTest, restoreVariables, nSteps, gpuIndex, | |
| 187 testPMIndex): | |
| 188 os.environ['CUDA_VISIBLE_DEVICES'] = '%d' % gpuIndex | |
| 189 | |
| 190 outLogPath = logPath | |
| 191 trainWriterPath = pathjoin(logPath, 'Train') | |
| 192 validWriterPath = pathjoin(logPath, 'Valid') | |
| 193 outModelPath = pathjoin(modelPath, 'model.ckpt') | |
| 194 outPMPath = pmPath | |
| 195 | |
| 196 batchSize = UNet2D.hp['batchSize'] | |
| 197 imSize = UNet2D.hp['imSize'] | |
| 198 nChannels = UNet2D.hp['nChannels'] | |
| 199 nClasses = UNet2D.hp['nClasses'] | |
| 200 | |
| 201 # -------------------------------------------------- | |
| 202 # data | |
| 203 # -------------------------------------------------- | |
| 204 | |
| 205 Train = np.zeros((nTrain, imSize, imSize, nChannels)) | |
| 206 Valid = np.zeros((nValid, imSize, imSize, nChannels)) | |
| 207 Test = np.zeros((nTest, imSize, imSize, nChannels)) | |
| 208 LTrain = np.zeros((nTrain, imSize, imSize, nClasses)) | |
| 209 LValid = np.zeros((nValid, imSize, imSize, nClasses)) | |
| 210 LTest = np.zeros((nTest, imSize, imSize, nClasses)) | |
| 211 | |
| 212 print('loading data, computing mean / st dev') | |
| 213 if not os.path.exists(modelPath): | |
| 214 os.makedirs(modelPath) | |
| 215 if restoreVariables: | |
| 216 datasetMean = loadData(pathjoin(modelPath, 'datasetMean.data')) | |
| 217 datasetStDev = loadData(pathjoin(modelPath, 'datasetStDev.data')) | |
| 218 else: | |
| 219 datasetMean = 0 | |
| 220 datasetStDev = 0 | |
| 221 for iSample in range(nTrain + nValid + nTest): | |
| 222 I = im2double(tifread('%s/I%05d_Img.tif' % (imPath, iSample))) | |
| 223 datasetMean += np.mean(I) | |
| 224 datasetStDev += np.std(I) | |
| 225 datasetMean /= (nTrain + nValid + nTest) | |
| 226 datasetStDev /= (nTrain + nValid + nTest) | |
| 227 saveData(datasetMean, pathjoin(modelPath, 'datasetMean.data')) | |
| 228 saveData(datasetStDev, pathjoin(modelPath, 'datasetStDev.data')) | |
| 229 | |
| 230 perm = np.arange(nTrain + nValid + nTest) | |
| 231 np.random.shuffle(perm) | |
| 232 | |
| 233 for iSample in range(0, nTrain): | |
| 234 path = '%s/I%05d_Img.tif' % (imPath, perm[iSample]) | |
| 235 im = im2double(tifread(path)) | |
| 236 Train[iSample, :, :, 0] = (im - datasetMean) / datasetStDev | |
| 237 path = '%s/I%05d_Ant.tif' % (imPath, perm[iSample]) | |
| 238 im = tifread(path) | |
| 239 for i in range(nClasses): | |
| 240 LTrain[iSample, :, :, i] = (im == i + 1) | |
| 241 | |
| 242 for iSample in range(0, nValid): | |
| 243 path = '%s/I%05d_Img.tif' % (imPath, perm[nTrain + iSample]) | |
| 244 im = im2double(tifread(path)) | |
| 245 Valid[iSample, :, :, 0] = (im - datasetMean) / datasetStDev | |
| 246 path = '%s/I%05d_Ant.tif' % (imPath, perm[nTrain + iSample]) | |
| 247 im = tifread(path) | |
| 248 for i in range(nClasses): | |
| 249 LValid[iSample, :, :, i] = (im == i + 1) | |
| 250 | |
| 251 for iSample in range(0, nTest): | |
| 252 path = '%s/I%05d_Img.tif' % (imPath, perm[nTrain + nValid + iSample]) | |
| 253 im = im2double(tifread(path)) | |
| 254 Test[iSample, :, :, 0] = (im - datasetMean) / datasetStDev | |
| 255 path = '%s/I%05d_Ant.tif' % (imPath, perm[nTrain + nValid + iSample]) | |
| 256 im = tifread(path) | |
| 257 for i in range(nClasses): | |
| 258 LTest[iSample, :, :, i] = (im == i + 1) | |
| 259 | |
| 260 # -------------------------------------------------- | |
| 261 # optimization | |
| 262 # -------------------------------------------------- | |
| 263 | |
| 264 tfLabels = tf.placeholder("float", shape=[None, imSize, imSize, nClasses], name='labels') | |
| 265 | |
| 266 globalStep = tf.Variable(0, trainable=False) | |
| 267 learningRate0 = 0.01 | |
| 268 decaySteps = 1000 | |
| 269 decayRate = 0.95 | |
| 270 learningRate = tf.train.exponential_decay(learningRate0, globalStep, decaySteps, decayRate, staircase=True) | |
| 271 | |
| 272 with tf.name_scope('optim'): | |
| 273 loss = tf.reduce_mean(-tf.reduce_sum(tf.multiply(tfLabels, tf.log(UNet2D.nn)), 3)) | |
| 274 updateOps = tf.get_collection(tf.GraphKeys.UPDATE_OPS) | |
| 275 # optimizer = tf.train.MomentumOptimizer(1e-3,0.9) | |
| 276 optimizer = tf.train.MomentumOptimizer(learningRate, 0.9) | |
| 277 # optimizer = tf.train.GradientDescentOptimizer(learningRate) | |
| 278 with tf.control_dependencies(updateOps): | |
| 279 optOp = optimizer.minimize(loss, global_step=globalStep) | |
| 280 | |
| 281 with tf.name_scope('eval'): | |
| 282 error = [] | |
| 283 for iClass in range(nClasses): | |
| 284 labels0 = tf.reshape(tf.to_int32(tf.slice(tfLabels, [0, 0, 0, iClass], [-1, -1, -1, 1])), | |
| 285 [batchSize, imSize, imSize]) | |
| 286 predict0 = tf.reshape(tf.to_int32(tf.equal(tf.argmax(UNet2D.nn, 3), iClass)), | |
| 287 [batchSize, imSize, imSize]) | |
| 288 correct = tf.multiply(labels0, predict0) | |
| 289 nCorrect0 = tf.reduce_sum(correct) | |
| 290 nLabels0 = tf.reduce_sum(labels0) | |
| 291 error.append(1 - tf.to_float(nCorrect0) / tf.to_float(nLabels0)) | |
| 292 errors = tf.tuple(error) | |
| 293 | |
| 294 # -------------------------------------------------- | |
| 295 # inspection | |
| 296 # -------------------------------------------------- | |
| 297 | |
| 298 with tf.name_scope('scalars'): | |
| 299 tf.summary.scalar('avg_cross_entropy', loss) | |
| 300 for iClass in range(nClasses): | |
| 301 tf.summary.scalar('avg_pixel_error_%d' % iClass, error[iClass]) | |
| 302 tf.summary.scalar('learning_rate', learningRate) | |
| 303 with tf.name_scope('images'): | |
| 304 split0 = tf.slice(UNet2D.nn, [0, 0, 0, 0], [-1, -1, -1, 1]) | |
| 305 split1 = tf.slice(UNet2D.nn, [0, 0, 0, 1], [-1, -1, -1, 1]) | |
| 306 if nClasses > 2: | |
| 307 split2 = tf.slice(UNet2D.nn, [0, 0, 0, 2], [-1, -1, -1, 1]) | |
| 308 tf.summary.image('pm0', split0) | |
| 309 tf.summary.image('pm1', split1) | |
| 310 if nClasses > 2: | |
| 311 tf.summary.image('pm2', split2) | |
| 312 merged = tf.summary.merge_all() | |
| 313 | |
| 314 # -------------------------------------------------- | |
| 315 # session | |
| 316 # -------------------------------------------------- | |
| 317 | |
| 318 saver = tf.train.Saver() | |
| 319 sess = tf.Session(config=tf.ConfigProto( | |
| 320 allow_soft_placement=True)) # config parameter needed to save variables when using GPU | |
| 321 | |
| 322 if os.path.exists(outLogPath): | |
| 323 shutil.rmtree(outLogPath) | |
| 324 trainWriter = tf.summary.FileWriter(trainWriterPath, sess.graph) | |
| 325 validWriter = tf.summary.FileWriter(validWriterPath, sess.graph) | |
| 326 | |
| 327 if restoreVariables: | |
| 328 saver.restore(sess, outModelPath) | |
| 329 print("Model restored.") | |
| 330 else: | |
| 331 sess.run(tf.global_variables_initializer()) | |
| 332 | |
| 333 # -------------------------------------------------- | |
| 334 # train | |
| 335 # -------------------------------------------------- | |
| 336 | |
| 337 batchData = np.zeros((batchSize, imSize, imSize, nChannels)) | |
| 338 batchLabels = np.zeros((batchSize, imSize, imSize, nClasses)) | |
| 339 for i in range(nSteps): | |
| 340 # train | |
| 341 | |
| 342 perm = np.arange(nTrain) | |
| 343 np.random.shuffle(perm) | |
| 344 | |
| 345 for j in range(batchSize): | |
| 346 batchData[j, :, :, :] = Train[perm[j], :, :, :] | |
| 347 batchLabels[j, :, :, :] = LTrain[perm[j], :, :, :] | |
| 348 | |
| 349 summary, _ = sess.run([merged, optOp], | |
| 350 feed_dict={UNet2D.tfData: batchData, tfLabels: batchLabels, UNet2D.tfTraining: 1}) | |
| 351 trainWriter.add_summary(summary, i) | |
| 352 | |
| 353 # validation | |
| 354 | |
| 355 perm = np.arange(nValid) | |
| 356 np.random.shuffle(perm) | |
| 357 | |
| 358 for j in range(batchSize): | |
| 359 batchData[j, :, :, :] = Valid[perm[j], :, :, :] | |
| 360 batchLabels[j, :, :, :] = LValid[perm[j], :, :, :] | |
| 361 | |
| 362 summary, es = sess.run([merged, errors], | |
| 363 feed_dict={UNet2D.tfData: batchData, tfLabels: batchLabels, UNet2D.tfTraining: 0}) | |
| 364 validWriter.add_summary(summary, i) | |
| 365 | |
| 366 e = np.mean(es) | |
| 367 print('step %05d, e: %f' % (i, e)) | |
| 368 | |
| 369 if i == 0: | |
| 370 if restoreVariables: | |
| 371 lowestError = e | |
| 372 else: | |
| 373 lowestError = np.inf | |
| 374 | |
| 375 if np.mod(i, 100) == 0 and e < lowestError: | |
| 376 lowestError = e | |
| 377 print("Model saved in file: %s" % saver.save(sess, outModelPath)) | |
| 378 | |
| 379 # -------------------------------------------------- | |
| 380 # test | |
| 381 # -------------------------------------------------- | |
| 382 | |
| 383 if not os.path.exists(outPMPath): | |
| 384 os.makedirs(outPMPath) | |
| 385 | |
| 386 for i in range(nTest): | |
| 387 j = np.mod(i, batchSize) | |
| 388 | |
| 389 batchData[j, :, :, :] = Test[i, :, :, :] | |
| 390 batchLabels[j, :, :, :] = LTest[i, :, :, :] | |
| 391 | |
| 392 if j == batchSize - 1 or i == nTest - 1: | |
| 393 | |
| 394 output = sess.run(UNet2D.nn, | |
| 395 feed_dict={UNet2D.tfData: batchData, tfLabels: batchLabels, UNet2D.tfTraining: 0}) | |
| 396 | |
| 397 for k in range(j + 1): | |
| 398 pm = output[k, :, :, testPMIndex] | |
| 399 gt = batchLabels[k, :, :, testPMIndex] | |
| 400 im = np.sqrt(normalize(batchData[k, :, :, 0])) | |
| 401 imwrite(np.uint8(255 * np.concatenate((im, np.concatenate((pm, gt), axis=1)), axis=1)), | |
| 402 '%s/I%05d.png' % (outPMPath, i - j + k + 1)) | |
| 403 | |
| 404 # -------------------------------------------------- | |
| 405 # save hyper-parameters, clean-up | |
| 406 # -------------------------------------------------- | |
| 407 | |
| 408 saveData(UNet2D.hp, pathjoin(modelPath, 'hp.data')) | |
| 409 | |
| 410 trainWriter.close() | |
| 411 validWriter.close() | |
| 412 sess.close() | |
| 413 | |
| 414 def deploy(imPath, nImages, modelPath, pmPath, gpuIndex, pmIndex): | |
| 415 os.environ['CUDA_VISIBLE_DEVICES'] = '%d' % gpuIndex | |
| 416 | |
| 417 variablesPath = pathjoin(modelPath, 'model.ckpt') | |
| 418 outPMPath = pmPath | |
| 419 | |
| 420 hp = loadData(pathjoin(modelPath, 'hp.data')) | |
| 421 UNet2D.setupWithHP(hp) | |
| 422 | |
| 423 batchSize = UNet2D.hp['batchSize'] | |
| 424 imSize = UNet2D.hp['imSize'] | |
| 425 nChannels = UNet2D.hp['nChannels'] | |
| 426 nClasses = UNet2D.hp['nClasses'] | |
| 427 | |
| 428 # -------------------------------------------------- | |
| 429 # data | |
| 430 # -------------------------------------------------- | |
| 431 | |
| 432 Data = np.zeros((nImages, imSize, imSize, nChannels)) | |
| 433 | |
| 434 datasetMean = loadData(pathjoin(modelPath, 'datasetMean.data')) | |
| 435 datasetStDev = loadData(pathjoin(modelPath, 'datasetStDev.data')) | |
| 436 | |
| 437 for iSample in range(0, nImages): | |
| 438 path = '%s/I%05d_Img.tif' % (imPath, iSample) | |
| 439 im = im2double(tifread(path)) | |
| 440 Data[iSample, :, :, 0] = (im - datasetMean) / datasetStDev | |
| 441 | |
| 442 # -------------------------------------------------- | |
| 443 # session | |
| 444 # -------------------------------------------------- | |
| 445 | |
| 446 saver = tf.train.Saver() | |
| 447 sess = tf.Session(config=tf.ConfigProto( | |
| 448 allow_soft_placement=True)) # config parameter needed to save variables when using GPU | |
| 449 | |
| 450 saver.restore(sess, variablesPath) | |
| 451 print("Model restored.") | |
| 452 | |
| 453 # -------------------------------------------------- | |
| 454 # deploy | |
| 455 # -------------------------------------------------- | |
| 456 | |
| 457 batchData = np.zeros((batchSize, imSize, imSize, nChannels)) | |
| 458 | |
| 459 if not os.path.exists(outPMPath): | |
| 460 os.makedirs(outPMPath) | |
| 461 | |
| 462 for i in range(nImages): | |
| 463 print(i, nImages) | |
| 464 | |
| 465 j = np.mod(i, batchSize) | |
| 466 | |
| 467 batchData[j, :, :, :] = Data[i, :, :, :] | |
| 468 | |
| 469 if j == batchSize - 1 or i == nImages - 1: | |
| 470 | |
| 471 output = sess.run(UNet2D.nn, feed_dict={UNet2D.tfData: batchData, UNet2D.tfTraining: 0}) | |
| 472 | |
| 473 for k in range(j + 1): | |
| 474 pm = output[k, :, :, pmIndex] | |
| 475 im = np.sqrt(normalize(batchData[k, :, :, 0])) | |
| 476 # imwrite(np.uint8(255*np.concatenate((im,pm),axis=1)),'%s/I%05d.png' % (outPMPath,i-j+k+1)) | |
| 477 imwrite(np.uint8(255 * im), '%s/I%05d_Im.png' % (outPMPath, i - j + k + 1)) | |
| 478 imwrite(np.uint8(255 * pm), '%s/I%05d_PM.png' % (outPMPath, i - j + k + 1)) | |
| 479 | |
| 480 # -------------------------------------------------- | |
| 481 # clean-up | |
| 482 # -------------------------------------------------- | |
| 483 | |
| 484 sess.close() | |
| 485 | |
| 486 def singleImageInferenceSetup(modelPath, gpuIndex,mean,std): | |
| 487 variablesPath = pathjoin(modelPath, 'model.ckpt') | |
| 488 | |
| 489 hp = loadData(pathjoin(modelPath, 'hp.data')) | |
| 490 UNet2D.setupWithHP(hp) | |
| 491 if mean ==-1: | |
| 492 UNet2D.DatasetMean = loadData(pathjoin(modelPath, 'datasetMean.data')) | |
| 493 else: | |
| 494 UNet2D.DatasetMean = mean | |
| 495 | |
| 496 if std == -1: | |
| 497 UNet2D.DatasetStDev = loadData(pathjoin(modelPath, 'datasetStDev.data')) | |
| 498 else: | |
| 499 UNet2D.DatasetStDev = std | |
| 500 print(UNet2D.DatasetMean) | |
| 501 print(UNet2D.DatasetStDev) | |
| 502 | |
| 503 # -------------------------------------------------- | |
| 504 # session | |
| 505 # -------------------------------------------------- | |
| 506 | |
| 507 saver = tf.train.Saver() | |
| 508 UNet2D.Session = tf.Session(config=tf.ConfigProto()) | |
| 509 # allow_soft_placement=True)) # config parameter needed to save variables when using GPU | |
| 510 | |
| 511 saver.restore(UNet2D.Session, variablesPath) | |
| 512 print("Model restored.") | |
| 513 | |
| 514 def singleImageInferenceCleanup(): | |
| 515 UNet2D.Session.close() | |
| 516 | |
| 517 def singleImageInference(image, mode, pmIndex): | |
| 518 print('Inference...') | |
| 519 | |
| 520 batchSize = UNet2D.hp['batchSize'] | |
| 521 imSize = UNet2D.hp['imSize'] | |
| 522 nChannels = UNet2D.hp['nChannels'] | |
| 523 | |
| 524 PI2D.setup(image, imSize, int(imSize / 8), mode) | |
| 525 PI2D.createOutput(nChannels) | |
| 526 | |
| 527 batchData = np.zeros((batchSize, imSize, imSize, nChannels)) | |
| 528 for i in range(PI2D.NumPatches): | |
| 529 j = np.mod(i, batchSize) | |
| 530 batchData[j, :, :, 0] = (PI2D.getPatch(i) - UNet2D.DatasetMean) / UNet2D.DatasetStDev | |
| 531 if j == batchSize - 1 or i == PI2D.NumPatches - 1: | |
| 532 output = UNet2D.Session.run(UNet2D.nn, feed_dict={UNet2D.tfData: batchData, UNet2D.tfTraining: 0}) | |
| 533 for k in range(j + 1): | |
| 534 pm = output[k, :, :, pmIndex] | |
| 535 PI2D.patchOutput(i - j + k, pm) | |
| 536 # PI2D.patchOutput(i-j+k,normalize(imgradmag(PI2D.getPatch(i-j+k),1))) | |
| 537 | |
| 538 return PI2D.getValidOutput() | |
| 539 | |
| 540 | |
| 541 if __name__ == '__main__': | |
| 542 parser = argparse.ArgumentParser() | |
| 543 parser.add_argument("imagePath", help="path to the .tif file") | |
| 544 parser.add_argument("--model", help="type of model. For example, nuclei vs cytoplasm",default = 'nucleiDAPI') | |
| 545 parser.add_argument("--outputPath", help="output path of probability map") | |
| 546 parser.add_argument("--channel", help="channel to perform inference on", type=int, default=0) | |
| 547 parser.add_argument("--classOrder", help="background, contours, foreground", type = int, nargs = '+', default=-1) | |
| 548 parser.add_argument("--mean", help="mean intensity of input image. Use -1 to use model", type=float, default=-1) | |
| 549 parser.add_argument("--std", help="mean standard deviation of input image. Use -1 to use model", type=float, default=-1) | |
| 550 parser.add_argument("--scalingFactor", help="factor by which to increase/decrease image size by", type=float, | |
| 551 default=1) | |
| 552 parser.add_argument("--stackOutput", help="save probability maps as separate files", action='store_true') | |
| 553 parser.add_argument("--GPU", help="explicitly select GPU", type=int, default = -1) | |
| 554 parser.add_argument("--outlier", | |
| 555 help="map percentile intensity to max when rescaling intensity values. Max intensity as default", | |
| 556 type=float, default=-1) | |
| 557 args = parser.parse_args() | |
| 558 | |
| 559 logPath = '' | |
| 560 scriptPath = os.path.dirname(os.path.realpath(__file__)) | |
| 561 modelPath = os.path.join(scriptPath, 'models', args.model) | |
| 562 # modelPath = os.path.join(scriptPath, 'models/cytoplasmINcell') | |
| 563 # modelPath = os.path.join(scriptPath, 'cytoplasmZeissNikon') | |
| 564 pmPath = '' | |
| 565 | |
| 566 if os.system('nvidia-smi') == 0: | |
| 567 if args.GPU == -1: | |
| 568 print("automatically choosing GPU") | |
| 569 GPU = GPUselect.pick_gpu_lowest_memory() | |
| 570 else: | |
| 571 GPU = args.GPU | |
| 572 print('Using GPU ' + str(GPU)) | |
| 573 | |
| 574 else: | |
| 575 if sys.platform == 'win32': # only 1 gpu on windows | |
| 576 if args.GPU==-1: | |
| 577 GPU = 0 | |
| 578 else: | |
| 579 GPU = args.GPU | |
| 580 print('Using GPU ' + str(GPU)) | |
| 581 else: | |
| 582 GPU=0 | |
| 583 print('Using CPU') | |
| 584 os.environ['CUDA_VISIBLE_DEVICES'] = '%d' % GPU | |
| 585 UNet2D.singleImageInferenceSetup(modelPath, GPU,args.mean,args.std) | |
| 586 nClass = UNet2D.hp['nClasses'] | |
| 587 imagePath = args.imagePath | |
| 588 dapiChannel = args.channel | |
| 589 dsFactor = args.scalingFactor | |
| 590 parentFolder = os.path.dirname(os.path.dirname(imagePath)) | |
| 591 fileName = os.path.basename(imagePath) | |
| 592 fileNamePrefix = fileName.split(os.extsep, 1) | |
| 593 print(fileName) | |
| 594 fileType = fileNamePrefix[1] | |
| 595 | |
| 596 if fileType=='ome.tif' or fileType == 'btf' : | |
| 597 I = skio.imread(imagePath, img_num=dapiChannel,plugin='tifffile') | |
| 598 elif fileType == 'tif' : | |
| 599 I = tifffile.imread(imagePath, key=dapiChannel) | |
| 600 elif fileType == 'czi': | |
| 601 with czifile.CziFile(imagePath) as czi: | |
| 602 image = czi.asarray() | |
| 603 I = image[0, 0, dapiChannel, 0, 0, :, :, 0] | |
| 604 elif fileType == 'nd2': | |
| 605 with ND2Reader(imagePath) as fullStack: | |
| 606 I = fullStack[dapiChannel] | |
| 607 | |
| 608 if args.classOrder == -1: | |
| 609 args.classOrder = range(nClass) | |
| 610 | |
| 611 rawI = I | |
| 612 print(type(I)) | |
| 613 hsize = int((float(I.shape[0]) * float(dsFactor))) | |
| 614 vsize = int((float(I.shape[1]) * float(dsFactor))) | |
| 615 I = resize(I, (hsize, vsize)) | |
| 616 if args.outlier == -1: | |
| 617 maxLimit = np.max(I) | |
| 618 else: | |
| 619 maxLimit = np.percentile(I, args.outlier) | |
| 620 I = im2double(sk.rescale_intensity(I, in_range=(np.min(I), maxLimit), out_range=(0, 0.983))) | |
| 621 rawI = im2double(rawI) / np.max(im2double(rawI)) | |
| 622 if not args.outputPath: | |
| 623 args.outputPath = parentFolder + '//probability_maps' | |
| 624 | |
| 625 if not os.path.exists(args.outputPath): | |
| 626 os.makedirs(args.outputPath) | |
| 627 | |
| 628 append_kwargs = { | |
| 629 'bigtiff': True, | |
| 630 'metadata': None, | |
| 631 'append': True, | |
| 632 } | |
| 633 save_kwargs = { | |
| 634 'bigtiff': True, | |
| 635 'metadata': None, | |
| 636 'append': False, | |
| 637 } | |
| 638 if args.stackOutput: | |
| 639 slice=0 | |
| 640 for iClass in args.classOrder[::-1]: | |
| 641 PM = np.uint8(255*UNet2D.singleImageInference(I, 'accumulate', iClass)) # backwards in order to align with ilastik... | |
| 642 PM = resize(PM, (rawI.shape[0], rawI.shape[1])) | |
| 643 if slice==0: | |
| 644 skimage.io.imsave(args.outputPath + '//' + fileNamePrefix[0] + '_Probabilities_' + str(dapiChannel) + '.tif', np.uint8(255 * PM),**save_kwargs) | |
| 645 else: | |
| 646 skimage.io.imsave(args.outputPath + '//' + fileNamePrefix[0] + '_Probabilities_' + str(dapiChannel) + '.tif',np.uint8(255 * PM),**append_kwargs) | |
| 647 if slice==1: | |
| 648 save_kwargs['append'] = False | |
| 649 skimage.io.imsave(args.outputPath + '//' + fileNamePrefix[0] + '_Preview_' + str(dapiChannel) + '.tif', np.uint8(255 * PM), **save_kwargs) | |
| 650 skimage.io.imsave(args.outputPath + '//' + fileNamePrefix[0] + '_Preview_' + str(dapiChannel) + '.tif', np.uint8(255 * rawI), **append_kwargs) | |
| 651 slice = slice + 1 | |
| 652 | |
| 653 else: | |
| 654 contours = np.uint8(255*UNet2D.singleImageInference(I, 'accumulate', args.classOrder[1])) | |
| 655 hsize = int((float(I.shape[0]) * float(1 / dsFactor))) | |
| 656 vsize = int((float(I.shape[1]) * float(1 / dsFactor))) | |
| 657 contours = resize(contours, (rawI.shape[0], rawI.shape[1])) | |
| 658 skimage.io.imsave(args.outputPath + '//' + fileNamePrefix[0] + '_ContoursPM_' + str(dapiChannel) + '.tif',np.uint8(255 * contours),**save_kwargs) | |
| 659 skimage.io.imsave(args.outputPath + '//' + fileNamePrefix[0] + '_ContoursPM_' + str(dapiChannel) + '.tif',np.uint8(255 * rawI), **append_kwargs) | |
| 660 del contours | |
| 661 nuclei = np.uint8(255*UNet2D.singleImageInference(I, 'accumulate', args.classOrder[2])) | |
| 662 nuclei = resize(nuclei, (rawI.shape[0], rawI.shape[1])) | |
| 663 skimage.io.imsave(args.outputPath + '//' + fileNamePrefix[0] + '_NucleiPM_' + str(dapiChannel) + '.tif',np.uint8(255 * nuclei), **save_kwargs) | |
| 664 del nuclei | |
| 665 UNet2D.singleImageInferenceCleanup() | |
| 666 | |
| 667 #aligned output files to reflect ilastik | |
| 668 #outputting all classes as single file | |
| 669 #handles multiple formats including tif, ome.tif, nd2, czi | |
| 670 #selectable models (human nuclei, mouse nuclei, cytoplasm) | |
| 671 | |
| 672 #added legacy function to save output files | |
| 673 #append save function to reduce memory footprint | |
| 674 #added --classOrder parameter to specify which class is background, contours, and nuclei respectively |
