deepfillv2的动机

​ 结合了几乎所有的目前先进的图像修复技术,基于部分卷积提出了门控卷积,结合了CA中的注意力机制,根据 Adversarial Edge图像修复中的边缘信息先验提出了用户可交互的草图先验信息。基于spectral-normallized GAN 提出了 SN-PatchGAN 鉴别器,本文所用的损失函数只有l1 重建损失和 SN-PatchGAN损失.

1. Gated Convolution

为了介绍门控卷积,得先提提部分卷积,对于分类、分割等任务,网络的输入像素是全部有效的,而对于修复任务,孔洞区域的像素是无效像素,如果将其当成和其他区域的像素一样处理,那么必然会造成修复结果的模糊,颜色不一致等情况,基于这种原因,部分卷积(partial convolution)被提出。它的实现机制在我上一篇Image Inpainting for Irregular Holes Using有被提到。它的目的在于,使得卷积的结果尽量只依赖与有效像素。部分卷积有效提高了非规则掩模上的图像修复质量。但是仍然还存在一些问题:

  1. 在跟新mask时,它启发式地将所有空间位置分类为有效或无效。无论前一层的过滤范围覆盖了多少像素,下一层的掩码都将被设置为1(例如,1个有效的像素和9个有效的像素被当作相同的来更新当前的掩码),这样显得不太合理。
  2. 如果模型是需要与用户进行交互的,那么用户输入的稀疏草图掩模作为条件通道。在这种情况下,应该认为这些像素位置是有效的还是无效的?如何正确更新下一层的mask
  3. 对于部分卷积来说,如果网络加深到一定程度那么mask最终会被全部更新为1(即全部都是有效像素),本文的作者提出应该让网络自动学习最优的掩码,网络将软掩码值分配给每个空间位置
  4. 部分卷积中,每个层中的所有通道都共享同一个掩码mask,这限制了灵活性。本质上,部分卷积可以被视为不可学习的单通道特征硬门控。

部分卷积与门控卷积的图示区别如下图:

在这里插入图片描述

基于上述部分卷积的一些问题,本文作者提出了门控卷积。取代了部分卷积的硬门控的掩码mask更新规则,门控卷积从数据中自动学习软掩码mask.更新的数学表达如下:

在这里插入图片描述

这里的I是特征图, σ \sigma σ是sigmoid()函数, ϕ \phi ϕ是激活函数,可以是ReLU、ELU、LeakyReLU。实际就是对I分别做两次卷积,然后其中一个卷积用sigmoid()函数,将其值全部限制在0-1之间,然后与另外一个卷积得到的特征图进行逐像素的相乘。

门控卷积的代码实现非常简单,如下:

#1.门控卷积的模块
class Gated_Conv(nn.Module):
    def __init__(self,in_ch,out_ch,ksize=3,stride=1,rate=1,activation=nn.ELU):
        super(Gated_Conv, self).__init__()
        padding=int(rate*(ksize-1)/2)
        #通过卷积将通道数变成输出两倍,其中一半用来做门控,学习
        self.conv=nn.Conv2d(in_ch,2*out_ch,kernel_size=ksize,stride=stride,padding=padding,dilation=rate)
        self.activation=activation
    def forward(self,x):
        raw=self.conv(x)
        x1=raw.split(int(raw.shape[1]/2),dim=1)#将特征图分成两半,其中一半是做学习
        gate=torch.sigmoid(x1[0])#将值限制在0-1之间
        out=self.activation(x1[1])*gate
        return out

2. SN-PatchGAN

对于孔洞单一为矩形的,local GAN 使用提升了修复结果,但是对于自由形式孔洞区域,这种局部鉴别器显然不太适用。基于 global and local GANs、MarkovianGAN、perceptual loss 和spectral-normalized loss.。作者提出了简单高效的SN-PatchGAN,可以应对自由形式的空洞破损。网络结构如下图所示:

在这里插入图片描述

