标签 dp 下的文章

一.题目大意

一行有$$n$$个球,现在将这些球分成$$K$$组,每组可以有一个球或相邻两个球。一个球只能在至多一个组中。求对于$$1\le K\le m$$的所有$$K$$分别有多少种分组方法。答案对$$998244353$$取模。
$$1\le n \le 10^9$$
$$1\le K \le 2^{15}$$

二.解题报告

算法1

设$$f[i][j]$$表示把$$i$$个球分成$$j$$组的方案数。
则$$f[i][j]=f[i-1][j-1]+f[i-2][j-1]+f[i-1][j]$$
最终答案为$$f[n][1...m]$$
时间复杂度$$O(nm)$$

#include <bits/stdc++.h>
using namespace std;
#define ll long long
const int N=5100,mod=998244353;
int n,m;
int f[N][N];
int main()
{
    //freopen("tt.in","r",stdin);
    scanf("%d%d",&n,&m);
    f[1][1]=f[1][0]=f[0][0]=1;
    for(int i=2;i<=n;i++)
    {
        f[i][0]=1;
        for(int j=1;j<=i&&j<=m;j++)
            f[i][j]=((f[i-1][j-1]+f[i-2][j-1])%mod+f[i-1][j])%mod;
    }
    for(int i=1;i<=m;i++)
        printf("%d ",f[n][i]);
    return 0;
}

算法2

设$$a(K)$$表示分为$$K$$组的答案。枚举相邻两个一组的个数$$i$$
$$a(K)=\sum\limits_{i=0}^{min(K,n-K)}C_{n-i}^K C_K^i$$

$$=\sum\limits_{i=0}^{min(K,n-K)} \frac{(n-i)!}{(n-i-K)!i!(K-i)!}$$

其中$$\frac{1}{i!}$$和$$\frac{1}{(K-i)!}$$可以通过$$O(m)$$预处理每次$$O(1)$$计算。
设$$g(i,K) = \frac{(n-i)!}{(n-i-K)!}$$
则$$g(i,K) = g(i,K-1)*(n-i-K+1)$$
对于每个$$K$$可以$$O(m)$$处理。
时间复杂度 $$O(m^2)$$

#include <bits/stdc++.h>
using namespace std;
#define ll long long
const int N=21000,mod=998244353;
int n,m,ans;
int jc[N],njc[N],g[N];
int qpow(int x,int y)
{
    int ret=1;
    while(y)
    {
        if(y&1)ret=(ll)ret*x%mod;
        x=(ll)x*x%mod;y>>=1;
    }
    return ret;
}
void init()
{
    jc[0]=njc[0]=1;
    for(int i=1;i<=m;i++)jc[i]=(ll)jc[i-1]*i%mod;
    njc[m]=qpow(jc[m],mod-2);
    for(int i=m-1;i>=1;i--)
        njc[i]=(ll)njc[i+1]*(i+1)%mod;
}
int main()
{
    //freopen("tt.in","r",stdin);
    scanf("%d%d",&n,&m);
    init();
    for(int i=0;i<=m;i++)g[i]=1;
    for(int K=1;K<=m;K++)
    {
        ans=0;
        for(int i=0;i<=m;i++)
            g[i]=(ll)g[i]*(n-i-K+1)%mod;
        for(int i=0;i<=K&&i<=n-K;i++)
            ans=(ans+(ll)g[i]*njc[i]%mod*njc[K-i]%mod)%mod;
        printf("%d ",ans);
    }
    return 0;
}

算法3

