Latent Diffusion Model 代码解读

Author: Sijin Yu

前言

github: https://github.com/CompVis/latent-diffusion.
Latent Diffusion 分为两个训练阶段, 第一阶段训练VAE, 第二阶段训练Diffusion, 代码的组织结构如下:
在这里插入图片描述

第一阶段: AutoEncoder 的训练

AutoencoderKL

位置: latent-diffusion/ldm/models/autoencoder.py

该类实现一个基于 VAEAutoEncoder.

方法:

  • init_from_ckpt(self, path, ignore_keys=list()). 从指定路径加载模型和状态字典. (代码略)
  • encode(self, x). 输入 x, 输出一个高斯分布, 返回一个 DiagonalGaussianDistribution 对象. (见前向过程)
  • decode(self, z). 输入 z, 输出其解码结果, 返回一个 torch.tensor 对象. (见前向过程)
  • forward(self, input, sample_posterior=True). 前向过程. 先 encode, 再 decode. 返回 decode 结果 torch.tensor 对象和 encode 结果高斯分布 DiagonalGaussianDistribution 对象. (点击跳转)
  • get_input(self, batch, k) 用于将输入数据处理为合适的形状. (代码略)
  • training_step(self, batch, batch_idx, optimizer_idx). 训练. (点击跳转)
  • validation_step(self, batch, batch_idx). 测试. (点击跳转)
  • configure_optimizers(self). 配置和构造优化器. (代码略)
  • get_last_layer(self). 返回模型最后一层的权重. (代码略)
  • log_images(self, batch, only_inputs=False, **kwargs). 记录生成的图像. (代码略)
  • to_rgb(self, x). 记录生成的分割图像. (代码略)
构造函数
def __init__(self,
             ddconfig,               # 用于构造Encoder和Decoder的配置参数
             lossconfig,             # 用于构造损失函数的配置参数
             embed_dim,              # embedding dim
             ckpt_path=None,         # 加载预训练模型的路径
             ignore_keys=[],         # 加载模型时忽略的层
             image_key="image",      # 输入批次中提取图像数据的键名
             colorize_nlabels=None,  # 看着没啥用
             monitor=None,           # 用于监控模型训练过程的对象
             ):
    super().__init__()
    self.image_key = image_key
    self.encoder = Encoder(**ddconfig)
    self.decoder = Decoder(**ddconfig)
    self.loss = instantiate_from_config(lossconfig)
    assert ddconfig["double_z"]
    self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)  # 量化z为embedding
    self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) # 解量化为z
    self.embed_dim = embed_dim
    if colorize_nlabels is not None:
        assert type(colorize_nlabels)==int
        self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
    if monitor is not None:
        self.monitor = monitor
    if ckpt_path is not None:
        self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)

传入八个参数:

  • ddconfig. 字典, 用于构造 Encoder 和 Decoder 的配置参数.
  • lossconfig. 字典, 用于构造损失函数的配置参数.
    这里的具体代码比较复杂, 简单来说就是: 字典 lossconfig 有两个重要 key, 分别是:
    • 'target'. value 为一个字符串, 表示使用哪一个损失函数. 例如一个合法的 value 为 torch.nn.CrossEntropyLoss.
    • 'params'. value 为一个字典, 可以为空, 默认为 dict(). 这个参数将用于构造损失函数.
  • embed_dim. Embedding Dim, 嵌入维度.
  • 其余的和模型构造基本无关.

这里, 13和14行中的:

self.encoder = Encoder(**ddconfig)
self.decoder = Decoder(**ddconfig)

EncoderDecoder 的具体代码可见: 点击跳转.

点击返回 AutoencoderKL.

前向过程

前向过程代码如下 :

def forward(self, input, sample_posterior=True):
    posterior = self.encode(input)
    if sample_posterior:
        z = posterior.sample()
    else:
        z = posterior.mode()
    dec = self.decode(z)
    return dec, posterior

参数 sample_posterior 表示 latent variable z 是采样得来, 还是直接取均值.

这里还涉及 encodedecode 过程, 代码分别如下:

def encode(self, x):
    h = self.encoder(x)  # 通道数为 2*z_channels
    moments = self.quant_conv(h)  # 通道数 2*z_channels -> 2*embed_dim
    posterior = DiagonalGaussianDistribution(moments)
    return posterior

def decode(self, z):
    z = self.post_quant_conv(z) # 通道数 2*embed_dim -> 2*z_channels
    dec = self.decoder(z)
    return dec

注意, 这里为什么通道数都要乘 2? 因为要预测均值和对数方差.

点击返回 AutoencoderKL.

训练过程
def training_step(self, batch, batch_idx, optimizer_idx):
    inputs = self.get_input(batch, self.image_key)
    reconstructions, posterior = self(inputs)

    if optimizer_idx == 0:
        # train encoder+decoder+logvar
        aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
                                        last_layer=self.get_last_layer(), split="train")
        self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
        self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
        return aeloss

    if optimizer_idx == 1:
        # train the discriminator
        discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
                                            last_layer=self.get_last_layer(), split="train")

        self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
        self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False)
        return discloss

这里其实就是计算loss, 关键是要搞清楚 self.loss 是什么. 我们去看构造函数的定义:

self.loss = instantiate_from_config(lossconfig)

发现这是一个非常复杂的函数. 大概逻辑就是这将从一个字典的字符串中读取loss对应的对象. 那么我们去看配置文件, latent-diffusion/autoencoder/autoencoder_kl_8x8x64.yaml, 发现损失函数的配置信息如下:

lossconfig:
  target: ldm.modules.losses.LPIPSWithDiscriminator
  params:
    disc_start: 50001
    kl_weight: 0.000001
    disc_weight: 0.5

因此, 这里的 self.loss 其实就是 ldm.modules.losses.LPIPSWithDiscriminator 类的对象.

点击跳转 LPIPSWithDiscriminator.

点击返回 AutoencoderKL.

测试过程
def validation_step(self, batch, batch_idx):
    inputs = self.get_input(batch, self.image_key)
    reconstructions, posterior = self(inputs)
    aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step,
                                    last_layer=self.get_last_layer(), split="val")

    discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step,
                                        last_layer=self.get_last_layer(), split="val")

    self.log("val/rec_loss", log_dict_ae["val/rec_loss"])
    self.log_dict(log_dict_ae)
    self.log_dict(log_dict_disc)
    return self.log_dict

点击返回 AutoencoderKL.


EncoderDecoder

位置: latent-diffusion/ldm/modules/diffusionmodules/model.py

EncoderDecoder 的代码非常简单, 就是很经典的网络, 这里不多做解释, 直接上代码.

快捷返回 AutoencoderKL.

Encoder
class Encoder(nn.Module):
    def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
                 attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
                 resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla",
                 **ignore_kwargs):
        super().__init__()
        if use_linear_attn: attn_type = "linear"
        self.ch = ch
        self.temb_ch = 0
        self.num_resolutions = len(ch_mult)
        self.num_res_blocks = num_res_blocks
        self.resolution = resolution
        self.in_channels = in_channels

        # downsampling
        self.conv_in = torch.nn.Conv2d(in_channels,
                                       self.ch,
                                       kernel_size=3,
                                       stride=1,
                                       padding=1)

        curr_res = resolution
        in_ch_mult = (1,)+tuple(ch_mult)
        self.in_ch_mult = in_ch_mult
        self.down = nn.ModuleList()
        for i_level in range(self.num_resolutions):
            block = nn.ModuleList()
            attn = nn.ModuleList()
            block_in = ch*in_ch_mult[i_level]
            block_out = ch*ch_mult[i_level]
            for i_block in range(self.num_res_blocks):
                block.append(ResnetBlock(in_channels=block_in,
                                         out_channels=block_out,
                                         temb_channels=self.temb_ch,
                                         dropout=dropout))
                block_in = block_out
                if curr_res in attn_resolutions:
                    attn.append(make_attn(block_in, attn_type=attn_type))
            down = nn.Module()
            down.block = block
            down.attn = attn
            if i_level != self.num_resolutions-1:
                down.downsample = Downsample(block_in, resamp_with_conv)
                curr_res = curr_res // 2
            self.down.append(down)

        # middle
        self.mid = nn.Module()
        self.mid.block_1 = ResnetBlock(in_channels=block_in,
                                       out_channels=block_in,
                                       temb_channels=self.temb_ch,
                                       dropout=dropout)
        self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
        self.mid.block_2 = ResnetBlock(in_channels=block_in,
                                       out_channels=block_in,
                                       temb_channels=self.temb_ch,
                                       dropout=dropout)

        # end
        self.norm_out = Normalize(block_in)
        self.conv_out = torch.nn.Conv2d(block_in,
                                        2*z_channels if double_z else z_channels,
                                        kernel_size=3,
                                        stride=1,
                                        padding=1)

    def forward(self, x):
        # timestep embedding
        temb = None

        # downsampling
        hs = [self.conv_in(x)]
        for i_level in range(self.num_resolutions):
            for i_block in range(self.num_res_blocks):
                h = self.down[i_level].block[i_block](hs[-1], temb)
                if len(self.down[i_level].attn) > 0:
                    h = self.down[i_level].attn[i_block](h)
                hs.append(h)
            if i_level != self.num_resolutions-1:
                hs.append(self.down[i_level].downsample(hs[-1]))

        # middle
        h = hs[-1]
        h = self.mid.block_1(h, temb)
        h = self.mid.attn_1(h)
        h = self.mid.block_2(h, temb)

        # end
        h = self.norm_out(h)
        h = nonlinearity(h)
        h = self.conv_out(h)
        return h

点击此处返回 EncoderDecoder.

