记忆化搜索——最好写的数位 dp

Brilliant11001

2024-08-29 16:22:03

Algo. & Theory

简介

数位 dp 解决的是与数字有关的一类计数问题,在求解过程中常把一个数字的每一位都拆开来看,比如十进制下就是把千位、百位、十位、个位上的数字都拆开来看,其他进制类比十进制。

数位 dp 的问题一般比较显眼,有几个常见形式:

  1. 要求统计满足一定条件的数的数量(即,最终目的为计数);

  2. 这些条件经过转化后可以使用「数位」的思想去理解和判断;

  3. 输入会提供一个数字区间(有时也只提供上界)来作为统计的限制;

  4. 上界很大(比如 10^{18}),暴力枚举验证会超时。

(from OI Wiki)

在数位 dp 的实现上,我通常采用的是记忆化搜索,这样写不仅容易,而且易于拓展,还可以当板子来背,这已经是 dp 中少见的了。

例题

P2657 [SCOI2009] windy 数

题目大意

[l, r] 内有多少个数十进制表示下所有的相邻数位数值之差大于等于 2

思路

考虑从最高位开始填数,在记忆化搜索时记录 pos 表示当前填到第几位,pre\_num 表示上一个位置填的数是什么,limit 记录前面放的数是否顶上界,zero 记录当前这位之前是否是前导零。

先把上界的每一位抠出来,那么当搜索放第 i 位时,要先确定这一位能放什么数,若前面都是贴着上界放的,那么这一位最多只能放 num_{pos},否则就不受限制。

然后在枚举第 i 位放什么时还要满足相邻数位数值之差大于等于 2 的限制,这个很好转移。

当然,若是前导零的话还要特别注意,因为这时 0\sim 9 都可以放,而如果没有考虑到这一点,最高位就只能至少放 2 了。

在加记忆化时还要注意,若出现了顶上界或前导零的情况是不能记忆化的(当然你也可以多开两维来额外存,不过我觉得没什么必要,这个时候直接暴力搜索就行了,反正也费不了多少时间)。

\texttt{Code:}
#include <cmath>
#include <vector>
#include <cstring>
#include <iostream>

using namespace std;

const int N = 15;

int l, r;
int f[N][N];
vector<int> num;

int dfs(int pos, int pre_num, bool limit, bool zero) {
    if(pos < 0) return 1; //边界,若放完了最后一位就返回 1,因为我们一直是按要求放的,所以此时也是一种情况
    if(!limit && pre_num >= 0 && f[pos][pre_num] != -1) return f[pos][pre_num]; //记忆化
    int mx = (limit ? num[pos] : 9); //计算上界
    int res = 0;
    for(int i = 0; i <= mx; i++) {
        if(abs(i - pre_num) < 2) continue;
        if(!i && zero) //特判前导零的情况,这时 prenum 设为 -2 确保下一位不受任何限制
            res += dfs(pos - 1, -2, limit && (i == num[pos]), 1);
        else 
            res += dfs(pos - 1, i, limit && (i == num[pos]), 0);
    }
    if(!limit && !zero) f[pos][pre_num] = res;
    return res;
}

int calc(int x) {
    num.clear();
    int tmp = x;
    //先把上界的每一位抠出来
    while(tmp) {
        num.push_back(tmp % 10);
        tmp /= 10;
    }
    //初始化
    memset(f, -1, sizeof f);
    return dfs(num.size() - 1, -2, 1, 1);
}

int main() {
    scanf("%d%d", &l, &r);
    //数位 dp 通常都有这种类似前缀和的形式
    printf("%d\n", calc(r) - calc(l - 1));
    return 0;
}

P2602 [ZJOI2010] 数字计数

题目大意

[l, r] 中的数在十进制表示下 0\sim 9 各个数码分别出现了多少次。

思路

对每个数码分开计算,还是拆成 r 减去 l - 1 的形式。

在记忆化搜索时记录一下当前考虑的数码 d 填了多少次,在所有位填完后再计算即可。

后面都只放主要代码了,因为真的很板子。