在算法1基础上将$$f[i][0...m]$$看成一个关于x的多项式,称为$$f[i]$$
通过算法1可知如果已知$$f[n],f[n+1]$$可以$$O(m)$$求出$$f[n+2]$$
考虑在已知$$f[a],f[b],f[a-1],f[b-1]$$时计算$$f[a+b]$$
$$f[a+b]$$即为将长度为$$a$$和长度为$$b$$的两段连接起来得到的。
在连接处有两种情况,如果有一个两个球的组跨过连接处那么为$$f[a-1] * [b-1]$$(乘号为多项式卷积)
否则为$$f[a] * f[b]$$
因此$$f[a+b]=f[a-1] * f[b-1]+f[a] * f[b]$$
卷积可以通过NTT在$$O(m \log m)$$计算
因此可以在$$O(m\log m)$$通过$$f[a],f[b],f[a-1],f[b-1]$$计算$$f[a+b]$$
因此如果已知$$f[a-2],f[a-1],f[a]$$以及$$f[b-2],f[b-1],f[b]$$
可以通过$$f[a-2],f[a-1]$$,$$f[b-2],f[b-1]$$计算$$f[a+b-2]$$
通过$$f[a-1],f[a]$$,$$f[b-2],f[b-1]$$计算$$f[a+b-1]$$
通过$$f[a+b-2],f[a+b-1]$$计算$$f[a+b]$$
因此可以用倍增来计算$$f[n]$$
时间复杂度$$O(m\log n\log m)$$

#include <bits/stdc++.h>
using namespace std;
#define N (1<<16)+10
#define mod 998244353
#define ll long long
int len,n,m,flag;
int qpow(int x,int y)
{
    int ret=1;
    while(y)
    {
        if(y&1)ret=(ll)ret*x%mod;
        x=(ll)x*x%mod;y>>=1;
    }
    return ret;
}
void NTT(int *a,int len,int type)
{
    for(int i=0,t=0;i<len;i++)
    {
        if(i>t)swap(a[i],a[t]);
        for(int j=len>>1;(t^=j)<j;j>>=1);
    }
    for(int i=2;i<=len;i<<=1)
    {
        int wn=qpow(3,(mod-1)/i);
        for(int j=0;j<len;j+=i)
        {
            int w=1,t;
            for(int k=0;k<i>>1;k++,w=(ll)w*wn%mod)
            {
                t=(ll)a[j+k+(i>>1)]*w%mod;
                a[j+k+(i>>1)]=(a[j+k]-t+mod)%mod;
                a[j+k]=(a[j+k]+t)%mod;
            }
        }
    }
    if(type==-1)
    {
        for(int i=1;i<=len>>1;i++)swap(a[i],a[len-i]);
        int t=qpow(len,mod-2);
        for(int i=0;i<len;i++)a[i]=(ll)a[i]*t%mod;
    }
}
struct node
{
    int a[N],b[N];
    void dft()
    {
        memset(b,0,sizeof(b));
        for(int i=0;i<len>>1;i++)b[i]=a[i];
        NTT(b,len,1);
    }
}a1,a2,a3,b1,b2,b3,ans,ans1,ans2,c1,c2;
void mul(node &r1,const node &r2,const node &r3)
{
    for(int i=0;i<len;i++)
        r1.b[i]=r1.a[i]=(ll)r2.b[i]*r3.b[i]%mod;
    NTT(r1.a,len,-1);
    for(int i=len>>1;i<len;i++)r1.a[i]=0;
}
void inv(node &r1,const node &r2,const node &r3)
{
    r1.a[0]=1;
    for(int i=1;i<len>>1;i++)
        r1.a[i]=(r2.a[i]+r3.a[i-1])%mod;
    r1.dft();
}
int main()
{
    //freopen("tt.in","r",stdin);
    scanf("%d%d",&n,&m);
    for(len=1;len<(m+1)<<1;len<<=1);
    a2.a[0]=1;a3.a[0]=1;a3.a[1]=1;
    a1.dft();a2.dft();a3.dft();
    for(;n;n>>=1)
    {
        if(n&1)
        {
            if(!flag)
            {
                ans=a3;ans1=a2;ans2=a1;
                flag=1;
            }
            else
            {
                mul(c1,ans,a3);mul(c2,ans1,a2);
                inv(ans,c1,c2);
                mul(c1,ans1,a3);mul(c2,ans2,a2);
                inv(ans1,c1,c2);
                for(int i=0;i<len>>1;i++)
                    ans2.a[i]=((ans.a[i+1]-ans1.a[i+1]+mod)%mod-ans1.a[i]+mod)%mod;
                ans2.dft();
            }
        }
        mul(c1,a2,a2);mul(c2,a1,a1);
        inv(b1,c1,c2);
        mul(c1,a3,a3);mul(c2,a2,a2);
        inv(b3,c1,c2);
        b2.a[0]=1;
        for(int i=1;i<len>>1;i++)
            b2.a[i]=((b3.a[i]-b2.a[i-1]+mod)%mod-b1.a[i-1]+mod)%mod;
        b2.dft();
        a1=b1;a2=b2;a3=b3;
    }
    for(int i=1;i<=m;i++)
        printf("%d ",ans.a[i]);
    return 0;
}

