EityDawn
2024-11-20 15:31:21
统计两点间路径上颜色段数目
点分治。对于当前的分治中心
记
对于新子树的节点
复杂度为
#include<bits/stdc++.h>
#define all(x) x.begin(),x.end()
#define mset(x,y) memset((x),(y),sizeof((x)))
#define mcpy(x,y) memcpy((x),(y),sizeof((y)))
#define FileIn(x) freopen(""#x".in","r",stdin)
#define FileOut(x) freopen(""#x".out","w",stdout)
#define debug(x) cerr<<""#x" = "<<(x)<<'\n'
#define Assert(x) if(!(x)) cerr<<"Failed: "#x" at line "<<__LINE__,exit(1)
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef __int128 Int;
const int N=2e5+10,mod=1e9+7;
bool StM;
void Mod(int &x,int y)
{
x+=y;
if(x>=mod) x-=mod;
}
int n,k,m,a[N],rt=1;
struct edge{int to,val;};
vector<edge>G[N];
struct BIT{
int c[N],d[N];
#define lowbit(x) (x&-x)
BIT(){mset(c,0);}
void add(int x,int y)
{
++x;
while(x<=n)
Mod(c[x],y),++d[x],x+=lowbit(x);
return;
}
void del(int x,int y)
{
++x;
while(x<=n)
Mod(c[x],mod-y),--d[x],x+=lowbit(x);
}
int query(int x)
{
++x;
int sum=0;
while(x)
Mod(sum,c[x]),x-=lowbit(x);
return sum;
}
int C(int x)
{
++x;
int sum=0;
while(x)
sum+=d[x],x-=lowbit(x);
return sum;
}
}t[2];
int siz[N],ma[N];
bool vis[N];
int Dis[N],Val[N],id[N];
int Son[N],top=0;
int All=0;
void GetMid(int now,int from)
{
siz[now]=1,ma[now]=0;
for(auto [to,val]:G[now])
{
if(to==from||vis[to]) continue;
GetMid(to,now);
siz[now]+=siz[to];
ma[now]=max(ma[now],siz[to]);
}
ma[now]=max(ma[now],m-siz[now]);
if(ma[now]<=m/2) rt=now;
}
void GetSon(int now,int from,int Sum,int dis,int col)
{
Dis[now]=dis,Val[now]=(Sum+a[now])%mod;
Son[++top]=now;
for(auto [to,val]:G[now])
{
if(vis[to]||to==from) continue;
if(val==col) GetSon(to,now,(Sum+a[now])%mod,dis,col);
else GetSon(to,now,(Sum+a[now])%mod,dis+1,val);
}
return;
}
void Get(int now)
{
Mod(All,a[now]);
int l=0,r=0;
for(auto [to,val]:G[now])
{
if(vis[to]) continue;
int cur=top;
GetSon(to,now,a[now],0,val);
for(int i=cur+1;i<=top;i++)
{
if(Dis[Son[i]]<=k){
Mod(All,Val[Son[i]]);
l=1ll*t[val].C(k-Dis[Son[i]])*(Val[Son[i]]-a[now]+mod)%mod;
Mod(l,t[val].query(k-Dis[Son[i]]));Mod(All,l);
r=1ll*t[val^1].C(k-Dis[Son[i]]-1)*(Val[Son[i]]-a[now]+mod)%mod;
Mod(r,t[val^1].query(k-Dis[Son[i]]-1));Mod(All,r);
}
id[Son[i]]=val;
}
for(int i=cur+1;i<=top;i++)
t[val].add(Dis[Son[i]],Val[Son[i]]);
}
while(top)
t[id[Son[top]]].del(Dis[Son[top]],Val[Son[top]]),top--;
return;
}
void Calc(int now)
{
vis[now]=1,Get(now);
for(auto [to,now]:G[now])
{
if(vis[to]) continue;
GetMid(to,0),m=siz[to];
GetMid(to,0),Calc(rt);
}
return;
}
void Main()
{
cin>>n>>k;
for(int i=1;i<=n;i++) cin>>a[i];
for(int i=1,x,y,z;i<n;i++)
{
cin>>x>>y>>z;
G[x].push_back({y,z});
G[y].push_back({x,z});
}
m=n,GetMid(rt,0);Calc(rt);
cout<<All<<'\n';
}
bool EdM;
int main()
{
cerr<<fabs(&StM-&EdM)/1024.0/1024.0<<" MB\n";
ios::sync_with_stdio(false);
cin.tie(0);cout.tie(0);
int StT=clock();
int T=1;
while(T--) Main();
int EdT=clock();
cerr<<1e3*(EdT-StT)/CLOCKS_PER_SEC<<" ms\n";
return 0;
}