Starstream @ 2023-09-28 16:58:53
rt,区间修改的时候总是少一个左端点没改。下面是代码。
#include <iostream>
using namespace std;
const int N = 100010;
const int INF = 1e9;
struct Splay_Node
{
int size, cnt, v;
int p, s[2], val;
int sum, add;
void init(int _v, int _p)
{
v = _v, p = _p;
size = 1;
}
}tr[N];
int n, m;
int root, idx;
int w[N];
void pushup(int x)
{
tr[x].size = tr[tr[x].s[0]].size + tr[tr[x].s[1]].size + tr[x].cnt;
tr[x].sum = tr[tr[x].s[0]].sum + tr[tr[x].s[1]].sum + tr[x].cnt * tr[x].val;
}
void pushdown(int x)
{
if (tr[x].add)
{
Splay_Node &L = tr[tr[x].s[0]], &R = tr[tr[x].s[1]];
if (L.v != -N + 1) L.add += tr[x].add, L.sum += tr[x].add * L.size, L.val += tr[x].add;
if (R.v != N - 1) R.add += tr[x].add, R.sum += tr[x].add * R.size, R.val += tr[x].add;
tr[x].add = 0;
}
}
void rotate(int x)
{
pushdown(x);
int y = tr[x].p, z = tr[y].p;
int k = tr[y].s[1] == x;
tr[z].s[tr[z].s[1] == y] = x, tr[x].p = z;
tr[y].s[k] = tr[x].s[k ^ 1], tr[tr[x].s[k ^ 1]].p = y;
tr[x].s[k ^ 1] = y, tr[y].p = x;
pushup(y), pushup(x);
}
void splay(int x, int k)
{
pushdown(x);
while (tr[x].p != k)
{
pushdown(x);
int y = tr[x].p, z = tr[y].p;
if (z != k)
if ((tr[y].s[1] == x) ^ (tr[z].s[1] == y)) rotate(x);
else rotate(y);
rotate(x);
}
if (!k) root = x;
}
int kth(int k)
{
int u = root;
while (tr[u].size >= k)
{
pushdown(u);
if (tr[tr[u].s[0]].size >= k) u = tr[u].s[0];
else if (tr[tr[u].s[0]].size + tr[u].cnt >= k) return splay(u, 0), u;
else k -= tr[tr[u].s[0]].size + tr[u].cnt, u = tr[u].s[1];
pushup(u);
}
return -1;
}
void insert(int v, int val)
{
int u = root, p = 0;
pushdown(u);
while (u && tr[u].v != v)
pushdown(u), p = u, u = tr[u].s[v > tr[u].v];
if (u) tr[u].cnt ++ ;
else
{
u = ++ idx;
if (p) tr[p].s[v > tr[p].v] = u;
tr[u] = {1, 1, v, p};
tr[u].val = val, tr[u].sum = val;
}
splay(u, 0);
}
void output(int u)
{
pushdown(u);
if (tr[u].s[0]) output(tr[u].s[0]);
printf("tr[%d]{size: %d, cnt: %d, id: %d, val: %d, sum: %d, add: %d}\n",\
u, tr[u].size, tr[u].cnt, tr[u].v, tr[u].val, tr[u].sum, tr[u].add);
if (tr[u].s[1]) output(tr[u].s[1]);
}
int main()
{
int op, l, r, x;
insert(-N + 1, 0), insert(N - 1, 0);
scanf("%d%d", &n, &m);
for (int i = 1; i <= n; i ++ )
scanf("%d", &w[i]), insert(i, w[i]);
puts("\n*********************************\n");
output(root);
puts("\n*********************************\n");
while (m -- )
{
scanf("%d%d%d", &op, &l, &r);
l = kth(l), r = kth(r + 2);
splay(l, 0), splay(r, l);
Splay_Node &L = tr[tr[r].s[0]];
if (op == 1)
{
scanf("%d", &x);
L.add += x, L.sum += L.size * x, L.v += x;
puts("\n*********************************\n");
output(root);
puts("\n*********************************\n");
}
else printf("%d\n", L.sum);
}
return 0;
}