线段树(Segment Tree)是一种用来维护区间的数据结构。

与树状数组相比,线段树可以实现时间复杂度在 O(logn)O(\log n) 级别的区间修改,还可以同时支持多种操作(加、乘、最值等)。

操作列表

  • 上传(pushup)
  • 建树(build)
  • 下放懒标记(pushdown)
  • 区间查询(query)
  • 区间修改(modify)

通用操作

存储线段树

线段树是一个典型的二叉树,因此我们可以使用一个数组来存储线段树。

分析:很容易就知道线段树的深度为 logn\lceil\log n\rceil ,可得线段树的节点个数为 2logn+112^{\left\lceil\log{n}\right\rceil+1}-1,粗略估计开大小为 4n4n 的数组即可(可以使用位运算写成 n << 2)。

struct node {
    int l, r;
    long long s, d;

    node() {
        l = r = s = d = 0;
    }
    node(int _l, int _r) {
        l = _l;
        r = _r;
        s = d = 0;
    }
} tr[100005 << 2];
变量名 用途
l 区间的左端点
r 区间的右端点
s 区间和
d 懒标记

上传(pushup)

之所以把上传放在建树前面说,是因为建树的时候要用到它。

/**
 * 上传信息
 * @param u 父节点下标
 */
inline void pushup(int u) {
    tr[u].s = tr[u << 1].s + tr[u << 1 | 1].s;
}

将两个子节点所代表的区间的和相加即为父区间的和。

建树(build)

/**
 * 建立线段树
 * @param u 根节点下标
 * @param l 左端点
 * @param r 右端点
 */
void build(int u, int l, int r) {
    tr[u] = node(l, r);
    if (l == r) {
        tr[u].s = a[l];
        return;
    }
    int mid = l + r >> 1;
    build(u << 1, l, mid);
    build(u << 1 | 1, mid + 1, r);
    pushup(u);
}

先初始化当前区间,接下来分两种情况:

  1. 若当前区间长度等于 1  (l=r)1\ \ (l = r) ,则直接将当前区间的区间和赋值为 a[l] 即可。
  2. 若当前区间长度大于 1  (l<r)1\ \ (l < r) ,则将区间平均分成两部分(即从 (l+r)/2\lfloor(l+r)/2\rfloor 处断开分为两个区间,可写作 l + r >> 1),继续向下递归建立左右子树即可。

需要注意的是两个子区间没有交集,因此左子树的左端点是 ll 、右端点是 midmid ,右子树的左端点是 mid+1mid+1 、右端点是 rr

区间查询(query)

/**
 * 区间查询
 * @param u 父节点
 * @param l 左端点
 * @param r 右端点
 */
long long query(int u, int l, int r) {
    if (tr[u].l >= l && tr[u].r <= r) {  // 被包含直接返回当前区间和
        return tr[u].s;
    }
    int mid = tr[u].l + tr[u].r >> 1;
    long long s = 0;
    pushdown(u);                                // 下放懒标记
    if (l <= mid) s += query(u << 1, l, r);     // 和左侧有交集
    if (r > mid) s += query(u << 1 | 1, l, r);  // 和右侧有交集
    return s;
}
  1. 如果这个区间被包含,直接返回该区间的和。
  2. 如果和左儿子区间有交集,则继续向左儿子区间递归查询。
  3. 如果和右儿子区间有交集,则继续向右儿子区间递归查询。

需要注意的是在递归查询左右儿子区间之前要先下放懒标记(pushdown),否则会出问题。

区间加

本部分以 洛谷 P3372 【模板】线段树 1 为例子来简述一下线段树区间加的实现。

下放懒标记(pushdown)

/**
 * 下放懒标记
 * @param u 父节点下标
 */
inline void pushdown(int u) {
    if (!tr[u].d) return;
    // 处理左子树
    tr[u << 1].d += tr[u].d;
    tr[u << 1].s += (tr[u << 1].r - tr[u << 1].l + 1) * tr[u].d;
    // 处理右子树
    tr[u << 1 | 1].d += tr[u].d;
    tr[u << 1 | 1].s += (tr[u << 1 | 1].r - tr[u << 1 | 1].l + 1) * tr[u].d;
    // 清除懒标记
    tr[u].d = 0;
}

