萌新求助线段树,只过了hack求调,码风优秀,调疯了

P2572 [SCOI2010] 序列操作

isletfall @ 2024-08-13 13:42:48

#include<bits/stdc++.h>
using namespace std;
#define ll long long
const int maxn = 2e5+7;
struct node{
    ll l,r,sum;
    ll sum1,lsum1,rsum1;
    ll sum0,lsum0,rsum0;
    ll lazy_turn,lazy_all;
}tr[maxn << 2];
ll a[maxn];
void push_up(ll p){
    tr[p].sum = tr[p << 1].sum + tr[p << 1 | 1].sum;
    tr[p].lsum1 = tr[p << 1].lsum1;
    if(tr[p << 1].lsum1 == tr[p << 1].r - tr[p << 1].l + 1)
        tr[p].lsum1 += tr[p << 1 | 1].lsum1;
    tr[p].rsum1 = tr[p << 1 | 1].rsum1;
    if(tr[p << 1 | 1].rsum1 == tr[p << 1 | 1].r - tr[p << 1 | 1].l + 1)
        tr[p].rsum1 += tr[p << 1].rsum1;
    tr[p].sum1 = max(tr[p << 1].rsum1 + tr[p << 1 | 1].lsum1,max(tr[p << 1].sum1,tr[p << 1 | 1].sum1));

    tr[p].lsum0 = tr[p << 1].lsum0;
    if(tr[p << 1].lsum0 == tr[p << 1].r - tr[p << 1].l + 1)
        tr[p].lsum0 += tr[p << 1 | 1].lsum0;
    tr[p].rsum0 = tr[p << 1 | 1].rsum0;
    if(tr[p << 1 | 1].rsum0 == tr[p << 1 | 1].r - tr[p << 1 | 1].l + 1)
        tr[p].rsum0 += tr[p << 1].rsum0;
    tr[p].sum0 = max(tr[p << 1].rsum0 + tr[p << 1 | 1].lsum0,max(tr[p << 1].sum0,tr[p << 1 | 1].sum0));
}
void build(ll p,ll l,ll r){
    tr[p].l = l,tr[p].r = r;           
    tr[p].lazy_all = -1;
    if(l == r){
        tr[p].sum = a[l];
        tr[p].lsum1 = tr[p].rsum1 = tr[p].sum1 = a[l];
        tr[p].lsum0 = tr[p].rsum0 = tr[p].sum0 = a[l] ^ 1;
        return;
    }
    ll mid = (l + r) >> 1;
    build(p << 1,l,mid);
    build(p << 1 | 1,mid + 1,r);
    push_up(p);
}
void push_down_turn(ll p){
    tr[p].sum = (tr[p].r - tr[p].l + 1) - tr[p].sum;
    swap(tr[p].sum1,tr[p].sum0);
    swap(tr[p].lsum1,tr[p].lsum0);
    swap(tr[p].rsum1,tr[p].rsum0);
    if(tr[p].lazy_all != -1){
        tr[p].lazy_all ^= 1;
    }
    else tr[p].lazy_turn ^= 1;
}
void push_down_all(ll p,ll v){
    tr[p].lazy_turn = 0;
    tr[p].lazy_all = v;
    if(v == 1){
        tr[p].sum = tr[p].sum1 = tr[p].lsum1 = tr[p].rsum1 = tr[p].r - tr[p].l + 1;
        tr[p].sum0 = tr[p].lsum0 = tr[p].rsum0 = 0; 
    }
    else {
        tr[p].sum = tr[p].sum1 = tr[p].lsum1 = tr[p].rsum1 = 0;
        tr[p].sum0 = tr[p].lsum0 = tr[p].rsum0 = tr[p].r - tr[p].l + 1;

    }
}
void push_down(ll p){
    if(tr[p].lazy_all != -1){
        push_down_all(p << 1,tr[p].lazy_all);
        push_down_all(p << 1 | 1,tr[p].lazy_all);
        tr[p].lazy_all = -1;
    }
    if(tr[p].lazy_turn){
        push_down_turn(p << 1);
        push_down_turn(p << 1 | 1);
        tr[p].lazy_turn = 0;
    }
}
void update_turn(ll p,ll l,ll r){
    if(l <= tr[p].l && tr[p].r <= r){
        push_down_turn(p);
        return;
    }
    push_down(p);
    if(tr[p << 1].r >= l)update_turn(p << 1,l,r);
    if(tr[p << 1 | 1].l <= r)update_turn(p << 1 | 1,l,r);
    push_up(p);
}
void update_all(ll p,ll l,ll r,ll v){
    if(l <= tr[p].l && tr[p].r <= r){
        push_down_all(p,v);
        return;
    }
    push_down(p);
    if(tr[p << 1].r >= l)update_all(p << 1,l,r,v);
    if(tr[p << 1 | 1].l <= r)update_all(p << 1 | 1,l,r,v);
    push_up(p);
}
ll query_all(ll p,ll l,ll r){
    if(l <= tr[p].l && tr[p].r <= r)
        return tr[p].sum;
    push_down(p);
    ll res=0;
    if(tr[p << 1].r >= l)res += query_all(p << 1,l,r);
    if(tr[p << 1 | 1].l <=r)res += query_all(p << 1 | 1,l,r);
    return res;
}
node query_continue(ll p,ll l,ll r){
    if(l <= tr[p].l && tr[p].r <= r)
        return tr[p];
    push_down(p);
    ll mid = (tr[p].l + tr[p].r) >> 1;
    if(l <= mid && mid < r){
        node lans,rans,ans;
        lans = query_continue(p << 1,l,r);
        rans = query_continue(p << 1 | 1,l,r);
        ans.lsum1 = lans.lsum1;
        if(lans.lsum1 == tr[p << 1].r - tr[p << 1].l + 1)
            ans.lsum1 += rans.lsum1;
        ans.rsum1 = rans.rsum1;
        if(rans.rsum1 == tr[p << 1 | 1].r - tr[p << 1 | 1].l + 1);
        ans.rsum1 += lans.rsum1;
        ans.sum1 = max(lans.rsum1 + rans.lsum1,max(lans.sum1,rans.sum1));
        return ans;
    }
    if(l <= mid) return query_continue(p << 1,l,r);
    if(r > mid)return query_continue(p << 1 | 1,l,r);
}
ll n,m;
int main(){
    cin >> n >> m;
    for(int i = 1;i <= n;i++)
        cin >> a[i];
    build(1,1,n);
    for(int i = 1;i <= m;i++){
        ll opt,x,y;
        cin >> opt >> x >> y;
        x++,y++;
        if(opt == 0 || opt == 1)update_all(1,x,y,opt);
        if(opt == 2)update_turn(1,x,y);
        if(opt == 3)cout<<query_all(1,x,y)<<endl;
        if(opt == 4)cout<<query_continue(1,x,y).sum1<<endl;
    }
    return 0;
}

by ImNot6Dora @ 2024-11-24 09:16:52

122 行的代码移到上一行的if里面去


|