【记犯的一次低级错误】


完整错误信息如下:

RuntimeError: Error(s) in loading state_dict for DataParallel:
	size mismatch for module.lstm_block.lstm.weight_ih_l0: copying a param with shape torch.Size([1024, 500]) from checkpoint, the shape in current model is torch.Size([1024, 2000]).

错哪了:

错误意思大概是加载state_dict时,参数不匹配。百度基本都说是PyTorch版本环境不一致、torch.nn.DataParallel()关键字不匹配等,提出的解决办法是把strict参数赋False,如下:

checkpoint_file = os.path.join(args.checkpoint, args.test+'.pth.tar')
checkpoint = torch.load(checkpoint_file) 
model.load_state_dict(checkpoint['state_dict'],False) # 修改处

但这招在我这行不通,思来想去最后在Google找到答案,不得不说stack overflow还是牛。实际原因是一个很低级的错误:就是模型初始化的一个参数错了。

这是我的训练代码:

win_width     = 5        # 样本长度,单位为秒
time_steps    = win_width * sample_rate
num_variables = 2

model = LSTMFCN(time_steps, num_variables)

然而预测代码是这样的:

win_width     = 20        # 样本长度,单位为秒
time_steps    = win_width * sample_rate
num_variables = 2

model = LSTMFCN(time_steps, num_variables)

看看人家是怎么说的

Stack Overflow:Size Mismatch Runtime Error When Trying to Load a PyTorch Model
在这里插入图片描述

最终解决办法:

把参数改成和训练时一样的就OK啦。

Reference

[1] Stack Overflow:Size Mismatch Runtime Error When Trying to Load a PyTorch Model
[2] CSDN:pytorch加载模型报错RuntimeError:Error(s) in loading state_dict for DataParallel

Logo

旨在为数千万中国开发者提供一个无缝且高效的云端环境,以支持学习、使用和贡献开源项目。

更多推荐