Home 线段树和树状数组
Post
Cancel

线段树和树状数组

线段树

线段树是用来维护区间信息的数据结构.

线段树可以在 $O(\log{n})$ 的时间复杂度内实现单点修改、区间修改、区间查询(区间求和,求区间最大值,求区间最小值)等操作。

线段树将每个长度不为 1 的区间划分成左右两个区间递归求解,把整个线段划分为一个树形结构,通过合并左右两区间信息来求得该区间的信息。这种数据结构可以方便地进行大部分的区间操作。

有个大小为 5 的数组 a={10,11,12,13,14},要将其转化为线段树,有以下做法:设线段树的根节点编号为 ,用数组 d 来保存线段树,d[i] 用来保存线段树上编号为 i 的节点的值(这里每个节点所维护的值就是这个节点所表示的区间总和)。

线段树

因此, d[i] 的左孩子就是 d[i*2],右孩子就是 d[i*2+1],父节点就是 d[i/2]

如果 d[i] 表示的是区间 [s,t], 那么左孩子 d[i*2] 表示的是区间 [s,(s+t)/2],右孩子 d[i*2+1] 表示的是区间 [(s+t)/2+1,t]

在实现时,我们考虑递归建树。设当前的根节点为 p,如果根节点管辖的区间长度已经是 1,则可以直接根据 a 数组上相应位置的值初始化该节点。否则我们将该区间从中点处分割为两个子区间,分别进入左右子节点递归建树,最后合并两个子节点的信息。

1
2
3
4
5
6
7
8
9
10
11
12
13
void build(int s, int t, int p) {
  // 对 [s,t] 区间建立线段树,当前根的编号为 p
  if (s == t) {
    d[p] = a[s];
    return;
  }
  int m = s + ((t - s) >> 1);
  // 移位运算符的优先级小于加减法,所以加上括号
  // 如果写成 (s + t) >> 1 可能会超出 int 范围
  build(s, m, p * 2), build(m + 1, t, p * 2 + 1);
  // 递归对左右区间建树
  d[p] = d[p * 2] + d[(p * 2) + 1];
}

如果采用上述的堆式存储, 共有 n 个叶子结点, 那么线段树需要的存储空间是 $2^{\lceil\log{n}+1\rceil}$.

推导: 线段树的深度是 $O(\lceil\log{n}\rceil)$ (这里说的是深度不是高度, 而高度, 或者说层数, 应该是深度+1), 而且还是一棵完全二叉树. 我们暂且将它补成一棵满二叉树, 也就是添加上一些没有实际用处的叶子结点, 让叶子结点个数变为 $2^{\lceil\log{n}\rceil}$, 那么此时这棵满二叉树的总结点数就是 $2^{\lceil\log{n}+1\rceil}-1$. 为了方便可以直接把数组长度设为 $4n$, 因为 $\dfrac{2^{\lceil\log{n}+1\rceil}-1}{n}$ 的最大值在 $n=2^x+1,x\in N^+$ 时取到, 此时的结点个数为 $2^{\lceil\log{n}+1\rceil}-1=2^{x+1+1}-1=4\cdot 2^x-1=4(n-1)-1=4n-5$.

区间查询

包括求区间和、区间最值. 每个区间都必定能够拆分为几个子区间, 如果是求区间和就是不断二分找到这些子区间, 并将子区间的和加起来; 如果是求区间最值就是将各个子区间的最值进行比较, 最终的最值就是查询区间的最值.

以查询区间和为例, 代码如下:

1
2
3
4
5
6
7
8
9
10
11
int getsum(int l, int r, int s, int t, int p) {
  // [l, r] 为查询区间, [s, t] 为当前节点包含的区间, p 为当前节点的编号
  if (l <= s && t <= r)
    return d[p];  // 当前区间为询问区间的子集时直接返回当前区间的和
  int m = s + ((t - s) >> 1), sum = 0;
  if (l <= m) sum += getsum(l, r, s, m, p * 2);
  // 如果左儿子代表的区间 [s, m] 与询问区间有交集, 则递归查询左儿子
  if (r > m) sum += getsum(l, r, m + 1, t, p * 2 + 1);
  // 如果右儿子代表的区间 [m + 1, t] 与询问区间有交集, 则递归查询右儿子
  return sum;
}

区间修改与懒惰标记

如果要求修改区间 [l,r],把所有包含在区间 [l,r] 中的节点都遍历一次、修改一次,时间复杂度无法承受。我们这里要引入一个叫做「懒惰标记」的东西。

