一.题目大意

一行有$$n$$个球,现在将这些球分成$$K$$组,每组可以有一个球或相邻两个球。一个球只能在至多一个组中。求对于$$1\le K\le m$$的所有$$K$$分别有多少种分组方法。答案对$$998244353$$取模。
$$1\le n \le 10^9$$
$$1\le K \le 2^{15}$$

二.解题报告

算法1

设$$f[i][j]$$表示把$$i$$个球分成$$j$$组的方案数。
则$$f[i][j]=f[i-1][j-1]+f[i-2][j-1]+f[i-1][j]$$
最终答案为$$f[n][1...m]$$
时间复杂度$$O(nm)$$

#include <bits/stdc++.h>
using namespace std;
#define ll long long
const int N=5100,mod=998244353;
int n,m;
int f[N][N];
int main()
{
    //freopen("tt.in","r",stdin);
    scanf("%d%d",&n,&m);
    f[1][1]=f[1][0]=f[0][0]=1;
    for(int i=2;i<=n;i++)
    {
        f[i][0]=1;
        for(int j=1;j<=i&&j<=m;j++)
            f[i][j]=((f[i-1][j-1]+f[i-2][j-1])%mod+f[i-1][j])%mod;
    }
    for(int i=1;i<=m;i++)
        printf("%d ",f[n][i]);
    return 0;
}

算法2

设$$a(K)$$表示分为$$K$$组的答案。枚举相邻两个一组的个数$$i$$
$$a(K)=\sum\limits_{i=0}^{min(K,n-K)}C_{n-i}^K C_K^i$$

$$=\sum\limits_{i=0}^{min(K,n-K)} \frac{(n-i)!}{(n-i-K)!i!(K-i)!}$$

其中$$\frac{1}{i!}$$和$$\frac{1}{(K-i)!}$$可以通过$$O(m)$$预处理每次$$O(1)$$计算。
设$$g(i,K) = \frac{(n-i)!}{(n-i-K)!}$$
则$$g(i,K) = g(i,K-1)*(n-i-K+1)$$
对于每个$$K$$可以$$O(m)$$处理。
时间复杂度 $$O(m^2)$$

#include <bits/stdc++.h>
using namespace std;
#define ll long long
const int N=21000,mod=998244353;
int n,m,ans;
int jc[N],njc[N],g[N];
int qpow(int x,int y)
{
    int ret=1;
    while(y)
    {
        if(y&1)ret=(ll)ret*x%mod;
        x=(ll)x*x%mod;y>>=1;
    }
    return ret;
}
void init()
{
    jc[0]=njc[0]=1;
    for(int i=1;i<=m;i++)jc[i]=(ll)jc[i-1]*i%mod;
    njc[m]=qpow(jc[m],mod-2);
    for(int i=m-1;i>=1;i--)
        njc[i]=(ll)njc[i+1]*(i+1)%mod;
}
int main()
{
    //freopen("tt.in","r",stdin);
    scanf("%d%d",&n,&m);
    init();
    for(int i=0;i<=m;i++)g[i]=1;
    for(int K=1;K<=m;K++)
    {
        ans=0;
        for(int i=0;i<=m;i++)
            g[i]=(ll)g[i]*(n-i-K+1)%mod;
        for(int i=0;i<=K&&i<=n-K;i++)
            ans=(ans+(ll)g[i]*njc[i]%mod*njc[K-i]%mod)%mod;
        printf("%d ",ans);
    }
    return 0;
}

算法3

在算法1基础上将$$f[i][0...m]$$看成一个关于x的多项式,称为$$f[i]$$
通过算法1可知如果已知$$f[n],f[n+1]$$可以$$O(m)$$求出$$f[n+2]$$
考虑在已知$$f[a],f[b],f[a-1],f[b-1]$$时计算$$f[a+b]$$
$$f[a+b]$$即为将长度为$$a$$和长度为$$b$$的两段连接起来得到的。
在连接处有两种情况,如果有一个两个球的组跨过连接处那么为$$f[a-1] * [b-1]$$(乘号为多项式卷积)
否则为$$f[a] * f[b]$$
因此$$f[a+b]=f[a-1] * f[b-1]+f[a] * f[b]$$
卷积可以通过NTT在$$O(m \log m)$$计算
因此可以在$$O(m\log m)$$通过$$f[a],f[b],f[a-1],f[b-1]$$计算$$f[a+b]$$
因此如果已知$$f[a-2],f[a-1],f[a]$$以及$$f[b-2],f[b-1],f[b]$$
可以通过$$f[a-2],f[a-1]$$,$$f[b-2],f[b-1]$$计算$$f[a+b-2]$$
通过$$f[a-1],f[a]$$,$$f[b-2],f[b-1]$$计算$$f[a+b-1]$$
通过$$f[a+b-2],f[a+b-1]$$计算$$f[a+b]$$
因此可以用倍增来计算$$f[n]$$
时间复杂度$$O(m\log n\log m)$$

