用Python代码可视化理解Self-Attention和Transformer Encoder

在深度学习领域,Transformer架构已经成为自然语言处理任务的事实标准。然而,对于许多学习者来说,Self-Attention机制和Transformer Encoder的工作原理仍然显得抽象难懂。本文将带你通过Python代码实现和可视化,从零开始构建这些核心组件,让抽象的概念变得直观可见。

1. 环境准备与基础概念

在开始编码之前,我们需要确保环境配置正确并理解一些基础概念。首先安装必要的Python库:

pip install numpy matplotlib torch

Self-Attention是Transformer的核心机制,它允许模型在处理序列数据时,动态地关注输入序列的不同部分。与传统的RNN或CNN不同,Self-Attention能够直接建模序列中任意两个位置之间的关系,无论它们相距多远。

理解Self-Attention需要掌握三个关键向量:

  • Query(查询向量) :表示当前正在处理的位置
  • Key(键向量) :表示序列中所有位置的标识
  • Value(值向量) :包含每个位置的实际信息

这些向量通过以下步骤相互作用:

  1. 计算Query与所有Key的点积
  2. 缩放点积结果
  3. 应用softmax函数获得注意力权重
  4. 用权重对Value进行加权求和

2. 实现基础Self-Attention

让我们从实现基础的Scaled Dot-Product Attention开始。首先定义输入序列和必要的参数:

import numpy as np
import matplotlib.pyplot as plt

# 定义输入序列 (序列长度=3, 嵌入维度=4)
X = np.array([
    [1.0, 0.5, 0.2, 0.1],  # 第一个词向量
    [0.7, 0.6, 0.3, 0.2],  # 第二个词向量
    [0.4, 0.3, 0.2, 0.1]   # 第三个词向量
])

# 定义权重矩阵 (嵌入维度=4, 注意力维度=3)
W_Q = np.random.randn(4, 3) * 0.1
W_K = np.random.randn(4, 3) * 0.1
W_V = np.random.randn(4, 3) * 0.1

接下来实现Self-Attention的计算过程:

def scaled_dot_product_attention(X, W_Q, W_K, W_V):
    # 计算Q, K, V矩阵
    Q = np.dot(X, W_Q)
    K = np.dot(X, W_K)
    V = np.dot(X, W_V)
    
    # 计算注意力分数
    attention_scores = np.dot(Q, K.T)
    
    # 缩放
    d_k = K.shape[-1]
    attention_scores = attention_scores / np.sqrt(d_k)
    
    # 应用softmax
    attention_weights = np.exp(attention_scores) / np.sum(np.exp(attention_scores), axis=-1, keepdims=True)
    
    # 加权求和
    output = np.dot(attention_weights, V)
    
    return output, attention_weights

output, attention_weights = scaled_dot_product_attention(X, W_Q, W_K, W_V)

为了更直观地理解这个过程,我们可以可视化注意力权重:

def plot_attention_weights(weights):
    fig, ax = plt.subplots()
    im = ax.imshow(weights, cmap='viridis')
    
    # 设置坐标轴标签
    ax.set_xticks(np.arange(len(weights)))
    ax.set_yticks(np.arange(len(weights)))
    ax.set_xticklabels(["Token 1", "Token 2", "Token 3"])
    ax.set_yticklabels(["Token 1", "Token 2", "Token 3"])
    
    # 添加颜色条
    plt.colorbar(im)
    plt.title("Attention Weights")
    plt.show()

plot_attention_weights(attention_weights)

这个热图展示了每个token对其他token的关注程度。对角线通常较强,因为token会关注自身,但也会关注其他相关token。

3. 多头注意力机制

单头注意力只能学习一种关注模式,而多头注意力允许模型同时关注不同位置的不同方面。实现多头注意力需要:

  1. 将Q、K、V投影到多个子空间
  2. 在每个子空间独立计算注意力
  3. 拼接所有头的输出
  4. 通过线性变换得到最终输出