Decoder
class Decoder(nn.Module):
    def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
                 attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
                 resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False,
                 attn_type="vanilla", **ignorekwargs):
        super().__init__()
        if use_linear_attn: attn_type = "linear"
        self.ch = ch
        self.temb_ch = 0
        self.num_resolutions = len(ch_mult)
        self.num_res_blocks = num_res_blocks
        self.resolution = resolution
        self.in_channels = in_channels
        self.give_pre_end = give_pre_end
        self.tanh_out = tanh_out

        # compute in_ch_mult, block_in and curr_res at lowest res
        in_ch_mult = (1,)+tuple(ch_mult)
        block_in = ch*ch_mult[self.num_resolutions-1]
        curr_res = resolution // 2**(self.num_resolutions-1)
        self.z_shape = (1,z_channels,curr_res,curr_res)
        print("Working with z of shape {} = {} dimensions.".format(
            self.z_shape, np.prod(self.z_shape)))

        # z to block_in
        self.conv_in = torch.nn.Conv2d(z_channels,
                                       block_in,
                                       kernel_size=3,
                                       stride=1,
                                       padding=1)

        # middle
        self.mid = nn.Module()
        self.mid.block_1 = ResnetBlock(in_channels=block_in,
                                       out_channels=block_in,
                                       temb_channels=self.temb_ch,
                                       dropout=dropout)
        self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
        self.mid.block_2 = ResnetBlock(in_channels=block_in,
                                       out_channels=block_in,
                                       temb_channels=self.temb_ch,
                                       dropout=dropout)

        # upsampling
        self.up = nn.ModuleList()
        for i_level in reversed(range(self.num_resolutions)):
            block = nn.ModuleList()
            attn = nn.ModuleList()
            block_out = ch*ch_mult[i_level]
            for i_block in range(self.num_res_blocks+1):
                block.append(ResnetBlock(in_channels=block_in,
                                         out_channels=block_out,
                                         temb_channels=self.temb_ch,
                                         dropout=dropout))
                block_in = block_out
                if curr_res in attn_resolutions:
                    attn.append(make_attn(block_in, attn_type=attn_type))
            up = nn.Module()
            up.block = block
            up.attn = attn
            if i_level != 0:
                up.upsample = Upsample(block_in, resamp_with_conv)
                curr_res = curr_res * 2
            self.up.insert(0, up) # prepend to get consistent order

        # end
        self.norm_out = Normalize(block_in)
        self.conv_out = torch.nn.Conv2d(block_in,
                                        out_ch,
                                        kernel_size=3,
                                        stride=1,
                                        padding=1)

    def forward(self, z):
        #assert z.shape[1:] == self.z_shape[1:]
        self.last_z_shape = z.shape

        # timestep embedding
        temb = None

        # z to block_in
        h = self.conv_in(z)

        # middle
        h = self.mid.block_1(h, temb)
        h = self.mid.attn_1(h)
        h = self.mid.block_2(h, temb)

        # upsampling
        for i_level in reversed(range(self.num_resolutions)):
            for i_block in range(self.num_res_blocks+1):
                h = self.up[i_level].block[i_block](h, temb)
                if len(self.up[i_level].attn) > 0:
                    h = self.up[i_level].attn[i_block](h)
            if i_level != 0:
                h = self.up[i_level].upsample(h)

        # end
        if self.give_pre_end:
            return h

        h = self.norm_out(h)
        h = nonlinearity(h)
        h = self.conv_out(h)
        if self.tanh_out:
            h = torch.tanh(h)
        return h

点击此处返回 EncoderDecoder.

代码中涉及 make_attn(in_channels, attn_type="vanilla") 方法, 代码如下.

Attention
def make_attn(in_channels, attn_type="vanilla"):
    assert attn_type in ["vanilla", "linear", "none"], f'attn_type {attn_type} unknown'
    print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
    if attn_type == "vanilla":
        return AttnBlock(in_channels)
    elif attn_type == "none":
        return nn.Identity(in_channels)
    else:
        return LinAttnBlock(in_channels)
class AttnBlock(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.in_channels = in_channels

        self.norm = Normalize(in_channels)
        self.q = torch.nn.Conv2d(in_channels,
                                 in_channels,
                                 kernel_size=1,
                                 stride=1,
                                 padding=0)
        self.k = torch.nn.Conv2d(in_channels,
                                 in_channels,
                                 kernel_size=1,
                                 stride=1,
                                 padding=0)
        self.v = torch.nn.Conv2d(in_channels,
                                 in_channels,
                                 kernel_size=1,
                                 stride=1,
                                 padding=0)
        self.proj_out = torch.nn.Conv2d(in_channels,
                                        in_channels,
                                        kernel_size=1,
                                        stride=1,
                                        padding=0)


    def forward(self, x):
        h_ = x
        h_ = self.norm(h_)
        q = self.q(h_)
        k = self.k(h_)
        v = self.v(h_)

        # compute attention
        b,c,h,w = q.shape
        q = q.reshape(b,c,h*w)
        q = q.permute(0,2,1)   # b,hw,c
        k = k.reshape(b,c,h*w) # b,c,hw
        w_ = torch.bmm(q,k)     # b,hw,hw    w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
        w_ = w_ * (int(c)**(-0.5))
        w_ = torch.nn.functional.softmax(w_, dim=2)

        # attend to values
        v = v.reshape(b,c,h*w)
        w_ = w_.permute(0,2,1)   # b,hw,hw (first hw of k, second of q)
        h_ = torch.bmm(v,w_)     # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
        h_ = h_.reshape(b,c,h,w)

        h_ = self.proj_out(h_)

        return x+h_
class LinearAttention(nn.Module):
    def __init__(self, dim, heads=4, dim_head=32):
        super().__init__()
        self.heads = heads
        hidden_dim = dim_head * heads
        self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
        self.to_out = nn.Conv2d(hidden_dim, dim, 1)

    def forward(self, x):
        b, c, h, w = x.shape
        qkv = self.to_qkv(x)
        q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3)
        k = k.softmax(dim=-1)  
        context = torch.einsum('bhdn,bhen->bhde', k, v)
        out = torch.einsum('bhde,bhdn->bhen', context, q)
        out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)
        return self.to_out(out)
        
class LinAttnBlock(LinearAttention):
    """to match AttnBlock usage"""
    def __init__(self, in_channels):
        super().__init__(dim=in_channels, heads=1, dim_head=in_channels)

点击此处返回 EncoderDecoder.


DiagonalGaussianDistribution

位置: latent-diffusion/ldm/modules/distributions/distributions.py

该类表示一个对角高斯分布.

它有四个对外的方法:

  • sample(self). 返回一个服从该分布的随机样本. 点击跳转.
  • kl(self, other=None). 计算和另一个高斯分布 (默认为标准高斯分布) 的 KL 散度. 点击跳转.
  • nll(self, sample, dims=[1, 2, 3]). 计算给定样本的非负对数似然. 点击跳转.
  • mode(self). 返回均值. (代码略)
构造函数
def __init__(self, parameters, deterministic=False):
    self.parameters = parameters
    self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)  
    # 将张量parameters分为两个部分(从dim=1), 并赋值给均值和对数方差
    self.logvar = torch.clamp(self.logvar, -30.0, 20.0)  # 将对数方差限制在(-30.0, 20.0)这个范围
    self.deterministic = deterministic
    self.std = torch.exp(0.5 * self.logvar)  # 对数方差 -> 标准差
    self.var = torch.exp(self.logvar)  # 对数方差 -> 方差
    if self.deterministic:
        self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)

传入两个参数:

  • parameters. 一个 torch.tensor, 表示均值和对数方差.
  • deterministic=False. 是否有确定性. 如果 True, 则标准差和方差会被置为 0, 分布退化为一个确定的均值.

点击此处返回 DiagonalGaussianDistribution.

采样
def sample(self):
    # 返回一个这一分布的随机样本
    x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
    return x

点击此处返回 DiagonalGaussianDistribution.

KL 散度

类内计算 KL 散度的方法如下, 其中 other 参数传入另一个高斯分布对象, 默认为 None 时, 计算和标准高斯分布之间的 KL 散度.

def kl(self, other=None):
    if self.deterministic:
        return torch.Tensor([0.])
    else:
        if other is None:
            return 0.5 * torch.sum(torch.pow(self.mean, 2)
                                   + self.var - 1.0 - self.logvar,
                                   dim=[1, 2, 3])
        else:
            return 0.5 * torch.sum(
                torch.pow(self.mean - other.mean, 2) / other.var
                + self.var / other.var - 1.0 - self.logvar + other.logvar,
                dim=[1, 2, 3])

给定均值和方差, 也可以计算两个高斯分布之间的 KL 散度, 代码如下:

def normal_kl(mean1, logvar1, mean2, logvar2):
    tensor = None
    for obj in (mean1, logvar1, mean2, logvar2):
        if isinstance(obj, torch.Tensor):
            tensor = obj
            break
    assert tensor is not None, "at least one argument must be a Tensor"

    # Force variances to be Tensors. Broadcasting helps convert scalars to
    # Tensors, but it does not work for torch.exp().
    logvar1, logvar2 = [
        x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
        for x in (logvar1, logvar2)
    ]

    return 0.5 * (
        -1.0
        + logvar2
        - logvar1
        + torch.exp(logvar1 - logvar2)
        + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
    )

点击此处返回 DiagonalGaussianDistribution.

非负对数似然
def nll(self, sample, dims=[1,2,3]):
    if self.deterministic:
        return torch.Tensor([0.])
    logtwopi = np.log(2.0 * np.pi)
    return 0.5 * torch.sum(
        logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
        dim=dims)

点击此处返回 DiagonalGaussianDistribution.


LPIPSWithDiscriminator

位置: latent-diffusion/ldm/modules/losses/contperceptual.py

该类用于计算VAE的损失. 损失由四部分组成: (1) 真实图 - 生成图 像素级别的L1损失, (2) 真实图 - 生成图 特征级别的相似度损失, (3) VAE的KL损失, (4) 生成器和鉴别器的损失.

它有两个方法:

  • calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None). 计算自适应权重以平衡真实图 - 生成图的损失和生成/鉴别的损失. (代码略)

  • forward(self, inputs, reconstructions, posteriors, optimizer_idx, global_step, last_layer=None, cond=None, split="train", weights=None). 前向过程, 计算损失. 部分参数:

    • input. 真实的输入图像.
    • reconstructions. VAE重构的图像.
    • posteriors. VAE中间层预测的均值和方差的分布.
    • optimizer_idx. 一个指示器, 当其为 0 时优化生成器, 1 时优化鉴别器.

    点击跳转.

快捷返回 AutoencoderKL.

构造函数
 def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0,
             disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
             perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,
             disc_loss="hinge"):

    super().__init__()
    assert disc_loss in ["hinge", "vanilla"]
    self.kl_weight = kl_weight
    self.pixel_weight = pixelloss_weight
    self.perceptual_loss = LPIPS().eval()
    self.perceptual_weight = perceptual_weight
    # output log variance
    self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init)

    self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,
                                             n_layers=disc_num_layers,
                                             use_actnorm=use_actnorm
                                             ).apply(weights_init)
    self.discriminator_iter_start = disc_start
    self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss
    self.disc_factor = disc_factor
    self.discriminator_weight = disc_weight
    self.disc_conditional = disc_conditional

