用C++ 实现追赶法(Thomas算法)
用 C++ 实现追赶法(Thomas 算法)
一、什么是追赶法?
追赶法是专门用来解三对角线性方程组的高斯消元法。什么是三对角?就是只有三条对角线上有数,其余全是 0:
upper →
┌ ┐
│ m₀ u₀ 0 0 0 │
lower │ l₁ m₁ u₁ 0 0 │
↓ │ 0 l₂ m₂ u₂ 0 │
│ 0 0 l₃ m₃ u₃│
│ 0 0 0 l₄ m₄│
└ ┘
main_diag →
把这个结构存成四个数组:
lower = [*, l1, l2, l3, l4] // 下对角线(i 行对应 M[i-1] 的系数)
main_diag = [m0, m1, m2, m3, m4] // 主对角线(i 行对应 M[i] 的系数)
upper = [u0, u1, u2, u3, * ] // 上对角线(i 行对应 M[i+1] 的系数)
B = [b0, b1, b2, b3, b4] // 右端项
第 i 行方程就是:lower[i] * M[i-1] + main_diag[i] * M[i] + upper[i] * M[i+1] = B[i]
对于边界行(i=0 和 i=n-1),通常 main_diag = 1,其余为 0,直接确定边界值。
二、手算演示——跟着增广矩阵一步步消
拿一个具体例子,自然样条边界条件(两端已知 M0=0, M3=0),矩阵自己预先算好了:
lower = [0, 1, 2, 0]
main_diag = [1, 6, 6, 1]
upper = [0, 2, 1, 0]
B = [0, 14, -6, 0]
步骤 0:写出增广矩阵 [A | B]
M0 M1 M2 M3 | B
┌ ┬ ┐
│ 1 0 0 0 │ 0 │ 行0(边界,M0 已确定)
│ 1 6 2 0 │ 14 │ 行1
│ 0 2 6 1 │ -6 │ 行2
│ 0 0 0 1 │ 0 │ 行3(边界,M3 已确定)
└ ┴ ┘
步骤 1:行0 已经是标准式,直接用
行0 是 1·M0 + 0·M1 = 0,不用消。记 c_prime[0] = 0,d_prime[0] = 0。
消元后矩阵(行0 不动):
M0 M1 M2 M3 | B
┌ ┬ ┐
│ 1 0 0 0 │ 0 │ ← 完成
│ 1 6 2 0 │ 14 │
│ 0 2 6 1 │ -6 │
│ 0 0 0 1 │ 0 │
└ ┴ ┘
步骤 2:消行1 —— 干掉 M0
行1 是 1·M0 + 6·M1 + 2·M2 = 14。
利用行0 的关系 M0 = 0,代进去:6·M1 + 2·M2 = 14。
把 M1 系数归一化(除以 6 = main_diag[1] - lower[1]*c_prime[0]):
消元后行1: M1 + (1/3)·M2 = 7/3
即: c_prime[1] = 1/3, d_prime[1] = 7/3
矩阵更新:
M0 M1 M2 M3 | B
┌ ┬ ┐
│ 1 0 0 0 │ 0 │ ← 完成
│ 0 1 1/3 0 │ 7/3 │ ← 完成(M0 被消掉,M1 系数归一)
│ 0 2 6 1 │ -6 │
│ 0 0 0 1 │ 0 │
└ ┴ ┘
步骤 3:消行2 —— 干掉 M1
行2 是 2·M1 + 6·M2 + 1·M3 = -6。
利用行1 的关系 M1 = 7/3 - (1/3)·M2,代入行2 消掉 M1。
等价的做法(代码里就是这么写的):
w = main_diag[2] - lower[2] * c_prime[1]
= 6 - 2 * (1/3) = 16/3
新系数 = upper[2] / w = 1 / (16/3) = 3/16
新右端 = (B[2] - lower[2] * d_prime[1]) / w
= (-6 - 2 * 7/3) / (16/3)
= (-32/3) / (16/3) = -2
消元后行2: M2 + (3/16)·M3 = -2
即: c_prime[2] = 3/16, d_prime[2] = -2
矩阵更新:
M0 M1 M2 M3 | B
┌ ┬ ┐
│ 1 0 0 0 │ 0 │ ← 完成
│ 0 1 1/3 0 │ 7/3 │ ← 完成
│ 0 0 1 3/16 │ -2 │ ← 完成(M1 被消掉,M2 系数归一)
│ 0 0 0 1 │ 0 │
└ ┴ ┘
步骤 4:行3 边界,直接得到答案
行3 是 1·M3 = 0,已经是标准式。c_prime[3] = 0,d_prime[3] = 0。
M0 M1 M2 M3 | B
┌ ┬ ┐
│ 1 0 0 0 │ 0 │
│ 0 1 1/3 0 │ 7/3 │
│ 0 0 1 3/16 │ -2 │
│ 0 0 0 1 │ 0 │ ← 完成
└ ┴ ┘
消元结束。矩阵变成了上三角,每行形如 M[i] + c_prime[i]·M[i+1] = d_prime[i]。
步骤 5:赶——从下往上回代
M3 = d_prime[3] = 0 // 最后一行直接出答案
M2 = d_prime[2] - c_prime[2] * M3 = -2 - (3/16)*0 = -2
M1 = d_prime[1] - c_prime[1] * M2 = 7/3 - (1/3)*(-2) = 3
M0 = d_prime[0] - c_prime[0] * M1 = 0 - 0*3 = 0
结果:M = [0, 3, -2, 0],验算一下和直接解方程组一致。
三、套路总结——记住这两步就够了
整个算法就两个循环,对应手算的两个阶段:
追(前向消元) 赶(后向回代)
───────────── ─────────────
i=1→n-1 逐行消: i=n-2→0 倒着解:
1. 算消元后的新主对角元 M[n-1] = d_prime[n-1]
w = main[i] - lower[i] * c'[i-1]
M[i] = d_prime[i]
2. 取倒数 m = 1/w - c_prime[i] * M[i+1]
3. c'[i] = upper[i] * m
(归一化后的上对角元)
4. d'[i] = (B[i] - lower[i] * d'[i-1]) * m
(归一化后的右端项)
为什么 lower[i] * c_prime[i-1]? 上一行消元后,M[i] 在上一行的系数就是 c_prime[i-1](因为上一行变成 M[i-1] + c'[i-1]·M[i] = d'[i-1]),乘以本行 M[i-1] 的系数 lower[i],就是从上一行带来的"牵连",扣掉它本行就独立了。
为什么 B[i] 要减去 lower[i] * d_prime[i-1]? 同样的牵连,右端项也得跟着减。
四、C++ 代码——直接抄
4.1 通用模板
// 输入: lower, main_diag, upper, B (四个数组,长度 n)
// 输出: M (长度 n)
std::vector<double> c_prime(n, 0.0); // c_prime[0] = 0
std::vector<double> d_prime(n, 0.0); // d_prime[0] = 0
std::vector<double> M(n, 0.0);
// ===== 追 =====
for (int i = 1; i < n; ++i) {
double w = main_diag[i] - lower[i] * c_prime[i - 1];
double m = 1.0 / w; // 取倒数,后续乘
c_prime[i] = upper[i] * m;
d_prime[i] = (B[i] - lower[i] * d_prime[i - 1]) * m;
}
// ===== 赶 =====
M[n - 1] = d_prime[n - 1];
for (int i = n - 2; i >= 0; --i) {
M[i] = d_prime[i] - c_prime[i] * M[i + 1];
}
就这 10 行。任何时候碰到三对角方程组,把四个数组填好,这段代码直接贴进去就能用。
4.2 三次样条中的完整用法
来自 spline1d.cpp,多了构建矩阵和反推系数:
// === 1. 构建三对角矩阵 ===
std::vector<double> h(n - 1);
for (int i = 0; i < n - 1; ++i)
h[i] = s[i + 1] - s[i];
std::vector<double> lower(n, 0.0), main_diag(n, 1.0), upper(n, 0.0), B(n, 0.0);
for (int i = 1; i < n - 1; ++i) {
lower[i] = h[i - 1];
main_diag[i] = 2.0 * (h[i - 1] + h[i]);
upper[i] = h[i];
B[i] = 6.0 * ((y[i+1] - y[i]) / h[i] - (y[i] - y[i-1]) / h[i-1]);
}
// === 2. 追赶法 ===
std::vector<double> c_prime(n, 0.0), d_prime(n, 0.0), M(n, 0.0);
for (int i = 1; i < n; ++i) {
double m = 1.0 / (main_diag[i] - lower[i] * c_prime[i - 1]);
c_prime[i] = upper[i] * m;
d_prime[i] = (B[i] - d_prime[i - 1] * lower[i]) * m;
}
M[n - 1] = d_prime[n - 1];
for (int i = n - 2; i >= 0; --i)
M[i] = d_prime[i] - c_prime[i] * M[i + 1];
// === 3. 反推三次多项式系数 ===
for (int i = 0; i < n - 1; ++i) {
a_[i] = (M[i+1] - M[i]) / (6.0 * h[i]);
b_[i] = M[i] * 0.5;
c_[i] = (y[i+1] - y[i]) / h[i] - h[i] * (2.0 * M[i] + M[i+1]) / 6.0;
d_[i] = y[i];
}
五、一句话记住
追就是一行一行往下消,每行扣掉上一行带来的牵连,然后把主对角元归一化;赶就是从最后一行倒着往上代。代码就两个 for 循环,10 行。
更多推荐
所有评论(0)