AC自动机

hyfhaha

2019-05-02 10:45:51

Personal

begin:2019/5/2 update 2020/6/12 更新了LaTeX(咕了好久 感谢大家支持! ## [更好的阅读体验](https://www.cnblogs.com/hyfhaha/p/10802604.html) # AC自动机详细讲解 **AC自动机**真是个好东西!之前学$KMP$被$Next$指针搞晕了,所以咕了许久都不敢开**AC自动机**,近期学完之后,发现**AC自动机**并不是很难,特别是对于$KMP$​,个人感觉**AC自动机**比$KMP$要好理解一些,可能是因为我对树上的东西比较敏感(实际是因为我到现在都不会$KMP$)。 很多人都说**AC自动机**是在$Trie$树上作$KMP$,我不否认这一种观点,因为这确实是这样,不过对于刚开始学**AC自动机**的同学们就一些误导性的理解(至少对我是这样的)。$KMP$是建立在一个字符串上的,现在把$KMP$搬到了树上,不是很麻烦吗?实际上**AC自动机**只是有$KMP$的一种思想,实际上跟一个字符串的$KMP$有着很大的不同。 所以看这篇blog,请放下$KMP$,理解好$Trie$,再来学习。 ## 前置技能 1.[$Trie$](https://www.luogu.org/blog/juruohyfhaha/trie-xue-xi-zong-jie)(很重要哦) 2.$KMP$的思想(懂思想就可以了,不需要很熟练) # 问题描述 给定$n$个模式串和$1$个文本串,求有多少个模式串在文本串里**出现过**。 注意:是出现过,就是出现多次只算一次。 默认这里每一个人都已经会了$Trie$。 我们将$n$个模式串建成一颗$Trie$树,建树的方式和建$Trie$完全一样。 ![AC自动机](https://i.loli.net/2019/05/02/5ccaaa22cbf29.png) 假如我们现在有文本串$ABCDBC$。 我们用文本串在$Trie$上匹配,刚开始会经过$2、3、4$号点,发现到$4$,成功地匹配了一个模式串,然后就不能再继续匹配了,这时我们还要重新继续从根开始匹配吗? 不,这样的效率太慢了。这时我们就要借用$KMP$的思想,从$Trie$上的某个点继续开始匹配。 明显在这颗$Trie$上,我们可以继续从$7$号点开始匹配,然后匹配到$8$。 那么我们怎么确定从那个点开始匹配呢?我们称$i$匹配失败后继续从$j$开始匹配,$j$是$i$的$Fail$(失配指针)。 ## 构建Fail指针 ### $Fail$的含义 $Fail$指针的实质含义是什么呢? 如果一个点$i$的$Fail$指针指向$j$。那么$root$到$j$的字符串是$root$到$i$的字符串的一个后缀。 举个例子:(例子来自上面的图 ```cpp i:4 j:7 root到i的字符串是“ABC” root到j的字符串是“BC” “BC”是“ABC”的一个后缀 所以i的Fail指针指向j ``` 同时我们发现,“$C$”也是“$ABC$”的一个后缀。 所以$Fail$指针指的$j$的深度要尽量大。 重申一下$Fail$指针的含义:**((最长的(当前字符串的后缀))**在$Trie$上可以查找到)的末尾编号。 感觉读起来挺绕口的蛤。感性理解一下就好了,没什么卵用的。知道$Fail$有什么用就行了。 ### 求$Fail$ 首先我们可以确定,每一个点$i$的$Fail$指针指向的点的深度一定是比$i$小的。(Fail指的是后缀啊) 第一层的$Fail$一定指的是$root$。(比深度$1$还浅的只有$root$了) 设点$i$的父亲$fa$的$Fail$指针指的是$fafail$,那么如果$fafail$有和$i$值相同的儿子$j$,那么$i$的$Fail$就指向$j$。这里可能比较难理解一点,建议画图理解,不过等会转换成代码就很好理解了。 由于我们在处理$i$的情况必须要先处理好$fa$的情况,所以求$Fail$我们使用$BFS$来实现。 #### 实现的一些细节: * 1、刚开始我们不是要初始化第一层的$fail$指针为$root$,其实我们可以建一个虚节点$0$号节点,将$0$的**所有儿子**指向$root$($root$编号为$1$,记得初始化),然后$root$的$fail$指向$0$就OK了。效果是一样的。 * 2、如果不存在一个节点$i$,那么我们可以将那个节点设为$fafail$的**((值和$i$相同)的儿子)**。保证存在性,就算是$0$也可以成功返回到根,因为$0$的所有儿子都是根。 * 3、无论$fafail$存不存在和$i$值相同的儿子$j$,我们都可以将$i$的$fail$指向$j$。因为在处理$i$的时候$j$已经处理好了,如果出现这种情况,$j$的值是第$2$种情况,也是有实际值的,所以没有问题。 * 4、实现时不记父亲,我们直接让父亲更新儿子 ```cpp void getFail(){ for(int i=0;i<26;i++)trie[0].son[i]=1; //初始化0的所有儿子都是1 q.push(1);trie[1].fail=0; //将根压入队列 while(!q.empty()){ int u=q.front();q.pop(); for(int i=0;i<26;i++){ //遍历所有儿子 int v=trie[u].son[i]; //处理u的i儿子的fail,这样就可以不用记父亲了 int Fail=trie[u].fail; //就是fafail,trie[Fail].son[i]就是和v值相同的点 if(!v){trie[u].son[i]=trie[Fail].son[i];continue;} //不存在该节点,第二种情况 trie[v].fail=trie[Fail].son[i]; //第三种情况,直接指就可以了 q.push(v); //存在实节点才压入队列 } } } ``` # 查询 求出了$Fail$指针,查询就变得十分简单了。 为了避免重复计算,我们每经过一个点就打个标记为$-1$,下一次经过就不重复计算了。 同时,如果一个字符串匹配成功,那么他的$Fail$也肯定可以匹配成功(后缀嘛),于是我们就把$Fail$再统计答案,同样,$Fail$的$Fail$也可以匹配成功,以此类推……经过的点累加$flag$,标记为$-1$。 最后主要还是和$Trie$的查询是一样的。 ```cpp int query(char* s){ int u=1,ans=0,len=strlen(s); for(int i=0;i<len;i++){ int v=s[i]-'a'; int k=trie[u].son[v]; //跳Fail while(k>1&&trie[k].flag!=-1){ //经过就不统计了 ans+=trie[k].flag,trie[k].flag=-1; //累加上这个位置的模式串个数,标记 已 经过 k=trie[k].fail; //继续跳Fail } u=trie[u].son[v]; //到儿子那,存在性看上面的第二种情况 } return ans; } ``` # 代码 ```cpp #include<bits/stdc++.h> #define maxn 1000001 using namespace std; struct kkk{ int son[26],flag,fail; }trie[maxn]; int n,cnt; char s[1000001]; queue<int >q; void insert(char* s){ int u=1,len=strlen(s); for(int i=0;i<len;i++){ int v=s[i]-'a'; if(!trie[u].son[v])trie[u].son[v]=++cnt; u=trie[u].son[v]; } trie[u].flag++; } void getFail(){ for(int i=0;i<26;i++)trie[0].son[i]=1; //初始化0的所有儿子都是1 q.push(1);trie[1].fail=0; //将根压入队列 while(!q.empty()){ int u=q.front();q.pop(); for(int i=0;i<26;i++){ //遍历所有儿子 int v=trie[u].son[i]; //处理u的i儿子的fail,这样就可以不用记父亲了 int Fail=trie[u].fail; //就是fafail,trie[Fail].son[i]就是和v值相同的点 if(!v){trie[u].son[i]=trie[Fail].son[i];continue;} //不存在该节点,第二种情况 trie[v].fail=trie[Fail].son[i]; //第三种情况,直接指就可以了 q.push(v); //存在实节点才压入队列 } } } int query(char* s){ int u=1,ans=0,len=strlen(s); for(int i=0;i<len;i++){ int v=s[i]-'a'; int k=trie[u].son[v]; //跳Fail while(k>1&&trie[k].flag!=-1){ //经过就不统计了 ans+=trie[k].flag,trie[k].flag=-1; //累加上这个位置的模式串个数,标记已经过 k=trie[k].fail; //继续跳Fail } u=trie[u].son[v]; //到下一个儿子 } return ans; } int main(){ cnt=1; //代码实现细节,编号从1开始 scanf("%d",&n); for(int i=1;i<=n;i++){ scanf("%s",s); insert(s); } getFail(); scanf("%s",s); printf("%d\n",query(s)); return 0; } ``` updata:2019/5/7 AC自动机的应用 # AC自动机的一些应用 先拿[**P3796** 【模板】AC自动机(加强版)](https://www.luogu.org/problemnew/show/P3796)来说吧。 无疑,作为**模板2**,这道题的解法也是十分的经典。 我们先来分析一下题目:输入和模板1一样 1、求出现次数最多的次数 2、求出现次数最多的模式串 明显,我们如果统计出每一个模式串在文本串出现的次数,那么这道题就变得十分简单了,那么问题就变成了如何统计每个模式串出现的次数。 **做法:AC自动机** 首先题目统计的是出现次数最多的字符串,所以有重复的字符串是没有关系的。(因为后面的会覆盖前面的,统计的答案也是一样的) 那么我们就将标记模式串的$flag$设为当前是第几个模式串。就是下面插入时的变化: ```cpp trie[u].flag++; 变为 trie[u].flag=num; //num表示该字符串是第num个输入的 ``` 求$Fail$指针没有变化,原先怎么求就怎么求。 **查询**:我们开一个数组$vis$,表示第$i$个字符串出现的次数。 因为是重复计算,所以不能标记为$-1$了。 我们每经过一个点,如果有模式串标记,就将$vis[模式串标记]++$。然后继续跳fail,原因上面说过了。 这样我们就可以将每个模式串的出现次数统计出来。剩下的大家应该都会QwQ! ### 总代码 ```cpp //AC自动机加强版 #include<bits/stdc++.h> #define maxn 1000001 using namespace std; char s[151][maxn],T[maxn]; int n,cnt,vis[maxn],ans; struct kkk{ int son[26],fail,flag; void clear(){memset(son,0,sizeof(son));fail=flag=0;} }trie[maxn]; queue<int>q; void insert(char* s,int num){ int u=1,len=strlen(s); for(int i=0;i<len;i++){ int v=s[i]-'a'; if(!trie[u].son[v])trie[u].son[v]=++cnt; u=trie[u].son[v]; } trie[u].flag=num; //变化1:标记为第num个出现的字符串 } void getFail(){ for(int i=0;i<26;i++)trie[0].son[i]=1; q.push(1);trie[1].fail=0; while(!q.empty()){ int u=q.front();q.pop(); int Fail=trie[u].fail; for(int i=0;i<26;i++){ int v=trie[u].son[i]; if(!v){trie[u].son[i]=trie[Fail].son[i];continue;} trie[v].fail=trie[Fail].son[i]; q.push(v); } } } void query(char* s){ int u=1,len=strlen(s); for(int i=0;i<len;i++){ int v=s[i]-'a'; int k=trie[u].son[v]; while(k>1){ if(trie[k].flag)vis[trie[k].flag]++; //如果有模式串标记,更新出现次数 k=trie[k].fail; } u=trie[u].son[v]; } } void clear(){ for(int i=0;i<=cnt;i++)trie[i].clear(); for(int i=1;i<=n;i++)vis[i]=0; cnt=1;ans=0; } int main(){ while(1){ scanf("%d",&n);if(!n)break; clear(); for(int i=1;i<=n;i++){ scanf("%s",s[i]); insert(s[i],i); } scanf("%s",T); getFail(); query(T); for(int i=1;i<=n;i++)ans=max(vis[i],ans); //最后统计答案 printf("%d\n",ans); for(int i=1;i<=n;i++) if(vis[i]==ans) printf("%s\n",s[i]); } } ``` update:2019/5/9 # AC自动机的优化 ## topo建图优化 让我们了分析一下刚才那个**模板2**的时间复杂度,算了不分析了,直接告诉你吧,这样暴力去跳$fail$的最坏时间复杂度是$O(模式串长度 · 文本串长度)$。 为什么?因为对于每一次跳$fail$我们都只使深度减$1$,那样深度是多少,每一次跳的时间复杂度就是多少。那么还要乘上文本串长度,就几乎是 $O(模式串长度 · 文本串长度)$的了。 那么**模板1**的时间复杂度为什么就只有$O(模式串总长)$。因为每一个$Trie$上的点都只会经过**一次**(打了标记),但**模板2**每一个点就不止经过一次了(重复算,不打标记),所以时间复杂度就爆炸了。 那么我们可不可以让**模板2**的$Trie$上每个点只经过一次呢? 嗯~,还真可以! 题目看这里:[**P5357** 【模板】AC自动机(二次加强版)](https://www.luogu.org/problemnew/show/P5357) ### 做法:拓扑排序 让我们把$Trie$上的$fail$都**想象**成一条条**有向边**,那么我们如果在一个点对那个点进行一些操作,那么沿着这个点连出去的点也会进行操作(就是跳$fail$),所以我们才要暴力跳$fail$去更新之后的点。 ![AC自动机](https://i.loli.net/2019/05/02/5ccaaa22cbf29.png) 我们还是用上面的图,举个例子解释一下我刚才的意思。 我们先找到了编号$4$这个点,编号$4$的$fail$连向编号$7$这个点,编号$7$的$fail$连向编号$9$这个点。那么我们要更新编号$4$这个点的值,同时也要更新编号$7$和编号$9$,这就是暴力跳$fail$的过程。 我们下一次找到编号$7$这个点,还要**再次**更新编号$9$,所以时间复杂度就在这里被浪费了。 那么我们可不可以在找到的点打一个标记,最后再**一次性**将标记全部上传 来 更新其他点的$ans$。例如我们找到编号$4$,在编号$4$这个点打一个$ans$标记为$1$,下一次找到了编号$7$,又在编号$7$这个点打一个$ans$标记为$1$,那么最后,我们直接从编号$4$开始跳$fail$,然后将标记$ans$上传,**((点i的fail)的ans)加上(点i的ans)**,最后使编号$4$的$ans$为$1$,编号$7$的$ans$为$2$,编号$9$的$ans$为$2$,这样的答案和暴力跳$fail$是一样的,并且每一个点只经过了**一次**。 最后我们将有$flag$标记的$ans$传到$vis$数组里,就求出了答案。 em……,建议先消化一下。 那么现在问题来了,怎么确定更新顺序呢?明显我们打了标记后肯定是从**深度大**的点开始更新上去的。 怎么实现呢?**拓扑排序!** 我们使每一个点向它的$fail$指针连一条边,明显,每一个点的**出度**为$1$($fail$只有一个),**入度**可能很多,所以我们就不需要像拓扑排序那样先建个图了,直接往$fail$指针跳就可以了。 最后我们根据$fail$指针建好图后(想象一下,程序里不用实现),一定是一个$DAG$,具体原因不解释(很简单的),那么我们就直接在上面跑拓扑排序,然后更新$ans$就可以了。 #### 代码实现: 首先是$getfail$这里,记得将$fail$的**入度**$in$更新。 ```cpp trie[v].fail=trie[Fail].son[i]; in[trie[v].fail]++; //记得加上入度 ``` 然后是$query$,不用暴力跳$fail$了,直接打上标记就行了,很简单吧 ```cpp void query(char* s){ int u=1,len=strlen(s); for(int i=0;i<len;++i) u=trie[u].son[s[i]-'a'],trie[u].ans++; //直接打上标记 } ``` 最后是拓扑,解释都在注释里了OwO! ```cpp void topu(){ for(int i=1;i<=cnt;++i) if(in[i]==0)q.push(i); //将入度为0的点全部压入队列里 while(!q.empty()){ int u=q.front();q.pop();vis[trie[u].flag]=trie[u].ans; //如果有flag标记就更新vis数组 int v=trie[u].fail;in[v]--; //将唯一连出去的出边fail的入度减去(拓扑排序的操作) trie[v].ans+=trie[u].ans; //更新fail的ans值 if(in[v]==0)q.push(v); //拓扑排序常规操作 } } ``` 应该还是很好理解的吧,实现起来也没有多难嘛! 对了还有重复单词的问题,和下面讲的"P3966[TJOI2013]单词"的解决方法一样的,不讲了吧。 # 习题讲解 基础题:[**P3966** [TJOI2013]单词](https://www.luogu.org/problemnew/show/P3966) 这道题和上面那道题没有什么不同,文本串就是将模式串用神奇的字符(例如"♂")隔起来的串。 但这道题有相同字符串要统计,所以我们用一个$Map$数组存这个字符串指的是$Trie$中的那个位置,最后把$vis[Map[i]]$输出就OK了。 下面是P5357【模板】AC自动机(二次加强版)的代码(套娃?大雾),剩下的大家怎么改应该还是知道的吧。 ```cpp #include<bits/stdc++.h> #define maxn 2000001 using namespace std; char s[maxn],T[maxn]; int n,cnt,vis[200051],ans,in[maxn],Map[maxn]; struct kkk{ int son[26],fail,flag,ans; }trie[maxn]; queue<int>q; void insert(char* s,int num){ int u=1,len=strlen(s); for(int i=0;i<len;++i){ int v=s[i]-'a'; if(!trie[u].son[v])trie[u].son[v]=++cnt; u=trie[u].son[v]; } if(!trie[u].flag)trie[u].flag=num; Map[num]=trie[u].flag; } void getFail(){ for(int i=0;i<26;i++)trie[0].son[i]=1; q.push(1); while(!q.empty()){ int u=q.front();q.pop(); int Fail=trie[u].fail; for(int i=0;i<26;++i){ int v=trie[u].son[i]; if(!v){trie[u].son[i]=trie[Fail].son[i];continue;} trie[v].fail=trie[Fail].son[i]; in[trie[v].fail]++; q.push(v); } } } void topu(){ for(int i=1;i<=cnt;++i) if(in[i]==0)q.push(i); //将入度为0的点全部压入队列里 while(!q.empty()){ int u=q.front();q.pop();vis[trie[u].flag]=trie[u].ans; //如果有flag标记就更新vis数组 int v=trie[u].fail;in[v]--; //将唯一连出去的出边fail的入度减去(拓扑排序的操作) trie[v].ans+=trie[u].ans; //更新fail的ans值 if(in[v]==0)q.push(v); //拓扑排序常规操作 } } void query(char* s){ int u=1,len=strlen(s); for(int i=0;i<len;++i) u=trie[u].son[s[i]-'a'],trie[u].ans++; } int main(){ scanf("%d",&n); cnt=1; for(int i=1;i<=n;++i){ scanf("%s",s); insert(s,i); }getFail();scanf("%s",T); query(T);topu(); for(int i=1;i<=n;++i)printf("%d\n",vis[Map[i]]); } ``` To be continue……