28pts求调

P3384 【模板】重链剖分/树链剖分

xiaozhuo @ 2023-09-19 12:40:43

#include<bits/stdc++.h>
using namespace std;
#define ll long long
int n, root, m, p, top[100010], dep[100010], size[100010], son[100010], fa[100010], rev[100010], id[100010], w[100010];
int head[100010], cnt, tot;
struct Tree
{
    ll tag, sum;
}tr[400010];
struct Node
{
    int to, next;
}e[200010];
void add(int u, int v)
{
    e[++cnt].to = v, e[cnt].next = head[u], head[u] = cnt;
}
void dfs1(int u, int f)
{
    size[u] = 1;
    for(int i = head[u];i;i = e[i].next)
    {
        int y = e[i].to;
        if(y == f) continue;
        dep[y] = dep[u] + 1;
        fa[y] = u;
        dfs1(y, u);
        size[u] += size[y];
        if(size[y] > size[son[u]]) son[u] = y;
    }
}
void dfs2(int u, int t)
{
    top[u] = t;
    id[u] = ++tot;
    rev[tot] = u;
    if(!son[u]) return; 
    dfs2(son[u], t);
    for(int i = head[u];i;i = e[i].next)
    {
        int y = e[i].to;
        if(y == fa[u] || y == son[u]) continue;
        dfs2(y, y);
    }
}
void pushup(int rt)
{
    tr[rt].sum = (tr[rt * 2].sum + tr[rt * 2 + 1].sum % p) % p;
}
void pushdown(int rt, int len)
{
    if(tr[rt].tag)
    {
        tr[rt * 2].sum = (tr[rt * 2].sum + tr[rt].tag % p * (len - (len >> 1))) % p;
        tr[rt * 2 + 1].sum = (tr[rt * 2 + 1].sum + tr[rt].tag % p * (len >> 1)) % p;
        tr[rt * 2].tag += tr[rt].tag;
        tr[rt * 2 + 1].tag += tr[rt].tag;
        tr[rt].tag = 0;
    }
}
void build(int l, int r, int rt)
{
    if(l >= r)
    {
        tr[rt].sum = w[rev[l]] % p;
        return;
    }
    int mid = (l + r) >> 1;
    build(l, mid, rt * 2);
    build(mid + 1, r, rt * 2 + 1);
    pushup(rt);
}
ll query(int L, int R, int l, int r, int rt)
{
    if(L <= l && R >= r)
        return tr[rt].sum % p;
    pushdown(rt, r - l + 1);
    int mid = (l + r) >> 1;
    ll res = 0;
    if(R > mid) res = (res + query(L, R, mid + 1, r, rt * 2 + 1) % p) % p;
    if(L <= mid) res = (res + query(L, R, l, mid, rt * 2) % p) % p;
    return res % p;
}
void update(int L, int R, int l, int r, int rt, int k)
{
    if(L <= l && R >= r)
    {
        tr[rt].sum = (tr[rt].sum + k * (r - l + 1)) % p;
        tr[rt].tag += k;
        return;
    }
    int mid = (l + r) >> 1;
    if(R > mid) update(L, R, mid + 1, r, rt * 2 + 1, k);
    if(L <= mid) update(L, R, l, mid, rt * 2, k);
    pushup(rt);
}

void changel(int x, int y, int k)
{
    k %= p;
    while(top[x] != top[y])
    {
        if(dep[top[x]] < dep[top[y]]) swap(x, y);
        update(id[top[x]], id[x], 1, n, 1, k);
        x = fa[top[x]];
    }
    if(dep[x] > dep[y]) swap(x, y);
    update(id[x], id[y], 1, n, 1, k);
}
ll ql(int x, int y)
{
    ll sum = 0;
    while(top[x] != top[y])
    {
        if(dep[top[x]] < dep[top[y]]) swap(x, y);
        sum = (sum + query(id[top[x]], id[x], 1, n, 1)) % p;
        x = fa[top[x]];
    }
    if(dep[x] > dep[y]) swap(x, y);
    sum = (sum + query(id[x], id[y], 1, n, 1)) % p;
    return sum;
}
void changes(int x, int k)
{
    k %= p;
    update(id[x], id[x] + size[x] - 1, 1, n, 1, k);
}
ll qs(int x)
{
    return query(id[x], id[x] + size[x] - 1, 1, n, 1);
}
int main()
{
    cin >> n >> m >> root >> p;
    for(int i = 1;i <= n;i ++) cin >> w[i];
    for(int i = 1;i < n;i ++)
    {
        int u, v;
        cin >> u >> v;
        add(u, v), add(v, u);
    }
    dfs1(root, 0);
    dfs2(root, root);
    build(1, n, 1);
    while(m --)
    {
        int op, x, y, z;
        cin >> op;
        if(op == 1)
        {
            cin >> x >> y >> z;
            changel(x, y, z);
        }
        if(op == 2)
        {
            cin >> x >> y;
            cout << ql(x, y) << endl;
        }
        if(op == 3)
        {
            cin >> x >> z;
            changes(x, z);
        }
        if(op == 4)
        {
            cin >> x;
            cout << qs(x) << endl;
        }
    }
    return 0;
}

by xiaozhuo @ 2023-09-20 14:54:55

呜呜呜,update忘写pushdown了,我是蠢驴。害我又重写一遍,然后有调了半天才发现,太久没写线段树了受不了了,此贴结


by liuye20100123 @ 2024-08-25 23:14:49

@xiaozhuo 太感谢啦,我也错这儿了,调了半天,呜呜


|