线段树模板求助Splay

P3372 【模板】线段树 1

STLvector @ 2024-09-23 16:52:46

只过了 #1 #2 #3 #9 四个点

目测是tag有没有处理的地方(用 Test_Tag 看了)

评测记录

#include <iostream>
#include <vector>
#if 1
    #define debug(format,...) fprintf(stderr,format,##__VA_ARGS__)
#else
    #define dedebug(format, ...)
#endif

class splay
{
private:
    const static int N=1e5+5;
    struct node
    {
        long long val,sum,lz;
        unsigned size;
    }tree[N];
    int s[N][2],f[N],root=0,tail=0;
    int newnode(int x)
    {
        tail++;
        tree[tail]={x,x,0,1};
        return tail;
    }
    int sid(int x){return s[f[x]][1]==x;}
    void pushup(int x)
    {
        tree[x].sum=tree[s[x][0]].sum+tree[x].val+tree[s[x][1]].sum;
        tree[x].size=tree[s[x][0]].size+1+tree[s[x][1]].size;
    }
    void execute(int x,long long tag)
    {
        if(x==0) return;
        tree[x].val+=tag;
        tree[x].sum+=tag*tree[x].size;
        tree[x].lz+=tag;
    }
    void pushdown(int x)
    {
        execute(s[x][0],tree[x].lz);
        execute(s[x][1],tree[x].lz);
        tree[x].lz=0;
    }

    void connect(int x,int y,int id)
    {
        if(x) f[x]=y;
        if(y) s[y][id]=x;
    }
    void rotate(int x)
    {
        int y=f[x],r=f[f[x]];
        int xid=sid(x),yid=sid(y);
        int z=s[x][1-xid];
        pushdown(y);
        pushdown(x);

        connect(z,y,xid);
        connect(y,x,1-xid);
        connect(x,r,yid);

        pushup(y);
        pushup(x);
    }
    void Splay(int x,int target=0)
    {
        if(target==0) root=x;
        for(;f[x]!=target;rotate(x))
            if(f[f[x]]!=target)
            {
                if(sid(x)==sid(f[x]))
                    rotate(f[x]);
                else rotate(x);
            }
    }
    int find(unsigned k)
    {
        int x=root;
        while(true)
        {
            pushdown(x);
            if(s[x][0]!=0&&k<=tree[s[x][0]].size)
                x=s[x][0];
            else if(tree[s[x][0]].size+1<k)
            {
                k-=tree[s[x][0]].size+1;
                x=s[x][1];
            }
            else
            {
                Splay(x);
                return x;
            }
        }
    }
public:
    int Root(){return root;}
    int build(int l,int r,std::vector<long long>& a)
    {
        if(l>r) return 0;
        int mid=(l+r)>>1;
        int x=newnode(a[mid]);
        if(root==0) root=x;
        connect(build(l,mid-1,a),x,0);
        connect(build(mid+1,r,a),x,1);
        pushup(x);
        return x;
    }
    void modify(int l,int r,long long k)
    {
        if(l!=1&&r!=tail)
        {
            int L=find(l-1),R=find(r+1);
            Splay(L);
            Splay(R,L);
            execute(s[s[root][1]][0],k);
        }
        else if(l!=1)
        {
            int L=find(l-1);
            Splay(L);
            execute(s[root][1],k);
        }
        else if(r!=tail)
        {
            int R=find(r+1);
            Splay(R);
            execute(s[root][0],k);
        }
        else
        {
            Splay(1);
            execute(1,k);
        }
    }
    long long query(int l,int r)
    {
        if(l!=1&&r!=tail)
        {
            int L=find(l-1),R=find(r+1);
            Splay(L);
            Splay(R,L);
            return tree[s[s[root][1]][0]].sum;
        }
        else if(l!=1)
        {
            int L=find(l-1);
            Splay(L);
            return tree[s[root][1]].sum;
        }
        else if(r!=tail)
        {
            int R=find(r+1);
            Splay(R);
            return tree[s[root][0]].sum;
        }
        else
        {
            Splay(1);
            return tree[1].sum;
        }
    }
    void Test_Tag(int x)
    {
        if(x==0) return;
        pushdown(x);
        Test_Tag(s[x][0]);
        Test_Tag(s[x][1]);
        pushup(x);
    }
    void Debug()
    {
        debug("\n\n\nroot=%d\n",root);
        for(int i=1;i<=tail;i++)
            debug("Node #%d -> ls = %d rs=%d size=%d sum=%lld    lz=%lld    val=%lld\n",i,s[i][0],s[i][1],tree[i].size,tree[i].sum,tree[i].lz,tree[i].val);
    }
}s;

namespace solve
{
    int n,m;
    std::vector<long long> a;
    void main()
    {
        std::cin>>n>>m;
        a.resize(n+1);
        for(int i=1;i<=n;i++)
            std::cin>>a[i];
        s.build(1,n,a);
        // s.Debug();
        while(m--)
        {
            int op;
            std::cin>>op;
            if(op==1)
            {
                int l,r;
                long long k;
                std::cin>>l>>r>>k;
                s.modify(l,r,k);
            }
            else
            {
                int l,r;
                std::cin>>l>>r;
                std::cout<<s.query(l,r)<<'\n';
            }
            // s.Debug();
        }
    }
}

int main()
{
    #define _PROBLEM_ ""
    // std::freopen(_PROBLEM_".in","r",stdin);
    // std::freopen(_PROBLEM_".out","w",stdout);
    std::ios::sync_with_stdio(0);
    std::cout.tie(0);
    std::cin.tie(0);
    solve::main();
    return 0;
}

by STLvector @ 2024-09-23 22:48:00

破案了,修改后要一直 pushup 到根


|