【学习笔记】吉司机线段树

吉司机线段树,又名 Segment Tree Beats!(Angel Beats?= =),一种支持 $O(nlog^2n)$ 区间取 $\min$/$\max$ 的势能线段树。

以取 $\min$ 为例,每个节点维护最大值和最大值出现次数和严格次大值(即不等于最大值),如果修改值小于当前区间次大值就往下递归,否则如果小于当前区间最大值就把最大值更换为修改值。这样复杂度是 $O(nlogn)$。

均摊分析证明:设线段树上一个节点的势能函数 $\varphi$ 为区间内本质不同的数字个数,每次往下递归必然使得 $\varphi - 1$,未被修改区间完全覆盖的节点 $\varphi$ 至多加一,总复杂度为 $O(nlogn)$。

主要思想就是这些,打标记啊处理啊和普通线段树不同的地方就是「给最大」「给次大」或「给历史最大」「给历史次大」要分开。

可以维护区间和。

实现没什么难点。

模板

要求支持:

  • 修改:区间加,区间取 $\min$
  • 查询:区间和,区间 $\max$,区间历史 $\max$

用「V」那题的做法没法维护区间和,所以只能上 Beats 啦

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
#include <bits/stdc++.h>
#define rep(i, x, y) for (int i = x; i <= y; i++)
using namespace std;

typedef long long ll;
const int N = 5e5 + 10;
int n, m, a[N];

void rd(int &x) {
x = 0; int f = 1; char ch = getchar();
for (; !isdigit(ch); ch = getchar()) if (ch == '-') f = -1;
for (; isdigit(ch); ch = getchar()) x = (x << 1) + (x << 3) + ch - '0';
x *= f;
}

namespace SGTBeats {
#define inf 1e18
#define mid (l + r >> 1)
#define ls (x << 1)
#define rs (x << 1 | 1)

void cmax(ll &x, ll y) { x = max(x, y); }

struct atom {
int l, r;
ll mxa, se, cnt, mxb, sum;
// 下面四个都是加法懒标记,只是对象不同
ll add_a; // 区间最大值
ll add_b; // 区间历史最大值
ll add_a1; // 区间非最大值
ll add_b1; // 区间历史非最大值
} tr[N << 2], t;

void upd(int x) {
tr[x].sum = tr[ls].sum + tr[rs].sum;
tr[x].mxa = max(tr[ls].mxa, tr[rs].mxa);
tr[x].mxb = max(tr[ls].mxb, tr[rs].mxb);
if (tr[ls].mxa == tr[rs].mxa) {
tr[x].cnt = tr[ls].cnt + tr[rs].cnt;
tr[x].se = max(tr[ls].se, tr[rs].se);
} else {
tr[x].cnt = (tr[ls].mxa > tr[rs].mxa ? tr[ls].cnt : tr[rs].cnt);
tr[x].se = (tr[ls].mxa > tr[rs].mxa ? max(tr[rs].mxa, tr[ls].se) : max(tr[ls].mxa, tr[rs].se));
}
}
void modi(int x, ll add_a, ll add_b, ll add_a1, ll add_b1) {
t = tr[x];
t.sum += add_a * tr[x].cnt + add_a1 * (tr[x].r - tr[x].l + 1 - tr[x].cnt);
t.mxa += add_a;
cmax(t.mxb, tr[x].mxa + add_b);
t.add_a += add_a;
t.add_a1 += add_a1;
cmax(t.add_b, tr[x].add_a + add_b);
cmax(t.add_b1, tr[x].add_a1 + add_b1);
if (t.se != -inf) t.se += add_a1;
tr[x] = t;
}
void psd(int x) {
int mxn = max(tr[ls].mxa, tr[rs].mxa);
if (tr[ls].mxa == mxn)
modi(ls, tr[x].add_a, tr[x].add_b, tr[x].add_a1, tr[x].add_b1);
else
modi(ls, tr[x].add_a1, tr[x].add_b1, tr[x].add_a1, tr[x].add_b1);
if (tr[rs].mxa == mxn) // 这里如果写「>= tr[ls].mxa」就炸了,因为改了……
modi(rs, tr[x].add_a, tr[x].add_b, tr[x].add_a1, tr[x].add_b1);
else
modi(rs, tr[x].add_a1, tr[x].add_b1, tr[x].add_a1, tr[x].add_b1);
tr[x].add_a = tr[x].add_a1 = tr[x].add_b = tr[x].add_b1 = 0;
}
void build(int x, int l, int r) {
tr[x].l = l, tr[x].r = r;
if (l == r) {
tr[x].mxa = tr[x].mxb = tr[x].sum = a[l];
tr[x].se = -inf;
tr[x].cnt = 1;
return;
}
build(ls, l, mid), build(rs, mid + 1, r);
upd(x);
}
void modi_add(int x, int l, int r, int lx, int rx, int v) {
if (lx <= l && r <= rx) {
modi(x, v, v, v, v);
return;
}
psd(x);
if (lx <= mid) modi_add(ls, l, mid, lx, rx, v);
if (rx > mid) modi_add(rs, mid + 1, r, lx, rx, v);
upd(x);
}
void modi_min(int x, int l, int r, int lx, int rx, int v) {
if (lx > r || rx < l || v >= tr[x].mxa) return;
if (lx <= l && r <= rx && v > tr[x].se) {
modi(x, v - tr[x].mxa, v - tr[x].mxa, 0, 0);
return;
}
psd(x);
modi_min(ls, l, mid, lx, rx, v);
modi_min(rs, mid + 1, r, lx, rx, v);
upd(x);
}
ll qry(int x, int l, int r, int lx, int rx, int op) {
if (lx > r || rx < l) return op == 3 ? 0 : -inf;
if (lx <= l && r <= rx) return op == 3 ? tr[x].sum : op == 4 ? tr[x].mxa : tr[x].mxb;
psd(x);
if (op == 3) return qry(ls, l, mid, lx, rx, op) + qry(rs, mid + 1, r, lx, rx, op);
else return max(qry(ls, l, mid, lx, rx, op), qry(rs, mid + 1, r, lx, rx, op));
}
};
using namespace SGTBeats;

int main() {
rd(n), rd(m);
rep(i, 1, n) rd(a[i]);
build(1, 1, n);
int op, l, r, v;
while (m--) {
rd(op), rd(l), rd(r);
if (op == 1) {
rd(v), modi_add(1, 1, n, l, r, v);
} else if (op == 2) {
rd(v), modi_min(1, 1, n, l, r, v);
} else {
printf("%lld\n", qry(1, 1, n, l, r, op));
}
}
return 0;
}

习题:前进四 题解见「UR 19 题解」。

SNOI2020-区间和 题解见「SNOI2020 乱写」(可能咕了 QwQ)