sub3 T了两个点#1#4

P3806 【模板】点分治 1

Kniqht @ 2023-09-29 18:31:45

rt,点分治求找出t的原因

400ms,但是题目要求200ms

#include<bits/stdc++.h>
#define ll long long
using namespace std;
const int N=1e5+10,M=N*2,K=1e7+10;
int n,Q,m,qt,f[K]; 
int h[N],e[M],ne[M],idx;
ll w[M],q[N];
bool st[N];
void add(int a,int b,int c){
    e[idx]=b,ne[idx]=h[a],w[idx]=c,h[a]=idx++;
}
ll get_sz(int u,int fa){
    if(st[u]) return 0;
    int res=1;//子树大小初始为1别赋值成0了 
    for(int i=h[u];~i;i=ne[i]){
        int j=e[i];
        if(j==fa) continue;
        res+=get_sz(j,u);
    }
    return res;
}
void get_dist(int u,int fa,ll dis){
    if(st[u]) return;
    q[++qt]=dis;
    for(int i=h[u];~i;i=ne[i]){
        int j=e[i];
        if(j==fa)continue;
        get_dist(j,u,dis+w[i]);
    }
}
ll get_wc(int u,int fa,int tot,int &rt){
    if(st[u]) return 0;
    ll res=1,ans=0;
    for(int i=h[u];~i;i=ne[i]){
        int j=e[i];
        if(j==fa)continue;
        ll t=get_wc(j,u,tot,rt);
        ans=max(ans,t);res+=t; 
    }
    ans=max(ans,tot-res);
    if(ans<=tot/2) rt=u;
    return res;
}

ll check(ll a[],int X){
    ll res=0;
    for(int i=1;i<=X;i++)
        if(a[i]<=m) f[a[i]]++;
    for(int i=1;i<=X;i++)
        if(a[i]<=m&&f[m-a[i]]){
            if(m-a[i]==a[i]) res+=f[a[i]]-1;
            else res+=f[m-a[i]];
        }
    for(int i=1;i<=X;i++) 
        if(a[i]<=m)f[a[i]]=0;
    return res;
}
ll p[N];
ll calc(int u){
    if(st[u]) return 0;
    int sz=get_wc(u,-1,get_sz(u,-1),u),pt=0;ll res=0;
    st[u]=1;
    for(int i=h[u];~i;i=ne[i]){
        int j=e[i];
        qt=0;
        get_dist(j,u,w[i]);
        res-=check(q,qt);
        for(int k=1;k<=qt;k++){
            p[++pt]=q[k];
            if(q[k]==m) res++;
        }
    }
    res+=check(p,pt);
    for(int i=h[u];~i;i=ne[i]) res+=calc(e[i]);
    return res; 
}
int main(){
    memset(h,-1,sizeof(h));
    scanf("%d%d",&n,&Q);
    for(int i=1;i<n;i++){
        int a,b,c;scanf("%d%d%d",&a,&b,&c);
        add(a,b,c);add(b,a,c);
    }
    while(Q--){
        memset(st,0,sizeof(st));
        scanf("%d",&m);
        int t=calc(1);
        printf(t?"AYE":"NAY");
        putchar('\n');
    }
    return 0;
}

by Kniqht @ 2023-09-29 18:34:02

byd memset改了一下还是不行(记录了st数组上次被赋值为1的地方,然后只修改这些),好奇怪,点分治我这写法应该没毛病吧

#include<bits/stdc++.h>
#define ll long long
using namespace std;
const int N=1e5+10,M=N*2,K=1e7+10;
int n,Q,m,qt,f[K],last[N],cntt; 
int h[N],e[M],ne[M],idx;
ll w[M],q[N];
bool st[N];
void add(int a,int b,int c){
    e[idx]=b,ne[idx]=h[a],w[idx]=c,h[a]=idx++;
}
ll get_sz(int u,int fa){
    if(st[u]) return 0;
    int res=1;//子树大小初始为1别赋值成0了 
    for(int i=h[u];~i;i=ne[i]){
        int j=e[i];
        if(j==fa) continue;
        res+=get_sz(j,u);
    }
    return res;
}
void get_dist(int u,int fa,ll dis){
    if(st[u]) return;
    q[++qt]=dis;
    for(int i=h[u];~i;i=ne[i]){
        int j=e[i];
        if(j==fa)continue;
        get_dist(j,u,dis+w[i]);
    }
}
ll get_wc(int u,int fa,int tot,int &rt){
    if(st[u]) return 0;
    ll res=1,ans=0;
    for(int i=h[u];~i;i=ne[i]){
        int j=e[i];
        if(j==fa)continue;
        ll t=get_wc(j,u,tot,rt);
        ans=max(ans,t);res+=t; 
    }
    ans=max(ans,tot-res);
    if(ans<=tot/2) rt=u;
    return res;
}

