参考博客:
https://blog.csdn.net/weixin_40522801/article/details/106563354
https://blog.csdn.net/yangwangnndd/article/details/100207686

函数定义:

load_state_dict(state_dict, strict=True)

作用
使用 state_dict 反序列化模型参数字典。用来加载模型参数。将 state_dict 中的 parameters 和 buffers 复制到此 module 及其子节点中。
概况:给模型对象加载训练好的模型参数,即加载模型参数

关于state_dict
在PyTorch中,一个torch.nn.Module模型中的可学习参数(比如weights和biases),模型的参数通过model.parameters()获取。而state_dict就是一个简单的Python dictionary,其功能是将每层与层的参数张量之间一一映射。
注意,只有包含了可学习参数(卷积层、线性层等)的层和已注册的命令(registered buffers,比如batchnorm的running_mean)才有模型的state_dict入口。优化方法目标(torch.optim)也有state_dict,其中包含的是关于优化器状态的信息和使用到的超参数。
因为state_dict目标是Python dictionaries,所以它们可以很轻松地实现保存、更新、变化和再存储,从而给PyTorch模型和优化器增加了大量的模块化(modularity)。

使用示例:

model.load_state_dict(torch.load('pose_dekr_hrnetw32_coco.pth'), strict=True)

官方函数说明

Copies parameters and buffers from state_dict into this module
and its descendants. If strict is True, then the keys of
state_dict must exactly match the keys returned by this
module’s state_dict() function.
从函数接收的参数state_dict中将参数和缓冲拷贝到当前这个模块及其子模块中.
如果函数接受的参数strict是True,那么state_dict的关键字必须确切地严格地和
该模块的state_dict()函数返回的关键字相匹配.

Parameters 参数

state_dict (dict) – a dict containing parameters and persistent buffers.
state_dict (字典类型) – 一个包含参数和持续性缓冲的字典
往往是pytorch模型pth文件
strict (布尔类型, 可选) – 该参数用来指明是否需要强制严格匹配,:state_dict中的关键字是否需要和该模块的state_dict()方法返回的关键字强制严格匹配.默认值是True

Returns 返回

返回类型:NamedTuple with missing_keys and unexpected_keys fields
missing_keys is a list of str containing the missing keys
missing_keys是一个字符串的列表,该列表包含了所有缺失的关键字.
unexpected_keys is a list of str containing the unexpected keys
unexpected_keys是一个字符串的列表,该列表包含了意料之外的关键字,
Logo

旨在为数千万中国开发者提供一个无缝且高效的云端环境,以支持学习、使用和贡献开源项目。

更多推荐