蒟蒻刚学kdt,20ptsMLE求助

P4148 简单题

hanhoudedidue @ 2024-12-26 20:34:24

代码:

#include<bits/stdc++.h>
// #define int long long
#define ls(x) (t[x].ls)
#define rs(x) (t[x].rs)
#define mid ((l+r)>>1)
#define lowbit(x) ((x)&(-x))
using namespace std;
const int N=2e5+2;
bool ST;
int n,m,T;

int tot,K=2,cnt;
struct node{
    int ls,rs,sum,mn[2],mx[2],val;
    int p[2];
}t[N];
int wt[N],newd,rt[35];
bool cmp(int x,int y){
    return t[wt[x]].p[newd]<t[wt[y]].p[newd];
}
// struct KDT{
    inline void pushup(int x){
        t[x].sum=t[ls(x)].sum+t[rs(x)].sum+t[x].val;
        for(int i=0;i<K;i++){
            t[x].mn[i]=t[x].mx[i]=t[x].p[i];
            if(t[x].ls){
                t[x].mn[i]=min(t[x].mn[i],t[ls(x)].mn[i]);
                t[x].mx[i]=max(t[x].mx[i],t[ls(x)].mx[i]);
            }
            if(t[x].rs){
                t[x].mn[i]=min(t[x].mn[i],t[rs(x)].mn[i]);
                t[x].mx[i]=max(t[x].mx[i],t[rs(x)].mx[i]);
            }
        }
    }
    inline void build(int &x,int l,int r,int d){
        if(l>r) return;
        newd=d;
        nth_element(wt+l+1,wt+mid+1,wt+r+1,cmp);
        x=wt[mid];
        build(ls(x),l,mid-1,(d+1)%K);
        build(rs(x),mid+1,r,(d+1)%K);
        pushup(x);
    }
    inline void rebuild(int &x){
        if(!x) return ;
        wt[++tot]=x;
        rebuild(ls(x));
        rebuild(rs(x));
        x=0; 
    }
    //<=K
    inline void insert(int sum,int x[]){
        t[++cnt].val=sum;
        for(int i=0;i<K;i++)
            t[cnt].p[i]=t[cnt].mn[i]=t[cnt].mx[i]=x[i];
        wt[tot=1]=cnt;
        for(int i=1;i<=30;i++){
            if(!rt[i]){
                build(rt[i],1,tot,0);
                break;
            }
            else rebuild(rt[i]); 
        }
    }
    inline int query(int x,int lx[],int rx[]){
        if(!x) return 0;
        bool ok1=1;//1:all in,2:all no in
        bool ok=1;
        int res=0;
        for(int i=0;i<K;i++){
            ok&=(t[x].p[i]>=lx[i]&&t[x].p[i]<=rx[i]);
            if(!(t[x].mn[i]>=lx[i]&&t[x].mx[i]<=rx[i])) ok1=0;
            if(lx[i]>t[x].mx[i]||rx[i]<t[x].mn[i]) return 0;
        }
        if(ok1) return t[x].sum;
        if(ok) res+=t[x].val;
        return query(ls(x),lx,rx)+query(rs(x),lx,rx)+res; 
    }
// }kdt;
int lstans;
bool ED;
signed main(){
    // cerr<<(&ED-&ST)/1048576.0<<'\n';
    ios::sync_with_stdio(0);
    cin.tie(0),cout.tie(0);
    cin>>n;
    // for(int i=1;i<N;i++) wt[i]=i;
    int opt,x,y,X[2],Y[2],A;
    while(1){
        cin>>opt;
        if(opt==1){
            for(int i=0;i<K;i++){
                cin>>X[i];
                X[i]^=lstans;
            }
            cin>>A;A^=lstans;
            // kdt.
            insert(A,X);
        }
        else if(opt==2){
            for(int i=0;i<K;i++) {
                cin>>X[i];
                X[i]^=lstans;   
            }
            for(int i=0;i<K;i++){
                cin>>Y[i];
                Y[i]^=lstans;
            }
            int res=0;
            for(int i=1;i<=30;i++)
                res+=
                // kdt.
                query(rt[i],X,Y);
            lstans=res;
            cout<<res<<'\n';
        }
        else break;
    }
    return 0;
}

调了一晚上了,还是找不出来


by hanhoudedidue @ 2024-12-26 20:38:44

ok了 给同学看了一眼就出来了,cmp写错了


|