【习题选讲】虚树

虚树,没那么玄乎,针对的是“多次树上询问一些点”的题目。比如下面这道题,题面中就有这样的表述:sigma(ki)<=500000 那么我们的算法一定不和 n 有关,而是和 k 有关啦!

虚树的时间复杂度就是 O(sigma(ki))!下面是练习时间~

[SDOI2011]-消耗战


传送门

很容易想到树形DP,但那样就是 O(nm),太大了!肿么办 ಠ_ಠ 上虚树!

每次把询问点及它们的 LCA 摘出来重新建一棵树,也就是虚树,虚树上树形 DP,单次 O(ki)。

注意,本题特殊性在于,一条链若是 根->a->b->c,那么只要保留 根->a 这段路径就可以了!这个很明显吧

code :

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
#include <bits/stdc++.h>
#define rep(i, x, y) for (int i = x; i <= y; i++)
using namespace std;

const int N = 1e6 + 10;
typedef long long ll;
int n, m, a[N], idx;
int dep[N], son[N], fa[N], tp[N], sz[N], dfn[N];
int lnk[N], nxt[N << 1], to[N << 1], val[N << 1], cnt;
int stk[N], top;
ll mn[N];
vector<int> v[N];

bool cmp(int x, int y) { return dfn[x] < dfn[y]; }

void add(int x, int y, int z) {
to[++cnt] = y, nxt[cnt] = lnk[x], lnk[x] = cnt, val[cnt] = z;
}

void addedge(int x, int y) { v[x].push_back(y); }

void dfs1(int x, int father) {
fa[x] = father;
dfn[x] = ++idx;
dep[x] = dep[father] + 1;
sz[x] = 1;
for (int i = lnk[x]; i; i = nxt[i]) {
int y = to[i];
if (y == father) continue;
mn[y] = min(mn[x], (ll)val[i]);
dfs1(y, x);
sz[x] += sz[y];
if (sz[y] > sz[son[x]]) son[x] = y;
}
}

void dfs2(int x, int top) {
tp[x] = top;
if (!son[x]) return;
dfs2(son[x], top);
for (int i = lnk[x]; i; i = nxt[i]) {
int y = to[i];
if (y != fa[x] && y != son[x]) dfs2(y, y);
}
}

int LCA(int x, int y) {
while (tp[x] != tp[y]) {
if (dep[tp[x]] < dep[tp[y]]) swap(x, y);
x = fa[tp[x]];
}
if (dep[x] > dep[y]) swap(x, y);
return x;
}

void insert(int x) {
if (top == 1) { stk[++top] = x; return; }
int lca = LCA(x, stk[top]);
if (lca == stk[top]) return; // 直接 return 的原因是 一条链上只要存除根结点以外最上面的那个点!!!
while (top > 1 && dfn[stk[top - 1]] >= dfn[lca])
addedge(stk[top - 1], stk[top]), --top;
if (lca != stk[top])
addedge(lca, stk[top]), stk[top] = lca;
stk[++top] = x;
}

ll DP(int x) {
if (v[x].size() == 0) return mn[x];
ll sum = 0;
for (int i = 0; i < v[x].size(); i++)
sum += DP(v[x][i]);
v[x].clear();
return min(mn[x], sum);
}

int main() {
cin >> n;
rep(i, 1, n - 1) {
int x, y, z; scanf("%d%d%d", &x, &y, &z);
add(x, y, z), add(y, x, z);
}
mn[1] = (1ll << 60);
dfs1(1, 0);
dfs2(1, 1);

cin >> m;
while (m--) {
int k; scanf("%d", &k);
rep(i, 1, k) scanf("%d", &a[i]);
sort(a + 1, a + k + 1, cmp);
stk[top = 1] = 1;
rep(i, 1, k) insert(a[i]);
while (top > 1) {
addedge(stk[top - 1], stk[top]);
--top;
}
printf("%lld\n", DP(1));
}
return 0;
}

[HEOI2014]-大工程


传送门

树形DP。一开始的思路是 求代价 sigma(边ab的val ✖️ a子树中选中节点数 ✖️ b子树中选中节点数),这样应该也可以?。看 hzwer 求代价和的方法是不断合并,设 f[x] 表示在遍历 y 子树前遍历的选中节点到 x 的距离和,反正搞一搞就出来了。最大最小代价也挺好求的。

code :

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
// bzoj_3611
#include <bits/stdc++.h>
#define rep(i, x, y) for (int i = x; i <= y; i++)
using namespace std;