#include <bits/stdc++.h>
using namespace std;
#define N (1<<16)+10
#define mod 998244353
#define ll long long
int len,n,m,flag;
int qpow(int x,int y)
{
    int ret=1;
    while(y)
    {
        if(y&1)ret=(ll)ret*x%mod;
        x=(ll)x*x%mod;y>>=1;
    }
    return ret;
}
void NTT(int *a,int len,int type)
{
    for(int i=0,t=0;i<len;i++)
    {
        if(i>t)swap(a[i],a[t]);
        for(int j=len>>1;(t^=j)<j;j>>=1);
    }
    for(int i=2;i<=len;i<<=1)
    {
        int wn=qpow(3,(mod-1)/i);
        for(int j=0;j<len;j+=i)
        {
            int w=1,t;
            for(int k=0;k<i>>1;k++,w=(ll)w*wn%mod)
            {
                t=(ll)a[j+k+(i>>1)]*w%mod;
                a[j+k+(i>>1)]=(a[j+k]-t+mod)%mod;
                a[j+k]=(a[j+k]+t)%mod;
            }
        }
    }
    if(type==-1)
    {
        for(int i=1;i<=len>>1;i++)swap(a[i],a[len-i]);
        int t=qpow(len,mod-2);
        for(int i=0;i<len;i++)a[i]=(ll)a[i]*t%mod;
    }
}
struct node
{
    int a[N],b[N];
    void dft()
    {
        memset(b,0,sizeof(b));
        for(int i=0;i<len>>1;i++)b[i]=a[i];
        NTT(b,len,1);
    }
}a1,a2,a3,b1,b2,b3,ans,ans1,ans2,c1,c2;
void mul(node &r1,const node &r2,const node &r3)
{
    for(int i=0;i<len;i++)
        r1.b[i]=r1.a[i]=(ll)r2.b[i]*r3.b[i]%mod;
    NTT(r1.a,len,-1);
    for(int i=len>>1;i<len;i++)r1.a[i]=0;
}
void inv(node &r1,const node &r2,const node &r3)
{
    r1.a[0]=1;
    for(int i=1;i<len>>1;i++)
        r1.a[i]=(r2.a[i]+r3.a[i-1])%mod;
    r1.dft();
}
int main()
{
    //freopen("tt.in","r",stdin);
    scanf("%d%d",&n,&m);
    for(len=1;len<(m+1)<<1;len<<=1);
    a2.a[0]=1;a3.a[0]=1;a3.a[1]=1;
    a1.dft();a2.dft();a3.dft();
    for(;n;n>>=1)
    {
        if(n&1)
        {
            if(!flag)
            {
                ans=a3;ans1=a2;ans2=a1;
                flag=1;
            }
            else
            {
                mul(c1,ans,a3);mul(c2,ans1,a2);
                inv(ans,c1,c2);
                mul(c1,ans1,a3);mul(c2,ans2,a2);
                inv(ans1,c1,c2);
                for(int i=0;i<len>>1;i++)
                    ans2.a[i]=((ans.a[i+1]-ans1.a[i+1]+mod)%mod-ans1.a[i]+mod)%mod;
                ans2.dft();
            }
        }
        mul(c1,a2,a2);mul(c2,a1,a1);
        inv(b1,c1,c2);
        mul(c1,a3,a3);mul(c2,a2,a2);
        inv(b3,c1,c2);
        b2.a[0]=1;
        for(int i=1;i<len>>1;i++)
            b2.a[i]=((b3.a[i]-b2.a[i-1]+mod)%mod-b1.a[i-1]+mod)%mod;
        b2.dft();
        a1=b1;a2=b2;a3=b3;
    }
    for(int i=1;i<=m;i++)
        printf("%d ",ans.a[i]);
    return 0;
}

算法4

