题解:CF1725E Electrical Efficiency

SkyStarOfficial

2024-11-19 20:41:19

Solution

Electrical Efficiency

题目大意

树上 n 个点,每个点有权值 a_i ,求出所有质因子相同的任意三个点联通块边数距离之和。

思路解析

因为每个质因子无论出现多少次,我们都把它们看做一次,所以我们认为它们对答案的贡献是独立的。

那么我们现在仅考虑一种质因子。

因为它是联通块,x,y,z 可以任意交换,所以三点之间的边的数量就是任意两点边数之和再除以二。

f(x,y,z)=\frac{d(x,y)+d(x,z)+d(y,z)}{2}

我们现在需要考虑一条边对于 f(x,y,z) 的贡献,因为一个联通块至少三个点,所以不能是这两个端点之一

它所包含的点可以是除了端点之外的任意一个合法的点,共有 k-2 种选择。k包含同种质因子的点的个数。

这些联通块每个都会包含这条边,边 (x,y) 对每个的贡献是 \frac{d(x,y)}{2},所以我们这条边对答案的贡献是 \frac{(k-2)d(x,y)}{2}

所以这种质因子对答案贡献是 \frac{(k-2)}{2}\sum{d(x,y)}

我们需要求出每个点两两之间的距离,这个可以通过树形 DP 实现。

假设我们已经求出了一个点子树内包含同种质因子的点p 个,那么这条边的贡献对两两之间距离和的贡献就是 p(k-p),原因是子树内每个包含同种质因子的点必须经过这条边才能到子树外的点。容易证明这样做是不重不漏。

这时候我们发现如果质因子太大的话可以最多需要遍历 n 遍树,但是如果用虚树维护,一个点就最多被访问 n\log{n} 次。

可以证明一个数的质因子不超过 \log{n} 个,所以最多被访问 n\log{n} 遍。

就可以使用虚树维护每种质因子。

但这种方法比较复杂,考虑一种比较简单的树上启发式合并的简单做法。

每次我们把小的子节点向大的节点合并,如果发现根节点较小就交换,这样保证了最多合并 n\log{n} 次。

我们记录每个子节点的质因数,和这个质因子一共出现的次数,那么它的系数就是 \frac{k-2}{2}

我们合并时每次减掉旧的贡献,加上新的贡献即可。

注意一个节点的初始贡献是只包含这个点的情况。

精细化实现复杂度为 O(n\log{n})。可以通过。

参考实现

#include<bits/extc++.h>
using namespace __gnu_pbds;
#define int long long
using namespace std;
const int MAXN=3e5+10;
const int mod=998244353;
gp_hash_table<int,int>mp[MAXN];
int n,colcnt[MAXN],res[MAXN],ans=0,inv;
vector<int>col[MAXN];
struct node{
    int nxt,to;
}e[MAXN*4];
int head[MAXN],tot=0;
void add(int x,int y){
    e[++tot].nxt=head[x];
    e[tot].to=y;
    head[x]=tot;
}
int qpow(int base,int ret){
    int ans=1;
    while(ret){
        if(ret&1)ans=ans*base%mod;
        base=base*base%mod;
        ret/=2;
    }
    return ans;
}
void dfs(int x,int fa){
    for(int i=head[x];i;i=e[i].nxt){
        int v=e[i].to;
        if(v==fa)continue;
        dfs(v,x);
        if(mp[x].size()<mp[v].size()){
            swap(mp[x],mp[v]);swap(res[x],res[v]);
        }
        for(auto i:mp[v]){
            res[x]-=(colcnt[i.first]-2)*mp[x][i.first]%mod*(colcnt[i.first]-mp[x][i.first])%mod*inv%mod;
            res[x]=(res[x]+mod)%mod;
            mp[x][i.first]+=i.second;
            res[x]+=(colcnt[i.first]-2)*mp[x][i.first]%mod*(colcnt[i.first]-mp[x][i.first])%mod*inv%mod;
            res[x]=(res[x]+mod)%mod;
        }
    }
    ans+=res[x];
    ans%=mod;
}
signed main(){
    ios::sync_with_stdio(0);
    cin>>n;inv=qpow(2,mod-2);
    for(int k=1;k<=n;k++){
        int x;cin>>x;
        for(int i=2;i<=sqrt(x);i++){
            if(x%i==0){
                col[k].emplace_back(i);
                colcnt[i]++;
            }
            while(x%i==0)x/=i;
        }
        if(x!=1)col[k].emplace_back(x),colcnt[x]++;
    }
    for(int i=1;i<n;i++){
        int x,y;cin>>x>>y;add(x,y);add(y,x);
    }
    for(int i=1;i<=n;i++){
        for(auto j:col[i]){
            mp[i][j]++;
            res[i]+=(colcnt[j]-2)*(colcnt[j]-1)%mod*inv%mod;
            res[i]%=mod;
        }
    }
    dfs(1,0);
    cout<<ans%mod;
    return 0;
}