OI萌新求调二分85分(玄关

P4383 [八省联考 2018] 林克卡特树

yh2022mayu @ 2024-07-29 20:58:36

#include<bits/stdc++.h>
#include<map>
#include<set>
#include<list>
#include<queue>
#include<deque>
#include<stack>
#include<cmath>
#include<cstdio>
#include<vector>
#include<cstring>
#include<sstream>
#include<iostream>
#include<algorithm>
#define YES puts("YES")
#define Yes puts("Yes")
#define NO puts("NO")
#define No puts("No")
#define b_i __int128
#define ll long long
#define ull unsigned long long
#define x first
#define y second
#define pr pair
#define dl double
#define pll pr<ll, ll>
#define pii pr<int, int>
#define pdd pr<dl, dl>
#define ld long dl
#define vc vector
#define vci vc<int>
#define vcl vc<ll>
#define vcd vc<dl>
#define vcs vc<string>
#define vcp vc<pii >
#define ps push
#define pp pop
#define frt front
#define bck back
#define psf push_front
#define psb push_back
#define ppb pop_back
#define ppf pop_front
#define bgn begin
using namespace std;
const ll md = 998244353;
const ll mod = 1e9 + 7;
const dl esp = 1e-12;
const int inf = 2e9;
const ll INF = 2e18;
inline int read(){
    char c = getchar();
    int ans = 0, cnt = 1;
    while(c < '0' || c > '9'){
        if(c == '-') cnt = -1;
        c = getchar();
    }
    while(c >= '0' && c <= '9'){
        ans *= 10;
        ans += (c - '0');
        c = getchar();
    }
    return ans * cnt;
}
inline void write(int x){
    if(x < 0){
        putchar('-');
        x = -x;
    }
    if(x > 9) write(x / 10);
    putchar(x % 10 + '0');
}
ll ksm(ll a, ll b, ll md){
    ll ans = 1;
    while(b){
        if(b % 2) ans = ans * a % md;
        a = a * a % md;
        b /= 2;
    }
    return ans;
}
int rnd(int l, int r){
    return l + rand() % (r - l + 1);
}
int dx[4] = {1, 0, -1, 0};
int dy[4] = {0, -1, 0, 1};
int n, k;
ll ans;
struct DP{
    ll ans;
    int k;
    bool operator < (DP a) const{
        if(a.ans != ans) return ans < a.ans;
        return k > a.k;
    }
    DP operator + (DP a) const{
        return (DP){a.ans + ans, a.k + k};
    }
} dp[600005][3], cdp[3], cnt;
vc<pii > g[600005];
void dfs(int u, int f, ll x){
    for(int i = 0; i < g[u].size(); i++){
        int v = g[u][i].x, w = g[u][i].y;
        if(v == f) continue;
        dfs(v, u, x);
        cdp[0] = cdp[1] = cdp[2] = {(ll)-2e18, (int)2e9};
        for(int j = 0; j < 3; j++)
            for(int k = 0; k < 3; k++) cdp[j] = max(cdp[j], dp[u][j] + dp[v][k]);
        cdp[1] = max(cdp[1], dp[u][0] + dp[v][0] + (DP){w - x, 1});
        cdp[1] = max(cdp[1], dp[u][0] + dp[v][1] + (DP){w, 0});
        cdp[2] = max(cdp[2], dp[u][1] + dp[v][1] + (DP){w + x, -1});
        cdp[2] = max(cdp[2], dp[u][1] + dp[v][0] + (DP){w, 0});
        dp[u][0] = cdp[0];
        dp[u][1] = cdp[1];
        dp[u][2] = cdp[2];
    }
}
bool chk(ll x){
    for(int i = 1; i <= n; i++) dp[i][0] = {0ll, 0}, dp[i][1] = {(ll)-2e18, (int)2e9}, dp[i][2] = {-x, 1};
    dfs(1, 0, x);
    cnt = max(max(dp[1][0], dp[1][1]), dp[1][2]);
//  cout << x << ' ' << cnt.k << endl;
    return cnt.k >= k;
}
int main(){
    //ios::sync_with_stdio(false);
    //cin.tie(NULL);
    //cout.tie(NULL);
    //freopen(".in", "r", stdin);
    //freopen(".out", "w", stdout);
    cin >> n >> k;
    k++;
    for(int i = 1, a, b, c; i < n; i++){
        cin >> a >> b >> c;
        g[a].psb({b, c});
        g[b].psb({a, c});
    }
    ll l = -1e12, r = 1e12;
    while(l < r){
        ll mid = (l + r + 1) >> 1;
//      cout << l << ' ' << r << ' ';
        if(chk(mid)) l = mid;
        else r = mid - 1;
        if(cnt.k == k){
            cout << cnt.ans + mid * k;
            return 0;
        }
    }
    chk(l);
//  cout << l << endl;
    cout << cnt.ans + l * k;
    return 0;
}

将重载小于号的k > a.k换成 k < a.k 就AC了

或者将二分换为


        ll mid = (l + r + 1) >> 1;
//      cout << l << ' ' << r << ' ';
        if(chk(mid)) l = mid;
        else r = mid - 1;

也可以


|