torch.load 出现 AttributeError: Can't get attribute 'Net' on module '__main__'

问题

最近,将已经训练好的模型保存下来后,通过torch.load(model_path)方法读取时,发现没办法正常运行,抛出如下错误:

 AttributeError: Can't get attribute 'Net' on module '__main__'

我直接好家伙,骂骂咧咧去搜为啥。


报错原因: torch.load()方法所在文件,找不到 Net 类。好兄弟,这是关键!

pycharm中一路点方法下去看,点了两层到一个 _serialization.py的文件的load()方法

在这里插入图片描述

看到Return 是一个类对象,就可以知道,load过程实际上需要new一个新的模型对象,那就需要加载模型对应的类了。

所以我们需要导入下模型对应的类,否则就会抛出上述异常。

解决方案

因为我实在test.py文件中调用,所以直接在该文件头部,导入对应的model类文件即可。

from *.py import NET
  • *.py 表示你的model类所在的文件
  • NET 表示模型对应的类

Logo

旨在为数千万中国开发者提供一个无缝且高效的云端环境,以支持学习、使用和贡献开源项目。

更多推荐