感觉是spread或者modify写错了,但是具体在哪不知道

P3372 【模板】线段树 1

Rhss @ 2023-02-26 14:43:51

#include <iostream>
#include <string>
#include <algorithm>
using namespace std;
const int N = 7;
int T , n , a[N] , tr[4 * N] , add[N];
void spread(int node,int start,int end)
{
    if(add[node])
    {
        int mid = start + end >> 1;
        tr[node * 2] += add[node] * (mid - start + 1);
        tr[node * 2 + 1] += add[node] * (end - mid);
        add[node * 2] += add[node];
        add[node * 2 + 1] += add[node];
        add[node] = 0;
    }
}
void modify(int node,int start,int end,int l,int r,int val)
{
    //当前节点表示区间被所求区间覆盖
    if(start >= l && end <= r)
    {
        //区间求和
        tr[node] += (end - start + 1) * val;
        //为当前节点打下懒标记、就不再下传节点了
        add[node] += val;
        return ;
    }
    spread(node , start , end);
    int mid = start + end >> 1;

    if(l <= mid)
    modify(node << 1 , start , mid , l , r , val);
    if(r >= mid + 1) 
    modify(node << 1 + 1, mid + 1 , end , l , r , val);
    tr[node] = tr[node * 2] + tr[node * 2 +];
}
int query(int node, int start, int end, int l,int r){
    //若目标区间与当时区间没有重叠,结束递归返回0 
    if (start > r || end < l){
        return 0;
    }
    //若目标区间包含当时区间,直接返回节点值 
    else if (l <= start && r >= end){
        return tr[node];
    }
    else {
        spread(node , start , end);
        int mid = (start + end) / 2;
        int left  = 2 * node;
        int right = 2 * node + 1;
        //计算左边区间的值 
        int sum_left = 0 , sum_right = 0;
        if(l <= mid)
        sum_left  = query(left , start, mid, l, r);
        //计算右边区间的值 
        if(r >= mid + 1)
        sum_right = query(right , mid+1, end, l, r);
        //相加即为答案 
        return sum_left + sum_right;
    } 
}
void build(int node , int start , int end)
{
    //递归边界(即遇到叶子节点)
    if(start == end){
        tr[node] = a[start];
    }
    else{
        //区间除二
        int mid = (start + end) / 2;
        //获取左右子树根节点下标
        int left = node * 2;
        int right = node * 2 + 1;
        build(left , start , mid);
        build(right , mid + 1 , end);
        tr[node] = tr[left] + tr[right];
    }
}
void solve(int c , int x)
{
    for(int i = 1 ; i <= c ; ++i) scanf("%d" , &a[i]);
    build(1 , 1 , c);
    while(x--)
    {
        int opr;
        cin>>opr;
        if(opr==1)
        {
            int r , w , q;
            cin >> r >> w >> q;
            modify(1 , 1 , c , r , w , q);
        }
        else
        {
            int r , w;
            cin >> r >> w;
            cout << query(1 , 1 , c , r , w) << endl;
        }
    }
}
int main()
{
    int c , t;
    while(cin >> c >> t)
    {
        solve(c , t);
    }
    return 0;
}

by Nwayy @ 2023-02-26 14:52:15

@Rhss update 部分整个都是有问题的,你这相当于退化成 n^2 了,懒标记打了跟没打一样。正确做法是遇到一个包含的区间就区间加,对当前节点打上懒标记,下次遇到就下传。


by Nwayy @ 2023-02-26 14:53:54

for 套上 update 操作我是没看懂,建议好好理解一下线段树的思想。


by 传奇666666 @ 2023-02-26 15:16:01

@Rhss 首先这个题需要开ll。其次,您的add数组很明显应该是4n。最后,您中间有一个地方写的是 node<<1+1 ,这个是不对的,这样算出来的结果是 node<<2 ,应改为 node<<1|1 。希望对您有帮助。


by 传奇666666 @ 2023-02-26 15:16:48

然后别的地方除了N就开了7之外都是没有问题的


by Rhss @ 2023-02-26 16:19:16

@传奇666666 开7是为了调试样例,按您的解法更改以后依旧是全W,可以麻烦您再看看吗

#include <iostream>
#include <string>
#include <algorithm>
using namespace std;
const int N = 2e5 + 10;
using ll = long long;
ll T , c , t , a[N] , tr[4 * N] , add[4 * N];
void spread(ll node,ll start,ll end)
{
    if(add[node])
    {
        ll mid = (start + end) / 2;
        tr[node * 2] += add[node] * (mid - start + 1);
        add[node * 2] += add[node];
        tr[node * 2 + 1] += add[node] * (end - mid);
        add[node * 2 + 1] += add[node];
    }
}
void modify(ll node,ll start,ll end,ll l,ll r,ll val)
{
    //当前节点表示区间被所求区间覆盖
    if(start >= l && end <= r)
    {
        //区间求和
        tr[node] += (end - start + 1) * val;
        //为当前节点打下懒标记、就不再下传节点了
        add[node] += val;
        return ;
    }
    spread(node , start , end);
    ll mid = (start + end) / 2;
    if(l <= mid)
    modify(node * 2 , start , mid , l , r , val);
    if(r >= mid + 1) 
    modify(node * 2 + 1, mid + 1 , end , l , r , val);
    tr[node] = tr[node * 2] + tr[node * 2 + 1];
}
ll query(ll node, ll start, ll end, ll l,ll r){
    //若目标区间与当时区间没有重叠,结束递归返回0 
    if (start > r || end < l){
        return 0;
    }
    //若目标区间包含当时区间,直接返回节点值 
    else if (l <= start && r >= end){
        return tr[node];
    }
    else {
        spread(node , start , end);
        ll mid = (start + end) / 2;
        //计算左边区间的值 
        ll sum_left = 0 , sum_right = 0;
        if(l <= mid)
        sum_left  = query(node * 2 , start, mid, l, r);
        //计算右边区间的值 
        if(r >= mid + 1)
        sum_right = query(node * 2 + 1 , mid+1, end, l, r);
        //相加即为答案 
        return sum_left + sum_right;
    } 
}
void build(ll node , ll start , ll end)
{
    //递归边界(即遇到叶子节点)
    if(start == end){
        tr[node] = a[start];
    }
    else{
        //区间除二
        ll mid = (start + end) / 2;
        //获取左右子树根节点下标
        ll left = node * 2;
        ll right = node * 2 + 1;
        build(left , start , mid);
        build(right , mid + 1 , end);
        tr[node] = tr[left] + tr[right];
    }
}
void solve(ll c , ll x)
{
    for(ll i = 1 ; i <= c ; ++i) scanf("%lld" , &a[i]);
    build(1 , 1 , c);
    while(x--)
    {
        ll opr;
        cin>>opr;
        if(opr==1)
        {
            ll r , w , q;
            cin >> r >> w >> q;
            modify(1 , 1 , c , r , w , q);
        }
        else
        {
            ll r , w;
            cin >> r >> w;
            cout << query(1 , 1 , c , r , w) << endl;
        }
    }
}
int main()
{
    while(cin >> c >> t)
    {
        solve(c , t);
    }
    return 0;
}

by Rhss @ 2023-02-26 16:32:04

@传奇666666 非常感谢您,我刚刚发现是因为懒标记传递给子树的时候没有清空而导致的,我是一个线段树初学者,非常感谢您的帮助


|