萌新初学OI,求救

P4383 [八省联考 2018] 林克卡特树

Sai0511 @ 2019-03-17 10:09:02

哪里写错了= = ?

#include <bits/stdc++.h>
#define il inline
typedef long long ll;
const int maxn = 3e5 + 10;
using namespace std;
namespace Fast_Input {
    template<class T> il void read(T& res) {
        res = 0;char ch;bool sign = 0;
        do { ch = getchar(); sign |= ch == '-'; } while(!isdigit(ch));
        while(isdigit(ch)) res = (res << 1) + (res << 3) + (ch & 15) , ch = getchar();
        (sign) && (res = -res);
    }
}
using Fast_Input::read;
template<class A,class B> il void cmax(A& x,B y) {
    if(x < y) x = y;    return;
}
int n,m,i,j,k,ecnt;
ll cnt,l,r;
struct data {
    ll x,y;
    data() { x = y = 0; }
    data(ll _x,ll _y) { x = _x; y = _y; }
    il bool operator < (const data& z) const {
        if(x == z.x) return y > z.y; else return x < z.x;
    }
    il data operator + (const data& z) const {  
        return data(x + z.x,y + z.y);
    }
    il data operator + (int z) {
        return data(x + z,y);
    }
}f[maxn][3];         
il data rew(data z,int mid) {
    return data(z.x - mid,z.y + 1);
}
int head[maxn],wei[maxn << 1],ver[maxn << 1],nxt[maxn << 1];
il void addedge(int u,int v,int w) {
    wei[++ecnt] = w;
    ver[ecnt] = v;
    nxt[ecnt] = head[u];
    head[u] = ecnt;
    return;
}
il int _abs(int x) {
    return x < 0 ? -x : x;
}
void dfs(int u,int fa,ll mid) {
    cmax(f[u][2],data(-mid,1));
    for(int i = head[u];~i;i = nxt[i]) {
        int v = ver[i],w = wei[i];
        if(v != fa) {
            dfs(v,u,mid);
            f[u][2] = max(f[u][2] + f[v][0],rew(f[u][1] + f[v][1] + w,mid));
            f[u][1] = max(f[u][1] + f[v][0],f[u][0] + f[v][1] + w);
            f[u][0] = f[u][0] + f[v][0];
        }
    }
    cmax(f[u][0],max(rew(f[u][1],mid),f[u][2]));  
}
int main() {
    read(n);read(k); k++;  memset(head,-1,sizeof(head));
    for(int i = 1,u,v,w;i < n;i++) {
        read(u);read(v);read(w);
        addedge(u,v,w);
        addedge(v,u,w);   
        cnt += _abs(w);
    }
    r = cnt; l = -r;
    while(l <= r) {
        ll mid = (l + r) >> 1;
        memset(f,0,sizeof(f));
        dfs(1,0,mid);       
        if(f[1][0].y <= k) r = mid - 1;
        else l = mid + 1;
    }
    memset(f,0,sizeof(f)); dfs(1,0,l);
    printf("%lld\n",l * k + f[1][0].x);
    return 0;
}

by resftlmuttmotw @ 2019-03-17 10:10:07

@Sai_0511

您下面 帮我看看 巨佬 谢谢


|