★★★ 本文源自AlStudio社区精品项目,【点击此处】查看更多精品内容 >>>

摘要

        视觉Transformer(ViTs)已经被证明在各种视觉任务中是有效的。 但是,将它们调整到移动友好的大小会导致性能显著下降。 因此,研制轻量化的视觉Transformer已成为一个重要的研究领域。 本文介绍了CloFormer,这是一个轻量级的视觉Transformer,利用上下文感知的局部增强。 CloFormer研究了原始卷积算子中常用的全局共享权重和注意力中出现的特定于Token的上下文感知权重之间的关系,并提出了一种有效而直观的获取高频局部信息的模块。 在CloFormer中,我们引入了AttnConv,这是一个具有注意风格的卷积算子。 提出的AttnConv使用共享权值来聚合局部信息,并使用精心设计的上下文感知权重来增强局部特征。 AttnConv和原始注意力相结合,使用池化来减少CloFormer中的FLOPs,使模型能够感知高频和低频信息。 在图像分类、目标检测、语义分割等方面进行了大量实验,验证了CloFormer的优越性。

1. CloFormer

        如图2所示,CloFormer有四个阶段,主要包括Clo Block、ConvFFN、convolution stem。输入图像首先经过convolution stem来获得Token。然后经过一系列的Clo Block和ConvFFN提取特征,最后使用全局平均池化和全连接层得到最终的预测。

1.1 Clo Block

        CloFormer 中的 Clo块 是非常关键的组件。每个 Clo 块由一个局部分支和一个全局分支组成。

  1. 全局分支。全局分支使用传统的注意力机制,但对 K 和 V 进行了下采样以减少计算量,从而捕捉低频全局信息。具体公式如下:

X global  = Attntion ⁡ ( Q g , Pool ⁡ ( K g ) , Pool ⁡ ( V g ) ) X_{\text {global }}=\operatorname{Attntion}\left(Q_{g}, \operatorname{Pool}\left(K_{g}\right), \operatorname{Pool}\left(V_{g}\right)\right) Xglobal =Attntion(Qg,Pool(Kg),Pool(Vg))

  1. 局部分支(AttnConv)。在局部分支中,本文引入了一个精心设计的 AttnConv,一种简单而有效的卷积操作符,它采用了注意力机制的风格。所提出的 AttnConv 有效地融合了共享权重和上下文感知权重,以聚合高频的局部信息。具体地,AttnConv 首先使用深度卷积(DWconv)提取局部表示,其中 DWconv 具有共享权重。然后,其使用上下文感知权重来增强局部特征。与 Non-Local 等生成上下文感知权重的方法不同,AttnConv 使用门控机制生成上下文感知权重,引入了比常用的注意力机制更强的非线性。此外,AttnConv 将卷积算子应用于 Query 和 Key 以聚合局部信息,然后计算 Q 和 K 的哈达玛积,并对结果进行一系列线性或非线性变换,生成范围在 [-1,1] 之间的上下文感知权重。值得注意的是,AttnConv 继承了卷积的平移等变性,因为它的所有操作都基于卷积。具体公式如下:

Q , K , V = F C ( X i n ) V s = D W conv ⁡ ( V ) Q l = D W conv ⁡ ( Q ) K l = D W conv ⁡ ( K ) A t t n t = F C ( Swish ⁡ ( F C ( Q l ⊙ K l ) ) ) A t t n = Tanh ⁡ ( A t t n t d ) X local  = A t t n ⊙ V s \begin{array}{l} Q, K, V=\mathbf{F C}\left(X_{i n}\right) \\ V_{s}=\mathbf{D W} \operatorname{conv}(V) \\ Q_{l}=\mathbf{D W} \operatorname{conv}(Q) \\ K_{l}=\mathbf{D W} \operatorname{conv}(K) \\ A_{t t n_{t}}=\mathbf{F C}\left(\operatorname{Swish}\left(\mathbf{F C}\left(Q_{l} \odot K_{l}\right)\right)\right) \\ A_{t t n}=\operatorname{Tanh}\left(\frac{A t t n_{t}}{\sqrt{d}}\right) \\ X_{\text {local }}=A t t n \odot V_{s} \end{array} Q,K,V=FC(Xin)Vs=DWconv(V)Ql=DWconv(Q)Kl=DWconv(K)Attnt=FC(Swish(FC(QlKl)))Attn=Tanh(d Attnt)Xlocal =AttnVs

        最后,将全局特征和局部特征合并起来,并使用一个MLP得到最终的输出。公式表示如下:

