一、.view()函数简介

  • PyTorch中的.view()函数是一个用于改变张量形状的方法。它类似于NumPy中的.reshape()函数,可以通过重新排列张量的维度来改变其形状,而不改变张量的数据
  • 在深度学习中,.view()函数常用于调整输入数据的形状以适应模型的输入要求,或者在网络层之间传递数据时进行形状的转换。
  • .view()函数的语法如下,shape是一个整数元组,用于指定新的张量形状,新形状的元素个数必须与原形状的元素个数相同。函数返回一个具有指定形状的新张量,但与原始张量共享数据存储,因此它们指向相同的内存区域。
New_Tensor = Tensor.view(*shape)

二、.view()函数的使用方法

1. 改变形状

直接使用.view(new_shape)将张量修改为指定形状,保持新张量和原张量在数量上一致即可。

import torch

# 创建一个形状为(2, 3, 4)的张量
x = torch.randn(2, 3, 4)
print("原始张量形状:", x.shape)

# 使用.view()改变张量形状为(6, 4)
y = x.view(6, 4)
print("改变形状后的张量形状:", y.shape)

运行结果:
在这里插入图片描述

2. 使用-1展平张量

在PyTorch中,.view()函数可以接受一个特殊的参数 -1,用于自动计算张量在该维度上的大小。将某个维度的大小设置为 -1,可以使得该维度的大小根据其他维度的大小自动确定,以保持张量的元素总数不变。

import torch

# 创建一个形状为(2, 3, 4)的张量
x = torch.randn(2, 3, 4)
print("原始张量形状:", x.shape)

# 使用.view()展平张量为形状为(12,2)
y1 = x.view(-1,2)
print("展平后的张量y1形状:", y1.shape)

# 使用.view()展平张量为形状为(6,2,2)
y2 = x.view(-1,2,2)
print("展平后的张量y2形状:", y2.shape)

运行结果:
在这里插入图片描述

3. 调整维度

import torch

# 创建一个形状为(2, 3, 4)的张量
x = torch.randn(2, 3, 4)
print("原始张量形状:", x.shape)

# 使用.view()调整张量的维度为(2, 4, 3)
y = x.view(2, 4, 3)
print("调整维度后的张量形状:", y.shape)

运行结果:
在这里插入图片描述

Logo

旨在为数千万中国开发者提供一个无缝且高效的云端环境,以支持学习、使用和贡献开源项目。

更多推荐