我的作业:codeforces 755G
一.题目大意
一行有$$n$$个球,现在将这些球分成$$K$$组,每组可以有一个球或相邻两个球。一个球只能在至多一个组中。求对于$$1\le K\le m$$的所有$$K$$分别有多少种分组方法。答案对$$998244353$$取模。
$$1\le n \le 10^9$$
$$1\le K \le 2^{15}$$
二.解题报告
算法1
设$$f[i][j]$$表示把$$i$$个球分成$$j$$组的方案数。
则$$f[i][j]=f[i-1][j-1]+f[i-2][j-1]+f[i-1][j]$$
最终答案为$$f[n][1...m]$$
时间复杂度$$O(nm)$$
#include <bits/stdc++.h>
using namespace std;
#define ll long long
const int N=5100,mod=998244353;
int n,m;
int f[N][N];
int main()
{
//freopen("tt.in","r",stdin);
scanf("%d%d",&n,&m);
f[1][1]=f[1][0]=f[0][0]=1;
for(int i=2;i<=n;i++)
{
f[i][0]=1;
for(int j=1;j<=i&&j<=m;j++)
f[i][j]=((f[i-1][j-1]+f[i-2][j-1])%mod+f[i-1][j])%mod;
}
for(int i=1;i<=m;i++)
printf("%d ",f[n][i]);
return 0;
}
算法2
设$$a(K)$$表示分为$$K$$组的答案。枚举相邻两个一组的个数$$i$$
$$a(K)=\sum\limits_{i=0}^{min(K,n-K)}C_{n-i}^K C_K^i$$
$$=\sum\limits_{i=0}^{min(K,n-K)} \frac{(n-i)!}{(n-i-K)!i!(K-i)!}$$
其中$$\frac{1}{i!}$$和$$\frac{1}{(K-i)!}$$可以通过$$O(m)$$预处理每次$$O(1)$$计算。
设$$g(i,K) = \frac{(n-i)!}{(n-i-K)!}$$
则$$g(i,K) = g(i,K-1)*(n-i-K+1)$$
对于每个$$K$$可以$$O(m)$$处理。
时间复杂度 $$O(m^2)$$
#include <bits/stdc++.h>
using namespace std;
#define ll long long
const int N=21000,mod=998244353;
int n,m,ans;
int jc[N],njc[N],g[N];
int qpow(int x,int y)
{
int ret=1;
while(y)
{
if(y&1)ret=(ll)ret*x%mod;
x=(ll)x*x%mod;y>>=1;
}
return ret;
}
void init()
{
jc[0]=njc[0]=1;
for(int i=1;i<=m;i++)jc[i]=(ll)jc[i-1]*i%mod;
njc[m]=qpow(jc[m],mod-2);
for(int i=m-1;i>=1;i--)
njc[i]=(ll)njc[i+1]*(i+1)%mod;
}
int main()
{
//freopen("tt.in","r",stdin);
scanf("%d%d",&n,&m);
init();
for(int i=0;i<=m;i++)g[i]=1;
for(int K=1;K<=m;K++)
{
ans=0;
for(int i=0;i<=m;i++)
g[i]=(ll)g[i]*(n-i-K+1)%mod;
for(int i=0;i<=K&&i<=n-K;i++)
ans=(ans+(ll)g[i]*njc[i]%mod*njc[K-i]%mod)%mod;
printf("%d ",ans);
}
return 0;
}
算法3
在算法1基础上将$$f[i][0...m]$$看成一个关于x的多项式,称为$$f[i]$$
通过算法1可知如果已知$$f[n],f[n+1]$$可以$$O(m)$$求出$$f[n+2]$$
考虑在已知$$f[a],f[b],f[a-1],f[b-1]$$时计算$$f[a+b]$$
$$f[a+b]$$即为将长度为$$a$$和长度为$$b$$的两段连接起来得到的。
在连接处有两种情况,如果有一个两个球的组跨过连接处那么为$$f[a-1] * [b-1]$$(乘号为多项式卷积)
否则为$$f[a] * f[b]$$
因此$$f[a+b]=f[a-1] * f[b-1]+f[a] * f[b]$$
卷积可以通过NTT在$$O(m \log m)$$计算
因此可以在$$O(m\log m)$$通过$$f[a],f[b],f[a-1],f[b-1]$$计算$$f[a+b]$$
因此如果已知$$f[a-2],f[a-1],f[a]$$以及$$f[b-2],f[b-1],f[b]$$
可以通过$$f[a-2],f[a-1]$$,$$f[b-2],f[b-1]$$计算$$f[a+b-2]$$
通过$$f[a-1],f[a]$$,$$f[b-2],f[b-1]$$计算$$f[a+b-1]$$
通过$$f[a+b-2],f[a+b-1]$$计算$$f[a+b]$$
因此可以用倍增来计算$$f[n]$$
时间复杂度$$O(m\log n\log m)$$
#include <bits/stdc++.h>
using namespace std;
#define N (1<<16)+10
#define mod 998244353
#define ll long long
int len,n,m,flag;
int qpow(int x,int y)
{
int ret=1;
while(y)
{
if(y&1)ret=(ll)ret*x%mod;
x=(ll)x*x%mod;y>>=1;
}
return ret;
}
void NTT(int *a,int len,int type)
{
for(int i=0,t=0;i<len;i++)
{
if(i>t)swap(a[i],a[t]);
for(int j=len>>1;(t^=j)<j;j>>=1);
}
for(int i=2;i<=len;i<<=1)
{
int wn=qpow(3,(mod-1)/i);
for(int j=0;j<len;j+=i)
{
int w=1,t;
for(int k=0;k<i>>1;k++,w=(ll)w*wn%mod)
{
t=(ll)a[j+k+(i>>1)]*w%mod;
a[j+k+(i>>1)]=(a[j+k]-t+mod)%mod;
a[j+k]=(a[j+k]+t)%mod;
}
}
}
if(type==-1)
{
for(int i=1;i<=len>>1;i++)swap(a[i],a[len-i]);
int t=qpow(len,mod-2);
for(int i=0;i<len;i++)a[i]=(ll)a[i]*t%mod;
}
}
struct node
{
int a[N],b[N];
void dft()
{
memset(b,0,sizeof(b));
for(int i=0;i<len>>1;i++)b[i]=a[i];
NTT(b,len,1);
}
}a1,a2,a3,b1,b2,b3,ans,ans1,ans2,c1,c2;
void mul(node &r1,const node &r2,const node &r3)
{
for(int i=0;i<len;i++)
r1.b[i]=r1.a[i]=(ll)r2.b[i]*r3.b[i]%mod;
NTT(r1.a,len,-1);
for(int i=len>>1;i<len;i++)r1.a[i]=0;
}
void inv(node &r1,const node &r2,const node &r3)
{
r1.a[0]=1;
for(int i=1;i<len>>1;i++)
r1.a[i]=(r2.a[i]+r3.a[i-1])%mod;
r1.dft();
}
int main()
{
//freopen("tt.in","r",stdin);
scanf("%d%d",&n,&m);
for(len=1;len<(m+1)<<1;len<<=1);
a2.a[0]=1;a3.a[0]=1;a3.a[1]=1;
a1.dft();a2.dft();a3.dft();
for(;n;n>>=1)
{
if(n&1)
{
if(!flag)
{
ans=a3;ans1=a2;ans2=a1;
flag=1;
}
else
{
mul(c1,ans,a3);mul(c2,ans1,a2);
inv(ans,c1,c2);
mul(c1,ans1,a3);mul(c2,ans2,a2);
inv(ans1,c1,c2);
for(int i=0;i<len>>1;i++)
ans2.a[i]=((ans.a[i+1]-ans1.a[i+1]+mod)%mod-ans1.a[i]+mod)%mod;
ans2.dft();
}
}
mul(c1,a2,a2);mul(c2,a1,a1);
inv(b1,c1,c2);
mul(c1,a3,a3);mul(c2,a2,a2);
inv(b3,c1,c2);
b2.a[0]=1;
for(int i=1;i<len>>1;i++)
b2.a[i]=((b3.a[i]-b2.a[i-1]+mod)%mod-b1.a[i-1]+mod)%mod;
b2.dft();
a1=b1;a2=b2;a3=b3;
}
for(int i=1;i<=m;i++)
printf("%d ",ans.a[i]);
return 0;
}
算法4
由算法3知$$f[i]=(x+1)f[i-1]+x f[i-2]$$
求$$f[i]$$的通项公式:
设$$f[i]=C_1 T_1(x)^i+C_2 T_2(x)^i$$
其中$$T_1(x),T_2(x)$$为关于$$T(x)$$的方程$$T(x)^2=(x+1)T(x)+x$$的两根。
解得$$
\left\{
\begin{aligned}
T_1(x)=\frac{x+1+\sqrt{x^2+6x+1}}{2}\\
T_2(x)=\frac{x+1-\sqrt{x^2+6x+1}}{2}
\end{aligned}
\right.
$$
将$$f[0]=1,f[1]=1+x$$带入
解得$$
\left\{
\begin{aligned}
C_1=\frac{T_1(x)}{T_1(x)-T_2(x)}\\
C_2=\frac{T_2(x)}{T_2(x)-T_1(x)}
\end{aligned}
\right.
$$
因此$$f[i]=\frac{T_1(x)^{i+1}-T_2(x)^{i+1}}{T_1(x)-T_2(x)}$$
$$ = \frac{(\frac{x+1+\sqrt{x^2+6x+1}}{2})^{n+1}-(\frac{x+1-\sqrt{x^2+6x+1}}{2})^{n+1}}{\sqrt{x^2+6x+1}}$$
其中$$(\frac{x+1-\sqrt{x^2+6x+1}}{2})^{n+1}$$最低次项大于$$n$$,因此可以舍掉
通过多项式开根,多项式求逆可以$$O(m\log m)$$计算$$\frac{x+1+\sqrt{x^2+6x+1}}{2}$$以及$$\frac{1}{\sqrt{x^2+6x+1}}$$
通过多项式exp以及多项式取ln可以$$O(m\log m)$$计算一个多项式的$$n+1$$次幂
时间复杂度$$O(m\log m)$$
#include <bits/stdc++.h>
using namespace std;
#define N (1<<16)+10
#define ll long long
#define mod 998244353
const int inv2=499122177;
int n,m,len;
int tmp[N],sq[N],inv_sq[N],rt1[N],ln1[N],a[N],ans[N],inv[N];
int test1[N],test2[N],test3[N];
int qpow(int x,int y)
{
int ret=1;
while(y)
{
if(y&1)ret=(ll)ret*x%mod;
x=(ll)x*x%mod;y>>=1;
}
return ret;
}
void NTT(int *a,int len,int type)
{
for(int i=0,t=0;i<len;i++)
{
if(i<t)swap(a[i],a[t]);
for(int j=len>>1;(t^=j)<j;j>>=1);
}
for(int i=2;i<=len;i<<=1)
{
int wn=qpow(3,(mod-1)/i);
for(int j=0;j<len;j+=i)
{
int w=1,t;
for(int k=0;k<i>>1;k++,w=(ll)w*wn%mod)
{
t=(ll)a[j+k+(i>>1)]*w%mod;
a[j+k+(i>>1)]=(a[j+k]-t+mod)%mod;
a[j+k]=(a[j+k]+t)%mod;
}
}
}
if(type==-1)
{
for(int i=1;i<len>>1;i++)swap(a[i],a[len-i]);
int t=qpow(len,mod-2);
for(int i=0;i<len;i++)a[i]=(ll)a[i]*t%mod;
}
}
////////////////
void test_root(int *a,int len)
{
memset(test1,0,sizeof(test1));
for(int i=0;i<len;i++)
test1[i]=a[i];
NTT(test1,len<<1,1);
for(int i=0;i<len<<1;i++)
test1[i]=(ll)test1[i]*test1[i]%mod;
NTT(test1,len<<1,-1);
for(int i=0;i<len;i++)
printf("#%d ",test1[i]);
puts("");
}
void test_inv(int *a,int *b,int len)
{
memset(test1,0,sizeof(test1));
memset(test2,0,sizeof(test2));
for(int i=0;i<len;i++)
test1[i]=a[i],test2[i]=b[i];
NTT(test1,len<<1,1);
NTT(test2,len<<1,1);
for(int i=0;i<len<<1;i++)
test3[i]=(ll)test1[i]*test2[i]%mod;
NTT(test3,len<<1,-1);
for(int i=0;i<len;i++)
printf("#%d ",test3[i]);
puts("");
}
////////////////
void get_inv(int *a,int *b,int len)
{
static int tmp[N];
if(len==1)
{
b[0]=qpow(a[0],mod-2);
return;
}
get_inv(a,b,len>>1);
for(int i=0;i<len;i++)tmp[i]=a[i];
for(int i=len;i<len<<1;i++)tmp[i]=0;
NTT(tmp,len<<1,1);
NTT(b,len<<1,1);
for(int i=0;i<len<<1;i++)
b[i]=(ll)b[i]*(2-(ll)b[i]*tmp[i]%mod+mod)%mod;
NTT(b,len<<1,-1);
for(int i=len;i<len<<1;i++)b[i]=0;
}
void get_root(int *a,int *b,int len)
{
static int invb[N],tmp[N];
if(len==1){b[0]=1;return;}
get_root(a,b,len>>1);
for(int i=0;i<len<<1;i++)invb[i]=0;
get_inv(b,invb,len);
for(int i=0;i<len;i++)tmp[i]=a[i];
for(int i=len;i<len<<1;i++)tmp[i]=0;
NTT(tmp,len<<1,1);
NTT(b,len<<1,1);
NTT(invb,len<<1,1);
for(int i=0;i<len<<1;i++)
b[i]=(ll)inv2*(b[i]+(ll)tmp[i]*invb[i]%mod)%mod;
NTT(b,len<<1,-1);
for(int i=len;i<len<<1;i++)b[i]=0;
}
void get_ln(int *a,int *b,int len)
{
static int inva[N],a1[N];
for(int i=0;i<len<<1;i++)inva[i]=0;
get_inv(a,inva,len);
for(int i=0;i<len;i++)a1[i]=(ll)(i+1)*a[i+1]%mod;
for(int i=len;i<len<<1;i++)a1[i]=0;
NTT(a1,len<<1,1);
NTT(inva,len<<1,1);
for(int i=0;i<len<<1;i++)a1[i]=(ll)a1[i]*inva[i]%mod;
NTT(a1,len<<1,-1);
b[0]=0;
for(int i=1;i<len;i++)
b[i]=(ll)a1[i-1]*inv[i]%mod;
for(int i=len;i<len<<1;i++)b[i]=0;
}
void get_exp(int *a,int *b,int len)
{
static int lnb[N],tmp[N];
if(len==1){b[0]=1;return;}
get_exp(a,b,len>>1);
for(int i=0;i<len<<1;i++)lnb[i]=0;
get_ln(b,lnb,len);
for(int i=0;i<len;i++)tmp[i]=(a[i]-lnb[i]+mod)%mod;
tmp[0]++;
for(int i=len;i<len<<1;i++)tmp[i]=0;
NTT(b,len<<1,1);
NTT(tmp,len<<1,1);
for(int i=0;i<len<<1;i++)
b[i]=(ll)b[i]*tmp[i]%mod;
NTT(b,len<<1,-1);
for(int i=len;i<len<<1;i++)b[i]=0;
}
int main()
{
//freopen("tt.in","r",stdin);
scanf("%d%d",&n,&m);
for(len=1;len<=m;len<<=1);
for(int i=1;i<len;i++)inv[i]=qpow(i,mod-2);
tmp[0]=1;tmp[1]=6;tmp[2]=1;
get_root(tmp,sq,len);
get_inv(sq,inv_sq,len);
rt1[0]=rt1[1]=1;
for(int i=0;i<len;i++)
rt1[i]=(ll)(rt1[i]+sq[i])%mod*inv2%mod;
get_ln(rt1,ln1,len);
for(int i=0;i<len;i++)
ln1[i]=(ll)ln1[i]*(n+1)%mod;
get_exp(ln1,a,len);
NTT(inv_sq,len<<1,1);
NTT(a,len<<1,1);
for(int i=0;i<len<<1;i++)
ans[i]=(ll)a[i]*inv_sq[i]%mod;
NTT(ans,len<<1,-1);
for(int i=1;i<=m;i++)
printf("%d ",i>n ? 0:ans[i]);
return 0;
}