唐刚学树剖10min,求调(马蜂优良)

P3384 【模板】重链剖分/树链剖分

Rainypaster @ 2024-07-16 11:47:52

rt, 线段树用的是标记永久化。结构应该是挺清晰的。

#include <bits/stdc++.h>
using namespace std;
int n, m, r, p;
const int N = 1e5 + 5;
int w[N];
vector<int> g[N];
int dep[N], fa[N], size[N], top[N], son[N], id[N], a[N];
int cnt;

struct Segment_Tree
{
    struct node
    {
        int l, r, sum, lazy;
    }tr[N << 4];

    void push_up(int u)
    {
        tr[u].sum = tr[u << 1].sum + tr[u << 1 | 1].sum;
        tr[u].sum %= p;
    }
    void build(int u, int l, int r)
    {
        tr[u].l = l, tr[u].r = r;
        if(l == r){
            tr[u].sum = a[l];
            tr[u].sum %= p;
            return ;
        }
        int mid = (l + r) / 2;
        build(u << 1, l, mid), build(u << 1 | 1, mid + 1, r);
        push_up(u);
    }
    void update(int u, int l, int r, int k)
    {
        if(l <= tr[u].l && tr[u].r <= r){
            tr[u].lazy += k;
            return ;
        }
        tr[u].sum += (min(r, tr[u].r) - max(l, tr[u].l) + 1) * k;
        int mid = (tr[u].l + tr[u].r) >> 1;
        if(l <= mid) update(u << 1, l, r, k);
        if(r >  mid) update(u << 1 | 1, l, r, k);
    }
    int query(int u, int l, int r)
    {
        if(l <= tr[u].l && tr[u].r <= r) {
            return tr[u].sum + (tr[u].r - tr[u].l + 1) * tr[u].lazy;
        }
        int res = (min(r, tr[u].r) - max(l, tr[u].l) + 1) * tr[u].lazy;
        int mid = (tr[u].l + tr[u].r) >> 1;
        if(l <= mid) res += query(u << 1, l, r);
        if(r >  mid) res += query(u << 1 | 1, l, r);
        return res;
    }
}seg;

struct TreePou
{
    int ans = 0;
    int query(int x, int y)
    {
        while(top[x] != top[y]){
            if(dep[top[x]] < dep[top[x]]) swap(x, y);
            ans += seg.query(1, id[top[x]], id[x]);
            ans %= p;
            x = fa[top[x]];
        }
        if(dep[x] > dep[y]) swap(x, y);
        ans += seg.query(1, id[x], id[y]);
        ans %= p;
        return ans;
    }
    void update(int x, int y, int val)
    {
        val %= p;
        while(top[x] != top[y]){
            if(dep[top[x]] < dep[top[x]]) swap(x, y);
            seg.update(1, id[top[x]], id[x], val);
            x = fa[top[x]];
        }
        if(dep[x] > dep[y]) swap(x, y);
        seg.update(1, id[x], id[y], val);
    }
    void tupdate(int x, int val){
        seg.update(1, id[x], id[x] + size[x] - 1, val);
    }
    int tquery(int x){
        return seg.query(1, id[x], id[x] + size[x] - 1);
    }
}tp;

void dfs1(int u, int f, int deep)
{
    dep[u] = deep;
    fa[u] = f;
    size[u] = 1;
    int maxn = -1;
    for(int i = 0;i < g[u].size();i ++ ){
        int v = g[u][i];
        if(v == f) continue;
        dfs1(v, u, deep + 1);
        size[u] += size[v];
        if(size[v] > maxn){
            maxn = size[v];
            son[u] = v;
        }
    }
}
void dfs2(int u, int topfa)
{
    id[u] = ++cnt;
    a[cnt] = w[u];
    top[u] = topfa;
    if(!son[u]) return ;
    dfs2(son[u], topfa);
    for(int i = 0;i < g[u].size();i ++ ){
        int v = g[u][i];
        if(v == fa[u] || v == son[u]) continue;
        dfs2(v, v);
    }
}

