报错代码

Wh = torch.mm(h, self.w)

报错RuntimeError: self must be a matrix

原因:torch.mm()是两个矩阵相乘,即两个二维的张量相乘,维度超过二维,则会报错。
这两个tensor的维度是[16, 16, 29][29, 70]

>>> h.shape
torch.Size([16, 16, 29])
>>> self.w.shape
torch.Size([29, 70])

修改:使用torch.matmul()

Wh = torch.matmul(h, self.w)

>>>Wh.shape
torch.Size([16, 16, 70])
Logo

长江两岸老火锅,共聚山城开发者!We Want You!

更多推荐