logits是网络的输出,logits.shape=(batch_size, w, h, 21),21类语义标签。

pred_classes = tf.expand_dims(tf.argmax(logits, axis=3, output_type=tf.int32), axis=3)#shape=(?, ?, ?, 1)

我们用numpy解释argmax和expand_dims这两个函数:

import numpy as np

#(2, 2, 4, 3)
x = np.array([[[[31,20,10],
	        [20,43,30],
	        [40,10,62],
	        [40,60,76]],
	       [[10,72,20],
	        [81,30,40],
	        [97,50,70],
	        [40,50,68]]],

	      [[[10,22,10],
	        [20,30,81],
	        [40,10,62],
	        [40,65,30]],
	       [[10,72,20],
	        [81,30,40],
	        [97,50,70],
	        [40,50,68]]]])

#(2, 2, 4)
y1 = np.argmax(x, axis=3)
'''[[[0 1 2 2]
     [1 0 0 2]]

    [[1 2 2 1]
     [1 0 0 2]]]'''

#(2, 2, 4, 1)
y2 = np.expand_dims(y1, axis=3)
'''[[[[0]
      [1]
      [2]
      [2]]

     [[1]
      [0]
      [0]
      [2]]]


    [[[1]
      [2]
      [2]
      [1]]

     [[1]
      [0]
      [0]
      [2]]]]'''

y2的值是每个一维列表的最大值的下标,如第一个值为0,是因为[31,20,10]中最大元素31的下标是0。

batch_size为1时的网络输出:

[[[31,20,10],
  [20,43,30],
  [40,10,62],
  [40,60,76]],
 [[10,72,20],
  [81,30,40],
  [97,50,70],
  [40,50,68]]]

注意图像大小是2×4,而不是4×3或3×4。所以[31,20,10]是第一个像素被归类为第0类、第1类、第2类的概率。因为31最大,所以该像素的语义标签被归类为0。这样就可以解释y2:

[[[0]
  [1]
  [2]
  [2]]

 [[1]
  [0]
  [0]
  [2]]]

batch_size=1时,它是指一个2×4的图像的第(0,0)个像素标签为0、第(0,1)个像素标签为1、... 、第(1,3)个像素标签为2。

Logo

瓜分20万奖金 获得内推名额 丰厚实物奖励 易参与易上手

更多推荐