告别数学恐惧!用Python从零实现Gibbs采样,搞定高维分布采样(附完整代码)

很多数据科学初学者在面对复杂的概率模型时,常常被Gibbs采样背后的数学理论吓退。但今天我要告诉你一个秘密: 你完全可以在不理解马尔可夫链数学证明的情况下,通过代码直观掌握Gibbs采样的精髓 。我们将从一个简单的二维混合高斯分布出发,用不到100行Python代码实现完整的Gibbs采样器,并可视化采样过程。

1. 准备工作:理解问题场景

假设你正在处理一个用户行为分析项目,需要从以下联合概率分布中采样:

import numpy as np
from scipy.stats import multivariate_normal
import matplotlib.pyplot as plt

# 定义两个高斯分布的参数
mu1 = np.array([0, 0])
cov1 = np.array([[1.5, 0.2], [0.2, 1.4]])

mu2 = np.array([2, 3])
cov2 = np.array([[1.1, 0.4], [0.4, 1.3]])

# 混合权重
w1, w2 = 0.5, 0.5

def target_distribution(x, y):
    pos = np.array([x, y])
    return w1 * multivariate_normal.pdf(pos, mean=mu1, cov=cov1) + \
           w2 * multivariate_normal.pdf(pos, mean=mu2, cov=cov2)

这个分布有两个明显的峰值区域,传统采样方法难以处理。Gibbs采样的核心思想是: 在高维空间中,沿着每个坐标轴方向轮流采样 ,逐步逼近目标分布。

2. Gibbs采样器实现

Gibbs采样需要知道每个变量的条件分布。对于我们的二维高斯混合模型,条件分布计算如下:

def conditional_distribution(x_given, dim):
    """计算给定一个维度时另一个维度的条件分布"""
    # 计算每个高斯分量在当前x_given下的权重
    if dim == 0:  # 给定y,采样x
        py1 = multivariate_normal.pdf(x_given, mean=mu1[1], cov=cov1[1,1])
        py2 = multivariate_normal.pdf(x_given, mean=mu2[1], cov=cov2[1,1])
        w = w1 * py1 / (w1 * py1 + w2 * py2)
        
        # 计算混合高斯条件分布的参数
        cond_mu = w * (mu1[0] + cov1[0,1]/cov1[1,1]*(x_given - mu1[1])) + \
                 (1-w) * (mu2[0] + cov2[0,1]/cov2[1,1]*(x_given - mu2[1]))
        cond_var = w * (cov1[0,0] - cov1[0,1]**2/cov1[1,1]) + \
                  (1-w) * (cov2[0,0] - cov2[0,1]**2/cov2[1,1])
    else:  # 给定x,采样y
        px1 = multivariate_normal.pdf(x_given, mean=mu1[0], cov=cov1[0,0])
        px2 = multivariate_normal.pdf(x_given, mean=mu2[0], cov=cov2[0,0])
        w = w1 * px1 / (w1 * px1 + w2 * px2)
        
        cond_mu = w * (mu1[1] + cov1[1,0]/cov1[0,0]*(x_given - mu1[0])) + \
                 (1-w) * (mu2[1] + cov2[1,0]/cov2[0,0]*(x_given - mu2[0]))
        cond_var = w * (cov1[1,1] - cov1[1,0]**2/cov1[0,0]) + \
                  (1-w) * (cov2[1,1] - cov2[1,0]**2/cov2[0,0])
    
    return cond_mu, cond_var

现在我们可以实现Gibbs采样器了:

def gibbs_sampling(initial_point, n_samples, burn_in):
    samples = np.zeros((n_samples + burn_in, 2))
    samples[0] = initial_point
    
    for i in range(1, n_samples + burn_in):
        # 交替采样每个维度
        dim = i % 2
        
        # 根据另一个维度的当前值计算条件分布
        x_given = samples[i-1, 1-dim]
        cond_mu, cond_var = conditional_distribution(x_given, dim)
        
        # 从条件分布中采样
        samples[i, dim] = np.random.normal(cond_mu, np.sqrt(cond_var))
        samples[i, 1-dim] = x_given  # 保持另一个维度不变
    
    return samples[burn_in:]  # 丢弃burn-in期的样本

3. 采样过程可视化

让我们运行采样器并观察采样路径:

# 运行Gibbs采样
np.random.seed(42)
initial_point = np.array([-2, -2])  # 故意从一个不好的初始点开始
samples = gibbs_sampling(initial_point, n_samples=2000, burn_in=500)

# 绘制采样路径
plt.figure(figsize=(12, 6))
x = np.linspace(-4, 5, 100)
y = np.linspace(-4, 6, 100)
X, Y = np.meshgrid(x, y)
Z = np.zeros_like(X)

for i in range(len(x)):
    for j in range(len(y)):
        Z[j, i] = target_distribution(x[i], y[j])

