大佬求助,37pts过了#1, #2, #3, #11,真不知道哪里写挂了555

P3384 【模板】重链剖分/树链剖分

xinhuo2005 @ 2024-07-13 15:49:54

#include<bits/stdc++.h>
#define ll long long
using namespace std;

const int maxn = 2e5 + 10;
vector<int> e[maxn];
int a[maxn], w[maxn];
ll seg[maxn << 2], tag[maxn << 2];
int fa[maxn], siz[maxn], son[maxn], dep[maxn];
int dfn[maxn], top[maxn], rk[maxn];
int mod;

inline int read()
{
    int x = 0, f = 1;
    char c = getchar();
    while (c < '0' || c > '9')
    {
        if (c == '-') f = -1;
        c = getchar();
    }
    while (c >= '0' && c <= '9')
    {
        x = (x << 1) + (x << 3) + (c ^ '0');
        c = getchar();
    }
    return x * f;
}

void add(int u, int  v)
{
    e[u].emplace_back(v);
    e[v].emplace_back(u);
}

inline int ls(int p) {return p << 1;}
inline int rs(int p) {return p << 1 | 1;}

void push_up(int p)
{
    seg[p] = (seg[ls(p)] + seg[rs(p)]) % mod;
}

void push_down(int l, int r, int p)
{
    int mid = (l + r) >> 1;

    seg[ls(p)] = (seg[ls(p)] + tag[p] * (mid - l + 1ll)) % mod;
    seg[rs(p)] = (seg[rs(p)] + tag[p] * (r - mid)) % mod;

    tag[ls(p)] = (tag[ls(p)] + tag[p]) % mod;
    tag[rs(p)] = (tag[rs(p)] + tag[p]) % mod;
    tag[p] = 0;
}

void bulid(int l, int r, int p)
{
    if (l == r) 
    {
        seg[p] = a[l] % mod; return;    
    }
    int mid = (l + r) >> 1;
    bulid(l, mid, ls(p)); bulid(mid + 1, r, rs(p));
    push_up(p);
}

void update_add(int s, int t, int l, int r, int x, int p)
{
    x %= mod;
    if (s <= l && r <= t)
    {
        seg[p] = ((r - l + 1) * x + seg[p]) % mod;
        tag[p] = (tag[p] + x) % mod;
        return ;
    }

    push_down(l, r, p);
    int mid = (l + r) >> 1;
    if (s <= mid) update_add(s, t, l, mid, x, ls(p));
    if (mid < t) update_add(s, t, mid + 1, r, x, rs(p));
    push_up(p);
}

ll getSum(int s, int t, int l, int r, int p)
{
    ll res = 0ll;
    if (s <= l && r <= t) return seg[p] % mod;

    push_down(l, r, p);
    int mid = (l + r) >> 1;

    if (s <= mid) res = (res + getSum(s, t, l, mid, ls(p))) % mod;
    if (mid < t) res = (res + getSum(s, t, mid + 1, r, rs(p))) % mod;
    return res;
}

void dfs1(int u, int fno)
{
    siz[u] = 1;
    fa[u] = fno;
    son[u] = 0;
    int res = 1;
    for (auto v : e[u])
    {
        if (v == fno) continue;
        dep[v] = dep[u] + 1;
        dfs1(v, u); 
        if (son[u] == 0 || res < siz[v])
        {
            son[u] = v;
            res = siz[v];
        }
        siz[u] += siz[v];
    }
}

int n, m, root, cnt;

void dfs2(int u, int t)
{
    top[u] = t;
    cnt ++;
    dfn[u] = cnt;
    rk[cnt] = u;
    a[cnt] = w[u];

    if (son[u] == 0) return ;
    dfs2(son[u], t);
    for (auto v : e[u])
    {
        if (v == fa[u] || v == son[u]) continue;
        dfs2(v, v);
    }
}

void opt1(int x, int y, int z)
{
    while (top[x] != top[y])
    {
        if (dep[x] < dep[y]) swap(x, y);
        update_add(dfn[top[x]], dfn[x], 1, n, z, 1);
        x = fa[top[x]];
    }
    if (dep[x] < dep[y]) swap(x, y);
    update_add(dfn[y], dfn[x], 1, n, z, 1);
}

ll opt2(int x, int y)
{
    ll res = 0ll;
    while (top[x] != top[y])
    {
        if (dep[x] < dep[y]) swap(x, y);
        res = (res + getSum(dfn[top[x]], dfn[x], 1, n, 1)) % mod;
        x = fa[top[x]];
    }
    if (dep[x] < dep[y]) swap(x, y);
    res = (res + getSum(dfn[y], dfn[x], 1, n, 1)) % mod;
    return res;
}

void opt3(int x, int z)
{
    update_add(dfn[x], dfn[x] + siz[x] - 1, 1, n, z, 1);
}

ll opt4(int x)
{
    return getSum(dfn[x], dfn[x] + siz[x] - 1, 1, n, 1);
}

int main()
{
    n = read(), m = read(), root = read(), mod = read();
    for (int i = 1; i <= n; i++) w[i] = read();
    int u, v;
    for (int i = 1; i < n; i++)
    {
        u = read(), v = read();
        add(u, v);
    }

    dfs1(root, root);
    dfs2(root, root);

    bulid(1, n, 1);
    int opt, x, y, z;
    while (m--)
    {
        opt = read();
        if (opt == 1)
        {
            x = read(), y = read(), z = read();
            opt1(x, y, z);
        }
        else if (opt == 2)
        {
            x = read(), y = read();
            printf("%lld\n", opt2(x, y));
        }
        else if (opt == 3)
        {
            x = read(), z = read();
            opt3(x, z);
        }
        else 
        {
            x = read();
            printf("%lld\n", opt4(x));
        }
    }
    return 0;
}

by wby_1234 @ 2024-07-13 16:22:43

将“ seg[p] = (seg[ls(p)] + seg[rs(p)]) % mod;” 改为“ seg[p] = (seg[ls(p)] + seg[rs(p)]) % mod+1;”即可。


by xinhuo2005 @ 2024-07-13 16:41:56

@wby_1234 大佬不对,改了还是没过


by wby_1234 @ 2024-07-13 17:50:15

我过了


by LWT223355 @ 2024-08-01 09:41:30

@xinhuo2005 教练~ stO Orz


|