传入的参数:

  • disc_start. 用于开始应用鉴别器损失的迭代次数, 影响GAN损失的权重.
  • logvar_init. 对数方差的初始值, 用于衡量重构损失和正则损失. (下文有详细讨论)
  • kl_weight. KL损失的权重. (KL损失: VAE的预测高斯分布和标准高斯分布的KL损失, 这一损失也被认为是VAE中的一个正则损失).
  • pixelloss_weight. 像素损失的权重. 但这个参数在代码中完全没有用到. (像素损失: 真实的图像和生成的图像之间的L1损失).
  • disc_weight. 生成器/鉴别器损失的权重. (生成/鉴别损失: 对于鉴别器, 要识别真实图像/生成图像; 对于生成器, 要欺骗鉴别起). 这一参数和上面的 disc_start 共同影响GAN损失的权重.
  • perceptual_weight. 感知相似损失的权重. (感知相似损失: 和像素损失类似, 保证真实图像和生成图像相似. 感知损失是把图像放入VGG中, 计算各层的特征, 并计算特征之间的相似性).
  • disc_num_layers. 鉴别器的层数.
  • disc_in_channels. 鉴别器的输入通道数.
  • disc_factor. 控制GAN损失的因子. 它和上面的 disc_start, disc_weight 共同最终决定GAN损失的权重.
  • use_actnorm. 是否在GAN中使用激活归一化 (ActNorm).
  • disc_conditional. 鉴别器是否为有条件的.
  • disc_loss. 鉴别器损失函数的类型.

构造函数第10行中的: self.perceptual_loss = LPIPS().eval()LPIPS 类用于计算两个图像的感知相似度. 点击跳转 LPIPS.

点击此处返回 LPIPSWithDiscriminator.

前向过程
 def forward(self, inputs, reconstructions, posteriors, optimizer_idx,
            global_step, last_layer=None, cond=None, split="train",
            weights=None):
    # rec_loss为原图和生成图的L1距离
    rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
    if self.perceptual_weight > 0:
        # p_loss是LPIPS损失, 由图像的每一层vgg特征之间的相似度计算得来
        p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
        # 乘一个因子self.perceptual_weight来衡量不同损失的重要程度
        # 重构损失=L1距离+w*LPIPS损失
        rec_loss = rec_loss + self.perceptual_weight * p_loss
    # 计算非负对数似然
    nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar  # 这里下文有解释
    weighted_nll_loss = nll_loss
    if weights is not None:
        weighted_nll_loss = weights*nll_loss
    weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]
    nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
    # 计算后验分布和标准高斯分布之间的距离
    kl_loss = posteriors.kl()
    kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]

    # 下面的损失是用于训练GAN部分的
    # optimizer_idx有两个取值, 0或1, 0时更新生成器, 1时更新鉴别器
    if optimizer_idx == 0:
        # 更新生成器
        if cond is None:  # cond表示是否有条件判别
            assert not self.disc_conditional  # 无条件判别
            logits_fake = self.discriminator(reconstructions.contiguous())
        else:
            assert self.disc_conditional  # 有条件判别
            logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))
        # logits_fake是判别器的输出
        # 注意我们的输入是reconstructions, 这是假数据, 当前正在训练生成器, 目标是欺骗鉴别器
        # 鉴别器: 真数据 ---> 0;  假数据 ---> 1
        g_loss = -torch.mean(logits_fake)  # 生成器损失

        # 下面是给生成器损失乘一个权重, 目的是加强训练生成器
        # 当生成器权重<=0.0时, 不再使用生成器
        # 生成器只在训练VAE阶段用, 在训练Diffusion阶段不用
        if self.disc_factor > 0.0:
            try:
                d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
            except RuntimeError:
                assert not self.training
                d_weight = torch.tensor(0.0)
        else:
            d_weight = torch.tensor(0.0)

        disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
        # 损失 = 重构损失(weighted_nll_oss)+正则KL损失(kl_loss)+生成器损失(g_loss)
        loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss

        log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(),
               "{}/kl_loss".format(split): kl_loss.detach().mean(), "{}/nll_loss".format(split): nll_loss.detach().mean(),
               "{}/rec_loss".format(split): rec_loss.detach().mean(),
               "{}/d_weight".format(split): d_weight.detach(),
               "{}/disc_factor".format(split): torch.tensor(disc_factor),
               "{}/g_loss".format(split): g_loss.detach().mean(),
               }
        return loss, log

    if optimizer_idx == 1:
        # 更新鉴别器
        if cond is None:  # 同上, 是否有条件鉴别
            logits_real = self.discriminator(inputs.contiguous().detach())
            logits_fake = self.discriminator(reconstructions.contiguous().detach())
        else:
            logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))
            logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))
        # 同上, 鉴别器权重
        disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
        # self.disc_loss给出了如何训练鉴别器
        d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)  # 这里下文有解释

        log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
               "{}/logits_real".format(split): logits_real.detach().mean(),
               "{}/logits_fake".format(split): logits_fake.detach().mean()
               }
        return d_loss, log

鉴别器的损失 self.disc_loss 为何物 ?

首先, 看 self.disc_loss 的声明, 在构造函数中:

self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss

即鉴别器的损失有两种, 这两种损失的代码都非常简单, 如下:

def hinge_d_loss(logits_real, logits_fake):
 loss_real = torch.mean(F.relu(1. - logits_real))
 loss_fake = torch.mean(F.relu(1. + logits_fake))
 d_loss = 0.5 * (loss_real + loss_fake)
 return d_loss


def vanilla_d_loss(logits_real, logits_fake):
 d_loss = 0.5 * (
     torch.mean(torch.nn.functional.softplus(-logits_real)) +
     torch.mean(torch.nn.functional.softplus(logits_fake)))
 return d_loss

为什么要 nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar ?

首先, 来看 self.logvar 的声明, 在构造函数中:

self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init)

这是一个可学习的数. 通过将重构误差 rec_loss 正则化为 nll_loss, 允许模型估计重构误差的不确定性. 通过这种方式, 模型可以学习在哪些区域的重构更加困难. 例如, 如果模型认为某个区域的重构更加困难, 可以通过增加该区域的 self.logvar 值来降低重构误差的影响, 这有助于模型更加健壮, 更好地应对有噪声的数据.

那么, 有读者自然会疑问, 如果只是这样, 为什么不只使用下面的方法呢:

nll_loss = rec_loss / torch.exp(self.logvar)

换言之, 为什么要在后面加上 self.logvar ? 这其实也很容易理解, 我们不希望模型无脑地增加不确定性. 如果我们不加上 self.logvar, 那可能陷入一种这样的情况: 模型无限地增加 self.logvar, 认为重构总是很困难, 最终让重构误差 nll_loss 趋于 0, 并只考虑正则化误差. 这显然是不合适的, 因此在后面加上对数方差, 让模型能在两种情况下作出选择.

点击此处返回 LPIPSWithDiscriminator.


LPIPS

位置: taming/modules/losses/lpips.py

它全称为 Learned Perceptual Image Patch Similarity, 继承 torch.nn.Module, 用于比较两个图像在感知上的相似度.

它的主要方法有:

  • load_from_pretrained(self, name="vgg_lpips"). 用于加载预训练权重. (代码略)
  • from_pretrained(cls, name="vgg_lpips"). 类方法. 用于加载预训练权重. (代码略)
  • forward(self, input, target). 前向过程. 输入 input 为原图, target 为生成图, 返回两者在多尺度上的相似度. 点击跳转.

快捷返回 LPIPSWithDiscriminator.

构造函数
def __init__(self, use_dropout=True):
    super().__init__()
    self.scaling_layer = ScalingLayer()
    self.chns = [64, 128, 256, 512, 512]  # vg16 features
    self.net = vgg16(pretrained=True, requires_grad=False)
    self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
    self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
    self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
    self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
    self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
    self.load_from_pretrained()
    for param in self.parameters():
        param.requires_grad = False

点击此处返回 LPIPS.

前向过程
def forward(self, input, target):
    in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target))
    outs0, outs1 = self.net(in0_input), self.net(in1_input)
    feats0, feats1, diffs = {}, {}, {}
    lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
    for kk in range(len(self.chns)):
        feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk])
        diffs[kk] = (feats0[kk] - feats1[kk]) ** 2

    res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) for kk in range(len(self.chns))]
    val = res[0]
    for l in range(1, len(self.chns)):
        val += res[l]
    return val

点击此处返回 LPIPS.


第二阶段: Latent Diffusion 的训练

LatentDiffusion

位置: latent-diffusion/dfm/models/diffusion/ddpm.py

LatentDiffusion 类继承于经典的图像空间上的 DDPM 类, 十分建议先看 DDPM 类的代码. 点击跳转.

LatentDiffusion 类有以下方法:

  • __init__. 构造函数.
  • register_schedule. 注册时间表, 调用 DDPM 类中的 register_schedule 方法. (代码略)
  • make_cond_schedule. 在上面的 register_schedule 方法中被调用, 指定在扩散过程的哪些时间步骤上应用条件输入. (代码略)
  • on_train_batch_start. 使用了 rank_zero_only 装饰器和 torch.no_grad() 装饰器. 只在训练开始时的第一个批次触发, 目的是对隐空间设定一个标准化重缩放因子 self.scale_factor. 这对模型训练的稳定性和性能有益. (代码略)
  • instantiate_first_stage. 实例化第一阶段的模型 (即AutoEncoder) 并冻结模型参数. (代码略)
  • instantiate_cond_stage. 实例化条件编码模型 (即CLIP Text Encoder) 并冻结模型参数. (代码略)
  • _get_denoise_row_from_list. 从提供的样本中生成图像, 并将图像可视化为网格. (代码略)
  • get_first_stage_encoding. 从第一阶段的Encoder中获得latent variable z. (代码略)
  • get_learned_conditioning. 从条件编码器中得到条件编码. (代码略)
  • meshgrid. 创建一个网格坐标张量. 输入 hw, 表示图像的高和宽. 输出一个形状为 [h, w, 2]torch.tensor 对象, 分别表示每个像素的 y y y 坐标和 x x x 坐标. (代码略)
  • delta_border. 计算图像中每个像素到图像边缘的归一化距离. 输入 hw, 表示图像的高和宽. 输出一个形状为 [h, w, 1]torch.tensor 对象, 分别表示每个像素到图像边缘的归一化距离. (代码略)
  • get_weighting. 计算图像每个区域的权重, 中央区域权重大, 边缘区域权重小, 权重根据像素点到图像边缘的归一化距离决定. (代码略)
  • get_fold_unfold. 将图像打成 patch, 并根据每个区域的权重重置图像像素值. (代码略)
  • get_input. 使用了 torch.no_grad() 装饰器. 处理批量数据, 得到最终的输入, 包括图像 x 和条件 c.
  • decode_first_stage. 解码潜在表示 z. (代码略)
  • differentiable_decode_first_stage. 这个方法是 decode_first_stage 的可微版本, 即允许梯度传递. (代码略)
  • encode_first_stage. 使用了 torch.no_grad() 装饰器. 将图像编码为 z. (代码略)
  • shared_step. 在一个批量内共享时间步, 执行 Latent Diffusion. (代码略)
  • forward. 采样并执行反向过程, 返回重建损失. (代码略)
  • _rescale_annotations. 用于重新缩放图像中的边界框坐标. (代码略)
  • apply_model. 将带有噪声的图像 x_noise 应用于多个块, 每个块应用模型, 然后将它们重新组合为新的图像. (代码略)
  • _predict_eps_from_xstart. 这个函数无调用. (代码略)
  • _prior_bpd. 计算扩散最后一个时间步的分布和标准高斯分布之间的KL散度. 这个KL项只依赖于编码器, 它不能通过优化来改变, 它是模型对输入数据进行建模的一个度量. 这个函数无调用. (代码略)
  • p_losses. 和 DDPM 中的 p_losses 作用一致. (代码略)
  • p_mean_variance. 和 DDPM 中的 p_mean_variance 作用一致. (代码略)
  • p_sample. 和 DDPM 中的 p_sample 作用一致. (代码略)
  • p_sample_loop. 和 DDPM 中的 p_sample_loop 作用一致. (代码略)
  • progressive_denoising. 采样并生成最终图像. (代码略)
  • sample. 和 DDPM 中的 sample 作用一致. (代码略)