plt.contour(X, Y, Z, levels=20, cmap='Blues')
plt.plot(samples[:100, 0], samples[:100, 1], 'r-', alpha=0.5, lw=1)
plt.scatter(samples[:100, 0], samples[:100, 1], c='red', s=20, alpha=0.7)
plt.title('Gibbs采样路径 (前100次迭代)')
plt.xlabel('x')
plt.ylabel('y')
plt.colorbar()
plt.show()

你会看到采样点如何从初始位置逐渐"爬向"高概率区域,最终在两个峰值之间来回跳跃。

4. 实际应用中的关键技巧

4.1 初始值和burn-in期选择

Gibbs采样对初始值不敏感,但选择合适的burn-in期很重要。一个实用的方法是:

  1. 运行多次短链,观察收敛情况
  2. 计算 Gelman-Rubin统计量 评估收敛
  3. 保守估计:丢弃前10-20%的样本
def check_convergence():
    # 运行4条独立链
    chains = []
    for _ in range(4):
        initial = np.random.uniform(-3, -1, size=2)
        chains.append(gibbs_sampling(initial, 1000, 0))
    
    # 计算链内和链间方差
    chain_means = np.mean(chains, axis=1)
    chain_vars = np.var(chains, axis=1, ddof=1)
    
    W = np.mean(chain_vars)  # 链内方差
    B = len(chains[0]) * np.var(chain_means, ddof=1)  # 链间方差
    
    R_hat = np.sqrt((W + B) / W)
    print(f"Gelman-Rubin统计量: {R_hat:.3f}")
    
    if R_hat < 1.1:
        print("链已收敛")
    else:
        print("链未收敛,需要更长的burn-in期")

check_convergence()

4.2 处理高维情况

当维度增加时,Gibbs采样效率会下降。这时可以考虑:

  • 块Gibbs采样 :将相关变量一起采样
  • 与Metropolis-Hastings结合 :对某些维度使用MH采样
  • 参数化技巧 :寻找更有效的参数表示

4.3 诊断采样质量

除了目视检查,还可以:

  1. 计算 自相关函数 评估样本独立性
  2. 比较 边际分布 与理论值
  3. 使用 迹图 检查采样是否卡住
def plot_trace(samples):
    plt.figure(figsize=(12, 4))
    plt.subplot(121)
    plt.plot(samples[:, 0])
    plt.title('x维度的迹图')
    
    plt.subplot(122)
    plt.plot(samples[:, 1])
    plt.title('y维度的迹图')
    plt.show()

plot_trace(samples[:500])  # 检查前500个样本

5. 完整代码与扩展应用

以下是完整的Gibbs采样实现,可直接用于你的项目:

import numpy as np
from scipy.stats import multivariate_normal
import matplotlib.pyplot as plt

class GibbsSampler:
    def __init__(self, target_dist, cond_dist_func, dim=2):
        self.target_dist = target_dist
        self.cond_dist_func = cond_dist_func
        self.dim = dim
    
    def sample(self, initial_point, n_samples, burn_in=0, thin=1):
        samples = np.zeros((n_samples * thin + burn_in, self.dim))
        samples[0] = initial_point
        
        for i in range(1, n_samples * thin + burn_in):
            dim = i % self.dim
            given_values = np.delete(samples[i-1], dim)
            
            cond_mu, cond_var = self.cond_dist_func(given_values, dim)
            samples[i, dim] = np.random.normal(cond_mu, np.sqrt(cond_var))
            samples[i, np.delete(np.arange(self.dim), dim)] = given_values
        
        return samples[burn_in::thin]

# 使用示例
if __name__ == "__main__":
    # 定义目标分布和条件分布函数 (如前所述)
    # ...
    
    sampler = GibbsSampler(target_distribution, conditional_distribution)
    samples = sampler.sample(initial_point=np.array([-2, -2]), 
                           n_samples=2000, 
                           burn_in=500)
    
    # 绘制结果
    plt.figure(figsize=(8, 8))
    plt.contour(X, Y, Z, levels=20, cmap='Blues')
    plt.scatter(samples[:, 0], samples[:, 1], c='red', s=5, alpha=0.3)
    plt.title('Gibbs采样结果')
    plt.show()

这个实现可以轻松扩展到更高维度,只需提供相应的条件分布函数即可。Gibbs采样在以下场景特别有用:

  • 贝叶斯统计 :从后验分布中采样
  • 主题建模 :LDA等模型参数估计
  • 图像处理 :马尔可夫随机场采样
  • 金融建模 :复杂联合分布模拟

记住,Gibbs采样的威力在于它能处理那些难以直接采样的高维分布,而实现它所需的代码出奇地简单。当你下次遇到复杂的联合分布时,不妨试试这个"轮流采样"的巧妙方法。

更多推荐