【PyTorch】repeat_interleave()函数详解

函数原型

torch.repeat_interleave(input, repeats, dim=None) → Tensor

详解

重复张量的元素
输入参数

  1. input (类型:torch.Tensor):输入张量
  2. repeats(类型:int或torch.Tensor):每个元素的重复次数。repeats参数会被广播来适应输入张量的维度
  3. dim(类型:int)需要重复的维度。默认情况下,将把输入张量展平(flatten)为向量,然后将每个元素重复repeats次,并返回重复后的张量。

举例:

>>> x = torch.tensor([1, 2, 3])
>>> x.repeat_interleave(2)
tensor([1, 1, 2, 2, 3, 3])
# 传入多维张量,默认`展平`
>>> y = torch.tensor([[1, 2], [3, 4]])
>>> torch.repeat_interleave(y, 2)
tensor([1, 1, 2, 2, 3, 3, 4, 4])
# 指定维度
>>> torch.repeat_interleave(y,3,0)
tensor([[1, 2],
        [1, 2],
        [1, 2],
        [3, 4],
        [3, 4],
        [3, 4]])
>>> torch.repeat_interleave(y, 3, dim=1)
tensor([[1, 1, 1, 2, 2, 2],
        [3, 3, 3, 4, 4, 4]])
# 指定不同元素重复不同次数
>>> torch.repeat_interleave(y, torch.tensor([1, 2]), dim=0)
tensor([[1, 2],
        [3, 4],
        [3, 4]])
Logo

旨在为数千万中国开发者提供一个无缝且高效的云端环境,以支持学习、使用和贡献开源项目。

更多推荐