求助

P3384 【模板】重链剖分/树链剖分

liuxy1234 @ 2023-05-02 12:32:06

#include <bits/stdc++.h>
#define int long long
using namespace std;

inline int read()
{
    register int x = 0, f = 1;
    register char c = getchar();
    while(c < '0' || c > '9')
    {
        if(c == '-')f = -1;
        c = getchar();
    }
    while(c <= '9' && c >= '0')
    {
        x = x * 10 + c - '0';
        c = getchar();
    }
    return x * f;
}

struct segment_tree
{
    int l, r, sum, lazy;
}s[500010];

struct node
{
    int fa, top, pos, wson, size, dep, val;
    //fa是父结点编号
    //top是链顶端
    //pos是按子树内、重链上编号连续重排后的编号
    //wson是重子编号
    //size是子树大小
    //dep是深度
    //val是节点点权 
}t[100010];

int n, m, r, mod;

struct edge
{
    int v, nxt;
}e[200010];

int h[100010], cnt;

void addedge(int u, int v)
{
    cnt++;
    e[cnt].v = v, e[cnt].nxt = h[u], h[u] = cnt;
    return;
}

int p[100010];

void pushdown(int id)
{
    s[id * 2].lazy = (s[id * 2].lazy + s[id].lazy) % mod;
    s[id * 2 + 1].lazy += s[id].lazy;
    s[id * 2 + 1].lazy %= mod;
    s[id].sum += s[id].lazy * (s[id].r - s[id].l + 1);
    s[id].sum %= mod;
    s[id].lazy = 0;
    return;
}

void update(int id)
{
    s[id].sum = s[id * 2 + 1].sum + (s[id * 2].lazy * (s[id * 2].r - s[id * 2].l + 1)) + (s[id * 2 + 1].lazy * (s[id * 2 + 1].r - s[id * 2 + 1].l + 1));
    s[id].sum %= mod;
    return;
}

void buildtree(int q, int l, int r)
{
    s[q].l = l, s[q].r = r, s[q].lazy = 0;
    if(l == r)
    {
        s[q].sum = p[l];
        s[q].sum %= mod;
        return;
    }
    buildtree(q * 2, l, (l + r) / 2);
    buildtree(q * 2 + 1, (l + r) / 2 + 1, r);
    update(q);
    return;
}

void add(int q, int l, int r, int c)
{
    if(s[q].l == l && s[q].r == r)
    {
        s[q].lazy += c;
        s[q].lazy %= mod;
        return;
    }
    pushdown(q);
    if(l >= s[q * 2].l && l <= s[q * 2].r)add(q * 2, l, min(s[q * 2].r, r), c);
    if(r >= s[q * 2 + 1].l && r <= s[q * 2 + 1].r)add(q * 2 + 1, max(s[q * 2 + 1].l, l), r, c);
    update(q);
    return;
}

int query(int q, int l, int r)
{
    if(l == s[q].l && r == s[q].r)
    {
        return (s[q].sum + s[q].lazy * (s[q].r - s[q].l + 1)) % mod;
    }
    pushdown(q);
    int ans = 0;
    if(l >= s[q * 2].l && l <= s[q * 2].r)ans += query(l, min(s[q * 2].r, r), q * 2);
    ans %= mod;
    if(r <= s[q * 2 + 1].r && r >= s[q * 2 + 1].l)ans += query(max(s[q * 2 + 1].l, l), r, q * 2 + 1);
    update(q);
    return ans % mod;
} 

void dfs1(int x, int fa, int dep)//处理size,wson 
{
    t[x].dep = dep;
    t[x].size = 1;
    t[x].fa = fa;
    int maxsize = -1, wsonid = 0;
    for(int i = h[x];i;i = e[i].nxt)
    {
        int v = e[i].v;
        if(v == fa)continue;
        dfs1(v, x, dep + 1);
        if(t[v].size > maxsize)
        {
            maxsize = t[v].size;
            wsonid = v;
        }
        t[x].size += t[v].size;
    }
    t[x].wson = wsonid;
    return;
}

int pos;

void dfs2(int x, int fa)//求top&pos
{
    if(t[fa].wson == x)t[x].top = t[fa].top;
    else t[x].top = x;
    pos++;
    t[x].pos = pos;
    p[pos] = t[x].val;
    if(t[x].wson != 0)dfs2(t[x].wson, x);
    for(int i = h[x];i;i = e[i].nxt)
    {
        int v = e[i].v;
        if(v == fa || v == t[x].wson)continue;
        dfs2(v, x);
    }
    return;
}

void addroute(int x, int y, int c)
{
    while(t[x].top != t[y].top)
    {
        if(t[t[x].top].dep < t[t[y].top].dep)
        {
            swap(x, y);
        }
        add(1, t[t[x].top].pos, t[x].pos, c);
        x = t[t[x].top].fa;
    }
    if(t[x].dep < t[y].dep)swap(x, y);
    add(1, t[y].pos, t[x].pos, c);
    return;
}

int queryroute(int x, int y)
{ 
    int ans = 0;
    while(t[x].top != t[y].top)
    {
        if(t[t[x].top].dep < t[t[y].top].dep)
        {
            swap(x, y);
        }
        ans += query(1, t[t[x].top].pos, t[x].pos);
        ans %= mod;
        x = t[t[x].top].fa;
    }
    if(t[x].dep < t[y].dep)swap(x, y);
    ans += query(1, t[y].pos, t[x].pos);
    ans %= mod;
    return ans;
}

void addtree(int x, int c)
{
    add(1, t[x].pos, t[x].pos + t[x].size - 1, c);
    return;
}

int querytree(int x)
{
    int ans;
    ans = query(1, t[x].pos, t[x].pos + t[x].size - 1);
    ans %= mod;
    return ans;
}

signed main()
{
    cin >> n >> m >> r >> mod;
    for(int i = 1;i <= n;i++)
    {
        cin >> t[i].val;
    }
    for(int i = 1;i < n;i++)
    {
         int u, v;
         cin >> u >> v;
         addedge(u, v);
         addedge(v, u);
    }
    buildtree(1, 1, n); 
    dfs1(r, r, 1);
    dfs2(r, 0);
    while(m--)
    {
        int op, x, y, z;
        cin >> op >> x;
        if(op == 1)
        {
            cin >> y >> z;
            addroute(x, y, z);
        }
        if(op == 2)
        {
            cin >> y;
            cout << queryroute(x, y) << "\n";
        }
        if(op == 3)
        {
            cin >> z;
            addtree(x, z);
        }
        if(op == 4)
        {
            cout << querytree(x) << "\n";
        }
    }
    return 0;
}

|