蒟蒻树剖 10pts 求调 T^T

P3384 【模板】重链剖分/树链剖分

miss_A @ 2024-01-18 22:57:28

1~#4 RE,#5~#10 TLE,#11 AC T^T

看了一圈就本蒟蒻错的最离谱......

#include <iostream>
#include <cstdio>
#include <cmath>
#include <cstring>
#include <algorithm>

using namespace std;
const int N = 2e5+100;
int n, m, r, MOD;
int d[N << 2], add[N << 2];
int son[N], new_id[N], fa[N], cnt, depth[N], siz[N], top[N];

struct node
{
    int nex, to;
}e[N << 1];
int tot, head[N], w[N], new_w[N];

namespace SP
{
    void add(int u, int v){
        e[++tot].to = v;
        e[tot].nex = head[u];
        head[u] = tot;
    }

    void dfs1(int now, int fath){
        depth[now] = depth[fath] + 1;
        fa[now] = fath;
        siz[now] = 1, son[now] = 0;
        for(int i = head[now]; i; i = e[i].nex){
            if(e[i].to == fath)continue;
            dfs1(e[i].to, now);
            siz[now] += siz[e[i].to];
            if(siz[son[now]] < siz[e[i].to]){
                son[now] = e[i].to;
            }
        }
    }
    void dfs2(int now, int topx){
        top[now] = topx;
        new_id[now] = ++cnt;
        new_w[cnt] = w[now];
        if(son[now])dfs2(son[now], topx);
        else return ;
        for(int i = head[now]; i; i = e[i].nex){
            if(e[i].to != topx && e[i].to != son[now])
                dfs2(e[i].to, e[i].to);//建立新的重链
        }
    }
}

namespace segt
{
    void build(int p, int l, int r){
        if(l == r){
            d[p] = new_w[l] % MOD;
            return ;
        }
        int mid = l + ((r - l) >> 1);
        build(p << 1, l , mid);
        build(p << 1|1, mid + 1, r);
        d[p] = (d[p << 1] + d[p << 1|1]) % MOD;
    }

    void pushdown(int p, int l, int r){
        if(l == r)return ;
        int mid = l + ((r - l) >> 1);
        add[p << 1] = (add[p << 1] + add[p]) % MOD;
        add[p << 1|1] = (add[p << 1|1] + add[p]) % MOD;
        d[p << 1] = (d[p << 1] + add[p] * (mid - l + 1)) % MOD;
        d[p << 1|1] = (d[p << 1|1] + add[p] * (r - mid)) % MOD;
        add[p] = 0;
    }
    void update(int p, int l, int r, int x, int y, int z){
        if(x <= l && y >= r){//当前区间在目标区间内部
            add[p] += z;
            d[p] += z * (r - l + 1) % MOD;
            d[p] %= MOD;
            return ;
        }
        int mid = l + ((r - l) << 1);
        if(add[p])pushdown(p, l, r);//标记下传
        if(x <= mid)update(p << 1, l, mid, x, y, z);
        if(y > mid)update(p << 1|1, mid + 1, r, x, y, z);
    }
    void updRange(int x, int y, int z){
        z %= MOD;
        while(top[x] != top[y]){
            if(depth[top[x]] < depth[top[y]])swap(x, y);
            update(1, 1, n, new_id[top[x]], new_id[x], z);//更新已经过的重链
            x = fa[top[x]];
        }
        //当下 x 与 y 已经在同一条重链上
        if(depth[x] > depth[y])swap(x, y);
        update(1, 1, n, new_id[x], new_id[y], z);//更新当前重链上 x, y 之间的点
    }

    int getsum(int p, int l, int r, int x, int y){
        if(l > y || r < x || l > r)return 0;
        if(l >= x && r <= y)return d[p];
        if(add[p])pushdown(p, l, r);
        int mid = l + ((r - l) >> 1);
        int sum = 0;
        if(x <= mid)sum = (sum + getsum(p << 1, l, mid, x, y)) % MOD;
        if(y > mid)sum = (sum + getsum(p << 1|1, mid + 1, r, x, y)) % MOD;
        return sum % MOD;
    }
    int getRange(int x, int y){
        int sum = 0;
        while(top[x] != top[y]){
            if(depth[top[x]] < depth[top[y]])swap(x, y);
            sum = (sum + getsum(1, 1, n, new_id[top[x]], new_id[x])) % MOD;
            x = fa[top[x]];
        }
        if(depth[x] > depth[y])swap(x, y);
        sum = (sum + getsum(1, 1, n, new_id[x], new_id[y])) % MOD;
        return sum % MOD;
    }
}

