TLE on 20求助(树上差分+map,复杂度nlog^2n)

P1600 [NOIP2016 提高组] 天天爱跑步

cats142857 @ 2022-08-09 10:49:43

#include <bits/stdc++.h>
using namespace std;
vector<int> line[300000];
vector<int>::iterator it;
map<int,int> value1[300000],value2[300000];
map<int,int>::iterator itv,itv2;
int pw[21],mem[300000]={},mem2[300000]={},depth[300000],anc[21][300000],p1[300000],p2[300000],tim[300000],ans[300000]={};
stack<int> dfstack;
inline int fastread(){
    char ch=getchar();
    int x=0;
    while(ch<'0'||ch>'9')
    {
        ch=getchar();
    }
    while(ch>='0'&&ch<='9')
    {
        x=x*10+ch-'0';
        ch=getchar();
    }
    return x;
}
void weightdfs(){
    int k,flag;
    while(!dfstack.empty())
    {
        flag=0;
        k=dfstack.top();
        dfstack.pop();
        for(it=line[k].begin();it!=line[k].end();it++)
        {
            if(mem[*it]==0)
            {
                anc[0][*it]=k;
                depth[*it]=depth[k]+1;
                flag=1;
                dfstack.push(*it);
                mem[*it]=1;
            }
        }
    }
}
inline int lca(int x,int y){
    int tmp,i=0,lcans;
    if(depth[x]>depth[y])
    {
        tmp=x;
        x=y;
        y=tmp;
    }
    tmp=depth[y]-depth[x];
    while(tmp>0)
    {
        if(tmp%2==1)y=anc[i][y];
        tmp/=2;
        i++;
    }
    if(x==y)return x;
    tmp=depth[x];
    while(tmp>0)
    {
        if(anc[(int)log2(tmp)][x]!=anc[(int)log2(tmp)][y])
        {
            x=anc[(int)log2(tmp)][x];
            y=anc[(int)log2(tmp)][y];
            tmp-=pw[(int)log2(tmp)];
        }
        else
        {
            lcans=anc[(int)log2(tmp)][x];
            tmp=pw[(int)log2(tmp)]-1;
        }
    }
    return lcans;
}
void ansdfs(){
    int k,flag,maxc1,maxp1,maxc2,maxp2;
    while(!dfstack.empty())
    {
        flag=0;
        k=dfstack.top();
        if(mem2[k]==0)
        {
            mem2[k]=1;
            for(it=line[k].begin();it!=line[k].end();it++)
            {
                if(mem[*it]==0)
                {
                    flag=1;
                    mem[*it]=1;
                    dfstack.push(*it);
                }
            }
            continue;
        }
        dfstack.pop();
        maxc1=0;
        maxp1=-1;
        maxc2=0;
        maxp2=-1;
        for(it=line[k].begin();it!=line[k].end();it++)
        {
            if(*it!=anc[0][k])
            {
                if(value1[p1[*it]].size()>maxc1)
                {
                    maxc1=value1[p1[*it]].size();
                    maxp1=*it;
                }
                if(value2[p2[*it]].size()>maxc2)
                {
                    maxc2=value1[p2[*it]].size();
                    maxp2=*it;
                }
            }
        }
        if(value1[p1[k]].size()>value1[p1[maxp1]].size())maxp1=k;
        if(value2[p2[k]].size()>value2[p2[maxp2]].size())maxp2=k;
        for(it=line[k].begin();it!=line[k].end();it++)
        {
            if(*it!=anc[0][k]&&(*it!=maxp1))
            {
                for(itv=value1[p1[*it]].begin();itv!=value1[p1[*it]].end();itv++)
                {
                    itv2=value1[p1[maxp1]].find((*itv).first);
                    if(itv2!=value1[p1[maxp1]].end())
                    {
                        (*itv2).second+=(*itv).second;
                        if((*itv2).second==0)value1[p1[maxp1]].erase(itv2);
                    }
                    else value1[p1[maxp1]].insert(*itv);
                }
                value1[p1[*it]].clear();
            }
            if(*it!=anc[0][k]&&(*it!=maxp2))
            {
                for(itv=value2[p2[*it]].begin();itv!=value2[p2[*it]].end();itv++)
                {
                    itv2=value2[p2[maxp2]].find((*itv).first);
                    if(itv2!=value2[p2[maxp2]].end())
                    {
                        (*itv2).second+=(*itv).second;
                        if((*itv2).second==0)value2[p2[maxp2]].erase(itv2);
                    }
                    else value2[p2[maxp2]].insert(*itv);
                }
                value2[p2[*it]].clear();
            }
        }
        if(maxp1!=-1&&maxp1!=k)
        {
            for(itv=value1[p1[k]].begin();itv!=value1[p1[k]].end();itv++)
            {
                itv2=value1[p1[maxp1]].find((*itv).first);
                if(itv2!=value1[p1[maxp1]].end())
                {
                    (*itv2).second+=(*itv).second;
                    if((*itv2).second==0)value1[p1[maxp1]].erase(itv2);
                }
                else value1[p1[maxp1]].insert(*itv);
            }
            value1[p1[k]].clear();
            p1[k]=p1[maxp1];
        }
        if(maxp2!=-1&&maxp2!=k)
        {
            for(itv=value2[p2[k]].begin();itv!=value2[p2[k]].end();itv++)
            {
                itv2=value2[p2[maxp2]].find((*itv).first);
                if(itv2!=value2[p2[maxp2]].end())
                {
                    (*itv2).second+=(*itv).second;
                    if((*itv2).second==0)value2[p2[maxp2]].erase(itv2);
                }
                else value2[p2[maxp2]].insert(*itv);
            }
            value2[p2[k]].clear();
            p2[k]=p2[maxp2];
        }
        itv=value1[p1[k]].find(tim[k]+depth[k]);
        if(itv!=value1[p1[k]].end())ans[k]+=(*itv).second;
        itv=value2[p2[k]].find(tim[k]-depth[k]);
        if(itv!=value2[p2[k]].end())ans[k]+=(*itv).second;
    }
}
int main(int argc, char** argv) {
    ios::sync_with_stdio(false),cin.tie(0);
    int n,m,i,j,u,v,lcatmp;
    pw[0]=1;
    for(i=1;i<21;i++)pw[i]=pw[i-1]*2;
    n=fastread();
    m=fastread();
    for(i=0;i<n-1;i++)
    {
        u=fastread();
        v=fastread();
        line[u-1].push_back(v-1);
        line[v-1].push_back(u-1);
    }
    for(i=0;i<n;i++)tim[i]=fastread();
    dfstack.push(0);
    mem[0]=1;
    depth[0]=0;
    anc[0][0]=-1;
    weightdfs();
    for(i=1;i<21;i++)
    {
        for(j=0;j<n;j++)
        {
            if(depth[j]>=pw[i])anc[i][j]=anc[i-1][anc[i-1][j]];
        }
    }
    for(i=0;i<m;i++)
    {
        u=fastread();
        v=fastread();
        lcatmp=lca(u-1,v-1);
        itv=value1[u-1].find(depth[u-1]);
        if(itv!=value1[u-1].end())
        {
            (*itv).second++;
            if((*itv).second==0)value1[u-1].erase(itv);
        }
        else value1[u-1].insert(make_pair(depth[u-1],1));
        if(anc[0][lcatmp]!=-1)
        {
            itv=value1[anc[0][lcatmp]].find(depth[u-1]);
            if(itv!=value1[anc[0][lcatmp]].end())
            {
                (*itv).second--;
                if((*itv).second==0)value1[anc[0][lcatmp]].erase(itv);
            }
            else value1[anc[0][lcatmp]].insert(make_pair(depth[u-1],-1));
        }
        itv=value2[v-1].find(depth[u-1]-2*depth[lcatmp]);
        if(itv!=value2[v-1].end())
        {
            (*itv).second++;
            if((*itv).second==0)value2[v-1].erase(itv);
        }
        else value2[v-1].insert(make_pair(depth[u-1]-2*depth[lcatmp],1));
        itv=value2[lcatmp].find(depth[u-1]-2*depth[lcatmp]);
        if(itv!=value2[lcatmp].end())
        {
            (*itv).second--;
            if((*itv).second==0)value2[lcatmp].erase(itv);
        }
        else value2[lcatmp].insert(make_pair(depth[u-1]-2*depth[lcatmp],-1));
    }
    for(i=0;i<n;i++)
    {
        p1[i]=i;
        p2[i]=i;
        mem[i]=0;
    }
    dfstack.push(0);
    mem[0]=1;
    ansdfs();
    for(i=0;i<n;i++)
    {
        if(i<n-1)printf("%d ",ans[i]);
        else printf("%d\n",ans[i]);
    }
    return 0;
}

by cats142857 @ 2022-08-09 12:05:11

按秩合并代码错误,实现的实际复杂度为n^2logn,已改正并切题


|