ll check(ll a[],int X){
    ll res=0;
    for(int i=1;i<=X;i++)
        if(a[i]<=m) f[a[i]]++;
    for(int i=1;i<=X;i++)
        if(a[i]<=m&&f[m-a[i]]){
            if(m-a[i]==a[i]) res+=f[a[i]]-1;
            else res+=f[m-a[i]];
        }
    for(int i=1;i<=X;i++) 
        if(a[i]<=m)f[a[i]]=0;
    return res;
}
ll p[N];
ll calc(int u){
    if(st[u]) return 0;
    int sz=get_wc(u,-1,get_sz(u,-1),u),pt=0;ll res=0;
    st[u]=1;last[++cntt]=u;
    for(int i=h[u];~i;i=ne[i]){
        int j=e[i];
        qt=0;
        get_dist(j,u,w[i]);
        res-=check(q,qt);
        for(int k=1;k<=qt;k++){
            p[++pt]=q[k];
            if(q[k]==m) res++;
        }
    }
    res+=check(p,pt);
    for(int i=h[u];~i;i=ne[i]) res+=calc(e[i]);
    return res; 
}
int main(){
    memset(h,-1,sizeof(h));
    scanf("%d%d",&n,&Q);
    for(int i=1;i<n;i++){
        int a,b,c;scanf("%d%d%d",&a,&b,&c);
        add(a,b,c);add(b,a,c);
    }
    while(Q--){
        for(int i=1;i<=cntt;i++) st[last[i]]=0;
        cntt=0;
        scanf("%d",&m);
        int t=calc(1);
        printf(t?"AYE":"NAY");
        putchar('\n');
    }
    return 0;
}

by Lee666666 @ 2023-10-04 21:37:58

我跟你一样啊(悲

#include <cstdio>
#include <vector>
using namespace std;

inline int read() {
    int s = 0;
    char ch = getchar();
    while (ch < 48 || ch > 57) {
        ch = getchar();
    }
    while (ch > 47 && ch < 58) {
        s = (s << 3) + (s << 1) + (ch ^ 48);
        ch = getchar();
    }
    return s;
}

inline int min(int a, int b) {
    return a < b ? a : b;
}

inline int max(int a, int b) {
    return a > b ? a : b;
}

struct edge {
    int v, w, to;
} E[20015];

bool ok[115], ans[115], vis[10015];
int n, m, mn = 10000, mx, eid, id, szid, cnt, Size, Fa, p[10015], q[115], sz[10015], point[10015], s[10015], d[10015], bct[10000015];

void insert(int u, int v, int w) {
    E[eid].v = v;
    E[eid].w = w;
    E[eid].to = p[u];
    p[u] = eid++;
    return;
}

int dfs(int u, int fa) {
    int v, tmp, res = u;
    for (register int i = p[u]; ~i; i = E[i].to) {
        v = E[i].v;
        if (v != fa) {
            d[v] = d[u] + E[i].w;
            tmp = dfs(v, u);
            if (d[tmp] > d[res]) {
                res = tmp;
            }
        } 
    }
    return res;
}

void dfs1(int u, int fa) {
    sz[u] = 1;
    int v, res = 0;
    for (register int i = p[u]; ~i; i = E[i].to) {
        v = E[i].v;
        if (v != fa && !vis[v]) {
            dfs1(v, u);
            res = max(res, sz[v]);
            sz[u] += sz[v];
        }
    }
    res = max(res, Size - sz[u]);
    if (res < szid) {
        szid = res;
        id = u;
    }
    return;
}

void dfs2(int u, int fa) {
    point[cnt++] = u;
    s[u] = Fa;
    int v;
    for (register int i = p[u]; ~i; i = E[i].to) {
        v = E[i].v;
        if (v != fa && !vis[v]) {
            d[v] = d[u] + E[i].w;
            if (d[v] <= 10000000) {
                if (bct[d[v]] && bct[d[v]] != Fa) {
                    bct[d[v]] = n + 1;
                }
                else {
                    bct[d[v]] = Fa;
                }
            }
            dfs2(v, u);
        }
    }
    return;
}

void solve(int u) {
    vis[u] = 1;
    point[cnt++] = u;
    s[u] = u;
    d[u] = 0;
    bct[0] = u;
    int v, l, r;
    for (register int i = p[u]; ~i; i = E[i].to) {
        v = E[i].v;
        if (!vis[v]) {
            d[v] = E[i].w;
            if (d[v] <= 10000000) {
                if (bct[d[v]] && bct[d[v]] != v) {
                    bct[d[v]] = n + 1;
                }
                else {
                    bct[d[v]] = v;
                }
            }
            Fa = v;
            dfs2(v, u);
        }
    }
    for (register int i = 0; i < m; i++) {
        if (!ans[i] && !ok[i]) {
            for (register int j = 0; j < cnt; j++) {
                if (q[i] >= d[point[j]] && bct[q[i] - d[point[j]]] && bct[q[i] - d[point[j]]] != s[point[j]]) {
                    ans[i] = 1;
                    break;
                }
            }
        }
    }
    for (register int i = 0; i < cnt; i++) {
        if (d[point[i]] <= 10000000) {
            bct[d[point[i]]] = 0;
        }
    }
    cnt = 0;
    for (register int i = p[u]; ~i; i = E[i].to) {
        v = E[i].v;
        if (!vis[v] && sz[v] > 1) {
            szid = Size = sz[v];
            dfs1(v, u);
            solve(v);
        }
    }
    return;
}

int main() {
    // freopen("P3806.in", "r", stdin);
    // freopen("user.out", "w", stdout);
    n = read();
    m = read();
    for (register int i = 1; i <= n; i++) {
        p[i] = -1;
    }
    int u, v, w;
    for (register int i = 1; i < n; i++) {
        u = read();
        v = read();
        w = read();
        insert(u, v, w);
        insert(v, u, w);
        mn = min(mn, w);
    }
    id = dfs(1, 0);
    d[id] = 0;
    mx = d[dfs(id, 0)];
    for (register int i = 0; i < m; i++) {
        q[i] = read();
        if (q[i] < mn || q[i] > mx) {
            ok[i] = 1;
        }
    }
    szid = Size = n;
    dfs1(1, 0);
    solve(id);
    for (register int i = 0; i < m; i++) {
        if (ans[i]) {
            printf("AYE\n");
        }
        else {
            printf("NAY\n");
        }
    }
    return 0;
}

|