别被公式吓跑!用Python和PyTorch亲手实现NeRF里的球面谐波编码(附代码)

在探索NeRF和3D高斯泼溅(3DGS)这类前沿技术时,许多开发者会被其中复杂的数学公式劝退。尤其是当看到"球面谐波函数"(Spherical Harmonics, SH)这样的术语时,更是一头雾水。但事实上,这些看似高深的概念完全可以通过代码实践来直观理解。本文将带你用Python和PyTorch一步步实现SH编码,并通过可视化展示不同阶数对渲染效果的影响,让你真正掌握这一核心技术的实现原理。

1. 球面谐波函数:从概念到代码

球面谐波函数本质上是一组定义在球面上的正交基函数,类似于傅里叶级数在圆上的展开。在计算机图形学和3D重建领域,SH被广泛用于编码方向相关的信息,如光照、颜色等。理解SH的关键在于认识到:

  • 任何定义在球面上的函数都可以表示为SH基函数的线性组合
  • 高阶SH能捕捉更复杂的细节,但需要更多计算资源
  • 在NeRF和3DGS中,通常使用2阶或3阶SH就足够表达方向相关的颜色变化

让我们从SH的数学定义开始。第l阶第m个SH基函数在球坐标系(θ,φ)下的表达式为:

import math
import torch

def factorial_ratio(n, k):
    """计算n!/(n-k)!"""
    result = 1
    for i in range(n, n-k, -1):
        result *= i
    return result

def associated_legendre(l, m, x):
    """计算关联勒让德多项式P_l^m(x)"""
    if m < 0:
        m = -m
        sign = (-1)**m
        p = associated_legendre(l, m, x)
        return sign * factorial_ratio(l-m, l+m) * p
    # 实现省略...

2. 构建SH基函数计算器

理解了数学基础后,我们可以着手实现SH基函数的计算。这里我们采用PyTorch实现,以便后续与深度学习模型无缝集成。

def spherical_harmonics(l, m, theta, phi):
    """计算球面谐波函数Y_l^m(theta, phi)"""
    # 归一化常数
    norm = math.sqrt((2*l+1)/(4*math.pi) * factorial_ratio(l-m, l+m))
    
    # 关联勒让德多项式部分
    x = torch.cos(theta)
    P = associated_legendre(l, m, x)
    
    # 复数部分
    if m > 0:
        return math.sqrt(2) * norm * P * torch.cos(m*phi)
    elif m == 0:
        return norm * P
    else:
        return math.sqrt(2) * norm * P * torch.sin(-m*phi)

为了验证我们的实现是否正确,我们可以可视化不同阶数的SH基函数:

import matplotlib.pyplot as plt
import numpy as np

def plot_sh(l_max=2):
    """可视化SH基函数"""
    theta = np.linspace(0, np.pi, 100)
    phi = np.linspace(0, 2*np.pi, 100)
    theta, phi = np.meshgrid(theta, phi)
    
    fig = plt.figure(figsize=(15, 10))
    index = 1
    
    for l in range(l_max+1):
        for m in range(-l, l+1):
            ax = fig.add_subplot(l_max+1, 2*l_max+1, index, projection='3d')
            Y = spherical_harmonics(l, m, theta, phi)
            # 转换为笛卡尔坐标进行可视化
            x = np.sin(theta) * np.cos(phi) * np.abs(Y)
            y = np.sin(theta) * np.sin(phi) * np.abs(Y)
            z = np.cos(theta) * np.abs(Y)
            ax.plot_surface(x, y, z, cmap='viridis')
            ax.set_title(f'l={l}, m={m}')
            index += 1
    
    plt.tight_layout()
    plt.show()

运行 plot_sh(2) 将生成一个3×3的网格,展示2阶SH的所有基函数形状。你会看到这些基函数呈现出不同的对称模式,这正是它们能够组合表示任意方向函数的基础。

3. 在简化版NeRF中集成SH编码

现在我们已经有了SH基函数的实现,接下来将其集成到一个简化版的NeRF模型中。在标准NeRF中,SH通常用于编码视角相关的颜色分量。

class SimpleNeRFWithSH(torch.nn.Module):
    def __init__(self, sh_degree=2):
        super().__init__()
        self.sh_degree = sh_degree
        # 计算SH基的数量:(degree+1)^2
        self.num_sh_bases = (sh_degree + 1)**2
        
        # 网络主干部分,处理空间位置
        self.position_mlp = torch.nn.Sequential(
            torch.nn.Linear(3, 256),
            torch.nn.ReLU(),
            torch.nn.Linear(256, 256),
            torch.nn.ReLU(),
            torch.nn.Linear(256, 256),
            torch.nn.ReLU(),
            torch.nn.Linear(256, 256),
            torch.nn.ReLU(),
        )
        
        # 输出密度和SH系数
        self.output_head = torch.nn.Sequential(
            torch.nn.Linear(256, 256),
            torch.nn.ReLU(),
            torch.nn.Linear(256, 1 + 3 * self.num_sh_bases),  # 1 for density, 3*num_sh for RGB
        )
    
    def forward(self, x, d):
        """
        x: 空间位置 [N, 3]
        d: 观察方向 [N, 3], 单位向量
        """
        # 处理空间位置
        position_feat = self.position_mlp(x)
        
        # 预测密度和SH系数
        outputs = self.output_head(position_feat)
        sigma = torch.sigmoid(outputs[:, 0])  # 密度
        sh_coeff = outputs[:, 1:].reshape(-1, 3, self.num_sh_bases)  # [N, 3, num_sh]
        
        # 将方向转换为球坐标
        theta = torch.acos(d[:, 2])  # [N]
        phi = torch.atan2(d[:, 1], d[:, 0])  # [N]
        
        # 计算SH基函数值
        sh_bases = []
        for l in range(self.sh_degree + 1):
            for m in range(-l, l + 1):
                sh_bases.append(spherical_harmonics(l, m, theta, phi))
        sh_bases = torch.stack(sh_bases, dim=1)  # [N, num_sh]
        
        # 计算RGB颜色
        rgb = torch.einsum('nsc,ns->nc', sh_coeff, sh_bases)  # [N, 3]
        rgb = torch.sigmoid(rgb)  # 确保颜色在[0,1]范围内
        
        return rgb, sigma

