Loser_Syx @ 2025-01-10 19:54:21
#include <iostream>
#include <cassert>
#include <queue>
#include <cmath>
#include <cstring>
#include <algorithm>
#include <bitset>
#include <random>
#include <ctime>
#include <map>
#include <set>
using namespace std;
#define int long long
#define pii pair<int, int>
#define eb emplace_back
#define F first
#define S second
#define test(x) cout << "Test: " << (x) << '\n'
#define lowbit(x) (x & -x)
#define debug puts("qwq");
#define open(x) freopen(#x".in", "r", stdin);freopen(#x".out", "w", stdout);
#define close fclose(stdin);fclose(stdout);
namespace FastIO {
template <typename T = int>
inline T read() {
T s = 0, w = 1;
char c = getchar();
while (!isdigit(c)) {
if (c == '-') w = -1;
c = getchar();
}
while (isdigit(c)) s = (s << 1) + (s << 3) + (c ^ 48), c = getchar();
return s * w;
}
template <typename T>
inline void read(T &s) {
s = 0;
int w = 1;
char c = getchar();
while (!isdigit(c)) {
if (c == '-') w = -1;
c = getchar();
}
while (isdigit(c)) s = (s << 1) + (s << 3) + (c ^ 48), c = getchar();
s = s * w;
}
template <typename T, typename... Arp> inline void read(T &x, Arp &...arp) {
read(x), read(arp...);
}
template <typename T>
inline void write(T x, char ch = '\n') {
if (x < 0) x = -x, putchar('-');
static char stk[25];
int top = 0;
do {
stk[top++] = x % 10 + '0', x /= 10;
} while (x);
while (top) putchar(stk[--top]);
putchar(ch);
return;
}
template <typename T>
inline void smax(T &x, T y) {
if (x < y) x = y;
}
template <typename T>
inline void smin(T &x, T y) {
if (x > y) x = y;
}
void quit() {
exit(0);
}
} using namespace FastIO;
const int N = 2e5 + 19, M = 1e6 + 19, L = 21, inf = 1e18;
int cnt[M], val[N], dep[N], f[N][L], vis[N], dfn[N], o[N], siz[N], mxsiz[N], tot, lg[N], n, k, ans, m, c[N];
vector<pii> g[N]; vector<int> G[N];
struct BIT {
int t[N];
void add(int x, int s) {
++x;
while (x < N) {
smax(t[x], s);
x += lowbit(x);
}
}
void clear(int x) {
++x;
while (x < N) {
t[x] = 0; x += lowbit(x);
}
}
int query(int x) {
++x; int ans = 0;
while (x) {
smax(ans, t[x]);
x -= lowbit(x);
} return ans;
}
} t;
void dfs(int u, int fa) {
f[u][0] = fa;
for (int i = 1; i < L; ++i) f[u][i] = f[f[u][i-1]][i-1];
for (auto [v, w] : g[u]) {
if (v == fa) continue;
dep[v] = dep[u] + w; val[v] = val[u] + c[v]; dfs(v, u);
}
}
int LCA(int x, int y) {
if (dep[x] > dep[y]) swap(x, y);
for (int s = dep[y] - dep[x], i = 0; i < L; ++i) if (s >> i & 1) y = f[y][i];
if (x == y) return x;
for (int i = L-1; ~i; --i) if (f[x][i] ^ f[y][i]) x = f[x][i], y = f[y][i];
return f[x][0];
}
int dis(int x, int y) { return dep[x] + dep[y] - 2 * dep[LCA(x, y)]; }
int get(int x, int y) { return val[x] + val[y] - 2 * val[LCA(x, y)]; }
int now, sumsiz, rt, num;
void getrt(int u, int fa) {
siz[u] = mxsiz[u] = 1;
for (auto [v, w] : g[u]) {
if (v == fa || vis[v]) continue;
getrt(v, u);
siz[u] += siz[v]; smax(mxsiz[u], siz[v]);
} smax(mxsiz[u], sumsiz - siz[u]);
if (mxsiz[u] < num) {
rt = u; num = mxsiz[u];
}
}
int solve(int x) {
sumsiz = siz[x];
num = 1e9;
getrt(x, 0); x = rt; vis[x] = 1;
int nowsiz = sumsiz;
for (auto [v, w] : g[x]) {
if (vis[v]) continue;
if (siz[x] < siz[v]) siz[v] = nowsiz - siz[x];
int ret = solve(v); G[x].eb(ret);
} siz[x] = nowsiz;
return x;
} void lose(int u) {
dfn[u] = ++tot; o[tot] = u;
for (int v : G[u]) {
lose(v);
}
} void calc(int u) {
for (int v : G[u]) calc(v);
for (int v : G[u]) {
for (int i = dfn[v]; i <= dfn[v] + siz[v] - 1; ++i) {
int c = get(u, o[i]), d = dis(u, o[i]); if (c <= k) {
smax(ans, d + t.query(k - c));
}
}
for (int i = dfn[v]; i <= dfn[v] + siz[v] - 1; ++i) {
int c = get(u, o[i]), d = dis(u, o[i]); if (c <= k) {
t.add(c, d);
}
}
}
for (int v : G[u]) {
for (int i = dfn[v]; i <= dfn[v] + siz[v] - 1; ++i) {
int c = get(u, o[i]); if (c <= k) {
t.clear(c);
}
}
}
}
signed main() {
read(n, k, m); for (int i = 1; i <= m; ++i) c[read()] = 1;
for (int i = 1, u, v, w; i < n; ++i) {
read(u, v, w);
g[u].eb(v, w); g[v].eb(u, w);
} siz[1] = n; dfs(1, 0); int rt = solve(1); lose(rt); calc(rt); write(ans);
return 0;
}
by Hoks @ 2025-01-10 20:04:15
@Loser_Syx 呜呜呜怎么背着我卷
by Loser_Syx @ 2025-01-10 20:04:34
@Hoks 你别急,我的 LCA 写成奶龙了。
by Loser_Syx @ 2025-01-10 20:06:51
update:
#include <iostream>
#include <cassert>
#include <queue>
#include <cmath>
#include <cstring>
#include <algorithm>
#include <bitset>
#include <random>
#include <ctime>
#include <map>
#include <set>
using namespace std;
#define int long long
#define pii pair<int, int>
#define eb emplace_back
#define F first
#define S second
#define test(x) cout << "Test: " << (x) << '\n'
#define lowbit(x) (x & -x)
#define debug puts("qwq");
#define open(x) freopen(#x".in", "r", stdin);freopen(#x".out", "w", stdout);
#define close fclose(stdin);fclose(stdout);
namespace FastIO {
template <typename T = int>
inline T read() {
T s = 0, w = 1;
char c = getchar();
while (!isdigit(c)) {
if (c == '-') w = -1;
c = getchar();
}
while (isdigit(c)) s = (s << 1) + (s << 3) + (c ^ 48), c = getchar();
return s * w;
}
template <typename T>
inline void read(T &s) {
s = 0;
int w = 1;
char c = getchar();
while (!isdigit(c)) {
if (c == '-') w = -1;
c = getchar();
}
while (isdigit(c)) s = (s << 1) + (s << 3) + (c ^ 48), c = getchar();
s = s * w;
}
template <typename T, typename... Arp> inline void read(T &x, Arp &...arp) {
read(x), read(arp...);
}
template <typename T>
inline void write(T x, char ch = '\n') {
if (x < 0) x = -x, putchar('-');
static char stk[25];
int top = 0;
do {
stk[top++] = x % 10 + '0', x /= 10;
} while (x);
while (top) putchar(stk[--top]);
putchar(ch);
return;
}
template <typename T>
inline void smax(T &x, T y) {
if (x < y) x = y;
}
template <typename T>
inline void smin(T &x, T y) {
if (x > y) x = y;
}
void quit() {
exit(0);
}
} using namespace FastIO;
const int N = 2e5 + 19, M = 1e6 + 19, L = 21, inf = 1e18;
int D[N], cnt[M], val[N], dep[N], f[N][L], vis[N], dfn[N], o[N], siz[N], mxsiz[N], tot, lg[N], n, k, ans, m, c[N];
vector<pii> g[N]; vector<int> G[N];
struct BIT {
int t[N];
void add(int x, int s) {
++x;
while (x < N) {
smax(t[x], s);
x += lowbit(x);
}
}
void clear(int x) {
++x;
while (x < N) {
t[x] = 0; x += lowbit(x);
}
}
int query(int x) {
++x; int ans = 0;
while (x) {
smax(ans, t[x]);
x -= lowbit(x);
} return ans;
}
} t;
void dfs(int u, int fa) {
f[u][0] = fa;
for (int i = 1; i < L; ++i) f[u][i] = f[f[u][i-1]][i-1];
for (auto [v, w] : g[u]) {
if (v == fa) continue;
D[v] = D[u] + 1; dep[v] = dep[u] + w; val[v] = val[u] + c[v]; dfs(v, u);
}
}
int LCA(int x, int y) {
if (D[x] > D[y]) swap(x, y);
for (int s = D[y] - D[x], i = 0; i < L; ++i) if (s >> i & 1) y = f[y][i];
if (x == y) return x;
for (int i = L-1; ~i; --i) if (f[x][i] ^ f[y][i]) x = f[x][i], y = f[y][i];
return f[x][0];
}
int dis(int x, int y) { return dep[x] + dep[y] - 2 * dep[LCA(x, y)]; }
int get(int x, int y) { return val[x] + val[y] - 2 * val[LCA(x, y)] + c[LCA(x, y)]; }
int now, sumsiz, rt, num;
void getrt(int u, int fa) {
siz[u] = mxsiz[u] = 1;
for (auto [v, w] : g[u]) {
if (v == fa || vis[v]) continue;
getrt(v, u);
siz[u] += siz[v]; smax(mxsiz[u], siz[v]);
} smax(mxsiz[u], sumsiz - siz[u]);
if (mxsiz[u] < num) {
rt = u; num = mxsiz[u];
}
}
int solve(int x) {
sumsiz = siz[x];
num = 1e9;
getrt(x, 0); x = rt; vis[x] = 1;
int nowsiz = sumsiz;
for (auto [v, w] : g[x]) {
if (vis[v]) continue;
if (siz[x] < siz[v]) siz[v] = nowsiz - siz[x];
int ret = solve(v); G[x].eb(ret);
} siz[x] = nowsiz;
return x;
} void lose(int u) {
dfn[u] = ++tot; o[tot] = u;
for (int v : G[u]) {
lose(v);
}
} void calc(int u) {
for (int v : G[u]) calc(v);
for (int v : G[u]) {
for (int i = dfn[v]; i <= dfn[v] + siz[v] - 1; ++i) {
int c = get(u, o[i]), d = dis(u, o[i]); if (c <= k) {
smax(ans, d + t.query(k - c));
}
}
for (int i = dfn[v]; i <= dfn[v] + siz[v] - 1; ++i) {
int c = get(u, o[i]), d = dis(u, o[i]); if (c <= k) {
t.add(c, d);
}
}
}
for (int v : G[u]) {
for (int i = dfn[v]; i <= dfn[v] + siz[v] - 1; ++i) {
int c = get(u, o[i]); if (c <= k) {
t.clear(c);
}
}
}
}
signed main() {
read(n, k, m); for (int i = 1; i <= m; ++i) c[read()] = 1;
for (int i = 1, u, v, w; i < n; ++i) {
read(u, v, w);
g[u].eb(v, w); g[v].eb(u, w);
} siz[1] = n; dfs(1, 0); int rt = solve(1); lose(rt); calc(rt); write(ans);
return 0;
}
by Loser_Syx @ 2025-01-10 20:17:11
一个问题是我 1 是黑点没考虑,但还是 WA。
by Loser_Syx @ 2025-01-10 20:25:26
奶龙了,我合并时候把 u
的黑色算了两遍。
现在 T 了。
#include <iostream>
#include <cassert>
#include <queue>
#include <cmath>
#include <cstring>
#include <algorithm>
#include <bitset>
#include <random>
#include <ctime>
#include <map>
#include <set>
using namespace std;
#define int long long
#define pii pair<int, int>
#define eb emplace_back
#define F first
#define S second
#define test(x) cout << "Test: " << (x) << '\n'
#define lowbit(x) (x & -x)
#define debug puts("qwq");
#define open(x) freopen(#x".in", "r", stdin);freopen(#x".out", "w", stdout);
#define close fclose(stdin);fclose(stdout);
namespace FastIO {
template <typename T = int>
inline T read() {
T s = 0, w = 1;
char c = getchar();
while (!isdigit(c)) {
if (c == '-') w = -1;
c = getchar();
}
while (isdigit(c)) s = (s << 1) + (s << 3) + (c ^ 48), c = getchar();
return s * w;
}
template <typename T>
inline void read(T &s) {
s = 0;
int w = 1;
char c = getchar();
while (!isdigit(c)) {
if (c == '-') w = -1;
c = getchar();
}
while (isdigit(c)) s = (s << 1) + (s << 3) + (c ^ 48), c = getchar();
s = s * w;
}
template <typename T, typename... Arp> inline void read(T &x, Arp &...arp) {
read(x), read(arp...);
}
template <typename T>
inline void write(T x, char ch = '\n') {
if (x < 0) x = -x, putchar('-');
static char stk[25];
int top = 0;
do {
stk[top++] = x % 10 + '0', x /= 10;
} while (x);
while (top) putchar(stk[--top]);
putchar(ch);
return;
}
template <typename T>
inline void smax(T &x, T y) {
if (x < y) x = y;
}
template <typename T>
inline void smin(T &x, T y) {
if (x > y) x = y;
}
void quit() {
exit(0);
}
} using namespace FastIO;
const int N = 2e5 + 19, M = 1e6 + 19, L = 21, inf = 1e18;
int D[N], cnt[M], val[N], dep[N], f[N][L], vis[N], dfn[N], o[N], siz[N], mxsiz[N], tot, lg[N], n, k, ans, m, c[N], A[N], B[N];
vector<pii> g[N]; vector<int> G[N];
struct BIT {
int t[N];
void add(int x, int s) {
++x;
while (x < N) {
smax(t[x], s);
x += lowbit(x);
}
}
void clear(int x) {
++x;
while (x < N) {
t[x] = 0; x += lowbit(x);
}
}
int query(int x) {
++x; int ans = 0;
while (x) {
smax(ans, t[x]);
x -= lowbit(x);
} return ans;
}
} t;
void dfs(int u, int fa) {
f[u][0] = fa; val[u] = val[fa] + c[u];
for (int i = 1; i < L; ++i) f[u][i] = f[f[u][i-1]][i-1];
for (auto [v, w] : g[u]) {
if (v == fa) continue;
D[v] = D[u] + 1; dep[v] = dep[u] + w; dfs(v, u);
}
}
int LCA(int x, int y) {
if (D[x] > D[y]) swap(x, y);
for (int s = D[y] - D[x], i = 0; i < L; ++i) if (s >> i & 1) y = f[y][i];
if (x == y) return x;
for (int i = L-1; ~i; --i) if (f[x][i] ^ f[y][i]) x = f[x][i], y = f[y][i];
return f[x][0];
}
int dis(int x, int y) { return dep[x] + dep[y] - 2 * dep[LCA(x, y)]; }
int get(int x, int y) {
int L = LCA(x, y);
return val[x] + val[y] - 2 * val[L] + c[L];
}
int now, sumsiz, rt, num;
void getrt(int u, int fa) {
siz[u] = mxsiz[u] = 1;
for (auto [v, w] : g[u]) {
if (v == fa || vis[v]) continue;
getrt(v, u);
siz[u] += siz[v]; smax(mxsiz[u], siz[v]);
} smax(mxsiz[u], sumsiz - siz[u]);
if (mxsiz[u] < num) {
rt = u; num = mxsiz[u];
}
}
int solve(int x) {
sumsiz = siz[x];
num = 1e9;
getrt(x, 0); x = rt; vis[x] = 1;
int nowsiz = sumsiz;
for (auto [v, w] : g[x]) {
if (vis[v]) continue;
if (siz[x] < siz[v]) siz[v] = nowsiz - siz[x];
int ret = solve(v); G[x].eb(ret);
} siz[x] = nowsiz;
return x;
} void lose(int u) {
dfn[u] = ++tot; o[tot] = u;
for (int v : G[u]) {
lose(v);
}
} void calc(int u) {
for (int v : G[u]) calc(v);
for (int v : G[u]) {
for (int i = dfn[v]; i <= dfn[v] + siz[v] - 1; ++i) {
A[o[i]] = get(u, o[i]); B[o[i]] = dis(u, o[i]);
int e = A[o[i]] - c[u], d = B[o[i]]; if (e <= k) {
smax(ans, d + t.query(k - e));
}
}
for (int i = dfn[v]; i <= dfn[v] + siz[v] - 1; ++i) {
int c = A[o[i]], d = B[o[i]]; if (c <= k) {
t.add(c, d);
}
}
}
for (int v : G[u]) {
for (int i = dfn[v]; i <= dfn[v] + siz[v] - 1; ++i) {
int c = A[o[i]]; if (c <= k) {
t.clear(c);
}
}
}
}
signed main() {
read(n, k, m); for (int i = 1; i <= m; ++i) c[read()] = 1;
for (int i = 1, u, v, w; i < n; ++i) {
read(u, v, w);
g[u].eb(v, w); g[v].eb(u, w);
} siz[1] = n; dfs(1, 0); int rt = solve(1); lose(rt); calc(rt); write(ans);
return 0;
}
by loser_seele @ 2025-01-10 20:25:29
糖果。