题解:AT_abc322_g [ABC322G] Two Kinds of Base

rzh01014

2024-11-17 13:55:29

Solution

ABC322G

前言

MnZn第一次出题解。

题目描述

有个序列 S=(S_1,S_2,S_3,…,S_n)
定义函数 f(S,a)=\sum\limits_{i=1}^{k}S_i \cdot a^{k-i}
给定 N,X,求所有满足 F(S,a,b)=X 的三元组数量。

做法

定义 F(S,a,b) 的结果为 ansans=\sum\limits_{i=1}^{k}S_i \times(a^{k-i}-b^{k-i})
对于 a^{k}-b^{k},可以化成 (a-b)\sum\limits_{i=0}^{k-1}a^{k-i-1}b^{i} 的形式,因此可以发现:(a-b)|ans
不难发现 ans 是关于 k 指数级增长的,故 k 的范围不会很大,小于 18

证明如下:
定义 ans_ik=i 时的结果,令 sum_i=\sum\limits_{j=1}^{i}(a-b)\times s_i \times b^{j-i}
易发现 ans_i=a \times ans_{i-1}-sum_i

因为本题的 k 的数据范围不大,因此可以分类讨论 k 的范围。

因此该问题是由 k=2k\geq 3 的情况加起来,总复杂度在 O(X\log^2X) 级别。

Code

#include <bits/stdc++.h>
#define int long long
typedef long long ll;
using namespace std;
const int N=2e5+5,mod=998244353;
int n;
ll ans=0,x;
inline ll ksm(int x,ll y) {
    ll ret=1;
    while(y) {/*这里不能取模,此时是求a^k-b^k是否符合要求,若取模了会使答案不正确*/
        if(y&1) {
            ret=ret*x;
        }
        x=x*x;
        y>>=1;
    }
    return ret;
}
signed main() {
    ios::sync_with_stdio(0);
    cin.tie(0);
    cin>>n>>x;
    for(int s=1; s<=sqrt(x); s++) {/*k=2的情况*/
        if(x%s) continue;
        if(s<10) {
            for(int a=s+1+(x/s); a<=n; a++) {
                int b=a-(x/s);
                if(b>=10) break;
                ans=(ans+b)%mod;
            }
            if(n-(x/s)>=10) ans=(ans+10*(n-(x/s)-9)%mod)%mod;
        }
        if(s*s!=x&&x/s<10) {
            for(int a=x/s+s+1; a<=n; a++) {
                int b=a-s;
                if(b>=10) break;
                ans=(ans+b)%mod;
            }
            if(n-s>=10) ans=(ans+10*(n-s-9)%mod)%mod;
        }
    }
    for(int k=3; k<=18; k++) {/*枚举k>2*/
        for(int a=1; a<=n; a++) {
            if(ksm(a,k-1)-ksm(a-1,k-1)>x) break;
            for(int s=1; s<a; s++) {
                int b=a-s;
                if(ksm(a,k-1)-ksm(b,k-1)>x) break;
                int v=x,flag=1;
                for(int i=1; i<k; i++) {
                    int noww=ksm(a,k-i)-ksm(b,k-i);
                    if(v/noww>=min(10ll,min(a,b))) {
                        flag=0;
                        break;
                    }
                    v=(v%noww);
                }
                if(flag&&!v) ans=(ans+min(10ll,min(a,b)))%mod;
            }
        }
    }
    cout<<ans;
    return 0;
}