用PyTorch可视化拆解:CNN与MLP的本质联系与差异

在咖啡厅里,我常看到初学者对着厚厚的教材皱眉——那些关于卷积神经网络(CNN)和多层感知机(MLP)关系的数学公式,就像天书般令人困惑。直到有天,我随手在Jupyter里画了几行代码,突然发现:原来这两个看似不同的结构,本质上是同一枚硬币的两面。本文将带您用PyTorch和Matplotlib,通过 可视化计算过程 来直观理解这个深度学习中的重要概念。

1. 环境准备与基础概念速览

1.1 快速搭建实验环境

我们先准备好实验所需的工具链。推荐使用Google Colab或本地Jupyter环境,确保已安装最新版PyTorch:

import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np

print(f"PyTorch版本: {torch.__version__}")
print(f"CUDA可用: {torch.cuda.is_available()}")

1.2 CNN与MLP的简明定义

  • CNN(卷积神经网络) :通过局部感受野和权值共享处理网格状数据(如图像)的神经网络
  • MLP(多层感知机) :全连接网络,每个神经元都与上一层的所有神经元相连

关键疑问 :为什么说MLP是CNN的特例?让我们用代码来验证这个命题。

2. 从代码角度看CNN的"退化"过程

2.1 构建等尺寸卷积核的CNN

假设我们有一张3x3的灰度图像,用CNN处理时故意将卷积核也设为3x3:

# 模拟3x3输入图像
input_img = torch.tensor([[1,2,3],
                         [4,5,6],
                         [7,8,9]], dtype=torch.float32).unsqueeze(0).unsqueeze(0)

# 定义3x3卷积核(与输入同尺寸)
conv_layer = nn.Conv2d(1, 1, kernel_size=3, stride=1, padding=0, bias=False)
with torch.no_grad():
    conv_layer.weight.data = torch.ones_like(conv_layer.weight) * 0.1  # 统一权重方便观察

# 执行卷积操作
output = conv_layer(input_img)
print(f"卷积输出: {output.squeeze()}")

此时卷积操作实际上是在进行 全局加权求和 ——这与MLP的全连接操作已经非常相似。

2.2 可视化计算过程

让我们把计算过程画出来:

def visualize_operation(input_tensor, weight_tensor, operation_type):
    fig, ax = plt.subplots(1, 2, figsize=(10,4))
    
    # 显示输入和权重
    ax[0].imshow(input_tensor.squeeze(), cmap='viridis')
    ax[0].set_title('Input Image')
    
    # 显示权重分布
    ax[1].imshow(weight_tensor.squeeze(), cmap='plasma')
    ax[1].set_title(f'{operation_type} Weights')
    
    plt.tight_layout()
    plt.show()

visualize_operation(input_img, conv_layer.weight.data, 'Convolution')

当卷积核与输入同尺寸时,每个输出像素都是所有输入像素的加权和——这正是全连接层的计算特性。

3. MLP的卷积视角解读

3.1 用1x1卷积实现MLP

在PyTorch中,我们可以用1x1卷积来模拟MLP的全连接操作:

# 将3x3图像展平为9维向量
flatten_input = input_img.view(1, 1, -1)  # 形状变为[1,1,9]

# 定义等效的"全连接层"(实际是1x1卷积)
mlp_layer = nn.Conv1d(1, 1, kernel_size=1, bias=False)
with torch.no_grad():
    mlp_layer.weight.data = torch.ones_like(mlp_layer.weight) * 0.1

# 执行"全连接"操作
mlp_output = mlp_layer(flatten_input)
print(f"MLP输出: {mlp_output.squeeze()}")

3.2 计算过程的数学等价性

让我们对比两种操作的数学表达式:

操作类型 计算公式 输出形状
等尺寸CNN $output = \sum_{i=1}^{3}\sum_{j=1}^{3} w_{ij}x_{ij}$ 标量
展平MLP $output = \sum_{k=1}^{9} w_kx_k$ 标量

关键发现 :当CNN的卷积核覆盖整个输入区域时,其计算过程与MLP完全相同。

4. 为什么图像处理不用"退化版CNN"

4.1 空间信息丢失问题

用代码演示使用全尺寸卷积核处理真实图像的问题:

from PIL import Image

# 加载测试图像
img = Image.open('test_image.jpg').convert('L').resize((224,224))
img_tensor = torch.from_numpy(np.array(img)).float().unsqueeze(0).unsqueeze(0)

# 定义全尺寸卷积(实际不可行)
try:
    full_conv = nn.Conv2d(1, 1, kernel_size=224, stride=1, padding=0)
    output = full_conv(img_tensor)
except Exception as e:
    print(f"错误: {e}")

实际问题

  1. 参数量爆炸(224x224的卷积核有50,176个参数)
  2. 无法捕捉局部特征
  3. 计算复杂度呈指数增长

4.2 局部感受野的优势对比

通过表格对比两种方式的特性:

特性 全尺寸卷积(MLP式) 标准CNN
参数量 $O(n^2)$ $O(k^2)$ (k<<n)
空间信息 完全丢失 保留局部关系
计算效率 极低
平移不变性
适用场景 小规模结构化数据 图像/视频等网格数据
# 演示标准CNN处理图像的效果
normal_conv = nn.Conv2d(1, 1, kernel_size=3, padding=1)
output = normal_conv(img_tensor)

plt.figure(figsize=(12,4))
plt.subplot(1,2,1)
plt.title("原始图像")
plt.imshow(img_tensor.squeeze(), cmap='gray')

plt.subplot(1,2,2)
plt.title("3x3卷积结果")
plt.imshow(output.detach().squeeze(), cmap='gray')
plt.show()

5. 进阶理解:网络结构中的灵活转换

5.1 ResNet中的MLP与CNN混合

在现代架构中,常常能看到两者的混合使用。例如ResNet中的瓶颈结构:

class Bottleneck(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=1)  # 1x1卷积(类似MLP)
        self.conv2 = nn.Conv2d(64, 64, kernel_size=3, padding=1) # 标准卷积
        self.conv3 = nn.Conv2d(64, 256, kernel_size=1)  # 1x1卷积
        
    def forward(self, x):
        return self.conv3(self.conv2(self.conv1(x)))

设计要点

  • 1x1卷积用于降维/升维(类似MLP的功能)
  • 3x3卷积捕捉空间特征
  • 两者配合实现高效计算

5.2 Vision Transformer中的特殊案例

有趣的是,Vision Transformer (ViT) 的处理方式:

# 模拟ViT的patch嵌入层
image = torch.randn(1, 3, 224, 224)
patch_size = 16
num_patches = (224 // patch_size) ** 2

# 将图像分割为16x16的patch并展平
patches = image.unfold(2, patch_size, patch_size).unfold(3, patch_size, patch_size)
patches = patches.contiguous().view(1, num_patches, -1)  # 形状[1, 196, 768]

# 线性投影(本质是MLP)
projection = nn.Linear(patch_size*patch_size*3, 768)
embedded = projection(patches)

这种处理实际上是将局部区域先展平再用MLP处理,是另一种空间信息利用方式。

Logo

免费领 200 小时云算力,进群参与显卡、AI PC 幸运抽奖

更多推荐