37pts求调试,悬赏关注

P3384 【模板】重链剖分/树链剖分

halehu @ 2023-07-21 23:32:03

已经步步取模了还是错。。。

记录

#include<bits/stdc++.h>
#define LL long long
using namespace std;
const int N = 1e6 + 5;
LL n,m,root,mod,tr[N],add[N],w1[N],w2[N],head[N],tot;
LL f[N],son[N],sz[N],id[N],cnt,top[N],dep[N];
struct edge{
    LL nxt,to;
}e[N];
void Add(LL u,LL v){
    e[tot].to = v,e[tot].nxt = head[u],head[u] = tot ++;
}
void dfs1(LL u,LL fa){
    f[u] = fa;
    sz[u] ++;
    dep[u] = dep[fa] + 1;
    for(int i=head[u];i!=-1;i=e[i].nxt){
        LL v = e[i].to;
        if(v == fa) continue;
        dfs1(v,u);
        if(!son[u] || sz[v] > sz[son[u]])
           son[u] = v;
        sz[u] += sz[v];
    }
}
void dfs2(LL u,LL topu){
    id[u] = ++ cnt;
    w2[cnt] = w1[u];
    top[u] = topu;
    if(!son[u]) return;
    dfs2(son[u],topu);
    for(int i=head[u];i!=-1;i=e[i].nxt){
        LL v = e[i].to;
        if(v == f[u] || v == son[u]) continue;
        dfs2(v,v);
    }
}
void pushdown(LL p,LL len){
    if(!add[p]) return;
    add[p*2] = (add[p*2] + add[p]) % mod;
    add[p*2 + 1] = (add[p*2 + 1] + add[p]) % mod;
    tr[p*2] = (tr[p*2] + add[p] * (len - len / 2) % mod) % mod;
    tr[p*2 + 1] = (tr[p*2 + 1] + add[p] * len / 2 % mod) % mod;
    add[p] = 0;
}
void update_tree(LL x,LL y,LL l,LL r,LL p,LL val){
    if(x > r || y < l) return;
    if(x <= l && r <= y){
        tr[p] = (tr[p] + (r - l + 1) * val % mod) % mod;
        add[p] = (add[p] + val) % mod;
        return; 
    }
    LL mid = (l + r) >> 1;
    pushdown(p,r - l + 1);
    update_tree(x,y,l,mid,p*2,val);
    update_tree(x,y,mid+1,r,p*2+1,val);
    tr[p] = (tr[p*2] + tr[p*2 + 1]) % mod;
}
LL query_tree(LL x,LL y,LL l,LL r,LL p){
    if(x > r || y < l) return 0;
    if(x <= l && y >= r) return tr[p] % mod;
    LL mid = (l + r) >> 1;
    pushdown(p,r - l + 1);
    return (query_tree(x,y,l,mid,p*2) + query_tree(x,y,mid+1,r,p*2 + 1)) % mod;
}
void update(LL x,LL y,LL val){
    while(top[x] != top[y]){
        if(dep[top[x]] < dep[top[y]]) swap(x,y);
        update_tree(id[top[x]],id[x],1,n,1,val);
        x = f[top[x]];
    }
    if(dep[x] > dep[y]) swap(x,y);
    update_tree(id[x],id[y],1,n,1,val);
}
LL query(LL x,LL y){
    LL res = 0;
    while(top[x] != top[y]){
        if(dep[top[x]] < dep[top[y]]) swap(x,y);
        res = (res + query_tree(id[top[x]],id[x],1,n,1)) % mod;
        x = f[top[x]];
    }
    if(dep[x] > dep[y]) swap(x,y);
    res = (res + query_tree(id[x],id[y],1,n,1)) % mod;
    return res;
}
void build(LL l,LL r,LL p){
    if(l == r){
        tr[p] = w2[l] % mod;
        return;
    }
    LL mid = (l + r) >> 1;
    build(l,mid,p*2);
    build(mid+1,r,p*2+1);
    tr[p] = (tr[p*2] + tr[p*2 + 1]) % mod;
}
int main(){
    memset(head,-1,sizeof head);
    scanf("%lld%lld%lld%lld",&n,&m,&root,&mod);
    for(int i=1;i<=n;i++) scanf("%lld",&w1[i]),w1[i] = w1[i] % mod;
    for(int i=1;i<n;i++){
        LL a,b;
        scanf("%lld%lld",&a,&b);
        Add(a,b),Add(b,a);
    }
    dfs1(root,0);
    dfs2(root,root);
    build(1,n,1);
    for(int i=1;i<=m;i++){
        LL op,x,y,z;
        scanf("%lld%lld",&op,&x);
        if(op == 1){
            scanf("%lld%lld",&y,&z);
            update(x,y,z);
        }
        else if(op == 2){
            scanf("%lld",&y);
            printf("%lld\n",query(x,y) % mod);
        }
        else if(op == 3){
            scanf("%lld",&z);
            update_tree(id[x],id[x] + sz[x] - 1,1,n,1,z);
        }
        else if(op == 4)
            printf("%lld\n",query_tree(id[x],id[x] + sz[x] - 1,1,n,1) % mod);
    }
}

by _XHY20180718_ @ 2023-07-22 00:50:02

@halehu

pushdown中算区间长度为啥不打括号?

