84pts求条

P3366 【模板】最小生成树

Cute_Furina @ 2024-09-25 20:56:22

#include <bits/stdc++.h>
using namespace std;
#define int long long
int n, m, ans, memo[5010];
bool vis[5010];
struct edge {
    int u, v, t;
}mp[200010];
bool cmp(edge a, edge b) {
    if(a.t != b.t) return a.t < b.t;
    return a.u < b.u;
}
int find(int a) {
    if(a != memo[a]) memo[a] = find(memo[a]);
    return memo[a];
}
signed main() {
    for(int i = 0; i < 5009;i ++) {
        memo[i] = i;
    }
    cin >> n >> m;
    for(int i = 0;i < m;i ++) {
        cin >> mp[i].u >> mp[i].v >> mp[i].t;
    }
    sort(mp, mp + m, cmp);
    for(int i = 0; i < m;i ++) {
        if(find(mp[i].u) == find(mp[i].v)) {
            continue;
        }
        memo[find(mp[i].u)] = memo[find(mp[i].v)];
        ans += mp[i].t;
        vis[mp[i].u] = 1;
        vis[mp[i].v] = 1;
    }
    for(int i = 1;i <= n;i ++) {
        if(vis[i] == 0) {
            cout << "orz" << endl;
            return 0;
        }
    }
    cout << ans << endl;
    return 0 ;
}

by yyyx_ @ 2024-09-25 21:12:02

考虑最小生成树有 n-1 条边,按选择了多少条边来算。

退出循环后,判断是否选够了 n-1 条边即可。

对于你的代码,如果只选择了 n-2 条边,但形成二分图,此时 vis 数组全部标记为 1,但显然不是一个完整的图。

修改后的代码:

for (int i = 0; i < m; i++)
{
    if (find(mp[i].u) == find(mp[i].v))
    {
        continue;
    }
    memo[find(mp[i].u)] = memo[find(mp[i].v)];
    ans += mp[i].t;
    if (++edge == n - 1)
        break;
}
if (edge == n - 1)
    cout << ans << endl;
else
    cout << "orz" << endl;

@Luowj


by yyyx_ @ 2024-09-25 21:14:43

更正:不是二分图,是恰好两个不连通的连通块


by Sheep_YCH @ 2024-09-25 21:15:00

@Luowj


by Sheep_YCH @ 2024-09-25 21:15:27

@Luowj

#include <bits/stdc++.h>
using namespace std;
#define int long long
int n, m, ans, memo[5010];
bool vis[5010];
struct edge {
    int u, v, t;
}mp[200010];
bool cmp(edge a, edge b) {
    if(a.t != b.t) return a.t < b.t;
    return a.u < b.u;
}
int find(int a) {
    if(a != memo[a]) memo[a] = find(memo[a]);
    return memo[a];
}
signed main() {
    for(int i = 0; i < 5009;i ++) {
        memo[i] = i;
    }
    cin >> n >> m;
    for(int i = 0;i < m;i ++) {
        cin >> mp[i].u >> mp[i].v >> mp[i].t;
    }
    sort(mp, mp + m, cmp);
    for(int i = 0; i < m;i ++) {

        int fu = find(mp[i].u),fv = find(mp[i].v);
        if(fu == fv) {
            continue;
        }
        memo[fu] = fv;
        ans += mp[i].t;
        vis[mp[i].u] = 1;
        vis[mp[i].v] = 1;
    }
    int sum = 0;
    for(int i = 1;i <= n;i ++) {
        if(find(i) == i) sum ++;
    }
    sum != 1 ? cout << "orz" : cout << ans << endl;
    return 0 ;
}

by Sheep_YCH @ 2024-09-25 21:17:23

@Luowj

注意 memo[find(mp[i].u)]=memo[find(mp[i].v)] 是不对的,应该是我发的代码里的那样


by Cute_Furina @ 2024-09-26 18:20:50

@yangyeyixuan @Sheep_YCH thx


|