const int N = 2e6 + 10, inf = 0x7fffffff;
typedef long long ll;
int n, Q, idx, top;
ll sz[N], son[N], fa[N], tp[N], id[N], dep[N];
ll h[N], v[N], stk[N], siz[N], f[N], mn[N], mx[N];
int to[N << 1], nxt[N << 1], lnk[N], cnt;
vector<int> g[N], val[N];
ll tot, ans1, ans2;

void add(int x, int y) {
to[++cnt] = y, nxt[cnt] = lnk[x], lnk[x] = cnt;
}

void dfs1(int x, int fat) {
sz[x] = 1;
fa[x] = fat;
dep[x] = dep[fat] + 1;
int mx = -1;
for (int i = lnk[x]; i; i = nxt[i]) {
int y = to[i];
if (y == fat) continue;
dfs1(y, x);
sz[x] += sz[y];
if (mx < sz[y])
mx = sz[y], son[x] = y;
}
}

void dfs2(int x, int rt) {
id[x] = ++idx;
tp[x] = rt;
if (son[x]) dfs2(son[x], rt);
for (int i = lnk[x]; i; i = nxt[i]) {
int y = to[i];
if (y != fa[x] && y != son[x]) dfs2(y, y);
}
}

bool cmp(int a, int b) { return id[a] < id[b]; }

void addedge(int x, int y) {
g[x].push_back(y);
val[x].push_back(dep[y] - dep[x]);
}

int LCA(int x, int y) {
while (tp[x] != tp[y]) {
if (dep[tp[x]] < dep[tp[y]]) swap(x, y);
x = fa[tp[x]];
}
if (dep[x] > dep[y]) swap(x, y);
return x;
}

void insert(int x) {
int lca = LCA(x, stk[top]);
if (lca == stk[top]) { stk[++top] = x; return; }
while (top > 1 && id[stk[top - 1]] >= id[lca])
addedge(stk[top - 1], stk[top]), --top;
if (lca != stk[top])
addedge(lca, stk[top]), stk[top] = lca;
stk[++top] = x;
}

void dfs(int x) {
f[x] = 0;
siz[x] = v[x];
mn[x] = v[x] ? 0 : inf;
mx[x] = v[x] ? 0 : -inf;
for (int i = 0; i < g[x].size(); i++) {
int y = g[x][i], z = val[x][i];
dfs(y);
tot += (f[x] + siz[x] * z) * siz[y] + f[y] * siz[x];
siz[x] += siz[y];
f[x] += f[y] + siz[y] * z;
ans1 = min(ans1, mn[x] + mn[y] + z);
ans2 = max(ans2, mx[x] + mx[y] + z);
mn[x] = min(mn[x], mn[y] + z);
mx[x] = max(mx[x], mx[y] + z);
}
g[x].clear(), val[x].clear();
}

int main() {
cin >> n;
rep(i, 1, n - 1) {
int x, y; cin >> x >> y;
add(x, y), add(y, x);
}
dfs1(1, 0);
dfs2(1, 1);
cin >> Q;
while (Q--) {
int k; cin >> k;
rep(i, 1, k) cin >> h[i], v[h[i]] = 1;
if (k == 1) { puts("0 0 0"); continue; }
sort(h + 1, h + k + 1, cmp);
top = 0;
if (!v[1]) stk[++top] = 1;
rep(i, 1, k) insert(h[i]);
while (top > 1) {
addedge(stk[top - 1], stk[top]);
--top;
}
tot = 0, ans1 = inf, ans2 = -inf;
dfs(1);
printf("%lld %lld %lld\n", tot, ans1, ans2);
rep(i, 1, k) v[h[i]] = 0;
}
return 0;
}

[HNOI2014]-世界树


传送门

hzwer大神写得好

真的,太难调了orz。。。两遍 dfs 求出每个虚树上的点对应的管理点,虚树上每条边 ab 对应在树上的区域涵盖的点,它们的管理点不是 a 的管理点就是 b 的管理点。确定范围,其实可以倍增跳到 ab 的中点。size 的加加减减要注意,很容易出错 ಥ_ಥ

code :

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
#include <bits/stdc++.h>
#define rep(i, x, y) for (int i = x; i <= y; i++)
using namespace std;

const int N = 3e5 + 5;
int n, a[N], Q, k, m[N], b[N];
int to[N << 1], nxt[N << 1], lnk[N], cnt;
int f[N][22], dfn[N], idx, dep[N], sz[N];
int stk[N], top, q[N], id, ans[N], bel[N], siz[N];

