蒟蒻求调代码

P4074 [WC2013] 糖果公园

KiDDOwithTopTree @ 2021-03-27 14:26:15

嘤嘤嘤,调不出来…

样例过了,20分…

#include<algorithm>
#include<iostream>
#include<cmath>
using namespace std;
#define int long long
const int N=1e6+10;
struct edge{
    int from,to;
    int nxt;
};
struct quest{
    int l,r;
    int tim;
    int lca;
    int pos;
};
struct node{
    int old_val,new_val;
    int pos;
};
edge e[N];
quest q[N];
node c[N];
int last[N];
int f[N][30],dep[N];
int rev[N],st[N],ed[N];
int blo[N],lm[N],rm[N];
int a[N],b[N],v[N],w[N],cnt[N],used[N],res[N];
int num,tot,block;
inline bool cmp(quest x,quest y){
    if(blo[x.l]!=blo[y.l])
        return blo[x.l]<blo[y.l];
    else if(blo[x.r]!=blo[y.r])
        return blo[x.r]<blo[y.r];
    else
        return x.tim<y.tim;
}
inline void add(int from,int to){
    tot++;
    e[tot].from=from;
    e[tot].to=to;
    e[tot].nxt=last[from];
    last[from]=tot;
}
void dfs(int u,int fa){
    rev[++num]=u;
    st[u]=num;
    dep[u]=dep[fa]+1;
    f[u][0]=fa;
    for(int i=1;i<=20;i++)
        f[u][i]=f[f[u][i-1]][i-1];
    for(int i=last[u];i;i=e[i].nxt)
        if(e[i].to!=fa)
            dfs(e[i].to,u);
    rev[++num]=u;
    ed[u]=num;
}
int get_lca(int x,int y){
    if(dep[x]<dep[y])
        swap(x,y);
    for(int i=20;i>=0;i--)
        if(dep[f[x][i]]>=dep[y])
            x=f[x][i];
    if(x==y)
        return x;
    for(int i=20;i>=0;i--)
        if(f[x][i]!=f[y][i])
            x=f[x][i],y=f[y][i];
    return f[x][0];
}
inline void add_ans(int x,int &ans){
    cnt[a[x]]++;
    ans+=v[a[x]]*w[cnt[a[x]]];
}
inline void del_ans(int x,int &ans){
    ans-=v[a[x]]*w[cnt[a[x]]];
    cnt[a[x]]--;
}
inline void change(int x,int &ans){
    used[x]?del_ans(x,ans):add_ans(x,ans);
    used[x]^=1;
}
signed main(){
    int n,m,k;
    cin>>n>>m>>k;
    for(int i=1;i<=m;i++)
        cin>>v[i];
    for(int i=1;i<=n;i++)
        cin>>w[i];
    int x,y;
    for(int i=1;i<=n-1;i++){
        cin>>x>>y;
        add(x,y);
        add(y,x);
    }
    dfs(1,0);
    for(int i=1;i<=n;i++){
        cin>>a[i];
        b[i]=a[i];
    }
    block=pow(2*n,1/3.0);
    for(int i=1;i<=block;i++){
        lm[i]=(i-1)*block*block+1;
        rm[i]=i*block*block;
    }
    rm[block]=2*n;
    for(int i=1;i<=block;i++)
        for(int j=lm[i];j<=rm[i];j++)
            blo[j]=i;
    int opt,pos,tim=0;
    for(int i=1;i<=k;i++){
        cin>>opt;
        if(!opt){
            ++tim;
            cin>>c[tim].pos>>c[tim].new_val;
            c[tim].old_val=b[c[tim].pos];
            b[c[tim].pos]=c[tim].new_val;
        }
        else{
            pos=i-tim;
            cin>>q[pos].l>>q[pos].r;
            q[pos].pos=pos;
            q[pos].tim=tim;
            q[pos].lca=get_lca(q[pos].l,q[pos].r);
            if(st[q[i].l]>st[q[pos].r])
                swap(st[q[pos].l],st[q[pos].r]);
            if(q[pos].lca==q[pos].l){
                q[pos].lca=0;
                q[pos].l=st[q[pos].l];
                q[pos].r=st[q[pos].r];
            }
            else{
                q[pos].l=ed[q[pos].l];
                q[pos].r=st[q[pos].r];
            }
        }
    }
    k-=tim;
    sort(q+1,q+k+1,cmp);
    int l=1,r=0;
    int ans=0;
    tim=0;
    for(int i=1;i<=k;i++){
        while(l<q[i].l)
            change(rev[l++],ans);
        while(r<q[i].r)
            change(rev[++r],ans);
        while(l>q[i].l)
            change(rev[--l],ans);
        while(r>q[i].r)
            change(rev[r--],ans);
        if(q[i].lca)
            change(q[i].lca,ans);
        while(tim<q[i].tim){
            tim++;
            pos=c[tim].pos;
            if(used[pos]){
                ans-=v[a[pos]]*w[cnt[a[pos]]];
                cnt[a[pos]]--;
            }
            a[pos]=c[tim].new_val;
            if(used[pos]){
                cnt[a[pos]]++;
                ans+=v[a[pos]]*w[cnt[a[pos]]];
            }
        }
        while(tim>q[i].tim){
            pos=c[tim].pos;
            if(used[pos]){
                ans-=v[a[pos]]*w[cnt[a[pos]]];
                cnt[a[pos]]--;
            }
            a[pos]=c[tim].old_val;
            if(used[pos]){
                cnt[a[pos]]++;
                ans+=v[a[pos]]*w[cnt[a[pos]]];
            }
            tim--;
        }
        res[q[i].pos]=ans;
        if(q[i].lca)
            change(q[i].lca,ans);
    }
    for(int i=1;i<=k;i++)
        cout<<res[i]<<'\n';
}

|