XJOI200807 题解

A


这题就是典型的部分分提醒正解。

01 那部分可以引导想到 bitset,但 $O(q n / 32)$ 的显然不对,于是想到 $O(q 2^k / 32)$ 的,即 $f[i, s]$ 表示在第 i 个所在集合 s 中有没有值“扩散”到第 i 个,然后 or 和 and 操作能分别代替 max 和 min 操作,就做完了。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
#include <bits/stdc++.h>
#define rep(i, x, y) for (int i = x; i <= y; i++)
using namespace std;

const int N = 1e5 + 10;
int n, K, Q, cnt;
int a[13][N], b[13][N], t[13];
bitset<4096> B[N];

int main() {
cin >> n >> K >> Q;
rep(i, 1, K) rep(j, 1, n) scanf("%d", &a[i][j]);
rep(j, 1, n) {
rep(i, 1, K) t[i] = a[i][j];
sort(t + 1, t + K + 1);
rep(i, 1, K) b[i][j] = t[i];
}
rep(s, 0, (1 << K) - 1)
rep(i, 1, K) if ((s >> (i - 1)) & 1) B[i][s] = 1;
cnt = K;
while (Q--) {
int op, x, y; scanf("%d%d%d", &op, &x, &y);
if (op == 1) {
B[++cnt] = (B[x] | B[y]);
} else if (op == 2) {
B[++cnt] = (B[x] & B[y]);
} else {
for (int j = K; j; --j) {
int s = 0;
rep(i, 1, K) if (a[i][y] >= b[j][y]) s |= (1 << (i - 1));
if (B[x][s]) { printf("%d\n", b[j][y]); break; }
}
}
}
return 0;
}

B


显然要枚举最终颜色 $i$。问题抽象成有 $s + 1$ 个点 $0 ~ s$,当前在 $a_i$,每步可以向左向右或不动,问走到 $s$ 的期望步数,其中在 $0$ 点和 $s$ 点时都是 0。

题解,等差数列那里还是比较妙的!其实期望题多是从概念出发,我还是概念不清啊。

C


看起来就很套路的题。想到转化就好不少= =

这个平方很难搞啊!考虑它的实际意义:连通块路径数期望的两倍。根据期望的线性性质,我们可以把大期望化成小期望,分别考虑每条路径的贡献。(套路!有碰到过好几次啊啊啊)

设 $p(i, x, y)$ 表示 $x$、$y$ 在 $i$ 天后还在同一连通块的概率,显然前 $i - 1$ 天 $x$ 到 $y$ 的路径上的边都没被砍,因此

设 $E(i)$ 表示第 $i$ 天的路径数期望,那么

枚举 $z = dist(x, y)$ 就能把组合数拆开,就可以化成卷积形式了。设 $g(z)$ 表示 $dist(x, y) = z$ 的 $(x, y)$ 对数(点分治 + FFT),那么

卷积形式,FFT 就做完了。

调到去世,发现是三处sb错误 :)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
#include <bits/stdc++.h>
#define rep(i, x, y) for (int i = x; i <= y; i++)
using namespace std;

const int N = 4e5 + 10, mod = 998244353, G1 = 3, G2 = (mod + 1) / G1;
typedef long long ll;
ll n, l, lim;
ll f[N][20], rev[N];
ll fac[N], inv[N], F[N], G[N], A[N], B[N];
ll sz[N], rt, T, mark[N], len[N], tot, bin[N], mxdep, num[N], mx[N];
vector<int> nxt[N];

ll quick_pow(ll a, ll b) {
ll ret = 1;
for (; b; b >>= 1) {
if (b & 1) ret = ret * a % mod;
a = a * a % mod;
} return ret;
}

void init(int n) {
fac[0] = inv[0] = inv[1] = 1;
rep(i, 1, n) fac[i] = fac[i - 1] * i % mod;
rep(i, 2, n) inv[i] = (mod - mod / i) * inv[mod % i] % mod;
rep(i, 2, n) inv[i] = inv[i] * inv[i - 1] % mod;
}

void get_rev(int n) {
l = 0, lim = 1;
while (lim <= n) l++, lim <<= 1;
rep(i, 1, lim - 1) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (l - 1));
}

void NTT(ll a[], int op) {
rep(i, 0, lim - 1)
if (i < rev[i]) swap(a[i], a[rev[i]]);
for (int mid = 1; mid < lim; mid <<= 1) {
ll W = quick_pow(op == 1 ? G1 : G2, (mod - 1) / (mid << 1));
for (int j = 0; j < lim; j += (mid << 1)) {
ll w = 1;
for (int k = 0; k < mid; k++, w = w * W % mod) {
ll x = a[j + k], y = w * a[j + k + mid] % mod;
a[j + k] = (x + y) % mod, a[j + k + mid] = (x - y + mod) % mod;
}
}
}
if (op == -1) {
ll Inv = quick_pow(lim, mod - 2);
rep(i, 0, lim - 1) (a[i] *= Inv) %= mod;
}
}

void getrt(int x, int fa) {
sz[x] = 1, mx[x] = 0;
for (int i = 0; i < nxt[x].size(); i++) {
int y = nxt[x][i];
if (y == fa || mark[y]) continue;
getrt(y, x);
sz[x] += sz[y], mx[x] = max(mx[x], sz[y]);
}
mx[x] = max(mx[x], T - sz[x]);
if (mx[x] < mx[rt]) rt = x;
}

void getdis(int x, int fa, ll dep) {
++bin[dep];
mxdep = max(mxdep, dep);
for (int i = 0; i < nxt[x].size(); i++) {
int y = nxt[x][i];
if (y == fa || mark[y]) continue;
getdis(y, x, dep + 1);
}
}

void calc(int op) {
get_rev(mxdep << 1);
rep(i, 0, mxdep) A[i] = B[i] = bin[i], bin[i] = 0;
rep(i, mxdep + 1, lim - 1) A[i] = B[i] = 0;
NTT(A, 1), NTT(B, 1);
rep(i, 0, lim - 1) A[i] = (A[i] * B[i]) % mod;
NTT(A, -1);
rep(i, 0, min(n, mxdep << 1)) num[i] = (num[i] + op * A[i]) % mod;
}

void solve(int x) {
mark[x] = 1, mxdep = 0, getdis(x, 0, 0), calc(1);
for (int i = 0; i < nxt[x].size(); i++) {
int y = nxt[x][i];
if (mark[y]) continue;
mxdep = 0, getdis(y, x, 1), calc(-1); // 容斥出一定经过 x 的路径数
}
for (int i = 0; i < nxt[x].size(); i++) {
int y = nxt[x][i];
if (mark[y]) continue;
T = sz[y], mx[rt = 0] = 1e9, getrt(y, x), solve(rt);
}
}

int main() {
cin >> n;
init(n);
rep(i, 1, n - 1) {
int x, y; scanf("%d%d", &x, &y);
nxt[x].push_back(y), nxt[y].push_back(x);
}
T = n, mx[rt = 0] = 1e9, getrt(1, 0), solve(rt);
get_rev(n << 1);
rep(i, 0, n - 1) num[i] = (num[i] + mod) % mod;
rep(i, 0, n - 1) A[i] = num[i] * fac[n - 1 - i] % mod, B[i] = inv[i];
rep(i, n, lim - 1) A[i] = B[i] = 0;
NTT(A, 1), NTT(B, 1);
rep(i, 0, lim - 1) A[i] = (A[i] * B[i]) % mod;
NTT(A, -1);
rep(i, 0, n - 1) printf("%lld ", A[n - i - 1] * fac[n - i - 1] % mod * inv[n - 1] % mod);
puts("");
return 0;
}