一.题目大意
给出一个长度为$$N$$的整数序列$$a$$,求有多少$$1$$到$$N$$的排列$$p$$满足如下条件:
$$\bullet$$ 对于每个$$1\le i\le N$$满足$$p_i=a_i$$,$$p_{p_i}=a_i$$ 中至少一个。
答案对$$10^9+7$$ 取模。
$$1\le N\le 10^5$$
$$1\le a_i\le N$$
二.解题报告
考虑对于一个排列$$p$$,连一条从$$i$$到$$p_i$$的边。形成的图由一些环组成。考虑其中一个环。
将原来$$i$$到$$p_i$$的边改为$$i$$到$$p_i$$或$$p_{p_i}$$。
情况1:所有原来的边保持不变。则仍然为一个环。
情况2:环大小为奇数且不为$$1$$,原来$$i$$到$$p_i$$的边变为$$i$$到$$p_{p_i}$$。此时仍然为一个环。
情况3:环大小为偶数,原来$$i$$到$$p_i$$的边变为$$i$$到$$p_{p_i}$$。此时环分裂为两个。
情况4:其他情况时,形成一棵基环内向树。并且以环上每个点为根的树都为一条链。
所求即为有多少排列的图经过转化可以与$$i$$与$$a_i$$连边得到的图相同。
考虑$$i$$与$$a_i$$连边得到的图,将相同大小的环在一起处理。
可以将任意两个相同大小的环合并。
也可以单独成为一个环:对于大小为奇数且不为$$1$$的环有两种方法,对于大小为偶数或$$1$$的环有一种方法。
可以枚举合并环的对数,用组合数计算方案数。
对于基环树,如果存在以环上点为根的树不为一条链那么无解。否则对于一条链有两种放置方法:
方法1:链上的第二个点与根相距$$1$$
方法2:链上的第二个点与根相距$$2$$
一条链放置后的结尾不能超过下一条链的根。
通过当前链的根与下一个根的距离可以确定当前链有0种,一种或两种放法。
这样可以求出对于一个基环树的方案数。
有联通块不是环或基环树时无解。
时间复杂度$$O(n)$$
#include <bits/stdc++.h>
using namespace std;
#define ll long long
const int N=110000,mod=1000000007;
int n,ans;
int jc[N],njc[N],bir[N],nbir[N];
int a[N],fa[N],du[N],num[N],vis[N],tail[N],rem[N];
int f[N][2][2];
vector<int>vec[N],v1;
int qpow(int x,int y)
{
int ret=1;
while(y)
{
if(y&1)ret=(ll)ret*x%mod;
x=(ll)x*x%mod;y>>=1;
}
return ret;
}
void init()
{
for(int i=1;i<=n;i++)fa[i]=i;
jc[0]=njc[0]=bir[0]=nbir[0]=1;
for(int i=1;i<=n;i++)
{
jc[i]=(ll)jc[i-1]*i%mod;
bir[i]=bir[i-1]*2%mod;
}
njc[n]=qpow(jc[n],mod-2);
nbir[n]=qpow(bir[n],mod-2);
for(int i=n-1;i>=1;i--)
{
njc[i]=(ll)njc[i+1]*(i+1)%mod;
nbir[i]=nbir[i+1]*2%mod;
}
}
int find(int x){return fa[x]==x ? x : fa[x]=find(fa[x]);}
void quit(){puts("0");exit(0);}
void ins(int x)
{
if(!tail[x])return;
if(rem[x]<tail[x])quit();
else if(rem[x]>tail[x])ans=ans*2%mod;
}
int main()
{
//freopen("tt.in","r",stdin);
scanf("%d",&n);
init();
for(int i=1;i<=n;i++)
{
scanf("%d",&a[i]);
if(find(i)!=find(a[i]))
fa[find(i)]=find(a[i]);
du[a[i]]++;
}
for(int i=1;i<=n;i++)
{
if(du[i]>2)quit();
vec[find(i)].push_back(i);
}
ans=1;
for(int i=1,sz;i<=n;i++)if(sz=vec[i].size())
{
int p=-1;
for(int j=0;j<sz;j++)
if(du[vec[i][j]]==2)p=vec[i][j];
if(p==-1)num[sz]++;
else
{
int t=0,p1=a[p];
vis[p]=++t;
while(p1!=p)
{
if(vis[p1])quit();
vis[p1]=++t;
p1=a[p1];
}
for(int j=0,cnt;j<sz;j++)
if(du[p1=vec[i][j]]==0)
{
cnt=0;
while(!vis[p1])
{
cnt++;
if(du[p1]==2)quit();
p1=a[p1];
}
tail[p1]=cnt;
}
int now=0;p1=p;
for(int j=1;j<=t*2;j++,p1=a[p1])
{
if(tail[p1])
rem[p1]=max(rem[p1],now+1),now=0;
else now++;
}
v1.clear();
ins(p);p1=a[p];
while(p1!=p)
{
ins(p1);
p1=a[p1];
}
}
}
for(int i=1;i<=n;i++)if(num[i])
{
int t=0,now=1;
for(int j=0;j<=num[i];j+=2)
{
int t1=(ll)jc[num[i]]*njc[num[i]-j]%mod*njc[j>>1]%mod*nbir[j>>1]%mod*now%mod;
if((i&1)&&i!=1)t1=(ll)t1*bir[num[i]-j]%mod;
t=(t+t1)%mod;
now=(ll)now*i%mod;
}
ans=(ll)ans*t%mod;
}
printf("%d\n",ans);
return 0;
}