【PyTorch教程】保姆级实战教程【七】
第6章 - 优化技巧实训操作手册 1. 正则化技术:Dropout和Batch normalization 1.1 Dropout Dropout是一种防止神经网络过拟合的技术。它在训练期间随机“关闭”一些神经元,使其在前向传播和反向传播中都不工作。 实操步骤: 在你的模型中的适当位置插入Dropout层。选择一个dropout率,例如0.5
·
第6章 - 优化技巧实训操作手册
1. 正则化技术:Dropout和Batch normalization
1.1 Dropout
Dropout是一种防止神经网络过拟合的技术。它在训练期间随机“关闭”一些神经元,使其在前向传播和反向传播中都不工作。
实操步骤:
- 在你的模型中的适当位置插入Dropout层。
- 选择一个dropout率,例如0.5,表示每次前向传播时都随机关闭50%的神经元。
import torch.nn as nn
class ModelWithDropout(nn.Module):
def __init__(self):
super(ModelWithDropout, self).__init__()
self.fc1 = nn.Linear(10, 5)
self.dropout = nn.Dropout(0.5)
self.fc2 = nn.Linear(5, 1)
def forward(self, x):
x = torch.relu(self.fc1(x))
x = self.dropout(x)
x = self.fc2(x)
return x
1.2 Batch normalization
Batch normalization可以使深度网络的每一层都进行归一化处理,从而加速训练。
实操步骤:
- 在你的模型的适当位置插入Batch normalization层。
- 确保Batch normalization的输入特征数量与前一层的输出特征数量相匹配。
class ModelWithBatchNorm(nn.Module):
def __init__(self):
super(ModelWithBatchNorm, self).__init__()
self.fc1 = nn.Linear(10, 5)
self.bn1 = nn.BatchNorm1d(5)
self.fc2 = nn.Linear(5, 1)
def forward(self, x):
x = self.bn1(torch.relu(self.fc1(x)))
x = self.fc2(x)
return x
2. 超参数调优技巧
2.1 网格搜索
网格搜索是一种穷举搜索方法,用于找到最佳的超参数组合。
实操步骤:
- 定义要搜索的超参数的可能值。
- 使用每种组合训练模型,并选择性能最好的组合。
注意:由于网格搜索的计算成本可能很高,建议先在小型数据集上进行。
3. 学习率调度
3.1 学习率衰减
随着训练的进行,减小学习率可以帮助模型收敛。
实操步骤:
- 定义一个优化器。
- 使用一个学习率调度器,例如每10个epoch后将学习率乘以0.1。
import torch.optim as optim
optimizer = optim.SGD(model.parameters(), lr=0.01)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
4. 权重初始化策略
4.1 Xavier/Glorot初始化
适用于Sigmoid和tanh激活函数。
实操步骤:
- 使用Xavier初始化方法初始化你的模型的权重。
nn.init.xavier_uniform_(model.fc1.weight)
4.2 He初始化
适用于ReLU激活函数。
实操步骤:
- 使用He初始化方法初始化你的模型的权重。
nn.init.kaiming_uniform_(model.fc1.weight)
实战项目:优化第4章的CNN模型(服装图像分类)
目标:利用本章学到的优化技巧,提高第4章CNN模型在FashionMNIST数据集上的性能。
1. 数据准备
首先,我们要加载FashionMNIST数据集,并对其进行适当的预处理。
from torchvision import datasets, transforms
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
train_dataset = datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)
2. 修改CNN模型
我们将在原始模型的基础上添加Dropout和Batch normalization。
import torch.nn as nn
class OptimizedCNN(nn.Module):
def __init__(self):
super(OptimizedCNN, self).__init__()
self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
self.bn1 = nn.BatchNorm2d(32)
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
self.bn2 = nn.BatchNorm2d(64)
self.fc1 = nn.Linear(64 * 7 * 7, 128)
self.dropout = nn.Dropout(0.5)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = self.bn1(torch.relu(self.conv1(x)))
x = nn.MaxPool2d(2)(x)
x = self.bn2(torch.relu(self.conv2(x)))
x = nn.MaxPool2d(2)(x)
x = x.view(x.size(0), -1)
x = torch.relu(self.fc1(x))
x = self.dropout(x)
x = self.fc2(x)
return x
model = OptimizedCNN()
3. 初始化策略
使用He初始化方法对模型进行初始化。
def weights_init(m):
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
nn.init.kaiming_uniform_(m.weight)
model.apply(weights_init)
4. 定义损失函数、优化器和学习率调度器
import torch.optim as optim
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=5, factor=0.5)
5. 训练模型
训练模型时,我们还将调整学习率。
epochs = 10
for epoch in range(epochs):
total_loss = 0.0
for batch_idx, (data, target) in enumerate(train_loader):
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
total_loss += loss.item()
average_loss = total_loss / len(train_loader)
scheduler.step(average_loss)
print(f"Epoch {epoch+1}/{epochs}, Average Loss: {average_loss:.4f}")
6. 评估模型
利用FashionMNIST的测试数据集评估模型。
test_dataset = datasets.FashionMNIST(root='./data', train=False, transform=transform)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1000, shuffle=False)
model.eval()
correct = 0
with torch.no_grad():
for data, target in test_loader:
output = model(data)
pred = output.argmax(dim=1)
correct += pred.eq(target).sum().item()
accuracy = correct / len(test_loader.dataset)
print(f"Accuracy: {accuracy:.4f}")
更多推荐
已为社区贡献15条内容
所有评论(0)