基于TransU-Net的遥感图像语义分割与分类,遥感建筑物数据集,基于Pytorch框架,针对不同城市建筑物精准提取

在这里插入图片描述
1
在这里插入图片描述
1
在这里插入图片描述
基于TransU-Net的遥感图像语义分割与分类是一个非常有趣且具有挑战性的任务,代码示例,包括数据集准备、模型定义、训练和预测等步骤。
代码示例,仅供参考。

项目结构

首先,确保你的项目结构类似于以下内容:

transunet_project/
├── data/
│   ├── train_images/
│   ├── train_labels/
│   ├── val_images/
│   └── val_labels/
├── src/
│   ├── my_dataset.py
│   ├── predict.py
│   ├── train.py
│   └── transforms.py
├── predictions/
└── save_weights/

数据集准备

假设同学你已经有了一个包含遥感图像及其对应标签的数据集,并且已经按照上述结构组织好。

定义数据集类 (my_dataset.py)

import os
from torch.utils.data import Dataset
from PIL import Image
import numpy as np

class RemoteSensingDataset(Dataset):
    def __init__(self, image_dir, label_dir, transform=None):
        self.image_dir = image_dir
        self.label_dir = label_dir
        self.transform = transform
        self.images = os.listdir(image_dir)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        img_path = os.path.join(self.image_dir, self.images[index])
        label_path = os.path.join(self.label_dir, self.images[index].replace('.jpg', '.png'))
        image = Image.open(img_path).convert("RGB")
        label = Image.open(label_path).convert("L")

        if self.transform is not None:
            image = self.transform(image)
            label = self.transform(label)

        return image, label

在这里插入图片描述

数据增强与预处理 (transforms.py)

from torchvision import transforms

def get_transform():
    return transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

在这里插入图片描述

模型定义 (train.py)

import torch
import torch.nn as nn
from transformers import SwinTransformer

class TransUNet(nn.Module):
    def __init__(self, num_classes=2):
        super(TransUNet, self).__init__()
        self.swin = SwinTransformer()
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, num_classes, kernel_size=1)
        )

    def forward(self, x):
        x = self.swin(x)
        x = self.decoder(x)
        return x

训练脚本 (train.py)

import argparse
import os
import time
from datetime import datetime

import torch
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms
from tqdm import tqdm

from my_dataset import RemoteSensingDataset
from transunet import TransUNet
from transforms import get_transform

def parse_args():
    parser = argparse.ArgumentParser(description='Train a segmentation model')
    parser.add_argument('--model', type=str, default='transunet', help='model name')
    parser.add_argument('--epochs', type=int, default=100, help='number of epochs')
    parser.add_argument('--batch-size', type=int, default=4, help='batch size')
    args = parser.parse_args()
    return args

def create_model(args):
    if args.model == 'transunet':
        return TransUNet(num_classes=2)
    else:
        raise ValueError('Invalid model name')

def main():
    args = parse_args()

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    train_dataset = RemoteSensingDataset(
        image_dir='./data/train_images',
        label_dir='./data/train_labels',
        transform=get_transform()
    )
    val_dataset = RemoteSensingDataset(
        image_dir='./data/val_images',
        label_dir='./data/val_labels',
        transform=get_transform()
    )

    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False)

    model = create_model(args).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    best_val_loss = float('inf')
    for epoch in range(args.epochs):
        model.train()
        running_loss = 0.0
        for images, labels in tqdm(train_loader):
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        avg_train_loss = running_loss / len(train_loader)
        print(f'Epoch [{epoch + 1}/{args.epochs}], Train Loss: {avg_train_loss:.4f}')

        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                loss = criterion(outputs, labels)
                val_loss += loss.item()

        avg_val_loss = val_loss / len(val_loader)
        print(f'Epoch [{epoch + 1}/{args.epochs}], Val Loss: {avg_val_loss:.4f}')

        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save(model.state_dict(), f'save_weights/{args.model}/best_model.pth')

if __name__ == '__main__':
    main()

预测脚本 (predict.py)

import argparse
import os
import time
from datetime import datetime

import torch
from PIL import Image
from torchvision import transforms
from tqdm import tqdm

from my_dataset import RemoteSensingDataset
from transunet import TransUNet
from transforms import get_transform

def parse_args():
    parser = argparse.ArgumentParser(description='Predict using a segmentation model')
    parser.add_argument('--model', type=str, default='transunet', help='model name')
    args = parser.parse_args()
    return args

def create_model(args):
    if args.model == 'transunet':
        return TransUNet(num_classes=2)
    else:
        raise ValueError('Invalid model name')

def main():
    args = parse_args()

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    test_dataset = RemoteSensingDataset(
        image_dir='./data/test_images',
        label_dir='./data/test_labels',
        transform=get_transform()
    )

    model = create_model(args).to(device)
    model.load_state_dict(torch.load(f'save_weights/{args.model}/best_model.pth'))
    model.eval()

    test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

    with torch.no_grad():
        for images, labels in tqdm(test_loader):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)

            # Save the prediction
            predicted = predicted.cpu().numpy()[0]
            predicted_img = Image.fromarray(predicted.astype(np.uint8))
            predicted_img.save(os.path.join('./predictions', f'{datetime.now().strftime("%Y%m%d_%H%M%S")}.png'))

if __name__ == '__main__':
    main()

以上代码提供了一个完整的基于TransU-Net的遥感图像语义分割与分类的实现,包括数据集准备、模型定义、训练和预测等步骤。你可以根据自己的需求进行调整和优化。
以上文字及代码仅供参考学习使用。

Logo

免费领 100 小时云算力,进群参与显卡、AI PC 幸运抽奖

更多推荐