Kniqht @ 2023-09-29 18:31:45
rt,点分治求找出t的原因
400ms,但是题目要求200ms
#include<bits/stdc++.h>
#define ll long long
using namespace std;
const int N=1e5+10,M=N*2,K=1e7+10;
int n,Q,m,qt,f[K];
int h[N],e[M],ne[M],idx;
ll w[M],q[N];
bool st[N];
void add(int a,int b,int c){
e[idx]=b,ne[idx]=h[a],w[idx]=c,h[a]=idx++;
}
ll get_sz(int u,int fa){
if(st[u]) return 0;
int res=1;//子树大小初始为1别赋值成0了
for(int i=h[u];~i;i=ne[i]){
int j=e[i];
if(j==fa) continue;
res+=get_sz(j,u);
}
return res;
}
void get_dist(int u,int fa,ll dis){
if(st[u]) return;
q[++qt]=dis;
for(int i=h[u];~i;i=ne[i]){
int j=e[i];
if(j==fa)continue;
get_dist(j,u,dis+w[i]);
}
}
ll get_wc(int u,int fa,int tot,int &rt){
if(st[u]) return 0;
ll res=1,ans=0;
for(int i=h[u];~i;i=ne[i]){
int j=e[i];
if(j==fa)continue;
ll t=get_wc(j,u,tot,rt);
ans=max(ans,t);res+=t;
}
ans=max(ans,tot-res);
if(ans<=tot/2) rt=u;
return res;
}
ll check(ll a[],int X){
ll res=0;
for(int i=1;i<=X;i++)
if(a[i]<=m) f[a[i]]++;
for(int i=1;i<=X;i++)
if(a[i]<=m&&f[m-a[i]]){
if(m-a[i]==a[i]) res+=f[a[i]]-1;
else res+=f[m-a[i]];
}
for(int i=1;i<=X;i++)
if(a[i]<=m)f[a[i]]=0;
return res;
}
ll p[N];
ll calc(int u){
if(st[u]) return 0;
int sz=get_wc(u,-1,get_sz(u,-1),u),pt=0;ll res=0;
st[u]=1;
for(int i=h[u];~i;i=ne[i]){
int j=e[i];
qt=0;
get_dist(j,u,w[i]);
res-=check(q,qt);
for(int k=1;k<=qt;k++){
p[++pt]=q[k];
if(q[k]==m) res++;
}
}
res+=check(p,pt);
for(int i=h[u];~i;i=ne[i]) res+=calc(e[i]);
return res;
}
int main(){
memset(h,-1,sizeof(h));
scanf("%d%d",&n,&Q);
for(int i=1;i<n;i++){
int a,b,c;scanf("%d%d%d",&a,&b,&c);
add(a,b,c);add(b,a,c);
}
while(Q--){
memset(st,0,sizeof(st));
scanf("%d",&m);
int t=calc(1);
printf(t?"AYE":"NAY");
putchar('\n');
}
return 0;
}
by Kniqht @ 2023-09-29 18:34:02
byd memset改了一下还是不行(记录了st数组上次被赋值为1的地方,然后只修改这些),好奇怪,点分治我这写法应该没毛病吧
#include<bits/stdc++.h>
#define ll long long
using namespace std;
const int N=1e5+10,M=N*2,K=1e7+10;
int n,Q,m,qt,f[K],last[N],cntt;
int h[N],e[M],ne[M],idx;
ll w[M],q[N];
bool st[N];
void add(int a,int b,int c){
e[idx]=b,ne[idx]=h[a],w[idx]=c,h[a]=idx++;
}
ll get_sz(int u,int fa){
if(st[u]) return 0;
int res=1;//子树大小初始为1别赋值成0了
for(int i=h[u];~i;i=ne[i]){
int j=e[i];
if(j==fa) continue;
res+=get_sz(j,u);
}
return res;
}
void get_dist(int u,int fa,ll dis){
if(st[u]) return;
q[++qt]=dis;
for(int i=h[u];~i;i=ne[i]){
int j=e[i];
if(j==fa)continue;
get_dist(j,u,dis+w[i]);
}
}
ll get_wc(int u,int fa,int tot,int &rt){
if(st[u]) return 0;
ll res=1,ans=0;
for(int i=h[u];~i;i=ne[i]){
int j=e[i];
if(j==fa)continue;
ll t=get_wc(j,u,tot,rt);
ans=max(ans,t);res+=t;
}
ans=max(ans,tot-res);
if(ans<=tot/2) rt=u;
return res;
}
ll check(ll a[],int X){
ll res=0;
for(int i=1;i<=X;i++)
if(a[i]<=m) f[a[i]]++;
for(int i=1;i<=X;i++)
if(a[i]<=m&&f[m-a[i]]){
if(m-a[i]==a[i]) res+=f[a[i]]-1;
else res+=f[m-a[i]];
}
for(int i=1;i<=X;i++)
if(a[i]<=m)f[a[i]]=0;
return res;
}
ll p[N];
ll calc(int u){
if(st[u]) return 0;
int sz=get_wc(u,-1,get_sz(u,-1),u),pt=0;ll res=0;
st[u]=1;last[++cntt]=u;
for(int i=h[u];~i;i=ne[i]){
int j=e[i];
qt=0;
get_dist(j,u,w[i]);
res-=check(q,qt);
for(int k=1;k<=qt;k++){
p[++pt]=q[k];
if(q[k]==m) res++;
}
}
res+=check(p,pt);
for(int i=h[u];~i;i=ne[i]) res+=calc(e[i]);
return res;
}
int main(){
memset(h,-1,sizeof(h));
scanf("%d%d",&n,&Q);
for(int i=1;i<n;i++){
int a,b,c;scanf("%d%d%d",&a,&b,&c);
add(a,b,c);add(b,a,c);
}
while(Q--){
for(int i=1;i<=cntt;i++) st[last[i]]=0;
cntt=0;
scanf("%d",&m);
int t=calc(1);
printf(t?"AYE":"NAY");
putchar('\n');
}
return 0;
}
by Lee666666 @ 2023-10-04 21:37:58
我跟你一样啊(悲
#include <cstdio>
#include <vector>
using namespace std;
inline int read() {
int s = 0;
char ch = getchar();
while (ch < 48 || ch > 57) {
ch = getchar();
}
while (ch > 47 && ch < 58) {
s = (s << 3) + (s << 1) + (ch ^ 48);
ch = getchar();
}
return s;
}
inline int min(int a, int b) {
return a < b ? a : b;
}
inline int max(int a, int b) {
return a > b ? a : b;
}
struct edge {
int v, w, to;
} E[20015];
bool ok[115], ans[115], vis[10015];
int n, m, mn = 10000, mx, eid, id, szid, cnt, Size, Fa, p[10015], q[115], sz[10015], point[10015], s[10015], d[10015], bct[10000015];
void insert(int u, int v, int w) {
E[eid].v = v;
E[eid].w = w;
E[eid].to = p[u];
p[u] = eid++;
return;
}
int dfs(int u, int fa) {
int v, tmp, res = u;
for (register int i = p[u]; ~i; i = E[i].to) {
v = E[i].v;
if (v != fa) {
d[v] = d[u] + E[i].w;
tmp = dfs(v, u);
if (d[tmp] > d[res]) {
res = tmp;
}
}
}
return res;
}
void dfs1(int u, int fa) {
sz[u] = 1;
int v, res = 0;
for (register int i = p[u]; ~i; i = E[i].to) {
v = E[i].v;
if (v != fa && !vis[v]) {
dfs1(v, u);
res = max(res, sz[v]);
sz[u] += sz[v];
}
}
res = max(res, Size - sz[u]);
if (res < szid) {
szid = res;
id = u;
}
return;
}
void dfs2(int u, int fa) {
point[cnt++] = u;
s[u] = Fa;
int v;
for (register int i = p[u]; ~i; i = E[i].to) {
v = E[i].v;
if (v != fa && !vis[v]) {
d[v] = d[u] + E[i].w;
if (d[v] <= 10000000) {
if (bct[d[v]] && bct[d[v]] != Fa) {
bct[d[v]] = n + 1;
}
else {
bct[d[v]] = Fa;
}
}
dfs2(v, u);
}
}
return;
}
void solve(int u) {
vis[u] = 1;
point[cnt++] = u;
s[u] = u;
d[u] = 0;
bct[0] = u;
int v, l, r;
for (register int i = p[u]; ~i; i = E[i].to) {
v = E[i].v;
if (!vis[v]) {
d[v] = E[i].w;
if (d[v] <= 10000000) {
if (bct[d[v]] && bct[d[v]] != v) {
bct[d[v]] = n + 1;
}
else {
bct[d[v]] = v;
}
}
Fa = v;
dfs2(v, u);
}
}
for (register int i = 0; i < m; i++) {
if (!ans[i] && !ok[i]) {
for (register int j = 0; j < cnt; j++) {
if (q[i] >= d[point[j]] && bct[q[i] - d[point[j]]] && bct[q[i] - d[point[j]]] != s[point[j]]) {
ans[i] = 1;
break;
}
}
}
}
for (register int i = 0; i < cnt; i++) {
if (d[point[i]] <= 10000000) {
bct[d[point[i]]] = 0;
}
}
cnt = 0;
for (register int i = p[u]; ~i; i = E[i].to) {
v = E[i].v;
if (!vis[v] && sz[v] > 1) {
szid = Size = sz[v];
dfs1(v, u);
solve(v);
}
}
return;
}
int main() {
// freopen("P3806.in", "r", stdin);
// freopen("user.out", "w", stdout);
n = read();
m = read();
for (register int i = 1; i <= n; i++) {
p[i] = -1;
}
int u, v, w;
for (register int i = 1; i < n; i++) {
u = read();
v = read();
w = read();
insert(u, v, w);
insert(v, u, w);
mn = min(mn, w);
}
id = dfs(1, 0);
d[id] = 0;
mx = d[dfs(id, 0)];
for (register int i = 0; i < m; i++) {
q[i] = read();
if (q[i] < mn || q[i] > mx) {
ok[i] = 1;
}
}
szid = Size = n;
dfs1(1, 0);
solve(id);
for (register int i = 0; i < m; i++) {
if (ans[i]) {
printf("AYE\n");
}
else {
printf("NAY\n");
}
}
return 0;
}