悬关,求调

P4148 简单题

Zxx132536 @ 2024-02-19 11:05:36

#include <bits/stdc++.h>
const double alpha=0.75;
using namespace std;
const int N=2e5+100;
struct Point
{
 int dim[2],val;
}order[N];  int cnt;
struct kd_tree
{
 int ls,rs,sum,size;
 int Min[2],Max[2];
 Point p;
}t[N];
int tot,root,now;
int top,Stack[N];
bool cmp(Point x,Point y)
{
 return x.dim[now]<y.dim[now];
}
void update(int u)
{
 for(int i=0;i<2;i++)
 {
    t[u].Min[i]=t[u].Max[i]=t[u].p.dim[i];
    if(t[u].ls)
    {
        t[u].Min[i]=min(t[u].Min[i],t[t[u].ls].Min[i]);
        t[u].Max[i]=max(t[u].Max[i],t[t[u].ls].Max[i]);
    }
    if(t[u].rs)
    {
        t[u].Min[i]=min(t[u].Min[i],t[t[u].rs].Min[i]);
        t[u].Max[i]=max(t[u].Max[i],t[t[u].rs].Max[i]);
    }
 }
 t[u].sum=t[t[u].ls].sum+t[t[u].rs].sum+t[u].p.val;
 t[u].size=t[t[u].ls].size+t[t[u].rs].size+1;
}
void slap(int u,int num)
{
 if(!u) return ;
 slap(t[u].ls,num); Stack[++top]=u;
 order[t[t[u].ls].size+num+1]=t[u].p;
 slap(t[u].rs,t[t[u].ls].size+num+1);
}
int build(int l,int r,int d)
{
 if(l>r)    return 0;
 int u;
 if(top)    u=Stack[top--];
 else       u=++tot;
 int mid=l+r>>1;
 now=d;
 nth_element(order+1,order+mid,order+r+1,cmp);
 t[u].p=order[mid];
 t[u].ls=build(l,mid-1,d^1);
 t[u].rs=build(mid+1,r,d^1);
 update(u);
 return u;
}
bool notbalance(int u)
{
 if(t[t[u].ls].size>alpha*t[u].size||t[t[u].rs].size>alpha*t[u].size)   return true;
 return false;
}
void Insert(int &u,Point now,int d)
{
 if(!u)
 {
    if(top) u=Stack[top--];
    else    u=++tot;
    t[u].ls=t[u].rs=0,t[u].p=now;
    update(u);
    return ;
 }
 if(now.dim[d]<=t[u].p.dim[d])  Insert(t[u].ls,now,d^1);
 else                           Insert(t[u].rs,now,d^1);
 update(u);
 if(notbalance(u))
 {
    slap(u,0);
    u=build(1,t[u].size,d);
 }
}
int query(int u,int x1,int y1,int x2,int y2)
{
 if(!u) return 0;
 int X1=t[u].Min[0],Y1=t[u].Min[1],X2=t[u].Max[0],Y2=t[u].Max[1];
 if(x1<=X1&&x2>=X2&&y1<=Y1&&y2>=Y2) return t[u].sum;
 if(x1>X2||x2<X1||y1>Y2||y2<Y1)     return 0;
 int ans=0;
 X1=X2=t[u].p.dim[0],Y1=Y2=t[u].p.dim[1];
 if(x1<=X1&&x2>=X2&&y1<=Y1&&y2>=Y2) ans+=t[u].p.val;
 ans+=query(t[u].ls,x1,y1,x2,y2)+query(t[u].rs,x1,y1,x2,y2);
 return ans;
}
int main()
{
 int n; scanf("%d",&n);
 int ans=0;
 while(1)
 {
    int opt;    scanf("%d",&opt);
    if(opt==1)
    {
        int x,y,val;    Point p;        scanf("%d%d%d",&x,&y,&val);
        x^=ans,y^=ans,val^=ans; p.dim[0]=x,p.dim[1]=y,p.val=val;
        Insert(root,p,0);
    }
    if(opt==2)
    {
        int x1,y1,x2,y2;    scanf("%d%d%d%d",&x1,&y1,&x2,&y2);
        x1^=ans;    x2^=ans;    y1^=ans;    y2^=ans;
        ans=query(root,x1,y1,x2,y2);
        printf("%d\n",ans);
    }
    if(opt==3)  break;
 }
 return 0;
}

by ZYLZPP @ 2024-04-03 12:51:10

nth_element(order+1,order+mid,order+r+1,cmp);

不是从1开始,而是从l开始

改为

nth_element(order+l,order+mid,order+r+1,cmp);

|