apply 函数是nn.Module 中实现的, 递归地调用self.children() 去处理自己以及子模块。

该方法会将fn递归的应用于模块的每一个子模块(.children()的结果)及其自身。典型的用法是,对一个model的参数进行初始化。

from torch import nn
import torch
@torch.no_grad()  ##装饰器
def init_weights(m):
    print(m)
    if type(m) == nn.Linear:
        m.weight.data.fill_(1.0)
        m.bias.data.fill_(0)


model = nn.Sequential(
    nn.Linear(2, 2),
)
model.apply(init_weights)
print(list(model.parameters()))

Logo

为开发者提供学习成长、分享交流、生态实践、资源工具等服务,帮助开发者快速成长。

更多推荐