引言

上一节介绍了基于平均场假设的变分推断与广义EM算法的关系,本节将介绍通过随机梯度的思想实现变分推断

回顾:基于平均场假设的变分推断

基于平均场假设的变分推断通常称为经典变分推断(Classical Variational Inference)。其核心自然是 平均场假设:将隐变量 Z \mathcal Z Z的概率分布 Q ( Z ) \mathcal Q(\mathcal Z) Q(Z)看做 M \mathcal M M个独立的子概率分布
Q ( Z ) = ∏ i = 1 M Q i ( Z ( i ) ) \mathcal Q(\mathcal Z) = \prod_{i=1}^{\mathcal M} \mathcal Q_i(\mathcal Z^{(i)}) Q(Z)=i=1MQi(Z(i))
其迭代过程的思想是坐标上升法(Coordinate Ascent):

  • 求解 Q j ( Z ( j ) ) \mathcal Q_j(\mathcal Z^{(j)}) Qj(Z(j)),固定 Q j ( Z ( j ) ) \mathcal Q_j(\mathcal Z^{(j)}) Qj(Z(j))外的所有分布,并将求解出的 Q ^ i ( Z ( i ) ) \hat {\mathcal Q}_i(\mathcal Z^{(i)}) Q^i(Z(i))替换原始的 Q j ( Z ( j ) ) \mathcal Q_j(\mathcal Z^{(j)}) Qj(Z(j))
    Q ^ j ( Z ( j ) ) = arg ⁡ max ⁡ Q j ( Z ( j ) ) { − K L [ ϕ ^ ( X , Z ( j ) ) ∣ ∣ Q j ( Z ( j ) ) ] } Q ( Z ) = Q 1 ( Z ( 1 ) ) × ⋯ × Q ^ j ( Z ( j ) ) × ⋯ × Q M ( Z ( M ) ) \hat {\mathcal Q}_j (\mathcal Z^{(j)}) = \mathop{\arg\max}\limits_{\mathcal Q_j(\mathcal Z^{(j)})} \left\{-\mathcal K\mathcal L \left[\hat \phi (\mathcal X,\mathcal Z^{(j)}) || \mathcal Q_j(\mathcal Z^{(j)})\right]\right\} \\ \mathcal Q(\mathcal Z) = \mathcal Q_1(\mathcal Z^{(1)}) \times \cdots \times \hat {\mathcal Q}_j(\mathcal Z^{(j)}) \times \cdots\times \mathcal Q_{\mathcal M}(\mathcal Z^{(\mathcal M)}) Q^j(Z(j))=Qj(Z(j))argmax{KL[ϕ^(X,Z(j))∣∣Qj(Z(j))]}Q(Z)=Q1(Z(1))××Q^j(Z(j))××QM(Z(M))
  • 重复上述步骤,最终第一次迭代结果得到如下形式:
    Q ( Z ) = Q ^ 1 ( Z ( 1 ) ) × ⋯ × Q ^ M ( Z ( M ) ) \mathcal Q(\mathcal Z) = \hat {\mathcal Q}_1(\mathcal Z^{(1)}) \times \cdots \times \hat {\mathcal Q}_{\mathcal M}(\mathcal Z^{(\mathcal M)}) Q(Z)=Q^1(Z(1))××Q^M(Z(M))
  • 继续迭代,直到 Q ( Z ) \mathcal Q(\mathcal Z) Q(Z)结果稳定且收敛。

经典变分推断的问题

虽然通过坐标上升法能够近似求解隐变量 Z \mathcal Z Z的最优后验概率分布 P ( Z ∣ X ) P(\mathcal Z \mid \mathcal X) P(ZX),但 经典变分推断 的问题也是显而易见的:平均场假设这个假设本身过于苛刻

平均场假设要保证隐变量各分组之间相互独立。而隐变量本身就是基于真实情况人为定义的变量
实际情况中,定义的隐变量满足平均场假设是极为困难的,因此,经典变分推断基本无法使用于真实任务

至此,我们在近似求解后验概率分布 P ( Z ∣ X ) P(\mathcal Z \mid \mathcal X) P(ZX),就需要对 P ( Z ∣ X ) P(\mathcal Z \mid \mathcal X) P(ZX)整体进行求解
本节将从梯度角度对 P ( Z ∣ X ) P(\mathcal Z \mid \mathcal X) P(ZX)进行求解。