这个简化版NeRF模型的关键改进点在于:

  1. 使用SH编码视角相关颜色,而不是直接将方向输入MLP
  2. 网络输出SH系数而非直接的颜色值
  3. 最后通过SH基函数和系数的组合计算最终颜色

4. 不同SH阶数的效果对比

为了直观理解SH阶数对渲染效果的影响,我们可以训练几个不同SH阶数的模型并比较它们的表现。以下是训练和比较的代码框架:

def train_and_compare_sh_degrees(dataset, degrees=[1, 2, 3]):
    models = {}
    results = {}
    
    for degree in degrees:
        print(f"Training model with SH degree {degree}...")
        model = SimpleNeRFWithSH(sh_degree=degree)
        optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
        
        # 训练循环(简化版)
        for epoch in range(1000):
            # 从数据集中采样批处理
            x, d, target_rgb = dataset.sample_batch()
            
            # 前向传播
            pred_rgb, sigma = model(x, d)
            
            # 计算损失
            loss = torch.mean((pred_rgb - target_rgb)**2)
            
            # 反向传播
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            if epoch % 100 == 0:
                print(f"Epoch {epoch}, Loss: {loss.item():.4f}")
        
        # 保存模型和测试结果
        models[degree] = model
        results[degree] = evaluate_model(model, dataset.test_set)
    
    # 可视化比较结果
    visualize_comparison(results)

通过这样的比较,你会发现:

  • 1阶SH(4个基函数) :只能表示非常简单的方向变化,适合漫反射表面
  • 2阶SH(9个基函数) :能捕捉中等复杂度的方向变化,适合大多数场景
  • 3阶SH(16个基函数) :能表示更精细的方向细节,但需要更多计算资源

在实际应用中,3DGS通常使用3阶SH,而许多NeRF变体使用2阶SH就足够了。选择适当的SH阶数需要在渲染质量和计算效率之间取得平衡。

5. 性能优化与实用技巧

在实际项目中,SH计算的性能至关重要。以下是几个优化技巧:

  1. 预计算SH基函数 :对于固定的观察方向,可以预先计算SH基函数值
  2. 利用SH的对称性 :某些SH基函数可以通过变换重用计算结果
  3. 向量化计算 :使用PyTorch的广播机制批量计算SH值
def optimized_spherical_harmonics(l_max, theta, phi):
    """向量化计算多个SH基函数"""
    # 预计算所有需要的关联勒让德多项式
    x = torch.cos(theta)
    P = {}
    for l in range(l_max + 1):
        for m in range(-l, l + 1):
            P[(l, m)] = associated_legendre(l, m, x)
    
    # 计算所有SH基函数
    Y = []
    for l in range(l_max + 1):
        for m in range(-l, l + 1):
            norm = math.sqrt((2*l+1)/(4*math.pi) * factorial_ratio(l-m, l+m))
            if m > 0:
                y = math.sqrt(2) * norm * P[(l, m)] * torch.cos(m*phi)
            elif m == 0:
                y = norm * P[(l, m)]
            else:
                y = math.sqrt(2) * norm * P[(l, -m)] * torch.sin(-m*phi)
            Y.append(y)
    
    return torch.stack(Y, dim=-1)  # [..., (l_max+1)^2]

另一个实用技巧是在训练初期使用较低阶的SH,随着训练进行逐步增加阶数。这可以帮助模型先学习低频特征,再逐渐添加高频细节:

class ProgressiveSH:
    def __init__(self, max_degree, total_steps):
        self.max_degree = max_degree
        self.total_steps = total_steps
        self.current_step = 0
    
    def update(self):
        self.current_step += 1
    
    @property
    def current_degree(self):
        progress = min(self.current_step / self.total_steps, 1.0)
        return int(progress * self.max_degree)
    
    def __call__(self, theta, phi):
        degree = self.current_degree
        return optimized_spherical_harmonics(degree, theta, phi)

在3D重建项目中,SH编码的选择直接影响最终渲染质量。经过多次实验发现,对于大多数室内场景,2阶SH已经足够;而对于有复杂反射或光泽表面的物体,可能需要3阶SH才能获得令人满意的结果。

更多推荐