gym103447C 的简单做法

rsy_

2024-11-16 14:57:15

Personal

这个做法可以过 (当然仅仅是可以过,但是觉得不好卡),管解抽象的 dp + dsu on tree 太困难拼劲全力无法战胜。

模拟赛写的牛逼代码,感觉非常简洁。

显然我们可以设 dp_i 表示 i 子树的答案。

我们发现只要子树第一次选择的颜色相同,那么就把这些最后可以选择的点暴力存起来。

然后转移的时候就是子树 dp 值的 sum - 出现次数最大的颜色的出现次数 + 1,这些颜色使用一个 bitset 存储,得到常数很小的 \mathcal O(\frac{n^3}{w}) 做法。

#include <bits/stdc++.h>
#define lb(x) (x&-x)
#define L(i,j,k) for(int i=(j);i<=(k);++i)
#define R(i,j,k) for(int i=(j);i>=(k);--i)
#define swap(a,b) (a^=b^=a^=b)

using namespace std;
using i64 = long long;

typedef pair<int, int> pii;
typedef long long ll;
typedef unsigned long long ull;
void chmin(int &x, int c) {
  x = min(x, c);
}
void chmax(int &x, int c) {
  x = max(x, c);
}

const int maxn = 1e5 + 10, mod = 998244353;
vector<int> g[maxn];
int N, val[maxn], dp[maxn];
bitset<maxn> dfs (int u) {
    bitset<maxn> S, S2;
    S2.reset(), S.reset();
    if (g[u].size() == 0) {
        S[val[u]] = 1;
        dp[u] = 1;
        return S;
    }
    S.reset();
    vector<bitset<maxn>> G;
    for (int v : g[u])
        G.push_back(dfs (v)), dp[u] += dp[v];
    for (auto x : G) S2 |= x;
    int mx = 0;
    for (int i = 1; i <= N; i ++ ) {
        if (S2[i]) {
            int cnt = 0;
            for (auto x : G) cnt += x[i];
            if (cnt - 1 > mx) {
                mx = cnt - 1, S.reset();
            }
            if (cnt - 1 == mx) {
                S[i] = 1;
            }
        }
    }
    dp[u] -= mx;
    if (mx == 0) {
        return S2;
    }
    return S;
} 

void solve() {
    cin >> N;
    L (i, 2, N) {
        int x;
        cin >> x;
        g[x].push_back(i);
    }
    L (i, 1, N) {
        cin >> val[i];
    }
    dfs (1);
    cout << dp[1] << '\n';
}

signed main() {
    ios::sync_with_stdio(false);
    cin.tie(0), cout.tie(0);
  int T = 1;
  while (T--)solve();
  return 0;
}

TLE on test 9。

大胆猜测 res 的 sz 不会太多,于是换成 vector 进行暴力。

#pragma GCC optimize(1, 2, 3, "Ofast")
#include <bits/stdc++.h>
#define lb(x) (x&-x)
#define L(i,j,k) for(int i=(j);i<=(k);++i)
#define R(i,j,k) for(int i=(j);i>=(k);--i)
#define swap(a,b) (a^=b^=a^=b)

using namespace std;
using i64 = long long;

typedef pair<int, int> pii;
typedef long long ll;
typedef unsigned long long ull;
void chmin(int &x, int c) {
  x = min(x, c);
}
void chmax(int &x, int c) {
  x = max(x, c);
}

const int maxn = 1e5 + 10, mod = 998244353;
vector<int> g[maxn];
int N, val[maxn], dp[maxn], cnt[maxn];
vector<int> res[maxn];
void dfs (int u) {
    if (g[u].size() == 0) {
        res[u].push_back(val[u]);
        dp[u] = 1; return ;
    }
    for (int v : g[u])
        dfs (v), dp[u] += dp[v];
    vector<int> t1; t1.clear();
    int mx = 0;
    for (int v : g[u]) for (int x : res[v]) t1.push_back(x), cnt[x] = 0;
    for (int v : g[u]) for (int x : res[v]) cnt[x] ++, chmax (mx, cnt[x]);
    for (int v : g[u]) for (int x : res[v]) if (cnt[x] == mx) res[u].push_back(x);
    sort (res[u].begin(), res[u].end());
    res[u].erase(unique(res[u].begin(), res[u].end()), res[u].end());
    dp[u] -= mx - 1;
    if (mx == 1) {
        sort (t1.begin(), t1.end());
        t1.erase(unique(t1.begin(), t1.end()), t1.end());
        res[u] = t1;
    }
} 

void solve() {
    cin >> N;
    L (i, 2, N) {
        int x;
        cin >> x;
        g[x].push_back(i);
    }
    L (i, 1, N) {
        cin >> val[i];
    }
    dfs (1);
    cout << dp[1] << '\n';
}

signed main() {
    ios::sync_with_stdio(false);
    cin.tie(0), cout.tie(0); 
  int T = 1;
  while (T--)solve();
  return 0;
}