int main(){

    ios::sync_with_stdio(false);
    cin.tie(0), cout.tie(0);

    cin >> n >> m >> r >> MOD;//结点个数、操作个数、根节点序号和取模数
    for(int i = 1; i <= n; i++){
        cin >> w[i];//结点的初始权值
    }
    for(int i = 1; i < n; i++){
        int u, v;
        cin >> u >> v;
        SP::add(u, v), SP::add(v, u);
    }

    SP::dfs1(r, 0);
    SP::dfs2(r, r);

    segt::build(1, 1, n);
    for(int i = 1; i <= m; i++){
        int k, x, y, z;
        cin >> k;
        switch (k){
        case 1:
        //将树从 x 到 y 结点最短路径上所有节点的值都加上 z
            cin >> x >> y >> z;
            segt::updRange(new_id[x], new_id[y], z);
            break;
        case 2:
        //求树从 x 到 y 结点最短路径上所有节点的值之和
            cin >> x >> y;
            cout << segt::getRange(new_id[x], new_id[y]) << endl;
            break;
        case 3:
        //将以 x 为根节点的子树内所有节点值都加上 z
            cin >> x >> z;
            segt::update(1, 1, n, new_id[x], new_id[x] + siz[x] - 1, z);
            break;
        case 4:
        //表示求以 x 为根节点的子树内所有节点值之和
            cin >> x;
            cout << segt::getsum(1, 1, n, new_id[x], new_id[x] + siz[x] - 1) << endl;
            break;
        }
    }

    return 0;
}

by small_john @ 2024-01-18 23:14:25

@miss_A 问题如下:

  1. 在求链顶时 if(e[i].to != topx && e[i].to != son[now]) 有错,应该是 if(e[i].to != f[now] && e[i].to != son[now]),这导致了 TLE;
  2. 线段树要多取模,比如 pushdown 函数中有乘法,要在前加上 1ll 并在后面取模;
  3. 线段树 updatemid 的值计算有误;
  4. 在对路径更改、查询路径时,传参数不用传 dfs 序,直接传节点编号即可。

by small_john @ 2024-01-18 23:14:53

@miss_A AC 代码:

#include <iostream>
#include <cstdio>
#include <cmath>
#include <cstring>
#include <algorithm>

using namespace std;
const int N = 2e5+100;
int n, m, r, MOD;
int d[N << 2], add[N << 2];
int son[N], new_id[N], fa[N], cnt, depth[N], siz[N], top[N];

struct node
{
    int nex, to;
}e[N << 1];
int tot, head[N], w[N], new_w[N];

namespace SP
{
    void add(int u, int v){
        e[++tot].to = v;
        e[tot].nex = head[u];
        head[u] = tot;
    }

    void dfs1(int now, int fath){
        depth[now] = depth[fath] + 1;
        fa[now] = fath;
        siz[now] = 1, son[now] = 0;
        for(int i = head[now]; i; i = e[i].nex){
            if(e[i].to == fath)continue;
            dfs1(e[i].to, now);
            siz[now] += siz[e[i].to];
            if(siz[son[now]] < siz[e[i].to]){
                son[now] = e[i].to;
            }
        }
    }
    void dfs2(int now, int topx){
        top[now] = topx;
        new_id[now] = ++cnt;
        new_w[cnt] = w[now];
        if(son[now])dfs2(son[now], topx);
        else return ;
        for(int i = head[now]; i; i = e[i].nex){
            if(e[i].to != fa[now] && e[i].to != son[now])
                dfs2(e[i].to, e[i].to);//建立新的重链
        }
    }
}

namespace segt
{
    void build(int p, int l, int r){
        if(l == r){
            d[p] = new_w[l] % MOD;
            return ;
        }
        int mid = l + ((r - l) >> 1);
        build(p << 1, l , mid);
        build(p << 1|1, mid + 1, r);
        d[p] = (d[p << 1] + d[p << 1|1]) % MOD;
    }

