Splay 区间操作WA求调

P3372 【模板】线段树 1

Starstream @ 2023-09-28 16:58:53

rt,区间修改的时候总是少一个左端点没改。下面是代码。

#include <iostream>

using namespace std;

const int N = 100010;
const int INF = 1e9;

struct Splay_Node
{
    int size, cnt, v;
    int p, s[2], val;
    int sum, add;

    void init(int _v, int _p)
    {
        v = _v, p = _p;
        size = 1;
    }
}tr[N];

int n, m;
int root, idx;
int w[N];

void pushup(int x)
{
    tr[x].size = tr[tr[x].s[0]].size + tr[tr[x].s[1]].size + tr[x].cnt;
    tr[x].sum = tr[tr[x].s[0]].sum + tr[tr[x].s[1]].sum + tr[x].cnt * tr[x].val;
}

void pushdown(int x)
{
    if (tr[x].add)
    {
        Splay_Node &L = tr[tr[x].s[0]], &R = tr[tr[x].s[1]];
        if (L.v != -N + 1) L.add += tr[x].add, L.sum += tr[x].add * L.size, L.val += tr[x].add;
        if (R.v != N - 1) R.add += tr[x].add, R.sum += tr[x].add * R.size, R.val += tr[x].add;
        tr[x].add = 0;
    }
}

void rotate(int x)
{
    pushdown(x);
    int y = tr[x].p, z = tr[y].p;
    int k = tr[y].s[1] == x;
    tr[z].s[tr[z].s[1] == y] = x, tr[x].p = z;
    tr[y].s[k] = tr[x].s[k ^ 1], tr[tr[x].s[k ^ 1]].p = y;
    tr[x].s[k ^ 1] = y, tr[y].p = x;
    pushup(y), pushup(x);
}

void splay(int x, int k)
{
    pushdown(x);
    while (tr[x].p != k)
    {
        pushdown(x);
        int y = tr[x].p, z = tr[y].p;
        if (z != k)
            if ((tr[y].s[1] == x) ^ (tr[z].s[1] == y)) rotate(x);
            else rotate(y);
        rotate(x);
    }
    if (!k) root = x;
}

int kth(int k)
{
    int u = root;
    while (tr[u].size >= k)
    {
        pushdown(u);
        if (tr[tr[u].s[0]].size >= k) u = tr[u].s[0];
        else if (tr[tr[u].s[0]].size + tr[u].cnt >= k) return splay(u, 0), u;
        else k -= tr[tr[u].s[0]].size + tr[u].cnt, u = tr[u].s[1];
        pushup(u);
    }
    return -1;
}

void insert(int v, int val)
{
    int u = root, p = 0;
    pushdown(u);
    while (u && tr[u].v != v)
        pushdown(u), p = u, u = tr[u].s[v > tr[u].v];
    if (u) tr[u].cnt ++ ;
    else
    {
        u = ++ idx;
        if (p) tr[p].s[v > tr[p].v] = u;
        tr[u] = {1, 1, v, p};
        tr[u].val = val, tr[u].sum = val;
    }
    splay(u, 0);
}

void output(int u)
{
    pushdown(u);
    if (tr[u].s[0]) output(tr[u].s[0]);
    printf("tr[%d]{size: %d, cnt: %d, id: %d, val: %d, sum: %d, add: %d}\n",\
        u, tr[u].size, tr[u].cnt, tr[u].v, tr[u].val, tr[u].sum, tr[u].add);
    if (tr[u].s[1]) output(tr[u].s[1]);
}

int main()
{
    int op, l, r, x;
    insert(-N + 1, 0), insert(N - 1, 0);

    scanf("%d%d", &n, &m);
    for (int i = 1; i <= n; i ++ )
        scanf("%d", &w[i]), insert(i, w[i]);

    puts("\n*********************************\n");
    output(root);
    puts("\n*********************************\n");

    while (m -- )
    {
        scanf("%d%d%d", &op, &l, &r);
        l = kth(l), r = kth(r + 2);
        splay(l, 0), splay(r, l);
        Splay_Node &L = tr[tr[r].s[0]];
        if (op == 1)
        {
            scanf("%d", &x);
            L.add += x, L.sum += L.size * x, L.v += x;
            puts("\n*********************************\n");
            output(root);
            puts("\n*********************************\n");
        }
        else printf("%d\n", L.sum);
    }

    return 0;
}

|