【根号分治/启发式合并】CF375D Tree and Queries

Ice_lift

2024-07-18 10:36:04

Personal

首先,我们发现同一个点的查询只需知道子树中每种颜色的数量就能一起得到答案,所以我们不妨将询问离线,一起考虑。

res_{u, col} 表示 u 子树下每种颜色的数量,显然 res_{u, col} = \sum res_{v, col},这个东西进行启发式合并可以做到 \text{O}(n \log^{2} n) 实现。

由于是次数统计,我们不妨考虑根号分治。

对于出现次数 \le \sqrt{n} 的颜色,维护 cnt_{u, t} 表示 u 子树中出现 t 次的颜色数量。在合并过程中,该值也可以维护出来,维护复杂度 \text{O}(n\log n)。统计答案时,只需统计出现 k \sim \sqrt{n} 次的点的数量即可,时间复杂度至多 \text{O}(\sqrt{n})

对于出现次数 > \sqrt{n} 的颜色,这样的颜色最多有 \sqrt{n} 种,于是将所有这样的颜色存储进一个 set 中,统计答案时至多将所有这样的颜色扫一遍,时间复杂度 \text{O}(\log n\times \sqrt{n})

最终答案即为两部分的和。

时间复杂度为 \text{O}(n\log^2 n + n \sqrt{n})

PS:此题作者试了很多种方法,很多都因为 10^5 \times 317 的空间爆炸,这种方法是目前唯一通过的方法,如果您有更好的实现方法,欢迎与我探讨。

Code

#include <bits/stdc++.h>
using namespace std;
const int N = 1e5 + 1;
const int M = 317 + 1;
int n, m, b;
int c[N], ct[N], tot;
vector<int> g[N];
int a[N];
vector<int > q[N], pp[N];
set<int> p[N];
int ans[N], cnt[N][M];

map<int, int>  dfs (int u, int fa) {
    map<int, int> res;
    res[c[u]] ++;
    if(res[c[u]] <= b) cnt[u][res[c[u]]] ++;
    else p[u].insert(c[u]);
    for (auto v : g[u]) {
        if(v == fa)  continue;
        map<int, int> res2 = dfs(v, u);
        if(res.size() < res2.size()) swap(res, res2), swap(cnt[u], cnt[v]), swap(p[u], p[v]);
        for (auto x : res2) {
            if(res[x.first] <= b) cnt[u][res[x.first]] --;
            res[x.first] += x.second;
            if(res[x.first] <= b) cnt[u][res[x.first]] ++;
            else p[u].insert(x.first);
        }
        res2.clear();
    }
    for (auto x : q[u]) {
        int k = a[x], id = x, xx = 0;
        if(k <= b) for (int i = k; i <= b; i ++) xx += cnt[u][i];
        for (auto y : p[u]) xx += (res[y] >= k);
        ans[id] = xx;
    }
    return res;
}

signed main() {
    ios::sync_with_stdio(0), cin.tie(0), cout.tie(0);
    cin >> n >> m;
    b = sqrt(n);
    for (int i = 1; i <= n; i ++) cin >> c[i], ct[c[i]] ++;
    for (int i = 1; i < n; i ++) {
        int u, v;
        cin >> u >> v;
        g[u].push_back(v), g[v].push_back(u);
    }
    for (int i = 1, u; i <= m; i ++) {
        cin >> u >> a[i];
        q[u].push_back(i);
    }
    dfs(1, 0);
    for (int i = 1; i <= m; i ++) cout << ans[i] << endl;
    return 0;
}