网络的输入包括:破损图片、孔洞掩码mask、用户指导的先验草图信息。网络的输出是3D的feature map.而不是传统鉴别器输出的了一个打分标量。网络堆叠了6个卷积为kernel size为5,stride=2去捕获Markovian patches的特征统计信息。值得注意的是输出特征图的每一个元素的感受野都是包含了整个输入图。因此全局鉴别器也就不需要了。同时也采用了spectral normalizetion (借鉴的是SN-GANs)来进一步稳定GAN的训练。为了鉴别出真图还是假图,采用了hinge loss作为目标函数,对于生成器G:
l o s s G = − E z − p z ( z ) [ D s n ( G ( z ) ) ] loss_G=-E_{z-p_z(z)}[D^{sn}(G(z))] lossG=Ezpz(z)[Dsn(G(z))]
对于鉴别器:
l o s s D = E x − P d a t a ( x ) [ R e L U ( 1 − D s n ( x ) ) ] + E z − p z ( z ) [ R e L U ( 1 + D s n ( G ( z ) ) ) ] loss_D=E_{x-P_{data}(x)}[ReLU(1-D^{sn}(x))]+E_{z-p_z(z)}[ReLU(1+D^{sn}(G(z)))] lossD=ExPdata(x)[ReLU(1Dsn(x))]+Ezpz(z)[ReLU(1+Dsn(G(z)))]
这里的 D s n D^{sn} Dsn代表spectral-normalized discriminator ,G是修复网络。

鉴别器网络结构实现如下:

#1.
class SpectralNorm(nn.Module):
    '''
    spectral normalization,modified from https://github.com/christiancosgrove/pytorch-spectral-normalization-gan/blob/master/spectral_normalization.py

    '''
    def __init__(self, module, name='weight', power_iterations=1):
        super(SpectralNorm, self).__init__()
        self.module = module
        self.name = name
        self.power_iterations = power_iterations
        if not self._made_params():
            self._make_params()

    def _update_u_v(self):
        u = getattr(self.module, self.name + "_u")
        v = getattr(self.module, self.name + "_v")
        w = getattr(self.module, self.name + "_bar")

        height = w.data.shape[0]
        for _ in range(self.power_iterations):
            v.data = l2normalize(torch.mv(torch.t(w.view(height,-1).data), u.data))
            u.data = l2normalize(torch.mv(w.view(height,-1).data, v.data))

        # sigma = torch.dot(u.data, torch.mv(w.view(height,-1).data, v.data))
        sigma = u.dot(w.view(height, -1).mv(v))
        setattr(self.module, self.name, w / sigma.expand_as(w))

    def _made_params(self):
        try:
            u = getattr(self.module, self.name + "_u")
            v = getattr(self.module, self.name + "_v")
            w = getattr(self.module, self.name + "_bar")
            return True
        except AttributeError:
            return False


    def _make_params(self):
        w = getattr(self.module, self.name)

        height = w.data.shape[0]
        width = w.view(height, -1).data.shape[1]

        u = Parameter(w.data.new(height).normal_(0, 1), requires_grad=False)
        v = Parameter(w.data.new(width).normal_(0, 1), requires_grad=False)
        u.data = l2normalize(u.data)
        v.data = l2normalize(v.data)
        w_bar = Parameter(w.data)

        del self.module._parameters[self.name]

        self.module.register_parameter(self.name + "_u", u)
        self.module.register_parameter(self.name + "_v", v)
        self.module.register_parameter(self.name + "_bar", w_bar)


    def forward(self, *args):
        self._update_u_v()
        return self.module.forward(*args)

#2.SN卷积层实现
    class SN_Conv(nn.Module):
    def __init__(self,in_ch,out_ch,ksize=3,stride=1,rate=1,activation=nn.LeakyReLU()):
        super(SN_Conv,self).__init__()
        padding = int(rate * (ksize - 1) / 2)
        conv = nn.Conv2d(in_ch,out_ch, kernel_size=ksize, stride=stride, padding=padding, dilation=rate)
        self.snconv = SpectralNorm(conv)
        self.activation = activation
    def forward(self,x):
        x1 = self.snconv(x)
        if self.activation is not None:
            x1 = self.activation(x1)
        return x1

    
#3.sn鉴别器网络    
class SNDiscriminator(nn.Module):
    def __init__(self,in_ch=5,cnum=64):
        super(SNDiscriminator,self).__init__()

        disconv_layer = OrderedDict()
        disconv_layer['conv1'] = SN_Conv(in_ch=in_ch,out_ch=cnum,ksize=5,stride=2)
        disconv_layer['conv2'] = SN_Conv(in_ch=cnum, out_ch=2*cnum, ksize=5, stride=2)
        disconv_layer['conv3'] = SN_Conv(in_ch=2*cnum, out_ch=4*cnum, ksize=5, stride=2)
        disconv_layer['conv4'] = SN_Conv(in_ch=4 * cnum, out_ch=4 * cnum, ksize=5, stride=2)
        disconv_layer['conv5'] = SN_Conv(in_ch=4 * cnum, out_ch=4 * cnum, ksize=5, stride=2)
        disconv_layer['conv6'] = SN_Conv(in_ch=4 * cnum, out_ch=4 * cnum, ksize=5, stride=2)
        self.dislayer = nn.Sequential(disconv_layer)

    def forward(self,x):
        x1 = self.dislayer(x)
        #print(x1.shape)
        out = x1.view(x1.shape[0],-1)
        return out