这部分代码其实很简单。

将左、右子树的懒标记加上父节点的懒标记,区间和加上 (rl+1)×d(r - l + 1)\times dr,lr, l 分别表示儿子区间的左、右端点,dd表示父节点的懒标记),最后清空父节点的懒标记即可。

区间修改(modify)

/**
 * 区间修改
 * @param u 父节点下标
 * @param l 左端点
 * @param r 右端点
 * @param d 增加的值
 */
void modify(int u, int l, int r, int d) {
    if (tr[u].l >= l && tr[u].r <= r) {  // 被包含直接修改
        tr[u].d += d;
        tr[u].s += (tr[u].r - tr[u].l + 1) * d;
        return;
    }
    int mid = tr[u].l + tr[u].r >> 1;
    pushdown(u);                               // 下放懒标记
    if (l <= mid) modify(u << 1, l, r, d);     // 和左侧有交集
    if (r > mid) modify(u << 1 | 1, l, r, d);  // 和右侧有交集
    pushup(u);                                 // 上传新信息
}

区间修改和区间查询的实现相似。

  1. 如果当前区间被包含,直接添加懒标记并修改区间和。
  2. 如果和左儿子区间有交集,则继续向左儿子区间递归修改。
  3. 如果和右儿子区间有交集,则继续向右儿子区间递归修改。

需要注意的是在递归修改左右儿子区间之前要先下放懒标记(pushdown),修改完成以后要上传新信息(pushup),否则会出问题。

区间加、乘

本部分以 洛谷 P3373 【模板】线段树 2 为例子来简述一下线段树区间加、乘的实现。

在编写之前,结构体中需要先添加一个乘法的懒标记 x ,并将其赋初值为 11 ,修改之后的结构体如下所示。

struct node {
    int l, r;
    long long s, d, x;

    node() {
        l = r = s = d = 0;
        x = 1;
    }
    node(int _l, int _r) {
        l = _l, r = _r;
        s = d = 0;
        x = 1;
    }
} tr[100005 << 2];

下放懒标记(pushdown)

/**
 * 下放懒标记
 * @param u 父节点下标
 * @attention 先乘后加
 */
void pushdown(int u) {
    // 左子树
    tr[u << 1].s = ((tr[u << 1].s * tr[u].x) + (tr[u << 1].r - tr[u << 1].l + 1) * tr[u].d) % p;
    tr[u << 1].x = tr[u << 1].x * tr[u].x % p;
    tr[u << 1].d = (tr[u << 1].d * tr[u].x + tr[u].d) % p;
    // 右子树
    tr[u << 1 | 1].s = ((tr[u << 1 | 1].s * tr[u].x) + (tr[u << 1 | 1].r - tr[u << 1 | 1].l + 1) * tr[u].d) % p;
    tr[u << 1 | 1].x = tr[u << 1 | 1].x * tr[u].x % p;
    tr[u << 1 | 1].d = (tr[u << 1 | 1].d * tr[u].x + tr[u].d) % p;
    // 清除懒标记
    tr[u].d = 0;
    tr[u].x = 1;
}

此处遵循先乘后加的原则,先修改区间和,再修改乘法懒标记,最后修改加法懒标记,不要忘记 mod p\bmod\ p

注意:此处清除懒标记的时候,乘法懒标记应修改为 11

区间修改(modify)

/**
 * 区间修改
 * @details 修改区间 [l, r] 中的每一个数
 * @param u 父节点下标
 * @param l 左端点
 * @param r 右端点
 * @param x 乘上的数
 * @param d 增加的值
 */
