声明:此篇文章是个人学习笔记,并非教程,所以内容可能不够严谨。可作参考,但不保证绝对正确。如果你发现我的文章有什么错误,非常欢迎指正,谢谢哦


torch.mean参数简单介绍

#求所有元素的平均值:input是要处理的张量;返回值是1个数,张量形式
torch.mean(input, *, dtype=None) → Tensor 

#沿张量中某个维度求平均值:input是要处理的张量,dim是想求的维度,keepdim是否保留长度为1的维度;返回值是张量形式,维数默认和原张量一样。
torch.mean(input, dim, keepdim=False, *, dtype=None, out=None) → Tensor

实验验证函数和参数作用

数据集

以batch=2,channel=2,hight=4, width=4的图片为例,即维度为[2, 2, 4, 4]的张量:
batch1:(红色、橙色为两个不同的通道,下同)
在这里插入图片描述
batch2:
在这里插入图片描述

以下用代码实验:
data = torch.tensor(
        [[[ [9.0, 0, 7, 6],
            [3, 2, 6, 8],
            [7, 5, 4, 4],
            [4, 8, 3, 5]],

         [  [3, 8, 7, 2],
            [9, 6, 1, 2],
            [2, 0, 8, 0],
            [2, 9, 8, 4]]],

         [[ [6, 1, 5, 6],
            [2, 3, 4, 8],
            [5, 3, 3, 3],
            [4, 1, 8, 4]],

          [ [3, 6, 5, 4],
            [4, 9, 8, 5],
            [7, 1, 5, 4],
            [4, 4, 8, 6]]]
         ]
)
total_mean = torch.mean(data)
print(total_mean)
print(total_mean.size())
#输出结果为:
#tensor(4.7031),也就是所有元素的平均值
#torch.Size([])
mean_data = torch.mean(data, dim=0, keepdim=True)
print('dim=0', mean_data)
print(mean_data.size())
#dim=0 tensor([[[[7.5000, 0.5000, 6.0000, 6.0000],
#         [2.5000, 2.5000, 5.0000, 8.0000],
#         [6.0000, 4.0000, 3.5000, 3.5000],
#         [4.0000, 4.5000, 5.5000, 4.5000]],
#
#        [[3.0000, 7.0000, 6.0000, 3.0000],
#         [6.5000, 7.5000, 4.5000, 3.5000],
#         [4.5000, 0.5000, 6.5000, 2.0000],
#         [3.0000, 6.5000, 8.0000, 5.0000]]]])
#torch.Size([1, 2, 4, 4])
#即对不同batch求平均值变为1个batch
mean_data = torch.mean(data, dim=1, keepdim=True)
print('dim=1', mean_data)
print(mean_data.size())
#dim=1 tensor([[[[6.0000, 4.0000, 7.0000, 4.0000],
#          [6.0000, 4.0000, 3.5000, 5.0000],
#          [4.5000, 2.5000, 6.0000, 2.0000],
#          [3.0000, 8.5000, 5.5000, 4.5000]]],
#
#        [[[4.5000, 3.5000, 5.0000, 5.0000],
#          [3.0000, 6.0000, 6.0000, 6.5000],
#          [6.0000, 2.0000, 4.0000, 3.5000],
#          [4.0000, 2.5000, 8.0000, 5.0000]]]])
#torch.Size([2, 1, 4, 4])
#即对不同通道求平均值变为1个通道
mean_data = torch.mean(data, dim=2, keepdim=True)
print('dim=2', mean_data)
print(mean_data.size())
#dim=2 tensor([[[[5.7500, 3.7500, 5.0000, 5.7500]],
#
#         [[4.0000, 5.7500, 6.0000, 2.0000]]],
#
#        [[[4.2500, 2.0000, 5.0000, 5.2500]],
#
#         [[4.5000, 5.0000, 6.5000, 4.7500]]]])
#torch.Size([2, 2, 1, 4])
#即对图片不同行求平均值变为1行
mean_data = torch.mean(data, dim=3, keepdim=True)
print('dim=3', mean_data)
print(mean_data.size())
#dim=3 tensor([[[[5.5000],
#          [4.7500],
#          [5.0000],
#          [5.0000]],
#
#         [[5.0000],
#          [4.5000],
#          [2.5000],
#          [5.7500]]],
#
#        [[[4.5000],
#          [4.2500],
#          [3.5000],
#          [4.2500]],
#
#         [[4.5000],
#          [6.5000],
#          [4.2500],
#          [5.5000]]]])
#torch.Size([2, 2, 4, 1])
#即对图片不同列求平均值变为1列

上述实验可验证:
dim指明的维度对应张量维度的索引,即0对应最外层batch,1对应channel,2对应hight, 3对应width。

mean_data = torch.mean(data, dim=3, keepdim=False)
print('dim=3', mean_data)
print(mean_data.size())
#dim=3 tensor([[[5.5000, 4.7500, 5.0000, 5.0000],
#         [5.0000, 4.5000, 2.5000, 5.7500]],
#
#        [[4.5000, 4.2500, 3.5000, 4.2500],
#         [4.5000, 6.5000, 4.2500, 5.5000]]])
#torch.Size([2, 2, 4])

上述实验可验证:
keepdim表示是否保留长度为1的维度,默认情况下为False。也就是说,由于对列求平均值后,列数变为1了,当keepdim为False时,就把这个维度去除,原来的四维张量变为三维张量。

Logo

旨在为数千万中国开发者提供一个无缝且高效的云端环境,以支持学习、使用和贡献开源项目。

更多推荐