由算法3知$$f[i]=(x+1)f[i-1]+x f[i-2]$$
求$$f[i]$$的通项公式:
设$$f[i]=C_1 T_1(x)^i+C_2 T_2(x)^i$$
其中$$T_1(x),T_2(x)$$为关于$$T(x)$$的方程$$T(x)^2=(x+1)T(x)+x$$的两根。

解得$$
\left\{
\begin{aligned}
T_1(x)=\frac{x+1+\sqrt{x^2+6x+1}}{2}\\
T_2(x)=\frac{x+1-\sqrt{x^2+6x+1}}{2}
\end{aligned}
\right.
$$

将$$f[0]=1,f[1]=1+x$$带入

解得$$
\left\{
\begin{aligned}
C_1=\frac{T_1(x)}{T_1(x)-T_2(x)}\\
C_2=\frac{T_2(x)}{T_2(x)-T_1(x)}
\end{aligned}
\right.
$$

因此$$f[i]=\frac{T_1(x)^{i+1}-T_2(x)^{i+1}}{T_1(x)-T_2(x)}$$

$$ = \frac{(\frac{x+1+\sqrt{x^2+6x+1}}{2})^{n+1}-(\frac{x+1-\sqrt{x^2+6x+1}}{2})^{n+1}}{\sqrt{x^2+6x+1}}$$

其中$$(\frac{x+1-\sqrt{x^2+6x+1}}{2})^{n+1}$$最低次项大于$$n$$,因此可以舍掉

通过多项式开根,多项式求逆可以$$O(m\log m)$$计算$$\frac{x+1+\sqrt{x^2+6x+1}}{2}$$以及$$\frac{1}{\sqrt{x^2+6x+1}}$$
通过多项式exp以及多项式取ln可以$$O(m\log m)$$计算一个多项式的$$n+1$$次幂
时间复杂度$$O(m\log m)$$

#include <bits/stdc++.h>
using namespace std;
#define N (1<<16)+10
#define ll long long
#define mod 998244353
const int inv2=499122177;
int n,m,len;
int tmp[N],sq[N],inv_sq[N],rt1[N],ln1[N],a[N],ans[N],inv[N];
int test1[N],test2[N],test3[N];
int qpow(int x,int y)
{
    int ret=1;
    while(y)
    {
        if(y&1)ret=(ll)ret*x%mod;
        x=(ll)x*x%mod;y>>=1;
    }
    return ret;
}
void NTT(int *a,int len,int type)
{
    for(int i=0,t=0;i<len;i++)
    {
        if(i<t)swap(a[i],a[t]);
        for(int j=len>>1;(t^=j)<j;j>>=1);
    }
    for(int i=2;i<=len;i<<=1)
    {
        int wn=qpow(3,(mod-1)/i);
        for(int j=0;j<len;j+=i)
        {
            int w=1,t;
            for(int k=0;k<i>>1;k++,w=(ll)w*wn%mod)
            {
                t=(ll)a[j+k+(i>>1)]*w%mod;
                a[j+k+(i>>1)]=(a[j+k]-t+mod)%mod;
                a[j+k]=(a[j+k]+t)%mod;
            }
        }
    }
    if(type==-1)
    {
        for(int i=1;i<len>>1;i++)swap(a[i],a[len-i]);
        int t=qpow(len,mod-2);
        for(int i=0;i<len;i++)a[i]=(ll)a[i]*t%mod;
    }
}
////////////////
void test_root(int *a,int len)
{
    memset(test1,0,sizeof(test1));
    for(int i=0;i<len;i++)
        test1[i]=a[i];
    NTT(test1,len<<1,1);
    for(int i=0;i<len<<1;i++)
        test1[i]=(ll)test1[i]*test1[i]%mod;
    NTT(test1,len<<1,-1);
    for(int i=0;i<len;i++)
        printf("#%d ",test1[i]);
    puts("");
}
void test_inv(int *a,int *b,int len)
{
    memset(test1,0,sizeof(test1));
    memset(test2,0,sizeof(test2));
    for(int i=0;i<len;i++)
        test1[i]=a[i],test2[i]=b[i];
    NTT(test1,len<<1,1);
    NTT(test2,len<<1,1);
    for(int i=0;i<len<<1;i++)
        test3[i]=(ll)test1[i]*test2[i]%mod;
    NTT(test3,len<<1,-1);
    for(int i=0;i<len;i++)
        printf("#%d ",test3[i]);
    puts("");
}
////////////////
void get_inv(int *a,int *b,int len)
{
    static int tmp[N];
    if(len==1)
    {
        b[0]=qpow(a[0],mod-2);
        return;
    }
    get_inv(a,b,len>>1);
    for(int i=0;i<len;i++)tmp[i]=a[i];
    for(int i=len;i<len<<1;i++)tmp[i]=0;
    NTT(tmp,len<<1,1);
    NTT(b,len<<1,1);
    for(int i=0;i<len<<1;i++)
        b[i]=(ll)b[i]*(2-(ll)b[i]*tmp[i]%mod+mod)%mod;
    NTT(b,len<<1,-1);
    for(int i=len;i<len<<1;i++)b[i]=0;
}
void get_root(int *a,int *b,int len)
{
    static int invb[N],tmp[N];
    if(len==1){b[0]=1;return;}
    get_root(a,b,len>>1);
    for(int i=0;i<len<<1;i++)invb[i]=0;
    get_inv(b,invb,len);
    for(int i=0;i<len;i++)tmp[i]=a[i];
    for(int i=len;i<len<<1;i++)tmp[i]=0;
    NTT(tmp,len<<1,1);
    NTT(b,len<<1,1);
    NTT(invb,len<<1,1);
    for(int i=0;i<len<<1;i++)
        b[i]=(ll)inv2*(b[i]+(ll)tmp[i]*invb[i]%mod)%mod;
    NTT(b,len<<1,-1);
    for(int i=len;i<len<<1;i++)b[i]=0;
}
void get_ln(int *a,int *b,int len)
{
    static int inva[N],a1[N];
    for(int i=0;i<len<<1;i++)inva[i]=0;
    get_inv(a,inva,len);
    for(int i=0;i<len;i++)a1[i]=(ll)(i+1)*a[i+1]%mod;
    for(int i=len;i<len<<1;i++)a1[i]=0;
    NTT(a1,len<<1,1);
    NTT(inva,len<<1,1);
    for(int i=0;i<len<<1;i++)a1[i]=(ll)a1[i]*inva[i]%mod;
    NTT(a1,len<<1,-1);
    b[0]=0;
    for(int i=1;i<len;i++)
        b[i]=(ll)a1[i-1]*inv[i]%mod;
    for(int i=len;i<len<<1;i++)b[i]=0;
}
void get_exp(int *a,int *b,int len)
{
    static int lnb[N],tmp[N];
    if(len==1){b[0]=1;return;}
    get_exp(a,b,len>>1);
    for(int i=0;i<len<<1;i++)lnb[i]=0;
    get_ln(b,lnb,len);
    for(int i=0;i<len;i++)tmp[i]=(a[i]-lnb[i]+mod)%mod;
    tmp[0]++;
    for(int i=len;i<len<<1;i++)tmp[i]=0;
    NTT(b,len<<1,1);
    NTT(tmp,len<<1,1);
    for(int i=0;i<len<<1;i++)
        b[i]=(ll)b[i]*tmp[i]%mod;
    NTT(b,len<<1,-1);
    for(int i=len;i<len<<1;i++)b[i]=0;
}
int main()
{
    //freopen("tt.in","r",stdin);
    scanf("%d%d",&n,&m);
    for(len=1;len<=m;len<<=1);
    for(int i=1;i<len;i++)inv[i]=qpow(i,mod-2);
    tmp[0]=1;tmp[1]=6;tmp[2]=1;
    get_root(tmp,sq,len);
    get_inv(sq,inv_sq,len);
    rt1[0]=rt1[1]=1;
    for(int i=0;i<len;i++)
        rt1[i]=(ll)(rt1[i]+sq[i])%mod*inv2%mod;
    get_ln(rt1,ln1,len);
    for(int i=0;i<len;i++)
        ln1[i]=(ll)ln1[i]*(n+1)%mod;
    get_exp(ln1,a,len);
    NTT(inv_sq,len<<1,1);
    NTT(a,len<<1,1);
    for(int i=0;i<len<<1;i++)
        ans[i]=(ll)a[i]*inv_sq[i]%mod;
    NTT(ans,len<<1,-1);
    for(int i=1;i<=m;i++)
        printf("%d ",i>n ? 0:ans[i]);
    return 0;
}

标签: dp, 多项式, 数学

添加新评论