void modify(int u, int l, int r, long long x, long long d) {
    // 被包含直接修改
    if (tr[u].l >= l && tr[u].r <= r) {
        tr[u].s = ((tr[u].s * x) + (tr[u].r - tr[u].l + 1) * d) % p;
        tr[u].x = tr[u].x * x % p;
        tr[u].d = (tr[u].d * x + d) % p;
        return;
    }
    int mid = tr[u].l + tr[u].r >> 1;
    pushdown(u);                                  // 下放懒标记
    if (l <= mid) modify(u << 1, l, r, x, d);     // 和左侧有交集
    if (r > mid) modify(u << 1 | 1, l, r, x, d);  // 和右侧有交集
    pushup(u);                                    // 上传新信息
}

大体上和加法的修改函数一样,而在修改时与下放懒标记做法相同,遵循先乘后加的原则。

调用的时候若只需要使用乘法部分,加数设置为 00 即可。若只需要使用加法部分,乘数设置为 11 即可。

全部代码

到这里基本操作就说完了,下面是全部的 AC 代码。

区间加

#include <bits/stdc++.h>

using namespace std;

/**
 * 线段树节点
 */
struct node {
    int l, r;
    long long s, d;

    node() {
        l = r = s = d = 0;
    }
    node(int _l, int _r) {
        l = _l;
        r = _r;
        s = d = 0;
    }
} tr[100005 << 2];
int n, m, op, x, y, k, a[100005];

/**
 * 上传区间和
 * @param u 父节点下标
 */
void pushup(int u) {
    tr[u].s = tr[u << 1].s + tr[u << 1 | 1].s;
}

/**
 * 下放懒标记
 * @param u 父节点下标
 */
void pushdown(int u) {
    if (!tr[u].d) return;
    // 处理左子树
    tr[u << 1].d += tr[u].d;
    tr[u << 1].s += (tr[u << 1].r - tr[u << 1].l + 1) * tr[u].d;
    // 处理右子树
    tr[u << 1 | 1].d += tr[u].d;
    tr[u << 1 | 1].s += (tr[u << 1 | 1].r - tr[u << 1 | 1].l + 1) * tr[u].d;
    // 清除懒标记
    tr[u].d = 0;
}

/**
 * 建立线段树
 * @param u 根节点下标
 * @param l 左端点
 * @param r 右端点
 */
void build(int u, int l, int r) {
    tr[u] = node(l, r);
    if (l == r) {
        tr[u].s = a[l];
        return;
    }
    int mid = l + r >> 1;
    build(u << 1, l, mid);
    build(u << 1 | 1, mid + 1, r);
    pushup(u);
}

/**
 * 区间修改
 * @param u 父节点下标
 * @param l 左端点
 * @param r 右端点
 * @param d 增加的值
 */
void modify(int u, int l, int r, int d) {
    if (tr[u].l >= l && tr[u].r <= r) {  // 被包含直接修改
        tr[u].d += d;
        tr[u].s += (tr[u].r - tr[u].l + 1) * d;
        return;
    }
    int mid = tr[u].l + tr[u].r >> 1;
    pushdown(u);                               // 下放懒标记
    if (l <= mid) modify(u << 1, l, r, d);     // 和左侧有交集
    if (r > mid) modify(u << 1 | 1, l, r, d);  // 和右侧有交集
    pushup(u);                                 // 上传新信息
}

/**
 * 区间查询
 * @param u 父节点
 * @param l 左端点
 * @param r 右端点
 */
long long query(int u, int l, int r) {
    if (tr[u].l >= l && tr[u].r <= r) {  // 被包含直接返回
        return tr[u].s;
    }
    int mid = tr[u].l + tr[u].r >> 1;
    long long s = 0;
    pushdown(u);                                // 下放懒标记
    if (l <= mid) s += query(u << 1, l, r);     // 和左侧有交集
    if (r > mid) s += query(u << 1 | 1, l, r);  // 和右侧有交集
    return s;
}

int main() {
    cin >> n >> m;
    for (int i = 1; i <= n; i++) {
        cin >> a[i];
    }
    build(1, 1, n);
    for (int i = 0; i < m; i++) {
        cin >> op >> x >> y;
        if (op == 1) {
            cin >> k;
            modify(1, x, y, k);
        }
        else if (op == 2) {
            cout << query(1, x, y) << endl;
        }
    }
    return 0;
}

区间加、乘

#include <bits/stdc++.h>

