回文自动机学习笔记

木xx木大

2020-12-18 22:34:24

Algo. & Theory

题外话:字符串算法真的好难理解!(当然也是因为我太菜了)这篇博客的内容我用了一整天才完成,希望能写得还算清楚吧。

简要介绍

回文自动机(PAM)是一种处理回文串的工具。它的每个结点表示一个本质不同的回文串。

因为回文串长度分为奇数和偶数,为了方便处理,回文自动机由两棵树组成, 一棵树中的节点对应的回文子串长度均为奇数,另一棵树中的节点对应的回文子串长度均为偶数。

一个节点的 fail 指针指向的是这个节点所代表的回文串的最长回文真后缀所对应的节点,转移边 c 表示在当前字符串的首尾分别加一个字符 c

我们还需要在每个节点上维护此节点对应回文子串的长度 len

构建

我们称两棵树的根为奇根、偶根,分别代表长度为 -1,0 的回文串。注意:它们不表示任何实际的字符串,仅作为初始状态存在 。 偶根的 fail 指针指向奇根,而我们并不关心奇根的 fail 指针,因为奇根不可能失配(奇根转移出的下一个状态长度为 1,即单个字符,一定是回文子串)。

考虑构造完前 pos-1 个字符的回文树后,向自动机中添加在原串里位置为 pos 的字符,则可能会新产生一些以字符 s_{pos} 为结尾的回文字符串,而这些字符串可以看作往一个满足 s_{a-1}=s_{pos} 的回文字符串s[a,pos-1] 前后各加了一个字符。那么,我们从以 pos-1 结尾的最长回文子串对应的节点开始,不断沿着 fail 指针走,直到找到一个节点 p 满足 s_{pos}=s_{pos-len_p-1} ,即满足此节点所对应回文子串的上一个字符与待添加字符相同。 如果一直沿着 fail 指针走,最终会到长度为 −1 的奇根,由于 pos−(−1)−1=pos,所以这个式子最终一定会成立。这个过程即代码中的 getfail 函数。

设第一个满足如上条件的点是 p。如果 p 已经有了字符 s_{pos} 的转移,则直接增加它的 cnt(该字符串出现次数)即可。否则,新建结点 q,显然有 len_q=len_p+2 。如何求 fail_q 呢?从 fail_p 走到第一个在后面加字符 s_{pos} 仍为回文串的地方(方法同getfail),把它加字符 s_{pos} 后转移到的点作为 fail_q

P5496 【模板】回文自动机(PAM)

#include<bits/stdc++.h>
typedef long long ll;
using namespace std;
namespace FGF
{
    int n,m;
    const int N=5e5+5;
    char s[N];  
    int fai[N],cnt[N],len[N],nxt[N][30],tot,lst,pos;
    int getfail(int x)
    {
        while(s[pos-len[x]-1]!=s[pos])x=fai[x];
        return x;
    }
    void inser(char c)
    {
        int x=c-'a',p=getfail(lst);
        if(!nxt[p][x])
        {
            len[++tot]=len[p]+2;
            fai[tot]=nxt[getfail(fai[p])][x];
            nxt[p][x]=tot;
            cnt[tot]=cnt[fai[tot]]+1;
        }
        lst=nxt[p][x];
    }
    void work()
    {
        scanf("%s",s+1);n=strlen(s+1);
        s[0]='#',fai[0]=1;
        len[0]=0,len[1]=-1,tot=1;
        for(pos=1;pos<=n;pos++)
        {
            inser(s[pos]);
            s[pos+1]=(s[pos+1]-97+cnt[lst])%26+97;
            printf("%d ",cnt[lst]);
        }
    }
}
int main()
{
    FGF::work();
    return 0;
}

相关性质

定义:若 0<p\le |s|\forall1\le i\le |s|-p,s_i=s_{i+p} ,就称 ps 的周期。

下面陈列一些也许有用的性质,证明详见 oi-wiki

一些应用

应用一:一个串的本质不同回文子串个数等于回文树的状态数(排除奇根和偶根两个状态)。

应用二 :求回文子串出现次数

因为插入的时候,我们只在当前点结尾的最长回文串的结点 cnt 上加1;且如果一个串出现了,它的最长回文真后缀一定也出现了。所以我们只需要逆序枚举所有状态,将当前状态的出现次数加到其 fail 指针对应状态的出现次数上即可。

例题: P3649 [APIO2014]回文串