X t = C o n c a t ( X local  , X global  ) X out  = F C ( X t ) \begin{array}{l} X_{t}=\mathbf{C o n c a t}\left(X_{\text {local }}, X_{\text {global }}\right) \\ X_{\text {out }}=\mathbf{F C}\left(X_{t}\right) \end{array} Xt=Concat(Xlocal ,Xglobal )Xout =FC(Xt)

1.2 ConvFFN

        为了将局部信息融入 FFN 过程中,本文采用 ConvFFN 替换了常用的 FFN。ConvFFN 和常用的 FFN 的主要区别在于,ConvFFN 在 GELU 激活函数之后使用了深度卷积(DWconv),从而使 ConvFFN 能够聚合局部信息。由于DWconv 的存在,可以直接在 ConvFFN 中进行下采样而不需要引入 Patch Merge 模块。CloFormer 使用了两种ConvFFN。第一种是在阶段内的 ConvFFN,它直接利用跳跃连接。另一种是连接两个阶段的 ConvFFN,主要用于下采样操作。

2. 代码复现

2.1 下载并导入所需的库

!pip install einops-0.3.0-py3-none-any.whl
!pip install paddlex
%matplotlib inline
import paddle
import paddle.fluid as fluid
import numpy as np
import matplotlib.pyplot as plt
from paddle.vision.datasets import Cifar10
from paddle.vision.transforms import Transpose
from paddle.io import Dataset, DataLoader
from paddle import nn
import paddle.nn.functional as F
import paddle.vision.transforms as transforms
import os
import matplotlib.pyplot as plt
from matplotlib.pyplot import figure
import paddlex
import itertools
from einops import rearrange

2.2 创建数据集

train_tfm = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.ColorJitter(brightness=0.2,contrast=0.2, saturation=0.2),
    transforms.RandomHorizontalFlip(0.5),
    transforms.RandomRotation(20),
    paddlex.transforms.MixupImage(),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
])

test_tfm = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
])
paddle.vision.set_image_backend('cv2')
# 使用Cifar10数据集
train_dataset = Cifar10(data_file='data/data152754/cifar-10-python.tar.gz', mode='train', transform = train_tfm, )
val_dataset = Cifar10(data_file='data/data152754/cifar-10-python.tar.gz', mode='test',transform = test_tfm)
print("train_dataset: %d" % len(train_dataset))
print("val_dataset: %d" % len(val_dataset))
train_dataset: 50000
val_dataset: 10000
batch_size=256
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, drop_last=False, num_workers=4)

2.3 模型的创建

2.3.1 标签平滑
class LabelSmoothingCrossEntropy(nn.Layer):
    def __init__(self, smoothing=0.1):
        super().__init__()
        self.smoothing = smoothing

    def forward(self, pred, target):

        confidence = 1. - self.smoothing
        log_probs = F.log_softmax(pred, axis=-1)
        idx = paddle.stack([paddle.arange(log_probs.shape[0]), target], axis=1)
        nll_loss = paddle.gather_nd(-log_probs, index=idx)
        smooth_loss = paddle.mean(-log_probs, axis=-1)
        loss = confidence * nll_loss + self.smoothing * smooth_loss

        return loss.mean()