结果也是 TLE on 9,发现会排序,把排序删掉,然后就变快了。

#include <bits/stdc++.h>
#define lb(x) (x&-x)
#define L(i,j,k) for(int i=(j);i<=(k);++i)
#define R(i,j,k) for(int i=(j);i>=(k);--i)
#define swap(a,b) (a^=b^=a^=b)

using namespace std;
using i64 = long long;

typedef pair<int, int> pii;
typedef long long ll;
typedef unsigned long long ull;
void chmin(int &x, int c) {
  x = min(x, c);
}
void chmax(int &x, int c) {
  x = max(x, c);
}

const int maxn = 1e5 + 10, mod = 998244353;
vector<int> g[maxn];
int N, val[maxn], dp[maxn], cnt[maxn];
vector<int> res[maxn];
void dfs (int u) {
    if (g[u].size() == 0) {
        res[u].push_back(val[u]);
        dp[u] = 1; return ;
    }
    for (int v : g[u])
        dfs (v), dp[u] += dp[v];
    vector<int> t1; t1.clear();
    int mx = 0;
    for (int v : g[u]) for (int x : res[v])
        if (cnt[x] != 100000 + u) t1.push_back(x), cnt[x] = 100000 + u;
    for (int v : g[u]) for (int x : res[v]) cnt[x] ++, chmax (mx, cnt[x]);
    for (int v : g[u]) {
        for (int x : res[v]) {
            if (cnt[x] == mx) {
                res[u].push_back(x), cnt[x] = 0;
            }
        }
        res[v].clear();
    } 
    for (int x : g[u]) res[x].clear();
    dp[u] -= mx - 100000 - u - 1;
    if (mx == 1) {
        res[u] = t1;
    }
} 

void solve() {
    cin >> N;
    L (i, 2, N) {
        int x;
        cin >> x;
        g[x].push_back(i);
    }
    L (i, 1, N) {
        cin >> val[i];
    }
    dfs (1);
    cout << dp[1] << '\n';
}

signed main() {
    ios::sync_with_stdio(false);
    cin.tie(0), cout.tie(0);
  int T = 1;
  while (T--)solve();
  return 0;
}

MLE on 28

这个随机数据下是一定不会 mle 的。

考虑用完之后就 clear 这个 vector,空间是会变小的,考虑出题人要怎么构造卡你?

无非就是很多很大的 vector,这样的话合并为完之后也是一个很大的 vector,于是我们设一个阈值 B 表示如果这个 vector 的 sz > B 那么就不加了,反正后面也有很多,实测 B 开到 600 就过了。

这样的话时间也是有保证的,顶天 \mathcal O(Bn)

#include <bits/stdc++.h>
#define lb(x) (x&-x)
#define L(i,j,k) for(int i=(j);i<=(k);++i)
#define R(i,j,k) for(int i=(j);i>=(k);--i)
#define swap(a,b) (a^=b^=a^=b)

using namespace std;
using i64 = long long;

typedef pair<int, int> pii;
typedef long long ll;
typedef unsigned long long ull;
void chmin(int &x, int c) {
  x = min(x, c);
}
void chmax(int &x, int c) {
  x = max(x, c);
}

const int maxn = 1e5 + 10, mod = 998244353;
vector<int> g[maxn];
int N, val[maxn], dp[maxn], cnt[maxn];
vector<int> res[maxn];
const int B = 600;
void dfs (int u) {
    if (g[u].size() == 0) {
        res[u].push_back(val[u]);
        dp[u] = 1; return ;
    }
    for (int v : g[u])
        dfs (v), dp[u] += dp[v];
    vector<int> t1; t1.clear();
    int mx = 0;
    for (int v : g[u]) for (int x : res[v])
        if (cnt[x] != 100000 + u) t1.push_back(x), cnt[x] = 100000 + u;
    for (int v : g[u]) for (int x : res[v]) cnt[x] ++, chmax (mx, cnt[x]);
    for (int v : g[u]) {
        for (int x : res[v]) {
            if (cnt[x] == mx) {
                res[u].push_back(x), cnt[x] = 0;
                if (res[u].size() >= B) break; 
            }
        }
        if (res[u].size() >= B) break;
        res[v].clear();
    } 
    for (int x : g[u]) res[x].clear();
    dp[u] -= mx - 100000 - u - 1;
    if (mx == 1) {
        res[u] = t1;
    }
} 

void solve() {
    cin >> N;
    L (i, 2, N) {
        int x;
        cin >> x;
        g[x].push_back(i);
    }
    L (i, 1, N) {
        cin >> val[i];
    }
    dfs (1);
    cout << dp[1] << '\n';
}

signed main() {
    ios::sync_with_stdio(false);
    cin.tie(0), cout.tie(0);
  int T = 1;
  while (T--)solve();
  return 0;
}

被卡了就把阈值变大。我是乱搞带式。