它的工作方式是这样的: 例如我们需要对区间 [l,r] 进行加 a 的操作, 并且 [l,r] 可以被拆分为 [l,l][l+1,r] 两个子区间. 我们先只对 [l,l][l+1,r] 更新懒惰标记, 即只把这两个区间对应的结点的 tag 值修改为区间的变化量. 「懒惰」就体现在, 对于区间 [l+1,r] 的所有子区间, 我们是不会在这个时候去更新它们的值的, 因为目前不确定它的那些子区间是否会被用到. 如果下一次查询操作需要查询到 [l+1,r] 的某个子区间时, 我们在搜索到这个子区间的过程中一定会经过 [l+1,r], 因此会发现它的 tag 值需要被更新. 这个时候, 就会进行「标记下放」操作, 也就是清除 [l+1,r] 的标记, 给它的子区间更新标记, 注意这里也只用更新到目标子区间那一层.

区间修改 (加上某个值):

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
void update(int l, int r, int c, int s, int t, int p) {
  // [l, r] 为修改区间, c 为被修改的元素的变化量, [s, t] 为当前节点包含的区间, p
  // 为当前节点的编号
  if (l <= s && t <= r) {
    d[p] += (t - s + 1) * c, b[p] += c;
    return;
  }  // 当前区间为修改区间的子集时直接修改当前节点的值,然后打标记,结束修改
  int m = s + ((t - s) >> 1);
  if (b[p]) {
    // 如果当前节点的懒标记非空,则更新当前节点两个子节点的值和懒标记值
    d[p * 2] += b[p] * (m - s + 1), d[p * 2 + 1] += b[p] * (t - m);
    b[p * 2] += b[p], b[p * 2 + 1] += b[p];  // 将标记下传给子节点
    b[p] = 0;                                // 清空当前节点的标记
  }
  if (l <= m) update(l, r, c, s, m, p * 2);
  if (r > m) update(l, r, c, m + 1, t, p * 2 + 1);
  d[p] = d[p * 2] + d[p * 2 + 1];
}

区间查询:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
int getsum(int l, int r, int s, int t, int p) {
  // [l, r] 为查询区间, [s, t] 为当前节点包含的区间, p 为当前节点的编号
  if (l <= s && t <= r) return d[p];
  // 当前区间为询问区间的子集时直接返回当前区间的和
  int m = s + ((t - s) >> 1);
  if (b[p]) {
    // 如果当前节点的懒标记非空,则更新当前节点两个子节点的值和懒标记值
    d[p * 2] += b[p] * (m - s + 1), d[p * 2 + 1] += b[p] * (t - m);
    b[p * 2] += b[p], b[p * 2 + 1] += b[p];  // 将标记下传给子节点
    b[p] = 0;                                // 清空当前节点的标记
  }
  int sum = 0;
  if (l <= m) sum = getsum(l, r, s, m, p * 2);
  if (r > m) sum += getsum(l, r, m + 1, t, p * 2 + 1);
  return sum;
}

区间修改 (修改为某个值):

和刚才的区别就是, 存储线段树结点的数组 d 和懒标记数组 b 都是直接等于而不是加上相应的值. 这里需要注意的是, 如果懒标记 b 被修改为 0 了, 那还怎么确定这个懒标记是否被修改过? 因此, 又引入一个 v 数组记录懒标记是否被修改. 实际上, 在刚才的情形中, 如果加上的值有正有负, 也需要进行这样的记录. 总之, 根据具体题目要求而定.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
void update(int l, int r, int c, int s, int t, int p) {
  if (l <= s && t <= r) {
    d[p] = (t - s + 1) * c, b[p] = c;
    return;
  }
  int m = s + ((t - s) >> 1);
  // 额外数组储存是否修改值
  if (v[p]) {
    d[p * 2] = b[p] * (m - s + 1), d[p * 2 + 1] = b[p] * (t - m);
    b[p * 2] = b[p * 2 + 1] = b[p];
    v[p * 2] = v[p * 2 + 1] = 1;
    v[p] = 0;
  }
  if (l <= m) update(l, r, c, s, m, p * 2);
  if (r > m) update(l, r, c, m + 1, t, p * 2 + 1);
  d[p] = d[p * 2] + d[p * 2 + 1];
}

