分类 OI 下的文章

一.题目大意

一行有$$n$$个球,现在将这些球分成$$K$$组,每组可以有一个球或相邻两个球。一个球只能在至多一个组中。求对于$$1\le K\le m$$的所有$$K$$分别有多少种分组方法。答案对$$998244353$$取模。
$$1\le n \le 10^9$$
$$1\le K \le 2^{15}$$

二.解题报告

算法1

设$$f[i][j]$$表示把$$i$$个球分成$$j$$组的方案数。
则$$f[i][j]=f[i-1][j-1]+f[i-2][j-1]+f[i-1][j]$$
最终答案为$$f[n][1...m]$$
时间复杂度$$O(nm)$$

#include <bits/stdc++.h>
using namespace std;
#define ll long long
const int N=5100,mod=998244353;
int n,m;
int f[N][N];
int main()
{
    //freopen("tt.in","r",stdin);
    scanf("%d%d",&n,&m);
    f[1][1]=f[1][0]=f[0][0]=1;
    for(int i=2;i<=n;i++)
    {
        f[i][0]=1;
        for(int j=1;j<=i&&j<=m;j++)
            f[i][j]=((f[i-1][j-1]+f[i-2][j-1])%mod+f[i-1][j])%mod;
    }
    for(int i=1;i<=m;i++)
        printf("%d ",f[n][i]);
    return 0;
}

算法2

设$$a(K)$$表示分为$$K$$组的答案。枚举相邻两个一组的个数$$i$$
$$a(K)=\sum\limits_{i=0}^{min(K,n-K)}C_{n-i}^K C_K^i$$

$$=\sum\limits_{i=0}^{min(K,n-K)} \frac{(n-i)!}{(n-i-K)!i!(K-i)!}$$

其中$$\frac{1}{i!}$$和$$\frac{1}{(K-i)!}$$可以通过$$O(m)$$预处理每次$$O(1)$$计算。
设$$g(i,K) = \frac{(n-i)!}{(n-i-K)!}$$
则$$g(i,K) = g(i,K-1)*(n-i-K+1)$$
对于每个$$K$$可以$$O(m)$$处理。
时间复杂度 $$O(m^2)$$

#include <bits/stdc++.h>
using namespace std;
#define ll long long
const int N=21000,mod=998244353;
int n,m,ans;
int jc[N],njc[N],g[N];
int qpow(int x,int y)
{
    int ret=1;
    while(y)
    {
        if(y&1)ret=(ll)ret*x%mod;
        x=(ll)x*x%mod;y>>=1;
    }
    return ret;
}
void init()
{
    jc[0]=njc[0]=1;
    for(int i=1;i<=m;i++)jc[i]=(ll)jc[i-1]*i%mod;
    njc[m]=qpow(jc[m],mod-2);
    for(int i=m-1;i>=1;i--)
        njc[i]=(ll)njc[i+1]*(i+1)%mod;
}
int main()
{
    //freopen("tt.in","r",stdin);
    scanf("%d%d",&n,&m);
    init();
    for(int i=0;i<=m;i++)g[i]=1;
    for(int K=1;K<=m;K++)
    {
        ans=0;
        for(int i=0;i<=m;i++)
            g[i]=(ll)g[i]*(n-i-K+1)%mod;
        for(int i=0;i<=K&&i<=n-K;i++)
            ans=(ans+(ll)g[i]*njc[i]%mod*njc[K-i]%mod)%mod;
        printf("%d ",ans);
    }
    return 0;
}

算法3