随机梯度变分推断的求解过程

回顾变分推断的推导过程,基于隐变量 Z \mathcal Z Z最优近似分布 Q ^ ( Z ) \hat {\mathcal Q}(\mathcal Z) Q^(Z) 可进行如下表示:
Q ^ ( Z ) = arg ⁡ max ⁡ Q ( Z ) L [ Q ( Z ) ] ⇒ Q ^ ( Z ) ≈ P ( Z ∣ X ) L [ Q ( Z ) ] = ∫ Z Q ( Z ) ⋅ log ⁡ [ P ( X , Z ) Q ( Z ) ] d Z \hat {\mathcal Q}(\mathcal Z) = \mathop{\arg\max}\limits_{\mathcal Q(\mathcal Z)} \mathcal L[\mathcal Q(\mathcal Z)] \Rightarrow \hat {\mathcal Q}(\mathcal Z) \approx P(\mathcal Z \mid \mathcal X) \\ \mathcal L[\mathcal Q(\mathcal Z)] = \int_{\mathcal Z} \mathcal Q(\mathcal Z) \cdot \log \left[\frac{P(\mathcal X,\mathcal Z)}{\mathcal Q(\mathcal Z)}\right] d \mathcal Z Q^(Z)=Q(Z)argmaxL[Q(Z)]Q^(Z)P(ZX)L[Q(Z)]=ZQ(Z)log[Q(Z)P(X,Z)]dZ
既然是 通过调整 Q ( Z ) \mathcal Q(\mathcal Z) Q(Z)的最值,使得 L [ Q ( Z ) ] \mathcal L[\mathcal Q(\mathcal Z)] L[Q(Z)]达到最大,因此可以尝试使用 梯度上升法(Gradient Ascent) 进行求解。

这里需要进行一些假设
既然要求解最优的 Q ( Z ) \mathcal Q(\mathcal Z) Q(Z),根据梯度上升法,自然要求解 Q ( Z ) \mathcal Q(\mathcal Z) Q(Z)的梯度。