    void pushdown(int p, int l, int r){
        if(l == r)return ;
        int mid = l + ((r - l) >> 1);
        add[p << 1] = (add[p << 1] + add[p]) % MOD;
        add[p << 1|1] = (add[p << 1|1] + add[p]) % MOD;
        d[p << 1] = (d[p << 1] + 1ll * add[p] * (mid - l + 1) % MOD) % MOD;
        d[p << 1|1] = (d[p << 1|1] + 1ll * add[p] * (r - mid) % MOD) % MOD;
        add[p] = 0;
    }
    void update(int p, int l, int r, int x, int y, int z){
        if(x <= l && y >= r){//当前区间在目标区间内部
            add[p] += z;
            add[p] %= MOD;
            d[p] += z * (r - l + 1) % MOD;
            d[p] %= MOD;
            return ;
        }
        int mid = l + ((r - l) >> 1);
        if(add[p])pushdown(p, l, r);//标记下传
        if(x <= mid)update(p << 1, l, mid, x, y, z);
        if(y > mid) update(p << 1|1, mid + 1, r, x, y, z);
        d[p] = (d[p << 1] + d[p << 1|1]) % MOD;
    }
    void updRange(int x, int y, int z){
        z %= MOD;
        while(top[x] != top[y]){
            if(depth[top[x]] < depth[top[y]])swap(x, y);
            update(1, 1, n, new_id[top[x]], new_id[x], z);//更新已经过的重链
            x = fa[top[x]];
        }
        //当下 x 与 y 已经在同一条重链上
        if(depth[x] > depth[y])swap(x, y);
        update(1, 1, n, new_id[x], new_id[y], z);//更新当前重链上 x, y 之间的点
    }

    int getsum(int p, int l, int r, int x, int y){
        if(l >= x && r <= y)return d[p];
        if(add[p])pushdown(p, l, r);
        int mid = l + ((r - l) >> 1);
        int sum = 0;
        if(x <= mid)sum = (sum + getsum(p << 1, l, mid, x, y)) % MOD;
        if(y > mid)sum = (sum + getsum(p << 1|1, mid + 1, r, x, y)) % MOD;
        return sum % MOD;
    }
    int getRange(int x, int y){
        int sum = 0;
        while(top[x] != top[y]){
            if(depth[top[x]] < depth[top[y]])swap(x, y);
            sum = (sum + getsum(1, 1, n, new_id[top[x]], new_id[x])) % MOD;
            x = fa[top[x]];
        }
        if(depth[x] > depth[y])swap(x, y);
        sum = (sum + getsum(1, 1, n, new_id[x], new_id[y])) % MOD;
        return sum % MOD;
    }
}

int main(){

    ios::sync_with_stdio(false);
    cin.tie(0), cout.tie(0);

    cin >> n >> m >> r >> MOD;//结点个数、操作个数、根节点序号和取模数
    for(int i = 1; i <= n; i++){
        cin >> w[i];//结点的初始权值
    }
    for(int i = 1; i < n; i++){
        int u, v;
        cin >> u >> v;
        SP::add(u, v), SP::add(v, u);
    }

    SP::dfs1(r, 0);
    SP::dfs2(r, r);

    segt::build(1, 1, n);
    for(int i = 1; i <= m; i++){
        int k, x, y, z;
        cin >> k;
        switch (k){
        case 1:
        //将树从 x 到 y 结点最短路径上所有节点的值都加上 z
            cin >> x >> y >> z;
            segt::updRange(x, y, z);
            break;
        case 2:
        //求树从 x 到 y 结点最短路径上所有节点的值之和
            cin >> x >> y;
            cout << segt::getRange(x, y) << endl;
            break;
        case 3:
        //将以 x 为根节点的子树内所有节点值都加上 z
            cin >> x >> z;
            segt::update(1, 1, n, new_id[x], new_id[x] + siz[x] - 1, z);
            break;
        case 4:
        //表示求以 x 为根节点的子树内所有节点值之和
            cin >> x;
            cout << segt::getsum(1, 1, n, new_id[x], new_id[x] + siz[x] - 1) << endl;
            break;
        }
    }

    return 0;
}

by small_john @ 2024-01-18 23:16:07

@miss_A 还漏了一个错,就是线段树更改时没 pushup


by miss_A @ 2024-01-18 23:41:23

@pyy1 orz

理解了,求链顶时不走回头路,两数相乘结果可能溢出,取 mid 时用 >>,原代码查询和更改路径时误将参数的 dfs 序当作参数本身。

跪谢大佬


|