用 C++ 实现追赶法(Thomas 算法)

一、什么是追赶法?

追赶法是专门用来解三对角线性方程组的高斯消元法。什么是三对角?就是只有三条对角线上有数,其余全是 0:

         upper →
         ┌                     ┐
         │ m₀  u₀   0    0   0 │
  lower  │ l₁  m₁   u₁   0   0 │
    ↓    │ 0   l₂   m₂   u₂  0 │
         │ 0   0    l₃   m₃  u₃│
         │ 0   0    0    l₄  m₄│
         └                     ┘
           main_diag →

把这个结构存成四个数组:

lower     = [*,  l1, l2, l3, l4]   // 下对角线(i 行对应 M[i-1] 的系数)
main_diag = [m0, m1, m2, m3, m4]   // 主对角线(i 行对应 M[i] 的系数)
upper     = [u0, u1, u2, u3, * ]   // 上对角线(i 行对应 M[i+1] 的系数)
B         = [b0, b1, b2, b3, b4]   // 右端项

第 i 行方程就是:lower[i] * M[i-1] + main_diag[i] * M[i] + upper[i] * M[i+1] = B[i]

对于边界行(i=0 和 i=n-1),通常 main_diag = 1,其余为 0,直接确定边界值。


二、手算演示——跟着增广矩阵一步步消

拿一个具体例子,自然样条边界条件(两端已知 M0=0, M3=0),矩阵自己预先算好了:

lower     = [0,  1,  2,  0]
main_diag = [1,  6,  6,  1]
upper     = [0,  2,  1,  0]
B         = [0, 14, -6,  0]

步骤 0:写出增广矩阵 [A | B]

  M0  M1  M2  M3  |  B
┌                   ┬    ┐
│ 1   0   0   0    │  0  │  行0(边界,M0 已确定)
│ 1   6   2   0    │ 14  │  行1
│ 0   2   6   1    │ -6  │  行2
│ 0   0   0   1    │  0  │  行3(边界,M3 已确定)
└                   ┴    ┘

步骤 1:行0 已经是标准式,直接用

行0 是 1·M0 + 0·M1 = 0,不用消。记 c_prime[0] = 0d_prime[0] = 0

消元后矩阵(行0 不动):

  M0  M1  M2  M3  |   B
┌                   ┬      ┐
│ 1   0   0   0    │   0   │ ← 完成
│ 1   6   2   0    │  14   │
│ 0   2   6   1    │  -6   │
│ 0   0   0   1    │   0   │
└                   ┴      ┘

步骤 2:消行1 —— 干掉 M0

行1 是 1·M0 + 6·M1 + 2·M2 = 14

利用行0 的关系 M0 = 0,代进去:6·M1 + 2·M2 = 14

把 M1 系数归一化(除以 6 = main_diag[1] - lower[1]*c_prime[0]):

消元后行1:  M1 + (1/3)·M2 = 7/3

即: c_prime[1] = 1/3,  d_prime[1] = 7/3

矩阵更新:

  M0  M1  M2  M3  |   B
┌                   ┬      ┐
│ 1   0    0   0   │   0   │ ← 完成
│ 0   1   1/3  0   │  7/3  │ ← 完成(M0 被消掉,M1 系数归一)
│ 0   2    6   1   │  -6   │
│ 0   0    0   1   │   0   │
└                   ┴      ┘

步骤 3:消行2 —— 干掉 M1

行2 是 2·M1 + 6·M2 + 1·M3 = -6

利用行1 的关系 M1 = 7/3 - (1/3)·M2,代入行2 消掉 M1。

等价的做法(代码里就是这么写的):

w = main_diag[2] - lower[2] * c_prime[1]
  = 6 - 2 * (1/3) = 16/3

新系数 = upper[2] / w = 1 / (16/3) = 3/16
新右端 = (B[2] - lower[2] * d_prime[1]) / w
       = (-6 - 2 * 7/3) / (16/3)
       = (-32/3) / (16/3) = -2
消元后行2:  M2 + (3/16)·M3 = -2

即: c_prime[2] = 3/16,  d_prime[2] = -2

矩阵更新:

  M0  M1  M2   M3  |   B
┌                    ┬      ┐
│ 1   0   0    0    │   0   │ ← 完成
│ 0   1  1/3   0    │  7/3  │ ← 完成
│ 0   0   1   3/16  │  -2   │ ← 完成(M1 被消掉,M2 系数归一)
│ 0   0   0    1    │   0   │
└                    ┴      ┘

步骤 4:行3 边界,直接得到答案

行3 是 1·M3 = 0,已经是标准式。c_prime[3] = 0d_prime[3] = 0

  M0  M1  M2  M3  |   B
┌                   ┬      ┐
│ 1   0   0   0    │   0   │
│ 0   1  1/3  0    │  7/3  │
│ 0   0   1  3/16  │  -2   │
│ 0   0   0   1    │   0   │ ← 完成
└                   ┴      ┘

