AT_arc087_d [ARC087F] Squirrel Migration 题解

_fairytale_

2024-11-20 17:44:25

Solution

如果能重现那个夏天你我多么渺小

把振幅逐一背下就能数出当时心跳

垂首看船裁开水面荡摇

慌乱对视后失笑

于是所有烦闷都四散奔逃

考虑 \sum{dis(i,p_i)} 的上界。对于树上的每条边 fa_u\to u ,这条边最多被经过 2\min(n-siz_u,siz_u) 次。因此答案上界是 \sum_{fa_u\to u} 2\min(n-siz_u,siz_u)

把重心提为根,这样对于一条边 fa_u\to u\min(n-siz_u,siz_u)=siz_u,于是答案上界可以写成 \sum_{fa_u\to u} 2siz_u

尝试构造一种方案使答案取到上界。不难发现,这等价于要求对于根的每个子树中的点 u,要求 p_u 不在子树中。对这个东西计数,设 f_{i,j,k} 表示考虑了前 i 棵子树,有 j 个点 u 没有找好 p_u,有 k 个点 v 还不存在一个 p_u=v 的方案数。转移是枚举当前子树中的 a 个点,把它们的 p 定为前面的点,再枚举当前子树中的 b 个点,把前面的点的 p 定为它们,记 c 为当前子树大小,从 f_{i-1,j,k} 转移到 f_{i,j-b+c-a,k-a+c-b},复杂度大概是 \mathcal O(n^4)

发现转移方程的第二维和第三维其实是一样的。这是因为我们可以认为过程是先把子树中的 c 个点合并进前面的子树中,然后对其中的 (a+b) 个点找到对应的 p,这样同时也有 (a+b) 个点被 p 对应了。所以重新设 f_{i,j} 表示考虑了前 i 棵子树,有 j 个点没有找好 p,同时有 j 个点没有被某个 p 对应的方案数。转移同样枚举 a,b,复杂度 \mathcal O(n^3)

考虑容斥。设 f_{i,j} 表示考虑了前 i 棵子树,钦定了 j 个点的 p 在自己子树中的方案数。这样不需考虑每个点的 p 不能在自己子树中的限制,也就是,剩下没有被钦定的点只需要构成一个排列。

最终我们以 \mathcal O(n^2) 的复杂度解决了这个问题。

#include<bits/stdc++.h>
bool Mst;
#define rep(x,qwq,qaq) for(int x=(qwq);x<=(qaq);++x)
#define per(x,qwq,qaq) for(int x=(qwq);x>=(qaq);--x)
using namespace std;
#define m107 1000000007
template<class _T>
void ckmax(_T &x,_T y) {
    x=max(x,y);
}
template <int MOD>
struct modint {
    int val;
    static int norm(const int& x) {
        return x < 0 ? x + MOD : x;
    }
    static constexpr int get_mod() {
        return MOD;
    }
    modint() : val(0) {}
    modint(const int& m) : val(norm(m)) {}
    modint(const long long& m) : val(norm(m % MOD)) {}
    modint operator-() const {
        return modint(norm(-val));
    }
    bool operator==(const modint& o) {
        return val == o.val;
    }
    bool operator<(const modint& o) {
        return val < o.val;
    }
    modint& operator+=(const modint& o) {
        return val = (1ll * val + o.val) % MOD, *this;
    }
    modint& operator-=(const modint& o) {
        return val = norm(1ll * val - o.val), *this;
    }
    modint& operator*=(const modint& o) {
        return val = static_cast<int>(1ll * val * o.val % MOD), *this;
    }
    modint& operator/=(const modint& o) {
        return *this *= o.inv();
    }
    modint& operator^=(const modint& o) {
        return val ^= o.val, *this;
    }
    modint& operator>>=(const modint& o) {
        return val >>= o.val, *this;
    }
    modint& operator<<=(const modint& o) {
        return val <<= o.val, *this;
    }
    modint operator-(const modint& o) const {
        return modint(*this) -= o;
    }
    modint operator+(const modint& o) const {
        return modint(*this) += o;
    }
    modint operator*(const modint& o) const {
        return modint(*this) *= o;
    }
    modint operator/(const modint& o) const {
        return modint(*this) /= o;
    }
    modint operator^(const modint& o) const {
        return modint(*this) ^= o;
    }
    bool operator!=(const modint& o) {
        return val != o.val;
    }
    modint operator>>(const modint& o) const {
        return modint(*this) >>= o;
    }
    modint operator<<(const modint& o) const {
        return modint(*this) <<= o;
    }
    friend std::istream& operator>>(std::istream& is, modint& a) {
        long long v;
        return is >> v, a.val = norm(v % MOD), is;
    }
    friend std::ostream& operator<<(std::ostream& os, const modint& a) {
        return os << a.val;
    }
    friend std::string tostring(const modint& a) {
        return std::to_string(a.val);
    }
    template <typename T>
    friend modint qpow(const modint a, const T& b) {
        assert(b >= 0);
        modint x = a, res = 1;
        for (T p = b; p; x *= x, p >>= 1)
            if (p & 1) res *= x;
        return res;
    }
    modint inv() const {
        return qpow(*this,MOD-2);
    }
};
using M107 = modint<1000000007>;
using mint = M107;
#define inf 0x3f3f3f3f
#define maxn 5100
#define mod m107
template<typename Tp>
int qp(int x,Tp y) {
    assert(y>=0);
    x%=mod;
    int res=1;
    while(y) {
        if(y&1)res=1ll*res*x%mod;
        x=1ll*x*x%mod;
        y>>=1;
    }
    return res;
}
int inv(int x) {
    return qp(x,mod-2);
}

