Danny_boodman
2018-05-31 15:40:41
好久没写博客了(主要是博主太弱全在学习),今天发现了个很有意思的东东,但炒鸡好用。于是,便把这位仁兄记录了下来。没错,他就是——树上差分。
先摆一波题
天天爱跑步;
运输计划;
疫情控制;
松鼠的新家;
一堆noip的难题由此诞生。
好了,四不四已经跃跃欲试了呢?让我们先看一下
放一个水题:
“给你一个m×n的矩阵,然后使用k块地毯铺地。每片地毯都给出左下角和右上角坐标。问所有地毯铺完之后,还有多少个整点(所谓整点,即横、纵坐标均为整数的点)没有被地毯覆盖。”
想到暴力:1.暴力枚举每张地毯
2.将所有被覆盖的点均做上标记
3.最后再枚举所有整点,若未被标记则ans+1;
然而时间复杂度是O(mnk)的,直接超时。
竟然有人想用线段树?太强了,考场祝您一路顺风。
考虑差分
用前缀和的方式进行维护,比如我们覆盖了2到5.
在2那里加上1,那么2之后的前缀和就都是1(表示覆盖)了。
然而我们要找的是2到5,不是2以后,所以在6那里要减1。
所以我们要求的前缀和
完美解决
让我们考虑树上差分
首先我们除了一般的grand,depth等数组以外,多开两个数组:tmp和prev。
tmp用来记录点的出现次数(具体点说实际上记录的是点到其父亲的边的出现次数),prev记录每个点到其父亲的那条边。对于一条起点s,终点t的路径。我们这样处理:
tmp[s]++,tmp[t]++,tmp[LCA(s,t)]-=2。(记住:最后要从所有叶结点把权值向上累加。)以一次操作为例,我们来看看效果(可以画一张图)。首先tmp[s]++,一直推上去到根,这时候s到root的路径访问次数都+1,tmp[t]++后,t到lca路径加了1,s到lca路径加了1,而lca到根的路径加了2。
这时,我们只需要tmp[LCA(s,t)]-=2,推到根,就能把那些多余的路径减掉,达到想要的目的。而这是一次操作,对于很多次操作的话,我们只需要维护tmp,而不必每次更新到根,维护好tmp最后Dfs一遍即可。这时如果tmp[i]==次数的话,说明i到其父亲的边是被所有路径覆盖的。如图
放一个例题代码:运输计划
#include<iostream>
#include<stdio.h>
#include<cstring>
using namespace std;
struct ss{
int next,to,val;
};ss data[600010];
struct truck{
int s,t,lca;
};truck node[300010];
int n,m,p,flag,cnt,maxn;
int head[300010],deep[300010],f[300010][25],dis[300010],pre[300010],sum[300010];
void change(int &a,int &b)
{
int t=a;a=b;b=t;
}
void add(int a,int b,int c)
{
data[++p].to=b;
data[p].next=head[a];
data[p].val=c;
head[a]=p;
}
void dfs(int a,int fa)
{
deep[a]=deep[fa]+1;
f[a][0]=fa;
for(int i=1;i<=20;i++)
f[a][i]=f[f[a][i-1]][i-1];
for(int i=head[a];i;i=data[i].next)
{
int v=data[i].to;
if(v==fa) continue;
dis[v]=dis[a]+data[i].val;
pre[v]=data[i].val;
dfs(v,a);
}
}
int getlca(int a,int b)
{
if(deep[a]>deep[b]) change(a,b);
for(int i=20;i>=0;i--)
{
if(deep[a]<=deep[f[b][i]])
b=f[b][i];
}
if(a==b) return a;
for(int i=20;i>=0;i--)
if(f[a][i]!=f[b][i])
{
a=f[a][i];
b=f[b][i];
}
return f[a][0];
}
int judge(int a,int fa,int cnt,int maxn)
{
int nsum=sum[a];
for(int i=head[a];i;i=data[i].next)
{
int v=data[i].to;
if(v==fa) continue;
nsum+=judge(v,a,cnt,maxn);
}
if(nsum>=cnt&&pre[a]>=maxn) flag=1;
return nsum;
}
int check(long long limit)
{
memset(sum,0,sizeof(sum));
cnt=0,flag=0,maxn=0;
for(int i=1;i<=m;i++)
if(dis[node[i].s]+dis[node[i].t]-2*dis[node[i].lca]>limit)
{
sum[node[i].s]++,sum[node[i].t]++,sum[node[i].lca]-=2;
cnt++;
maxn=max(maxn,dis[node[i].s]+dis[node[i].t]-2*dis[node[i].lca]);
}
if(cnt==0) return 1;
int wsb=judge(1,0,cnt,maxn-limit);
return flag;
}
int main()
{
scanf("%d%d",&n,&m);
for(int i=1;i<=n-1;i++)
{
int a,b,c;
scanf("%d%d%d",&a,&b,&c);
add(a,b,c);
add(b,a,c);
}
dfs(1,0);
for(int i=1;i<=m;i++)
{
scanf("%d%d",&node[i].s,&node[i].t);
node[i].lca=getlca(node[i].s,node[i].t);
}
long long l=0,r=3000000000;
while(l+1<r)
{
int mid=(l+r)/2;
if(check(mid)) r=mid;
else l=mid;
}
if(check(l)) printf("%lld\n",l);
else printf("%lld\n",r);
return 0;
}
此操作中我们这样维护:每次经过一条边,(如从u到v)我们让tmp[u]++,tmp[v]++,tmp[LCA(u,v)]--,tmp[grand[LCA(u,v)][0]]--。(最后要把tmp推上去)
以一次添加为例想象一下,首先u到根的路径上tmp都+1,此时u到根间结点tmp都为1,之后v到根路径上tmp+1,此时u到LCA前一个,v到LCA前一个点的tmp都+1,而LCA到根的所有点都+2,然后从tmp[LCA]--,更新上去,此时u-v路上所有tmp都+1,已经达到目的。
而多余的是什么部分呢,也就是LCA的上一个结点(grand[LCA][0])到根的这一段都多加了1,所以tmp[grand[LCA][0]]--,更新上去,也就完成了。
实际操作时也不需要每次更新都推上去,只要把四个tmp维护好,最后Dfs走一边就更新完了。
如图
放一个例题代码:松鼠的新家
#include<iostream>
#include<stdio.h>
using namespace std;
struct ss{
int next,to;
};ss data[600010];
int n,q;
int a[300010],head[600010],deep[300010],p[300010][25],sum[300010];
void change(int &a,int &b)
{
int t=a;a=b;b=t;
}
void add(int a,int b)
{
data[++q].to=b;
data[q].next=head[a];
head[a]=q;
}
void dfs(int a,int fa)
{
deep[a]=deep[fa]+1;
p[a][0]=fa;
for(int i=1;(1<<i)<=deep[a];i++)
p[a][i]=p[p[a][i-1]][i-1];
for(int i=head[a];i;i=data[i].next)
{
int v=data[i].to;
if(v!=fa)
dfs(v,a);
}
}
int lca(int a,int b)
{
if(deep[a]>deep[b]) change(a,b);
for(int i=20;i>=0;i--)
{
if(deep[a]<=deep[b]-(1<<i))
{
//printf("a=%d %d %d\n",a,b,i);
b=p[b][i];
}
}
//printf("%d %d\n",a,b);
if(a==b) return a;
for(int i=20;i>=0;i--)
{
if(p[a][i]!=p[b][i])
a=p[a][i],b=p[b][i];
}
//printf("a=%d\n",a);
return p[a][0];
}
void search(int a)
{
for(int i=head[a];i;i=data[i].next)
{
int v=data[i].to;
if(v==p[a][0]) continue;
search(v);
sum[a]+=sum[v];
}
}
int main()
{
scanf("%d",&n);
for(int i=1;i<=n;i++)
scanf("%d",&a[i]);
for(int i=1;i<=n-1;i++)
{
int u,v;
scanf("%d%d",&u,&v);
add(u,v);add(v,u);
}
dfs(1,0);
for(int i=1;i<=n-1;i++)
{
int x=a[i],y=a[i+1],LCA=lca(x,y);
//printf("lca=%d\n",LCA);
sum[x]++;
sum[y]++;
sum[LCA]--;
sum[p[LCA][0]]--;
}
search(1);
for(int i=2;i<=n;i++)
sum[a[i]]--;
for(int i=1;i<=n;i++)
printf("%d\n",sum[i]);
return 0;
}