标签 树形dp 下的文章

题意:给出一棵树,树上有一些关键点,一个点集可行当且仅当它可以表示为一个关键点与距它距离d以内的所有点的集合。求可行点集总数。

先不考虑关键点(设所有点均为关键点)
对于一个除全集以外的可行点集可以表示它的点组成树的一个子树。只在子树中d最小的点处统计这个点集。可以发现子树中这样的点只有一个。
因此计算一个点的贡献时,设以他为根时深度最大的子树距该点距离为f1,深度次大的子树距该点距离为f2。则该点的d取值范围为[0,min(f1-1,f2+1)]
当存在关键点时,选取的d必须使一个关键点包含在可行子树中,因此选取的d必须完整覆盖一个包含关键点的子树

#include <bits/stdc++.h>
using namespace std;
#define ll long long
const int N=210000;
int n,tot;
char s[N];
int head[N],nex[N<<1],to[N<<1];
int f1[N],f2[N],g[N],lb[N],size[N];
ll ans;
void add(int x,int y)
{
    tot++;
    nex[tot]=head[x];head[x]=tot;
    to[tot]=y;
}
void upd(int &x,int &y,int z)
{
    if(z>x)y=x,x=z;
    else if(z>y)y=z;
}
void dfs1(int x,int y)
{
    size[x]=s[x]-'0';
    for(int i=head[x],t;i;i=nex[i])if((t=to[i])!=y)
    {
        dfs1(t,x);
        size[x]+=size[t];
        upd(f1[x],f2[x],f1[t]+1);
        if(size[t])
            lb[x]=min(lb[x],f1[t]+1);
    }
}
void dfs2(int x,int y)
{
    for(int i=head[x],t,d;i;i=nex[i])if((t=to[i])!=y)
    {
        g[t]=max(g[x],f1[t]+1==f1[x] ? f2[x] : f1[x])+1;
        if(size[t]!=size[1])
            lb[t]=min(lb[t],g[t]);
        dfs2(t,x);
    }
}
int main()
{
    //freopen("tt.in","r",stdin);
    scanf("%d",&n);
    for(int i=1,x,y;i<n;i++)
    {
        scanf("%d%d",&x,&y);
        add(x,y);add(y,x);
    }
    scanf("%s",s+1);
    for(int i=1;i<=n;i++)
        if(s[i]=='0')lb[i]=n+1;
    dfs1(1,0);
    dfs2(1,0);
    for(int i=1;i<=n;i++)
    {
        upd(f1[i],f2[i],g[i]);
        ans+=max(0,min(f1[i]-1,f2[i]+1)-lb[i]+1);
    }
    printf("%lld\n",ans+1);
    return 0;
}

题意:一棵树,点i上有ai个石头,每次操作可以选择一条两个叶子之间的简单路径,并从路径上每个节点都移除一个石头。求是否可以经过一些操作之后移除所有石头。

从下往上dp,设点i子树中需要通过该点的路径数和为sum,那么就有sum-ai条路径跨过i点,2ai-sum条路径经过ai父亲。设需要通过该点的路径数最大的子树需要通过该点的路径数为mx,那么如果sum-mx<sum-ai则无解。

#include <bits/stdc++.h>
using namespace std;
#define N 110000
int n,tot,root;
int a[N],head[N],nex[N<<1],to[N<<1];
int du[N],rem[N];
void add(int x,int y)
{
    tot++;
    nex[tot]=head[x];head[x]=tot;
    to[tot]=y;
}
void quit(){puts("NO");exit(0);}
void dfs(int x,int y)
{
    if(du[x]==1)
    {
        rem[x]=a[x];
        return;
    }
    int mx=0,sum=0;
    for(int i=head[x];i;i=nex[i])
        if(to[i]!=y)
        {
            dfs(to[i],x);
            sum+=rem[to[i]];
            mx=max(mx,rem[to[i]]);
        }
    int v1=sum-a[x],v2=a[x]-v1;
    if(v1<0||v2<0)quit();
    if(sum-mx<v1)quit();
    rem[x]=v2;
}
int main()
{
    //freopen("tt.in","r",stdin);
    scanf("%d",&n);
    for(int i=1;i<=n;i++)
        scanf("%d",&a[i]);
    for(int i=1,x,y;i<n;i++)
    {
        scanf("%d%d",&x,&y);
        add(x,y);add(y,x);
        du[x]++;du[y]++;
    }
    if(n==2)return puts(a[1]==a[2] ? "YES" : "NO"),0;
    root=1;
    while(du[root]==1)root++;
    dfs(root,0);
    if(rem[root])quit();
    puts("YES");
    return 0;
}