\texttt{Code:}
ll dfs(int pos, ll cnt, bool limit, bool zero, int d) {
    if(pos < 0) return cnt;
    if(!limit && !zero && f[pos][cnt] != -1) return f[pos][cnt];
    int mx = (limit ? num[pos] : 9);
    ll res = 0;
    for(int i = 0; i <= mx; i++)
        res += dfs(pos - 1, cnt + ((!zero || i) && (i == d)), limit && (i == num[pos]), zero && (!i), d);
    if(!limit && !zero) f[pos][cnt] = res;
    return res;
}

Digit Sum

题目大意

[1, N] 中有多少个数在十进制表示下数码和是 D 的倍数。

数据范围:1\le N\le 10^{10000},1\le D\le 100

思路

很明显的数位 dp。

首先把上界 N 的每一位抠出来,然后进行填数,个人喜欢从最高位开始填。

加上记忆化,设 f(pos, r) 表示在没有顶上界和前导零的情况下,当前填到了第 pos 位,余数为 r 的数的个数。

然后在搜索过程中记一下当前数位和 \bmod p 等于多少,再简单转移一下即可,详细注释在代码中。

这里再讲一下数位 dp 如何分析时间复杂度。

注意到状态数为 D\cdot\lg N,每次转移时最多枚举 10 个可填的数,所以时间复杂度为 O(D\cdot \lg N),可以通过此题。

注意!由于最后要 -1,所以为防止减为负数要先加上模数再取模。

\texttt{Code:}
ll dfs(int pos, int r, bool limit, bool zero) {
    if(pos < 0) return (r == 0);
    if(!limit && !zero && f[pos][r] != -1) return f[pos][r];
    int mx = (limit ? num[pos] : 9);
    ll res = 0;
    for(int i = 0; i <= mx; i++)
        res = (res + dfs(pos - 1, (r + i) % d, limit && (i == num[pos]), zero && (!i))) % mod;
    if(!limit && !zero) f[pos][r] = res;
    return res;
}

P4127 [AHOI2009] 同类分布

题目大意

求出 [l, r] 中各位数字之和能整除原数的数的个数。

思路

若是要求整除的数是同一个数,那就和上一题一样,但若除的数都不一样该怎么办?

那我们就换一种思路,直接枚举数位和,然后在搜索时每填一个数就相应地减去,同时记录一下余数,其他的参数照搬即可。

由于最多有 18 位,所以要枚举 1\sim 162

\texttt{Code:}
ll dfs(int pos, int sum, int r, bool limit, bool zero) {
    if(sum < 0) return 0;
    if(pos < 0) return !sum && !r;
    if(!limit && !zero && f[pos][sum][r] != -1) return f[pos][sum][r];
    int mx = (limit ? num[pos] : 9);
    ll res = 0;
    for(ll i = 0; i <= mx; i++)
        res += dfs(pos - 1, sum - i, (r * 10 + i) % d, limit && (i == num[pos]), zero && (!i));
    if(!limit && !zero) f[pos][sum][r] = res;
    return res;
}

P10958 启示录

题目大意:

