奇怪 WA,开 longlong 了。

SP1825 FTOUR2 - Free tour II

Loser_Syx @ 2025-01-10 19:54:21

#include <iostream>
#include <cassert>
#include <queue>
#include <cmath>
#include <cstring>
#include <algorithm>
#include <bitset>
#include <random>
#include <ctime>
#include <map>
#include <set>
using namespace std;
#define int long long
#define pii pair<int, int>
#define eb emplace_back
#define F first
#define S second
#define test(x) cout << "Test: " << (x) << '\n'
#define lowbit(x) (x & -x)
#define debug puts("qwq");
#define open(x) freopen(#x".in", "r", stdin);freopen(#x".out", "w", stdout);
#define close fclose(stdin);fclose(stdout);
namespace FastIO {
    template <typename T = int>
    inline T read() {
        T s = 0, w = 1;
        char c = getchar();
        while (!isdigit(c)) {
            if (c == '-') w = -1;
            c = getchar();
        }
        while (isdigit(c)) s = (s << 1) + (s << 3) + (c ^ 48), c = getchar();
        return s * w;
    }
    template <typename T>
    inline void read(T &s) {
        s = 0;
        int w = 1;
        char c = getchar();
        while (!isdigit(c)) {
            if (c == '-') w = -1;
            c = getchar();
        }
        while (isdigit(c)) s = (s << 1) + (s << 3) + (c ^ 48), c = getchar();
        s = s * w;
    }
    template <typename T, typename... Arp> inline void read(T &x, Arp &...arp) {
        read(x), read(arp...);
    }
    template <typename T>
    inline void write(T x, char ch = '\n') {
        if (x < 0) x = -x, putchar('-');
        static char stk[25];
        int top = 0;
        do {
            stk[top++] = x % 10 + '0', x /= 10;
        } while (x);
        while (top) putchar(stk[--top]);
        putchar(ch);
        return;
    }
    template <typename T>
    inline void smax(T &x, T y) {
        if (x < y) x = y;
    }
    template <typename T>
    inline void smin(T &x, T y) {
        if (x > y) x = y;
    }
    void quit() {
        exit(0);
    }
} using namespace FastIO;
const int N = 2e5 + 19, M = 1e6 + 19, L = 21, inf = 1e18;
int cnt[M], val[N], dep[N], f[N][L], vis[N], dfn[N], o[N], siz[N], mxsiz[N], tot, lg[N], n, k, ans, m, c[N];
vector<pii> g[N]; vector<int> G[N];
struct BIT {
    int t[N];
    void add(int x, int s) {
        ++x;
        while (x < N) {
            smax(t[x], s);
            x += lowbit(x);
        }
    }
    void clear(int x) {
        ++x;
        while (x < N) {
            t[x] = 0; x += lowbit(x);
        }
    }
    int query(int x) {
        ++x; int ans = 0;
        while (x) {
            smax(ans, t[x]);
            x -= lowbit(x);
        } return ans;
    }
} t;
void dfs(int u, int fa) {
    f[u][0] = fa;
    for (int i = 1; i < L; ++i) f[u][i] = f[f[u][i-1]][i-1];
    for (auto [v, w] : g[u]) {
        if (v == fa) continue;
        dep[v] = dep[u] + w; val[v] = val[u] + c[v]; dfs(v, u);
    }
}
int LCA(int x, int y) {
    if (dep[x] > dep[y]) swap(x, y);
    for (int s = dep[y] - dep[x], i = 0; i < L; ++i) if (s >> i & 1) y = f[y][i];
    if (x == y) return x;
    for (int i = L-1; ~i; --i) if (f[x][i] ^ f[y][i]) x = f[x][i], y = f[y][i];
    return f[x][0];
}
int dis(int x, int y) { return dep[x] + dep[y] - 2 * dep[LCA(x, y)]; }
int get(int x, int y) { return val[x] + val[y] - 2 * val[LCA(x, y)]; }
int now, sumsiz, rt, num;
void getrt(int u, int fa) {
    siz[u] = mxsiz[u] = 1;
    for (auto [v, w] : g[u]) {
        if (v == fa || vis[v]) continue;
        getrt(v, u);
        siz[u] += siz[v]; smax(mxsiz[u], siz[v]);
    } smax(mxsiz[u], sumsiz - siz[u]);
    if (mxsiz[u] < num) {
        rt = u; num = mxsiz[u];
    }
}
int solve(int x) {
    sumsiz = siz[x];
    num = 1e9;
    getrt(x, 0); x = rt; vis[x] = 1;
    int nowsiz = sumsiz;
    for (auto [v, w] : g[x]) {
        if (vis[v]) continue;
        if (siz[x] < siz[v]) siz[v] = nowsiz - siz[x];
        int ret = solve(v); G[x].eb(ret);
    } siz[x] = nowsiz;
    return x;
} void lose(int u) {
    dfn[u] = ++tot; o[tot] = u;
    for (int v : G[u]) {
        lose(v);
    }
} void calc(int u) {
    for (int v : G[u]) calc(v);
    for (int v : G[u]) {
        for (int i = dfn[v]; i <= dfn[v] + siz[v] - 1; ++i) {
            int c = get(u, o[i]), d = dis(u, o[i]); if (c <= k) {
                smax(ans, d + t.query(k - c));
            }
        }
        for (int i = dfn[v]; i <= dfn[v] + siz[v] - 1; ++i) {
            int c = get(u, o[i]), d = dis(u, o[i]); if (c <= k) {
                t.add(c, d);
            }
        }
    }
    for (int v : G[u]) {
        for (int i = dfn[v]; i <= dfn[v] + siz[v] - 1; ++i) {
            int c = get(u, o[i]); if (c <= k) {
                t.clear(c);
            }
        }
    }
}
signed main() {
    read(n, k, m); for (int i = 1; i <= m; ++i) c[read()] = 1;
    for (int i = 1, u, v, w; i < n; ++i) {
        read(u, v, w);
        g[u].eb(v, w); g[v].eb(u, w);
    } siz[1] = n; dfs(1, 0); int rt = solve(1); lose(rt); calc(rt); write(ans);
    return 0;
}

