K-D Tree MLE20 求调

P4148 简单题

sto_clx_orz @ 2023-08-29 20:36:19

#include<bits/stdc++.h>
using namespace std;

class KDT
{
    class point
    {
        vector<int> w,w_,_w;
        int k,size;
        point *ls,*rs;
        public:
        point()
        {
            w.resize(0);
            w_.resize(0);
            _w.resize(0);
            k=size=0;
            ls=rs=NULL;
        }
        point(int k_)
        {
            k=k_;
            size=0;
            w.resize(k);
            w_.resize(k);
            _w.resize(k);
            ls=rs=NULL;
        }
        void build(int k_,int x,int l,int r,vector<vector<int>>&a)
        {
            if(r<l)return;
            size=r-l+1;
            k=k_;
            w.resize(k),w_.resize(k);
            int mid=(l+r)>>1;
            nth_element(a.begin()+l,a.begin()+mid,a.begin()+r+1,[&x,&k_](vector<int> a,vector<int> b){return a[x%k_]<b[x%k_];});
            ls=new point(k);
            rs=new point(k);
            w=a[mid];
            w_=w;
            _w=w;
            ls->build(k,x+1,l,mid-1,a);
            for(int i(0);i<k;++i)w_[i]=max(w_[i],ls->w_[i]);
            for(int i(0);i<k;++i)_w[i]=min(_w[i],ls->_w[i]);
            rs->build(k,x+1,mid+1,r,a);
            for(int i(0);i<k;++i)w_[i]=max(w_[i],rs->w_[i]);
            for(int i(0);i<k;++i)_w[i]=min(_w[i],rs->_w[i]);
        }
        int query(vector<int>a,vector<int>b)
        {
            int o=1,ans=0;
            for(int i(0);i<k;++i)o&=(w_[i]<a[i]);
            if(o)return 0;
            o=1;
            for(int i(0);i<k;++i)o&=(_w[i]>b[i]);
            if(o)return 0;
            o=1;
            for(int i=0;i<k;i++)o&=(w_[i]<=b[i]&&_w[i]>=a[i]);
            if(o)return size;
            o=1;
            for(int i(0);i<k;++i)o&=(w[i]<=b[i]&&w[i]>=a[i]);
            if(o)++ans;
            if(ls!=NULL)ans+=ls->query(a,b);
            if(rs!=NULL)ans+=rs->query(a,b);
            return ans;
        }
    };
    int k;
    vector<vector<int>>a,aa;
    point root;
    public:
    KDT()
    {
        k=0;
        a.resize(0,vector<int>(0));
        aa.resize(0,vector<int>(0));
    }
    void set(int n_,int k_)
    {
        k=k_;
    }
    KDT(int k_)
    {
        k=k_;
    }
    void build()
    {
        root.build(k,0,0,a.size()-1,a);
    }
    void insert(vector<int>x)
    {
        aa.push_back(x);
        if(aa.size()>=sqrt(aa.size()+a.size()))
        {
            for(auto i:aa)a.push_back(i);
            aa.clear();
            build();
        }
    }
    int query(vector<int>a,vector<int>b)
    {
        int ans=0;
        ans+=root.query(a,b);
        for(auto w:aa)
        {
            int o=1;
            for(int i(0);i<k;++i)
                o&=(w[i]<=b[i]&&w[i]>=a[i]);
            ans+=o;
        }
        return ans;
    }
};

int n,t,k=3;

int main()
{
    ios::sync_with_stdio(false),cin.tie(0);
    cin>>t;
    while(t--)
    {
        cin>>n;
        KDT x(k);
        for(int i(1);i<=n;++i)
        {
            int type;
            cin>>type;
            if(type==1)
            {
                vector<int>a(k);
                for(int j(0);j<k;++j)cin>>a[j];
                x.insert(a);
            }
            else
            {
                vector<int>a(k),b(k);
                for(int j(0);j<k;++j)cin>>a[j];
                for(int j(0);j<k;++j)cin>>b[j];
                cout<<x.query(a,b)<<"\n";
            }
        }
    }
}

(c++17以下可能会CE)


|