DDPM

快捷返回 LatentDiffusion.

位置: latent-diffusion/dfm/models/diffusion/ddpm.py

方法: (省略了传入参数)

  • __init__. 构造函数. (点击跳转)
  • register_schedule. 用于计算DDPM中的 β \beta β, α \alpha α 等参数, 以及扩散过程中的分布参数. (点击跳转)
  • ema_scope. 使用了 @contextmanager 装饰器. 用于训练中临时切换到使用指数移动平均权重的模型. (代码略)
  • init_from_ckpt. 从指定的 checkpoint 读取模型. (代码略)
  • q_mean_variance. 计算扩散过程中的条件分布 q ( x t ∣ x 0 ) q(x_t|x_0) q(xtx0), 返回均值, 方差, 对数方差. (点击跳转)
  • predict_start_from_noise. 给定带噪音的图像 x t x_t xt, 时间步 t t t, 预测噪音 ϵ ^ \hat\epsilon ϵ^, 计算预测去噪图像 x ^ 0 \hat x_0 x^0. (点击跳转)
  • q_posterior. 计算后验分布 q ( x t − 1 ∣ x t , x 0 ) q(x_{t-1}|x_t, x_0) q(xt1xt,x0), 返回均值, 方差, 对数方差. (点击跳转)
  • p_mean_variance. 计算反向过程 p ( x t − 1 ∣ x t ) p(x_{t-1}|x_t) p(xt1xt), 返回均值, 方差, 对数方差. (点击跳转)
  • p_sample. 用于反向过程采样, 给定 x t x_t xt, 通过 p ( x t − 1 ∣ x t ) p(x_{t-1}|x_t) p(xt1xt) 采样 x t − 1 x_{t-1} xt1. (无梯度) (点击跳转)
  • p_sample_loop. 用于反向过程采样, 通过 p ( x t − 1 ∣ x t ) p(x_{t-1}|x_t) p(xt1xt) 逐步从 x T x_T xT 得到 x 0 x_0 x0. (无梯度) (点击跳转)
  • sample. 用于反向过程采样, 通过 p ( x t − 1 ∣ x t ) p(x_{t-1}|x_t) p(xt1xt) 逐步从 x T x_T xT 得到 x 0 x_0 x0. (无梯度) (点击跳转)
  • q_sample. 用于扩散过程采样, 通过 q ( x t ∣ x 0 ) q(x_t|x_0) q(xtx0) 采样 x t x_t xt. (点击跳转)
  • get_loss. 计算UNet预测的噪音和真实噪音之间的损失. (点击跳转)
  • p_losses. 计算引入了变分下界损失VLB后的预测噪音和真实噪音之间的损失. (点击跳转)
  • forward. 前向过程. 输入原始图像 x 0 x_0 x0, 输出扩散损失. (点击跳转)
  • get_input. 处理输入图像为合适的形状. (代码略)
  • shared_step. 读取批量的输入图像, 然后执行前向过程, 得到损失, 让一个批量内的所有样本共享时间步. (代码略)
  • training_step. 训练. (点击跳转)
  • validation_step. 预测. (点击跳转)
  • on_train_batch_end. 这个函数看着没啥意义, 在代码中也没调用, 可以忽略.
  • _get_rows_from_list. 这个函数用于修改一些样本的形. (代码略)
  • log_images. 用来将生成的图像记录日志. (无梯度) (代码略)
  • configure_optimizers. 用来配置优化器, 用了 torch.optim.AdamW. (代码略)
构造函数
def __init__(self,
             unet_config,
             timesteps=1000,
             beta_schedule="linear",
             loss_type="l2",
             ckpt_path=None,
             ignore_keys=[],
             load_only_unet=False,
             monitor="val/loss",
             use_ema=True,
             first_stage_key="image",
             image_size=256,
             channels=3,
             log_every_t=100,
             clip_denoised=True,
             linear_start=1e-4,
             linear_end=2e-2,
             cosine_s=8e-3,
             given_betas=None,
             original_elbo_weight=0.,
             v_posterior=0.,  # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta
             l_simple_weight=1.,
             conditioning_key=None,
             parameterization="eps",  # all assuming fixed variance schedules
             scheduler_config=None,
             use_positional_encodings=False,
             learn_logvar=False,
             logvar_init=0.,
             ):
    super().__init__()
    assert parameterization in ["eps", "x0"], 'currently only supporting "eps" and "x0"'
    self.parameterization = parameterization
    print(f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode")
    self.cond_stage_model = None
    self.clip_denoised = clip_denoised
    self.log_every_t = log_every_t
    self.first_stage_key = first_stage_key
    self.image_size = image_size  # try conv?
    self.channels = channels
    self.use_positional_encodings = use_positional_encodings
    self.model = DiffusionWrapper(unet_config, conditioning_key)
    count_params(self.model, verbose=True)
    self.use_ema = use_ema
    if self.use_ema:
        self.model_ema = LitEma(self.model)
        print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")

    self.use_scheduler = scheduler_config is not None
    if self.use_scheduler:
        self.scheduler_config = scheduler_config

    self.v_posterior = v_posterior
    self.original_elbo_weight = original_elbo_weight
    self.l_simple_weight = l_simple_weight

    if monitor is not None:
        self.monitor = monitor
    if ckpt_path is not None:
        self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet)

    self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps,
                           linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)

    self.loss_type = loss_type

    self.learn_logvar = learn_logvar
    self.logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,))
    if self.learn_logvar:
        self.logvar = nn.Parameter(self.logvar, requires_grad=True)

参数:

  • unet_config. 字典, UNet 的配置参数.
  • timesteps. 扩散模型的总时间步数. (默认值 1000)
  • beta_schedule. 扩散过程中噪声水平 β t \beta_t βt 的调整策略. (默认值 'linear', 线性增大 β t \beta_t βt)
  • loss_type. 计算噪声预测误差的损失函数的类型. (默认值 'l2', 均方误差)
  • ckpt_path. 加载 checkpoint 文件的路径. (默认值 None, 不加载)
  • ignore_keys. 使用 checkpoint 加载模型时忽略的键列表. (默认值 [], 不忽略)
  • load_only_unet. 是否只加载 UNet 的权重. (默认值 False)
  • monitor. 在训练过程中用于监控模型好坏的指标. (默认值 'val/loss', 测试损失)
  • use_ema. 是否使用指数移动平均 (EMA) 来平滑模型参数. (默认值 True)
  • first_stage_key. 第一阶段模型中使用的键名. (默认值 'image')
  • image_size. 图像的大小. (默认值 256, 图像大小为 256 × 256 256\times256 256×256)
  • channels. 图像的通道数. (默认值 3)
  • log_every_t. 在生成过程中每隔多少时间步 t t t 记录一次图片. (默认值 100)
  • clip_denoised. 是否将噪音裁剪至 ( − 1.0 , 1.0 ) (-1.0, 1.0) (1.0,1.0) 区间. (默认值 True)
  • linear_start. β 0 \beta_0 β0 的值. (默认值 1e-4)
  • linear_end. β T \beta_T βT 的值. (默认值 2e-2)
  • cosine_s. 只在使用余弦增加 β t \beta_t βt 时有效, 控制余弦增大的参数. (默认值 8e-3)
  • given_betas. 直接给定一组 [ β t ] t = 0 T [\beta_t]_{t=0}^T [βt]t=0T. (默认值 None)
  • original_elbo_wight. 损失函数中使用原始证据下界 (ELBO) 的权重 (默认值 0.)
  • v_posterior. 用于选择后验方差的权重 v v v. σ t = ( 1 − v ) β ~ t + v β t \sigma_t=(1-v)\tilde\beta_t+v\beta_t σt=(1v)β~t+vβt. (默认值 0.)
  • l_simple_weight. 简单损失的权重. (默认值 1.)
  • conditioning_key. 使用条件生成时, 条件数据的键. (默认值 None)
  • parameterization. 模型参数化的方式, 即 UNet 预测原始图像还是噪声. (默认值 'eps')
  • scheduler_config. 字典, 优化器的配置参数. (默认值 None)
  • use_positional_encodings. 是否使用位置编码. (默认值 False)
  • learn_logvar. 是否学习对数方差的参数. (默认值 False)
  • logvar_init. 对数方差的初始值. (默认值 0.)

我们看构造函数有一行:

self.model = DiffusionWrapper(unet_config, conditioning_key)

这个其实就是 UNet, DiffusionWrapper 类就是实现有条件 diffusion 和无条件 diffusion 的, 它的代码如下:

class DiffusionWrapper(pl.LightningModule):
 def __init__(self, diff_model_config, conditioning_key):
     super().__init__()
     self.diffusion_model = instantiate_from_config(diff_model_config)
     self.conditioning_key = conditioning_key
     assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm']

 def forward(self, x, t, c_concat: list = None, c_crossattn: list = None):
     if self.conditioning_key is None:
         out = self.diffusion_model(x, t)
     elif self.conditioning_key == 'concat':
         xc = torch.cat([x] + c_concat, dim=1)
         out = self.diffusion_model(xc, t)
     elif self.conditioning_key == 'crossattn':
         cc = torch.cat(c_crossattn, 1)
         out = self.diffusion_model(x, t, context=cc)
     elif self.conditioning_key == 'hybrid':
         xc = torch.cat([x] + c_concat, dim=1)
         cc = torch.cat(c_crossattn, 1)
         out = self.diffusion_model(xc, t, context=cc)
     elif self.conditioning_key == 'adm':
         cc = c_crossattn[0]
         out = self.diffusion_model(x, t, y=cc)
     else:
         raise NotImplementedError()

     return out

自然地, 我们的目光再次回到了 instantiate_from_config 函数. 我们要看配置文件的描述. 在 latent-diffusion/config/latent-diffusion/celebahq-ldm-vq-4.yaml 中, 找到:

unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel

因此, 我们去看 UNetModel 类. (点击跳转)

构造函数中还有这样一行:

if self.use_ema:
 self.model_ema = LitEma(self.model)

