别再暴力求解了!用Python实现线段树,轻松搞定LeetCode区间查询难题
用Python线段树征服LeetCode区间查询问题的实战指南
在算法面试和竞赛中,区间查询问题就像一位不速之客——它总在你最紧张的时候出现,用O(n)的暴力解法羞辱你的代码。但今天,我们要用Python的线段树给这位"客人"一点颜色看看。
1. 从LeetCode实战认识线段树的威力
LeetCode 307题"区域和检索 - 数组可修改"是个典型的拦路虎。题目要求我们实现一个类,能够高效处理两种操作:更新数组中的某个元素,以及查询数组中某个区间的和。面对这样的问题,未经训练的开发者往往会陷入暴力求解的陷阱。
暴力解法的致命缺陷 :
class NumArray:
def __init__(self, nums):
self.nums = nums.copy()
def update(self, index, val):
self.nums[index] = val
def sumRange(self, left, right):
return sum(self.nums[left:right+1])
这个解法虽然简单直接,但每次查询都需要O(n)时间。当操作次数达到10^5量级时,这种解法在LeetCode上必然会超时。
而线段树解法则能将两种操作的时间复杂度都优化到O(logn):
class SegmentTreeNode:
def __init__(self, l, r):
self.l = l
self.r = r
self.left = None
self.right = None
self.sum = 0
class NumArray:
def __init__(self, nums):
def build(l, r):
node = SegmentTreeNode(l, r)
if l == r:
node.sum = nums[l]
return node
mid = (l + r) // 2
node.left = build(l, mid)
node.right = build(mid+1, r)
node.sum = node.left.sum + node.right.sum
return node
self.root = build(0, len(nums)-1)
def update(self, index, val):
def update_val(node):
if node.l == node.r:
node.sum = val
return
if index <= node.left.r:
update_val(node.left)
else:
update_val(node.right)
node.sum = node.left.sum + node.right.sum
update_val(self.root)
def sumRange(self, left, right):
def query(node, l, r):
if node.r < l or node.l > r:
return 0
if l <= node.l and node.r <= r:
return node.sum
return query(node.left, l, r) + query(node.right, l, r)
return query(self.root, left, right)
2. 线段树的构建艺术
线段树的核心思想是分治——将整个区间不断二分,直到每个子区间只包含一个元素。这种结构天然适合用递归来实现。
构建过程的三个关键点 :
- 节点设计 :每个节点需要存储区间范围[l, r]和该区间的聚合值(如和、最大值等)
- 递归终止条件 :当区间长度为1时,即为叶子节点
- 合并子节点信息 :父节点的值由其左右子节点的值聚合而来
优化存储的实用技巧 :
- 对于静态线段树(区间不变化),可以用数组模拟二叉树
- 对于n个元素的数组,线段树的空间复杂度为O(4n)
- 预先分配数组空间比动态创建节点更高效
class ArraySegmentTree:
def __init__(self, data):
self.n = len(data)
self.size = 1
while self.size < self.n:
self.size <<= 1
self.tree = [0] * (2 * self.size)
# 初始化叶子节点
for i in range(self.n):
self.tree[self.size + i] = data[i]
# 构建内部节点
for i in range(self.size - 1, 0, -1):
self.tree[i] = self.tree[2*i] + self.tree[2*i+1]
3. 区间查询的优雅实现
线段树的查询操作体现了分治思想的精髓。当查询区间完全覆盖当前节点区间时,直接返回节点值;当没有重叠时,返回空值;部分重叠时则继续向左右子树查询。
查询优化的关键点 :
- 利用区间覆盖判断避免不必要的递归
- 合并左右子树的查询结果
- 对于求和查询,空区间返回0;对于最大值查询,返回-∞
def query_sum(self, l, r):
res = 0
l += self.size
r += self.size
while l <= r:
if l % 2 == 1:
res += self.tree[l]
l += 1
if r % 2 == 0:
res += self.tree[r]
r -= 1
l = l // 2
r = r // 2
return res
性能对比 :
| 操作类型 | 暴力解法 | 线段树解法 |
|---|---|---|
| 单点更新 | O(1) | O(logn) |
| 区间查询 | O(n) | O(logn) |
| 空间占用 | O(n) | O(4n) |
4. 延迟更新:线段树的进阶技巧
当面对区间更新操作时,朴素实现需要对每个受影响的节点进行更新,时间复杂度退化为O(n)。延迟标记(Lazy Propagation)技术解决了这一痛点。
延迟标记的工作原理 :
- 当更新完全覆盖某个节点区间时,先更新该节点并打上标记
- 只有当需要访问该节点的子节点时,才将标记下传
- 标记下传后清除当前节点的标记
class LazySegmentTree:
def __init__(self, data):
self.n = len(data)
self.size = 1
while self.size < self.n:
self.size <<= 1
self.tree = [0] * (2 * self.size)
self.lazy = [0] * (2 * self.size)
for i in range(self.n):
self.tree[self.size + i] = data[i]
for i in range(self.size - 1, 0, -1):
self.tree[i] = self.tree[2*i] + self.tree[2*i+1]
def push(self, node, node_l, node_r):
if self.lazy[node] != 0:
mid = (node_l + node_r) // 2
# 更新左子节点
self.tree[2*node] += self.lazy[node] * (mid - node_l + 1)
self.lazy[2*node] += self.lazy[node]
# 更新右子节点
self.tree[2*node+1] += self.lazy[node] * (node_r - mid)
self.lazy[2*node+1] += self.lazy[node]
# 清除当前节点标记
self.lazy[node] = 0
def range_add(self, l, r, val):
self._range_add(1, 0, self.size-1, l, r, val)
def _range_add(self, node, node_l, node_r, l, r, val):
if r < node_l or node_r < l:
return
if l <= node_l and node_r <= r:
self.tree[node] += val * (node_r - node_l + 1)
self.lazy[node] += val
return
self.push(node, node_l, node_r)
mid = (node_l + node_r) // 2
self._range_add(2*node, node_l, mid, l, r, val)
self._range_add(2*node+1, mid+1, node_r, l, r, val)
self.tree[node] = self.tree[2*node] + self.tree[2*node+1]
5. 线段树在LeetCode中的实战应用
线段树不仅能解决区间求和问题,还能高效处理各种区间统计问题。以下是几个典型应用场景:
区间最值问题(LeetCode 239) :
class MaxSegmentTree:
def __init__(self, data):
self.n = len(data)
self.size = 1
while self.size < self.n:
self.size <<= 1
self.tree = [-float('inf')] * (2 * self.size)
for i in range(self.n):
self.tree[self.size + i] = data[i]
for i in range(self.size - 1, 0, -1):
self.tree[i] = max(self.tree[2*i], self.tree[2*i+1])
def query_max(self, l, r):
res = -float('inf')
l += self.size
r += self.size
while l <= r:
if l % 2 == 1:
res = max(res, self.tree[l])
l += 1
if r % 2 == 0:
res = max(res, self.tree[r])
r -= 1
l = l // 2
r = r // 2
return res
区间染色问题(LeetCode 699) :
class ColorSegmentTree:
def __init__(self, size):
self.size = 1
while self.size < size:
self.size <<= 1
self.tree = [0] * (2 * self.size)
self.lazy = [0] * (2 * self.size)
def push(self, node, node_l, node_r):
if self.lazy[node] != 0:
self.tree[2*node] = self.lazy[node]
self.lazy[2*node] = self.lazy[node]
self.tree[2*node+1] = self.lazy[node]
self.lazy[2*node+1] = self.lazy[node]
self.lazy[node] = 0
def range_set(self, l, r, val):
self._range_set(1, 0, self.size-1, l, r, val)
def _range_set(self, node, node_l, node_r, l, r, val):
if r < node_l or node_r < l:
return
if l <= node_l and node_r <= r:
self.tree[node] = val
self.lazy[node] = val
return
self.push(node, node_l, node_r)
mid = (node_l + node_r) // 2
self._range_set(2*node, node_l, mid, l, r, val)
self._range_set(2*node+1, mid+1, node_r, l, r, val)
self.tree[node] = max(self.tree[2*node], self.tree[2*node+1])
def query_max(self, l, r):
return self._query_max(1, 0, self.size-1, l, r)
def _query_max(self, node, node_l, node_r, l, r):
if r < node_l or node_r < l:
return -float('inf')
if l <= node_l and node_r <= r:
return self.tree[node]
self.push(node, node_l, node_r)
mid = (node_l + node_r) // 2
return max(self._query_max(2*node, node_l, mid, l, r),
self._query_max(2*node+1, mid+1, node_r, l, r))
6. 线段树的扩展应用与优化
动态开点线段树 :当区间范围很大但实际使用的点很少时(如[1, 1e9]),传统的线段树会浪费大量空间。动态开点线段树只在需要时才创建节点。
class DynamicSegmentTreeNode:
def __init__(self, l, r):
self.l = l
self.r = r
self.left = None
self.right = None
self.sum = 0
self.lazy = 0
class DynamicSegmentTree:
def __init__(self):
self.root = DynamicSegmentTreeNode(0, int(1e9))
def push(self, node):
if node.lazy != 0 and node.l != node.r:
mid = (node.l + node.r) // 2
if not node.left:
node.left = DynamicSegmentTreeNode(node.l, mid)
if not node.right:
node.right = DynamicSegmentTreeNode(mid+1, node.r)
node.left.sum += node.lazy * (mid - node.l + 1)
node.left.lazy += node.lazy
node.right.sum += node.lazy * (node.r - mid)
node.right.lazy += node.lazy
node.lazy = 0
def range_add(self, l, r, val):
self._range_add(self.root, l, r, val)
def _range_add(self, node, l, r, val):
if node.r < l or node.l > r:
return
if l <= node.l and node.r <= r:
node.sum += val * (node.r - node.l + 1)
node.lazy += val
return
self.push(node)
self._range_add(node.left, l, r, val)
self._range_add(node.right, l, r, val)
node.sum = (node.left.sum if node.left else 0) + (node.right.sum if node.right else 0)
二维线段树 :处理二维平面上的区间查询问题,通过嵌套线段树实现。虽然时间复杂度较高(O(log²n)),但在某些场景下仍然有用武之地。
线段树与其他数据结构的结合 :
- 线段树+离散化:处理坐标范围很大但实际点很少的问题
- 线段树+二分查找:快速定位满足特定条件的区间
- 线段树+扫描线:解决矩形面积并、周长并等问题
7. 线段树的常见陷阱与调试技巧
初学者常犯的错误 :
- 区间边界处理不当,导致无限递归或错误结果
- 忘记下传延迟标记,导致查询结果错误
- 空间分配不足,导致数组越界
- 聚合函数选择不当(如用求和线段树处理最大值问题)
调试线段树的实用方法 :
- 打印整棵树的结构,验证构建是否正确
- 对每个更新操作后检查相关节点的值
- 使用小规模测试用例手动验证
- 比较暴力解法和线段树解法的结果
def print_tree(tree):
for i in range(1, len(tree.tree)):
print(f"Node {i}: [{tree.tree[i].l}, {tree.tree[i].r}] = {tree.tree[i].sum}", end="")
if tree.lazy[i] != 0:
print(f" (lazy: {tree.lazy[i]})")
else:
print()
8. 线段树的性能优化实战
优化构建过程 :
- 使用迭代而非递归构建,避免递归深度过大
- 预先计算所有需要的节点,减少动态分配的开销
- 对于特定问题,可以省略不必要的节点属性
内存优化技巧 :
- 使用更紧凑的数据结构存储节点
- 对于布尔值属性,使用位运算压缩存储
- 在C++/Rust等语言中,可以使用内存池技术
查询优化 :
- 对于频繁查询的热点区间,可以缓存结果
- 根据查询模式调整树的结构(如偏向左侧或右侧)
- 使用非递归实现减少函数调用开销
class OptimizedSegmentTree:
def __init__(self, data):
self.n = len(data)
self.size = 1
while self.size < self.n:
self.size <<= 1
# 使用单个数组存储值和延迟标记
self.tree = [(0, 0)] * (2 * self.size) # (value, lazy)
# 初始化叶子节点
for i in range(self.n):
self.tree[self.size + i] = (data[i], 0)
# 构建内部节点
for i in range(self.size - 1, 0, -1):
self.tree[i] = (self.tree[2*i][0] + self.tree[2*i+1][0], 0)
def push(self, node, node_len):
val, lazy = self.tree[node]
if lazy != 0:
self.tree[2*node] = (self.tree[2*node][0] + lazy * (node_len // 2), self.tree[2*node][1] + lazy)
self.tree[2*node+1] = (self.tree[2*node+1][0] + lazy * (node_len // 2), self.tree[2*node+1][1] + lazy)
self.tree[node] = (val, 0)
def range_add(self, l, r, val):
l += self.size
r += self.size
l_len = 1
r_len = 1
while l <= r:
if l % 2 == 1:
self.tree[l] = (self.tree[l][0] + val * l_len, self.tree[l][1] + val)
l += 1
if r % 2 == 0:
self.tree[r] = (self.tree[r][0] + val * r_len, self.tree[r][1] + val)
r -= 1
l //= 2
r //= 2
l_len *= 2
r_len *= 2
def query_sum(self, l, r):
res = 0
l += self.size
r += self.size
# 需要先下传所有标记
self._push_all(l)
self._push_all(r)
while l <= r:
if l % 2 == 1:
res += self.tree[l][0]
l += 1
if r % 2 == 0:
res += self.tree[r][0]
r -= 1
l //= 2
r //= 2
return res
def _push_all(self, idx):
k = self.size
while idx > 1:
idx >>= 1
k >>= 1
self.push(idx, k)
更多推荐
所有评论(0)