using namespace std;

struct node {
    int l, r;
    long long s, d, x;

    node() {
        l = r = s = d = 0;
        x = 1;
    }
    node(int _l, int _r) {
        l = _l, r = _r;
        s = d = 0;
        x = 1;
    }
} tr[100005 << 2];
int n, m, p, op, x, y;
long long k, a[100005];

/**
 * 上传信息
 * @param u 父节点下标
 */
void pushup(int u) {
    tr[u].s = (tr[u << 1].s + tr[u << 1 | 1].s) % p;
}

/**
 * 下放懒标记
 * @param u 父节点下标
 * @attention 先乘后加
 */
void pushdown(int u) {
    // 左子树
    tr[u << 1].s = ((tr[u << 1].s * tr[u].x) + (tr[u << 1].r - tr[u << 1].l + 1) * tr[u].d) % p;
    tr[u << 1].x = tr[u << 1].x * tr[u].x % p;
    tr[u << 1].d = (tr[u << 1].d * tr[u].x + tr[u].d) % p;
    // 右子树
    tr[u << 1 | 1].s = ((tr[u << 1 | 1].s * tr[u].x) + (tr[u << 1 | 1].r - tr[u << 1 | 1].l + 1) * tr[u].d) % p;
    tr[u << 1 | 1].x = tr[u << 1 | 1].x * tr[u].x % p;
    tr[u << 1 | 1].d = (tr[u << 1 | 1].d * tr[u].x + tr[u].d) % p;
    // 清除懒标记
    tr[u].d = 0;
    tr[u].x = 1;
}

/**
 * 建立线段树
 * @param u 根节点下标
 * @param l 左端点
 * @param r 右端点
 */
void build(int u, int l, int r) {
    tr[u] = node(l, r);
    if (l == r) {
        tr[u].s = a[l] % p;
        return;
    }
    int mid = l + r >> 1;
    build(u << 1, l, mid);
    build(u << 1 | 1, mid + 1, r);
    pushup(u);
}

/**
 * 区间修改
 * @details 将区间 [l, r] 中的每一个数加上 d
 * @param u 父节点下标
 * @param l 左端点
 * @param r 右端点
 * @param x 乘上的数
 * @param d 增加的值
 */
void modify(int u, int l, int r, long long x, long long d) {
    // 被包含直接修改
    if (tr[u].l >= l && tr[u].r <= r) {
        tr[u].s = ((tr[u].s * x) + (tr[u].r - tr[u].l + 1) * d) % p;
        tr[u].x = tr[u].x * x % p;
        tr[u].d = (tr[u].d * x + d) % p;
        return;
    }
    int mid = tr[u].l + tr[u].r >> 1;
    pushdown(u);                                  // 下放懒标记
    if (l <= mid) modify(u << 1, l, r, x, d);     // 和左侧有交集
    if (r > mid) modify(u << 1 | 1, l, r, x, d);  // 和右侧有交集
    pushup(u);                                    // 上传新信息
}

/**
 * 区间查询
 * @param u 
 * @param l 
 * @param r 
 * @return int 
 */
long long query(int u, int l, int r) {
    if (tr[u].l >= l && tr[u].r <= r) {  // 被包含直接返回
        return tr[u].s;
    }
    int mid = tr[u].l + tr[u].r >> 1;
    long long s = 0;
    pushdown(u);                                         // 下放懒标记
    if (l <= mid) s = query(u << 1, l, r);               // 和左侧有交集
    if (r > mid) s = (s + query(u << 1 | 1, l, r)) % p;  // 和右侧有交集
    return s;
}

int main() {
    cin >> n >> m >> p;
    for (int i = 1; i <= n; i++) {
        cin >> a[i];
    }
    build(1, 1, n);
    while (m--) {
        cin >> op >> x >> y;
        if (op == 1) {
            cin >> k;
            modify(1, x, y, k, 0);
        }
        else if (op == 2) {
            cin >> k;
            modify(1, x, y, 1, k);
        }
        else if (op == 3) {
            cout << query(1, x, y) % p << endl;
        }
    }
    return 0;
}