这里是使用指数移动平均 (EMA) 来平滑模型参数. (代码略)

点击此处返回 DDPM.

注册 β \beta β α \alpha α 时间表
def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
                      linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
    if exists(given_betas):
        betas = given_betas  # 给定beta
    else:
        betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,
                                   cosine_s=cosine_s)  # 计算beta
    alphas = 1. - betas  # 这是定义, \alpha_t = 1 - \beta_t
    alphas_cumprod = np.cumprod(alphas, axis=0)  # 计算 \prod_{j}^{i}\alpha_{i}
    # cumprod用于计算数组元素的乘积, 返回一个新的数组, 每个元素是到目前为止所有元素的乘积
    alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])  # 去掉最后一个元素, 在最前面加1

    timesteps, = betas.shape  # 总时间步数
    self.num_timesteps = int(timesteps)
    self.linear_start = linear_start
    self.linear_end = linear_end
    assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'

    to_torch = partial(torch.tensor, dtype=torch.float32)
    # to_torch(x) 等价于: torch.tensor(x, dtypye=torch.float32)

    # 下面这些都是用来计算中间参数的, 用于给不同的函数直接调用下面的这些参数
    # register_buffer是torch.nn.Module的一个方法, 用于将一个tensor添加到模型的缓冲区
    # 缓冲区不会被视作模型参数, 不参与梯度更新
    # DDPM中的\beta_t数组:
    self.register_buffer('betas', to_torch(betas))
    # DDPM中的\bar\alpha_t数组:
    self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
    # 为了在计算涉及到前一个时间步的公式时方便引用: 
    self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
    # alphas_cumprod_prev[t] 实际上表示的是在第 t-1 个时间步后,原始信号剩余的比例

    # calculations for diffusion q(x_t | x_{t-1}) and others
    # 计算 \sqrt{\bar\alpha_t}: 
    self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
    # 计算 \sqrt{(1-\bar\alpha_t)}:
    self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
    # 计算 \log{(1-\bar\alpha_t)}:
    self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
    # 计算 \sqrt{\frac{1}{\bar\alpha_t}}:
    self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
    # 计算 \sqrt{\frac{1}{\bar\alpha_t} - 1}:
    self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))

    # calculations for posterior q(x_{t-1} | x_t, x_0)
    posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / (
                1. - alphas_cumprod) + self.v_posterior * betas
    # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
    self.register_buffer('posterior_variance', to_torch(posterior_variance))
    # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
    self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20))))
    self.register_buffer('posterior_mean_coef1', to_torch(
        betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)))
    self.register_buffer('posterior_mean_coef2', to_torch(
        (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod)))

    if self.parameterization == "eps":
        lvlb_weights = self.betas ** 2 / (
                    2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod))
    elif self.parameterization == "x0":
        lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2. * 1 - torch.Tensor(alphas_cumprod))
    else:
        raise NotImplementedError("mu not supported")
    # TODO how to choose this term
    lvlb_weights[0] = lvlb_weights[1]
    self.register_buffer('lvlb_weights', lvlb_weights, persistent=False)
    assert not torch.isnan(self.lvlb_weights).all()

快捷返回: DDPM | 前向过程 (加噪) | 预测原始图像 | 前向过程 (加噪) 的后验分布 | 反向过程 (去噪) 的损失.

前向过程 (加噪)
def q_mean_variance(self, x_start, t):
    # 计算 q(x_t | x_0)
    mean = (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start)
    variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
    log_variance = extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)
    return mean, variance, log_variance

这个函数是计算DDPM中的这一公式:
q ( x t ∣ x 0 ) = N ( x t ; α ˉ t x 0 , ( 1 − α ˉ t ) I ) q(x_t|x_0)=\mathcal N(x_t;\sqrt{\bar\alpha_t}x_0, (1-\bar\alpha_t)\mathbf I) q(xtx0)=N(xt;αˉt x0,(1αˉt)I)
输入:

  • x_start. 原始图像 x 0 x_0 x0.
  • t. 时间步 t t t.

返回:

  • mean. 分布 q ( x t ∣ x 0 ) q(x_t|x_0) q(xtx0) 的均值.
  • variance. 分布 q ( x t ∣ x 0 ) q(x_t|x_0) q(xtx0) 的方差.
  • log_variance. 分布 q ( x t ∣ x 0 ) q(x_t|x_0) q(xtx0) 的对数方差.

这里: (这些在上面的 register_schedule 方法中定义, 点击跳转)

  • self.sqrt_alphas_cumprod 代表 α ˉ t \sqrt{\bar\alpha_t} αˉt 数组.
  • self.alphas_cumprod 代表 α ˉ t \bar\alpha_t αˉt 数组.
  • self.log_one_minus_alphas_cumprod 代表 log ⁡ ( 1 − α ˉ t ) \log{(1-\bar\alpha_t)} log(1αˉt) 数组.

这里的 extract_into_tensor(a, t, x_shape) 表示从数组 a 中拿取第 t 个元素, 并 reshape 为兼容 x_shape 的形状的 torch.tensor 对象.

点击此处返回 DDPM.

预测原始图像
def predict_start_from_noise(self, x_t, t, noise):
    return (
            extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
            extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
    )

这个函数是计算DDPM中的这一公式:
x ^ 0 = 1 α ˉ t x t − 1 α ˉ t − 1 ⋅ ϵ ^ \hat x_0=\sqrt{\frac{1}{\bar\alpha_t}} x_t-\sqrt{\frac{1}{\bar\alpha_t}-1}\cdot\hat\epsilon x^0=αˉt1 xtαˉt11 ϵ^
输入:

  • x_t. 带噪音的图像 x t x_t xt.
  • t. 时间步 t t t.
  • noise. 预测的噪音 ϵ ^ \hat\epsilon ϵ^.

输出:

  • 预测的原始图像 x ^ 0 \hat x_0 x^0.

这里: (这些在上面的 register_schedule 方法中定义, 点击跳转)

  • self.sqrt_recip_alphas_cumprod 代表 1 / α ˉ t \sqrt{1/\bar\alpha_t} 1/αˉt 数组.
  • self.sqrt_recipm1_alphas_cumprod 代表 1 / α ˉ t − 1 \sqrt{1/\bar\alpha_t-1} 1/αˉt1 数组.

这里的 extract_into_tensor(a, t, x_shape) 表示从数组 a 中拿取第 t 个元素, 并 reshape 为兼容 x_shape 的形状的 torch.tensor 对象.

点击此处返回 DDPM.

前向过程 (加噪) 的后验分布
def q_posterior(self, x_start, x_t, t):
    posterior_mean = (
            extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start +
            extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
    )
    posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape)
    posterior_log_variance_clipped = extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape)
    return posterior_mean, posterior_variance, posterior_log_variance_clipped

这个函数是计算DDPM中的这一公式:
q ( x t − 1 ∣ x t , x 0 ) = N ( x t − 1 ; μ ~ t ( x t , x 0 ) , β ~ t I ) q(x_{t-1}|x_t, x_0)=\mathcal N(x_{t-1};\tilde\mu_t(x_t,x_0),\tilde\beta_t\mathbf I) q(xt1xt,x0)=N(xt1;μ~t(xt,x0),β~tI)
其中,
μ ~ t ( x t , x 0 ) : = α ˉ t − 1 β t 1 − α ˉ t x 0 + α t ( 1 − α ˉ t − 1 ) 1 − α ˉ t x t \tilde \mu_t(x_t, x_0):=\frac{\sqrt{\bar\alpha_{t-1}}\beta_t}{1-\bar\alpha_t}x_0+\frac{\sqrt{\alpha_t}(1-\bar\alpha_{t-1})}{1-\bar\alpha_t}x_t μ~t(xt,x0):=1αˉtαˉt1 βtx0+1αˉtαt (1αˉt1)xt

β ~ t : = 1 − α ˉ t − 1 1 − α ˉ t β t \tilde\beta_t:=\frac{1-\bar\alpha_{t-1}}{1-\bar\alpha_t}\beta_t β~t:=1αˉt1αˉt1βt

输入:

  • x_start. 原始图像 x 0 x_0 x0.
  • x_t. 带噪音图像 x t x_t xt.
  • t. 时间步 t t t.

输出:

  • posterior_mean. 后验分布 q ( x t − 1 ∣ x t , x 0 ) q(x_{t-1}|x_t, x_0) q(xt1xt,x0) 的均值.
  • posterior_variance. 后验分布 q ( x t − 1 ∣ x t , x 0 ) q(x_{t-1}|x_t, x_0) q(xt1xt,x0) 的方差.
  • posterior_log_variance_clipped. 后验分布 q ( x t − 1 ∣ x t , x 0 ) q(x_{t-1}|x_t, x_0) q(xt1xt,x0) 的对数方差.

这里: (这些在上面的 register_schedule 方法中定义, 点击跳转)

  • self.posterior_mean_coef1 代表 α ˉ t − 1 β t 1 − α ˉ t \frac{\sqrt{\bar\alpha_{t-1}}\beta_t}{1-\bar\alpha_t} 1αˉtαˉt1 βt 数组.
  • self.posterior_mean_coef2 代表 α t ( 1 − α ˉ t − 1 ) 1 − α ˉ t \frac{\sqrt{\alpha_t}(1-\bar\alpha_{t-1})}{1-\bar\alpha_t} 1αˉtαt (1αˉt1) 数组.
  • self.posterior_variance 代表 1 − α ˉ t − 1 1 − α ˉ t β t \frac{1-\bar\alpha_{t-1}}{1-\bar\alpha_t}\beta_t 1αˉt1αˉt1βt 数组.
  • self.posterior_log_variance_clipped 代表 log ⁡ ( 1 − α ˉ t − 1 1 − α ˉ t β t ) \log\left(\frac{1-\bar\alpha_{t-1}}{1-\bar\alpha_t}\beta_t\right) log(1αˉt1αˉt1βt) 数组.

这里的 extract_into_tensor(a, t, x_shape) 表示从数组 a 中拿取第 t 个元素, 并 reshape 为兼容 x_shape 的形状的 torch.tensor 对象.

快捷返回: 点击此处返回 DDPM | 反向过程 (去噪)

反向过程 (去噪)
def p_mean_variance(self, x, t, clip_denoised: bool):
    model_out = self.model(x, t)
    if self.parameterization == "eps":
        # 模型预测的是噪音
        x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
    elif self.parameterization == "x0":
        # 模型预测的是去噪图像
        x_recon = model_out
    if clip_denoised:
        # 是否将图像的值裁剪到(-1.0, 1.0)区间
        x_recon.clamp_(-1., 1.)
    # 计算后验分布 q(x_{t-1} | x_t, \hat x_0), 用这个分布估计分布 p_{\theta}(x_{t-1}|x_t)
    model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
    return model_mean, posterior_variance, posterior_log_variance