#include<bits/stdc++.h>
typedef long long ll;
using namespace std;
namespace FGF
{
    int n,m;
    const int N=3e5+5;
    char s[N];
    int fai[N],nxt[N][30],len[N],cnt[N],lst,tot;
    ll ans;
    int getfail(int x,int y)
    {
        while(s[y-1-len[x]]!=s[y])x=fai[x];
        return x;
    }
    void build()
    {
        s[0]='#';fai[0]=1,lst=0;
        len[0]=0,len[tot=1]=-1;
        for(int i=1;s[i];i++)
        {
            int x=s[i]-'a',p=getfail(lst,i);
            if(!nxt[p][x])
            {
                len[++tot]=len[p]+2;
                fai[tot]=nxt[getfail(fai[p],i)][x];
                nxt[p][x]=tot;
            }
            lst=nxt[p][x];cnt[lst]++;
        }
    }
    void work()
    {
        scanf("%s",s+1);
        n=strlen(s+1);
        build();
        for(int i=tot;i;i--)
            cnt[fai[i]]+=cnt[i],ans=max(ans,1LL*cnt[i]*len[i]);
        printf("%lld",ans);
    }
}
int main()
{
    FGF::work();
    return 0;
}

应用三:f 指针的引入

引入一个f 指针指向长度小于等于当前节点一半的最长回文后缀的节点,求法和 fail 指针的求法类似。

当我们新建一个节点后,如果它的长度小于等于 2,那么这个节点的 f 指针指向它的 fail 节点

否则,从它父亲的 f 指针指向的节点开始跳 fail 指针,直到跳到某一个节点所表示的回文串的两侧都能扩展这个字符且拓展后的长度小于等于当前节点长度的一半,那么新建节点的 f 的指针就指向该节点的儿子。

不同的人对这个指针的叫法不同,但本质是相同的。

例题:P4287 [SHOI2011]双倍回文

一个字符串满足双倍回文,当且仅当它的长度是 4 的倍数且它的 f 指针指向的节点所表示的回文串长度刚好是这个字符串长度的一半。枚举每个节点更新答案即可。

#include<bits/stdc++.h>
typedef long long ll;
using namespace std;
namespace FGF
{
    int n,m,ans;
    const int N=5e5+5;
    char s[N];
    int fai[N],lst,len[N],f[N],nxt[N][30],tot;
    int getfail(int x,int y)
    {
        while(s[y-len[x]-1]!=s[y])x=fai[x];
        return x;
    }
    void build()
    {
        s[0]='#';len[tot=1]=-1;fai[0]=1;
        for(int i=1;i<=n;i++)
        {
            int x=s[i]-'a',p=getfail(lst,i);
            if(!nxt[p][x])
            {
                len[++tot]=len[p]+2;
                fai[tot]=nxt[getfail(fai[p],i)][x];
                nxt[p][x]=tot;
                if(len[tot]<=2)f[tot]=fai[tot];
                else
                {
                    int tmp=f[p];
                    while(s[i-len[tmp]-1]!=s[i]||len[tmp]+2>len[tot]>>1)tmp=fai[tmp];
                    f[tot]=nxt[tmp][x];
                }
            }
            lst=nxt[p][x];
        }
    }
    void work()
    {
        scanf("%d",&n);
        scanf("%s",s+1);
        build();
        for(int i=1;i<=tot;i++)
            if(len[i]%4==0&&len[i]==len[f[i]]*2)ans=max(ans,len[i]);
        printf("%d",ans);
    }
}
int main()
{
    FGF::work();
    return 0;
}

P4762 [CERC2014]Virus synthesis

先建回文自动机,然后记 dp_i表示转移到 i 节点代表的回文串的最少的需要次数。

显然 2 操作越多越好,而经过 2 操作之后的串必定是一个回文串,所以最后的答案肯定是由一个回文串+不断暴力添加得来,所以有 ans=\min(ans,dp_i+n-len_i)

代码细节:注意赋初值和清零!

