【2021牛客寒假第五场】C-比武招亲(下)欧拉降幂+多项式求逆预处理伯努利数计算等幂求和

传送门:
https://ac.nowcoder.com/acm/contest/9985/C

这应该是我打的最好的一场了。
在这里插入图片描述

前置技能

\red{欧拉降幂、多项式求逆、伯努利数、等幂求和}

题意

n

[

1

,

m

]

给一个长度为n的序列,往里面填数字,数字范围为[1,m],可以重复填。

n[1,m]

m

=

n

n

n

.

.

m=n^{n^{n^{..}}}

m=nnn..

a

n

s

%

p

p

请计算所有序列贡献和ans \% p,p一定是一个质数。

ans%pp

思路

m

=

n

n

n

.

.

p

看到m=n^{n^{n^{..}}}这个式子,很多人就知道模p之后是一个定值,根据欧拉降幂。

m=nnn..p

具体请看

欧拉降幂

m

好,我们算出来m之后,到这里还是非常简单的,难点在后面。

m

n

m

便

n

m

因为n和m我们都知道,为了方便起来,我们把n和m互换一下。

nm便nm

n

m

n表示可以用的数字,m表示位置个数。

nm

B

我们还像B题一样,计算每一个最大值的贡献值。

B

x

m

1

x

假设最大值为x,则剩下m-1个位置填的数不能超过x,那么这么的序列一共有

xm1x

x

m

(

x

1

)

m

x^m-(x-1)^m

xm(x1)m

x

x

x

就是存在x数字的减掉不存在x数字的个数,并且一定有x。

xxx

m

x

m

有人问,为什么不是m*x^m,因为序列中是有可重复数字的。

mxm

m

=

3

,

x

=

2

2

1

x

2
  

2
  

1

2
  

2
  

1

比如m=3,x=2,剩下为2和1,把x放在第一个得2\;2\;1,放在第二个得2\;2\;1。这是不对的。

m=3,x=221x221221

但是比赛中我没有直接看出来,还是用二项式定理推出来,好菜!

:

根据计算出来的贡献,得到最终公式:

:

i

=

1

n

i

[

i

m

(

i

1

)

m

]
  

m

o

d
  

p

\sum_{i=1}^ni*[i^m-(i-1)^m]\;mod\;p

i=1ni[im(i1)m]modp

n

1

e

9

O

(

n

)

由于n是我们欧拉降幂求出来的,最大也会有\red{1e9}这么多,所以O(n)肯定是不行的。

n1e9O(n)

那该怎么办呢?这就是这题最难的地方了,伯努利数就诞生了。


https://oi-wiki.org/math/bernoulli/

因为\red{伯努利数一大应用就是等幂求和}。

O

I

W

i

k

i

O

(

n

2

)

O

(

n

l

o

g

n

)

根据OI-Wiki上介绍,暴力求复杂度为\red{O(n^2)},不过\red{多项式求逆}可以达到\red{O(nlogn)}

OIWikiO(n2)O(nlogn)

n

1

e

5

O

(

n

l

o

g

n

)

这里的n不是数字个数,而是幂次。所以为什么说长度为1e5,就是为O(nlogn)准备的。

n1e5O(nlogn)

O

(

n

)

O

(

n

l

o

g

n

)

最后等幂求和O(n)处理,最终复杂度为O(nlogn)。完全可以接受。

O(n)O(nlogn)

所以这题整个步骤就是

m

O

(

n

l

o

g

n

)

O

(

n

)

\red{欧拉降幂求m,多项式求逆O(nlogn)处理伯努利数,伯努利数O(n)处理等幂求和。}

mO(nlogn)O(n)

Code(715MS)

#include "bits/stdc++.h"

using namespace std;

typedef long long ll;
typedef long double ld;

const double eps = 1e-6;
const double PI = acos(-1);

const int N = 1e6 + 10, M = 1e5 + 10;

struct Complex {
    double x, y;
    Complex(double a = 0, double b = 0): x(a), y(b) {}
    Complex operator + (const Complex &rhs) { return Complex(x + rhs.x, y + rhs.y); }
    Complex operator - (const Complex &rhs) { return Complex(x - rhs.x, y - rhs.y); }
    Complex operator * (const Complex &rhs) { return Complex(x * rhs.x - y * rhs.y, x * rhs.y + y * rhs.x); }
    Complex conj() { return Complex(x, -y); }
} w[N];

ll mod;
int tr[N];
ll F[N], G[N];

ll quick_pow(ll a, ll b, ll p) {
    ll ans = 1;
    while(b) {
        if(b & 1) ans = ans * a % p;
        a = a * a % p;
        b >>= 1;
    }
    return ans % p;
}

void FFT(Complex *A, int len) {
    for (int i = 0; i < len; i++) if(i < tr[i]) swap(A[i], A[tr[i]]);
    for (int i = 2, lyc = len >> 1; i <= len; i <<= 1, lyc >>= 1)
        for (int j = 0; j < len; j += i) {
            Complex *l = A + j, *r = A + j + (i >> 1), *p = w;
            for (int k = 0; k < i >> 1; k++) {
                Complex tmp = *r * *p;
                *r = *l - tmp, *l = *l + tmp;
                ++l, ++r, p += lyc;
            }
        }
}