这个函数是计算DDPM中的这一公式:
p θ ( x t − 1 ∣ x t ) : = N ( x t − 1 ; μ θ ( x t , t ) , Σ θ ( x t , t ) ) p_{\theta}(x_{t-1}|x_t):=\mathcal N(x_{t-1};\mu_{\theta}(x_t, t), \Sigma_{\theta}(x_t, t)) pθ(xt1xt):=N(xt1;μθ(xt,t),Σθ(xt,t))
输入:

  • x. 当前时间步的带噪音图像 x t x_t xt.
  • t. 时间步 t t t.
  • clip_denoised. 是否将图像的值裁剪到 ( − 1.0 , 1.0 ) (-1.0, 1.0) (1.0,1.0) 区间.

输出:

  • model_mean. 模型预测的均值 μ θ ( x t , t ) \mu_{\theta}(x_t, t) μθ(xt,t).
  • posterior_variance. 模型预测的方差 Σ θ ( x t , t ) \Sigma_{\theta}(x_t, t) Σθ(xt,t).
  • posterior_log_variance. 模型预测的对数方差 log ⁡ ( Σ θ ( x t , t ) ) \log(\Sigma_{\theta}(x_t, t)) log(Σθ(xt,t)).

这里的 self.q_posterior 是用前向过程的后验分布来近似反向过程的分布, 定义见: 点击跳转.

快捷返回: 点击此处返回 DDPM

采样图像
@torch.no_grad()
def p_sample(self, x, t, clip_denoised=True, repeat_noise=False):
    b, *_, device = *x.shape, x.device
    model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, clip_denoised=clip_denoised)
    noise = noise_like(x.shape, device, repeat_noise)  # 返回一个和x一样形状的标准高斯噪音noise
    # repeat_noise表示是否重复使用一个噪音, 若重复使用, 一个batch内的所有样本将加同一个随机噪音; 否则每个样本将独立采样
    # nonzero_mask表示是否有噪音, t=0时无噪音(为0), 其它时候有噪音(为1)
    nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
    return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise

函数 p_sample 是在分布 p θ ( x t − 1 ∣ x t ) p_{\theta}(x_{t-1}|x_t) pθ(xt1xt) 中采样一个 x t − 1 x_{t-1} xt1.

输入:

  • x. 当前时间步的样本 x t x_t xt.
  • t. 当前时间步 t t t.
  • clip_denoised. 是否对噪音裁剪到区间 ( − 1.0 , 1.0 ) (-1.0, 1.0) (1.0,1.0) 内.
  • repeat_noise. 是否在一个批量中对所有样本重复使用同一个噪音.

输出:

  • x 相同形状的, 下一个时间步中的一个批量的样本 x t − 1 x_{t-1} xt1.
@torch.no_grad()
def p_sample_loop(self, shape, return_intermediates=False):
    device = self.betas.device
    b = shape[0]
    img = torch.randn(shape, device=device)
    intermediates = [img]
    for i in tqdm(reversed(range(0, self.num_timesteps)), desc='Sampling t', total=self.num_timesteps):
        # i从T-1到0
        # t = torch.full((b,), i, device=device, dtype=torch.long)
        img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long),
                            clip_denoised=self.clip_denoised)
        # img是x_i
        if i % self.log_every_t == 0 or i == self.num_timesteps - 1:
            intermediates.append(img)  # 存中间的图像
    if return_intermediates:
        return img, intermediates
    return img  # x_0

输入:

  • shape. 图像的形状.
  • return_intermediates. 是否返回反向过程中的中间图像.

输出:

  • img. 生成的图像 x 0 x_0 x0.
  • intermediates. 一个列表, 存了中间图像.
@torch.no_grad()
def sample(self, batch_size=16, return_intermediates=False):
    image_size = self.image_size
    channels = self.channels
    return self.p_sample_loop((batch_size, channels, image_size, image_size),
                              return_intermediates=return_intermediates)

给定 batch_size, 直接生成一个样本 x 0 x_0 x0.

点击返回 DDPM.

模拟扩散过程
def q_sample(self, x_start, t, noise=None):
    noise = default(noise, lambda: torch.randn_like(x_start))
    # 如果noise不是None, 直接返回noise, 否则生成一个和x_start一样形状的noise
    return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
            extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)

这个函数通过DDPM扩散过程:
q ( x t ∣ x 0 ) = N ( x t ; α ˉ t x 0 , ( 1 − α ˉ t ) I ) q(x_t|x_0)=\mathcal N(x_t;\sqrt{\bar\alpha_t}x_0, (1-\bar\alpha_t)\mathbf I) q(xtx0)=N(xt;αˉt x0,(1αˉt)I)
在分布 q ( x t ∣ x 0 ) q(x_t|x_0) q(xtx0) 中采样一个 x t x_t xt.

输入:

  • x_start. 原始的图像 x 0 x_0 x0.
  • t. 时间步 t t t.
  • noise. 噪音, 如果为 None 则默认为和 x_start 一样的标准高斯噪音样本.

输出:

  • x_start 相同形状的样本 x t x_t xt.

点击此处返回 DDPM.

噪音预测损失
def get_loss(self, pred, target, mean=True):
    if self.loss_type == 'l1':
        loss = (target - pred).abs()
        if mean:
            loss = loss.mean()
    elif self.loss_type == 'l2':
        if mean:
            loss = torch.nn.functional.mse_loss(target, pred)
        else:
            loss = torch.nn.functional.mse_loss(target, pred, reduction='none')
    else:
        raise NotImplementedError("unknown loss type '{loss_type}'")

    return loss

这个函数用于计算噪音的预测损失. 代码非常简单, 不多解释.

输入:

  • pred. 预测的噪音.
  • target. 真实的噪音.
  • mean. 是否将噪音平均为标量.

输出:

  • loss. 损失.

点击此处返回 DDPM.

反向过程 (去噪) 的损失
def p_losses(self, x_start, t, noise=None):
    # noise是当前时间步t加入的噪音
    noise = default(noise, lambda: torch.randn_like(x_start))
    # x_noisy是从x_0开始执行加噪过程采样得到的样本
    x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
    # model_out是UNet预测的时间步t加入的噪音
    model_out = self.model(x_noisy, t)

    loss_dict = {}
    if self.parameterization == "eps":
        target = noise
    elif self.parameterization == "x0":
        target = x_start
    else:
        raise NotImplementedError(f"Paramterization {self.parameterization} not yet supported")

    # 计算模型预测噪音(或者图像)的损失
    loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3])
    log_prefix = 'train' if self.training else 'val'
    loss_dict.update({f'{log_prefix}/loss_simple': loss.mean()})
    # 简单损失 = loss * 权重
    loss_simple = loss.mean() * self.l_simple_weight
    # 变分下界(VLB)损失 = 时间步t对应的权重 * loss
    loss_vlb = (self.lvlb_weights[t] * loss).mean()
    loss_dict.update({f'{log_prefix}/loss_vlb': loss_vlb})
    # 总损失 = 简单损失 + 原始ELBO损失权重 * 变分下界(VLB)损失
    loss = loss_simple + self.original_elbo_weight * loss_vlb
    loss_dict.update({f'{log_prefix}/loss': loss})
    return loss, loss_dict

该函数以 x 0 x_0 x0 t t t 为输入, 先执行扩散过程 q ( x t ∣ x 0 ) q(x_t|x_0) q(xtx0), 采样得到一个噪音图像样本 x t x_t xt. 然后用 UNet 去预测时间步 t t t 中加入的噪音, 并计算预测损失. 这个计算根据 DDPM 中的下面公式:
L = E x 0 ∼ q ( x 0 ) , ϵ ∼ N ( 0 , I ) [ β t 2 2 β ~ t α t ( 1 − α ˉ t ) ∣ ∣ ϵ − ϵ θ ( α ˉ t x 0 + 1 − α ˉ t ϵ ) ∣ ∣ 2 ] L=\mathbb E_{x_0\sim q(x_0), \epsilon\sim\mathcal N(0, \mathbf I)}\left[\frac{\beta_t^2}{2\tilde\beta_t\alpha_t(1-\bar\alpha_t)}\left|\left|\epsilon-\epsilon_{\theta}(\sqrt{\bar\alpha_t}x_0+\sqrt{1-\bar\alpha_t}\epsilon)\right|\right|^2\right] L=Ex0q(x0),ϵN(0,I)[2β~tαt(1αˉt)βt2 ϵϵθ(αˉt x0+1αˉt ϵ) 2]
输入:

  • x_start. 原始图像 x 0 x_0 x0.
  • t. 时间步 t t t.
  • nosie. 噪音样本.

输出:

  • loss. 总的损失.
  • loss_dict. 记录损失的字典. (用于日志)

这里, self.lvlb_weights 表示 β t 2 2 β ~ t α t ( 1 − α ˉ t ) \frac{\beta_t^2}{2\tilde\beta_t\alpha_t(1-\bar\alpha_t)} 2β~tαt(1αˉt)βt2 数组, 在上面的 register_schedule 方法中定义, 点击跳转.

点击此处返回 DDPM.

前向过程
def forward(self, x, *args, **kwargs):
    # b, c, h, w, device, img_size, = *x.shape, x.device, self.image_size
    # assert h == img_size and w == img_size, f'height and width of image must be {img_size}'
    # 随机一个时间步, 在(0, T)之间随机
    t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
    # 返回预测噪音损失.
    return self.p_losses(x, t, *args, **kwargs)

输入: x 原始图像.

输出: 损失, 损失字典.

点击此处返回 DDPM.

训练
def training_step(self, batch, batch_idx):
    loss, loss_dict = self.shared_step(batch)

    self.log_dict(loss_dict, prog_bar=True,
                  logger=True, on_step=True, on_epoch=True)

    self.log("global_step", self.global_step,
             prog_bar=True, logger=True, on_step=True, on_epoch=False)

    if self.use_scheduler:
        lr = self.optimizers().param_groups[0]['lr']
        self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False)

    return loss

训练代码非常简单, 不多解释.

点击此处返回 DDPM.

测试
@torch.no_grad()
def validation_step(self, batch, batch_idx):
    _, loss_dict_no_ema = self.shared_step(batch)
    with self.ema_scope():
        _, loss_dict_ema = self.shared_step(batch)
        loss_dict_ema = {key + '_ema': loss_dict_ema[key] for key in loss_dict_ema}
    self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
    self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)

测试代码非常简单, 不多解释.

点击此处返回 DDPM.


UNetModel

位置: latent-diffusion/modules/diffusionmodules/openaimodel.py

这个类实现了UNet. 主要只有两个方法: __init__forward. 模型结构都写在构造函数里了, 为了简单起见, 我们先看前向过程 forward.

点击返回 DDPM.