#### 思路: 考虑数位 dp。 一般数位 dp 问题有两种常见形式: 1. 询问 $[l, r]$ 内有多少个符合条件的数; 2. 询问满足条件的第 $k$ 大(小)的数是什么。 很显然这道题是第二种形式。 首先问题 $1$ 很简单,那我们考虑将第二个问题转化成第一个问题来做。 因为答案具有单调性,于是可以二分判定。 每次二分到一个值 $mid$,计算 $[1, mid]$ 的魔鬼数个数,若大于等于 $x$,则说明所求在 $mid$ 左侧,否则在 $mid$ 右侧。 接着考虑问题 $1$,这里采用记忆化搜索的方式,注释在代码中。 ```cpp int dfs(int pos, int cnt, bool flag, bool limit) { if(pos < 0) return flag; if(!limit && f[pos][cnt][flag] != -1) return f[pos][cnt][flag]; int mx = (limit ? num[pos] : 9); int res = 0; for(int i = 0; i <= mx; i++) { int ncnt; if(i == 6) ncnt = cnt + 1; else ncnt = 0; res += dfs(pos - 1, ncnt, flag || (ncnt >= 3), limit && (i == num[pos])); } if(!limit) f[pos][cnt][flag] = res; return res; } ``` 这里我直接把二分值域拉满了,但是实测发现第 $50000000$ 个魔鬼数只有 $6668056399$。 时间复杂度为:$O(N^2MT\log V)$,这里 $N$ 表示数字位数,$V$ 表示二分值域,$M$ 表示每次枚举填的数的个数,可看作 $10$。 ## 拓展: 数位 dp 一般会与 Lucas 定理一起食用,毕竟 Lucas 定理就是逐位求组合数。 ### [P7976 「Stoi2033」园游会](https://www.luogu.com.cn/problem/P7976) ### 题目大意: 设函数 $F(x) := (x + 1) \bmod 3 − 1$,$T$ 次询问,计算: $$\sum\limits_{i = 0}^{n}\sum\limits_{j}F\left({i\choose j}\right)$$ ### 思路: 看到奇奇怪怪的组合数求和首先考虑 $\text{Lucas}$,将原数在 $3$ 进制下拆位,得: $${i\choose j} = \prod\limits_{k = 1}^{m}{i_k\choose j_k}\bmod 3$$ 其中 $m$ 表示 $i$ 和三进制下较长的那个数的数字位数。 接着注意到 $F$ 函数是一个**积性函数**(这个可以分九类讨论证明),即 $F(xy) = F(x)F(y)$,所以实际上 $F\left({i\choose j}\right)$ 要计算的就是所有 $F\left({i_k\choose j_k}\right)$ 的乘积。 对于每一个 $i$,$j$ 的每一位就独立了,这时候再分类讨论: 1. 当 $i_k = 0$ 时,$j_k$ 取 $0$ 的时候有贡献,此时这一位的值为 $F(1) = 1$; 2. 当 $i_k = 1$ 时,$j_k$ 取 $0,1$ 的时候有贡献,此时这一位的值为 $F(1) + F(1) = 2$; 3. 当 $i_k = 2$ 时,$j_k$ 取 $0,1,2$ 的时候有贡献,此时这一位的值为 $F(1) + F(2) + F(1) = 1$。 而乘 $1$ 是不会是答案增加的,所以只用考虑乘 $2$ 的个数就行了,即: $$\sum\limits_{j}F\left(i\choose j\right) = \left(\prod\limits_{i_k = 0}1\right)\left(\prod\limits_{i_k = 1}2\right)\left(\prod\limits_{i_k = 2}1\right) = 2^{\#\{i_k = 1\}}$$ 然后就会发现统计一下三进制表示下 $1$ 的个数就行了,数位 dp 即可。 $\texttt{Code:}
#include <cmath>
#include <vector>
#include <cstring>
#include <iostream>
#include <algorithm>

using namespace std;
typedef long long ll;

const int N = 65, mod = 1732073999;

int T;
ll vmax, n;
vector<int> num;
ll f[N][N];

ll qpow(int a, int b) {
    ll ans = 1, base = a;
    while(b) {
        if(b & 1) ans = ans * base % mod;
        base = base * base % mod;
        b >>= 1;
    }
    return ans;
}

ll dfs(int pos, int cnt, bool limit, bool zero) {
    if(pos < 0) return qpow(2, cnt);
    if(!limit && !zero && ~f[pos][cnt]) return f[pos][cnt];
    int mx = (limit ? num[pos] : 2);
    ll res = 0;
    for(int i = 0; i <= mx; i++)
        res = (res + dfs(pos - 1, cnt + (i == 1), limit && (i == num[pos]), zero && (!i))) % mod;
    if(!limit && !zero) f[pos][cnt] = res;
    return res;
}

ll calc(ll x) {
    num.clear();
    ll tmp = x;
    while(tmp) {
        num.push_back(tmp % 3);
        tmp /= 3;
    }
    return dfs(num.size() - 1, 0, 1, 1);
}

int main() {
    scanf("%d%lld", &T, &vmax);
    memset(f, -1, sizeof f);
    while(T--) {
        scanf("%lld", &n);
        printf("%lld\n", calc(n));
    }
    return 0;
}

习题:

P8764 [蓝桥杯 2021 国 BC] 二进制问题
P6218 [USACO06NOV] Round Numbers S
P4124 [CQOI2016] 手机号码
P4317 花神的数论题
P7976 「Stoi2033」园游会
P3413 SAC#1 - 萌数
P3286 [SCOI2014] 方伯伯的商场之旅
P2481 [SDOI2010] 代码拍卖会