在算法1基础上将$$f[i][0...m]$$看成一个关于x的多项式,称为$$f[i]$$
通过算法1可知如果已知$$f[n],f[n+1]$$可以$$O(m)$$求出$$f[n+2]$$
考虑在已知$$f[a],f[b],f[a-1],f[b-1]$$时计算$$f[a+b]$$
$$f[a+b]$$即为将长度为$$a$$和长度为$$b$$的两段连接起来得到的。
在连接处有两种情况,如果有一个两个球的组跨过连接处那么为$$f[a-1] * [b-1]$$(乘号为多项式卷积)
否则为$$f[a] * f[b]$$
因此$$f[a+b]=f[a-1] * f[b-1]+f[a] * f[b]$$
卷积可以通过NTT在$$O(m \log m)$$计算
因此可以在$$O(m\log m)$$通过$$f[a],f[b],f[a-1],f[b-1]$$计算$$f[a+b]$$
因此如果已知$$f[a-2],f[a-1],f[a]$$以及$$f[b-2],f[b-1],f[b]$$
可以通过$$f[a-2],f[a-1]$$,$$f[b-2],f[b-1]$$计算$$f[a+b-2]$$
通过$$f[a-1],f[a]$$,$$f[b-2],f[b-1]$$计算$$f[a+b-1]$$
通过$$f[a+b-2],f[a+b-1]$$计算$$f[a+b]$$
因此可以用倍增来计算$$f[n]$$
时间复杂度$$O(m\log n\log m)$$

#include <bits/stdc++.h>
using namespace std;
#define N (1<<16)+10
#define mod 998244353
#define ll long long
int len,n,m,flag;
int qpow(int x,int y)
{
    int ret=1;
    while(y)
    {
        if(y&1)ret=(ll)ret*x%mod;
        x=(ll)x*x%mod;y>>=1;
    }
    return ret;
}
void NTT(int *a,int len,int type)
{
    for(int i=0,t=0;i<len;i++)
    {
        if(i>t)swap(a[i],a[t]);
        for(int j=len>>1;(t^=j)<j;j>>=1);
    }
    for(int i=2;i<=len;i<<=1)
    {
        int wn=qpow(3,(mod-1)/i);
        for(int j=0;j<len;j+=i)
        {
            int w=1,t;
            for(int k=0;k<i>>1;k++,w=(ll)w*wn%mod)
            {
                t=(ll)a[j+k+(i>>1)]*w%mod;
                a[j+k+(i>>1)]=(a[j+k]-t+mod)%mod;
                a[j+k]=(a[j+k]+t)%mod;
            }
        }
    }
    if(type==-1)
    {
        for(int i=1;i<=len>>1;i++)swap(a[i],a[len-i]);
        int t=qpow(len,mod-2);
        for(int i=0;i<len;i++)a[i]=(ll)a[i]*t%mod;
    }
}
struct node
{
    int a[N],b[N];
    void dft()
    {
        memset(b,0,sizeof(b));
        for(int i=0;i<len>>1;i++)b[i]=a[i];
        NTT(b,len,1);
    }
}a1,a2,a3,b1,b2,b3,ans,ans1,ans2,c1,c2;
void mul(node &r1,const node &r2,const node &r3)
{
    for(int i=0;i<len;i++)
        r1.b[i]=r1.a[i]=(ll)r2.b[i]*r3.b[i]%mod;
    NTT(r1.a,len,-1);
    for(int i=len>>1;i<len;i++)r1.a[i]=0;
}
void inv(node &r1,const node &r2,const node &r3)
{
    r1.a[0]=1;
    for(int i=1;i<len>>1;i++)
        r1.a[i]=(r2.a[i]+r3.a[i-1])%mod;
    r1.dft();
}
int main()
{
    //freopen("tt.in","r",stdin);
    scanf("%d%d",&n,&m);
    for(len=1;len<(m+1)<<1;len<<=1);
    a2.a[0]=1;a3.a[0]=1;a3.a[1]=1;
    a1.dft();a2.dft();a3.dft();
    for(;n;n>>=1)
    {
        if(n&1)
        {
            if(!flag)
            {
                ans=a3;ans1=a2;ans2=a1;
                flag=1;
            }
            else
            {
                mul(c1,ans,a3);mul(c2,ans1,a2);
                inv(ans,c1,c2);
                mul(c1,ans1,a3);mul(c2,ans2,a2);
                inv(ans1,c1,c2);
                for(int i=0;i<len>>1;i++)
                    ans2.a[i]=((ans.a[i+1]-ans1.a[i+1]+mod)%mod-ans1.a[i]+mod)%mod;
                ans2.dft();
            }
        }
        mul(c1,a2,a2);mul(c2,a1,a1);
        inv(b1,c1,c2);
        mul(c1,a3,a3);mul(c2,a2,a2);
        inv(b3,c1,c2);
        b2.a[0]=1;
        for(int i=1;i<len>>1;i++)
            b2.a[i]=((b3.a[i]-b2.a[i-1]+mod)%mod-b1.a[i-1]+mod)%mod;
        b2.dft();
        a1=b1;a2=b2;a3=b3;
    }
    for(int i=1;i<=m;i++)
        printf("%d ",ans.a[i]);
    return 0;
}

