树剖板子 TLE 8 9 10 其余AC

P3384 【模板】重链剖分/树链剖分

AK_heaven @ 2024-04-07 22:20:12

悬关,求调

#include <bits/stdc++.h>
#define int long long
#define rep(a, b, c) for(int c = a; c <= b; c++)
#define rer(a, b, c) for(int c = a; c >= b; c--)
#define rG(u) for(int i = h[u], v = e[i].v; i; i = e[i].last, v = e[i].v)
#define ls (p<<1)
#define rs (p<<1|1)

using namespace std;

const int maxn = 1e6 + 10;

int top[maxn], dfn[maxn], rnk[maxn], n, m, pi, rt, sz[maxn], w[maxn];

int h[maxn], tt, dep[maxn], cnt, hson[maxn], fa[maxn];

struct Edge {
    int v, last;
}e[maxn<<1];

void add(int u, int v) {
    e[++tt] = {v, h[u]};
    h[u] = tt;
}

struct TT {
    int l, r, sm, tag;
}; // tamplate of tree

struct Seg {
    TT tr[maxn];
    void pushdown(int p) {
        if(tr[p].tag) {
            tr[ls].sm += tr[p].tag * (tr[ls].r - tr[ls].l + 1) % pi;
            tr[ls].sm %= pi;
            tr[rs].sm += tr[p].tag * (tr[rs].r - tr[rs].l + 1) % pi;
            tr[rs].sm %= pi;
            tr[ls].tag += tr[p].tag;
            tr[rs].tag += tr[p].tag;
            tr[p].tag = 0;
        }
    }

    void pushup(int p) {
        tr[p].sm = tr[ls].sm + tr[rs].sm;
        tr[p].sm %= pi;
    }

    void build(int p, int l, int r) {
        tr[p] = {l, r, w[rnk[l]], 0};
        if(l == r) return ;
        int mid = (l + r) >> 1;
        build(ls, l, mid); build(rs, mid+1, r);
        pushup(p);
    }

    void update(int p, int x, int y, int k) {
        if(x <= tr[p].l && tr[p].r <= y) {
            tr[p].sm += k * (tr[p].r - tr[p].l + 1) % pi;
            tr[p].sm %= pi;
            tr[p].tag += k;
            return ;
        }
        pushdown(p);
        int mid = (tr[p].l + tr[p].r) >> 1;
        if(x <= mid) update(ls, x, y, k);
        if(y > mid) update(rs, x, y, k);
        pushup(p);
    }

    int query(int p, int x, int y) {
        if(x <= tr[p].l && tr[p].r <= y)
          return tr[p].sm;
        int mid = (tr[p].l + tr[p].r) >> 1, ret = 0;
        pushdown(p);
        if(x <= mid) ret = (ret + query(ls, x, y)) % pi;
        if(y > mid) ret = (ret + query(rs, x, y)) % pi;
        return ret;
    }
}st;

int querysum1(int x, int y) {
    int ret = 0, fx = top[x], fy = top[y];
    while(fx != fy) {
        if(dep[fx] >= dep[fy])
          ret += st.query(1, dfn[fx], dfn[x]), x = fa[fx];
        else 
          ret += st.query(1, dfn[fy], dfn[y]), y = fa[fy];
        ret %= pi;
        fx = top[x];
        fy = top[y];
    }
    if(dep[x] < dep[y])
      ret += st.query(1, dfn[x], dfn[y]);
    else 
      ret += st.query(1, dfn[y], dfn[x]);
    return ret;
}

void dfs1(int u) {
    sz[u] = 1;
    hson[u] = -1;
    rG(u) {
        if(!dep[v]) {
            dep[v] = dep[u] + 1;
            fa[v] = u;
            dfs1(v);
            sz[u] += sz[v];
            if(hson[u] == -1 || sz[v] >= hson[u])
              hson[u] = v;
        }
    }
}

void dfs2(int u, int t) {
    top[u] = t;
    cnt++;
    dfn[u] = cnt;
    rnk[cnt] = u;
    if(hson[u] == -1)
      return ;
    dfs2(hson[u], t);
    rG(u) {
        if(v != hson[u] && v != fa[u])
          dfs2(v, v);
    }
}

void modify1(int x, int y, int k) {
    int fx = top[x], fy = top[y];
    while(fx != fy) {
        if(dep[fx] >= dep[fy])
          st.update(1, dfn[fx], dfn[x], k), x = fa[fx];
        else 
          st.update(1, dfn[fy], dfn[y], k), y = fa[fy];
        fx = top[x];
        fy = top[y];
    }
    if(dep[x] < dep[y])
      st.update(1, dfn[x], dfn[y], k);
    else 
      st.update(1, dfn[y], dfn[x], k);
}

void modify2(int x, int k) {
    st.update(1, dfn[x], dfn[x] + sz[x] - 1, k);
}

int querysum2(int x) {
    return st.query(1, dfn[x], dfn[x] + sz[x] - 1) % pi;
}

void buildT() {
    dep[rt] = 1;
    dfs1(rt);
    dfs2(rt, rt);
    st.build(1, 1, n);
}

signed main() {
    ios::sync_with_stdio(false);
    cin.tie(0); cout.tie(0);
    cin >> n >> m >> rt >> pi;
    rep(1, n, i)
      cin >> w[i];
    rep(1, n-1, i) {
        int u, v; cin >> u >> v;
        add(u, v); add(v, u);
    }
    buildT();
    rep(1, m, i) {
        int op, x, y, z;
        cin >> op >> x;
        if(op == 1) {
            cin >> y >> z;
            modify1(x, y, z);
        }
        if(op == 2) {
            cin >> y;
            cout << querysum1(x, y) % pi << '\n';
        }
        if(op == 3) {
            cin >> y;
            modify2(x, y);
        }
        if(op == 4) {
            cout << querysum2(x) % pi<< '\n';
        }
    }
    return 0;
}

by Nagasaki_Soyo @ 2024-04-07 22:23:04

@AK_heaven if(hson[u] == -1 || sz[v] >= sz[hson[u]])


by AK_heaven @ 2024-04-10 21:29:34

@Nagasaki_Soyo 已关


|