关于使用Vision Transformer训练cifar-10
因为CIFAR-10数据是32*32像素大小,我们需要将它转化为224*224像素大小。
·
1.导入相应的包
!pip -q install vit_pytorch linformer
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from linformer import Linformer
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
from tqdm.notebook import tqdm
import torchvision
2.定义数据处理操作
因为CIFAR-10数据是32*32像素大小,我们需要将它转化为224*224像素大小。
transforms_train = transforms.Compose([
transforms.Resize((256,256)),
transforms.RandomResizedCrop(224,scale=(0.64,1.0),ratio=(1.0,1.0)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])
])
transforms_test = transforms.Compose([
transforms.Resize((224,224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])
])
3.导入数据
train_data = torchvision.datasets.CIFAR10(root="./cifar10", train=True, download=True, transform=transforms_train)
train_loader = torch.utils.data.DataLoader(train_data, batch_size=128, shuffle=True, num_workers=2)
test_data = torchvision.datasets.CIFAR10(root='./cifar10', train=False,
download=False, transform=transforms_test)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=128,
shuffle=False, num_workers=2)
train_data_size = len(train_data)
test_data_size = len(test_data)
4.定义ViT函数
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
# helpers
def pair(t):
return t if isinstance(t, tuple) else (t, t)
# classes
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
class Attention(nn.Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
super().__init__()
inner_dim = dim_head * heads
project_out = not (heads == 1 and dim_head == dim)
self.heads = heads
self.scale = dim_head ** -0.5
self.norm = nn.LayerNorm(dim)
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
) if project_out else nn.Identity()
def forward(self, x):
x = self.norm(x)
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = self.attend(dots)
attn = self.dropout(attn)
out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
FeedForward(dim, mlp_dim, dropout = dropout)
]))
def forward(self, x):
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x
return self.norm(x)
class ViT(nn.Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
super().__init__()
image_height, image_width = pair(image_size)
patch_height, patch_width = pair(patch_size)
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
num_patches = (image_height // patch_height) * (image_width // patch_width)
patch_dim = channels * patch_height * patch_width
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
nn.LayerNorm(patch_dim),
nn.Linear(patch_dim, dim),
nn.LayerNorm(dim),
)
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.dropout = nn.Dropout(emb_dropout)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
self.pool = pool
self.to_latent = nn.Identity()
self.mlp_head = nn.Linear(dim, num_classes)
def forward(self, img):
x = self.to_patch_embedding(img)
b, n, _ = x.shape
cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)
x = torch.cat((cls_tokens, x), dim=1)
x += self.pos_embedding[:, :(n + 1)]
x = self.dropout(x)
x = self.transformer(x)
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
x = self.to_latent(x)
return self.mlp_head(x)
5.实例化模型并且转移到GPU上
model = ViT(
image_size = 224,
patch_size = 16,
num_classes = 10,
dim = 1024,
depth = 6,
heads = 16,
mlp_dim = 2048,
dropout = 0.1,
emb_dropout = 0.1
)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.to(device)
6.选取损失函数和优化器
LR = 0.001
epoch = 100
loss_func = nn.CrossEntropyLoss()
loss_func = loss_func.to(device)
optimizer = torch.optim.SGD(model.parameters(),lr=LR,weight_decay=0.001,momentum=0.9)
7.训练模块
total_train_step = 0
for i in range(epoch):
print(f'---------------第{i + 1}轮训练---------------')
total_train_acc = 0
total_train_loss = 0
model.train()
for j,data in enumerate(train_loader):
inputs,labels = data
inputs,labels = inputs.to(device),labels.to(device)
outputs = model(inputs)
loss = loss_func(outputs,labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
acc = (outputs.argmax(1) == labels).sum()
total_train_acc = total_train_acc + acc
total_train_loss += loss
total_train_step += 1
if total_train_step % 100 == 0:
print("训练次数:{}, Loss: {}".format(total_train_step, loss.item()))
print("整体训练集集上的Loss: {}".format(total_train_loss))
print("整体训练集上的正确率: {}".format(total_train_acc / train_data_size))
model.eval()
total_test_acc = 0
total_test_loss = 0
with torch.no_grad():
for h,data in enumerate(test_loader):
inputs_test, labels_test = data
inputs_test, labels_test = inputs_test.to(device), labels_test.to(device)
outputs = model(inputs_test)
loss_test = loss_func(outputs, labels_test)
total_test_loss += loss_test
acc_test = (outputs.argmax(1) == labels_test).sum()
total_test_acc += acc_test
print("整体测试集上的Loss: {}".format(total_test_loss))
print("整体测试集上的正确率: {}".format(total_test_acc / test_data_size))
torch.save(model,'./model')
初学,欢迎大家指出错误。
更多推荐
已为社区贡献1条内容
所有评论(0)