by Hoks @ 2025-01-10 20:04:15

@Loser_Syx 呜呜呜怎么背着我卷


by Loser_Syx @ 2025-01-10 20:04:34

@Hoks 你别急,我的 LCA 写成奶龙了。


by Loser_Syx @ 2025-01-10 20:06:51

update:

#include <iostream>
#include <cassert>
#include <queue>
#include <cmath>
#include <cstring>
#include <algorithm>
#include <bitset>
#include <random>
#include <ctime>
#include <map>
#include <set>
using namespace std;
#define int long long
#define pii pair<int, int>
#define eb emplace_back
#define F first
#define S second
#define test(x) cout << "Test: " << (x) << '\n'
#define lowbit(x) (x & -x)
#define debug puts("qwq");
#define open(x) freopen(#x".in", "r", stdin);freopen(#x".out", "w", stdout);
#define close fclose(stdin);fclose(stdout);
namespace FastIO {
    template <typename T = int>
    inline T read() {
        T s = 0, w = 1;
        char c = getchar();
        while (!isdigit(c)) {
            if (c == '-') w = -1;
            c = getchar();
        }
        while (isdigit(c)) s = (s << 1) + (s << 3) + (c ^ 48), c = getchar();
        return s * w;
    }
    template <typename T>
    inline void read(T &s) {
        s = 0;
        int w = 1;
        char c = getchar();
        while (!isdigit(c)) {
            if (c == '-') w = -1;
            c = getchar();
        }
        while (isdigit(c)) s = (s << 1) + (s << 3) + (c ^ 48), c = getchar();
        s = s * w;
    }
    template <typename T, typename... Arp> inline void read(T &x, Arp &...arp) {
        read(x), read(arp...);
    }
    template <typename T>
    inline void write(T x, char ch = '\n') {
        if (x < 0) x = -x, putchar('-');
        static char stk[25];
        int top = 0;
        do {
            stk[top++] = x % 10 + '0', x /= 10;
        } while (x);
        while (top) putchar(stk[--top]);
        putchar(ch);
        return;
    }
    template <typename T>
    inline void smax(T &x, T y) {
        if (x < y) x = y;
    }
    template <typename T>
    inline void smin(T &x, T y) {
        if (x > y) x = y;
    }
    void quit() {
        exit(0);
    }
} using namespace FastIO;
const int N = 2e5 + 19, M = 1e6 + 19, L = 21, inf = 1e18;
int D[N], cnt[M], val[N], dep[N], f[N][L], vis[N], dfn[N], o[N], siz[N], mxsiz[N], tot, lg[N], n, k, ans, m, c[N];
vector<pii> g[N]; vector<int> G[N];
struct BIT {
    int t[N];
    void add(int x, int s) {
        ++x;
        while (x < N) {
            smax(t[x], s);
            x += lowbit(x);
        }
    }
    void clear(int x) {
        ++x;
        while (x < N) {
            t[x] = 0; x += lowbit(x);
        }
    }
    int query(int x) {
        ++x; int ans = 0;
        while (x) {
            smax(ans, t[x]);
            x -= lowbit(x);
        } return ans;
    }
} t;
void dfs(int u, int fa) {
    f[u][0] = fa;
    for (int i = 1; i < L; ++i) f[u][i] = f[f[u][i-1]][i-1];
    for (auto [v, w] : g[u]) {
        if (v == fa) continue;
        D[v] = D[u] + 1; dep[v] = dep[u] + w; val[v] = val[u] + c[v]; dfs(v, u);
    }
}
int LCA(int x, int y) {
    if (D[x] > D[y]) swap(x, y);
    for (int s = D[y] - D[x], i = 0; i < L; ++i) if (s >> i & 1) y = f[y][i];
    if (x == y) return x;
    for (int i = L-1; ~i; --i) if (f[x][i] ^ f[y][i]) x = f[x][i], y = f[y][i];
    return f[x][0];
}
int dis(int x, int y) { return dep[x] + dep[y] - 2 * dep[LCA(x, y)]; }
int get(int x, int y) { return val[x] + val[y] - 2 * val[LCA(x, y)] + c[LCA(x, y)]; }
int now, sumsiz, rt, num;
void getrt(int u, int fa) {
    siz[u] = mxsiz[u] = 1;
    for (auto [v, w] : g[u]) {
        if (v == fa || vis[v]) continue;
        getrt(v, u);
        siz[u] += siz[v]; smax(mxsiz[u], siz[v]);
    } smax(mxsiz[u], sumsiz - siz[u]);
    if (mxsiz[u] < num) {
        rt = u; num = mxsiz[u];
    }
}
int solve(int x) {
    sumsiz = siz[x];
    num = 1e9;
    getrt(x, 0); x = rt; vis[x] = 1;
    int nowsiz = sumsiz;
    for (auto [v, w] : g[x]) {
        if (vis[v]) continue;
        if (siz[x] < siz[v]) siz[v] = nowsiz - siz[x];
        int ret = solve(v); G[x].eb(ret);
    } siz[x] = nowsiz;
    return x;
} void lose(int u) {
    dfn[u] = ++tot; o[tot] = u;
    for (int v : G[u]) {
        lose(v);
    }
} void calc(int u) {
    for (int v : G[u]) calc(v);
    for (int v : G[u]) {
        for (int i = dfn[v]; i <= dfn[v] + siz[v] - 1; ++i) {
            int c = get(u, o[i]), d = dis(u, o[i]); if (c <= k) {
                smax(ans, d + t.query(k - c));
            }
        }
        for (int i = dfn[v]; i <= dfn[v] + siz[v] - 1; ++i) {
            int c = get(u, o[i]), d = dis(u, o[i]); if (c <= k) {
                t.add(c, d);
            }
        }
    }
    for (int v : G[u]) {
        for (int i = dfn[v]; i <= dfn[v] + siz[v] - 1; ++i) {
            int c = get(u, o[i]); if (c <= k) {
                t.clear(c);
            }
        }
    }
}
signed main() {
    read(n, k, m); for (int i = 1; i <= m; ++i) c[read()] = 1;
    for (int i = 1, u, v, w; i < n; ++i) {
        read(u, v, w);
        g[u].eb(v, w); g[v].eb(u, w);
    } siz[1] = n; dfs(1, 0); int rt = solve(1); lose(rt); calc(rt); write(ans);
    return 0;
}

