24 pts 求调

P1285 队员分组

uncesspath @ 2023-10-23 17:31:02

代码非常抽象,因为是 UVA 1627 改的所以有一些奇怪的清空。

#include <iostream>
#include <vector>
#include <algorithm>
#include <cstring>
#include <cmath>

using std::cin;
using std::cout;

const int N = 100 + 10;

int T, n;
int f[N][N], dp[N][N << 1];
std::vector<int> e[N];
std::vector<int> com1[N], com2[N], ans1[N][N << 1], ans2[N][N << 1];
int fa[N], dfn[N], w[N], cnt, tot;
bool flag = false;

void dfs(int u, int d = 1)
{
    if (flag)
        return;
    if (d == 1)
        com1[cnt].push_back(u);
    else
        com2[cnt].push_back(u);
    dfn[u] = ++tot;
    w[cnt] += d;
    for (auto v : e[u])
    {
        if (v == fa[u])
            continue;
        if (dfn[v] != 0 && dfn[v] < dfn[u])
        {
            flag = true;
            return;
        }
        fa[v] = u;
        dfs(v, -d);
    }
}

int main()
{
    std::ios::sync_with_stdio(false);
    cin.tie(nullptr);

    // cin >> T;
    // while (T--)
    // {
        flag = false;
        cnt = tot = 0;
        memset(fa, 0, sizeof(fa));
        memset(dfn, 0, sizeof(dfn));
        memset(w, 0, sizeof(w));
        memset(f, 0, sizeof(f));
        memset(dp, 0, sizeof(dp));

        cin >> n;
        for (int i = 1; i <= n; i++)
        {
            int x;
            while (cin >> x)
            {
                if (!x) break;
                f[i][x] = 1;
            }
        }

        for (int i = 1; i <= n; i++)
            for (int j = i + 1; j <= n; j++)
                if (!f[i][j] || !f[j][i])
                    e[i].push_back(j), e[j].push_back(i);
        for (int i = 1; i <= n; i++)
            if (!dfn[i]) cnt++, dfs(i);

        for (int i = 1; i <= n; i++)
            e[i].erase(e[i].begin(), e[i].end());

        if (flag)
        {
            for (int i = 1; i <= cnt; i++)
            {
                com1[i].erase(com1[i].begin(), com1[i].end());
                com2[i].erase(com2[i].begin(), com2[i].end());
            }
            cout << "No Solution\n\n";
            // continue;
            return 0;
        }

        dp[0][n] = 1;
        for (int i = 1; i <= cnt; i++)
        {
            for (int j = 0; j <= 2 * n; j++)
            {
                if (j + w[i] <= 2 * n)
                {
                    dp[i][j] |= dp[i - 1][j + w[i]];
                    if (dp[i - 1][j + w[i]])
                    {
                        ans1[i][j] = ans1[i - 1][j + w[i]];
                        ans2[i][j] = ans2[i - 1][j + w[i]];
                        for (auto v : com2[i])
                            ans1[i][j].push_back(v);
                        for (auto v : com1[i])
                            ans2[i][j].push_back(v);
                    }
                }
                if (j - w[i] >= 0)
                {
                    dp[i][j] |= dp[i - 1][j - w[i]];
                    if (dp[i - 1][j - w[i]])
                    {
                        ans1[i][j] = ans1[i - 1][j - w[i]];
                        ans2[i][j] = ans2[i - 1][j - w[i]];
                        for (auto v : com1[i])
                            ans1[i][j].push_back(v);
                        for (auto v : com2[i])
                            ans2[i][j].push_back(v);
                    }
                }
            }
        }

        int now = n, pos = 0;
        for (int i = 0; i <= 2 * n; i++)
        {
            if (dp[cnt][i] && now > abs(n - i))
            {
                now = std::min(now, abs(n - i));
                pos = i;
            }
        }
        std::sort(ans1[cnt][pos].begin(), ans1[cnt][pos].end());
        std::sort(ans2[cnt][pos].begin(), ans2[cnt][pos].end());
        cout << ans1[cnt][pos].size() << ' ';
        for (int i = 0; i < (int)ans1[cnt][pos].size() - 1; i++)
            cout << ans1[cnt][pos][i] << ' ';
        cout << ans1[cnt][pos][ans1[cnt][pos].size() - 1];
        cout << '\n';
        cout << ans2[cnt][pos].size() << ' ';
        for (int i = 0; i < (int)ans2[cnt][pos].size() - 1; i++)
            cout << ans2[cnt][pos][i] << ' ';
        cout << ans2[cnt][pos][ans2[cnt][pos].size() - 1];
        cout << '\n';
        // if (T)
        //     cout << '\n';

        for (int i = 1; i <= cnt; i++)
        {
            com1[i].erase(com1[i].begin(), com1[i].end());
            com2[i].erase(com2[i].begin(), com2[i].end());
            for (int j = 0; j <= 2 * n; j++)
            {
                ans1[i][j].erase(ans1[i][j].begin(), ans1[i][j].end());
                ans2[i][j].erase(ans2[i][j].begin(), ans2[i][j].end());
            }
        }
    // }
    return 0;
}

|