torch.load 出现 AttributeError: Can‘t get attribute ‘Net‘ on module ‘__main__‘
torch.load 出现 AttributeError: Can't get attribute 'Net' on module '__main__'问题解决方案问题最近,将已经训练好的模型保存下来后,通过torch.load(model_path)方法读取时,发现没办法正常运行,抛出如下错误:AttributeError: Can't get attribute 'Net' on module
文章共403字 · 阅读需要大约2分钟
一键AI生成摘要,助你高效阅读
问答
·
问题
最近,将已经训练好的模型保存下来后,通过torch.load(model_path)方法读取时,发现没办法正常运行,抛出如下错误:
AttributeError: Can't get attribute 'Net' on module '__main__'
我直接好家伙,骂骂咧咧去搜为啥。
报错原因: torch.load()方法所在文件,找不到 Net 类。好兄弟,这是关键!
pycharm中一路点方法下去看,点了两层到一个 _serialization.py的文件的load()方法
看到Return 是一个类对象,就可以知道,load过程实际上需要new一个新的模型对象,那就需要加载模型对应的类了。
所以我们需要导入下模型对应的类,否则就会抛出上述异常。
解决方案
因为我实在test.py文件中调用,所以直接在该文件头部,导入对应的model类文件即可。
from *.py import NET
*.py
表示你的model类所在的文件NET
表示模型对应的类
更多推荐
已为社区贡献4条内容
所有评论(0)