#25 线段树合并求调

P1600 [NOIP2016 提高组] 天天爱跑步

ShiRoZeTsu @ 2023-09-05 20:52:57

从 #6 开始就过不去了,不知道为什么

还有,我为什么要用线段树合并写这道题啊()

#include <iostream>
#include <cstdio>
using namespace std;
const int maxn = 6e5 + 5;

struct edge {
    int to, nxt;
} e[maxn<<1];

int tot = 1, head[maxn];
void add(int u, int v) {
    e[++tot].to = v;
    e[tot].nxt = head[u];
    head[u] = tot;
}

int n, m;
int dep[maxn], ans[maxn], w[maxn];
int st[25][maxn];

void dfs1(int u, int fa) {
    dep[u] = dep[fa] + 1;
    st[0][u] = fa;
    for(int i = 1; (1<<i) <= dep[u]; i++)
        st[i][u] = st[i-1][st[i-1][u]];
    for(int i = head[u]; i; i = e[i].nxt) {
        int v = e[i].to;
        if(v != fa) dfs1(v, u);
    }
}

int lca(int x, int y) {
    if(dep[x] < dep[y]) swap(x, y);
    for(int i = 22; dep[x] != dep[y]; i--)
        if(dep[y]+(1<<i) <= dep[x]) x = st[i][x];
    if(x == y) return x;
    for(int i = 22; i >= 0; i--)
        if((1<<i) <= dep[x] && st[i][x] != st[i][y])
            x = st[i][x], y = st[i][y];
    return st[0][x];
}

#define mid ((l + r) >> 1)

struct seg {
    int cnt, top;
    int root[maxn], stk[maxn];

    struct node {
        int ls, rs, v;
    } t[maxn<<5];

    int newnode() {
        if(top) {
            top--;
            return stk[top+1];
        }
        return ++cnt;
    }

    void del(int x) {
        t[x].ls = t[x].rs = t[x].v = 0;
        stk[++top] = x;
    }

    void modify(int& o, int l, int r, int pos, int val) {
        if(!o) o = newnode();
        if(l == r) {
            t[o].v += val;
            return;
        }
        if(pos <= mid) modify(t[o].ls, l, mid, pos, val);
        else modify(t[o].rs, mid+1, r, pos, val);
    }

    int merge(int o, int p, int l, int r) {
        if(!o || !p) return o+p;
        if(l == r) {
            t[o].v += t[p].v;
            del(p);
            return o;
        }
        t[o].ls = merge(t[o].ls, t[p].ls, l, mid);
        t[o].rs = merge(t[o].rs, t[p].rs, mid+1, r);
        del(p);
        return o;
    }

    int query(int o, int l, int r, int pos) {
        if(!o) return 0;
        if(l == r) return t[o].v;
        if(pos <= mid) return query(t[o].ls, l, mid, pos);
        else return query(t[o].rs, mid+1, r, pos);
    }
} a, b;

void dfs2(int u, int fa) {
    for(int i = head[u]; i; i = e[i].nxt) {
        int v = e[i].to;
        if(v == fa) continue;
        dfs2(v, u);
        a.root[u] = a.merge(a.root[u], a.root[v], 1, n<<1);
        b.root[u] = b.merge(b.root[u], b.root[v], 1, n<<1);
    }
    ans[u] += a.query(a.root[u], 1, n<<1, dep[u]+w[u]);
    ans[u] += b.query(b.root[u], 1, n<<1, dep[u]+n-w[u]);
}

int main() {
    scanf("%d %d", &n, &m);
    for(int i = 1; i < n; i++) {
        int u, v;
        scanf("%d %d", &u, &v);
        add(u, v); add(v, u);
    }
    dfs1(1, 0);

    for(int i = 1; i <= n; i++)
        scanf("%d", &w[i]);
    for(int i = 1; i <= m; i++) {
        int u, v;
        scanf("%d %d", &u, &v);
        int zx = lca(u, v);
        a.modify(a.root[u], 1, n<<1, dep[u], 1);
        b.modify(b.root[v], 1, n<<1, n-(dep[u]-dep[zx])+dep[zx], 1);
        a.modify(a.root[zx], 1, n<<1, dep[u], -1);
        b.modify(b.root[st[0][zx]], 1, n<<1, n-(dep[u]-dep[zx])+dep[zx], -1);
    }
    dfs2(1, 0);

    for(int i = 1; i <= n; i++)
        printf("%d ", ans[i]);
    return 0;
}

|