首页 > 学院 > 开发设计 > 正文

bzoj 3992: [SDOI2015]序列统计 NTT

2019-11-06 07:55:32
字体:
来源:转载
供稿:网友

题意

小C有一个集合S,里面的元素都是小于M的非负整数。他用程序编写了一个数列生成器,可以生成一个长度为N的数列,数列中的每个数都属于集合S。 小C用这个生成器生成了许多这样的数列。但是小C有一个问题需要你的帮助:给定整数x,求所有可以生成出的,且满足数列中所有数的乘积mod M的值等于x的不同的数列的有多少个。小C认为,两个数列{Ai}和{Bi}不同,当且仅当至少存在一个整数i,满足Ai≠Bi。另外,小C认为这个问题的答案可能很大,因此他只需要你帮助他求出答案mod 1004535809的值就可以了。 对于全部的数据,1<=N<=109,3<=M<=8000,M为质数,1<=x<=M-1,输入数据保证集合S中元素不重复

分析

NTT处女题。。。调了超久发现是把n次单位根算成n/2次单位根了。 可以先找到M的原根g,那么gs{1<=s<m−1}模m下互不同余,那么就可以把s数组的成绩转换成g的幂次的和的形式。 求原根的方法是,设m的标准分解式为pq11pq22...pqnn,那么若g为m的原根则必满足gmpk均不为1. 然后将s数组转换成一个多项式,因为带取模,所以直接上NTT即可。 NTT(快速数论变换): 这里写图片描述

代码

#include<iostream>#include<cstdio>#include<cstdlib>#include<cstring>#include<algorithm>#include<cmath>#define MOD 1004535809#define MAXN 100005#define LL long longusing namespace std;int n,m,x,S,top,sta[MAXN],a[MAXN],b[MAXN],c[MAXN],ans[MAXN],g,ny,ind[MAXN],rev[MAXN],N,lg;int ksm(int x,int y,int mo){ int ans=1; while (y) { if (y&1) ans=(LL)ans*x%mo; y>>=1;x=(LL)x*x%mo; } return ans;}int get_g(int m){ int tmp=m-1; for (int i=2;i<=m;i++) if (tmp%i==0) { sta[++top]=i; while (tmp%i==0) tmp/=i; } for (int i=2;i<=m;i++) { int j; for (j=1;j<=top;j++) if (ksm(i,(m-1)/sta[j],m)==1) break; if (j==top+1) return i; }}void NTT(int *a,int f){ for (int i=0;i<N;i++) if (i<rev[i]) swap(a[i],a[rev[i]]); for (int i=1;i<N;i<<=1) { int wn=ksm(3,f==1?(MOD-1)/i/2:MOD-1-(MOD-1)/i/2,MOD); for (int j=0;j<N;j+=(i<<1)) { int w=1; for (int k=0;k<i;k++) { int u=a[j+k],v=(LL)w*a[j+k+i]%MOD; a[j+k]=(u+v)%MOD;a[j+k+i]=(u-v+MOD)%MOD; w=(LL)w*wn%MOD; } } } if (f==-1) for (int i=0;i<N;i++) a[i]=(LL)a[i]*ny%MOD;}void mul(int *a,int *bb,int *cc){ for (int i=0;i<N;i++) b[i]=bb[i],c[i]=cc[i]; NTT(b,1);NTT(c,1); for (int i=0;i<N;i++) a[i]=(LL)b[i]*c[i]%MOD; NTT(a,-1); for (int i=m-1;i<N;i++) a[i-m+1]=(a[i-m+1]+a[i])%MOD,a[i]=0;}void solve(int *a){ while (n) { if (n&1) mul(ans,ans,a); mul(a,a,a);n>>=1; }}void bitrev(){ for (int i=0;i<N;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(lg-1));}int main(){ scanf("%d%d%d%d",&n,&m,&x,&S); for (N=1;N<=m*2;N*=2,lg++); g=get_g(m); for (int i=1,w=g;i<m-1;i++,w=(LL)w*g%m) ind[w]=i; for (int i=1;i<=S;i++) { int x; scanf("%d",&x); if (x==0) continue; a[ind[x]]=1; } ny=ksm(N,MOD-2,MOD); ans[0]=1; bitrev();solve(a); PRintf("%d",ans[ind[x]]); return 0;}
发表评论 共有条评论
用户名: 密码:
验证码: 匿名发表