class MultiHeadAttention:
    def __init__(self, d_model=4, num_heads=2):
        self.d_model = d_model
        self.num_heads = num_heads
        self.depth = d_model // num_heads
        
        # 初始化权重矩阵
        self.W_Q = np.random.randn(d_model, d_model) * 0.1
        self.W_K = np.random.randn(d_model, d_model) * 0.1
        self.W_V = np.random.randn(d_model, d_model) * 0.1
        self.W_O = np.random.randn(d_model, d_model) * 0.1
    
    def split_heads(self, x):
        # 将最后一个维度分割为(num_heads, depth)
        batch_size = x.shape[0]
        return x.reshape(batch_size, -1, self.num_heads, self.depth).transpose(0, 2, 1, 3)
    
    def __call__(self, X):
        batch_size = X.shape[0]
        
        # 线性变换
        Q = np.dot(X, self.W_Q)
        K = np.dot(X, self.W_K)
        V = np.dot(X, self.W_V)
        
        # 分割多头
        Q = self.split_heads(Q)
        K = self.split_heads(K)
        V = self.split_heads(V)
        
        # 计算缩放点积注意力
        attention_scores = np.matmul(Q, K.transpose(0, 1, 3, 2)) / np.sqrt(self.depth)
        attention_weights = np.exp(attention_scores) / np.sum(np.exp(attention_scores), axis=-1, keepdims=True)
        scaled_attention = np.matmul(attention_weights, V)
        
        # 拼接多头
        scaled_attention = scaled_attention.transpose(0, 2, 1, 3)
        concat_attention = scaled_attention.reshape(batch_size, -1, self.d_model)
        
        # 最终线性变换
        output = np.dot(concat_attention, self.W_O)
        
        return output, attention_weights

# 使用多头注意力
multi_head_attn = MultiHeadAttention()
output, attention_weights = multi_head_attn(X)

# 可视化第一个头的注意力权重
plot_attention_weights(attention_weights[0])

多头注意力的优势在于它能够并行学习不同的关注模式。例如,在处理自然语言时,一个头可能关注语法关系,另一个头可能关注语义关系。

4. Transformer Encoder实现

完整的Transformer Encoder层包含以下组件:

  1. 多头自注意力机制
  2. 残差连接和层归一化
  3. 前馈神经网络
  4. 再次残差连接和层归一化

让我们实现这些组件:

class LayerNormalization:
    def __init__(self, d_model, eps=1e-6):
        self.gamma = np.ones(d_model)
        self.beta = np.zeros(d_model)
        self.eps = eps
    
    def __call__(self, x):
        mean = np.mean(x, axis=-1, keepdims=True)
        std = np.std(x, axis=-1, keepdims=True)
        return self.gamma * (x - mean) / (std + self.eps) + self.beta

class FeedForwardNetwork:
    def __init__(self, d_model=4, d_ff=8):
        self.W1 = np.random.randn(d_model, d_ff) * 0.1
        self.b1 = np.zeros(d_ff)
        self.W2 = np.random.randn(d_ff, d_model) * 0.1
        self.b2 = np.zeros(d_model)
    
    def __call__(self, x):
        return np.dot(np.maximum(0, np.dot(x, self.W1) + self.b1), self.W2) + self.b2

class TransformerEncoderLayer:
    def __init__(self, d_model=4, num_heads=2):
        self.multi_head_attn = MultiHeadAttention(d_model, num_heads)
        self.ffn = FeedForwardNetwork(d_model)
        self.layernorm1 = LayerNormalization(d_model)
        self.layernorm2 = LayerNormalization(d_model)
    
    def __call__(self, x):
        # 多头注意力
        attn_output, _ = self.multi_head_attn(x)
        
        # 残差连接和层归一化
        x = self.layernorm1(x + attn_output)
        
        # 前馈网络
        ffn_output = self.ffn(x)
        
        # 再次残差连接和层归一化
        return self.layernorm2(x + ffn_output)

# 使用Transformer Encoder层
encoder_layer = TransformerEncoderLayer()
encoder_output = encoder_layer(X)

为了理解为什么Transformer使用Layer Normalization而不是Batch Normalization,我们可以对比两者的效果:

def compare_normalization(X):
    # Batch Normalization
    batch_mean = np.mean(X, axis=0)
    batch_std = np.std(X, axis=0)
    batch_norm = (X - batch_mean) / (batch_std + 1e-6)
    
    # Layer Normalization
    layer_mean = np.mean(X, axis=-1, keepdims=True)
    layer_std = np.std(X, axis=-1, keepdims=True)
    layer_norm = (X - layer_mean) / (layer_std + 1e-6)
    
    # 可视化对比
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
    
    ax1.imshow(batch_norm, cmap='viridis')
    ax1.set_title("Batch Normalization")
    ax1.set_xticks(np.arange(X.shape[1]))
    ax1.set_yticks(np.arange(X.shape[0]))
    
    ax2.imshow(layer_norm, cmap='viridis')
    ax2.set_title("Layer Normalization")
    ax2.set_xticks(np.arange(X.shape[1]))
    ax2.set_yticks(np.arange(X.shape[0]))
    
    plt.show()