Q ( Z ) \mathcal Q(\mathcal Z) Q(Z)本身是一个分布,也可以看作成一个概率模型。而概率模型本身可以看作是关于该模型参数的一个函数。因此:定义概率模型 Q ( Z ) \mathcal Q(\mathcal Z) Q(Z)的模型参数为 ϕ \phi ϕ,最终将求解 Q ( Z ) \mathcal Q(\mathcal Z) Q(Z)梯度转化为求解模型参数 ϕ \phi ϕ的梯度
Q ( Z ∣ ϕ ) \mathcal Q(\mathcal Z \mid \phi) Q(Zϕ)写法是保留之前对概率模型的表达。例如 P ( X ∣ θ ) P(\mathcal X \mid \theta) P(Xθ),对应的 L [ Q ( Z ) ] \mathcal L[\mathcal Q(\mathcal Z)] L[Q(Z)]公式也需要进行修改。
Q ( Z ) → Q ( Z ∣ ϕ ) L [ Q ( Z ) ] = ∫ Z ∣ ϕ Q ( Z ∣ ϕ ) ⋅ log ⁡ [ P ( X , Z ) Q ( Z ∣ ϕ ) ] d Z = E Q ( Z ∣ ϕ ) [ log ⁡ P ( X , Z ) − log ⁡ Q ( Z ∣ ϕ ) ] = L ( ϕ ) \mathcal Q(\mathcal Z) \to \mathcal Q(\mathcal Z \mid \phi) \\ \begin{aligned} \mathcal L[\mathcal Q(\mathcal Z)] & = \int_{\mathcal Z \mid \phi} \mathcal Q(\mathcal Z \mid \phi) \cdot \log \left[\frac{P(\mathcal X,\mathcal Z)}{\mathcal Q(\mathcal Z \mid \phi)}\right] d\mathcal Z \\ & = \mathbb E_{\mathcal Q(\mathcal Z \mid \phi)} \left[\log P(\mathcal X,\mathcal Z) - \log \mathcal Q(\mathcal Z \mid \phi)\right] \\ & = \mathcal L(\phi) \end{aligned} Q(Z)Q(Zϕ)L[Q(Z)]=ZϕQ(Zϕ)log[Q(Zϕ)P(X,Z)]dZ=EQ(Zϕ)[logP(X,Z)logQ(Zϕ)]=L(ϕ)
与此同时, L [ Q ( Z ) ] \mathcal L[\mathcal Q(\mathcal Z)] L[Q(Z)]中的变量由 Q ( Z ) \mathcal Q(\mathcal Z) Q(Z)变为 ϕ \phi ϕ L ( ϕ ) \mathcal L(\phi) L(ϕ)。从而将求解最优 Q ^ ( Z ) \hat {\mathcal Q}(\mathcal Z) Q^(Z)转化为求解最优参数 ϕ ^ \hat \phi ϕ^
ϕ ^ = arg ⁡ max ⁡ ϕ L ( ϕ ) \hat \phi = \mathop{\arg\max}\limits_{\phi} \mathcal L(\phi) ϕ^=ϕargmaxL(ϕ)
梯度 ∇ ϕ L ( ϕ ) \nabla_{\phi}\mathcal L(\phi) ϕL(ϕ)进行表示:
∇ ϕ L ( ϕ ) = ∇ ϕ ∫ Z ∣ ϕ Q ( Z ∣ ϕ ) ⋅ log ⁡ [ P ( X , Z ) Q ( Z ∣ ϕ ) ] d Z = ∇ ϕ ∫ Z ∣ ϕ Q ( Z ∣ ϕ ) ⋅ [ log ⁡ P ( X , Z ) − log ⁡ Q ( Z ∣ ϕ ) ] d Z \begin{aligned} \nabla_{\phi}\mathcal L(\phi) & = \nabla_{\phi} \int_{\mathcal Z \mid \phi} \mathcal Q(\mathcal Z \mid \phi) \cdot \log \left[\frac{P(\mathcal X,\mathcal Z)}{\mathcal Q(\mathcal Z \mid \phi)}\right] d\mathcal Z \\ & = \nabla_{\phi} \int_{\mathcal Z \mid \phi} \mathcal Q(\mathcal Z \mid \phi) \cdot \left[ \log P(\mathcal X,\mathcal Z) - \log \mathcal Q(\mathcal Z \mid \phi)\right] d\mathcal Z \end{aligned} ϕL(ϕ)=ϕZϕQ(Zϕ)log[Q(Zϕ)P(X,Z)]dZ=ϕZϕQ(Zϕ)[logP(X,Z)logQ(Zϕ)]dZ
根据牛顿-莱布尼兹公式,将积分号 ∫ \int 与梯度 ∇ \nabla 进行交换
乘法求导~
∫ Z ∣ ϕ ∇ ϕ Q ( Z ∣ ϕ ) ⋅ [ log ⁡ P ( X , Z ) − log ⁡ Q ( Z ∣ ϕ ) ] d Z + ∫ Z ∣ ϕ Q ( Z ∣ ϕ ) ⋅ ∇ ϕ [ log ⁡ P ( X , Z ) − log ⁡ Q ( Z ∣ ϕ ) ] d Z \int_{\mathcal Z \mid \phi} \nabla_{\phi} \mathcal Q(\mathcal Z \mid \phi) \cdot \left[ \log P(\mathcal X,\mathcal Z) - \log \mathcal Q(\mathcal Z \mid \phi)\right]d\mathcal Z + \int_{\mathcal Z \mid \phi} \mathcal Q(\mathcal Z \mid \phi) \cdot \nabla_{\phi}\left[ \log P(\mathcal X,\mathcal Z) - \log \mathcal Q(\mathcal Z \mid \phi)\right] d\mathcal Z ZϕϕQ(Zϕ)[logP(X,Z)logQ(Zϕ)]dZ+ZϕQ(Zϕ)ϕ[logP(X,Z)logQ(Zϕ)]dZ

