81分TLE???

P3355 骑士共存问题

ABCD101 @ 2022-11-22 17:22:07

#include <bits/stdc++.h>
using namespace std;

#define ll long long
#define ull unsigned ll
#define llu ull
#define db double
#define fl float
#define us unsiged
#define fi first
#define se second
#define mp(a,b) make_pair(a,b)
#define pb(a) push_back(a)
#define pbp(a,b) pb(mp(a,b))

#define int_inf 0x3f3f3f3f
#define INT_INF INT_MAX
#define UINT_INF UINT_MAX
#define LL_INF LLONG_MAX
#define ULL_INF ULONG_LONG_MAX
#define Const const int
#define pi pair<int,int>  

#define fin(name) freopen(name,"r",stdin)
#define fout(name) freopen(name,"w",stdout)
#define read() fastRead()
#define print(a) fastPrint(a)
inline int fastRead(){
    char chr=getchar();
    int absData=0, isLowerThanZero=1;
    while(chr<'0'||chr>'9'){
        if(chr=='-') isLowerThanZero=-1;
        chr=getchar();
    }
    while(chr>='0'&&chr<='9')
        absData=(absData<<1)+(absData<<3)+(chr^48),
        chr=getchar();
    return absData*isLowerThanZero;
}
long long readll(){
    char c = getchar();
    long long x = 0, f = 1;
    for (; c < '0' || c > '9'; c = getchar())
        if (c == '-') f = -1;
    for (; c >= '0' && c <= '9'; c = getchar())
        x = (x << 1) + (x << 3) + (c ^ 48);
    return x * f;
}
inline void fastPrint(int number){
   if(number<0) putchar('-'), number=-number;
   if(number>9) fastPrint(number/10);
   putchar(number%10+'0');
}

/*****-------------------------------------*****/
Const N = 300;
int n, num[N][N], mch[N*N], vis[N*N];
double ans=0;
int ddd[4][2]={1,2,2,1,2,-1,1,-2};
bool mp[N][N];
vector<int> g[N*N];
inline bool dfs(int x, int y){
    if(vis[x]==y) return false;
    vis[x]=y;
    for(auto i : g[x])
        if(mch[i]==0 || dfs(mch[i], y)){
            mch[i] = x;
            return true;    
        }
    return false;
}
int main(){
    n=read(), ans=0;
    int m=read();
    while(m--){
        int tmp1=read(), tmp2=read();
        mp[tmp1][tmp2]=1;
    }
    for(int i=1, tot=0;i<=n;i++) 
        for(int j=1;j<=n;j++){
            num[i][j]=++tot, ans+=!mp[i][j];
            if(mp[i][j]) continue;
            for(int k=0;k<4;k++) 
                if(i>ddd[k][0] && j>ddd[k][1] && (j-ddd[k][1]<=n)
                && !mp[i-ddd[k][0]][j-ddd[k][1]])
                    g[num[i][j]].pb(num[i-ddd[k][0]][j-ddd[k][1]]),
                    g[num[i-ddd[k][0]][j-ddd[k][1]]].pb(num[i][j]);
        }
    for(int i=1;i<=n;i++)
        for(int j=1;j<=n;j++) ans-=dfs(num[i][j],1)/2.0,
            memset(vis, 0, sizeof vis);
    cout<<ans<<endl;
    return 0;
}

///////////////////

其中一个点本地测5秒


by StkOvflow @ 2023-01-14 13:36:49

@ABCD101 匈牙利算法是会被卡掉的,这题正解是最大流来着


by StkOvflow @ 2023-01-14 14:09:07

@StkOvflow 但是看到有佬说匈牙利过了,不知道怎么写QWQ


by StkOvflow @ 2023-01-16 19:54:14

@ABCD101 我的匈牙利AC了,贴一下

#include <iostream>
#include <cstring>
#include <algorithm>

using namespace std;
using PII = pair<int, int>;

const int N = 210, M = N * N;
int mat[M];
int n, m, k, res;
int st[M], g[N][N];
int h[M], e[M << 3], ne[M << 3], idx;
int dict[8][2] = {{-1, -2}, {-1, 2}, {1, -2}, {1, 2}, {-2, -1}, {-2, 1}, {2, -1}, {2, 1}};

int get(int x, int y) { return (x - 1) * n + y; }

void add(int a, int b) 
{
    e[ ++ idx] = b, ne[idx] = h[a], h[a] = idx;
}

bool find(int u) 
{
    for (int i = h[u]; i; i = ne[i]) 
    {
        int j = e[i];
        if (st[j]) continue ;
        st[j] = true ;
        int t = mat[j];
        if (!t || find(t)) 
        {
            mat[j] = u;
            return true ;
        }
    }
    return false;
}

int main() 
{
    scanf("%d%d", &n, &k);

    int t = k;
    while (t -- ) 
    {
        int x, y;
        scanf("%d%d", &x, &y);
        g[x][y] = true;
    }

    for (int i = 1; i <= n; i ++ )
        for (int j = 1; j <= n; j ++ )
            {
                if (g[i][j] || (i + j) % 2 == 0) continue ;
                for (int u = 0; u < 8; u ++ ) 
                {
                    int x = i + dict[u][0], y = j + dict[u][1];
                    if (x < 1 || x > n || y < 1 || y > n) continue ;
                    if (!g[x][y]) add(get(i, j), get(x, y));
                }
            }

    for (int i = 1; i <= n; i ++ ) 
        for (int j = 1; j <= n; j ++ ) 
        {
            if (g[i][j] || (i + j) % 2 == 0) continue ;
            memset(st, 0, sizeof st);
            if (find(get(i, j))) res ++ ;
        }
    printf("%d\n", n * n - k - res);

    return 0;
}

|