算法4

由算法3知$$f[i]=(x+1)f[i-1]+x f[i-2]$$
求$$f[i]$$的通项公式:
设$$f[i]=C_1 T_1(x)^i+C_2 T_2(x)^i$$
其中$$T_1(x),T_2(x)$$为关于$$T(x)$$的方程$$T(x)^2=(x+1)T(x)+x$$的两根。

解得$$
\left\{
\begin{aligned}
T_1(x)=\frac{x+1+\sqrt{x^2+6x+1}}{2}\\
T_2(x)=\frac{x+1-\sqrt{x^2+6x+1}}{2}
\end{aligned}
\right.
$$

将$$f[0]=1,f[1]=1+x$$带入

解得$$
\left\{
\begin{aligned}
C_1=\frac{T_1(x)}{T_1(x)-T_2(x)}\\
C_2=\frac{T_2(x)}{T_2(x)-T_1(x)}
\end{aligned}
\right.
$$

因此$$f[i]=\frac{T_1(x)^{i+1}-T_2(x)^{i+1}}{T_1(x)-T_2(x)}$$

$$ = \frac{(\frac{x+1+\sqrt{x^2+6x+1}}{2})^{n+1}-(\frac{x+1-\sqrt{x^2+6x+1}}{2})^{n+1}}{\sqrt{x^2+6x+1}}$$

其中$$(\frac{x+1-\sqrt{x^2+6x+1}}{2})^{n+1}$$最低次项大于$$n$$,因此可以舍掉

通过多项式开根,多项式求逆可以$$O(m\log m)$$计算$$\frac{x+1+\sqrt{x^2+6x+1}}{2}$$以及$$\frac{1}{\sqrt{x^2+6x+1}}$$
通过多项式exp以及多项式取ln可以$$O(m\log m)$$计算一个多项式的$$n+1$$次幂
时间复杂度$$O(m\log m)$$

#include <bits/stdc++.h>
using namespace std;
#define N (1<<16)+10
#define ll long long
#define mod 998244353
const int inv2=499122177;
int n,m,len;
int tmp[N],sq[N],inv_sq[N],rt1[N],ln1[N],a[N],ans[N],inv[N];
int test1[N],test2[N],test3[N];
int qpow(int x,int y)
{
    int ret=1;
    while(y)
    {
        if(y&1)ret=(ll)ret*x%mod;
        x=(ll)x*x%mod;y>>=1;
    }
    return ret;
}
void NTT(int *a,int len,int type)
{
    for(int i=0,t=0;i<len;i++)
    {
        if(i<t)swap(a[i],a[t]);
        for(int j=len>>1;(t^=j)<j;j>>=1);
    }
    for(int i=2;i<=len;i<<=1)
    {
        int wn=qpow(3,(mod-1)/i);
        for(int j=0;j<len;j+=i)
        {
            int w=1,t;
            for(int k=0;k<i>>1;k++,w=(ll)w*wn%mod)
            {
                t=(ll)a[j+k+(i>>1)]*w%mod;
                a[j+k+(i>>1)]=(a[j+k]-t+mod)%mod;
                a[j+k]=(a[j+k]+t)%mod;
            }
        }
    }
    if(type==-1)
    {
        for(int i=1;i<len>>1;i++)swap(a[i],a[len-i]);
        int t=qpow(len,mod-2);
        for(int i=0;i<len;i++)a[i]=(ll)a[i]*t%mod;
    }
}
////////////////
void test_root(int *a,int len)
{
    memset(test1,0,sizeof(test1));
    for(int i=0;i<len;i++)
        test1[i]=a[i];
    NTT(test1,len<<1,1);
    for(int i=0;i<len<<1;i++)
        test1[i]=(ll)test1[i]*test1[i]%mod;
    NTT(test1,len<<1,-1);
    for(int i=0;i<len;i++)
        printf("#%d ",test1[i]);
    puts("");
}
void test_inv(int *a,int *b,int len)
{
    memset(test1,0,sizeof(test1));
    memset(test2,0,sizeof(test2));
    for(int i=0;i<len;i++)
        test1[i]=a[i],test2[i]=b[i];
    NTT(test1,len<<1,1);
    NTT(test2,len<<1,1);
    for(int i=0;i<len<<1;i++)
        test3[i]=(ll)test1[i]*test2[i]%mod;
    NTT(test3,len<<1,-1);
    for(int i=0;i<len;i++)
        printf("#%d ",test3[i]);
    puts("");
}
////////////////
void get_inv(int *a,int *b,int len)
{
    static int tmp[N];
    if(len==1)
    {
        b[0]=qpow(a[0],mod-2);
        return;
    }
    get_inv(a,b,len>>1);
    for(int i=0;i<len;i++)tmp[i]=a[i];
    for(int i=len;i<len<<1;i++)tmp[i]=0;
    NTT(tmp,len<<1,1);
    NTT(b,len<<1,1);
    for(int i=0;i<len<<1;i++)
        b[i]=(ll)b[i]*(2-(ll)b[i]*tmp[i]%mod+mod)%mod;
    NTT(b,len<<1,-1);
    for(int i=len;i<len<<1;i++)b[i]=0;
}
void get_root(int *a,int *b,int len)
{
    static int invb[N],tmp[N];
    if(len==1){b[0]=1;return;}
    get_root(a,b,len>>1);
    for(int i=0;i<len<<1;i++)invb[i]=0;
    get_inv(b,invb,len);
    for(int i=0;i<len;i++)tmp[i]=a[i];
    for(int i=len;i<len<<1;i++)tmp[i]=0;
    NTT(tmp,len<<1,1);
    NTT(b,len<<1,1);
    NTT(invb,len<<1,1);
    for(int i=0;i<len<<1;i++)
        b[i]=(ll)inv2*(b[i]+(ll)tmp[i]*invb[i]%mod)%mod;
    NTT(b,len<<1,-1);
    for(int i=len;i<len<<1;i++)b[i]=0;
}
void get_ln(int *a,int *b,int len)
{
    static int inva[N],a1[N];
    for(int i=0;i<len<<1;i++)inva[i]=0;
    get_inv(a,inva,len);
    for(int i=0;i<len;i++)a1[i]=(ll)(i+1)*a[i+1]%mod;
    for(int i=len;i<len<<1;i++)a1[i]=0;
    NTT(a1,len<<1,1);
    NTT(inva,len<<1,1);
    for(int i=0;i<len<<1;i++)a1[i]=(ll)a1[i]*inva[i]%mod;
    NTT(a1,len<<1,-1);
    b[0]=0;
    for(int i=1;i<len;i++)
        b[i]=(ll)a1[i-1]*inv[i]%mod;
    for(int i=len;i<len<<1;i++)b[i]=0;
}
void get_exp(int *a,int *b,int len)
{
    static int lnb[N],tmp[N];
    if(len==1){b[0]=1;return;}
    get_exp(a,b,len>>1);
    for(int i=0;i<len<<1;i++)lnb[i]=0;
    get_ln(b,lnb,len);
    for(int i=0;i<len;i++)tmp[i]=(a[i]-lnb[i]+mod)%mod;
    tmp[0]++;
    for(int i=len;i<len<<1;i++)tmp[i]=0;
    NTT(b,len<<1,1);
    NTT(tmp,len<<1,1);
    for(int i=0;i<len<<1;i++)
        b[i]=(ll)b[i]*tmp[i]%mod;
    NTT(b,len<<1,-1);
    for(int i=len;i<len<<1;i++)b[i]=0;
}
int main()
{
    //freopen("tt.in","r",stdin);
    scanf("%d%d",&n,&m);
    for(len=1;len<=m;len<<=1);
    for(int i=1;i<len;i++)inv[i]=qpow(i,mod-2);
    tmp[0]=1;tmp[1]=6;tmp[2]=1;
    get_root(tmp,sq,len);
    get_inv(sq,inv_sq,len);
    rt1[0]=rt1[1]=1;
    for(int i=0;i<len;i++)
        rt1[i]=(ll)(rt1[i]+sq[i])%mod*inv2%mod;
    get_ln(rt1,ln1,len);
    for(int i=0;i<len;i++)
        ln1[i]=(ll)ln1[i]*(n+1)%mod;
    get_exp(ln1,a,len);
    NTT(inv_sq,len<<1,1);
    NTT(a,len<<1,1);
    for(int i=0;i<len<<1;i++)
        ans[i]=(ll)a[i]*inv_sq[i]%mod;
    NTT(ans,len<<1,-1);
    for(int i=1;i<=m;i++)
        printf("%d ",i>n ? 0:ans[i]);
    return 0;
}

题意:用三种颜色染一个长度为n的序列,m个条件,每个条件为一段区间内的颜色个数为x。求满足条件的序列个数。

设$$f[i][j][k]$$表示到$$i$$,最后出现位置最靠前的颜色最后出现位置为$$k$$,最后出现位置第二靠前的颜色最后出现位置为$$j$$的方案数。


#include <bits/stdc++.h>
using namespace std;
#define PA pair<int,int>
const int N=310,mod=1000000007;
int n,m,ans;
struct node
{
    int l,r,num;
    friend bool operator < (const node &r1,const node &r2)
    {return r1.r<r2.r;}
}a[N];
int f[N][N][N];
int main()
{
    //freopen("tt.in","r",stdin);
    scanf("%d%d",&n,&m);
    for(int i=1;i<=m;i++)
        scanf("%d%d%d",&a[i].l,&a[i].r,&a[i].num);
    sort(a+1,a+1+m);
    f[1][0][0]=3;
    for(int i=1,now=1;i<=n;i++)
    {
        while(now<=m&&a[now].r==i)
        {
            int l=a[now].l,num=a[now].num;
            for(int j=0;j<i;j++)
                for(int k=0;k<=j;k++)
                {
                    int t=1;
                    if(j>=l)t++;
                    if(k>=l)t++;
                    if(t!=num)f[i][j][k]=0;
                }
            now++;
        }
        if(i==n)break;
        for(int j=0;j<i;j++)
            for(int k=0;k<=j;k++)
            {
                (f[i+1][j][k]+=f[i][j][k])%=mod;
                (f[i+1][i][k]+=f[i][j][k])%=mod;
                (f[i+1][i][j]+=f[i][j][k])%=mod;
            }
    }
    for(int j=0;j<n;j++)
        for(int k=0;k<=j;k++)
            (ans+=f[n][j][k])%=mod;
    printf("%d\n",ans);
    return 0;
}

题意:一棵n个点的树,其中一些点已经有值,需要给所有没有值的点赋值,使任意相邻的两个点相差恰好为1,求是否可行,可行求一组解。

对所有有值的点建一棵虚树,每个儿子会对父亲的取值范围和奇偶性有一个限制,从下往上dp,如果出现不合法的情况无解。之后虚树上没有值的点会有一个取值范围。从上往下dp,虚树上每个点取范围内任意一个满足奇偶性的值,并对虚树边上的点赋值,最后对不在虚树边上的点赋值。

#include <bits/stdc++.h>
using namespace std;
#define N 110000
const int inf=1e9;
int n,tot,m,cnt,top;
int head[N],nex[N<<1],to[N<<1];
int fa[N][21],deep[N];
int val[N],L[N],R[N],pos[N],a[N],st[N],sig[N];
vector<int>vec[N];
void quit(){puts("No");exit(0);}
void add(int x,int y)
{
    tot++;
    nex[tot]=head[x];head[x]=tot;
    to[tot]=y;
}
void dfs(int x,int y)
{
    fa[x][0]=y;
    for(int i=1;i<=20;i++)
        fa[x][i]=fa[fa[x][i-1]][i-1];
    deep[x]=deep[y]+1;
    pos[x]=++cnt;
    for(int i=head[x];i;i=nex[i])
        if(to[i]!=y)
            dfs(to[i],x);
}
int cmp(int x,int y){return pos[x]<pos[y];}
int lca(int x,int y)
{
    if(deep[x]<deep[y])swap(x,y);
    for(int i=20;i>=0;i--)
        if(deep[fa[x][i]]>=deep[y])x=fa[x][i];
    if(x==y)return x;
    for(int i=20;i>=0;i--)
        if(fa[x][i]!=fa[y][i])
            x=fa[x][i],y=fa[y][i];
    return fa[x][0];
}
void dfs1(int x)
{
    L[x]=-inf;R[x]=inf;sig[x]=-1;
    for(int i=0,t,d,v;i<vec[x].size();i++)
    {
        dfs1(t=vec[x][i]);
        d=deep[t]-deep[x];
        v=(R[t]+d)&1;
        L[x]=max(L[x],L[t]-d);
        R[x]=min(R[x],R[t]+d);
        if(sig[x]==-1)sig[x]=v;
        else if(sig[x]!=v)quit();
    }
    if(L[x]>R[x])quit();
    if(val[x]!=val[0])
    {
        if(val[x]<L[x]||val[x]>R[x])quit();
        if(sig[x]!=-1&&(val[x]&1)!=sig[x])quit();
        L[x]=R[x]=val[x];
    }
}
void dfs2(int x)
{
    for(int i=0,t,d,d1;i<vec[x].size();i++)
    {
        t=vec[x][i];
        d=deep[t]-deep[x];
        val[t]=max(L[t],val[x]-d);
        d1=(val[x]+d-val[t])/2;
        int now=val[t],p=t;
        for(int j=1;j<=d1;j++)
            val[fa[p][0]]=val[p]+1,p=fa[p][0];
        for(int j=1;j<d-d1;j++)
            val[fa[p][0]]=val[p]-1,p=fa[p][0];
        dfs2(t);
    }
}
void dfs3(int x,int y)
{
    if(val[x]==val[0])
        val[x]=val[y]+1;
    for(int i=head[x];i;i=nex[i])
        if(to[i]!=y)dfs3(to[i],x);
}
int main()
{
    //freopen("tt.in","r",stdin);
    scanf("%d",&n);
    for(int i=1,x,y;i<n;i++)
    {
        scanf("%d%d",&x,&y);
        add(x,y);add(y,x);
    }
    memset(val,-0x3f,sizeof(val));
    scanf("%d",&m);
    for(int i=1;i<=m;i++)
    {
        scanf("%d",&a[i]);
        scanf("%d",&val[a[i]]);
    }
    dfs(1,0);
    sort(a+1,a+1+m,cmp);
    st[top=1]=1;
    for(int i=1,t;i<=m;i++)
    {
        while((t=lca(a[i],st[top]))!=st[top])
        {
            if(deep[t]>deep[st[top-1]])
            {
                vec[t].push_back(st[top]);
                st[top]=t;
            }
            else
            {
                vec[st[top-1]].push_back(st[top]);
                top--;
            }
        }
        if(st[top]!=a[i])
            st[++top]=a[i];
    }
    for(int i=top;i>1;i--)
        vec[st[i-1]].push_back(st[i]);
    dfs1(1);
    val[1]=L[1];
    dfs2(1);
    dfs3(1,0);
    puts("Yes");
    for(int i=1;i<=n;i++)
        printf("%d\n",val[i]);
    return 0;
}