消元结束。矩阵变成了上三角,每行形如 M[i] + c_prime[i]·M[i+1] = d_prime[i]

步骤 5:赶——从下往上回代

M3 = d_prime[3] = 0                                    // 最后一行直接出答案

M2 = d_prime[2] - c_prime[2] * M3 = -2 - (3/16)*0 = -2

M1 = d_prime[1] - c_prime[1] * M2 = 7/3 - (1/3)*(-2) = 3

M0 = d_prime[0] - c_prime[0] * M1 = 0 - 0*3 = 0

结果:M = [0, 3, -2, 0],验算一下和直接解方程组一致。


三、套路总结——记住这两步就够了

整个算法就两个循环,对应手算的两个阶段:

追(前向消元)                         赶(后向回代)
─────────────                         ─────────────
i=1→n-1 逐行消:                      i=n-2→0 倒着解:

1. 算消元后的新主对角元                 M[n-1] = d_prime[n-1]
   w = main[i] - lower[i] * c'[i-1]
                                      M[i] = d_prime[i]
2. 取倒数 m = 1/w                            - c_prime[i] * M[i+1]

3. c'[i] = upper[i] * m
   (归一化后的上对角元)

4. d'[i] = (B[i] - lower[i] * d'[i-1]) * m
   (归一化后的右端项)

为什么 lower[i] * c_prime[i-1] 上一行消元后,M[i] 在上一行的系数就是 c_prime[i-1](因为上一行变成 M[i-1] + c'[i-1]·M[i] = d'[i-1]),乘以本行 M[i-1] 的系数 lower[i],就是从上一行带来的"牵连",扣掉它本行就独立了。

为什么 B[i] 要减去 lower[i] * d_prime[i-1] 同样的牵连,右端项也得跟着减。


四、C++ 代码——直接抄

4.1 通用模板

// 输入: lower, main_diag, upper, B (四个数组,长度 n)
// 输出: M (长度 n)

std::vector<double> c_prime(n, 0.0);  // c_prime[0] = 0
std::vector<double> d_prime(n, 0.0);  // d_prime[0] = 0
std::vector<double> M(n, 0.0);

// ===== 追 =====
for (int i = 1; i < n; ++i) {
    double w = main_diag[i] - lower[i] * c_prime[i - 1];
    double m = 1.0 / w;                          // 取倒数,后续乘
    c_prime[i] = upper[i] * m;
    d_prime[i] = (B[i] - lower[i] * d_prime[i - 1]) * m;
}

// ===== 赶 =====
M[n - 1] = d_prime[n - 1];
for (int i = n - 2; i >= 0; --i) {
    M[i] = d_prime[i] - c_prime[i] * M[i + 1];
}

就这 10 行。任何时候碰到三对角方程组,把四个数组填好,这段代码直接贴进去就能用。

4.2 三次样条中的完整用法

来自 spline1d.cpp,多了构建矩阵和反推系数:

// === 1. 构建三对角矩阵 ===
std::vector<double> h(n - 1);
for (int i = 0; i < n - 1; ++i)
    h[i] = s[i + 1] - s[i];

std::vector<double> lower(n, 0.0), main_diag(n, 1.0), upper(n, 0.0), B(n, 0.0);
for (int i = 1; i < n - 1; ++i) {
    lower[i]     = h[i - 1];
    main_diag[i] = 2.0 * (h[i - 1] + h[i]);
    upper[i]     = h[i];
    B[i]         = 6.0 * ((y[i+1] - y[i]) / h[i] - (y[i] - y[i-1]) / h[i-1]);
}

// === 2. 追赶法 ===
std::vector<double> c_prime(n, 0.0), d_prime(n, 0.0), M(n, 0.0);
for (int i = 1; i < n; ++i) {
    double m = 1.0 / (main_diag[i] - lower[i] * c_prime[i - 1]);
    c_prime[i] = upper[i] * m;
    d_prime[i] = (B[i] - d_prime[i - 1] * lower[i]) * m;
}
M[n - 1] = d_prime[n - 1];
for (int i = n - 2; i >= 0; --i)
    M[i] = d_prime[i] - c_prime[i] * M[i + 1];

// === 3. 反推三次多项式系数 ===
for (int i = 0; i < n - 1; ++i) {
    a_[i] = (M[i+1] - M[i]) / (6.0 * h[i]);
    b_[i] = M[i] * 0.5;
    c_[i] = (y[i+1] - y[i]) / h[i] - h[i] * (2.0 * M[i] + M[i+1]) / 6.0;
    d_[i] = y[i];
}

五、一句话记住

就是一行一行往下消,每行扣掉上一行带来的牵连,然后把主对角元归一化;就是从最后一行倒着往上代。代码就两个 for 循环,10 行。

更多推荐