Pytorch中dim的理解
·
dim的定义
dim 表示维度
x = torch.randn(2, 3, 3)
print(x)
print(x.size())
print(x.dim())
输出:
tensor([[[-1.6943, -2.1487, 1.2332],
[-0.2261, -0.1596, 1.5513],
[ 2.0383, -0.6982, -2.1481]],
[[ 0.4201, -2.7373, 0.2424],
[-1.1152, 1.3682, -1.8322],
[ 0.1957, -0.2920, 0.1845]]])
torch.Size([2, 3, 3])
3
这样看着不是很清晰,但如果将[]格式化:
[
[
[-1.6943, -2.1487, 1.2332],
[-0.2261, -0.1596, 1.5513],
[ 2.0383, -0.6982, -2.1481]
],
[
[ 0.4201, -2.7373, 0.2424],
[-1.1152, 1.3682, -1.8322],
[ 0.1957, -0.2920, 0.1845]
]
]
- 维度
(2, 3, 3)就很明显了, 是从矩阵的外部到内部 - 而
x.dim() = 3意味着x有三个维度,dim = (0, 1, 2),0对应着x.size()中的(2, 3, 3)1对应着x.size()中的(2,3, 3)2对应着x.size()中的(2, 3,3)
dim的理解
当dim = 0时, 指的是 x(3, 3)
也就是:
x = torch.randn(2, 3, 3)
print(x)
for i in x:
print(i)
print(i.size())
输出:
tensor([[[-1.4251, -0.8321, 1.0230],
[ 0.2008, 0.5929, -0.7696],
[-0.3721, -1.0837, -0.6642]],
[[-0.5337, 0.7808, 0.4419],
[-0.4683, 0.3847, 0.0747],
[ 1.0156, -0.4933, 1.5340]]])
tensor(
[
[-1.4251, -0.8321, 1.0230],
[ 0.2008, 0.5929, -0.7696],
[-0.3721, -1.0837, -0.6642]
]
)
torch.Size([3, 3])
tensor(
[
[-0.5337, 0.7808, 0.4419],
[-0.4683, 0.3847, 0.0747],
[ 1.0156, -0.4933, 1.5340]
]
)
torch.Size([3, 3])
所以说当dim=0时, 相当于去除x中的dim = 0的维度
验证
- torch.argmax(tensor)
返回tensor中值最大的数的下标, 比较的是同型张量
Example:
>>> x = torch.tensor([1, 5, 8, 4, 6])
>>> torch.argmax(x)
tensor(2)
import torch
x = torch.randn(2, 3, 3)
print(x)
print('='*50, end='\n\n')
for i in x:
print(i)
print(i.size())
print('='*50, end='\n\n')
print(x.size())
print(x.dim())
print('='*50, end='\n\n')
y = torch.argmax(x, dim=0)
print(y)
print(y.size())
输出:
tensor(
[
[
[-1.3918, 0.0620, -0.4111],
[ 1.9623, -1.3399, -0.4673],
[-0.0185, -1.9024, 0.1340]
],
[
[ 0.7135, -0.5290, -0.7656],
[ 0.2642, 0.5956, -0.0718],
[-0.7465, -0.8098, -0.0874]
]
]
)
==================================================
tensor([[-1.3918, 0.0620, -0.4111],
[ 1.9623, -1.3399, -0.4673],
[-0.0185, -1.9024, 0.1340]])
torch.Size([3, 3])
tensor([[ 0.7135, -0.5290, -0.7656],
[ 0.2642, 0.5956, -0.0718],
[-0.7465, -0.8098, -0.0874]])
torch.Size([3, 3])
==================================================
torch.Size([2, 3, 3])
3
==================================================
tensor([[1, 0, 0],
[0, 1, 1],
[0, 1, 0]])
torch.Size([3, 3])
-
分析一下
y[0] = [1, 0, 0], 为什么呢?
有两种想法:- 它比较的是
[-1.3918, 0.0620, -0.4111]与[ 0.7135, -0.5290, -0.7656]
其中:
[-1.3918, 0.7135], 0.7135比较大, 所以返回1
[0.0620, -0.5290], 0.0620比较大, 所以返回0
[-0.4111, -0.7656], -0.4111比较大, 所以返回0 - 如果比较的是
x[i]中的每一列, 得到的是2x3的输出, 例如x[0]:
[-1.3918, 0.0620, -0.4111], [ 1.9623, -1.3399, -0.4673], [-0.0185, -1.9024, 0.1340]比较每一列, 经过
torch.argmax得到的是[1, 0, 2] - 它比较的是
-
如果按照去掉
dim = 0的部分,x':[ [-1.3918, 0.0620, -0.4111], [ 1.9623, -1.3399, -0.4673], [-0.0185, -1.9024, 0.1340] ], [ [ 0.7135, -0.5290, -0.7656], [ 0.2642, 0.5956, -0.0718], [-0.7465, -0.8098, -0.0874] ]也就是两个
size = (3, 3)的tensor, 这为什么不是第二种情况就比较合理了
因为比较的是两个tensor, 而第二种情况是分别在一个tensor内的比较, 再将两个tensor的比较结果合并- 总结: 比较的是去掉指定维度后的第一个维度, 比如这里的:(
2, 3, 3) -> (3, 3), 得到的结果的size是去掉指定dim的size
- 总结: 比较的是去掉指定维度后的第一个维度, 比如这里的:(
-
如果只有两个维度, 或许会好理解一些:
import torch x = torch.randn(2,3) print(x) y = torch.argmax(x, dim=0) print(y) print(y.size())输出:
tensor( [ [ 0.0251, -0.3640, 0.1965], [ 0.6902, 0.9846, 0.2035] ] ) tensor([1, 1, 1]) torch.Size([3])去掉
dim = 0, 比较的就是[ 0.0251, -0.3640, 0.1965]和[ 0.6902, 0.9846, 0.2035]
dim = (2, 3) -> dim(3) -
这时候再回来看上面3个维度的例子:
[ [-1.3918, 0.0620, -0.4111], [ 1.9623, -1.3399, -0.4673], [-0.0185, -1.9024, 0.1340] ], [ [ 0.7135, -0.5290, -0.7656], [ 0.2642, 0.5956, -0.0718], [-0.7465, -0.8098, -0.0874] ]比较两者时相当于在下面的
tensor做torch.argmax()[ [-1.3918, 0.0620, -0.4111], [ 0.7135, -0.5290, -0.7656] ]
更多推荐


所有评论(0)