int main()
{
    cin >> n >> m >> r >> p;
    for(int i = 1;i <= n;i ++ ) cin >> w[i];
    for(int i = 1;i < n;i ++ ){
        int x, y; cin >> x >> y;
        g[x].push_back(y), g[y].push_back(x);
    }
    dfs1(r, 0, 1);
    dfs2(r, r);
    seg.build(1, 1, n);
    while(m -- )
    {
        int op, x, y, z;
        cin >> op;
        if(op == 1)
        {
            cin >> x >> y >> z;
            tp.update(x, y, z);
        }
        else if(op == 2)
        {
            cin >> x >> y;
            cout << tp.query(x, y) << endl;
        }
        else if(op == 3)
        {
            cin >> x >> y;
            tp.tupdate(x, y);
        }
        else {
            cin >> x;
            cout << tp.tquery(x) << endl;
        }
    }
    return 0;
}

by Genius_Star @ 2024-07-16 11:56:53

@Rainypaster 在 update 处,您并没有对最后操作的 usum 更新。

即:

    void update(int u, int l, int r, int k)
    {
        if(l <= tr[u].l && tr[u].r <= r){
            tr[u].lazy += k;
            return ;
        }
        tr[u].sum += (min(r, tr[u].r) - max(l, tr[u].l) + 1) * k;
        int mid = (tr[u].l + tr[u].r) >> 1;
        if(l <= mid) update(u << 1, l, r, k);
        if(r >  mid) update(u << 1 | 1, l, r, k);
    }

改为:

    void update(int u, int l, int r, int k)
    {
          tr[u].sum += (min(r, tr[u].r) - max(l, tr[u].l) + 1) * k;
        if(l <= tr[u].l && tr[u].r <= r){
            tr[u].lazy += k;
            return ;
        }
        int mid = (tr[u].l + tr[u].r) >> 1;
        if(l <= mid) update(u << 1, l, r, k);
        if(r >  mid) update(u << 1 | 1, l, r, k);
    }

by Rainypaster @ 2024-07-16 12:13:51

@Genius_Star 不是这个问题,我线段树直接复制的https://www.luogu.com.cn/record/164096239,还有你这样改更不对了啊,标记永久化肯定不是这么写的


by Rainypaster @ 2024-07-16 12:14:57

lazy 更新完就不需要再更新 sum 了吧


by Rainypaster @ 2024-07-16 12:15:53

看看查询那


by Genius_Star @ 2024-07-16 14:18:19

@Rainypaster 你查询处 ans 要先赋值为 0,不然会累加上次的答案。


by Genius_Star @ 2024-07-16 14:18:38

树剖查询那


by Rainypaster @ 2024-07-16 15:03:16

@Genius_Star thx,但好像TLE了,本人在外面,回去再看看


by Rainypaster @ 2024-07-16 16:27:51

如上,修改后的,还是 TLE

#include <bits/stdc++.h>
using namespace std;
int n, m, r, p;
const int N = 1e5 + 5;
int w[N];
vector<int> g[N];
int dep[N], fa[N], size[N], top[N], son[N], id[N], a[N];
int cnt;

struct Segment_Tree
{
    struct node
    {
        int l, r, sum, lazy;
    }tr[N << 4];

    void push_up(int u)
    {
        tr[u].sum = tr[u << 1].sum + tr[u << 1 | 1].sum;
        tr[u].sum %= p;
    }
    void build(int u, int l, int r)
    {
        tr[u].l = l, tr[u].r = r;
        if(l == r){
            tr[u].sum = a[l];
            tr[u].sum %= p;
            return ;
        }
        int mid = (l + r) / 2;
        build(u << 1, l, mid), build(u << 1 | 1, mid + 1, r);
        push_up(u);
    }
    void update(int u, int l, int r, int k)
    {
        if(l <= tr[u].l && tr[u].r <= r){
            tr[u].lazy += k;
            return ;
        }
        tr[u].sum += (min(r, tr[u].r) - max(l, tr[u].l) + 1) * k;
        int mid = (tr[u].l + tr[u].r) >> 1;
        if(l <= mid) update(u << 1, l, r, k);
        if(r >  mid) update(u << 1 | 1, l, r, k);
        push_up(u);
    }
    int query(int u, int l, int r)
    {
        if(l <= tr[u].l && tr[u].r <= r) {
            return tr[u].sum + (tr[u].r - tr[u].l + 1) * tr[u].lazy;
        }
        int res = (min(r, tr[u].r) - max(l, tr[u].l) + 1) * tr[u].lazy;
        int mid = (tr[u].l + tr[u].r) >> 1;
        if(l <= mid) res += query(u << 1, l, r);
        if(r >  mid) res += query(u << 1 | 1, l, r);
        return res;
    }
}seg;

