题解:CF1575E Eye-Pleasing City Park Tour

EityDawn

2024-11-20 15:31:21

Solution

思路:

统计两点间路径上颜色段数目 \le k+1 的所有点对间路径上的点权之和。

点分治。对于当前的分治中心 xx 到其子树内的贡献很好求。现在需要考虑怎么合并新子树的答案。

dis_ii 到分治中心的颜色段数目减 1val_ii 到分治中心的点权和,col_i 为分治中心到 i 的第一条边的颜色。

对于新子树的节点 y,若 dis_y\le k,我们还需求 \sum_{dis_{z}\le k-dis_{y}} val_z+val_y-a_x,其中 z 还满足 col_z=col_y,而对于 col_z\neq col_yz,则需要求 \sum_{dis_{z}\le k-dis_{y}-1} val_z+val_y-a_x。显然可以开两个树状数组维护。

复杂度为 O(n\log^2n)

#include<bits/stdc++.h>
#define all(x) x.begin(),x.end()
#define mset(x,y) memset((x),(y),sizeof((x)))
#define mcpy(x,y) memcpy((x),(y),sizeof((y)))
#define FileIn(x) freopen(""#x".in","r",stdin)
#define FileOut(x) freopen(""#x".out","w",stdout)
#define debug(x) cerr<<""#x" = "<<(x)<<'\n'
#define Assert(x) if(!(x)) cerr<<"Failed: "#x" at line "<<__LINE__,exit(1)
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef __int128 Int;
const int N=2e5+10,mod=1e9+7;
bool StM;
void Mod(int &x,int y)
{
    x+=y;
    if(x>=mod) x-=mod;
}
int n,k,m,a[N],rt=1;
struct edge{int to,val;};
vector<edge>G[N];
struct BIT{
    int c[N],d[N];
    #define lowbit(x) (x&-x)
    BIT(){mset(c,0);}
    void add(int x,int y)
    {
        ++x;
        while(x<=n)
            Mod(c[x],y),++d[x],x+=lowbit(x);
        return;
    } 
    void del(int x,int y)
    {
        ++x;
        while(x<=n)
            Mod(c[x],mod-y),--d[x],x+=lowbit(x);
    }
    int query(int x)
    {
        ++x;
        int sum=0;
        while(x)
            Mod(sum,c[x]),x-=lowbit(x);
        return sum;
    }
    int C(int x)
    {
        ++x;
        int sum=0;
        while(x)
            sum+=d[x],x-=lowbit(x);
        return sum;
    }
}t[2];
int siz[N],ma[N];
bool vis[N];
int Dis[N],Val[N],id[N];
int Son[N],top=0;
int All=0;
void GetMid(int now,int from)
{
    siz[now]=1,ma[now]=0;
    for(auto [to,val]:G[now])
    {
        if(to==from||vis[to]) continue;
        GetMid(to,now);
        siz[now]+=siz[to];
        ma[now]=max(ma[now],siz[to]);
    }
    ma[now]=max(ma[now],m-siz[now]);
    if(ma[now]<=m/2) rt=now;
}
void GetSon(int now,int from,int Sum,int dis,int col)
{
    Dis[now]=dis,Val[now]=(Sum+a[now])%mod;
    Son[++top]=now;
    for(auto [to,val]:G[now])
    {
        if(vis[to]||to==from) continue;
        if(val==col) GetSon(to,now,(Sum+a[now])%mod,dis,col);
        else GetSon(to,now,(Sum+a[now])%mod,dis+1,val);
    }
    return;
}
void Get(int now)
{
    Mod(All,a[now]);
    int l=0,r=0;
    for(auto [to,val]:G[now])
    {
        if(vis[to]) continue;
        int cur=top;
        GetSon(to,now,a[now],0,val);
        for(int i=cur+1;i<=top;i++)
        {
            if(Dis[Son[i]]<=k){
                Mod(All,Val[Son[i]]);
                l=1ll*t[val].C(k-Dis[Son[i]])*(Val[Son[i]]-a[now]+mod)%mod;
                Mod(l,t[val].query(k-Dis[Son[i]]));Mod(All,l);
                r=1ll*t[val^1].C(k-Dis[Son[i]]-1)*(Val[Son[i]]-a[now]+mod)%mod;
                Mod(r,t[val^1].query(k-Dis[Son[i]]-1));Mod(All,r);
            }
            id[Son[i]]=val;
        }
        for(int i=cur+1;i<=top;i++)
            t[val].add(Dis[Son[i]],Val[Son[i]]);
    }
    while(top)
        t[id[Son[top]]].del(Dis[Son[top]],Val[Son[top]]),top--;
    return;
}
void Calc(int now)
{
    vis[now]=1,Get(now);
    for(auto [to,now]:G[now])
    {  
        if(vis[to]) continue;
        GetMid(to,0),m=siz[to];
        GetMid(to,0),Calc(rt);
    }
    return;
}
void Main()
{
    cin>>n>>k;
    for(int i=1;i<=n;i++) cin>>a[i];
    for(int i=1,x,y,z;i<n;i++)
    {
        cin>>x>>y>>z;
        G[x].push_back({y,z});
        G[y].push_back({x,z});
    }
    m=n,GetMid(rt,0);Calc(rt);
    cout<<All<<'\n';
}
bool EdM;
int main()
{
    cerr<<fabs(&StM-&EdM)/1024.0/1024.0<<" MB\n";
    ios::sync_with_stdio(false);
    cin.tie(0);cout.tie(0);
    int StT=clock();
    int T=1;
    while(T--) Main();
    int EdT=clock();
    cerr<<1e3*(EdT-StT)/CLOCKS_PER_SEC<<" ms\n";
    return 0;
}