torch.sort()和torch.argsort()简要介绍
定义1torch.sort(a,dim,descending)用法1输入a,在dim维进行排序,descending控制是否降序,默认为False。输出排序后的值以及对应值在原a中的下标,示例1import torcha = torch.tensor([[10,2,3],[4,6,5],[7,8,9]])print(a)>>tensor([[10,2,3],[ 4,6,5],[ 7,8
·
定义1
torch.sort(a,dim,descending)
用法1
输入a,在dim维进行排序,descending控制是否降序,默认为False。
输出排序后的值以及对应值在原a中的下标,
示例1
import torch
a = torch.tensor([[10,2,3],[4,6,5],[7,8,9]])
print(a)
>>tensor([[10, 2, 3],
[ 4, 6, 5],
[ 7, 8, 9]])
在dim=0默认升序
torch.sort(a,0)
>>torch.return_types.sort(
values=tensor([[ 4, 2, 3],
[ 7, 6, 5],
[10, 8, 9]]),
indices=tensor([[1, 0, 0],
[2, 1, 1],
[0, 2, 2]]))
在dim=1降序
torch.sort(a,1,descending=True)
>>torch.return_types.sort(
values=tensor([[10, 3, 2],
[ 6, 5, 4],
[ 9, 8, 7]]),
indices=tensor([[0, 2, 1],
[1, 2, 0],
[2, 1, 0]]))
定义2
torch.argsort()
用法2
返回排序后的值所对应原a的下标,即torch.sort()返回的indices
示例2
将输入a在 dim=0降序排列
torch.argsort(a,0,descending=True)
>>tensor([[0, 2, 2],
[2, 1, 1],
[1, 0, 0]])
更多推荐
已为社区贡献3条内容
所有评论(0)