求助#12TLE

SP1825 FTOUR2 - Free tour II

KevinGenZe @ 2023-03-31 19:52:55


#include<bits/stdc++.h>
#define ll long long
using namespace std;
const int N=2e5+5;
int n,m,k,b[N];
int sz[N],dp[N],vis[N],rt,s,tot;
ll edge[N],sum[N],ans;
vector<pair<int,int> >e[N];
struct tree
{
    ll c[N];
    int lowbit(int x){return x&-x;}
    void Add(int p,ll x)
    {
        for(;p<=n+1;p+=lowbit(p))
        c[p]=max(c[p],x);
    }
    ll Sum(int p)
    {
        ll ans=0;
        for(;p;p-=lowbit(p))
        ans=max(ans,c[p]);
        return ans;
    }
    void Clear(int p)
    {
        for(;p<=n+1;p+=lowbit(p))
        c[p]=0;
    }
}tr;

void DP(int u,int p)
{
    sz[u]=1,dp[u]=0;
    for(auto x:e[u])
    {
        int v=x.first;
        if(v==p||vis[v]) continue;
        dp[u]=max(sz[v],dp[u]);
        sz[u]+=sz[v];
    }
    dp[u]=max(dp[u],s-sz[u]);
    if(dp[u]<dp[rt]) rt=u;
}
void getdis(int u,int p,int d,int s)
{
    if(s>k) return;
    edge[++tot]=d;
    sum[tot]=s;
    for(auto x:e[u])
    {
        int v=x.first,w=x.second;
        if(v==p||vis[v]) continue;
        getdis(v,u,d+w,s+b[v]);
    }
}
void solve(int u)
{
    int cnt=0;
    for(auto x:e[u])
    {
        int v=x.first,w=x.second;
        if(vis[v]) continue;
        tot=0,getdis(v,u,w,b[v]);
        cnt+=tot;
        for(register int i=1;i<=tot;++i)
        ans=max(ans,tr.Sum(k-sum[i]+1)+edge[i]);
        for(register int i=1;i<=tot;++i)
        tr.Add(sum[i]+1+b[u],edge[i]);
    }
    for(register int i=1;i<=cnt;++i)
    tr.Clear(sum[i]+1+b[u]);
}
void dfs(int u)
{
    vis[u]=1;
    solve(u);
    for(auto x:e[u])
    {
        int v=x.first,w=x.second;
        if(vis[v]) continue;
        rt=0,s=sz[v],dp[0]=0x3f3f3f3f;
        DP(v,u),dfs(rt);
    }
}
signed main()
{
    #ifdef LOCAL
    freopen("hhy.in","r",stdin);
    freopen("hhy.out","w",stdout);
    #endif
    // ios_base::sync_with_stdio(false);
    // cin.tie(nullptr),cout.tie(nullptr);
    cin>>n>>k>>m;
    for(register int i=1;i<=m;++i)
    {
        int x;
        cin>>x;
        b[x]=1;
    }
    for(register int i=1;i<n;++i)
    {
        int u,v,w;
        cin>>u>>v>>w;
        e[u].emplace_back(v,w);
        e[v].emplace_back(u,w);
    }
    dp[0]=0x3f3f3f3f,s=n;
    DP(1,0);
    dfs(rt);
    cout<<ans;
    return 0;
}

|