by Loser_Syx @ 2025-01-10 20:17:11

一个问题是我 1 是黑点没考虑,但还是 WA。


by Loser_Syx @ 2025-01-10 20:25:26

奶龙了,我合并时候把 u 的黑色算了两遍。

现在 T 了。

#include <iostream>
#include <cassert>
#include <queue>
#include <cmath>
#include <cstring>
#include <algorithm>
#include <bitset>
#include <random>
#include <ctime>
#include <map>
#include <set>
using namespace std;
#define int long long
#define pii pair<int, int>
#define eb emplace_back
#define F first
#define S second
#define test(x) cout << "Test: " << (x) << '\n'
#define lowbit(x) (x & -x)
#define debug puts("qwq");
#define open(x) freopen(#x".in", "r", stdin);freopen(#x".out", "w", stdout);
#define close fclose(stdin);fclose(stdout);
namespace FastIO {
    template <typename T = int>
    inline T read() {
        T s = 0, w = 1;
        char c = getchar();
        while (!isdigit(c)) {
            if (c == '-') w = -1;
            c = getchar();
        }
        while (isdigit(c)) s = (s << 1) + (s << 3) + (c ^ 48), c = getchar();
        return s * w;
    }
    template <typename T>
    inline void read(T &s) {
        s = 0;
        int w = 1;
        char c = getchar();
        while (!isdigit(c)) {
            if (c == '-') w = -1;
            c = getchar();
        }
        while (isdigit(c)) s = (s << 1) + (s << 3) + (c ^ 48), c = getchar();
        s = s * w;
    }
    template <typename T, typename... Arp> inline void read(T &x, Arp &...arp) {
        read(x), read(arp...);
    }
    template <typename T>
    inline void write(T x, char ch = '\n') {
        if (x < 0) x = -x, putchar('-');
        static char stk[25];
        int top = 0;
        do {
            stk[top++] = x % 10 + '0', x /= 10;
        } while (x);
        while (top) putchar(stk[--top]);
        putchar(ch);
        return;
    }
    template <typename T>
    inline void smax(T &x, T y) {
        if (x < y) x = y;
    }
    template <typename T>
    inline void smin(T &x, T y) {
        if (x > y) x = y;
    }
    void quit() {
        exit(0);
    }
} using namespace FastIO;
const int N = 2e5 + 19, M = 1e6 + 19, L = 21, inf = 1e18;
int D[N], cnt[M], val[N], dep[N], f[N][L], vis[N], dfn[N], o[N], siz[N], mxsiz[N], tot, lg[N], n, k, ans, m, c[N], A[N], B[N];
vector<pii> g[N]; vector<int> G[N];
struct BIT {
    int t[N];
    void add(int x, int s) {
        ++x;
        while (x < N) {
            smax(t[x], s);
            x += lowbit(x);
        }
    }
    void clear(int x) {
        ++x;
        while (x < N) {
            t[x] = 0; x += lowbit(x);
        }
    }
    int query(int x) {
        ++x; int ans = 0;
        while (x) {
            smax(ans, t[x]);
            x -= lowbit(x);
        } return ans;
    }
} t;
void dfs(int u, int fa) {
    f[u][0] = fa; val[u] = val[fa] + c[u];
    for (int i = 1; i < L; ++i) f[u][i] = f[f[u][i-1]][i-1];
    for (auto [v, w] : g[u]) {
        if (v == fa) continue;
        D[v] = D[u] + 1; dep[v] = dep[u] + w; dfs(v, u);
    }
}
int LCA(int x, int y) {
    if (D[x] > D[y]) swap(x, y);
    for (int s = D[y] - D[x], i = 0; i < L; ++i) if (s >> i & 1) y = f[y][i];
    if (x == y) return x;
    for (int i = L-1; ~i; --i) if (f[x][i] ^ f[y][i]) x = f[x][i], y = f[y][i];
    return f[x][0];
}
int dis(int x, int y) { return dep[x] + dep[y] - 2 * dep[LCA(x, y)]; }
int get(int x, int y) { 
    int L = LCA(x, y);
    return val[x] + val[y] - 2 * val[L] + c[L]; 
}
int now, sumsiz, rt, num;
void getrt(int u, int fa) {
    siz[u] = mxsiz[u] = 1;
    for (auto [v, w] : g[u]) {
        if (v == fa || vis[v]) continue;
        getrt(v, u);
        siz[u] += siz[v]; smax(mxsiz[u], siz[v]);
    } smax(mxsiz[u], sumsiz - siz[u]);
    if (mxsiz[u] < num) {
        rt = u; num = mxsiz[u];
    }
}
int solve(int x) {
    sumsiz = siz[x];
    num = 1e9;
    getrt(x, 0); x = rt; vis[x] = 1;
    int nowsiz = sumsiz;
    for (auto [v, w] : g[x]) {
        if (vis[v]) continue;
        if (siz[x] < siz[v]) siz[v] = nowsiz - siz[x];
        int ret = solve(v); G[x].eb(ret);
    } siz[x] = nowsiz;
    return x;
} void lose(int u) {
    dfn[u] = ++tot; o[tot] = u;
    for (int v : G[u]) {
        lose(v);
    }
} void calc(int u) {
    for (int v : G[u]) calc(v);
    for (int v : G[u]) {
        for (int i = dfn[v]; i <= dfn[v] + siz[v] - 1; ++i) {
            A[o[i]] = get(u, o[i]); B[o[i]] = dis(u, o[i]);
            int e = A[o[i]] - c[u], d = B[o[i]]; if (e <= k) {
                smax(ans, d + t.query(k - e));
            }
        }
        for (int i = dfn[v]; i <= dfn[v] + siz[v] - 1; ++i) {
            int c = A[o[i]], d = B[o[i]]; if (c <= k) {
                t.add(c, d);
            }
        }
    }
    for (int v : G[u]) {
        for (int i = dfn[v]; i <= dfn[v] + siz[v] - 1; ++i) {
            int c = A[o[i]]; if (c <= k) {
                t.clear(c);
            }
        }
    }
}
signed main() {
    read(n, k, m); for (int i = 1; i <= m; ++i) c[read()] = 1;
    for (int i = 1, u, v, w; i < n; ++i) {
        read(u, v, w);
        g[u].eb(v, w); g[v].eb(u, w);
    } siz[1] = n; dfs(1, 0); int rt = solve(1); lose(rt); calc(rt); write(ans);
    return 0;
}

by loser_seele @ 2025-01-10 20:25:29

糖果。


|