观察第二项 ∫ Z ∣ ϕ Q ( Z ∣ ϕ ) ⋅ ∇ ϕ [ log ⁡ P ( X , Z ) − log ⁡ Q ( Z ∣ ϕ ) ] d Z \int_{\mathcal Z \mid \phi} \mathcal Q(\mathcal Z \mid \phi) \cdot \nabla_{\phi}\left[ \log P(\mathcal X,\mathcal Z) - \log \mathcal Q(\mathcal Z \mid \phi)\right] d\mathcal Z ZϕQ(Zϕ)ϕ[logP(X,Z)logQ(Zϕ)]dZ

  • 由于 ϕ \phi ϕ概率模型 Q ( Z ∣ ϕ ) \mathcal Q(\mathcal Z \mid \phi) Q(Zϕ)的模型参数,而 P ( X , Z ) P(\mathcal X,\mathcal Z) P(X,Z) X , Z \mathcal X,\mathcal Z X,Z的联合概率分布,因此与 ϕ \phi ϕ无关。因此第二项可变化为
    − ∫ Z ∣ ϕ Q ( Z ∣ ϕ ) ⋅ ∇ ϕ log ⁡ Q ( Z ∣ ϕ ) d Z = − ∫ Z ∣ ϕ 1 Q ( Z ∣ ϕ ) ⋅ Q ( Z ∣ ϕ ) ⋅ ∇ ϕ Q ( Z ∣ ϕ ) d Z = − ∫ Z ∣ ϕ ∇ ϕ Q ( Z ∣ ϕ ) d Z \begin{aligned} & - \int_{\mathcal Z \mid \phi} \mathcal Q(\mathcal Z \mid \phi) \cdot \nabla_{\phi} \log \mathcal Q(\mathcal Z \mid \phi) d\mathcal Z \\ & = -\int_{\mathcal Z \mid \phi} \frac{1}{\mathcal Q(\mathcal Z \mid \phi)} \cdot \mathcal Q(\mathcal Z \mid \phi) \cdot \nabla_{\phi} \mathcal Q(\mathcal Z \mid \phi)d\mathcal Z \\ & = - \int_{\mathcal Z \mid \phi} \nabla_{\phi} \mathcal Q(\mathcal Z \mid \phi)d\mathcal Z \end{aligned} ZϕQ(Zϕ)ϕlogQ(Zϕ)dZ=ZϕQ(Zϕ)1Q(Zϕ)ϕQ(Zϕ)dZ=ZϕϕQ(Zϕ)dZ
  • 再次使用牛顿-莱布尼兹公式,将梯度符号 ∇ \nabla 还原位置:
    − ∇ ϕ ∫ Z ∣ ϕ Q ( Z ∣ ϕ ) d Z - \nabla_{\phi} \int_{\mathcal Z \mid \phi} \mathcal Q(\mathcal Z \mid \phi) d\mathcal Z ϕZϕQ(Zϕ)dZ
  • 根据概率密度积分 ∫ Z ∣ ϕ Q ( Z ∣ ϕ ) d Z = 1 \int_{\mathcal Z \mid \phi} \mathcal Q(\mathcal Z \mid \phi) d\mathcal Z = 1 ZϕQ(Zϕ)dZ=1,第二项相当于对常数1求偏导,最后结果为0。即:
    第二项被完整地消掉了~
    ∫ Z ∣ ϕ Q ( Z ∣ ϕ ) ⋅ ∇ ϕ [ log ⁡ P ( X , Z ) − log ⁡ Q ( Z ∣ ϕ ) ] d Z = − ∇ ϕ 1 = 0 \int_{\mathcal Z \mid \phi} \mathcal Q(\mathcal Z \mid \phi) \cdot \nabla_{\phi}\left[ \log P(\mathcal X,\mathcal Z) - \log \mathcal Q(\mathcal Z \mid \phi)\right] d\mathcal Z = -\nabla_{\phi} 1 = 0 ZϕQ(Zϕ)ϕ[logP(X,Z)logQ(Zϕ)]dZ=ϕ1=0