3. inpainting Network Architecture

在这里插入图片描述

整个修复网络分为两个阶段(粗阶段和细化阶段),卷积部分都采用了门控卷积:

#1.粗阶段,输入是5通道(破损图片3,掩码mask,用户指导草图),输出为3通道
class CoarseNet(nn.Module):
    def __init__(self,in_ch=5,cnum=48):
        super(CoarseNet,self).__init__()
        self.conv1 = Gated_Conv(in_ch=in_ch,out_ch=cnum,ksize=5)
        self.conv2_down = Gated_Conv(in_ch=cnum,out_ch=2*cnum,stride=2)
        self.conv3 = Gated_Conv(in_ch=2*cnum,out_ch=2*cnum)
        self.conv4_down = Gated_Conv(in_ch=2*cnum,out_ch=4*cnum,stride=2)
        self.conv5 = Gated_Conv(in_ch=4*cnum,out_ch=4*cnum)
        self.conv6 = Gated_Conv(in_ch=4*cnum,out_ch=4*cnum)

        self.conv7 = Gated_Conv(in_ch=4*cnum,out_ch=4*cnum,rate=2)
        self.conv8 = Gated_Conv(in_ch=4 * cnum, out_ch=4 * cnum, rate=4)
        self.conv9 = Gated_Conv(in_ch=4 * cnum, out_ch=4 * cnum, rate=8)
        self.conv10 = Gated_Conv(in_ch=4 * cnum, out_ch=4 * cnum, rate=16)

        self.conv11 = Gated_Conv(in_ch=4 * cnum, out_ch=4 * cnum)
        self.conv12 = Gated_Conv(in_ch=4 * cnum, out_ch=4 * cnum)

        self.conv13_up = Gated_Deconv(in_ch=4*cnum,out_ch=2*cnum)
        self.conv14 = Gated_Conv(in_ch=2*cnum,out_ch=2*cnum)
        self.conv15_up = Gated_Deconv(in_ch=2*cnum,out_ch=cnum)
        self.conv16 = Gated_Conv(in_ch=cnum,out_ch=cnum//2)

        self.conv17 = nn.Conv2d(in_channels=cnum//2,out_channels=3,kernel_size=3,stride=1,padding=1)


    def forward(self,x):
        x1 = self.conv1(x)
        x2 = self.conv2_down(x1)
        x3 = self.conv3(x2)
        x4 = self.conv4_down(x3)
        x5 = self.conv5(x4)
        x6 = self.conv6(x5)
        x7 = self.conv7(x6)
        x8 = self.conv8(x7)
        x9 = self.conv9(x8)
        x10 = self.conv10(x9)
        x11 = self.conv11(x10)
        x12 = self.conv12(x11)
        x13 = self.conv13_up(x12)
        x14 = self.conv14(x13)
        x15 = self.conv15_up(x14)
        x16 = self.conv16(x15)
        x17 = self.conv17(x16)
        x_stage1 = F.tanh(x17)
        return x_stage1


#2,细化阶段的输入为粗阶段的输出结果,该阶段有两个分支(卷积分支和注意力机制分支)
class RefineNet(nn.Module):
    def __init__(self,in_ch=3,cnum=48):
        super(RefineNet,self).__init__()
        #1.conv branch
        xconv_layer = OrderedDict()
        xconv_layer['xconv1'] = Gated_Conv(in_ch=in_ch,out_ch=cnum,ksize=5)
        xconv_layer['xconv2_down'] = Gated_Conv(in_ch=cnum,out_ch=cnum,stride=2)
        xconv_layer['xconv3'] =  Gated_Conv(in_ch=cnum,out_ch=2*cnum)
        xconv_layer['xconv4_down'] = Gated_Conv(in_ch=2*cnum,out_ch=2*cnum,stride=2)
        xconv_layer['xconv5'] = Gated_Conv(in_ch=2*cnum,out_ch=4*cnum)
        xconv_layer['xconv6'] = Gated_Conv(in_ch=4*cnum,out_ch=4*cnum)

        xconv_layer['xconv7_atrous']  = Gated_Conv(in_ch=4*cnum,out_ch=4*cnum,rate=2)
        xconv_layer['xconv8_atrous'] = Gated_Conv(in_ch=4 * cnum, out_ch=4 * cnum, rate=4)
        xconv_layer['xconv9_atrous'] = Gated_Conv(in_ch=4 * cnum, out_ch=4 * cnum, rate=8)
        xconv_layer['xconv10_atrous'] = Gated_Conv(in_ch=4 * cnum, out_ch=4 * cnum, rate=16)

        self.xlayer = nn.Sequential(xconv_layer)

        #2.attention brach
        pmconv_layer1 = OrderedDict()
        pmconv_layer1['pmconv1'] = Gated_Conv(in_ch=in_ch,out_ch=cnum,ksize=5)
        pmconv_layer1['pmconv2_down'] = Gated_Conv(in_ch=cnum,out_ch=cnum,stride=2)
        pmconv_layer1['pmconv3'] = Gated_Conv(in_ch=cnum,out_ch=2*cnum)
        pmconv_layer1['pmconv4_down'] = Gated_Conv(in_ch=2*cnum, out_ch=4*cnum, stride=2)
        pmconv_layer1['pmconv5'] = Gated_Conv(in_ch=4*cnum,out_ch=4*cnum)
        pmconv_layer1['pmconv6'] = Gated_Conv(in_ch=4 * cnum, out_ch=4 * cnum,activation=nn.ReLU())
        self.pmlayer1 = nn.Sequential(pmconv_layer1)

        self.CA = Contextual_Attention(rate=2)

        pmconv_layer2 = OrderedDict()
        pmconv_layer2['pmconv9'] = Gated_Conv(in_ch=4*cnum,out_ch=4*cnum)
        pmconv_layer2['pmconv10'] = Gated_Conv(in_ch=4*cnum,out_ch=4*cnum)
        self.pmlayer2 = nn.Sequential(pmconv_layer2)

        #confluent branch
        allconv_layer = OrderedDict()
        allconv_layer['allconv11'] = Gated_Conv(in_ch=8*cnum,out_ch=4*cnum)
        allconv_layer['allconv12'] = Gated_Conv(in_ch=4 * cnum, out_ch=4 * cnum)
        allconv_layer['allconv13_up'] = Gated_Deconv(in_ch=4 * cnum, out_ch=2 * cnum)
        allconv_layer['allconv14'] = Gated_Conv(in_ch=2 * cnum, out_ch=2 * cnum)
        allconv_layer['allconv15_up'] = Gated_Deconv(in_ch=2 * cnum, out_ch=cnum)
        allconv_layer['allconv16'] = Gated_Conv(in_ch=cnum, out_ch=cnum//2)
        allconv_layer['allconv17'] = nn.Conv2d(in_channels=cnum//2,out_channels=3,kernel_size=3,padding=1)
        allconv_layer['tanh'] = nn.Tanh()
        self.colayer = nn.Sequential(allconv_layer)

    def forward(self, xin, mask):

        x1 = self.xlayer(xin)
        x_hallu = x1

        x2 = self.pmlayer1(xin)
        mask_s = self.resize_mask_like(mask,x2)
        x3,offset_flow = self.CA(x2,x2,mask_s)
        x4 = self.pmlayer2(x3)
        pm = x4

        x5 = torch.cat((x_hallu,pm),dim=1)
        x6 = self.colayer(x5)
        x_stage2 = x6

        return x_stage2,offset_flow

    def resize_mask_like(self,mask,x):
        sizeh = x.shape[2]
        sizew = x.shape[3]
        return down_sample(mask,size=(sizeh,sizew))


#3.完整的修复网络
class CAGenerator(nn.Module):
    def __init__(self,in_ch=5,cnum=48,):
        super(CAGenerator,self).__init__()
        self.stage_1 = CoarseNet(in_ch=in_ch,cnum=cnum)
        self.stage_2 = RefineNet(in_ch=3,cnum=cnum)

    def forward(self,xin,mask):
        stage1_out = self.stage_1(xin)
        stage2_in = stage1_out * mask + xin[:,0:3,:,:] * (1. - mask)
        stage2_out,offset_flow = self.stage_2(stage2_in,mask)

        return stage1_out,stage2_out,offset_flow

4.总结

作者提出了一种基于端到端生成网络的新型自由形式图像修复系统,该网络具有门控卷积,并经过逐像素l1损失和SN-patchGAN训练。证明门控卷积改善了修复的质量。

参考文献

1.Free-Form Image Inpainting with Gated Convolution(ICCV2019)

2.https://github.com/KeyKy/generative-inpainting-2.0-pytorch

Logo

腾讯云面向开发者汇聚海量精品云计算使用和开发经验,营造开放的云计算技术生态圈。

更多推荐