int getsum(int l, int r, int s, int t, int p) {
  if (l <= s && t <= r) return d[p];
  int m = s + ((t - s) >> 1);
  if (v[p]) {
    d[p * 2] = b[p] * (m - s + 1), d[p * 2 + 1] = b[p] * (t - m);
    b[p * 2] = b[p * 2 + 1] = b[p];
    v[p * 2] = v[p * 2 + 1] = 1;
    v[p] = 0;
  }
  int sum = 0;
  if (l <= m) sum = getsum(l, r, s, m, p * 2);
  if (r > m) sum += getsum(l, r, m + 1, t, p * 2 + 1);
  return sum;
}

一些优化

总结一些线段树的优化:

  • 在叶子节点处无需下放懒惰标记,所以懒惰标记可以不下传到叶子节点。
  • 下放懒惰标记可以写一个专门的函数 pushdown,从儿子节点更新当前节点也可以写一个专门的函数 maintain(或者对称地用 pushup),降低代码编写难度。
  • 标记永久化:如果确定懒惰标记不会在中途被加到溢出(即超过了该类型数据所能表示的最大范围),那么就可以将标记永久化。标记永久化可以避免下传懒惰标记,只需在进行询问时把标记的影响加到答案当中,从而降低程序常数。具体如何处理与题目特性相关,需结合题目来写。这也是树套树和可持久化数据结构中会用到的一种技巧。

树状数组

树状数组

树状数组的结构和线段树有些类似:用一个大节点表示一些小节点的信息,进行查询的时候只需要查询一些大节点而不是所有的小节点。

树状数组有趣的地方在于, 由它的结点编号可以推出它管理的是原数组 a 的哪些元素, 而线段树没有这个特性, 还需要对每个结点所表示的区间加以记录.

编号为 i 的结点 $c_i$ 所代表的区间就是 [i-lowbit(i)+1,i]. 因此, 这个结点所代表区间的长度就是 lowbit(i).

树状数组最重要的函数就是 lowbit.

1
#define lowbit(x) ((x)&(-x))

O(n) 建树

1
2
3
4
5
6
7
void init() {
  for (int i = 1; i <= n; ++i) {
    t[i] += a[i];
    int j = i + lowbit(i);
    if (j <= n) t[j] += t[i];
  }
}

单点修改

只需要更新所有的上级:

例如结点 1011000 的区间长度为 8, 这 8 个数也包含在结点 1100000 中, 因为结点 1100000 表示的区间是 [1000001,1100000], 而 1011000 表示的区间是 [1010000,1011000], 是它的子区间 (其实是因为 1100000=lowbit(1011000)).

1
2
3
4
5
6
void add(int x, int k) {
  while (x <= n) {  // 不能越界
    c[x] = c[x] + k;
    x = x + lowbit(x);
  }
}

前缀求和:

1
2
3
4
5
6
7
8
int getsum(int x) {  // a[1]..a[x]的和
  int ans = 0;
  while (x >= 1) {
    ans = ans + c[x];
    x = x - lowbit(x);
  }
  return ans;
}

区间求和

众所周知, a 的差分数组 b 求出的长度为 k 的前缀和其实就是 a 的元素 a[k], 但是怎么求出 $\sum_{i=1}^k{a[i]}$ 呢?

手动推导一下:

\[\begin{eqnarray} \label{eq} \sum_{i=1}^k{a[i]}&=&\sum_{i=1}^k{\sum_{j=1}^i{b[i]}} \nonumber \\ ~&=&k\cdot b[1]+(k-1)\cdot b[2]+\dots+1\cdot b[k] \nonumber \\ ~&=&\sum_{i=1}^k(k-i+1)\cdot b[i] \nonumber \\ ~&=&(k+1)\cdot \sum_{i=1}^kb[i]-\sum_{i=1}^ki\cdot b[i] \end{eqnarray}\]

因此只需要用两个树状数组分别维护 $b[i]$ 和 $i\cdot b[i]$ 即可.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
int t1[MAXN], t2[MAXN], n;

inline int lowbit(int x) { return x & (-x); }

void add(int k, int v) {
  int v1 = k * v;
  while (k <= n) {
    t1[k] += v, t2[k] += v1;
    k += lowbit(k);
  }
}

int getsum(int *t, int k) {
  int ret = 0;
  while (k) {
    ret += t[k];
    k -= lowbit(k);
  }
  return ret;
}

void add1(int l, int r, int v) {
  add(l, v), add(r + 1, -v);  // 将区间加差分为两个前缀加
}

long long getsum1(int l, int r) {
  return (r + 1ll) * getsum(t1, r) - 1ll * l * getsum(t1, l - 1) -
         (getsum(t2, r) - getsum(t2, l - 1));
}
This post is licensed under CC BY 4.0 by the author.