树剖求调玄关

P3384 【模板】重链剖分/树链剖分

Kevinx @ 2023-12-31 21:07:26

样例都没过(悲

#include<bits/stdc++.h>
#define ll long long
#define ls(p) p<<1
#define rs(p) p<<1|1
#define Mid ll mid = (l + r) >> 1
using namespace std;
const int N = 1e5 + 20;
ll n, m, r, p, a[N], b[N];
ll cnt = 1, num = 0, h[N<<1], fa[N<<1],dep[N<<1], sz[N<<1], son[N<<1], top[N<<1], id[N<<1];
struct node{
    ll v, nxt;
}e[N<<1];

struct TR{
    ll l, r, sum, tag;
}tree[N<<2];

void add_edge(ll u, ll v) {
    e[++cnt] = (node){v, h[u]};
    h[u] = cnt;
}
void dfs1(ll u, ll f, ll deep) {
    dep[u] = deep;
    fa[u] = f;
    sz[u] = 1;
    ll maxsz = -1;
    for(int i = h[u]; i; i = e[i].nxt) {
        ll v = e[i].v;
        if(v == f) continue;
        dfs1(v, u, deep + 1);
        sz[u] += sz[v];
        if(sz[v] > maxsz) maxsz = sz[v], son[u] = v;
    }
}
void dfs2(ll u, ll ttop) {
    id[u] = ++ num;
    a[num] = b[u];
    top[u] = ttop;
    if(!son[u]) return;
    dfs2(son[u], ttop);
    for(int i = h[u]; i; i = e[i].nxt) {
        ll v = e[i].v;
        if(son[u] == v || v == fa[u]) continue;
        dfs2(v, v);
    }
}
void build(ll p, ll l, ll r) {
    tree[p].l = l, tree[p].r = r;
    if(l == r) {
        tree[p].sum = a[l] % p;
        return ;
    }
    Mid;
    build(ls(p), l, mid);
    build(rs(p), mid + 1, r);
    tree[p].sum = (tree[ls(p)].sum + tree[rs(p)].sum) % p;
}
void up(ll p) {
    tree[p].sum = (tree[ls(p)].sum + tree[rs(p)].sum) % p;
}
void down(ll p, ll l, ll r) {
    ll k = tree[p].tag % p;
    tree[ls(p)].sum = (k * (tree[ls(p)].r - tree[ls(p)].l + 1) + tree[ls(p)].sum) % p;
    tree[rs(p)].sum = (k * (tree[rs(p)].r - tree[rs(p)].l + 1) + tree[rs(p)].sum) % p;
    tree[ls(p)].tag = (k + tree[ls(p)].tag) % p;
    tree[rs(p)].tag = (k + tree[rs(p)].tag) % p;
    tree[p].tag = 0;
}

void add(ll p, ll l, ll r, ll x, ll y, ll k) {
    if(x <= l && r <= y) {
//  cout << p << endl;;
        tree[p].tag = (tree[p].tag + k) % p;
        tree[p].sum = (tree[p].sum + k * (l - r + 1)) % p;
        return ;
    }

    down(p, l, r);
    Mid;
    if(x <= mid) add(ls(p), l, mid, x, y, k);
    if(y >  mid) add(rs(p), mid + 1, r, x, y, k);
    up(p);
    return ;
}
ll query(ll p, ll l, ll r, ll x, ll y) {
    if(x <= l && r <= y) {
        return tree[p].sum % p;
    }
//  cout << p; 
    down(p, l, r);
    Mid;
    ll res = 0;
    if(x <= mid) res += query(ls(p), l, mid, x, y);
    if(y >  mid) res += query(rs(p), mid + 1, r, x, y);
    up(p);
    return res % p;
}

void add_path(ll x, ll y, ll k) {
    k %= p;
    while(top[x] != top[y]) {
        if(dep[top[x]] < dep[top[y]]) swap(x, y);
        add(1, 1, n, id[top[x]], id[x], k);
        x = fa[top[x]];
    }
    if(dep[x] > dep[y]) swap(x, y);
    add(1, 1, n, id[x], id[y], k);
}

ll query_path(ll x, ll y) {
    ll res = 0;
    while(top[x] != top[y]) {
        if(dep[top[x]] < dep[top[y]]) swap(x, y);
        res = (res + query(1, 1, n, id[top[x]], id[x])) % p;
        x = fa[top[x]];
    }
    if(dep[x] > dep[y]) swap(x, y);
    res = (res + query(1, 1, n, id[x], id[y])) % p;
    return res;
}

int main (){
    scanf("%lld%lld%lld%lld", &n, &m, &r, &p);
//  cout << 111;
    for(int i = 1; i <= n; i++) {
//      cout << i;
        scanf("%lld", &b[i]);
    }
    for(int i = 1; i < n; i++) {
        ll u, v;
        scanf("%lld%lld", &u, &v); 
        add_edge(u, v);
        add_edge(v, u);
    }
    dfs1(r, 0, 1);
    dfs2(r, r);
    build(1, 1, n);
//  cout << tree[1].sum << endl;
    for(int i  = 1; i <= m; i++){
        ll op, x, y, k;
//      cout << 1;
        scanf("%lld", &op);
        if(op == 1) {
            scanf("%lld%lld%lld", &x, &y, &k);

            add_path(x, y, k);
        }
        if(op == 2) {
            scanf("%lld%lld", &x, &y);
            printf("%lld\n", query_path(x, y));
        }
        if(op == 3) {
            scanf("%lld%lld", &x, &k);
//          cout << id[x] << " " << id[x] + sz[x] - 1 << endl;
            add(1, 1, n, id[x], id[x] + sz[x] - 1, k);
        }
//      cout << 1999;
        if(op == 4) {
//          cout << 1999;
            scanf("%lld", &x);
//          cout << id[x] << " " << id[x] + sz[x] - 1 << endl;
            printf("%lld\n", query(1, 1, n, id[x], id[x] + sz[x] - 1));
        }
    }
    return 0;
} 

by Genshineer @ 2023-12-31 21:35:42

@Kevinx 你模数p和线段树节点p是不是重名了


by Kevinx @ 2023-12-31 22:16:08

@Genshineer 改了,但是28pts

只AC了#2、#3、#11

#include<bits/stdc++.h>
#define ll long long
#define ls(p) p<<1
#define rs(p) p<<1|1
#define Mid ll mid = (l + r) >> 1
using namespace std;
const int N = 1e5 + 20;
ll n, m, r, mod, a[N], b[N];
ll cnt = 1, num = 0, h[N<<1], fa[N<<1],dep[N<<1], sz[N<<1], son[N<<1], top[N<<1], id[N<<1];
struct node{
    ll v, nxt;
}e[N<<1];

struct TR{
    ll l, r, sum, tag;
}tree[N<<2];

void add_edge(ll u, ll v) {
    e[++cnt] = (node){v, h[u]};
    h[u] = cnt;
}
void dfs1(ll u, ll f, ll deep) {
    dep[u] = deep;
    fa[u] = f;
    sz[u] = 1;
    ll maxsz = -1;
    for(int i = h[u]; i; i = e[i].nxt) {
        ll v = e[i].v;
        if(v == f) continue;
        dfs1(v, u, deep + 1);
        sz[u] += sz[v];
        if(sz[v] > maxsz) maxsz = sz[v], son[u] = v;
    }
}
void dfs2(ll u, ll ttop) {
    id[u] = ++ num;
    a[num] = b[u];
    top[u] = ttop;
    if(!son[u]) return;
    dfs2(son[u], ttop);
    for(int i = h[u]; i; i = e[i].nxt) {
        ll v = e[i].v;
        if(son[u] == v || v == fa[u]) continue;
        dfs2(v, v);
    }
}
void build(ll p, ll l, ll r) {
    tree[p].l = l, tree[p].r = r;
    if(l == r) {
        tree[p].sum = a[l]% mod;
        return ;
    }
    Mid;
    build(ls(p), l, mid);
    build(rs(p), mid + 1, r);
    tree[p].sum = (tree[ls(p)].sum + tree[rs(p)].sum)% mod;
}
void up(ll p) {
    tree[p].sum = (tree[ls(p)].sum + tree[rs(p)].sum)% mod;
}
void down(ll p, ll l, ll r) {
    ll k = tree[p].tag % mod;
    tree[ls(p)].sum = (k * (tree[ls(p)].r - tree[ls(p)].l + 1) + tree[ls(p)].sum)% mod;
    tree[rs(p)].sum = (k * (tree[rs(p)].r - tree[rs(p)].l + 1) + tree[rs(p)].sum)% mod;
    tree[ls(p)].tag = (k + tree[ls(p)].tag)% mod;
    tree[rs(p)].tag = (k + tree[rs(p)].tag)% mod;
    tree[p].tag = 0;
}

void add(ll p, ll l, ll r, ll x, ll y, ll k) {
    if(x <= l && r <= y) {
//  cout << p << endl;;
        tree[p].tag = (tree[p].tag + k)% mod;
        tree[p].sum = (tree[p].sum + k * (l - r + 1))% mod;
        return ;
    }

    down(p, l, r);
    Mid;
    if(x <= mid) add(ls(p), l, mid, x, y, k);
    if(y >  mid) add(rs(p), mid + 1, r, x, y, k);
    up(p);
    return ;
}
ll query(ll p, ll l, ll r, ll x, ll y) {
    if(x <= l && r <= y) {
        return tree[p].sum% mod;
    }
//  cout << p; 
    down(p, l, r);
    Mid;
    ll res = 0;
    if(x <= mid) res += query(ls(p), l, mid, x, y);
    if(y >  mid) res += query(rs(p), mid + 1, r, x, y);
    up(p);
    return res% mod;
}

void add_path(ll x, ll y, ll k) {
    k %= mod;
    while(top[x] != top[y]) {
        if(dep[top[x]] < dep[top[y]]) swap(x, y);
        add(1, 1, n, id[top[x]], id[x], k);
        x = fa[top[x]];
    }
    if(dep[x] > dep[y]) swap(x, y);
    add(1, 1, n, id[x], id[y], k);
}

ll query_path(ll x, ll y) {
    ll res = 0;
    while(top[x] != top[y]) {
        if(dep[top[x]] < dep[top[y]]) swap(x, y);
        res = (res + query(1, 1, n, id[top[x]], id[x]))% mod;
        x = fa[top[x]];
    }
    if(dep[x] > dep[y]) swap(x, y);
    res = (res + query(1, 1, n, id[x], id[y]))% mod;
    return res;
}

int main (){
    scanf("%lld%lld%lld%lld", &n, &m, &r, &mod);
//  cout << 111;
    for(int i = 1; i <= n; i++) {
//      cout << i;
        scanf("%lld", &b[i]);
    }
    for(int i = 1; i < n; i++) {
        ll u, v;
        scanf("%lld%lld", &u, &v); 
        add_edge(u, v);
        add_edge(v, u);
    }
    dfs1(r, 0, 1);
    dfs2(r, r);
    build(1, 1, n);
//  cout << tree[1].sum << endl;
    for(int i  = 1; i <= m; i++){
        ll op, x, y, k;
//      cout << 1;
        scanf("%lld", &op);
        if(op == 1) {
            scanf("%lld%lld%lld", &x, &y, &k);

            add_path(x, y, k);
        }
        if(op == 2) {
            scanf("%lld%lld", &x, &y);
            printf("%lld\n", query_path(x, y));
        }
        if(op == 3) {
            scanf("%lld%lld", &x, &k);
//          cout << id[x] << " " << id[x] + sz[x] - 1 << endl;
            add(1, 1, n, id[x], id[x] + sz[x] - 1, k);
        }
//      cout << 1999;
        if(op == 4) {
//          cout << 1999;
            scanf("%lld", &x);
//          cout << id[x] << " " << id[x] + sz[x] - 1 << endl;
            printf("%lld\n", query(1, 1, n, id[x], id[x] + sz[x] - 1));
        }
    }
    return 0;
} 

by Genshineer @ 2023-12-31 22:25:40

@Kevinx add()函数

k * (l - r + 1)

你认真的吗


by Kevinx @ 2023-12-31 22:31:47

@Genshineer vocal,谢谢dalao%%%%


|