Kev1nL1kesCod1ng
2024-11-14 17:04:53
提供一种复杂度正确的算法。
因为穿过一条边的链只有一条,所以考虑 dp 记录这条链的信息,设
不难发现链穿过点
考虑分别进行转移。
考虑在
如果穿过点
如果穿过点
最后如果根节点度数大于
这里时间复杂度均摊
先看
不难想到把
考虑将所有的
对于
这里使用启发式合并,将小的合并到大的李超树上去。
对于求
时间复杂度
const int N = 5e4 + 5;
const int M = 3e6 + 5;
const ll LNF = 1e12 + 128;
int n;
int fi[N], ne[N << 1], to[N << 1], ecnt;
int ru[N], d[N];
struct Line {
ll k, b;
} p[N]; int cnt;
int ls[M], rs[M], F[M], tot;
vector<int> e[N]; int id[N], rt[N];
ll b[N], g[N];
ll sq(ll x) {
return x * x;
}
ll calc(ll i, ll x) {
return p[i].k * x + p[i].b;
}
void push(int & u, int l, int r, int x) {
if(! u) u = ++ tot;
int mid = l + r >> 1;
int & y = F[u];
if(calc(x, mid) < calc(y, mid)) swap(x, y);
if(l == r) return;
if(calc(x, l) < calc(y, l)) push(ls[u], l, mid, x);
if(calc(x, r) < calc(y, r)) push(rs[u], mid + 1, r, x);
}
ll query(ll u, int l, int r, int p) {
if(! u) return LNF;
ll res = calc(F[u], p);
if(l == r) {
return res;
}
int mid = l + r >> 1;
if(p <= mid) chmin(res, query(ls[u], l, mid, p));
else chmin(res, query(rs[u], mid + 1, r, p));
return res;
}
void add(int u, int v) {
ne[++ecnt] = fi[u];
to[ecnt] = v;
fi[u] = ecnt;
}
void dfs(int u, int fa) {
if(u != 1 && ru[u] == 1) {
p[u] = {- 2 * d[u], sq(d[u])};
push(rt[u], 1, n << 1, u);
e[id[u]].push_back(u);
return;
}
ll res = 0;
for(int i = fi[u]; i; i = ne[i]) {
int v = to[i];
if(v == fa) continue;
d[v] = d[u] - 1;
dfs(v, u);
g[v] = query(rt[v], 1, n << 1, d[u]) + sq(d[u]) + b[v];
res += g[v];
}
p[u].b = LNF; p[u].k = - 2 * d[u];
for(int i = fi[u]; i; i = ne[i]) {
int v = to[i];
if(v == fa) continue;
int pos = v;
b[v] += res - g[v];
if(SZ(e[id[v]]) > SZ(e[id[u]])) {
swap(id[u], id[v]);
swap(rt[u], rt[v]);
swap(b[u], b[v]);
}
for(int x : e[id[v]]) {
int val = 2 * d[u] - d[x];
chmin(p[u].b, query(rt[u], 1, n << 1, val) + b[u] + sq(val) - res + p[x].b + b[v] - sq(d[x]));
}
for(int x : e[id[v]]) {
p[x].b += b[v] - b[u];
push(rt[u], 1, n << 1, x);
e[id[u]].push_back(x);
}
}
e[id[u]].push_back(u);
p[u].b -= b[u];
p[u].b += sq(d[u]);
push(rt[u], 1, n << 1, u);
}
void solve() {
cin >> n;
REP(_, n - 1) {
int u, v;
cin >> u >> v;
add(u, v), add(v, u);
ru[u] ++, ru[v] ++;
}
FOR(i, 1, n) id[i] = i;
p[0] = {0, LNF};
d[1] = n;
dfs(1, 0);
if(ru[1] == 1) {
ll ans = LNF;
FOR(i, 1, n) chmin(ans, p[i].b + p[i].k * d[1] + sq(d[1]) + b[1]);
cout << ans << endl;
}
else {
cout << p[1].b - sq(d[1]) + b[1] << endl;
}
}