线段树60pts求助

P1253 扶苏的问题

Untitled_628496 @ 2023-09-29 15:06:22


#include<bits/stdc++.h>
#define ll long long
using namespace std;
const int N=7777777,M=999999999;
ll n,m,a[N];
struct node{
    ll l,r,tmax,tag,cover;
}tree[N*4];
int read(){
    int x=0,w=1;
    char ch=0;
    while(ch<'0'||ch>'9') { 
        if (ch=='-') w=-1; 
        ch=getchar();    
    }
    while(ch>='0'&&ch<='9') {
        x=x*10+(ch-'0'); 
        ch=getchar();
    }
    return x*w;
}

void build(ll i,ll l,ll r){
    tree[i].l=l;
    tree[i].r=r;
    tree[i].cover=M;
    if(l==r){
        tree[i].tmax=a[l];
        return ;
    }
    ll mid=l+r>>1;
    build(i<<1,l,mid);
    build(i<<1|1,mid+1,r);
    tree[i].tmax=max(tree[i<<1].tmax,tree[i<<1|1].tmax);
    return ;
}
void pushdown(ll i){
    if(tree[i].cover!=M){
        tree[i<<1].tag=tree[i<<1|1].tag=0;
        tree[i<<1].tmax=tree[i<<1|1].tmax=tree[i].cover;
        tree[i<<1].cover=tree[i<<1|1].cover=tree[i].cover;
        tree[i].cover=M;
    }
    tree[i<<1].tag+=tree[i].tag;
    tree[i<<1|1].tag+=tree[i].tag;
    tree[i<<1].tmax+=tree[i].tag;
    tree[i<<1|1].tmax+=tree[i].tag;
    tree[i].tag=0;
    return ;
}
void change(ll i,ll l,ll r,ll k){
    if(tree[i].l>=l&&tree[i].r<=r){
        tree[i].tag=0;
        tree[i].cover=k;
        tree[i].tmax=k;
        return ;
    }
    pushdown(i);
    ll mid=tree[i].l+tree[i].r>>1;
    if(l<=mid) change(i<<1,l,r,k);
    if(r>=mid+1) change(i<<1|1,l,r,k);
    tree[i].tmax=max(tree[i<<1].tmax,tree[i<<1|1].tmax);
    return ;
}
void add(ll i,ll l,ll r,ll k){
    if(tree[i].l>=l&&tree[i].r<=r){
        tree[i].tag+=k;
        tree[i].tmax+=k;
        return ;
    }
    if(tree[i].tag) pushdown(i);
    ll mid=tree[i].l+tree[i].r>>1;
    if(l<=mid) add(i<<1,l,r,k);
    if(r>=mid+1) add(i<<1|1,l,r,k);
    tree[i].tmax=max(tree[i<<1].tmax,tree[i<<1|1].tmax);
    return ;
}
ll find(ll i,ll l,ll r){
    if(tree[i].l>=l&&tree[i].r<=r){
        return tree[i].tmax;
    }
    if(tree[i].tag) pushdown(i);
    ll mid=tree[i].l+tree[i].r>>1;
    ll ans=-4655186461566;
    if(l<=mid) ans=max(ans,find(i<<1,l,r));
    if(r>=mid+1) ans=max(ans,find(i<<1|1,l,r));
    return ans;
}
int main(){
    n=read();
    m=read();
    for(int i=1;i<=n;i++){
        a[i]=read();
    }
    build(1,1,n);
    while(m--){
        ll op,l,r,k;
        op=read();
        l=read();
        r=read();
        if(op==1){
            k=read();
            change(1,l,r,k);
        }
        else if(op==2){
            k=read();
            add(1,l,r,k);
        }else{
            cout << find(1,l,r) << endl;
        }
    }
    return 0;
}

|