#include<bits/stdc++.h>
typedef long long ll;
using namespace std;
namespace FGF
{
    int n,m;
    const int N=1e5+5;
    char s[N];
    int val[N],f[N],fai[N],len[N],nxt[N][5],ans,dp[N],lst,tot;
    int getfail(int x,int y)
    {
        while(s[y-len[x]-1]!=s[y])x=fai[x];
        return x;
    }
    void build()
    {
        s[0]='#';len[1]=-1,tot=1,lst=0,fai[0]=1,len[0]=0;
        for(int i=1;i<=n;i++)
        {
            int x=val[(int)s[i]],p=getfail(lst,i);
            if(!nxt[p][x])
            {
                len[++tot]=len[p]+2;
                memset(nxt[tot],0,sizeof(nxt[tot]));
                fai[tot]=nxt[getfail(fai[p],i)][x];
                nxt[p][x]=tot;
                if(len[tot]<=2)f[tot]=fai[tot];
                else
                {
                    int tmp=f[p];
                    while(s[i-len[tmp]-1]!=s[i]||len[tmp]+2>len[tot]/2)tmp=fai[tmp];
                    f[tot]=nxt[tmp][x];
                }
            }
            lst=nxt[p][x];
        }
    }
    queue<int> q;
    void work()
    {
        scanf("%d",&m);
        val['A']=0,val['T']=1,val['C']=2,val['G']=3;
        while(m--)
        {
            memset(nxt[0],0,sizeof(nxt[0])),memset(nxt[1],0,sizeof(nxt[1]));
            scanf("%s",s+1);
            ans=n=strlen(s+1);
            build();
            for(int i=2;i<=tot;i++)dp[i]=len[i];
            dp[0]=1;q.push(0);
            while(q.size())
            {
                int u=q.front();q.pop();
                for(int i=0,x=nxt[u][i];i<4;i++,x=nxt[u][i])
                    if(x)
                    {
                        dp[x]=min(dp[x],min(dp[u]+1,dp[f[x]]+1+len[x]/2-len[f[x]]));
                        ans=min(ans,dp[x]+n-len[x]);
                        q.push(x);
                    }
            }
            printf("%d\n",ans);
        }
    }
}
int main()
{
    //freopen("1.in","r",stdin);
    //freopen("1.out","w",stdout);
    FGF::work();
    return 0;
}

应用四:优化dp

划重点 :这部分是本篇博客最难理解也最巧妙的地方!

问题描述:

给定一个长度为 n 的字符串,将其中若干个位置断开,要求每个分割出来的串都满足回文,询问最少需要断开几次/断开的方案数。

n\le 10^5

dp_i 表示 s 串长度为 i 的前缀的最小划分数/方案数。暴力转移是 O(n^2) 的,我们需要用上面提到的性质优化。

回文树上的每个节点 u 需要多维护两个信息:设 dif_u=len_u-len_{fail_u}sl_u 表示 u 沿着 fail 链向上跳遇到的第一个满足 dif_u\neq dif_v 的节点 v , 也就是 u 所在等差数列中长度最小的那个节点。

考虑将一个等差数列表示的所有回文串的 dp 值合并到最长的那一个回文串对应节点上。设 g_v 表示 v 所在等差数列的 dp 值之和,且 v 是这个等差数列中长度最长的节点;换句话说,就是 vsl_v 这一段上 (不包含 sl_v ,它被看做是下一条链的开头)转移位置的 dp 值的和。

假设当前枚举到第 i 个字符,回文树上对应节点为 x 。如图(图来自oi-wiki),最下方的橙色位置是增加 i 字符而产生的,g_x 为橙色三个位置的 dp 值之和。 fail_x出现的位置是 i-dif_xg_{fail_x} 包含的 dp 值是蓝色位置。 g_x 实际上等于 g_{fail_x}和多出来一个位置的 dp 值之和,多出来的位置是 i-(len_{sl_x}+dif_x) 。最后再用 g_x 去更新 dp_i ,这部分等差数列的贡献就计算完毕了,不断跳 sl_x ,重复这个过程即可。

考虑优化后的复杂度。回顾一下上面提到过的性质:

根据这个结论,如果按 sl 指针向上跳的话,最多向上跳 \log |s| 次。因此,这么做的复杂度为 O(n\log n)

感觉我讲的不是很清楚。如果没看懂建议参考 oi-wiki 和这位巨佬的博客。(讲得太棒啦!证明也很清楚orz)

例题:CF932G Palindrome Partition

a=s[1]s[n]s[2]s[n-1]\dots s[\dfrac{n}{2}-1]s[\dfrac{n}{2}]。那么问题等价于求 a 的偶回文串划分数。只在偶数位置更新dp数组即可。

