告别数学恐惧!用Python从零实现Gibbs采样,搞定高维分布采样(附完整代码)
·
告别数学恐惧!用Python从零实现Gibbs采样,搞定高维分布采样(附完整代码)
很多数据科学初学者在面对复杂的概率模型时,常常被Gibbs采样背后的数学理论吓退。但今天我要告诉你一个秘密: 你完全可以在不理解马尔可夫链数学证明的情况下,通过代码直观掌握Gibbs采样的精髓 。我们将从一个简单的二维混合高斯分布出发,用不到100行Python代码实现完整的Gibbs采样器,并可视化采样过程。
1. 准备工作:理解问题场景
假设你正在处理一个用户行为分析项目,需要从以下联合概率分布中采样:
import numpy as np
from scipy.stats import multivariate_normal
import matplotlib.pyplot as plt
# 定义两个高斯分布的参数
mu1 = np.array([0, 0])
cov1 = np.array([[1.5, 0.2], [0.2, 1.4]])
mu2 = np.array([2, 3])
cov2 = np.array([[1.1, 0.4], [0.4, 1.3]])
# 混合权重
w1, w2 = 0.5, 0.5
def target_distribution(x, y):
pos = np.array([x, y])
return w1 * multivariate_normal.pdf(pos, mean=mu1, cov=cov1) + \
w2 * multivariate_normal.pdf(pos, mean=mu2, cov=cov2)
这个分布有两个明显的峰值区域,传统采样方法难以处理。Gibbs采样的核心思想是: 在高维空间中,沿着每个坐标轴方向轮流采样 ,逐步逼近目标分布。
2. Gibbs采样器实现
Gibbs采样需要知道每个变量的条件分布。对于我们的二维高斯混合模型,条件分布计算如下:
def conditional_distribution(x_given, dim):
"""计算给定一个维度时另一个维度的条件分布"""
# 计算每个高斯分量在当前x_given下的权重
if dim == 0: # 给定y,采样x
py1 = multivariate_normal.pdf(x_given, mean=mu1[1], cov=cov1[1,1])
py2 = multivariate_normal.pdf(x_given, mean=mu2[1], cov=cov2[1,1])
w = w1 * py1 / (w1 * py1 + w2 * py2)
# 计算混合高斯条件分布的参数
cond_mu = w * (mu1[0] + cov1[0,1]/cov1[1,1]*(x_given - mu1[1])) + \
(1-w) * (mu2[0] + cov2[0,1]/cov2[1,1]*(x_given - mu2[1]))
cond_var = w * (cov1[0,0] - cov1[0,1]**2/cov1[1,1]) + \
(1-w) * (cov2[0,0] - cov2[0,1]**2/cov2[1,1])
else: # 给定x,采样y
px1 = multivariate_normal.pdf(x_given, mean=mu1[0], cov=cov1[0,0])
px2 = multivariate_normal.pdf(x_given, mean=mu2[0], cov=cov2[0,0])
w = w1 * px1 / (w1 * px1 + w2 * px2)
cond_mu = w * (mu1[1] + cov1[1,0]/cov1[0,0]*(x_given - mu1[0])) + \
(1-w) * (mu2[1] + cov2[1,0]/cov2[0,0]*(x_given - mu2[0]))
cond_var = w * (cov1[1,1] - cov1[1,0]**2/cov1[0,0]) + \
(1-w) * (cov2[1,1] - cov2[1,0]**2/cov2[0,0])
return cond_mu, cond_var
现在我们可以实现Gibbs采样器了:
def gibbs_sampling(initial_point, n_samples, burn_in):
samples = np.zeros((n_samples + burn_in, 2))
samples[0] = initial_point
for i in range(1, n_samples + burn_in):
# 交替采样每个维度
dim = i % 2
# 根据另一个维度的当前值计算条件分布
x_given = samples[i-1, 1-dim]
cond_mu, cond_var = conditional_distribution(x_given, dim)
# 从条件分布中采样
samples[i, dim] = np.random.normal(cond_mu, np.sqrt(cond_var))
samples[i, 1-dim] = x_given # 保持另一个维度不变
return samples[burn_in:] # 丢弃burn-in期的样本
3. 采样过程可视化
让我们运行采样器并观察采样路径:
# 运行Gibbs采样
np.random.seed(42)
initial_point = np.array([-2, -2]) # 故意从一个不好的初始点开始
samples = gibbs_sampling(initial_point, n_samples=2000, burn_in=500)
# 绘制采样路径
plt.figure(figsize=(12, 6))
x = np.linspace(-4, 5, 100)
y = np.linspace(-4, 6, 100)
X, Y = np.meshgrid(x, y)
Z = np.zeros_like(X)
for i in range(len(x)):
for j in range(len(y)):
Z[j, i] = target_distribution(x[i], y[j])
plt.contour(X, Y, Z, levels=20, cmap='Blues')
plt.plot(samples[:100, 0], samples[:100, 1], 'r-', alpha=0.5, lw=1)
plt.scatter(samples[:100, 0], samples[:100, 1], c='red', s=20, alpha=0.7)
plt.title('Gibbs采样路径 (前100次迭代)')
plt.xlabel('x')
plt.ylabel('y')
plt.colorbar()
plt.show()
你会看到采样点如何从初始位置逐渐"爬向"高概率区域,最终在两个峰值之间来回跳跃。
4. 实际应用中的关键技巧
4.1 初始值和burn-in期选择
Gibbs采样对初始值不敏感,但选择合适的burn-in期很重要。一个实用的方法是:
- 运行多次短链,观察收敛情况
- 计算 Gelman-Rubin统计量 评估收敛
- 保守估计:丢弃前10-20%的样本
def check_convergence():
# 运行4条独立链
chains = []
for _ in range(4):
initial = np.random.uniform(-3, -1, size=2)
chains.append(gibbs_sampling(initial, 1000, 0))
# 计算链内和链间方差
chain_means = np.mean(chains, axis=1)
chain_vars = np.var(chains, axis=1, ddof=1)
W = np.mean(chain_vars) # 链内方差
B = len(chains[0]) * np.var(chain_means, ddof=1) # 链间方差
R_hat = np.sqrt((W + B) / W)
print(f"Gelman-Rubin统计量: {R_hat:.3f}")
if R_hat < 1.1:
print("链已收敛")
else:
print("链未收敛,需要更长的burn-in期")
check_convergence()
4.2 处理高维情况
当维度增加时,Gibbs采样效率会下降。这时可以考虑:
- 块Gibbs采样 :将相关变量一起采样
- 与Metropolis-Hastings结合 :对某些维度使用MH采样
- 参数化技巧 :寻找更有效的参数表示
4.3 诊断采样质量
除了目视检查,还可以:
- 计算 自相关函数 评估样本独立性
- 比较 边际分布 与理论值
- 使用 迹图 检查采样是否卡住
def plot_trace(samples):
plt.figure(figsize=(12, 4))
plt.subplot(121)
plt.plot(samples[:, 0])
plt.title('x维度的迹图')
plt.subplot(122)
plt.plot(samples[:, 1])
plt.title('y维度的迹图')
plt.show()
plot_trace(samples[:500]) # 检查前500个样本
5. 完整代码与扩展应用
以下是完整的Gibbs采样实现,可直接用于你的项目:
import numpy as np
from scipy.stats import multivariate_normal
import matplotlib.pyplot as plt
class GibbsSampler:
def __init__(self, target_dist, cond_dist_func, dim=2):
self.target_dist = target_dist
self.cond_dist_func = cond_dist_func
self.dim = dim
def sample(self, initial_point, n_samples, burn_in=0, thin=1):
samples = np.zeros((n_samples * thin + burn_in, self.dim))
samples[0] = initial_point
for i in range(1, n_samples * thin + burn_in):
dim = i % self.dim
given_values = np.delete(samples[i-1], dim)
cond_mu, cond_var = self.cond_dist_func(given_values, dim)
samples[i, dim] = np.random.normal(cond_mu, np.sqrt(cond_var))
samples[i, np.delete(np.arange(self.dim), dim)] = given_values
return samples[burn_in::thin]
# 使用示例
if __name__ == "__main__":
# 定义目标分布和条件分布函数 (如前所述)
# ...
sampler = GibbsSampler(target_distribution, conditional_distribution)
samples = sampler.sample(initial_point=np.array([-2, -2]),
n_samples=2000,
burn_in=500)
# 绘制结果
plt.figure(figsize=(8, 8))
plt.contour(X, Y, Z, levels=20, cmap='Blues')
plt.scatter(samples[:, 0], samples[:, 1], c='red', s=5, alpha=0.3)
plt.title('Gibbs采样结果')
plt.show()
这个实现可以轻松扩展到更高维度,只需提供相应的条件分布函数即可。Gibbs采样在以下场景特别有用:
- 贝叶斯统计 :从后验分布中采样
- 主题建模 :LDA等模型参数估计
- 图像处理 :马尔可夫随机场采样
- 金融建模 :复杂联合分布模拟
记住,Gibbs采样的威力在于它能处理那些难以直接采样的高维分布,而实现它所需的代码出奇地简单。当你下次遇到复杂的联合分布时,不妨试试这个"轮流采样"的巧妙方法。
更多推荐


所有评论(0)