基于TransU-Net的遥感图像语义分割与分类,遥感建筑物数据集,基于Pytorch框架,针对不同城市建筑物精准提取
train.py。
·
基于TransU-Net的遥感图像语义分割与分类,遥感建筑物数据集,基于Pytorch框架,针对不同城市建筑物精准提取
文章目录

1
1
基于TransU-Net的遥感图像语义分割与分类是一个非常有趣且具有挑战性的任务,代码示例,包括数据集准备、模型定义、训练和预测等步骤。
代码示例,仅供参考。
项目结构
首先,确保你的项目结构类似于以下内容:
transunet_project/
├── data/
│ ├── train_images/
│ ├── train_labels/
│ ├── val_images/
│ └── val_labels/
├── src/
│ ├── my_dataset.py
│ ├── predict.py
│ ├── train.py
│ └── transforms.py
├── predictions/
└── save_weights/
数据集准备
假设同学你已经有了一个包含遥感图像及其对应标签的数据集,并且已经按照上述结构组织好。
定义数据集类 (my_dataset.py)
import os
from torch.utils.data import Dataset
from PIL import Image
import numpy as np
class RemoteSensingDataset(Dataset):
def __init__(self, image_dir, label_dir, transform=None):
self.image_dir = image_dir
self.label_dir = label_dir
self.transform = transform
self.images = os.listdir(image_dir)
def __len__(self):
return len(self.images)
def __getitem__(self, index):
img_path = os.path.join(self.image_dir, self.images[index])
label_path = os.path.join(self.label_dir, self.images[index].replace('.jpg', '.png'))
image = Image.open(img_path).convert("RGB")
label = Image.open(label_path).convert("L")
if self.transform is not None:
image = self.transform(image)
label = self.transform(label)
return image, label

数据增强与预处理 (transforms.py)
from torchvision import transforms
def get_transform():
return transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

模型定义 (train.py)
import torch
import torch.nn as nn
from transformers import SwinTransformer
class TransUNet(nn.Module):
def __init__(self, num_classes=2):
super(TransUNet, self).__init__()
self.swin = SwinTransformer()
self.decoder = nn.Sequential(
nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2),
nn.ReLU(inplace=True),
nn.Conv2d(64, num_classes, kernel_size=1)
)
def forward(self, x):
x = self.swin(x)
x = self.decoder(x)
return x
训练脚本 (train.py)
import argparse
import os
import time
from datetime import datetime
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms
from tqdm import tqdm
from my_dataset import RemoteSensingDataset
from transunet import TransUNet
from transforms import get_transform
def parse_args():
parser = argparse.ArgumentParser(description='Train a segmentation model')
parser.add_argument('--model', type=str, default='transunet', help='model name')
parser.add_argument('--epochs', type=int, default=100, help='number of epochs')
parser.add_argument('--batch-size', type=int, default=4, help='batch size')
args = parser.parse_args()
return args
def create_model(args):
if args.model == 'transunet':
return TransUNet(num_classes=2)
else:
raise ValueError('Invalid model name')
def main():
args = parse_args()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
train_dataset = RemoteSensingDataset(
image_dir='./data/train_images',
label_dir='./data/train_labels',
transform=get_transform()
)
val_dataset = RemoteSensingDataset(
image_dir='./data/val_images',
label_dir='./data/val_labels',
transform=get_transform()
)
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False)
model = create_model(args).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
best_val_loss = float('inf')
for epoch in range(args.epochs):
model.train()
running_loss = 0.0
for images, labels in tqdm(train_loader):
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
avg_train_loss = running_loss / len(train_loader)
print(f'Epoch [{epoch + 1}/{args.epochs}], Train Loss: {avg_train_loss:.4f}')
model.eval()
val_loss = 0.0
with torch.no_grad():
for images, labels in val_loader:
images, labels = images.to(device), labels.to(device)
outputs = model(images)
loss = criterion(outputs, labels)
val_loss += loss.item()
avg_val_loss = val_loss / len(val_loader)
print(f'Epoch [{epoch + 1}/{args.epochs}], Val Loss: {avg_val_loss:.4f}')
if avg_val_loss < best_val_loss:
best_val_loss = avg_val_loss
torch.save(model.state_dict(), f'save_weights/{args.model}/best_model.pth')
if __name__ == '__main__':
main()
预测脚本 (predict.py)
import argparse
import os
import time
from datetime import datetime
import torch
from PIL import Image
from torchvision import transforms
from tqdm import tqdm
from my_dataset import RemoteSensingDataset
from transunet import TransUNet
from transforms import get_transform
def parse_args():
parser = argparse.ArgumentParser(description='Predict using a segmentation model')
parser.add_argument('--model', type=str, default='transunet', help='model name')
args = parser.parse_args()
return args
def create_model(args):
if args.model == 'transunet':
return TransUNet(num_classes=2)
else:
raise ValueError('Invalid model name')
def main():
args = parse_args()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
test_dataset = RemoteSensingDataset(
image_dir='./data/test_images',
label_dir='./data/test_labels',
transform=get_transform()
)
model = create_model(args).to(device)
model.load_state_dict(torch.load(f'save_weights/{args.model}/best_model.pth'))
model.eval()
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)
with torch.no_grad():
for images, labels in tqdm(test_loader):
images, labels = images.to(device), labels.to(device)
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
# Save the prediction
predicted = predicted.cpu().numpy()[0]
predicted_img = Image.fromarray(predicted.astype(np.uint8))
predicted_img.save(os.path.join('./predictions', f'{datetime.now().strftime("%Y%m%d_%H%M%S")}.png'))
if __name__ == '__main__':
main()
以上代码提供了一个完整的基于TransU-Net的遥感图像语义分割与分类的实现,包括数据集准备、模型定义、训练和预测等步骤。你可以根据自己的需求进行调整和优化。
以上文字及代码仅供参考学习使用。
更多推荐

所有评论(0)