前向过程
def forward(self, x, timesteps=None, context=None, y=None,**kwargs):
    """
    Apply the model to an input batch.
    :param x: an [N x C x ...] Tensor of inputs.
    :param timesteps: a 1-D batch of timesteps.
    :param context: conditioning plugged in via crossattn
    :param y: an [N] Tensor of labels, if class-conditional.
    :return: an [N x C x ...] Tensor of outputs.
    """
    assert (y is not None) == (
        self.num_classes is not None
    ), "must specify y if and only if the model is class-conditional"
    hs = []  # 用于存储各层的feature map, 做UNet里的skip connection
    # 计算timestep embedding
    t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
    emb = self.time_embed(t_emb)

    if self.num_classes is not None:
        assert y.shape == (x.shape[0],)
        # 计算类别embedding
        emb = emb + self.label_emb(y)

    h = x.type(self.dtype)
    for module in self.input_blocks:  # UNet的下采样过程
        h = module(h, emb, context)
        hs.append(h)
    h = self.middle_block(h, emb, context)
    for module in self.output_blocks:  # UNet的上采样过程
        h = th.cat([h, hs.pop()], dim=1)
        h = module(h, emb, context)
    h = h.type(x.dtype)
    if self.predict_codebook_ids:
        return self.id_predictor(h)
    else:
        return self.out(h)

这个代码十分简单直接, 我们简单梳理下:

输入:

  • x. UNet的图像输入 x t x_t xt.
  • timesteps. 时间步 t t t.
  • context. 用于互注意力的条件.
  • y. 图像的类别条件, 即标签.

输出:

  • UNet的输出. 在Diffusion里, 这可以是对原图像的预测, 也可以是对噪音的预测.

在代码中, 有几个重要的东西:

  • timestep_embedding 函数. 根据给定的时间步 timesteps 得到一个 time embedding, 使用余弦编码. (点击跳转)

  • self.time_embed. 将余弦编码的 time embedding 线性映射为最终的 time emebdding, 让模型自己去学习 embedding. 代码如下:

    self.time_embed = nn.Sequential(
        nn.Linear(model_channels, time_embed_dim),
        nn.SiLU(),
        nn.Linear(time_embed_dim, time_embed_dim),
    )
    
  • self.label_emb. 将图像one-hot标签 y 映射为 label embedding. 代码如下:

    self.label_emb = nn.Embedding(num_classes, time_embed_dim)
    

下面的这三个都用到了 TimestepEmbedSequential 类, 点击跳转.

下采样

下面的函数是构造函数的片段.

self._feature_size = model_channels
input_block_chans = [model_channels]  # 存储下采样每一层的通道数
ch = model_channels
ds = 1
# channel_mult表示了每个下采样层的通道倍数
for level, mult in enumerate(channel_mult):
    # 对每个下采样层, 有num_res_blocks个ResBlock
    for _ in range(num_res_blocks):
        layers = [
            ResBlock(
                ch,
                time_embed_dim,
                dropout,
                out_channels=mult * model_channels,
                dims=dims,
                use_checkpoint=use_checkpoint,
                use_scale_shift_norm=use_scale_shift_norm,
            )
        ]
        ch = mult * model_channels
        # 这一分辨率是否需要attention
        if ds in attention_resolutions:
            if num_head_channels == -1:
                dim_head = ch // num_heads
            else:
                num_heads = ch // num_head_channels
                dim_head = num_head_channels
            if legacy:
                #num_heads = 1
                dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
            layers.append(
                AttentionBlock(
                    ch,
                    use_checkpoint=use_checkpoint,
                    num_heads=num_heads,
                    num_head_channels=dim_head,
                    use_new_attention_order=use_new_attention_order,
                ) if not use_spatial_transformer else SpatialTransformer(
                    ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
                )
            )
        self.input_blocks.append(TimestepEmbedSequential(*layers))
        self._feature_size += ch
        input_block_chans.append(ch)
    # 是否下采样的最后一个级别
    if level != len(channel_mult) - 1:
        # 不是, 因此要做下采样
        out_ch = ch
        self.input_blocks.append(
            TimestepEmbedSequential(
                ResBlock(
                    ch,
                    time_embed_dim,
                    dropout,
                    out_channels=out_ch,
                    dims=dims,
                    use_checkpoint=use_checkpoint,
                    use_scale_shift_norm=use_scale_shift_norm,
                    down=True,
                )
                if resblock_updown  # resblock_updown表示是否使用ResBlock做上采样/下采样
                else Downsample(
                    ch, conv_resample, dims=dims, out_channels=out_ch
                )
            )
        )
        ch = out_ch
        input_block_chans.append(ch)
        ds *= 2  # 更新分辨率
        self._feature_size += ch

这里涉及几个类:

点击返回 UNetModel.

中间层

# 中间层: ResBlock -> AttentionBlock -> ResBlock
if num_head_channels == -1:
    dim_head = ch // num_heads
else:
    num_heads = ch // num_head_channels
    dim_head = num_head_channels
if legacy:
    #num_heads = 1
    dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
self.middle_block = TimestepEmbedSequential(
    ResBlock(
        ch,
        time_embed_dim,
        dropout,
        dims=dims,
        use_checkpoint=use_checkpoint,
        use_scale_shift_norm=use_scale_shift_norm,
    ),
    AttentionBlock(
        ch,
        use_checkpoint=use_checkpoint,
        num_heads=num_heads,
        num_head_channels=dim_head,
        use_new_attention_order=use_new_attention_order,
    ) if not use_spatial_transformer else SpatialTransformer(
                    ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
                ),
    ResBlock(
        ch,
        time_embed_dim,
        dropout,
        dims=dims,
        use_checkpoint=use_checkpoint,
        use_scale_shift_norm=use_scale_shift_norm,
    ),
)
self._feature_size += ch

这里涉及几个类:

点击返回 UNetModel.

上采样

上采样和下采样的代码非常相似, 其实就是逆过程, 代码不多解释了, 如下:

self.output_blocks = nn.ModuleList([])
for level, mult in list(enumerate(channel_mult))[::-1]:
    # 将通道倒过来
    for i in range(num_res_blocks + 1):
        ich = input_block_chans.pop()
        layers = [
            ResBlock(
                ch + ich,
                time_embed_dim,
                dropout,
                out_channels=model_channels * mult,
                dims=dims,
                use_checkpoint=use_checkpoint,
                use_scale_shift_norm=use_scale_shift_norm,
            )
        ]
        ch = model_channels * mult
        if ds in attention_resolutions:
            if num_head_channels == -1:
                dim_head = ch // num_heads
            else:
                num_heads = ch // num_head_channels
                dim_head = num_head_channels
            if legacy:
                #num_heads = 1
                dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
            layers.append(
                AttentionBlock(
                    ch,
                    use_checkpoint=use_checkpoint,
                    num_heads=num_heads_upsample,
                    num_head_channels=dim_head,
                    use_new_attention_order=use_new_attention_order,
                ) if not use_spatial_transformer else SpatialTransformer(
                    ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
                )
            )
        if level and i == num_res_blocks:
            out_ch = ch
            layers.append(
                ResBlock(
                    ch,
                    time_embed_dim,
                    dropout,
                    out_channels=out_ch,
                    dims=dims,
                    use_checkpoint=use_checkpoint,
                    use_scale_shift_norm=use_scale_shift_norm,
                    up=True,
                )
                if resblock_updown
                else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
            )
            ds //= 2
        self.output_blocks.append(TimestepEmbedSequential(*layers))
        self._feature_size += ch

self.out = nn.Sequential(
    normalization(ch),
    nn.SiLU(),
    zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
)
if self.predict_codebook_ids:
    self.id_predictor = nn.Sequential(
    normalization(ch),
    conv_nd(dims, model_channels, n_embed, 1),
    #nn.LogSoftmax(dim=1)  # change to cross_entropy and produce non-normalized logits
)

点击返回 UNetModel.


timestep_embedding 函数

位置: latent-diffusion/modules/diffusionmodules/util.py

def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
    """
    Create sinusoidal timestep embeddings.
    :param timesteps: a 1-D Tensor of N indices, one per batch element.
                      These may be fractional.
    :param dim: the dimension of the output.
    :param max_period: controls the minimum frequency of the embeddings.
    :return: an [N x dim] Tensor of positional embeddings.
    """
    if not repeat_only:
        half = dim // 2
        freqs = torch.exp(
            -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
        ).to(device=timesteps.device)
        args = timesteps[:, None].float() * freqs[None]
        embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
        if dim % 2:
            embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
    else:
        embedding = repeat(timesteps, 'b -> b d', d=dim)
    return embedding

点击返回 UNetModel.


TimestepEmbedSequential

位置: latent-diffusion/modules/diffusionmodules/openaimodel.py

TimestepEmbedSequential 继承了 torch.nn.Sequential 类. 它可以很方便地在模型中加入 timestep 和 condition. 代码如下:

class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
    """
    A sequential module that passes timestep embeddings to the children that
    support it as an extra input.
    """

    def forward(self, x, emb, context=None):
        for layer in self:
            if isinstance(layer, TimestepBlock):
                x = layer(x, emb)
            elif isinstance(layer, SpatialTransformer):
                x = layer(x, context)
            else:
                x = layer(x)
        return x

这里的 TimestepBlock 类是一个很简单的抽象类: 点击跳转.

SpatialTransformer 类是将条件与图像做cross-attention的类: 点击跳转.

快捷返回: UNetModel | UNetModel 下采样 | UNetModel 中间层.


TimestepBlock

位置: latent-diffusion/modules/diffusionmodules/openaimodel.py

class TimestepBlock(nn.Module):
    """
    Any module where forward() takes timestep embeddings as a second argument.
    """

    @abstractmethod
    def forward(self, x, emb):
        """
        Apply the module to `x` given `emb` timestep embeddings.
        """

点击返回 TimestepEmbedSequential.


SpatialTransformer

位置: latent-diffusion/modules/attention.py

class SpatialTransformer(nn.Module):
    """
    Transformer block for image-like data.
    First, project the input (aka embedding)
    and reshape to b, t, d.
    Then apply standard transformer action.
    Finally, reshape to image
    """
    def __init__(self, in_channels, n_heads, d_head,
                 depth=1, dropout=0., context_dim=None):
        super().__init__()
        self.in_channels = in_channels
        inner_dim = n_heads * d_head
        self.norm = Normalize(in_channels)

        self.proj_in = nn.Conv2d(in_channels,
                                 inner_dim,
                                 kernel_size=1,
                                 stride=1,
                                 padding=0)

        self.transformer_blocks = nn.ModuleList(
            [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim)
                for d in range(depth)]
        )

        self.proj_out = zero_module(nn.Conv2d(inner_dim,
                                              in_channels,
                                              kernel_size=1,
                                              stride=1,
                                              padding=0))

    def forward(self, x, context=None):
        # note: if no context is given, cross-attention defaults to self-attention
        b, c, h, w = x.shape
        x_in = x
        x = self.norm(x)
        x = self.proj_in(x)
        x = rearrange(x, 'b c h w -> b (h w) c')
        for block in self.transformer_blocks:
            x = block(x, context=context)
        x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
        x = self.proj_out(x)
        return x + x_in

