torch.bmm 是 PyTorch 中的一个函数,用于执行批矩阵乘法(batch matrix multiplication)操作。

它的输入是三维张量,形状为 (batch_size, n, m) 和 (batch_size, m, p):
其中 n 是第一个矩阵的列数,m 是两个矩阵共享的维度,p 是第二个矩阵的列数。

torch.bmm 将批中的每对矩阵相乘,返回一个新的三维张量,形状为 (batch_size, n, p)。

例如,假设我们有两个批次的矩阵 A 和 B,维度分别为 (2, 3, 4) 和 (2, 4, 5)。我们可以使用 torch.bmm 将它们相乘
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

Logo

旨在为数千万中国开发者提供一个无缝且高效的云端环境,以支持学习、使用和贡献开源项目。

更多推荐