算法4

由算法3知$$f[i]=(x+1)f[i-1]+x f[i-2]$$
求$$f[i]$$的通项公式:
设$$f[i]=C_1 T_1(x)^i+C_2 T_2(x)^i$$
其中$$T_1(x),T_2(x)$$为关于$$T(x)$$的方程$$T(x)^2=(x+1)T(x)+x$$的两根。

解得$$
\left\{
\begin{aligned}
T_1(x)=\frac{x+1+\sqrt{x^2+6x+1}}{2}\\
T_2(x)=\frac{x+1-\sqrt{x^2+6x+1}}{2}
\end{aligned}
\right.
$$

将$$f[0]=1,f[1]=1+x$$带入

解得$$
\left\{
\begin{aligned}
C_1=\frac{T_1(x)}{T_1(x)-T_2(x)}\\
C_2=\frac{T_2(x)}{T_2(x)-T_1(x)}
\end{aligned}
\right.
$$

因此$$f[i]=\frac{T_1(x)^{i+1}-T_2(x)^{i+1}}{T_1(x)-T_2(x)}$$

$$ = \frac{(\frac{x+1+\sqrt{x^2+6x+1}}{2})^{n+1}-(\frac{x+1-\sqrt{x^2+6x+1}}{2})^{n+1}}{\sqrt{x^2+6x+1}}$$

其中$$(\frac{x+1-\sqrt{x^2+6x+1}}{2})^{n+1}$$最低次项大于$$n$$,因此可以舍掉

通过多项式开根,多项式求逆可以$$O(m\log m)$$计算$$\frac{x+1+\sqrt{x^2+6x+1}}{2}$$以及$$\frac{1}{\sqrt{x^2+6x+1}}$$
通过多项式exp以及多项式取ln可以$$O(m\log m)$$计算一个多项式的$$n+1$$次幂
时间复杂度$$O(m\log m)$$

#include <bits/stdc++.h>
using namespace std;
#define N (1<<16)+10
#define ll long long
#define mod 998244353
const int inv2=499122177;
int n,m,len;
int tmp[N],sq[N],inv_sq[N],rt1[N],ln1[N],a[N],ans[N],inv[N];
int test1[N],test2[N],test3[N];
int qpow(int x,int y)
{
    int ret=1;
    while(y)
    {
        if(y&1)ret=(ll)ret*x%mod;
        x=(ll)x*x%mod;y>>=1;
    }
    return ret;
}
void NTT(int *a,int len,int type)
{
    for(int i=0,t=0;i<len;i++)
    {
        if(i<t)swap(a[i],a[t]);
        for(int j=len>>1;(t^=j)<j;j>>=1);
    }
    for(int i=2;i<=len;i<<=1)
    {
        int wn=qpow(3,(mod-1)/i);
        for(int j=0;j<len;j+=i)
        {
            int w=1,t;
            for(int k=0;k<i>>1;k++,w=(ll)w*wn%mod)
            {
                t=(ll)a[j+k+(i>>1)]*w%mod;
                a[j+k+(i>>1)]=(a[j+k]-t+mod)%mod;
                a[j+k]=(a[j+k]+t)%mod;
            }
        }
    }
    if(type==-1)
    {
        for(int i=1;i<len>>1;i++)swap(a[i],a[len-i]);
        int t=qpow(len,mod-2);
        for(int i=0;i<len;i++)a[i]=(ll)a[i]*t%mod;
    }
}
////////////////
void test_root(int *a,int len)
{
    memset(test1,0,sizeof(test1));
    for(int i=0;i<len;i++)
        test1[i]=a[i];
    NTT(test1,len<<1,1);
    for(int i=0;i<len<<1;i++)
        test1[i]=(ll)test1[i]*test1[i]%mod;
    NTT(test1,len<<1,-1);
    for(int i=0;i<len;i++)
        printf("#%d ",test1[i]);
    puts("");
}
void test_inv(int *a,int *b,int len)
{
    memset(test1,0,sizeof(test1));
    memset(test2,0,sizeof(test2));
    for(int i=0;i<len;i++)
        test1[i]=a[i],test2[i]=b[i];
    NTT(test1,len<<1,1);
    NTT(test2,len<<1,1);
    for(int i=0;i<len<<1;i++)
        test3[i]=(ll)test1[i]*test2[i]%mod;
    NTT(test3,len<<1,-1);
    for(int i=0;i<len;i++)
        printf("#%d ",test3[i]);
    puts("");
}
////////////////
void get_inv(int *a,int *b,int len)
{
    static int tmp[N];
    if(len==1)
    {
        b[0]=qpow(a[0],mod-2);
        return;
    }
    get_inv(a,b,len>>1);
    for(int i=0;i<len;i++)tmp[i]=a[i];
    for(int i=len;i<len<<1;i++)tmp[i]=0;
    NTT(tmp,len<<1,1);
    NTT(b,len<<1,1);
    for(int i=0;i<len<<1;i++)
        b[i]=(ll)b[i]*(2-(ll)b[i]*tmp[i]%mod+mod)%mod;
    NTT(b,len<<1,-1);
    for(int i=len;i<len<<1;i++)b[i]=0;
}
void get_root(int *a,int *b,int len)
{
    static int invb[N],tmp[N];
    if(len==1){b[0]=1;return;}
    get_root(a,b,len>>1);
    for(int i=0;i<len<<1;i++)invb[i]=0;
    get_inv(b,invb,len);
    for(int i=0;i<len;i++)tmp[i]=a[i];
    for(int i=len;i<len<<1;i++)tmp[i]=0;
    NTT(tmp,len<<1,1);
    NTT(b,len<<1,1);
    NTT(invb,len<<1,1);
    for(int i=0;i<len<<1;i++)
        b[i]=(ll)inv2*(b[i]+(ll)tmp[i]*invb[i]%mod)%mod;
    NTT(b,len<<1,-1);
    for(int i=len;i<len<<1;i++)b[i]=0;
}
void get_ln(int *a,int *b,int len)
{
    static int inva[N],a1[N];
    for(int i=0;i<len<<1;i++)inva[i]=0;
    get_inv(a,inva,len);
    for(int i=0;i<len;i++)a1[i]=(ll)(i+1)*a[i+1]%mod;
    for(int i=len;i<len<<1;i++)a1[i]=0;
    NTT(a1,len<<1,1);
    NTT(inva,len<<1,1);
    for(int i=0;i<len<<1;i++)a1[i]=(ll)a1[i]*inva[i]%mod;
    NTT(a1,len<<1,-1);
    b[0]=0;
    for(int i=1;i<len;i++)
        b[i]=(ll)a1[i-1]*inv[i]%mod;
    for(int i=len;i<len<<1;i++)b[i]=0;
}
void get_exp(int *a,int *b,int len)
{
    static int lnb[N],tmp[N];
    if(len==1){b[0]=1;return;}
    get_exp(a,b,len>>1);
    for(int i=0;i<len<<1;i++)lnb[i]=0;
    get_ln(b,lnb,len);
    for(int i=0;i<len;i++)tmp[i]=(a[i]-lnb[i]+mod)%mod;
    tmp[0]++;
    for(int i=len;i<len<<1;i++)tmp[i]=0;
    NTT(b,len<<1,1);
    NTT(tmp,len<<1,1);
    for(int i=0;i<len<<1;i++)
        b[i]=(ll)b[i]*tmp[i]%mod;
    NTT(b,len<<1,-1);
    for(int i=len;i<len<<1;i++)b[i]=0;
}
int main()
{
    //freopen("tt.in","r",stdin);
    scanf("%d%d",&n,&m);
    for(len=1;len<=m;len<<=1);
    for(int i=1;i<len;i++)inv[i]=qpow(i,mod-2);
    tmp[0]=1;tmp[1]=6;tmp[2]=1;
    get_root(tmp,sq,len);
    get_inv(sq,inv_sq,len);
    rt1[0]=rt1[1]=1;
    for(int i=0;i<len;i++)
        rt1[i]=(ll)(rt1[i]+sq[i])%mod*inv2%mod;
    get_ln(rt1,ln1,len);
    for(int i=0;i<len;i++)
        ln1[i]=(ll)ln1[i]*(n+1)%mod;
    get_exp(ln1,a,len);
    NTT(inv_sq,len<<1,1);
    NTT(a,len<<1,1);
    for(int i=0;i<len<<1;i++)
        ans[i]=(ll)a[i]*inv_sq[i]%mod;
    NTT(ans,len<<1,-1);
    for(int i=1;i<=m;i++)
        printf("%d ",i>n ? 0:ans[i]);
    return 0;
}

题意:定义一个偶串为一个可写为一个字符串复制两次的字符串。S为一个偶串,定义f(S)为在S后加最少的字符(大于等于1个)使S成为的偶串。

设f(SS)=STST,g(S)=ST。则T为S最小的周期。
$$g(ST)=STT(|T| \\mid |S|) $$
$$g(ST)=STS(|T| \\nmid |S|)$$

#include <bits/stdc++.h>
using namespace std;
const int N=210000;
#define ll long long
char s[N];
ll L,R;
int nex[N];
ll len[210];
struct node
{
    ll a[26];
    void init(int x)
    {
        for(int i=1;i<=x;i++)
            a[s[i]-'a']++;
    }
    friend node operator + (const node &r1,const node &r2)
    {
        node ret;
        for(int i=0;i<26;i++)
            ret.a[i]=r1.a[i]+r2.a[i];
        return ret;
    }
    friend node operator - (const node &r1,const node &r2)
    {
        node ret;
        for(int i=0;i<26;i++)
            ret.a[i]=r1.a[i]-r2.a[i];
        return ret;
    }
    void print()
    {
        for(int i=0;i<26;i++)
            printf("%lld ",a[i]);
        puts("");
    }
}v[210];
node calc(ll x)
{
    int p;
    for(p=1;;p++)
        if(len[p]>=x)break;
    node ret,t;
    memset(&ret,0,sizeof(ret));
    memset(&t,0,sizeof(t));
    for(;p>=1;p--)
        if(len[p]<=x)
        {
            ret=ret+v[p];
            x-=len[p];
        }
    t.init(x);
    ret=ret+t;
    return ret;
}
int main()
{
    //freopen("tt.in","r",stdin);
    scanf("%s%lld%lld",s+1,&L,&R);
    len[2]=strlen(s+1)/2;
    nex[1]=0;
    for(int i=2,j=0;i<=len[2];i++)
    {
        while(j&&s[j+1]!=s[i])
            j=nex[j];
        if(s[j+1]==s[i])j++;
        nex[i]=j;
    }
    len[1]=len[2]-nex[len[2]];
    v[1].init(len[1]);
    v[2].init(len[2]);
    for(int i=3;;i++)
    {
        len[i]=len[i-1]+len[i-2];
        v[i]=v[i-1]+v[i-2];
        if(len[i]>=R)break;
    }
    (calc(R)-calc(L-1)).print();
    return 0;
}

整体二分裸题

#include <bits/stdc++.h>
using namespace std;
#define N 110000
#define PA pair<int,int>
int n,m,Q;
int X[N],Y[N],ans[N];
struct node
{
    int x,y,tar,pos;
}a[N],b[N];
struct Ufs
{
    int now;
    int fa[N],size[N],mem[N];
    void init()
    {
        for(int i=1;i<=n;i++)
            fa[i]=i,size[i]=1;
    }
    int find(int x){return fa[x]==x ? x : find(fa[x]);}
    void move(int tar)
    {
        int x,y;
        while(now<tar)
        {
            now++;
            if((x=find(X[now]))==(y=find(Y[now])))
                mem[now]=0;
            else
            {
                if(size[y]>size[x])swap(x,y);
                size[x]+=size[y];
                fa[y]=x;
                mem[now]=y;
            }
        }
        while(now>tar)
        {
            if(x=mem[now])
            {
                size[fa[x]]-=size[x];
                fa[x]=x;
            }
            now--;
        }
    }
    int calc(int x,int y)
    {
        x=find(x);y=find(y);
        if(x==y)return size[x];
        return size[x]+size[y];
    }
}ufs;
void solve(int l1,int r1,int l2,int r2)
{
    if(l2>r2)return;
    if(l1==r1)
    {
        for(int i=l2;i<=r2;i++)
            ans[a[i].pos]=l1;
        return;
    }
    int mid=(l1+r1)>>1,lm=l2,rm=r2;
    ufs.move(mid);
    for(int i=l2;i<=r2;i++)
    {
        if(ufs.calc(a[i].x,a[i].y)>=a[i].tar)
            b[lm++]=a[i];
        else b[rm--]=a[i];
    }
    for(int i=l2;i<=r2;i++)a[i]=b[i];
    solve(l1,mid,l2,lm-1);
    solve(mid+1,r1,rm+1,r2);
}
int main()
{
    //freopen("tt.in","r",stdin);
    scanf("%d%d",&n,&m);
    for(int i=1;i<=m;i++)
        scanf("%d%d",&X[i],&Y[i]);
    scanf("%d",&Q);
    for(int i=1;i<=Q;i++)
    {
        scanf("%d%d%d",&a[i].x,&a[i].y,&a[i].tar);
        a[i].pos=i;
    }
    ufs.init();
    solve(1,m,1,Q);
    for(int i=1;i<=Q;i++)
        printf("%d\n",ans[i]);
    return 0;
}

题意:m个区间,对1到n的i分别求有多少个区间中出现了i的倍数。

对于数i,长度大于等于的区间i一定包含i的倍数,长度小于i的区间最多包含一个i的倍数,从小到大枚举i,每次把长度小于i的区间从答案中减掉,插入树状数组中。枚举i的倍数,对于i的每个倍数在树状数组中查询。
写了个sb线段树。。。

#include <bits/stdc++.h>
using namespace std;
#define ls l,mid,now<<1
#define rs mid+1,r,now<<1|1
const int N=310000,M=110000;
int n,m,cnt;
struct node
{
    int l,r;
    friend bool operator < (const node &r1,const node &r2)
    {return r1.r<r2.r;}
}a[N];
vector<int>vec[M];
int b[M*20],pos[M],ed[M],beg[M],nv[M*20];
int val[M*20],bj[M*80],ans[M];
void pushdown(int now)
{
    bj[now<<1]+=bj[now];
    bj[now<<1|1]+=bj[now];
    bj[now]=0;
}
void modify(int l,int r,int now,int pos,int v)
{
    if(l==r)
    {
        val[l]+=bj[now]*nv[l];
        bj[now]=0;nv[l]=v;
        return;
    }
    int mid=(l+r)>>1;
    pushdown(now);
    if(mid>=pos)modify(ls,pos,v);
    else modify(rs,pos,v);
}
void insert(int l,int r,int now,int lq,int rq)
{
    if(lq<=l&&r<=rq)
        {bj[now]++;return;}
    int mid=(l+r)>>1;
    if(mid>=lq)insert(ls,lq,rq);
    if(mid<rq) insert(rs,lq,rq);
}
void down(int l,int r,int now)
{
    if(l==r)
    {
        val[l]+=bj[now]*nv[l];
        return;
    }
    int mid=(l+r)>>1;
    pushdown(now);
    down(ls);down(rs);
}
int main()
{
    //freopen("tt.in","r",stdin);
    scanf("%d%d",&n,&m);
    for(int i=1;i<=n;i++)
        scanf("%d%d",&a[i].l,&a[i].r);
    sort(a+1,a+1+n);
    for(int i=1;i<=m;i++)
        for(int j=i;j<=m;j+=i)
            vec[j].push_back(i);
    for(int i=1;i<=m;i++)
    {
        beg[i]=cnt+1;
        for(int j=0;j<vec[i].size();j++)
            b[++cnt]=vec[i][j];
        ed[i]=cnt;
    }
    for(int i=1,now=1,t=0;i<=m;i++)
    {
        while(t<ed[i])
        {
            t++;
            if(pos[b[t]])
                modify(1,cnt,1,pos[b[t]],0);
            modify(1,cnt,1,t,1);
            pos[b[t]]=t;
        }
        while(now<=n&&a[now].r==i)
        {
            insert(1,cnt,1,beg[a[now].l],ed[a[now].r]);
            now++;  
        }
    }
    down(1,cnt,1);
    for(int i=1;i<=cnt;i++)
        ans[b[i]]+=val[i];
    for(int i=1;i<=m;i++)
        printf("%d\n",ans[i]);
    return 0;
}

题意:给出一棵树,树上有一些关键点,一个点集可行当且仅当它可以表示为一个关键点与距它距离d以内的所有点的集合。求可行点集总数。

先不考虑关键点(设所有点均为关键点)
对于一个除全集以外的可行点集可以表示它的点组成树的一个子树。只在子树中d最小的点处统计这个点集。可以发现子树中这样的点只有一个。
因此计算一个点的贡献时,设以他为根时深度最大的子树距该点距离为f1,深度次大的子树距该点距离为f2。则该点的d取值范围为[0,min(f1-1,f2+1)]
当存在关键点时,选取的d必须使一个关键点包含在可行子树中,因此选取的d必须完整覆盖一个包含关键点的子树

#include <bits/stdc++.h>
using namespace std;
#define ll long long
const int N=210000;
int n,tot;
char s[N];
int head[N],nex[N<<1],to[N<<1];
int f1[N],f2[N],g[N],lb[N],size[N];
ll ans;
void add(int x,int y)
{
    tot++;
    nex[tot]=head[x];head[x]=tot;
    to[tot]=y;
}
void upd(int &x,int &y,int z)
{
    if(z>x)y=x,x=z;
    else if(z>y)y=z;
}
void dfs1(int x,int y)
{
    size[x]=s[x]-'0';
    for(int i=head[x],t;i;i=nex[i])if((t=to[i])!=y)
    {
        dfs1(t,x);
        size[x]+=size[t];
        upd(f1[x],f2[x],f1[t]+1);
        if(size[t])
            lb[x]=min(lb[x],f1[t]+1);
    }
}
void dfs2(int x,int y)
{
    for(int i=head[x],t,d;i;i=nex[i])if((t=to[i])!=y)
    {
        g[t]=max(g[x],f1[t]+1==f1[x] ? f2[x] : f1[x])+1;
        if(size[t]!=size[1])
            lb[t]=min(lb[t],g[t]);
        dfs2(t,x);
    }
}
int main()
{
    //freopen("tt.in","r",stdin);
    scanf("%d",&n);
    for(int i=1,x,y;i<n;i++)
    {
        scanf("%d%d",&x,&y);
        add(x,y);add(y,x);
    }
    scanf("%s",s+1);
    for(int i=1;i<=n;i++)
        if(s[i]=='0')lb[i]=n+1;
    dfs1(1,0);
    dfs2(1,0);
    for(int i=1;i<=n;i++)
    {
        upd(f1[i],f2[i],g[i]);
        ans+=max(0,min(f1[i]-1,f2[i]+1)-lb[i]+1);
    }
    printf("%lld\n",ans+1);
    return 0;
}