struct TreePou
{
    int query(int x, int y)
    {
        int ans = 0;
        while(top[x] != top[y]){
            if(dep[top[x]] < dep[top[x]]) swap(x, y);
            ans += seg.query(1, id[top[x]], id[x]);
            ans %= p;
            x = fa[top[x]];
        }
        if(dep[x] > dep[y]) swap(x, y);
        ans += seg.query(1, id[x], id[y]);
        ans %= p;
        return ans;
    }
    void update(int x, int y, int val)
    {
        val %= p;
        while(top[x] != top[y]){
            if(dep[top[x]] < dep[top[x]]) swap(x, y);
            seg.update(1, id[top[x]], id[x], val);
            x = fa[top[x]];
        }
        if(dep[x] > dep[y]) swap(x, y);
        seg.update(1, id[x], id[y], val);
    }
    void tupdate(int x, int val){
        seg.update(1, id[x], id[x] + size[x] - 1, val);
    }
    int tquery(int x){
        return seg.query(1, id[x], id[x] + size[x] - 1);
    }
}tp;

void dfs1(int u, int f, int deep)
{
    dep[u] = deep;
    fa[u] = f;
    size[u] = 1;
    int maxn = -1;
    for(int i = 0;i < g[u].size();i ++ ){
        int v = g[u][i];
        if(v == f) continue;
        dfs1(v, u, deep + 1);
        size[u] += size[v];
        if(size[v] > maxn){
            maxn = size[v];
            son[u] = v;
        }
    }
}
void dfs2(int u, int topfa)
{
    id[u] = ++cnt;
    a[cnt] = w[u];
    top[u] = topfa;
    if(!son[u]) return ;
    dfs2(son[u], topfa);
    for(int i = 0;i < g[u].size();i ++ ){
        int v = g[u][i];
        if(v == fa[u] || v == son[u]) continue;
        dfs2(v, v);
    }
}

int main()
{
    cin >> n >> m >> r >> p;
    for(int i = 1;i <= n;i ++ ) cin >> w[i];
    for(int i = 1;i < n;i ++ ){
        int x, y; cin >> x >> y;
        g[x].push_back(y), g[y].push_back(x);
    }
    dfs1(r, 0, 1);
    dfs2(r, r);
    seg.build(1, 1, n);
    while(m -- )
    {
        int op, x, y, z;
        cin >> op;
        if(op == 1)
        {
            cin >> x >> y >> z;
            tp.update(x, y, z);
        }
        else if(op == 2)
        {
            cin >> x >> y;
            cout << tp.query(x, y) << endl;
        }
        else if(op == 3)
        {
            cin >> x >> y;
            tp.tupdate(x, y);
        }
        else {
            cin >> x;
            cout << tp.tquery(x) << endl;
        }
    }
    return 0;
}

by Rainypaster @ 2024-07-16 16:45:03

改了一下,19pts,WA

#include <bits/stdc++.h>
using namespace std;
int n, m, r, p;
const int N = 1e5 + 5;
int w[N];
vector<int> g[N];
int dep[N], fa[N], size[N], top[N], son[N], id[N], a[N];
int cnt;

struct Segment_Tree
{
    struct node
    {
        int l, r, sum, lazy;
    }tr[N << 2];