区间长度要向下取整的,不能直接乘。

tr[p*2 + 1] = (tr[p*2 + 1] + add[p] * (len / 2) % mod) % mod;

然后A了:

#include<bits/stdc++.h>
#define LL long long
using namespace std;
const int N = 1e6 + 5;
LL n,m,root,mod,tr[N],add[N],w1[N],w2[N],head[N],tot;
LL f[N],son[N],sz[N],id[N],cnt,top[N],dep[N];
struct edge{
    LL nxt,to;
}e[N];
void Add(LL u,LL v){
    e[tot].to = v,e[tot].nxt = head[u],head[u] = tot ++;
}
void dfs1(LL u,LL fa){
    f[u] = fa;
    sz[u] ++;
    dep[u] = dep[fa] + 1;
    for(int i=head[u];i!=-1;i=e[i].nxt){
        LL v = e[i].to;
        if(v == fa) continue;
        dfs1(v,u);
        if(!son[u] || sz[v] > sz[son[u]])
           son[u] = v;
        sz[u] += sz[v];
    }
}
void dfs2(LL u,LL topu){
    id[u] = ++ cnt;
    w2[cnt] = w1[u];
    top[u] = topu;
    if(!son[u]) return;
    dfs2(son[u],topu);
    for(int i=head[u];i!=-1;i=e[i].nxt){
        LL v = e[i].to;
        if(v == f[u] || v == son[u]) continue;
        dfs2(v,v);
    }
}
void pushdown(LL p,LL len){
    if(!add[p]) return;
    add[p*2] = (add[p*2] + add[p]) % mod;
    add[p*2 + 1] = (add[p*2 + 1] + add[p]) % mod;
    tr[p*2] = (tr[p*2] + add[p] * (len - len / 2) % mod) % mod;
    tr[p*2 + 1] = (tr[p*2 + 1] + add[p] * (len / 2) % mod) % mod;
    add[p] = 0;
}
void update_tree(LL x,LL y,LL l,LL r,LL p,LL val){
    if(x > r || y < l) return;
    if(x <= l && r <= y){
        tr[p] = (tr[p] + (r - l + 1) * val % mod) % mod;
        add[p] = (add[p] + val) % mod;
        return; 
    }
    LL mid = (l + r) >> 1;
    pushdown(p,r - l + 1);
    update_tree(x,y,l,mid,p*2,val);
    update_tree(x,y,mid+1,r,p*2+1,val);
    tr[p] = (tr[p*2] + tr[p*2 + 1]) % mod;
}
LL query_tree(LL x,LL y,LL l,LL r,LL p){
    if(x > r || y < l) return 0;
    if(x <= l && y >= r) return tr[p] % mod;
    LL mid = (l + r) >> 1;
    pushdown(p,r - l + 1);
    return (query_tree(x,y,l,mid,p*2) + query_tree(x,y,mid+1,r,p*2 + 1)) % mod;
}
void update(LL x,LL y,LL val){
    while(top[x] != top[y]){
        if(dep[top[x]] < dep[top[y]]) swap(x,y);
        update_tree(id[top[x]],id[x],1,n,1,val);
        x = f[top[x]];
    }
    if(dep[x] > dep[y]) swap(x,y);
    update_tree(id[x],id[y],1,n,1,val);
}
LL query(LL x,LL y){
    LL res = 0;
    while(top[x] != top[y]){
        if(dep[top[x]] < dep[top[y]]) swap(x,y);
        res = (res + query_tree(id[top[x]],id[x],1,n,1)) % mod;
        x = f[top[x]];
    }
    if(dep[x] > dep[y]) swap(x,y);
    res = (res + query_tree(id[x],id[y],1,n,1)) % mod;
    return res;
}
void build(LL l,LL r,LL p){
    if(l == r){
        tr[p] = w2[l] % mod;
        return;
    }
    LL mid = (l + r) >> 1;
    build(l,mid,p*2);
    build(mid+1,r,p*2+1);
    tr[p] = (tr[p*2] + tr[p*2 + 1]) % mod;
}
int main(){
    memset(head,-1,sizeof head);
    scanf("%lld%lld%lld%lld",&n,&m,&root,&mod);
    for(int i=1;i<=n;i++) scanf("%lld",&w1[i]),w1[i] = w1[i] % mod;
    for(int i=1;i<n;i++){
        LL a,b;
        scanf("%lld%lld",&a,&b);
        Add(a,b),Add(b,a);
    }
    dfs1(root,0);
    dfs2(root,root);
    build(1,n,1);
    for(int i=1;i<=m;i++){
        LL op,x,y,z;
        scanf("%lld%lld",&op,&x);
        if(op == 1){
            scanf("%lld%lld",&y,&z);
            update(x,y,z);
        }
        else if(op == 2){
            scanf("%lld",&y);
            printf("%lld\n",query(x,y) % mod);
        }
        else if(op == 3){
            scanf("%lld",&z);
            update_tree(id[x],id[x] + sz[x] - 1,1,n,1,z);
        }
        else if(op == 4)
            printf("%lld\n",query_tree(id[x],id[x] + sz[x] - 1,1,n,1) % mod);
    }
}

by halehu @ 2023-07-22 07:42:33

@XHY20180718

谢谢大佬,已关注


|