inline void MTT(ll *x, ll *y, ll *z, int n) {
    int len = 1; while (len <= n) len <<= 1;
    for (int i = 0; i < len; i++) tr[i] = (tr[i >> 1] >> 1) | (i & 1 ? len >> 1 : 0);
    for (int i = 0; i < len; i++) w[i] = w[i] = Complex(cos(2 * PI * i / len), sin(2 * PI * i / len));

    for (int i = 0; i < len; i++) (x[i] += mod) %= mod, (y[i] += mod) %= mod;
    static Complex a[N], b[N];
    static Complex dfta[N], dftb[N], dftc[N], dftd[N];

    for (int i = 0; i < len; i++) a[i] = Complex(x[i] & 32767, x[i] >> 15);
    for (int i = 0; i < len; i++) b[i] = Complex(y[i] & 32767, y[i] >> 15);
    FFT(a, len), FFT(b, len);
    for (int i = 0; i < len; i++) {
        int j = (len - i) & (len - 1);
        static Complex da, db, dc, dd;
        da = (a[i] + a[j].conj()) * Complex(0.5, 0);
        db = (a[i] - a[j].conj()) * Complex(0, -0.5);
        dc = (b[i] + b[j].conj()) * Complex(0.5, 0);
        dd = (b[i] - b[j].conj()) * Complex(0, -0.5);
        dfta[j] = da * dc;
        dftb[j] = da * dd;
        dftc[j] = db * dc;
        dftd[j] = db * dd;
    }
    for (int i = 0; i < len; i++) a[i] = dfta[i] + dftb[i] * Complex(0, 1);
    for (int i = 0; i < len; i++) b[i] = dftc[i] + dftd[i] * Complex(0, 1);
    FFT(a, len), FFT(b, len);
    for (int i = 0; i < len; i++) {
        int da = (ll)(a[i].x / len + 0.5) % mod;
        int db = (ll)(a[i].y / len + 0.5) % mod;
        int dc = (ll)(b[i].x / len + 0.5) % mod;
        int dd = (ll)(b[i].y / len + 0.5) % mod;
        z[i] = (da + ((ll)(db + dc) << 15) + ((ll)dd << 30)) % mod;
    }
}

int getLen(int n) {
    int len = 1; while (len < (n << 1)) len <<= 1;
    for (int i = 0; i < len; i++) tr[i] = (tr[i >> 1] >> 1) | (i & 1 ? len >> 1 : 0);
    for (int i = 0; i < len; i++) w[i] = w[i] = Complex(cos(2 * PI * i / len), sin(2 * PI * i / len));
    return len;
}

void Get_Inv(ll *f, ll *g, int n) {
    if(n == 1) { g[0] = quick_pow(f[0], mod - 2, mod); return ; }
    Get_Inv(f, g, (n + 1) >> 1);
    int len = getLen(n);
    static ll c[N];
    for(int i = 0;i < len; i++) c[i] = i < n ? f[i] : 0;
    MTT(c, g, c, len); MTT(c, g, c, len);
    for(int i = 0;i < n; i++) g[i] = (2ll * g[i] - c[i] + mod) % mod;
    for(int i = n;i < len; i++) g[i] = 0;
    for(int i = 0;i < len; i++) c[i] = 0;
}

ll ff[N], invff[N], inv[N];
ll B[N];

void Init() {
    ff[0] = ff[1] = inv[0] = inv[1] = invff[0] = invff[1] = 1;
    for(int i = 2;i < M; i++)
    {
        ff[i] = ff[i - 1] * i % mod;
        inv[i] = mod - (mod / i) * inv[mod % i] % mod;
        invff[i] = invff[i - 1] * inv[i] % mod;
    }
}

ll C(ll m, ll n) {
    if(m < 0 || n < 0 || n > m)
        return 0;
    ll ans = ff[m];
    ans = ans * invff[n] % mod;
    ans = ans * invff[m - n] % mod;
    return ans;
}

void init_B(int m) {
    for(int i = 0;i <= m + 10; i++) F[i] = invff[i + 1];
    Get_Inv(F, G, m + 10);
    for(int i = 0;i <= m + 10; i++) B[i] = G[i] * ff[i] % mod;
}

ll n;

ll ph(ll x) {
    ll res = x,a = x;
    for(ll i = 2;i * i <= x; i++) {
        if(a % i == 0) {
            res=res / i * (i - 1);
            while(a % i == 0) a /= i;
        }
    }
    if(a > 1) res = res / a * (a - 1);
    return res;
}

ll f(ll p) {
    if(p == 1) return 0;
    ll k = ph(p);
    return quick_pow(n, f(k) + k, p);
}

void solve() {
    cin >> n >> mod;
    Init();
    ll k = f(mod);
    ll m = n; n = k;
    init_B(m);
    ll ans = 0, prod = n % mod;
    for(int i = m; ~i ; i--) {
        ans = (ans + prod * B[i] % mod * C(m + 1, i) % mod) % mod;
        prod = prod * n % mod;
    }
    ans = ans * quick_pow(m + 1, mod - 2, mod) % mod;
    ans = (n * quick_pow(n, m, mod) % mod - ans) % mod;
    cout << (ans % mod + mod) % mod << endl;
}

signed main() {
    solve();
}

版权声明:本文为fztsilly原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
原文链接:https://blog.csdn.net/fztsilly/article/details/113947241