torch.max()

torch.max(input, dim, keepdim=False) → output tensors (max, max_indices)
输入参数:
input = 输入tensor
dim = 求最大值的维度
keepdim = 是否保持原维度大小输出
输出:
max = 指定维度求得的最大值
max_indices = 指定维度求得的最大值索引

下面以一个大小为(3, 2, 5)的张量为例:
dim = 0

import torch
x = torch.rand(3, 2, 5) # 生成随机数
print(x)
>>> tensor([
			[[0.2514, 0.7950, 0.9641, 0.0135, 0.2785],
         	[0.2575, 0.4410, 0.6829, 0.6668, 0.5850]],

        	[[0.4725, 0.2015, 0.3406, 0.6989, 0.3551],
        	[0.9674, 0.5781, 0.6250, 0.3404, 0.4238]],
        	
        	[[0.2377, 0.3673, 0.3647, 0.1027, 0.9024],
        	[0.0047, 0.0106, 0.4600, 0.6851, 0.7389]]])
        	
x_value_index = torch.max(x, dim=0, keepdim=True) # 最大值和对应索引
print(x_value_index)
>>> torch.return_types.max(
	values=tensor([[
					[0.4725, 0.7950, 0.9641, 0.6989, 0.9024],
					[0.9674, 0.5781, 0.6829, 0.6851, 0.7389]
					]]),
	indices=tensor([[
					[1, 0, 0, 1, 2],
	         		[1, 1, 0, 2, 2]
	         		]])
	         		)

x_value = torch.max(x, 2, keepdim=True)[0]  # 单独取出最大值
print(x_value)
>>> tensor([[[0.9641],[0.6829]],[[0.6989],[0.9674]],[[0.9024],[0.7389]]])

x_index = torch.max(x, 2, keepdim=True)[1]  #单独取出最大值索引
print(x_index)
>>> tensor([[[2],[2]], [[3],[0]],[[4],[4]]])

dim = 1

import torch
x = torch.rand(3, 2, 5)
print(x)
>>> tensor([
			[[0.5524, 0.1146, 0.4460, 0.4948, 0.7163],
             [0.5388, 0.2290, 0.4652, 0.3818, 0.4202]],

     		[[0.4045, 0.5833, 0.7844, 0.5605, 0.6278],
         	 [0.0335, 0.1204, 0.3604, 0.4386, 0.0286]],

       		[[0.9510, 0.7801, 0.2879, 0.0369, 0.8103],
         	 [0.9522, 0.7442, 0.5938, 0.1807, 0.2721]]])

x_value_index = torch.max(x, dim=1, keepdim=True)
print(x_value_index)
>>> torch.return_types.max(
values=tensor([
			[[0.5524, 0.2290, 0.4652, 0.4948, 0.7163]],
			[[0.4045, 0.5833, 0.7844, 0.5605, 0.6278]],
			[[0.9522, 0.7801, 0.5938, 0.1807, 0.8103]]
			]),
indices=tensor([
			[[0, 1, 1, 0, 0]],
       		[[0, 0, 0, 0, 0]],
			[[1, 0, 1, 1, 0]]
			])
			)

x_value = torch.max(x, 2, keepdim=True)[0]
print(x_value)
>>> tensor([[[0.7163],[0.5388]],[[0.7844],[0.4386]],[[0.9510],[0.9522]]])

x_index = torch.max(x, 2, keepdim=True)[1]
print(x_index)
>>> tensor([[[4],[0]],[[2],[3]],[[0],[0]]])

dim = 2

import torch
x = torch.rand(3, 2, 5)  #  生成随机数
print(x)
>>>tensor([
		 [[0.9249, 0.5676, 0.1035, 0.3701, 0.4501],
         [0.5440, 0.9992, 0.7398, 0.1513, 0.3889]],

         [[0.2020, 0.4533, 0.1103, 0.9006, 0.8098],
         [0.3390, 0.3230, 0.8531, 0.1718, 0.4343]],

         [[0.9874, 0.2138, 0.0301, 0.9558, 0.8844],
         [0.7317, 0.3344, 0.4552, 0.3196, 0.6343]]
         ])

x_value_index = torch.max(x, dim = 2, keepdim=True) # 取每一行的最大值
print(x_value_index)
>>> torch.return_types.max(
values=tensor([[[0.9249],[0.9992]],[[0.9006],[0.8531]],[[0.9874],[0.7317]]]),
indices=tensor([[[0],[1]],[[3],[2]],[[0],[0]]]))

x_value = torch.max(x, 2, keepdim=True)[0]
print(x_value)
>>>tensor([[[0.9249],[0.9992]],[[0.9006],[0.8531]],[[0.9874],[0.7317]]])

x_index = torch.max(x, 2, keepdim=True)[1]
print(x_index)
>>> tensor([[[0],[1]],[[3],[2]],[[0],[0]]])
Logo

为开发者提供学习成长、分享交流、生态实践、资源工具等服务,帮助开发者快速成长。

更多推荐