• Attention的基本流程是,查询q与键值k相乘获得权重a,a与值v相乘获得注意力值。这篇博客讲的很清晰。
  • TopFormer使用多头注意力机制
    • 查询qq的每个头中特征图的每个元素有key_dim个特征
    • 键值kk和qq维度相同,为了相乘进行了转置
    • 值vv和LeViT一致,扩大了每个元素的维度,特征更多
class Attention(torch.nn.Module):
    def __init__(self, dim, key_dim, num_heads,
                 attn_ratio=4,
                 activation=None,
                 norm_cfg=dict(type='BN', requires_grad=True),):
        super().__init__() 
        self.num_heads = num_heads
        self.scale = key_dim ** -0.5
        self.key_dim = key_dim
        self.nh_kd = nh_kd = key_dim * num_heads # num_head key_dim
        self.d = int(attn_ratio * key_dim)
        self.dh = int(attn_ratio * key_dim) * num_heads
        self.attn_ratio = attn_ratio

        self.to_q = Conv2d_BN(dim, nh_kd, 1, norm_cfg=norm_cfg)
        self.to_k = Conv2d_BN(dim, nh_kd, 1, norm_cfg=norm_cfg)
        self.to_v = Conv2d_BN(dim, self.dh, 1, norm_cfg=norm_cfg)

        self.proj = torch.nn.Sequential(activation(), Conv2d_BN(
            self.dh, dim, bn_weight_init=0, norm_cfg=norm_cfg))

    def forward(self, x):  # x (B,N,C)
        B, C, H, W = get_shape(x)
        
        # B*num_heads*hw*key_dim,每个像素有key_dim个特征
        qq = self.to_q(x).reshape(B, self.num_heads, self.key_dim, H * W).permute(0, 1, 3, 2)
        # B*num_heads*key_dim*hw
        kk = self.to_k(x).reshape(B, self.num_heads, self.key_dim, H * W)
        # B*num_heads*hw*d
        vv = self.to_v(x).reshape(B, self.num_heads, self.d, H * W).permute(0, 1, 3, 2)

        # B*num_heads*hw*hw
        attn = torch.matmul(qq, kk)
        # hw的每个元素与所有元素的权重,类似协方差矩阵
        attn = attn.softmax(dim=-1) # dim = k

        # B*num_heads*hw*d
        xx = torch.matmul(attn, vv)

        xx = xx.permute(0, 1, 3, 2).reshape(B, self.dh, H, W)
        xx = self.proj(xx)
        return xx
Logo

欢迎加入西安开发者社区!我们致力于为西安地区的开发者提供学习、合作和成长的机会。参与我们的活动,与专家分享最新技术趋势,解决挑战,探索创新。加入我们,共同打造技术社区!

更多推荐