void add(int x, int y) {
to[++cnt] = y, nxt[cnt] = lnk[x], lnk[x] = cnt;
}

bool cmp(int a, int b) { return dfn[a] < dfn[b]; }

void dfs(int x, int fa) {
dfn[x] = ++idx;
dep[x] = dep[fa] + 1;
sz[x] = 1;
rep(i, 1, 19)
f[x][i] = f[f[x][i - 1]][i - 1];
for (int i = lnk[x]; i; i = nxt[i]) {
int y = to[i];
if (y == fa) continue;
f[y][0] = x;
dfs(y, x);
sz[x] += sz[y];
}
}

int LCA(int x, int y) {
if (dep[x] < dep[y]) swap(x, y);
for (int i = 19; i >= 0; i--)
if (dep[f[x][i]] >= dep[y]) x = f[x][i];
if (x == y) return x;
for (int i = 19; i >= 0; i--)
if (f[x][i] != f[y][i]) x = f[x][i], y = f[y][i];
return f[x][0];
}

int dis(int x, int y) {
return dep[x] + dep[y] - 2 * dep[LCA(x, y)];
}

void solve(int a, int b) {
int x = b, mid = b;
for (int i = 19; i >= 0; i--)
if (dep[f[x][i]] > dep[a]) x = f[x][i];
siz[a] -= sz[x];
if (bel[a] == bel[b]) { ans[bel[a]] += sz[x] - sz[b]; return; }
for (int i = 19; i >= 0; i--) {
int tmp = f[mid][i];
if (dep[tmp] <= dep[a]) continue;
int t1 = dis(bel[a], tmp), t2 = dis(bel[b], tmp);
if (t1 > t2 || (t1 == t2 && bel[b] < bel[a])) mid = tmp;
}
ans[bel[a]] += sz[x] - sz[mid];
ans[bel[b]] += sz[mid] - sz[b];
}

void dfs1(int x) {
q[++id] = x;
siz[x] = sz[x];
for (int i = lnk[x]; i; i = nxt[i]) {
int y = to[i];
dfs1(y);
if (!bel[y]) continue;
int t1 = dis(bel[y], x), t2 = dis(bel[x], x);
if ((t1 == t2 && bel[y] < bel[x]) || t1 < t2 || !bel[x])
bel[x] = bel[y];
}
}

void dfs2(int x) {
for (int i = lnk[x]; i; i = nxt[i]) {
int y = to[i];
int t1 = dis(bel[x], y), t2 = dis(bel[y], y);
if ((t1 == t2 && bel[y] > bel[x]) || t1 < t2 || !bel[y])
bel[y] = bel[x];
dfs2(y);
}
}

void build() {
sort(m + 1, m + k + 1, cmp);
if (bel[1] != 1) stk[++top] = 1;
rep(i, 1, k) {
int x = m[i], lca = 0;
if (!top) { stk[++top] = x; continue; }
lca = LCA(stk[top], x);
if (lca == stk[top]) { stk[++top] = x; continue; }
while (top > 1 && dfn[stk[top - 1]] >= dfn[lca]) {
add(stk[top - 1], stk[top]); top--; // 不能写 stk[top--], 不然会 WA。。。不知道为啥
}
if (lca != stk[top]) {
add(lca, stk[top]); stk[top] = lca;
}
stk[++top] = x;
}
while (top > 1) add(stk[top - 1], stk[top]), top--;
}

int main() {
cin >> n;
rep(i, 1, n - 1) {
int x, y; scanf("%d%d", &x, &y);
add(x, y), add(y, x);
}
dfs(1, 0);
memset(lnk, 0, sizeof(lnk));

cin >> Q;
while (Q--) {
cin >> k;
top = id = cnt = 0;
rep(i, 1, k) scanf("%d", &m[i]), b[i] = m[i], bel[m[i]] = m[i];
build();

dfs1(1), dfs2(1);
rep(o, 1, id)
for (int i = lnk[q[o]]; i; i = nxt[i])
solve(q[o], to[i]);

rep(i, 1, id)
ans[bel[q[i]]] += siz[q[i]];

rep(i, 1, k) printf("%d ", ans[b[i]]); puts("");

rep(i, 1, id)
ans[q[i]] = bel[q[i]] = siz[q[i]] = lnk[q[i]] = 0;
}
return 0;
}