VQ-VAE(Vector Quantized Variational Autoencoder)

VQ-VAE(向量量化变分自编码器)是由DeepMind在2017年提出的一种自监督学习模型,主要用于离散表征学习。它结合了变分自编码器(VAE)和向量量化(VQ)技术,能够将连续的高维数据(如图像、音频)压缩为离散的符号(token),同时保持重建质量。VQ-VAE是许多现代生成模型(如VQ-VAE-2、VQ-GAN)的基础。

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np

# 超参数
BATCH_SIZE = 128
EPOCHS = 20
LEARNING_RATE = 2e-4
NUM_EMBEDDINGS = 128  # 码本大小(K)
EMBEDDING_DIM = 64  # 码字维度(D)
BETA = 0.25  # Commitment Loss权重

# 设备配置
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 加载MNIST数据集(替换为CIFAR-10可处理彩色图像)
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))  # MNIST单通道
])
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)


# 定义VQ-VAE模型
class VectorQuantizer(nn.Module):
    def __init__(self, num_embeddings, embedding_dim, beta):
        super().__init__()
        # codebook 码本的大小和维度
        self.codebook = nn.Embedding(num_embeddings, embedding_dim)
        self.beta = beta

    def forward(self, z_e):
        # z_e形状:    [B, D, H, W]
        # z_e       [128,64,7,7]
        B, D, H, W = z_e.shape
        z_e_flat = z_e.permute(0, 2, 3, 1).reshape(-1, D)  # [B*H*W, D]

        # z_e_flat shape 6272 64

        # 计算与码本的距离
        # distance.shape 6272 128
        distances = (
                torch.sum(z_e_flat ** 2, dim=1, keepdim=True) +
                torch.sum(self.codebook.weight ** 2, dim=1) -
                2 * torch.matmul(z_e_flat, self.codebook.weight.t())
        )  # [B*H*W, K]


        # 找到最近码字
        indices = torch.argmin(distances, dim=1)  # [B*H*W]
        z_q = self.codebook(indices)  # [B*H*W, D]

        # 重构为原始形状
        z_q = z_q.view(B, H, W, D).permute(0, 3, 1, 2)  # [B, D, H, W]

        # 计算VQ损失
        # commitment_loss 让编码器适应码本
        commitment_loss = torch.mean((z_e.detach() - z_q) ** 2)
        # codebook_loss 让码本适应编码器
        codebook_loss = torch.mean((z_e - z_q.detach()) ** 2)
        vq_loss = codebook_loss + self.beta * commitment_loss

        # Straight-Through梯度近似
        z_q = z_e + (z_q - z_e).detach()

        return z_q, indices, vq_loss


class VQVAE(nn.Module):
    def __init__(self, num_embeddings, embedding_dim, beta):
        super().__init__()
        # 编码器
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 32, 4, stride=2, padding=1),  # MNIST: 1通道
            nn.ReLU(),
            nn.Conv2d(32, 64, 4, stride=2, padding=1),
            nn.ReLU(),
            # embedding_dim 64
            nn.Conv2d(64, embedding_dim, 3, padding=1)
        )
        # 向量量化层
        self.vq = VectorQuantizer(num_embeddings, embedding_dim, beta)
        # 解码器
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(embedding_dim, 64, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 1, 3, padding=1),
            nn.Tanh()
        )

    def forward(self, x):
        # z_e shape [5,64,7,7]
        z_e = self.encoder(x)  # 编码

        z_q, indices, vq_loss = self.vq(z_e)  # 量化
        x_recon = self.decoder(z_q)  # 解码
        return x_recon, vq_loss


# 初始化模型和优化器
model = VQVAE(NUM_EMBEDDINGS, EMBEDDING_DIM, BETA).to(device)
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)


# 训练函数
def train(model, dataloader, optimizer, epochs):
    model.train()
    for epoch in range(epochs):
        total_loss = 0
        for batch_idx, (x, _) in enumerate(dataloader):
            # 这里的返回值是 (data,target)
            # 所以使用(x,)进行接收
            # x shape [128,1,28,28]
            x = x.to(device)
            optimizer.zero_grad()

            # 前向传播
            x_recon, vq_loss = model(x)

            # 计算重构损失
            recon_loss = torch.mean((x_recon - x) ** 2)
            loss = recon_loss + vq_loss

            # 反向传播
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        print(f"Epoch [{epoch + 1}/{epochs}], Loss: {total_loss / len(dataloader):.4f}")


# 可视化重建结果
def visualize_reconstruction(model, dataloader, num_samples=5):
    model.eval()
    with torch.no_grad():
        x, _ = next(iter(dataloader))
        x = x[:num_samples].to(device)
        x_recon, _ = model(x)

        # 反归一化
        # x shape [5,1,28,28]
        x = (x + 1) / 2  # [-1,1] -> [0,1]
        x_recon = (x_recon + 1) / 2


        # 绘制对比图
        # 这里是现在画布的长度和高度
        plt.figure(figsize=(10, 4))
        for i in range(num_samples):
            # 原始图像
            plt.subplot(2, num_samples, i + 1)
            # 这里是 [28,28]
            plt.imshow(x[i].squeeze().cpu(), cmap='gray')
            plt.title("Original")
            plt.axis('off')

            # 重建图像
            plt.subplot(2, num_samples, i + num_samples + 1)
            plt.imshow(x_recon[i].squeeze().cpu(), cmap='gray')
            plt.title("Reconstructed")
            plt.axis('off')
        plt.show()

# model_dict=torch.load("model.pth")
# model.load_state_dict(model_dict)

# 训练并可视化
train(model, train_loader, optimizer, EPOCHS)

# torch.save(model.state_dict(),"model.pth")

visualize_reconstruction(model, train_loader)

实验结果

Logo

欢迎加入我们的广州开发者社区,与优秀的开发者共同成长!

更多推荐