yh2022mayu @ 2024-07-29 20:58:36
#include<bits/stdc++.h>
#include<map>
#include<set>
#include<list>
#include<queue>
#include<deque>
#include<stack>
#include<cmath>
#include<cstdio>
#include<vector>
#include<cstring>
#include<sstream>
#include<iostream>
#include<algorithm>
#define YES puts("YES")
#define Yes puts("Yes")
#define NO puts("NO")
#define No puts("No")
#define b_i __int128
#define ll long long
#define ull unsigned long long
#define x first
#define y second
#define pr pair
#define dl double
#define pll pr<ll, ll>
#define pii pr<int, int>
#define pdd pr<dl, dl>
#define ld long dl
#define vc vector
#define vci vc<int>
#define vcl vc<ll>
#define vcd vc<dl>
#define vcs vc<string>
#define vcp vc<pii >
#define ps push
#define pp pop
#define frt front
#define bck back
#define psf push_front
#define psb push_back
#define ppb pop_back
#define ppf pop_front
#define bgn begin
using namespace std;
const ll md = 998244353;
const ll mod = 1e9 + 7;
const dl esp = 1e-12;
const int inf = 2e9;
const ll INF = 2e18;
inline int read(){
char c = getchar();
int ans = 0, cnt = 1;
while(c < '0' || c > '9'){
if(c == '-') cnt = -1;
c = getchar();
}
while(c >= '0' && c <= '9'){
ans *= 10;
ans += (c - '0');
c = getchar();
}
return ans * cnt;
}
inline void write(int x){
if(x < 0){
putchar('-');
x = -x;
}
if(x > 9) write(x / 10);
putchar(x % 10 + '0');
}
ll ksm(ll a, ll b, ll md){
ll ans = 1;
while(b){
if(b % 2) ans = ans * a % md;
a = a * a % md;
b /= 2;
}
return ans;
}
int rnd(int l, int r){
return l + rand() % (r - l + 1);
}
int dx[4] = {1, 0, -1, 0};
int dy[4] = {0, -1, 0, 1};
int n, k;
ll ans;
struct DP{
ll ans;
int k;
bool operator < (DP a) const{
if(a.ans != ans) return ans < a.ans;
return k > a.k;
}
DP operator + (DP a) const{
return (DP){a.ans + ans, a.k + k};
}
} dp[600005][3], cdp[3], cnt;
vc<pii > g[600005];
void dfs(int u, int f, ll x){
for(int i = 0; i < g[u].size(); i++){
int v = g[u][i].x, w = g[u][i].y;
if(v == f) continue;
dfs(v, u, x);
cdp[0] = cdp[1] = cdp[2] = {(ll)-2e18, (int)2e9};
for(int j = 0; j < 3; j++)
for(int k = 0; k < 3; k++) cdp[j] = max(cdp[j], dp[u][j] + dp[v][k]);
cdp[1] = max(cdp[1], dp[u][0] + dp[v][0] + (DP){w - x, 1});
cdp[1] = max(cdp[1], dp[u][0] + dp[v][1] + (DP){w, 0});
cdp[2] = max(cdp[2], dp[u][1] + dp[v][1] + (DP){w + x, -1});
cdp[2] = max(cdp[2], dp[u][1] + dp[v][0] + (DP){w, 0});
dp[u][0] = cdp[0];
dp[u][1] = cdp[1];
dp[u][2] = cdp[2];
}
}
bool chk(ll x){
for(int i = 1; i <= n; i++) dp[i][0] = {0ll, 0}, dp[i][1] = {(ll)-2e18, (int)2e9}, dp[i][2] = {-x, 1};
dfs(1, 0, x);
cnt = max(max(dp[1][0], dp[1][1]), dp[1][2]);
// cout << x << ' ' << cnt.k << endl;
return cnt.k >= k;
}
int main(){
//ios::sync_with_stdio(false);
//cin.tie(NULL);
//cout.tie(NULL);
//freopen(".in", "r", stdin);
//freopen(".out", "w", stdout);
cin >> n >> k;
k++;
for(int i = 1, a, b, c; i < n; i++){
cin >> a >> b >> c;
g[a].psb({b, c});
g[b].psb({a, c});
}
ll l = -1e12, r = 1e12;
while(l < r){
ll mid = (l + r + 1) >> 1;
// cout << l << ' ' << r << ' ';
if(chk(mid)) l = mid;
else r = mid - 1;
if(cnt.k == k){
cout << cnt.ans + mid * k;
return 0;
}
}
chk(l);
// cout << l << endl;
cout << cnt.ans + l * k;
return 0;
}
将重载小于号的k > a.k
换成 k < a.k
就AC了
或者将二分换为
ll mid = (l + r + 1) >> 1;
// cout << l << ' ' << r << ' ';
if(chk(mid)) l = mid;
else r = mid - 1;
也可以