compare_normalization(X)

Batch Normalization对每个特征维度在batch上进行归一化,而Layer Normalization对每个样本的所有特征进行归一化。对于序列数据,Layer Normalization更加稳定,因为它不受batch size和序列长度变化的影响。

5. 完整Transformer Encoder可视化

现在我们将所有组件组合起来,构建一个完整的Transformer Encoder,并通过可视化理解其工作原理:

class TransformerEncoder:
    def __init__(self, num_layers=2, d_model=4, num_heads=2):
        self.layers = [TransformerEncoderLayer(d_model, num_heads) for _ in range(num_layers)]
    
    def __call__(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

# 创建并运行Transformer Encoder
encoder = TransformerEncoder()
encoder_output = encoder(X)

# 可视化输入和输出的变化
def plot_input_output(input_seq, output_seq):
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
    
    ax1.imshow(input_seq, cmap='viridis')
    ax1.set_title("Input Sequence")
    ax1.set_xticks(np.arange(input_seq.shape[1]))
    ax1.set_yticks(np.arange(input_seq.shape[0]))
    
    ax2.imshow(output_seq, cmap='viridis')
    ax2.set_title("Encoder Output")
    ax2.set_xticks(np.arange(output_seq.shape[1]))
    ax2.set_yticks(np.arange(output_seq.shape[0]))
    
    plt.show()

plot_input_output(X, encoder_output)

通过对比输入和输出的可视化,我们可以看到Transformer Encoder如何转换输入序列。每个位置的输出现在都包含了整个序列的上下文信息,这正是Self-Attention机制的核心价值。

6. 实际应用与扩展

理解了Transformer Encoder的基本原理后,我们可以将其应用于实际任务。例如,构建一个简单的文本分类器:

class TransformerClassifier:
    def __init__(self, vocab_size=100, d_model=4, num_heads=2, num_classes=2):
        self.embedding = np.random.randn(vocab_size, d_model) * 0.1
        self.encoder = TransformerEncoder(d_model=d_model, num_heads=num_heads)
        self.classifier = np.random.randn(d_model, num_classes) * 0.1
    
    def __call__(self, input_ids):
        # 嵌入层
        x = self.embedding[input_ids]
        
        # Transformer Encoder
        x = self.encoder(x)
        
        # 平均池化
        x = np.mean(x, axis=0)
        
        # 分类层
        return np.dot(x, self.classifier)

# 示例使用
model = TransformerClassifier()
input_ids = np.array([1, 5, 3])  # 假设的token ID序列
logits = model(input_ids)
print("Classification logits:", logits)

这个简单的分类器展示了如何将Transformer Encoder应用于下游任务。在实际应用中,我们通常会使用更大的模型和更复杂的架构,但基本原理是相同的。

为了进一步理解Transformer的威力,我们可以比较不同架构在处理长距离依赖时的表现:

def compare_architectures(sequence_length=10):
    # 创建一个简单的序列任务:识别序列中是否有特定的模式
    X = np.random.randn(sequence_length, 4)
    # 在序列的开始和结束位置添加特殊模式
    X[0, :2] = [1, -1]
    X[-1, :2] = [1, -1]
    
    # 定义不同架构
    class RNNModel:
        def __init__(self, hidden_size=4):
            self.W_hh = np.random.randn(hidden_size, hidden_size) * 0.1
            self.W_xh = np.random.randn(4, hidden_size) * 0.1
            self.W_hy = np.random.randn(hidden_size, 1) * 0.1
        
        def __call__(self, x):
            h = np.zeros(self.W_hh.shape[0])
            for t in range(x.shape[0]):
                h = np.tanh(np.dot(h, self.W_hh) + np.dot(x[t], self.W_xh))
            return np.dot(h, self.W_hy)
    
    class TransformerModel:
        def __init__(self, d_model=4, num_heads=2):
            self.encoder = TransformerEncoder(d_model=d_model, num_heads=num_heads)
            self.W_hy = np.random.randn(d_model, 1) * 0.1
        
        def __call__(self, x):
            x = self.encoder(x)
            x = np.mean(x, axis=0)  # 平均池化
            return np.dot(x, self.W_hy)
    
    # 测试模型
    rnn = RNNModel()
    transformer = TransformerModel()
    
    print("RNN output:", rnn(X))
    print("Transformer output:", transformer(X))

compare_architectures()

在这个简单的对比中,Transformer能够更好地捕捉序列两端的模式,而RNN由于顺序处理的特性,可能会在长序列中丢失早期信息。

更多推荐