萌新求助, 堆优化dij

P1342 请柬

Jayx @ 2023-04-06 14:00:29

#include<iostream>
#include<cstdio>
#include<cstring>
#include<queue>
#define int long long
using namespace std;
const int maxn = 200001, inf = 0x3f3f3f3f;
priority_queue<pair<int, int >, vector<pair <int, int> >, greater< pair<int, int > > > q ;
int n, m, s, ans;
int head[maxn], head1[maxn], dis[maxn];
bool vis[maxn];
struct node
{
    int next, to, w;
} e[maxn], e1[maxn];
void add1(int u, int v, int w, int step)
{
    e[step].to = v;
    e[step].w = w;
    e[step].next = head[u];
    head[u] = step;
}
void add2(int u, int v, int w, int step)
{
    e1[step].to = v;
    e1[step].w = w;
    e1[step].next = head1[u];
    head1[u] = step;
}
void Dijstra1(int s)
{
    memset(vis, false, sizeof(vis));
    memset(dis, inf, sizeof(dis));
    dis[s] = 0;
    q.push({0, s});
    while (!q.empty())
    {
        int x = q.top().second;
        q.pop();
        if (vis[x]) continue;
        vis[x] = true;
        for (int i = head1[x]; i; i = e[i].next)
        {
            int v = e[i].to;
            if (dis[v] > dis[x] + e[i].w)
            {
                dis[v] = dis[x] + e[i].w;
                q.push({dis[v], v});
            }
        }
    }
}
void Dijstra2(int s)
{
    memset(vis, false, sizeof(vis));
    memset(dis, inf, sizeof(dis));
    dis[s] = 0;
    q.push({0, s});
    while (!q.empty())
    {
        int x = q.top().second;
        q.pop();
        if (vis[x]) continue;
        vis[x] = true;
        for (int i = head1[x]; i; i = e1[i].next)
        {
            int v = e1[i].to;
            if (dis[v] > dis[x] + e1[i].w)
            {
                dis[v] = dis[x] + e1[i].w;
                q.push({dis[v], v});
            }
        }
    }
}
signed main()
{
    //freopen("in.txt", "r", stdin);

    scanf("%lld %lld %lld", &n, &m, &s);
    for (int i = 1; i <= m; i++)
    {
        int u, v, w;
        scanf("%lld %lld %lld", &u, &v, &w);
        add1(u, v, w, i);
        add2(v, u, w, i);
    }
    Dijstra1(s);
    for (int i = 1; i <= n; i++)
    {
        ans += dis[i];
    }
    Dijstra2(s);
    for (int i = 1; i <= n; i++)
    {
        ans += dis[i];
    }
    printf("%lld\n", ans);
    return 0;
}

by yizhiming @ 2023-04-06 15:28:06

@Jayx

有错的地方都标出来了

#include<iostream>
#include<cstdio>
#include<cstring>
#include<queue>
#define int long long
using namespace std;
const int maxn = 2000001, inf = 0x3f3f3f3f;//
priority_queue<pair<int, int >, vector<pair <int, int> >, greater< pair<int, int > > > q ;
int n, m, s, ans;
int head[maxn], head1[maxn], dis[maxn];
bool vis[maxn];
struct node
{
    int next, to, w;
} e[maxn], e1[maxn];
void add1(int u, int v, int w, int step)
{
    e[step].to = v;
    e[step].w = w;
    e[step].next = head[u];
    head[u] = step;
}
void add2(int u, int v, int w, int step)
{
    e1[step].to = v;
    e1[step].w = w;
    e1[step].next = head1[u];
    head1[u] = step;
}
void Dijstra1(int s)
{
    memset(vis, false, sizeof(vis));
    memset(dis, inf, sizeof(dis));
    dis[s] = 0;
    q.push({0, s});
    while (!q.empty())
    {
        int x = q.top().second;
        q.pop();
        if (vis[x]) continue;
        vis[x] = true;
        for (int i = head[x]; i; i = e[i].next)//
        {
            int v = e[i].to;
            if (dis[v] > dis[x] + e[i].w)
            {
                dis[v] = dis[x] + e[i].w;
                q.push({dis[v], v});
            }
        }
    }
}
void Dijstra2(int s)
{
    memset(vis, false, sizeof(vis));
    memset(dis, inf, sizeof(dis));
    dis[s] = 0;
    q.push({0, s});
    while (!q.empty())
    {
        int x = q.top().second;
        q.pop();
        if (vis[x]) continue;
        vis[x] = true;
        for (int i = head1[x]; i; i = e1[i].next)
        {
            int v = e1[i].to;
            if (dis[v] > dis[x] + e1[i].w)
            {
                dis[v] = dis[x] + e1[i].w;
                q.push({dis[v], v});
            }
        }
    }
}
signed main()
{
    //freopen("in.txt", "r", stdin);

    scanf("%lld %lld", &n, &m);
    s = 1; //
    for (int i = 1; i <= m; i++)
    {
        int u, v, w;
        scanf("%lld %lld %lld", &u, &v, &w);
        add1(u, v, w, i);
        add2(v, u, w, i);
    }
    Dijstra1(s);
    for (int i = 1; i <= n; i++)
    {
        ans += dis[i];
    }
    Dijstra2(s);
    for (int i = 1; i <= n; i++)
    {
        ans += dis[i];
    }
    printf("%lld\n", ans);
    return 0;
}

by Jayx @ 2023-04-09 09:10:45

@yizhiming 关注了,已AC,谢谢大佬


|