struct Combinatorics {
#define Lim 2000000
    int fac[Lim+10],invfac[Lim+10];
    Combinatorics() {
        fac[0]=invfac[0]=1;
        rep(i,1,Lim)fac[i]=1ll*fac[i-1]*i%mod;
        invfac[Lim]=inv(fac[Lim]);
        per(i,Lim-1,1)invfac[i]=1ll*invfac[i+1]*(i+1)%mod;
    }
    int C(int n,int m) {
        if(n<m||n<0||m<0)return 0;
        return 1ll*fac[n]*invfac[m]%mod*invfac[n-m]%mod;
    }
    int A(int n,int m) {
        if(n<m||n<0||m<0)return 0;
        return 1ll*fac[n]*invfac[n-m]%mod;
    }
} comb;
bool Med;
signed main() {
    cerr<<(&Mst-&Med)/1024.0/1024.0<<" MB\n";
    ios::sync_with_stdio(false);
    cin.tie(0),cout.tie(0);
    int n;cin>>n;
    vector<vector<int>>g(n+1);
    for(int i=1,u,v; i<n; ++i) {
        cin>>u>>v;
        g[u].emplace_back(v),g[v].emplace_back(u);
    }
    int rt=0,maxp=inf;
    auto dfs=[&](auto &self,int u,int f)->int {
        int sz=1;
        int mxp=0;
        for(int v:g[u]) {
            if(v==f)continue;
            int sizv=self(self,v,u);
            ckmax(mxp,sizv);
            sz+=sizv;
        }
        ckmax(mxp,n-sz+1);
        if(mxp<maxp)rt=u,maxp=mxp;
        return sz;
    };
    dfs(dfs,1,0);
    vector<vector<mint>>f(n+1,vector<mint>(n+1));
    int pre=0;
    int m=g[rt].size();
    f[0][0]=1;
    rep(i,1,m){
        int v=g[rt][i-1];
        int c=dfs(dfs,v,rt);
        rep(j,0,pre) {
            rep(k,0,c) {
                f[i][j+k]+=f[i-1][j]*comb.C(c,k)*comb.C(c,k)*comb.fac[k];
            }
        }
        pre+=c;
    }
    mint ans=0;
    rep(i,0,n)ans+=f[m][i]*comb.fac[n-i]*(i&1?-1:1);
    cout<<ans<<'\n';
    return 0;
}