torch.nn.Parameter是继承自torch.Tensor的子类,其主要作用是作为nn.Module中的可训练参数使用。它与torch.Tensor的区别就是nn.Parameter会自动被认为是module的可训练参数,即加入到parameter()这个迭代器中去;而module中非nn.Parameter()的普通tensor是不在parameter中的。

torch.nn.parameter.Parameter(data=None, requires_grad=True)
或
torch.nn.Parameter(data=None, requires_grad=True)

这两种写法都一样

      nn.Parameter可以看作是一个类型转换函数,将一个不可训练的类型 Tensor 转换成可以训练的类型 parameter ,并将这个 parameter 绑定到这个module 里面(net.parameter() 中就有这个绑定的 parameter,所以在参数优化的时候可以进行优化),所以经过类型转换这个变量就变成了模型的一部分,成为了模型中根据训练可以改动的参数。使用这个函数的目的也是想让某些变量在学习的过程中不断的修改其值以达到最优化。

     nn.Parameter()添加的参数会被添加到Parameters列表中,会被送入优化器中随训练一起学习更新   

      在nn.Module类中,pytorch也是使用nn.Parameter来对每一个module的参数进行初始化的

 

但是如果 nn.Parameter(requires_grad=False) 那么这个参数虽然绑定到模型里了,但是还是不可训练的,只是为了模型完整性这样写(例如magiclayout CVPR2021)

requires_grad默认值为True,表示可训练,False表示不可训练。

这样写还有一个好处就是,这个参数会随着模型的被移到cuda上,即如果执行过model.cuda(), 那么这个参数也就被移到了cuda上了

举例

import torch
from torch import nn

class MyModule(nn.Module):
    def __init__(self, input_size, output_size):
        super(MyModule, self).__init__()
        self.test = torch.rand(input_size, output_size)
        self.linear = nn.Linear(input_size, output_size)
    def forward(self, x):
        return self.linear(x)

model = MyModule(4, 2)
print(list(model.named_parameters()))

import torch
from torch import nn

class MyModule(nn.Module):
    def __init__(self, input_size, output_size):
        super(MyModule, self).__init__()
        self.test = nn.Parameter(torch.rand(input_size, output_size))
        self.linear = nn.Linear(input_size, output_size)
    def forward(self, x):
        return self.linear(x)

model = MyModule(4, 2)
print(list(model.named_parameters()))

也可以在外面,通过register_parameter()注册

import torch
from torch import nn

class MyModule(nn.Module):
    def __init__(self, input_size, output_size):
        super(MyModule, self).__init__()
        self.linear = nn.Linear(input_size, output_size)
    def forward(self, x):
        return self.linear(x)

model = MyModule(4, 2)
my_test = nn.Parameter(torch.rand(4, 2))
model.register_parameter('test',my_test)
print(list(model.named_parameters()))

Logo

为开发者提供学习成长、分享交流、生态实践、资源工具等服务,帮助开发者快速成长。

更多推荐