这里的 BasicTransformerBlock 是很经典的 Transformer, 代码为:

class BasicTransformerBlock(nn.Module):
    def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True):
        super().__init__()
        self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout)  # is a self-attention
        self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
        self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,
                                    heads=n_heads, dim_head=d_head, dropout=dropout)  # is self-attn if context is none
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)
        self.norm3 = nn.LayerNorm(dim)
        self.checkpoint = checkpoint

    def forward(self, x, context=None):
        return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)

    def _forward(self, x, context=None):
        x = self.attn1(self.norm1(x)) + x
        x = self.attn2(self.norm2(x), context=context) + x
        x = self.ff(self.norm3(x)) + x
        return x

这里, CrossAttention 类的代码为:

class CrossAttention(nn.Module):
    def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
        super().__init__()
        inner_dim = dim_head * heads
        context_dim = default(context_dim, query_dim)

        self.scale = dim_head ** -0.5
        self.heads = heads

        self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
        self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
        self.to_v = nn.Linear(context_dim, inner_dim, bias=False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, query_dim),
            nn.Dropout(dropout)
        )

    def forward(self, x, context=None, mask=None):
        h = self.heads

        q = self.to_q(x)
        context = default(context, x)
        k = self.to_k(context)
        v = self.to_v(context)

        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))

        sim = einsum('b i d, b j d -> b i j', q, k) * self.scale

        if exists(mask):
            mask = rearrange(mask, 'b ... -> b (...)')
            max_neg_value = -torch.finfo(sim.dtype).max
            mask = repeat(mask, 'b j -> (b h) () j', h=h)
            sim.masked_fill_(~mask, max_neg_value)

        # attention, what we cannot get enough of
        attn = sim.softmax(dim=-1)

        out = einsum('b i j, b j d -> b i d', attn, v)
        out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
        return self.to_out(out)

FeedForward 的代码为:

class FeedForward(nn.Module):
    def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
        super().__init__()
        inner_dim = int(dim * mult)
        dim_out = default(dim_out, dim)
        project_in = nn.Sequential(
            nn.Linear(dim, inner_dim),
            nn.GELU()
        ) if not glu else GEGLU(dim, inner_dim)

        self.net = nn.Sequential(
            project_in,
            nn.Dropout(dropout),
            nn.Linear(inner_dim, dim_out)
        )

    def forward(self, x):
        return self.net(x)

点击返回 TimestepEmbedSequential.


ResBlock

位置: latent-diffusion/modules/diffusionmodules/openaimodel.py

该类实现了一个基本的带残差连接的块, 代码比较简单, 不多解释:

class ResBlock(TimestepBlock):
    """
    A residual block that can optionally change the number of channels.
    :param channels: the number of input channels.
    :param emb_channels: the number of timestep embedding channels.
    :param dropout: the rate of dropout.
    :param out_channels: if specified, the number of out channels.
    :param use_conv: if True and out_channels is specified, use a spatial
        convolution instead of a smaller 1x1 convolution to change the
        channels in the skip connection.
    :param dims: determines if the signal is 1D, 2D, or 3D.
    :param use_checkpoint: if True, use gradient checkpointing on this module.
    :param up: if True, use this block for upsampling.
    :param down: if True, use this block for downsampling.
    """

    def __init__(
        self,
        channels,
        emb_channels,
        dropout,
        out_channels=None,
        use_conv=False,
        use_scale_shift_norm=False,
        dims=2,
        use_checkpoint=False,
        up=False,
        down=False,
    ):
        super().__init__()
        self.channels = channels
        self.emb_channels = emb_channels
        self.dropout = dropout
        self.out_channels = out_channels or channels
        self.use_conv = use_conv
        self.use_checkpoint = use_checkpoint
        self.use_scale_shift_norm = use_scale_shift_norm

        self.in_layers = nn.Sequential(
            normalization(channels),
            nn.SiLU(),
            conv_nd(dims, channels, self.out_channels, 3, padding=1),
        )

        self.updown = up or down

        if up:
            self.h_upd = Upsample(channels, False, dims)
            self.x_upd = Upsample(channels, False, dims)
        elif down:
            self.h_upd = Downsample(channels, False, dims)
            self.x_upd = Downsample(channels, False, dims)
        else:
            self.h_upd = self.x_upd = nn.Identity()

        self.emb_layers = nn.Sequential(
            nn.SiLU(),
            linear(
                emb_channels,
                2 * self.out_channels if use_scale_shift_norm else self.out_channels,
            ),
        )
        self.out_layers = nn.Sequential(
            normalization(self.out_channels),
            nn.SiLU(),
            nn.Dropout(p=dropout),
            zero_module(
                conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
            ),
        )

        if self.out_channels == channels:
            self.skip_connection = nn.Identity()
        elif use_conv:
            self.skip_connection = conv_nd(
                dims, channels, self.out_channels, 3, padding=1
            )
        else:
            self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)

    def forward(self, x, emb):
        """
        Apply the block to a Tensor, conditioned on a timestep embedding.
        :param x: an [N x C x ...] Tensor of features.
        :param emb: an [N x emb_channels] Tensor of timestep embeddings.
        :return: an [N x C x ...] Tensor of outputs.
        """
        return checkpoint(
            self._forward, (x, emb), self.parameters(), self.use_checkpoint
        )


    def _forward(self, x, emb):
        if self.updown:
            in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
            h = in_rest(x)
            h = self.h_upd(h)
            x = self.x_upd(x)
            h = in_conv(h)
        else:
            h = self.in_layers(x)
        emb_out = self.emb_layers(emb).type(h.dtype)
        while len(emb_out.shape) < len(h.shape):
            emb_out = emb_out[..., None]
        if self.use_scale_shift_norm:
            out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
            scale, shift = th.chunk(emb_out, 2, dim=1)
            h = out_norm(h) * (1 + scale) + shift
            h = out_rest(h)
        else:
            h = h + emb_out
            h = self.out_layers(h)
        return self.skip_connection(x) + h

快捷返回: UNetModel 下采样 | UNetModel 中间层.


AttentionBlock

位置: latent-diffusion/modules/diffusionmodules/openaimodel.py

该类实现了一个Attention块. 代码如下:

class AttentionBlock(nn.Module):
    """
    An attention block that allows spatial positions to attend to each other.
    Originally ported from here, but adapted to the N-d case.
    https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
    """

    def __init__(
        self,
        channels,
        num_heads=1,
        num_head_channels=-1,
        use_checkpoint=False,
        use_new_attention_order=False,
    ):
        super().__init__()
        self.channels = channels
        if num_head_channels == -1:
            self.num_heads = num_heads
        else:
            assert (
                channels % num_head_channels == 0
            ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
            self.num_heads = channels // num_head_channels
        self.use_checkpoint = use_checkpoint
        self.norm = normalization(channels)
        self.qkv = conv_nd(1, channels, channels * 3, 1)
        if use_new_attention_order:
            # split qkv before split heads
            self.attention = QKVAttention(self.num_heads)
        else:
            # split heads before split qkv
            self.attention = QKVAttentionLegacy(self.num_heads)

        self.proj_out = zero_module(conv_nd(1, channels, channels, 1))

    def forward(self, x):
        return checkpoint(self._forward, (x,), self.parameters(), True)   # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
        #return pt_checkpoint(self._forward, x)  # pytorch

    def _forward(self, x):
        b, c, *spatial = x.shape
        x = x.reshape(b, c, -1)
        qkv = self.qkv(self.norm(x))
        h = self.attention(qkv)
        h = self.proj_out(h)
        return (x + h).reshape(b, c, *spatial)

这里, QKVAttention 类的定义如下:

class QKVAttention(nn.Module):
    """
    A module which performs QKV attention and splits in a different order.
    """

    def __init__(self, n_heads):
        super().__init__()
        self.n_heads = n_heads

    def forward(self, qkv):
        """
        Apply QKV attention.
        :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
        :return: an [N x (H * C) x T] tensor after attention.
        """
        bs, width, length = qkv.shape
        assert width % (3 * self.n_heads) == 0
        ch = width // (3 * self.n_heads)
        q, k, v = qkv.chunk(3, dim=1)
        scale = 1 / math.sqrt(math.sqrt(ch))
        weight = th.einsum(
            "bct,bcs->bts",
            (q * scale).view(bs * self.n_heads, ch, length),
            (k * scale).view(bs * self.n_heads, ch, length),
        )  # More stable with f16 than dividing afterwards
        weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
        a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
        return a.reshape(bs, -1, length)

    @staticmethod
    def count_flops(model, _x, y):
        return count_flops_attn(model, _x, y)

QKVAttentionLegacy 类的定义如下:

class QKVAttentionLegacy(nn.Module):
    """
    A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
    """

    def __init__(self, n_heads):
        super().__init__()
        self.n_heads = n_heads

    def forward(self, qkv):
        """
        Apply QKV attention.
        :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
        :return: an [N x (H * C) x T] tensor after attention.
        """
        bs, width, length = qkv.shape
        assert width % (3 * self.n_heads) == 0
        ch = width // (3 * self.n_heads)
        q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
        scale = 1 / math.sqrt(math.sqrt(ch))
        weight = th.einsum(
            "bct,bcs->bts", q * scale, k * scale
        )  # More stable with f16 than dividing afterwards
        weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
        a = th.einsum("bts,bcs->bct", weight, v)
        return a.reshape(bs, -1, length)

    @staticmethod
    def count_flops(model, _x, y):
        return count_flops_attn(model, _x, y)

快捷返回: UNetModel 下采样 | UNetModel 中间层.

Downsample

位置: latent-diffusion/modules/diffusionmodules/openaimodel.py

该类实现了在UNet中的下采样模块. 代码如下:

class Downsample(nn.Module):
    """
    A downsampling layer with an optional convolution.
    :param channels: channels in the inputs and outputs.
    :param use_conv: a bool determining if a convolution is applied.
    :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
                 downsampling occurs in the inner-two dimensions.
    """

    def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1):
        super().__init__()
        self.channels = channels
        self.out_channels = out_channels or channels
        self.use_conv = use_conv
        self.dims = dims
        stride = 2 if dims != 3 else (1, 2, 2)
        if use_conv:
            self.op = conv_nd(
                dims, self.channels, self.out_channels, 3, stride=stride, padding=padding
            )
        else:
            assert self.channels == self.out_channels
            self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)

    def forward(self, x):
        assert x.shape[1] == self.channels
        return self.op(x)

UNetModel.


Logo

旨在为数千万中国开发者提供一个无缝且高效的云端环境,以支持学习、使用和贡献开源项目。

更多推荐