基于矩阵乘法的线段树 TLE on #10 求调

P1253 扶苏的问题

lizihan250 @ 2024-10-24 19:16:13

rt. 核心思路:设线段树内存的状态为 \begin{pmatrix}a & 1 \end{pmatrix},则有:

  • 加上一个数 k\begin{pmatrix}a & 1 \end{pmatrix}\begin{pmatrix}1 & 0\\ k & 1 \end{pmatrix}.
  • 加上一个数 k\begin{pmatrix}a & 1 \end{pmatrix}\begin{pmatrix}0 & 0\\ k & 1 \end{pmatrix}.

由于矩阵乘法具有结合律,可以方便的维护懒标记。代码如下。请问如何进一步优化?

#include<bits/stdc++.h>
using namespace std;
const int Maxn=1000000;
int f,tp,n,m,opt,x,y,k,nums[Maxn+100];
static int stk[30];
long long w;
char c;
struct Martix
{
    int n,c;
    long long nums[3][3];
    void init(const int _n,const int _c,const bool opt)
    {
        n=_n;
        c=_c;
        for(int i=1;i<=_n;i++)
        {
            for(int j=1;j<=_c;j++)
            {
                nums[i][j]=(i==j)*opt;
            }
        }
        return;
    }
    Martix operator * (const Martix &x) const
    {
        Martix res;
        res.init(n,x.c,0);
        /*if(c!=x.n)
        {
            printf("Warning: invaild multipul of Martix ( %d %d ) * ( %d %d )\n",n,c,x.n,x.c);
            return res;
        }*/
        for(int i=1;i<=n;i++)
        {
            for(int j=1;j<=x.c;j++)
            {
                res.nums[i][j]=0;
                for(int k=1;k<=c;k++)
                {
                    res.nums[i][j]+=nums[i][k]*x.nums[k][j];
                }
            }
        }
        return res;
    }
    /*void print()
    {
        printf("row: %d  column: %d\n",n,c);
        for(int i=1;i<=n;i++)
        {
            for(int j=1;j<=c;j++)
            {
                printf("%lld ",nums[i][j]);
            }
            printf("\n");
        }
        return;
    }*/
}base;
struct node
{
    int l,r;
    Martix x,lazy;
    void init(const int _l,const int _r,const long long _num)
    {
        l=_l;
        r=_r;
        x.init(1,2,0);
        lazy.init(2,2,1);
        x.nums[1][1]=_num;
        x.nums[1][2]=1;
        return;
    }
    /*void print()
    {
        printf("** %d %d **\n",l,r);
        x.print();
        lazy.print();
    }*/
}seg[(Maxn<<2)+100];
inline long long max(const node x,const node y)
{
    return max(x.x.nums[1][1],y.x.nums[1][1]);
}
void push_down(const int nw)
{
    seg[nw<<1].x=seg[nw<<1].x*seg[nw].lazy;
    seg[(nw<<1)+1].x=seg[(nw<<1)+1].x*seg[nw].lazy;
    seg[nw<<1].lazy=seg[nw<<1].lazy*seg[nw].lazy;
    seg[(nw<<1)+1].lazy=seg[(nw<<1)+1].lazy*seg[nw].lazy;
    seg[nw].lazy.init(2,2,1);
    return;
}
void build(const int l,const int r,const int nw)
{
    if(l==r)
    {
        seg[nw].init(l,r,nums[l]);
        //seg[nw].print();
        return;
    }
    int mid=l+((r-l)>>1);
    build(l,mid,nw<<1);
    build(mid+1,r,(nw<<1)+1);
    seg[nw].init(l,r,max(seg[nw<<1],seg[(nw<<1)+1]));
    //seg[nw].print();
    return;
}
void modify(const int l,const int r,const int nw,const Martix b)
{
    //printf("modify: ");
    //seg[nw].print();
    if(l<=seg[nw].l&&seg[nw].r<=r)
    {
        seg[nw].x=seg[nw].x*b;
        seg[nw].lazy=seg[nw].lazy*b;
        //printf("return: ");
        //seg[nw].print();
        return;
    }
    if(seg[nw].lazy.nums[1][1]!=1||seg[nw].lazy.nums[1][2]!=0||seg[nw].lazy.nums[2][1]!=0||seg[nw].lazy.nums[2][2]!=1) push_down(nw);
    int mid=seg[nw].l+((seg[nw].r-seg[nw].l)>>1);
    if(l<=mid) modify(l,r,nw<<1,b);
    if(mid<r) modify(l,r,(nw<<1)+1,b);
    seg[nw].x.nums[1][1]=max(seg[nw<<1],seg[(nw<<1)+1]);
    //printf("return: ");
    //seg[nw].print();
    return;
}
long long query(const int l,const int r,const int nw)
{
    if(l<=seg[nw].l&&seg[nw].r<=r) return seg[nw].x.nums[1][1];
    if(seg[nw].lazy.nums[1][1]!=1||seg[nw].lazy.nums[1][2]!=0||seg[nw].lazy.nums[2][1]!=0||seg[nw].lazy.nums[2][2]!=1) push_down(nw);
    int mid=seg[nw].l+((seg[nw].r-seg[nw].l)>>1);
    long long ans=-2e18;
    if(l<=mid) ans=max(ans,query(l,r,nw<<1));
    if(mid<r) ans=max(ans,query(l,r,(nw<<1)+1)); 
    seg[nw].x.nums[1][1]=max(seg[nw<<1],seg[(nw<<1)+1]);
    return ans;
}
inline long long read()
{
    f=1;
    w=0;
    c=getchar();
    while(c<'0'||c>'9')
    {
        if(c=='-') f=-1;
        c=getchar();
    }
    while('0'<=c&&c<='9')
    {
        w=(w<<3)+(w<<1)+c-48;
        c=getchar();
    }
    return f*w;
}
inline void print(long long x)
{
    tp=0;
    if(x==0)
    {
        putchar('0');
        return;
    }
    if(x<0)
    {
        putchar('-');
        x=-x;
    }
    while(x>0)
    {
        tp++;
        stk[tp]=x%10;
        x/=10;
    }
    while(tp>0)
    {
        putchar(stk[tp]+'0');
        tp--;
    }
    putchar('\n');
    return;
}
int main()
{
    n=read();
    m=read();
    for(int i=1;i<=n;i++)
    {
        nums[i]=read();
    }
    base.init(2,2,1);
    build(1,n,1);
    /*printf("*****\n");
    seg[1].lazy.print();
    seg[2].x.print();
    (seg[2].x*seg[1].lazy).print();
    printf("*****\n");*/
    while(m--)
    {
        opt=read();
        x=read();
        y=read();
        if(opt<=2)
        {
            k=read();
            base.nums[1][1]=opt-1;
            base.nums[2][1]=k;
            modify(x,y,1,base);
        }
        else print(query(x,y,1));
    }
    return 0;
}

|