求助 19pts 只过了1和11

P3384 【模板】重链剖分/树链剖分

YZYxx694 @ 2023-05-19 13:29:47

#include<bits/stdc++.h>

using namespace std;

using i64 = long long;

const int N = 1e5 + 10;

vector<int> e[N];

int n, tot, m, R, P;
int l[N], r[N], idx[N], a[N];
int c[N], son[N], sz[N], top[N];
int dep[N], f[N];

struct node{
    int val;
    int tag;
}seg[N * 4];

void update(int id)
{
    seg[id].val = (seg[id * 2].val + seg[id * 2 + 1].val) % P;
}

void build(int id, int l, int r)
{
    if(l == r)
    {
        seg[id].val = a[idx[l]] % P;
        return ;
    }
    int mid = (l + r) / 2;
    build(id * 2, l, mid);
    build(id * 2 + 1, mid + 1, r);
    update(id);
}

void pushdown(int id, int l, int r)
{
    int mid = (l + r) / 2;
    seg[id * 2].tag = (seg[id * 2].tag + seg[id].tag) % P;
    seg[id * 2].val = (seg[id * 2].val + seg[id].tag * (mid - l + 1) % P) % P;
    seg[id * 2 + 1].tag = (seg[id * 2 + 1].tag + seg[id].tag) % P;
    seg[id * 2 + 1].val = (seg[id * 2 + 1].val
     + seg[id].tag * (r - mid) % P) % P;
    seg[id].tag = 0;
}

void modify(int id, int l, int r, int ql, int qr, int k)
{
    if(ql == l && qr == r)
    {
        seg[id].tag = (seg[id].tag + k) % P;
        seg[id].val = (seg[id].val + k * (r - l + 1) % P) % P;
        return ;
    }
    int mid = (l + r) / 2;
    pushdown(id, l, r);
    if(qr <= mid)
    {
        modify(id * 2, l, mid, ql, qr, k);
    }
    else if(ql > mid)
        modify(id * 2 + 1, mid + 1, r, ql, qr, k);
    else
    {
        modify(id * 2, l, mid, ql, mid, k);
        modify(id * 2 + 1, mid + 1, r, mid + 1, qr, k);
    }
    update(id);
}

i64 query(int id, int l, int r, int ql, int qr)
{
    if(l == ql && r == qr)
    {
        return seg[id].val % P;
    }
    pushdown(id, l, r);
    int mid = (l + r) / 2;
    if(qr <= mid)
    {
        return query(id * 2, l, mid, ql, qr);
    }
    else if(ql > mid)
    {
        return query(id * 2 + 1, mid + 1, r, ql, qr);
    }
    else{
        return (query(id * 2, l, mid, ql, mid) + 
                query(id * 2, mid + 1, r, mid + 1, qr)) % P;
    }
}

//dfs1求子树大小,重儿子,父亲,深度
void dfs1(int u, int fa)
{
    dep[u] = dep[fa] + 1;
    sz[u] = 1;
    son[u] = -1;
    for(auto v : e[u])
    {
        if(v == fa)
            continue;
        f[v] = u;
        dfs1(v, u);
        sz[u] += sz[v];
        if(son[u] == -1 || sz[v] > sz[son[u]])
            son[u] = v;
    }
}

//dfs2,求dfs序,重链上链头的元素
void dfs2(int u, int topf)
{
    top[u] = topf;
    l[u] = ++tot;
    idx[tot] = u;
    if(son[u] != -1)
    {
        dfs2(son[u], topf);
    }
    for(auto v : e[u])
    {
        if(v == f[u] || v == son[u])
            continue;
        dfs2(v, v);
    }
    r[u] = tot;
}

signed main()
{
    ios::sync_with_stdio(false);
    cin.tie(0), cout.tie(0);

    cin >> n >> m >> R >> P;
    for(int i = 1; i <= n; i++)
    {
        cin >> a[i];
    }
    for(int i = 1; i < n; i++)
    {
        int u, v;
        cin >> u >> v;
        e[u].push_back(v);
        e[v].push_back(u);
    }
    dfs1(R, 0);
    dfs2(R, R);
    build(1, 1, n);
    for(int i = 1; i <= m; i++)
    {
        int ty;
        cin >> ty;
        if(ty == 1)
        {
            int x, y, z;
            cin >> x >> y >> z;
            while(top[x] != top[y])
            {
                if(dep[top[x]] > dep[top[y]])
                {
                    modify(1, 1, n, l[top[x]], l[x], z);
                    x = f[top[x]];  
                }
                else{
                    modify(1, 1, n, l[top[y]], l[y], z);
                    y = f[top[y]];
                }
            }
            if(dep[x] < dep[y])
            {
                modify(1, 1, n, l[x], l[y], z);
            }
            else{
                modify(1, 1, n, l[y], l[x], z);
            }
        }
        else if(ty == 2)
        {
            int x, y;
            cin >> x >> y;
            i64 ans = 0;
            while(top[x] != top[y])
            {
                if(dep[top[x]] > dep[top[y]])
                {
                    ans = (ans + query(1, 1, n, l[top[x]], l[x])) % P;
                    x = f[top[x]];
                }
                else{
                    ans = (ans + query(1, 1, n, l[top[y]], l[y])) % P;
                    y = f[top[y]];
                }
            }
            if(dep[x] < dep[y])
            {
                ans = (ans + query(1, 1, n, l[x], l[y])) % P;
            }
            else{
                ans = (ans + query(1, 1, n, l[y], l[x])) % P;
            }
            cout << ans << endl;
        }
        else if(ty == 3)
        {
            int x, z;
            cin >> x >> z;
            modify(1, 1, n, l[x], r[x], z);
        }
        else{
            int x;
            cin >> x;
            cout << query(1, 1, n, l[x], r[x]) << endl;
        }
    }

    return 0;
}

|