1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122
| #include <bits/stdc++.h> #define rep(i, x, y) for (int i = x; i <= y; i++) #define fi first #define se second using namespace std;
const int N = 3e5 + 5, M = 8e6 + 5, inf = 0x3f3f3f3f; typedef long long ll; typedef pair<int, int> pii; int n, m, cnt, idx, rt, cas; int a[N], fa[N], ch[M][2], mx[M][2]; ll ans; pii cur[N];
int getfa(int x) { return fa[x] == x ? x : fa[x] = getfa(fa[x]); }
void upd(int x, int y) { if (!x || !y) return; if (mx[x][0] == mx[y][0]) mx[x][1] = max(mx[x][1], mx[y][1]); if (mx[x][0] < mx[y][0]) mx[x][1] = max(mx[x][0], mx[y][1]); if (mx[x][0] > mx[y][0]) mx[x][1] = max(mx[x][1], mx[y][0]); mx[x][0] = max(mx[x][0], mx[y][0]); }
int newnode() { ++idx; ch[idx][0] = ch[idx][1] = 0, mx[idx][0] = mx[idx][1] = -1; return idx; }
int merge(int x, int y) { if (!x || !y) return x | y; int t = newnode(); ch[t][0] = merge(ch[x][0], ch[y][0]); ch[t][1] = merge(ch[x][1], ch[y][1]); upd(t, x), upd(t, y); return t; }
void dfs(int x, int dep) { if (dep < 0) return; int ls = ch[x][0], rs = ch[x][1]; if (ls) dfs(ls, dep - 1); if (rs) dfs(rs, dep - 1); ch[x][0] = merge(ch[x][0], ch[x][1]); }
bool unite(int x, int y) { x = getfa(x), y = getfa(y); if (x == y) return 0; fa[x] = y; return 1; }
void insert(int &x, int dep, int val, int id) { if (!x) x = ++idx; if (dep < 0) { if (mx[x][0] != id) { if (mx[x][0] < id) mx[x][1] = mx[x][0], mx[x][0] = id; else if (mx[x][1] < id) mx[x][1] = id; } return; } int c = ((val >> dep) & 1); insert(ch[x][c], dep - 1, val, id); mx[x][0] = mx[x][1] = -1; upd(x, ch[x][0]), upd(x, ch[x][1]); }
pii query(int x, int dep, int val, int id) { if (dep < 0) { return make_pair(0, (id == mx[x][0] ? mx[x][1] : mx[x][0])); } int c = ((val >> dep) & 1), ls = ch[x][0], rs = ch[x][1]; if (c) { if (!rs || (id == mx[rs][0] && mx[rs][1] == -1)) { return query(ls, dep - 1, val, id); } else { pii t = query(rs, dep - 1, val, id); t.fi += (1 << dep); return t; } } else { return query(ls, dep - 1, val, id); } }
int main() { cin >> n >> m; rep(i, 1, n) { scanf("%d", &a[i]); fa[i] = i; } cnt = n; while (cnt > 1) { ++cas; rep(i, 0, idx) { ch[i][0] = ch[i][1] = 0; mx[i][0] = mx[i][1] = -1; } idx = rt = 0; rep(i, 1, n) { insert(rt, m, a[i], getfa(i)); cur[i] = make_pair(-1, -1); } dfs(rt, m); rep(i, 1, n) { int f = getfa(i); pii t = query(rt, m, a[i], f); cur[f] = max(cur[f], t); } rep(i, 1, n) { if (fa[i] == i && cur[i].se > 0 && unite(i, cur[i].se)) { cnt--, ans += cur[i].fi; } } } printf("%lld\n", ans); return 0; }
|