ShiRoZeTsu @ 2023-09-05 20:52:57
从 #6 开始就过不去了,不知道为什么
还有,我为什么要用线段树合并写这道题啊()
#include <iostream>
#include <cstdio>
using namespace std;
const int maxn = 6e5 + 5;
struct edge {
int to, nxt;
} e[maxn<<1];
int tot = 1, head[maxn];
void add(int u, int v) {
e[++tot].to = v;
e[tot].nxt = head[u];
head[u] = tot;
}
int n, m;
int dep[maxn], ans[maxn], w[maxn];
int st[25][maxn];
void dfs1(int u, int fa) {
dep[u] = dep[fa] + 1;
st[0][u] = fa;
for(int i = 1; (1<<i) <= dep[u]; i++)
st[i][u] = st[i-1][st[i-1][u]];
for(int i = head[u]; i; i = e[i].nxt) {
int v = e[i].to;
if(v != fa) dfs1(v, u);
}
}
int lca(int x, int y) {
if(dep[x] < dep[y]) swap(x, y);
for(int i = 22; dep[x] != dep[y]; i--)
if(dep[y]+(1<<i) <= dep[x]) x = st[i][x];
if(x == y) return x;
for(int i = 22; i >= 0; i--)
if((1<<i) <= dep[x] && st[i][x] != st[i][y])
x = st[i][x], y = st[i][y];
return st[0][x];
}
#define mid ((l + r) >> 1)
struct seg {
int cnt, top;
int root[maxn], stk[maxn];
struct node {
int ls, rs, v;
} t[maxn<<5];
int newnode() {
if(top) {
top--;
return stk[top+1];
}
return ++cnt;
}
void del(int x) {
t[x].ls = t[x].rs = t[x].v = 0;
stk[++top] = x;
}
void modify(int& o, int l, int r, int pos, int val) {
if(!o) o = newnode();
if(l == r) {
t[o].v += val;
return;
}
if(pos <= mid) modify(t[o].ls, l, mid, pos, val);
else modify(t[o].rs, mid+1, r, pos, val);
}
int merge(int o, int p, int l, int r) {
if(!o || !p) return o+p;
if(l == r) {
t[o].v += t[p].v;
del(p);
return o;
}
t[o].ls = merge(t[o].ls, t[p].ls, l, mid);
t[o].rs = merge(t[o].rs, t[p].rs, mid+1, r);
del(p);
return o;
}
int query(int o, int l, int r, int pos) {
if(!o) return 0;
if(l == r) return t[o].v;
if(pos <= mid) return query(t[o].ls, l, mid, pos);
else return query(t[o].rs, mid+1, r, pos);
}
} a, b;
void dfs2(int u, int fa) {
for(int i = head[u]; i; i = e[i].nxt) {
int v = e[i].to;
if(v == fa) continue;
dfs2(v, u);
a.root[u] = a.merge(a.root[u], a.root[v], 1, n<<1);
b.root[u] = b.merge(b.root[u], b.root[v], 1, n<<1);
}
ans[u] += a.query(a.root[u], 1, n<<1, dep[u]+w[u]);
ans[u] += b.query(b.root[u], 1, n<<1, dep[u]+n-w[u]);
}
int main() {
scanf("%d %d", &n, &m);
for(int i = 1; i < n; i++) {
int u, v;
scanf("%d %d", &u, &v);
add(u, v); add(v, u);
}
dfs1(1, 0);
for(int i = 1; i <= n; i++)
scanf("%d", &w[i]);
for(int i = 1; i <= m; i++) {
int u, v;
scanf("%d %d", &u, &v);
int zx = lca(u, v);
a.modify(a.root[u], 1, n<<1, dep[u], 1);
b.modify(b.root[v], 1, n<<1, n-(dep[u]-dep[zx])+dep[zx], 1);
a.modify(a.root[zx], 1, n<<1, dep[u], -1);
b.modify(b.root[st[0][zx]], 1, n<<1, n-(dep[u]-dep[zx])+dep[zx], -1);
}
dfs2(1, 0);
for(int i = 1; i <= n; i++)
printf("%d ", ans[i]);
return 0;
}