不抄Transformer,我们自己造轮子!彻底搞懂注意力机制的计算流程

一、为什么还要自己实现注意力?

很多人学完Transformer后,对Attention的理解还停留在论文公式层面。但面试一问"QKV具体怎么计算的",就支支吾吾说不清楚。

关键问题:注意力机制到底在做什么?

  • 输入一个查询Q,它想知道自己跟谁最相关

  • K是"标签",V是"内容",通过Q和K的匹配度,从V中提取信息

  • 最终Q被升级成"带上下文信息的Q"

举个例子:你查"苹果"这个词(Q),词典里有32个词条(K),每个词条有64维特征(V)。Attention算出你查的"苹果"跟这32个词条的相似度,然后按照相似度加权组合这32个词条的特征,得到一个增强版的"苹果"表示。

这就是Attention的本质:用查询去检索信息,加权聚合出更有价值的表示。

二、我们的任务

已知:

  • V:32个单词,每个单词64维特征 → [1, 32, 64]

  • K:32个单词的"索引标记" → [1, 1, 32]

  • Q:查询张量,比如查"我" → [1, 1, 32]

要做什么?

  1. 计算Q与32个单词的相关性(注意力权重分布)

  2. 用权重加权V,得到一个增强版Q

  3. 输出:增强后的Q([1, 1, 32])和注意力权重([1, 32]

三、类结构设计

MyAttn
├── __init__(self, query_size, key_size, value_size1, value_size2, output_size)
│   ├── self.attn: Linear(32+32 → 32)        # 计算注意力得分
│   └── self.attn_combine: Linear(32+64 → 32) # 融合Q和V
│
└── forward(self, Q, K, V)
    ├── 1. 计算注意力权重
    │   ├── 拼接Q和K → [1, 64]
    │   ├── 线性层 → [1, 32]  (得分)
    │   └── Softmax → [1, 32] (概率分布)
    ├── 2. 加权聚合V
    │   └── 权重 × V → [1, 1, 64]
    ├── 3. 融合Q和聚合后的V
    │   ├── 拼接 → [1, 1, 96]
    │   └── 线性层 → [1, 1, 32]
    └── 4. 返回输出和注意力权重

四、完整代码逐行解析

import torch
import torch.nn as nn
import torch.nn.functional as F

class MyAttn(nn.Module):
    def __init__(self, query_size, key_size, value_size1, value_size2, output_size):
        super().__init__()
        
        # 记录维度参数(方便调试和扩展)
        self.query_size = query_size
        self.key_size = key_size
        self.value_size1 = value_size1  # 序列长度(单词数)
        self.value_size2 = value_size2  # 词向量维度
        self.output_size = output_size
        
        # 【核心1】注意力得分计算层
        # 输入:Q拼接K (32+32=64维)
        # 输出:32维(每个单词一个得分)
        self.attn = nn.Linear(query_size + key_size, value_size1)
        
        # 【核心2】注意力融合层
        # 输入:Q拼接加权后的V (32+64=96维)
        # 输出:32维(最终表示)
        self.attn_combine = nn.Linear(query_size + value_size2, output_size)

为什么设计两个线性层?

  • attn:负责"相关性打分",输入Q和K,输出跟V序列长度一致的得分向量

  • attn_combine:负责"融合升级",把原始Q和从V中提取的信息合并,得到增强版Q

    def forward(self, Q, K, V):
        # === 阶段1:计算注意力权重 ===
        # 为什么要用Q[0]?因为假设batch_size=1,取出具体数据方便计算
        # Q[0]: [1, 32], K[0]: [1, 32] → cat → [1, 64]
        qk_cat = torch.cat((Q[0], K[0]), dim=-1)
        
        # 线性变换得到得分:每个单词一个分数
        # [1, 64] → [1, 32]
        attn_scores = self.attn(qk_cat)
        
        # Softmax归一化为概率分布
        # [1, 32] → [1, 32] (所有概率和为1)
        attn_weights = F.softmax(attn_scores, dim=-1)

Softmax的作用

  • 将原始得分(可正可负)转换为概率分布(0-1之间,和为1)

  • 让"相关"的单词权重高,"不相关"的权重低

  • dim=-1表示在最后一个维度(32个单词)上做归一化

        # === 阶段2:用权重加权聚合V ===
        # 扩展维度:准备做矩阵乘法
        # [1, 32] → [1, 1, 32]
        attn_weights_expanded = attn_weights.unsqueeze(0)
        
        # 批量矩阵乘法:权重 × V
        # [1, 1, 32] @ [1, 32, 64] = [1, 1, 64]
        attn_applied = torch.bmm(attn_weights_expanded, V)

bmm的作用

  • bmm = batch matrix multiplication(批量矩阵乘法)

  • 权重矩阵 [1, 1, 32] × V矩阵 [1, 32, 64]

  • 本质:用32个权重系数,对32个64维向量做加权求和

  • 结果:一个64维的"加权平均向量"

        # === 阶段3:融合Q和提取的信息 ===
        # 拼接:原始Q + 从V中提取的信息
        # [1, 1, 32] + [1, 1, 64] = [1, 1, 96]
        output_cat = torch.cat((Q, attn_applied), dim=-1)
        
        # 降维:融合后映射到目标维度
        # [1, 1, 96] → [1, 1, 32]
        output = self.attn_combine(output_cat)
        
        return output, attn_weights

为什么要融合Q和attn_applied?

  • attn_applied是从V中提取的"外部信息"

  • 原始Q包含"自身信息"

  • 融合两者得到"自我+上下文"的增强表示

  • 类似于ResNet的残差思想:输出 = Q + 从V中提取的信息

五、测试代码

if __name__ == '__main__':
    # 维度配置
    query_size, key_size, value_size1, value_size2, output_size = 32, 32, 32, 64, 32
    
    # 构造输入
    Q = torch.randn(1, 1, query_size)    # [1, 1, 32]
    K = torch.randn(1, 1, key_size)      # [1, 1, 32]
    V = torch.randn(1, value_size1, value_size2)  # [1, 32, 64]
    
    # 前向传播
    my_attn = MyAttn(query_size, key_size, value_size1, value_size2, output_size)
    output, attn_weights = my_attn(Q, K, V)
    
    print(f"输出形状: {output.shape}")        # [1, 1, 32]
    print(f"注意力权重形状: {attn_weights.shape}")  # [1, 32]

六、维度变化速查表

步骤 操作 输入形状 输出形状 说明
1 拼接Q,K Q:[1,32], K:[1,32] [1,64] 准备打分
2 线性层attn [1,64] [1,32] 每个单词一个得分
3 Softmax [1,32] [1,32] 归一化为概率
4 unsqueeze [1,32] [1,1,32] 匹配batch维度
5 bmm [1,1,32] × [1,32,64] [1,1,64] 加权聚合V
6 拼接Q [1,1,32] + [1,1,64] [1,1,96] 融合信息
7 线性层combine [1,1,96] [1,1,32] 降维输出

七、核心问题Q&A

Q1:为什么用Q[0]而不是直接用Q?

A:因为假设batch_size=1,Q[0]取出第一维数据,从[1,1,32]变成[1,32],去掉冗余的batch维度,让计算更直观。但实际项目中建议直接用cat((Q,K), dim=-1),保持通用性。

Q2:attn和attn_combine的区别?

A:

  • attn:计算"注意力权重",输入Q和K,输出32维权重向量

  • attn_combine:融合"Q + 加权后的V",输出最终的增强表示

Q3:为什么要用bmm?

A:bmm专门用于批量矩阵乘法,高效且自动处理batch维度。这里用权重向量[1,1,32]去加权V的32个词向量[1,32,64],得到加权和[1,1,64]

Q4:注意力权重为什么要用Softmax?

A:Softmax将得分映射为概率分布,确保所有权重都在0-1之间且和为1。这样"相关"的单词获得高权重,"不相关"的获得低权重,符合注意力机制"聚焦重要信息"的思想。

八、总结

注意力机制的三步走

  1. 打分:用Q和K计算相关性得分

  2. 加权:Softmax归一化为权重,加权聚合V

  3. 融合:将聚合结果与原始Q融合,得到增强表示

代码设计的核心思想

  • 用线性层attn学习"如何打分"

  • 用线性层attn_combine学习"如何融合"

  • 所有参数通过反向传播自动学习

什么时候用Attention?

  • 需要从大量信息中提取关键内容时

  • 需要建模长距离依赖时

  • 需要动态分配计算资源时

更多推荐