至此, ∇ ϕ L ( ϕ ) \nabla_{\phi} \mathcal L(\phi) ϕL(ϕ)可表示为:
只剩下了第一项~
∇ ϕ L ( ϕ ) = ∫ Z ∣ ϕ ∇ ϕ Q ( Z ∣ ϕ ) ⋅ [ log ⁡ P ( X , Z ) − log ⁡ Q ( Z ∣ ϕ ) ] d Z \nabla_{\phi} \mathcal L(\phi) = \int_{\mathcal Z \mid \phi} \nabla_{\phi} \mathcal Q(\mathcal Z \mid \phi) \cdot \left[ \log P(\mathcal X,\mathcal Z) - \log \mathcal Q(\mathcal Z \mid \phi)\right]d\mathcal Z ϕL(ϕ)=ZϕϕQ(Zϕ)[logP(X,Z)logQ(Zϕ)]dZ
观察: ∇ ϕ Q ( Z ∣ ϕ ) \nabla_{\phi}\mathcal Q(\mathcal Z \mid \phi) ϕQ(Zϕ)它并不是概率分布,而是概率分布的梯度。因此没有办法将上式写成期望形式
但是这里通过技巧 Q ( Z ∣ ϕ ) \mathcal Q(\mathcal Z \mid \phi) Q(Zϕ)还原出来
可以自己反过来推一下~
∇ ϕ Q ( Z ∣ ϕ ) = Q ( Z ∣ ϕ ) ⋅ ∇ ϕ log ⁡ Q ( Z ∣ ϕ ) \nabla_{\phi}\mathcal Q(\mathcal Z \mid \phi) = \mathcal Q(\mathcal Z \mid \phi) \cdot \nabla_{\phi} \log \mathcal Q(\mathcal Z \mid \phi) ϕQ(Zϕ)=Q(Zϕ)ϕlogQ(Zϕ)
将上式带入, ∇ ϕ L ( ϕ ) \nabla_{\phi} \mathcal L(\phi) ϕL(ϕ)可以表示为:
∫ Z ∣ ϕ Q ( Z ∣ ϕ ) ⋅ ∇ ϕ log ⁡ Q ( Z ∣ ϕ ) ⋅ [ log ⁡ P ( X , Z ) − log ⁡ Q ( Z ∣ ϕ ) ] d Z \int_{\mathcal Z \mid \phi} \mathcal Q(\mathcal Z \mid \phi) \cdot \nabla_{\phi} \log \mathcal Q(\mathcal Z \mid \phi) \cdot \left[ \log P(\mathcal X,\mathcal Z) - \log \mathcal Q(\mathcal Z \mid \phi)\right] d\mathcal Z ZϕQ(Zϕ)ϕlogQ(Zϕ)[logP(X,Z)logQ(Zϕ)]dZ
可以将上述积分看作 Q ( Z ∣ ϕ ) \mathcal Q(\mathcal Z \mid \phi) Q(Zϕ)分布的期望形式
∇ ϕ L ( ϕ ) = E Q ( Z ∣ ϕ ) { ∇ ϕ log ⁡ Q ( Z ∣ ϕ ) ⋅ [ log ⁡ P ( X , Z ) − log ⁡ Q ( Z ∣ ϕ ) ] } \nabla_{\phi} \mathcal L(\phi) =\mathbb E_{\mathcal Q(\mathcal Z \mid \phi)}\left\{\nabla_{\phi} \log \mathcal Q(\mathcal Z \mid \phi) \cdot [\log P(\mathcal X,\mathcal Z) - \log \mathcal Q(\mathcal Z \mid \phi)]\right\} ϕL(ϕ)=EQ(Zϕ){ϕlogQ(Zϕ)[logP(X,Z)logQ(Zϕ)]}
至此,将梯度 ∇ ϕ L ( ϕ ) \nabla_{\phi}\mathcal L(\phi) ϕL(ϕ)使用期望形式表示出来。后续可以使用蒙特卡洛采样方法对该期望进行近似求解

至此,每求解一个 ∇ ϕ L ( ϕ ) \nabla_{\phi} \mathcal L(\phi) ϕL(ϕ),都可以对 Q ( Z ∣ ϕ ) \mathcal Q(\mathcal Z \mid \phi) Q(Zϕ)概率分布的模型参数 ϕ \phi ϕ 更新一次,以此类推。
最终可以近似得到概率模型 Q ( Z ∣ ϕ ) \mathcal Q(\mathcal Z \mid \phi) Q(Zϕ)的最优模型参数 ϕ ^ \hat \phi ϕ^,从而求解概率模型 Q ( Z ∣ ϕ ^ ) \mathcal Q(\mathcal Z \mid \hat \phi) Q(Zϕ^)

下一节将介绍 随机梯度变分推断的问题及其他衍生方法

相关参考:
机器学习-变分推断4(随机梯度变分推断-SGVI-1)

更多推荐