#include<bits/stdc++.h>
typedef long long ll;
using namespace std;
namespace FGF
{
    int n,m;
    const int N=1e6+5;
    const int mo=1e9+7;
    char s[N],a[N];
    int dif[N],sl[N],nxt[N][30],fai[N],len[N],tot,lst;
    ll g[N],dp[N];
    int getfail(int x,int y)
    {
        while(a[y-len[x]-1]!=a[y])x=fai[x];
        return x;
    }
    void inser(char c,int y)
    {
        int x=c-'a',p=getfail(lst,y);
        if(!nxt[p][x])
        {
            len[++tot]=len[p]+2;
            fai[tot]=nxt[getfail(fai[p],y)][x];
            nxt[p][x]=tot;
            dif[tot]=len[tot]-len[fai[tot]];
            if(dif[tot]==dif[fai[tot]])sl[tot]=sl[fai[tot]];
            else sl[tot]=fai[tot];
        }
        lst=nxt[p][x];
    }
    void work()
    {
        scanf("%s",s+1);
        n=strlen(s+1);
        if(n&1)
        {
            puts("0");
            return;
        }
        for(int i=1;i<=n;i+=2)a[i]=s[(i+1)/2];
        reverse(s+1,s+n+1);
        for(int i=2;i<=n;i+=2)a[i]=s[(i+1)/2];
        dp[0]=1;
        s[0]='#',len[1]=-1,fai[0]=1,tot=1;
        for(int i=1;i<=n;i++)
        {
            inser(a[i],i);
            for(int x=lst;x;x=sl[x])
            {
                g[x]=dp[i-dif[x]-len[sl[x]]];
                if(dif[x]==dif[fai[x]])g[x]=(g[x]+g[fai[x]])%mo;
                if(i%2==0)dp[i]=(dp[i]+g[x])%mo;
            }
        }
        printf("%lld",dp[n]);
    }
}
int main()
{
    FGF::work();
    return 0;
}

再丢个题:CF906E Reverses。好像也是要这样优化来做的,以后有时间再补吧

upd:

补了。令 s=a_1b_1a_2b_2\dots,问题转化为求 s 的最小偶回文划分,转移的时候记一下路径即可。

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
namespace FGF
{
    int n,m;
    const int N=1e6+5;
    char s[N],a[N],b[N];
    int dif[N],sl[N],tot,len[N],dp[N],fai[N],lst,g[N],pre[N],ans[N],nxt[N][26];
    int getfail(int x,int y)
    {
        while(s[y-len[x]-1]!=s[y])x=fai[x];
        return x;
    }
    void inser(int x,int y)
    {
        int pos=getfail(lst,y);
        if(!nxt[pos][x])
        {
            len[++tot]=len[pos]+2;
            fai[tot]=nxt[getfail(fai[pos],y)][x];
            nxt[pos][x]=tot;
            dif[tot]=len[tot]-len[fai[tot]];
            if(dif[tot]==dif[fai[tot]])sl[tot]=sl[fai[tot]];
            else sl[tot]=fai[tot];
        }
        lst=nxt[pos][x];
    }
    void work()
    {
        scanf("%s%s",a+1,b+1);
        n=strlen(a+1);
        for(int i=1,cnt=0;i<=n;i++)
            s[++cnt]=a[i],s[++cnt]=b[i];
        n*=2;
        memset(dp,0x3f,sizeof(dp)),memset(g,0x3f,sizeof(g));
        s[0]='#',tot=1,len[1]=-1,dp[0]=0,g[0]=0,fai[0]=1;
        for(int i=1;i<=n;i++)
        {
            inser(s[i]-'a',i);
            for(int x=lst;x;x=sl[x])
            {
                g[x]=dp[i-dif[x]-len[sl[x]]],pre[x]=i-dif[x]-len[sl[x]];
                if(dif[x]==dif[fai[x]]&&g[fai[x]]<g[x])g[x]=g[fai[x]],pre[x]=pre[fai[x]];
                if(i%2==0&&g[x]+1<dp[i])dp[i]=g[x]+1,ans[i]=pre[x];
                if(i%2==0&&s[i]==s[i-1]&&dp[i-2]<dp[i])dp[i]=dp[i-2],ans[i]=i-2;
            }
        }
        if(dp[n]>=0x3f3f3f3f)
        {
            puts("-1");
            return;
        }
        printf("%d\n",dp[n]);
        for(int x=n;x;x=ans[x])
            if(x-ans[x]!=2)printf("%d %d\n",ans[x]/2+1,x/2);
    }
}
int main()
{
    FGF::work();
    return 0;
}

参考资料