VQ-VAE讲解和实战
VQ-VAE简单来说就是,在潜在变量处进行向量量化,通过commitment_loss和codebook_loss让潜在向量和码本互相靠近。
·
VQ-VAE(Vector Quantized Variational Autoencoder)
VQ-VAE(向量量化变分自编码器)是由DeepMind在2017年提出的一种自监督学习模型,主要用于离散表征学习。它结合了变分自编码器(VAE)和向量量化(VQ)技术,能够将连续的高维数据(如图像、音频)压缩为离散的符号(token),同时保持重建质量。VQ-VAE是许多现代生成模型(如VQ-VAE-2、VQ-GAN)的基础。

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
# 超参数
BATCH_SIZE = 128
EPOCHS = 20
LEARNING_RATE = 2e-4
NUM_EMBEDDINGS = 128 # 码本大小(K)
EMBEDDING_DIM = 64 # 码字维度(D)
BETA = 0.25 # Commitment Loss权重
# 设备配置
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 加载MNIST数据集(替换为CIFAR-10可处理彩色图像)
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,)) # MNIST单通道
])
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
# 定义VQ-VAE模型
class VectorQuantizer(nn.Module):
def __init__(self, num_embeddings, embedding_dim, beta):
super().__init__()
# codebook 码本的大小和维度
self.codebook = nn.Embedding(num_embeddings, embedding_dim)
self.beta = beta
def forward(self, z_e):
# z_e形状: [B, D, H, W]
# z_e [128,64,7,7]
B, D, H, W = z_e.shape
z_e_flat = z_e.permute(0, 2, 3, 1).reshape(-1, D) # [B*H*W, D]
# z_e_flat shape 6272 64
# 计算与码本的距离
# distance.shape 6272 128
distances = (
torch.sum(z_e_flat ** 2, dim=1, keepdim=True) +
torch.sum(self.codebook.weight ** 2, dim=1) -
2 * torch.matmul(z_e_flat, self.codebook.weight.t())
) # [B*H*W, K]
# 找到最近码字
indices = torch.argmin(distances, dim=1) # [B*H*W]
z_q = self.codebook(indices) # [B*H*W, D]
# 重构为原始形状
z_q = z_q.view(B, H, W, D).permute(0, 3, 1, 2) # [B, D, H, W]
# 计算VQ损失
# commitment_loss 让编码器适应码本
commitment_loss = torch.mean((z_e.detach() - z_q) ** 2)
# codebook_loss 让码本适应编码器
codebook_loss = torch.mean((z_e - z_q.detach()) ** 2)
vq_loss = codebook_loss + self.beta * commitment_loss
# Straight-Through梯度近似
z_q = z_e + (z_q - z_e).detach()
return z_q, indices, vq_loss
class VQVAE(nn.Module):
def __init__(self, num_embeddings, embedding_dim, beta):
super().__init__()
# 编码器
self.encoder = nn.Sequential(
nn.Conv2d(1, 32, 4, stride=2, padding=1), # MNIST: 1通道
nn.ReLU(),
nn.Conv2d(32, 64, 4, stride=2, padding=1),
nn.ReLU(),
# embedding_dim 64
nn.Conv2d(64, embedding_dim, 3, padding=1)
)
# 向量量化层
self.vq = VectorQuantizer(num_embeddings, embedding_dim, beta)
# 解码器
self.decoder = nn.Sequential(
nn.ConvTranspose2d(embedding_dim, 64, 4, stride=2, padding=1),
nn.ReLU(),
nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1),
nn.ReLU(),
nn.Conv2d(32, 1, 3, padding=1),
nn.Tanh()
)
def forward(self, x):
# z_e shape [5,64,7,7]
z_e = self.encoder(x) # 编码
z_q, indices, vq_loss = self.vq(z_e) # 量化
x_recon = self.decoder(z_q) # 解码
return x_recon, vq_loss
# 初始化模型和优化器
model = VQVAE(NUM_EMBEDDINGS, EMBEDDING_DIM, BETA).to(device)
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
# 训练函数
def train(model, dataloader, optimizer, epochs):
model.train()
for epoch in range(epochs):
total_loss = 0
for batch_idx, (x, _) in enumerate(dataloader):
# 这里的返回值是 (data,target)
# 所以使用(x,)进行接收
# x shape [128,1,28,28]
x = x.to(device)
optimizer.zero_grad()
# 前向传播
x_recon, vq_loss = model(x)
# 计算重构损失
recon_loss = torch.mean((x_recon - x) ** 2)
loss = recon_loss + vq_loss
# 反向传播
loss.backward()
optimizer.step()
total_loss += loss.item()
print(f"Epoch [{epoch + 1}/{epochs}], Loss: {total_loss / len(dataloader):.4f}")
# 可视化重建结果
def visualize_reconstruction(model, dataloader, num_samples=5):
model.eval()
with torch.no_grad():
x, _ = next(iter(dataloader))
x = x[:num_samples].to(device)
x_recon, _ = model(x)
# 反归一化
# x shape [5,1,28,28]
x = (x + 1) / 2 # [-1,1] -> [0,1]
x_recon = (x_recon + 1) / 2
# 绘制对比图
# 这里是现在画布的长度和高度
plt.figure(figsize=(10, 4))
for i in range(num_samples):
# 原始图像
plt.subplot(2, num_samples, i + 1)
# 这里是 [28,28]
plt.imshow(x[i].squeeze().cpu(), cmap='gray')
plt.title("Original")
plt.axis('off')
# 重建图像
plt.subplot(2, num_samples, i + num_samples + 1)
plt.imshow(x_recon[i].squeeze().cpu(), cmap='gray')
plt.title("Reconstructed")
plt.axis('off')
plt.show()
# model_dict=torch.load("model.pth")
# model.load_state_dict(model_dict)
# 训练并可视化
train(model, train_loader, optimizer, EPOCHS)
# torch.save(model.state_dict(),"model.pth")
visualize_reconstruction(model, train_loader)
实验结果

更多推荐


所有评论(0)