    void push_up(int u)
    {
        tr[u].sum = tr[u << 1].sum + tr[u << 1 | 1].sum;
        tr[u].sum %= p;
    }
    void build(int u, int l, int r)
    {
        tr[u].l = l, tr[u].r = r;
        if(l == r){
            tr[u].sum = a[l];
            tr[u].sum %= p;
            return ;
        }
        int mid = (l + r) / 2;
        build(u << 1, l, mid), build(u << 1 | 1, mid + 1, r);
        push_up(u);
    }
    void update(int u, int l, int r, int k)
    {
        if(l <= tr[u].l && tr[u].r <= r){
            tr[u].lazy += k;
            return ;
        }
        tr[u].sum += (min(r, tr[u].r) - max(l, tr[u].l) + 1) * k;
        int mid = (tr[u].l + tr[u].r) >> 1;
        if(l <= mid) update(u << 1, l, r, k);
        if(r >  mid) update(u << 1 | 1, l, r, k);
        push_up(u);
    }
    int query(int u, int l, int r)
    {
        if(l <= tr[u].l && tr[u].r <= r) {
            return tr[u].sum + (tr[u].r - tr[u].l + 1) * tr[u].lazy;
        }
        int res = (min(r, tr[u].r) - max(l, tr[u].l) + 1) * tr[u].lazy;
        int mid = (tr[u].l + tr[u].r) >> 1;
        if(l <= mid) res += query(u << 1, l, r);
        if(r >  mid) res += query(u << 1 | 1, l, r);
        return res;
    }
}seg;

struct TreePou
{
    int query(int x, int y)
    {
        int ans = 0;
        while(top[x] != top[y]){
            if(dep[top[x]] < dep[top[y]]) swap(x, y);
            ans += seg.query(1, id[top[x]], id[x]);
            ans %= p;
            x = fa[top[x]];
        }
        if(dep[x] > dep[y]) swap(x, y);
        ans += seg.query(1, id[x], id[y]);
        ans %= p;
        return ans;
    }
    void update(int x, int y, int val)
    {
        val %= p;
        while(top[x] != top[y]){
            if(dep[top[x]] < dep[top[y]]) swap(x, y);
            seg.update(1, id[top[x]], id[x], val);
            x = fa[top[x]];
        }
        if(dep[x] > dep[y]) swap(x, y);
        seg.update(1, id[x], id[y], val);
    }
    void tupdate(int x, int val){
        seg.update(1, id[x], id[x] + size[x] - 1, val);
    }
    int tquery(int x){
        return seg.query(1, id[x], id[x] + size[x] - 1);
    }
}tp;

void dfs1(int u, int f, int deep)
{
    dep[u] = deep;
    fa[u] = f;
    size[u] = 1;
    int maxn = -1;
    for(int i = 0;i < g[u].size();i ++ ){
        int v = g[u][i];
        if(v == f) continue;
        dfs1(v, u, deep + 1);
        size[u] += size[v];
        if(size[v] > maxn){
            maxn = size[v];
            son[u] = v;
        }
    }
}
void dfs2(int u, int topfa)
{
    id[u] = ++cnt;
    a[cnt] = w[u];
    top[u] = topfa;
    if(!son[u]) return ;
    dfs2(son[u], topfa);
    for(int i = 0;i < g[u].size();i ++ ){
        int v = g[u][i];
        if(v == fa[u] || v == son[u]) continue;
        dfs2(v, v);
    }
}

int main()
{
    cin >> n >> m >> r >> p;
    for(int i = 1;i <= n;i ++ ) cin >> w[i];
    for(int i = 1;i < n;i ++ ){
        int x, y; cin >> x >> y;
        g[x].push_back(y), g[y].push_back(x);
    }
    dfs1(r, 0, 1);
    dfs2(r, r);
    seg.build(1, 1, n);
    while(m -- )
    {
        int op, x, y, z;
        cin >> op;
        if(op == 1)
        {
            cin >> x >> y >> z;
            tp.update(x, y, z);
        }
        else if(op == 2)
        {
            cin >> x >> y;
            cout << tp.query(x, y) << endl;
        }
        else if(op == 3)
        {
            cin >> x >> y;
            tp.tupdate(x, y);
        }
        else {
            cin >> x;
            cout << tp.tquery(x) << endl;
        }
    }
    return 0;
}

by Rainypaster @ 2024-07-16 16:59:58

过了,结。


| 下一页