2.3.2 DropPath
def drop_path(x, drop_prob=0.0, training=False):
    """
    Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
    the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
    See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ...
    """
    if drop_prob == 0.0 or not training:
        return x
    keep_prob = paddle.to_tensor(1 - drop_prob)
    shape = (paddle.shape(x)[0],) + (1,) * (x.ndim - 1)
    random_tensor = keep_prob + paddle.rand(shape, dtype=x.dtype)
    random_tensor = paddle.floor(random_tensor)  # binarize
    output = x.divide(keep_prob) * random_tensor
    return output


class DropPath(nn.Layer):
    def __init__(self, drop_prob=None):
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob

    def forward(self, x):
        return drop_path(x, self.drop_prob, self.training)
2.3.3 CloFormer模型的创建
class stem(nn.Layer):
    def __init__(self, in_dim, out_dim):
        super().__init__()
        self.conv1 = nn.Conv2D(in_dim, out_dim // 2, 3, padding=1, stride=2, bias_attr=False)

        self.conv2 = nn.Conv2D(out_dim // 2, out_dim, 3, padding=1, stride=2, bias_attr=False)

        self.conv3 = nn.Conv2D(out_dim, out_dim, 3, padding=1, bias_attr=False)

        self.conv4 = nn.Conv2D(out_dim, out_dim, 3, padding=1, bias_attr=False)

        self.conv5 = nn.Conv2D(out_dim, out_dim, 1, bias_attr=False)

        self.gelu = nn.GELU()
    
    def forward(self, x):
        x = self.conv1(x)
        x = self.gelu(x)
        x = self.conv2(x)
        x = self.gelu(x)
        x = self.conv3(x)
        x = self.gelu(x)
        x = self.conv4(x)
        x = self.gelu(x)
        x = self.conv5(x)
        return x
class CloBlock(nn.Layer):
    def __init__(self, global_dim, local_dim, kernel_size, pool_size, head, qk_scale=None, drop_path_rate=0.0):
        super().__init__()
        self.global_dim = global_dim
        self.local_dim = local_dim
        self.head = head

        self.norm = nn.LayerNorm(global_dim + local_dim)
        
        # global branch
        self.global_head = int(self.head * self.global_dim / (self.global_dim + self.local_dim))
        self.fc1 = nn.Linear(global_dim, global_dim * 3)
        self.pool1 = nn.AvgPool2D(pool_size)
        self.pool2 = nn.AvgPool2D(pool_size)
        self.qk_scale = qk_scale or global_dim ** -0.5
        self.softmax = nn.Softmax(axis=-1)

        # local branch
        self.local_head = int(self.head * self.local_dim / (self.global_dim + self.local_dim))
        self.fc2 = nn.Linear(local_dim, local_dim * 3)
        self.qconv = nn.Conv2D(local_dim // self.local_head, local_dim // self.local_head, kernel_size,
                padding=kernel_size//2, groups=local_dim // self.local_head)
        self.kconv = nn.Conv2D(local_dim // self.local_head, local_dim // self.local_head, kernel_size,
                padding=kernel_size//2, groups=local_dim // self.local_head)
        self.vconv = nn.Conv2D(local_dim // self.local_head, local_dim // self.local_head, kernel_size,
                padding=kernel_size//2, groups=local_dim // self.local_head)
        self.fc3 = nn.Conv2D(local_dim // self.local_head, local_dim // self.local_head, 1)
        self.swish = nn.Swish()
        self.fc4 = nn.Conv2D(local_dim // self.local_head, local_dim // self.local_head, 1)
        self.tanh = nn.Tanh()

        # fuse
        self.fc5 = nn.Conv2D(global_dim + local_dim, global_dim + local_dim, 1)
        self.drop_path = DropPath(drop_path_rate)

    def forward(self, x):
        identity = x

        B, C, H, W = x.shape

        x = rearrange(x, 'b c h w->b (h w) c')
        x = self.norm(x)
        x_local, x_global = paddle.split(x, [self.local_dim, self.global_dim], axis=-1)

        # global branch
        global_qkv = self.fc1(x_global)
        global_qkv = rearrange(global_qkv, 'b n (m h c)->m b h n c', m=3, h=self.global_head)
        global_q, global_k, global_v = global_qkv[0], global_qkv[1], global_qkv[2]
        global_k = rearrange(global_k, 'b m (h w) c->b (m c) h w', h=H, w=W)
        global_k = self.pool1(global_k)
        global_k = rearrange(global_k, 'b (m c) h w->b m (h w) c', m=self.global_head)
        global_v = rearrange(global_v, 'b m (h w) c->b (m c) h w', h=H, w=W)
        global_v = self.pool1(global_v)
        global_v = rearrange(global_v, 'b (m c) h w->b m (h w) c', m=self.global_head)
        attn = global_q @ global_k.transpose([0, 1, 3, 2]) * self.qk_scale
        attn = self.softmax(attn)
        x_global = attn @ global_v
        x_global = rearrange(x_global, 'b m (h w) c-> b (m c) h w', h=H, w=W)

        # local branch
        local_qkv = self.fc2(x_local)
        local_qkv = rearrange(local_qkv, 'b (h w) (m n c)->m (b n) c h w', m=3, h=H, w=W, n=self.local_head)
        local_q, local_k, local_v = local_qkv[0], local_qkv[1], local_qkv[2]
        local_q = self.qconv(local_q)
        local_k = self.kconv(local_k)
        local_v = self.vconv(local_v)
        attn = local_q * local_k
        attn = self.fc4(self.swish(self.fc3(attn)))
        attn = self.tanh(attn / (self.local_dim ** -0.5))
        x_local = attn * local_v
        x_local = rearrange(x_local, '(b n) c h w->b (n c) h w', b=B)

        # Fuse
        x = paddle.concat([x_local, x_global], axis=1)
        x = self.fc5(x)
        out = identity + self.drop_path(x)
        return out
class ConvFFN(nn.Layer):
    def __init__(self, in_dim, out_dim, kernel_size, stride, exp_ratio=4, drop_path_rate=0.0):
        super().__init__()
        self.norm1 = nn.LayerNorm(in_dim)
        self.fc1 = nn.Conv2D(in_dim, int(exp_ratio * in_dim), 1)
        self.gelu = nn.GELU()
        self.dwconv1 = nn.Conv2D(int(exp_ratio * in_dim), int(exp_ratio * in_dim), kernel_size, padding=kernel_size//2, stride=stride, groups=int(exp_ratio * in_dim))
        self.fc2 = nn.Conv2D(int(exp_ratio * in_dim), out_dim, 1)
        self.drop_path = DropPath(drop_path_rate)

        self.downsample = stride>1
        if self.downsample:
            self.dwconv2 = nn.Conv2D(in_dim, in_dim, kernel_size, padding=kernel_size//2, stride=stride, groups=in_dim)
            self.norm2 = nn.BatchNorm2D(in_dim)
            self.fc3 = nn.Conv2D(in_dim, out_dim, 1)
        
    def forward(self, x):
        
        if self.downsample:
            identity = self.fc3(self.norm2(self.dwconv2(x)))
        else:
            identity = x

        x = rearrange(x, 'b c h w->b h w c')
        x = self.norm1(x)
        x = rearrange(x, 'b h w c->b c h w')

        x = self.fc1(x)
        x = self.gelu(x)
        x = self.dwconv1(x)
        x = self.fc2(x)

        out = identity + self.drop_path(x)
        return out
class CloFormer(nn.Layer):
    def __init__(self, global_dim, local_dim, heads, in_dim=3, num_classes=1000, depths=[2, 2, 6, 2], attnconv_ks=[3, 5, 7, 9],
                pool_size=[8, 4, 2, 1], convffn_ks=5, convffn_ratio=4, drop_path_rate=0.0):
        super().__init__()

        dprs = [x.item() for x in paddle.linspace(0, drop_path_rate, sum(depths))]

        self.stem = stem(in_dim, global_dim[0] + local_dim[0])

        for i in range(len(depths)):
            layers = []
            dpr = dprs[sum(depths[:i]):sum(depths[:i + 1])]
            for j in range(depths[i]):
                if j < depths[i] - 1 or i == len(depths) - 1:
                    layers.append(
                        nn.Sequential(
                            CloBlock(global_dim[i], local_dim[i], attnconv_ks[i], pool_size[i], heads[i], dpr[j]),
                            ConvFFN(global_dim[i] + local_dim[i], global_dim[i] + local_dim[i], convffn_ks, 1, convffn_ratio, dpr[j])
                        )
                    )
                else:
                    layers.append(
                        nn.Sequential(
                            CloBlock(global_dim[i], local_dim[i], attnconv_ks[i], pool_size[i], heads[i], dpr[j]),
                            ConvFFN(global_dim[i] + local_dim[i], global_dim[i + 1] + local_dim[i + 1], convffn_ks, 2, convffn_ratio, dpr[j])
                        )
                    )

            self.__setattr__(f'stage{i}', nn.LayerList(layers))
        
        self.norm = nn.LayerNorm(global_dim[-1] + local_dim[-1])
        
        self.head = nn.Linear(global_dim[-1] + local_dim[-1], num_classes)
        self.apply(self._init_weights)

    def _init_weights(self, m):
        tn = nn.initializer.TruncatedNormal(std=.02)
        ones = nn.initializer.Constant(1.0)
        zeros = nn.initializer.Constant(0.0)
        if isinstance(m, (nn.Conv2D, nn.Linear)):
            tn(m.weight)
            if m.bias is not None:
                zeros(m.bias)
        elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2D)):
            zeros(m.bias)
            ones(m.weight)

    def forward_feature(self, x):
        for blk in self.stage0:
            x = blk(x) 
        
        for blk in self.stage1:
            x = blk(x) 

        for blk in self.stage2:
            x = blk(x) 

        for blk in self.stage3:
            x = blk(x) 

        x = rearrange(x, 'b c h w-> b h w c')
        x = self.norm(x)
        return x

    def forward(self, x):
        x = self.stem(x)
        x = self.forward_feature(x)
        x = paddle.mean(x, axis=[1, 2])
        x = self.head(x)
        return x

num_classes=10

def cloformer_xxs():
    global_dim = [8, 32, 64, 192]
    local_dim = [24, 32, 64, 64]
    heads = [4, 4, 8, 16]

    model = CloFormer(global_dim, local_dim, heads, num_classes=num_classes, drop_path_rate=0.0)

    return model


def cloformer_xs():
    global_dim = [16, 48, 80, 240]
    local_dim = [32, 48, 80, 112]
    heads = [3, 6, 10, 22]

    model = CloFormer(global_dim, local_dim, heads, num_classes=num_classes, drop_path_rate=0.06)

    return model


def cloformer_s():
    global_dim = [16, 64, 112, 336]
    local_dim = [48, 64, 112, 112]
    heads = [4, 8, 14, 28]

    model = CloFormer(global_dim, local_dim, heads, num_classes=num_classes, drop_path_rate=0.06)

    return model

2.3.4 模型的参数
model = cloformer_xxs()
paddle.summary(model, (1, 3, 224, 224))

model = cloformer_xs()
paddle.summary(model, (1, 3, 224, 224))

model = cloformer_s()
paddle.summary(model, (1, 3, 224, 224))

2.4 训练

learning_rate = 0.001
n_epochs = 100
paddle.seed(42)
np.random.seed(42)
work_path = 'work/model'

# CloFormer-XXS
model = cloformer_xxs()

criterion = LabelSmoothingCrossEntropy()

scheduler = paddle.optimizer.lr.CosineAnnealingDecay(learning_rate=learning_rate, T_max=50000 // batch_size * n_epochs, verbose=False)
optimizer = paddle.optimizer.Adam(parameters=model.parameters(), learning_rate=scheduler, weight_decay=1e-5)

gate = 0.0
threshold = 0.0
best_acc = 0.0
val_acc = 0.0
loss_record = {'train': {'loss': [], 'iter': []}, 'val': {'loss': [], 'iter': []}}   # for recording loss
acc_record = {'train': {'acc': [], 'iter': []}, 'val': {'acc': [], 'iter': []}}      # for recording accuracy

loss_iter = 0
acc_iter = 0

for epoch in range(n_epochs):
    # ---------- Training ----------
    model.train()
    train_num = 0.0
    train_loss = 0.0

    val_num = 0.0
    val_loss = 0.0
    accuracy_manager = paddle.metric.Accuracy()
    val_accuracy_manager = paddle.metric.Accuracy()
    print("#===epoch: {}, lr={:.10f}===#".format(epoch, optimizer.get_lr()))
    for batch_id, data in enumerate(train_loader):
        x_data, y_data = data
        labels = paddle.unsqueeze(y_data, axis=1)

        logits = model(x_data)

        loss = criterion(logits, y_data)

        acc = paddle.metric.accuracy(logits, labels)
        accuracy_manager.update(acc)
        if batch_id % 10 == 0:
            loss_record['train']['loss'].append(loss.numpy())
            loss_record['train']['iter'].append(loss_iter)
            loss_iter += 1

        loss.backward()

        optimizer.step()
        scheduler.step()
        optimizer.clear_grad()
        
        train_loss += loss
        train_num += len(y_data)

    total_train_loss = (train_loss / train_num) * batch_size
    train_acc = accuracy_manager.accumulate()
    acc_record['train']['acc'].append(train_acc)
    acc_record['train']['iter'].append(acc_iter)
    acc_iter += 1
    # Print the information.
    print("#===epoch: {}, train loss is: {}, train acc is: {:2.2f}%===#".format(epoch, total_train_loss.numpy(), train_acc*100))

    # ---------- Validation ----------
    model.eval()

    for batch_id, data in enumerate(val_loader):

        x_data, y_data = data
        labels = paddle.unsqueeze(y_data, axis=1)
        with paddle.no_grad():
          logits = model(x_data)

        loss = criterion(logits, y_data)

        acc = paddle.metric.accuracy(logits, labels)
        val_accuracy_manager.update(acc)

        val_loss += loss
        val_num += len(y_data)

    total_val_loss = (val_loss / val_num) * batch_size
    loss_record['val']['loss'].append(total_val_loss.numpy())
    loss_record['val']['iter'].append(loss_iter)
    val_acc = val_accuracy_manager.accumulate()
    acc_record['val']['acc'].append(val_acc)
    acc_record['val']['iter'].append(acc_iter)
    
    print("#===epoch: {}, val loss is: {}, val acc is: {:2.2f}%===#".format(epoch, total_val_loss.numpy(), val_acc*100))

    # ===================save====================
    if val_acc > best_acc:
        best_acc = val_acc
        paddle.save(model.state_dict(), os.path.join(work_path, 'best_model.pdparams'))
        paddle.save(optimizer.state_dict(), os.path.join(work_path, 'best_optimizer.pdopt'))

print(best_acc)
paddle.save(model.state_dict(), os.path.join(work_path, 'final_model.pdparams'))
paddle.save(optimizer.state_dict(), os.path.join(work_path, 'final_optimizer.pdopt'))

2.5 结果分析

def plot_learning_curve(record, title='loss', ylabel='CE Loss'):
    ''' Plot learning curve of your CNN '''
    maxtrain = max(map(float, record['train'][title]))
    maxval = max(map(float, record['val'][title]))
    ymax = max(maxtrain, maxval) * 1.1
    mintrain = min(map(float, record['train'][title]))
    minval = min(map(float, record['val'][title]))
    ymin = min(mintrain, minval) * 0.9

    total_steps = len(record['train'][title])
    x_1 = list(map(int, record['train']['iter']))
    x_2 = list(map(int, record['val']['iter']))
    figure(figsize=(10, 6))
    plt.plot(x_1, record['train'][title], c='tab:red', label='train')
    plt.plot(x_2, record['val'][title], c='tab:cyan', label='val')
    plt.ylim(ymin, ymax)
    plt.xlabel('Training steps')
    plt.ylabel(ylabel)
    plt.title('Learning curve of {}'.format(title))
    plt.legend()
    plt.show()
plot_learning_curve(loss_record, title='loss', ylabel='CE Loss')

在这里插入图片描述

plot_learning_curve(acc_record, title='acc', ylabel='Accuracy')

在这里插入图片描述

import time
work_path = 'work/model'
model = cloformer_xxs()
model_state_dict = paddle.load(os.path.join(work_path, 'best_model.pdparams'))
model.set_state_dict(model_state_dict)
model.eval()
aa = time.time()
for batch_id, data in enumerate(val_loader):

    x_data, y_data = data
    labels = paddle.unsqueeze(y_data, axis=1)
    with paddle.no_grad():
        logits = model(x_data)
bb = time.time()
print("Throughout:{}".format(int(len(val_dataset)//(bb - aa))))
Throughout:856
def get_cifar10_labels(labels):  
    """返回CIFAR10数据集的文本标签。"""
    text_labels = [
        'airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog',
        'horse', 'ship', 'truck']
    return [text_labels[int(i)] for i in labels]
def show_images(imgs, num_rows, num_cols, pred=None, gt=None, scale=1.5):  
    """Plot a list of images."""
    figsize = (num_cols * scale, num_rows * scale)
    _, axes = plt.subplots(num_rows, num_cols, figsize=figsize)
    axes = axes.flatten()
    for i, (ax, img) in enumerate(zip(axes, imgs)):
        if paddle.is_tensor(img):
            ax.imshow(img.numpy())
        else:
            ax.imshow(img)
        ax.axes.get_xaxis().set_visible(False)
        ax.axes.get_yaxis().set_visible(False)
        if pred or gt:
            ax.set_title("pt: " + pred[i] + "\ngt: " + gt[i])
    return axes
work_path = 'work/model'
X, y = next(iter(DataLoader(val_dataset, batch_size=18)))
model = cloformer_xxs()
model_state_dict = paddle.load(os.path.join(work_path, 'best_model.pdparams'))
model.set_state_dict(model_state_dict)
model.eval()
logits = model(X)
y_pred = paddle.argmax(logits, -1)
X = paddle.transpose(X, [0, 2, 3, 1])
axes = show_images(X.reshape((18, 224, 224, 3)), 1, 18, pred=get_cifar10_labels(y_pred), gt=get_cifar10_labels(y))
plt.show()
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).

在这里插入图片描述

!pip install interpretdl
import interpretdl as it
work_path = 'work/model'
model = cloformer_xxs()
model_state_dict = paddle.load(os.path.join(work_path, 'best_model.pdparams'))
model.set_state_dict(model_state_dict)
X, y = next(iter(DataLoader(val_dataset, batch_size=18)))
lime = it.LIMECVInterpreter(model)
lime_weights = lime.interpret(X.numpy()[3], interpret_class=y.numpy()[3], batch_size=100, num_samples=10000, visual=True)
100%|██████████| 10000/10000 [00:50<00:00, 198.19it/s]

50<00:00, 198.19it/s]

在这里插入图片描述

总结

        本论文提出了CloFormer,一种具有上下文感知的局部增强机制的轻量级视觉 Transformer,并且开发了一种新颖的局部感知方法。CloFormer 在类似的 FLOPs 和模型大小的模型中取得了竞争性的性能。特别地,经过精心设计的AttnConv 利用共享权重和上下文感知权重有效地提取高频局部特征表示。此外,采用了双分支结构来混合高频和低频信息。大量实验证明 CloFormer 是一种轻量高效的视觉骨干网络,胜过了许多现有的 SOTA 方法。

参考文献

  1. Rethinking Local Perception in Lightweight Vision Transformer
  2. 即插即用系列 | 清华提出最新移动端高效网络架构 CloFormer: 注意力机制与卷积的完美融合!

此文章为搬运
原项目链接

Logo

学大模型,用大模型上飞桨星河社区!每天8点V100G算力免费领!免费领取ERNIE 4.0 100w Token >>>

更多推荐