cats142857 @ 2022-08-09 10:49:43
#include <bits/stdc++.h>
using namespace std;
vector<int> line[300000];
vector<int>::iterator it;
map<int,int> value1[300000],value2[300000];
map<int,int>::iterator itv,itv2;
int pw[21],mem[300000]={},mem2[300000]={},depth[300000],anc[21][300000],p1[300000],p2[300000],tim[300000],ans[300000]={};
stack<int> dfstack;
inline int fastread(){
char ch=getchar();
int x=0;
while(ch<'0'||ch>'9')
{
ch=getchar();
}
while(ch>='0'&&ch<='9')
{
x=x*10+ch-'0';
ch=getchar();
}
return x;
}
void weightdfs(){
int k,flag;
while(!dfstack.empty())
{
flag=0;
k=dfstack.top();
dfstack.pop();
for(it=line[k].begin();it!=line[k].end();it++)
{
if(mem[*it]==0)
{
anc[0][*it]=k;
depth[*it]=depth[k]+1;
flag=1;
dfstack.push(*it);
mem[*it]=1;
}
}
}
}
inline int lca(int x,int y){
int tmp,i=0,lcans;
if(depth[x]>depth[y])
{
tmp=x;
x=y;
y=tmp;
}
tmp=depth[y]-depth[x];
while(tmp>0)
{
if(tmp%2==1)y=anc[i][y];
tmp/=2;
i++;
}
if(x==y)return x;
tmp=depth[x];
while(tmp>0)
{
if(anc[(int)log2(tmp)][x]!=anc[(int)log2(tmp)][y])
{
x=anc[(int)log2(tmp)][x];
y=anc[(int)log2(tmp)][y];
tmp-=pw[(int)log2(tmp)];
}
else
{
lcans=anc[(int)log2(tmp)][x];
tmp=pw[(int)log2(tmp)]-1;
}
}
return lcans;
}
void ansdfs(){
int k,flag,maxc1,maxp1,maxc2,maxp2;
while(!dfstack.empty())
{
flag=0;
k=dfstack.top();
if(mem2[k]==0)
{
mem2[k]=1;
for(it=line[k].begin();it!=line[k].end();it++)
{
if(mem[*it]==0)
{
flag=1;
mem[*it]=1;
dfstack.push(*it);
}
}
continue;
}
dfstack.pop();
maxc1=0;
maxp1=-1;
maxc2=0;
maxp2=-1;
for(it=line[k].begin();it!=line[k].end();it++)
{
if(*it!=anc[0][k])
{
if(value1[p1[*it]].size()>maxc1)
{
maxc1=value1[p1[*it]].size();
maxp1=*it;
}
if(value2[p2[*it]].size()>maxc2)
{
maxc2=value1[p2[*it]].size();
maxp2=*it;
}
}
}
if(value1[p1[k]].size()>value1[p1[maxp1]].size())maxp1=k;
if(value2[p2[k]].size()>value2[p2[maxp2]].size())maxp2=k;
for(it=line[k].begin();it!=line[k].end();it++)
{
if(*it!=anc[0][k]&&(*it!=maxp1))
{
for(itv=value1[p1[*it]].begin();itv!=value1[p1[*it]].end();itv++)
{
itv2=value1[p1[maxp1]].find((*itv).first);
if(itv2!=value1[p1[maxp1]].end())
{
(*itv2).second+=(*itv).second;
if((*itv2).second==0)value1[p1[maxp1]].erase(itv2);
}
else value1[p1[maxp1]].insert(*itv);
}
value1[p1[*it]].clear();
}
if(*it!=anc[0][k]&&(*it!=maxp2))
{
for(itv=value2[p2[*it]].begin();itv!=value2[p2[*it]].end();itv++)
{
itv2=value2[p2[maxp2]].find((*itv).first);
if(itv2!=value2[p2[maxp2]].end())
{
(*itv2).second+=(*itv).second;
if((*itv2).second==0)value2[p2[maxp2]].erase(itv2);
}
else value2[p2[maxp2]].insert(*itv);
}
value2[p2[*it]].clear();
}
}
if(maxp1!=-1&&maxp1!=k)
{
for(itv=value1[p1[k]].begin();itv!=value1[p1[k]].end();itv++)
{
itv2=value1[p1[maxp1]].find((*itv).first);
if(itv2!=value1[p1[maxp1]].end())
{
(*itv2).second+=(*itv).second;
if((*itv2).second==0)value1[p1[maxp1]].erase(itv2);
}
else value1[p1[maxp1]].insert(*itv);
}
value1[p1[k]].clear();
p1[k]=p1[maxp1];
}
if(maxp2!=-1&&maxp2!=k)
{
for(itv=value2[p2[k]].begin();itv!=value2[p2[k]].end();itv++)
{
itv2=value2[p2[maxp2]].find((*itv).first);
if(itv2!=value2[p2[maxp2]].end())
{
(*itv2).second+=(*itv).second;
if((*itv2).second==0)value2[p2[maxp2]].erase(itv2);
}
else value2[p2[maxp2]].insert(*itv);
}
value2[p2[k]].clear();
p2[k]=p2[maxp2];
}
itv=value1[p1[k]].find(tim[k]+depth[k]);
if(itv!=value1[p1[k]].end())ans[k]+=(*itv).second;
itv=value2[p2[k]].find(tim[k]-depth[k]);
if(itv!=value2[p2[k]].end())ans[k]+=(*itv).second;
}
}
int main(int argc, char** argv) {
ios::sync_with_stdio(false),cin.tie(0);
int n,m,i,j,u,v,lcatmp;
pw[0]=1;
for(i=1;i<21;i++)pw[i]=pw[i-1]*2;
n=fastread();
m=fastread();
for(i=0;i<n-1;i++)
{
u=fastread();
v=fastread();
line[u-1].push_back(v-1);
line[v-1].push_back(u-1);
}
for(i=0;i<n;i++)tim[i]=fastread();
dfstack.push(0);
mem[0]=1;
depth[0]=0;
anc[0][0]=-1;
weightdfs();
for(i=1;i<21;i++)
{
for(j=0;j<n;j++)
{
if(depth[j]>=pw[i])anc[i][j]=anc[i-1][anc[i-1][j]];
}
}
for(i=0;i<m;i++)
{
u=fastread();
v=fastread();
lcatmp=lca(u-1,v-1);
itv=value1[u-1].find(depth[u-1]);
if(itv!=value1[u-1].end())
{
(*itv).second++;
if((*itv).second==0)value1[u-1].erase(itv);
}
else value1[u-1].insert(make_pair(depth[u-1],1));
if(anc[0][lcatmp]!=-1)
{
itv=value1[anc[0][lcatmp]].find(depth[u-1]);
if(itv!=value1[anc[0][lcatmp]].end())
{
(*itv).second--;
if((*itv).second==0)value1[anc[0][lcatmp]].erase(itv);
}
else value1[anc[0][lcatmp]].insert(make_pair(depth[u-1],-1));
}
itv=value2[v-1].find(depth[u-1]-2*depth[lcatmp]);
if(itv!=value2[v-1].end())
{
(*itv).second++;
if((*itv).second==0)value2[v-1].erase(itv);
}
else value2[v-1].insert(make_pair(depth[u-1]-2*depth[lcatmp],1));
itv=value2[lcatmp].find(depth[u-1]-2*depth[lcatmp]);
if(itv!=value2[lcatmp].end())
{
(*itv).second--;
if((*itv).second==0)value2[lcatmp].erase(itv);
}
else value2[lcatmp].insert(make_pair(depth[u-1]-2*depth[lcatmp],-1));
}
for(i=0;i<n;i++)
{
p1[i]=i;
p2[i]=i;
mem[i]=0;
}
dfstack.push(0);
mem[0]=1;
ansdfs();
for(i=0;i<n;i++)
{
if(i<n-1)printf("%d ",ans[i]);
else printf("%d\n",ans[i]);
}
return 0;
}
by cats142857 @ 2022-08-09 12:05:11
按秩合并代码错误,实现的实际复杂度为n^2logn,已改正并切题