首页 > 学院 > 开发设计 > 正文

UOJ86 mx的组合数

2019-11-08 18:28:01
字体:
来源:转载
供稿:网友

大概看到题的时候就会做了。好厉害的题。

组合数模质数p等于某值的方案数,很容易想到利用卢卡斯定理。然后要使p进制下每一位对出的结果的乘积在模p意义下为某值,数位dp一波就好。

暴力转移是每位p^2的,但是转移的形式是c[i*j]+=a[i]*b[j],可以考虑找原根R,这就变成了c[logRi+logRj]+=a[logRi]*b[logRj],这里NTT就好了,模数还刚好是998244353。

然后你让我写。就很麻烦了。

先是一波高精度处理。找原根并预处理阶。在阶的基础上转换卷积形式以及NTT板子。哦还有预处理一波阶乘来求组合数。最后就是数位dp。

写还有调搞得我心力憔悴。

代码:

#include <cstdio>#include <cstring>#include <iostream>#define ll long long#define AwD 998244353int p;struct bigint{ int v[105],len;}n,l,r;void read(bigint&a){ char s[35];scanf("%s",s); a.len=strlen(s); for(int i=0;i<a.len;i++) a.v[a.len-i]=s[i]-'0';}int Operator%(bigint a,int b){ int res=0; for(int i=a.len;i;i--) res=(res*10+a.v[i])%b; return res;}bigint operator/(bigint a,int b){ for(int i=a.len;i;i--){ a.v[i-1]+=a.v[i]%b*10; a.v[i]/=b; } while(a.len>1&&!a.v[a.len]) a.len--; return a;}bigint incr(bigint a){ a.v[1]++; for(int i=1;i<a.len;i++) if(a.v[i]>=10){ a.v[i]-=10;a.v[i+1]++; } if(a.v[a.len]>=10){ a.v[a.len]-=10;a.v[++a.len]=1; } return a;}bool zero(bigint a){ return a.len==1&&!a.v[1];}void trs(bigint&a){ int b[105],n=0; while(!zero(a)){ b[++n]=a%p; a=a/p; } for(int i=1;i<=n;i++) a.v[i]=b[i]; a.len=n;}void exp0(bigint&a,int L){ for(int i=a.len+1;i<=L;i++) a.v[i]=0;}int kth[30005],rk[30005],R;void findR(){ R=1; while(1){ for(int i=1;i<p;i++) rk[i]=0; bool flag=kth[0]=1; for(int i=1;i<p;i++){ kth[i]=kth[i-1]*R%p; if(rk[kth[i]]){ flag=0;break; } rk[kth[i]]=i; } if(!flag){ R++;continue; } rk[1]=0; return; }}ll pw(ll x,ll y){ if(y<0) y+=AwD-1; if(!y) return 1; ll res=pw(x,y>>1); (res*=res)%=AwD; if(y&1) (res*=x)%=AwD; return res;}ll ntt(ll*a,int n,int d){ int i,j,k; ll w,t,u,v; for(i=(n>>1),j=1;j<n;j++){ if(i<j) t=a[i],a[i]=a[j],a[j]=t; for(k=(n>>1);i&k;i^=k,k>>=1);i^=k; } for(k=2;k<=n;k<<=1){ w=pw(3,(AwD-1)/k*d); for(i=0;i<n;i+=k){ t=1; for(j=i;j<i+(k>>1);j++){ u=a[j];v=t*a[j+(k>>1)]%AwD; a[j]=(u+v)%AwD;a[j+(k>>1)]=(u-v+AwD)%AwD;t=t*w%AwD; } } }}ll t1[65555],t2[65555];void multi(int*a,int*b,int*res){ int res0=0; for(int i=0;i<p;i++){ res0=(res0+1ll*a[i]*b[0])%AwD; if(i) res0=(res0+1ll*a[0]*b[i])%AwD; } //for(int i=0;i<p;i++) PRintf("%d ",a[i]);printf("!!/n"); //for(int i=0;i<p;i++) printf("%d ",b[i]);printf("!!/n"); for(int i=1;i<p;i++) t1[rk[i]]=a[i],t2[rk[i]]=b[i]; int l=1,invl;while(l<p-1) l<<=1;invl=pw(l<<=1,-1); for(int i=p-1;i<l;i++) t1[i]=t2[i]=0; //for(int i=0;i<l;i++) printf("%lld ",t1[i]);printf("!!/n"); //for(int i=0;i<l;i++) printf("%lld ",t2[i]);printf("!!/n"); ntt(t1,l,1);ntt(t2,l,1); for(int i=0;i<l;i++) (t1[i]*=t2[i])%=AwD; ntt(t1,l,-1); for(int i=0;i<l;i++) t1[i]=t1[i]*invl%AwD; //for(int i=0;i<l;i++) printf("%lld ",t1[i]);printf("!!/n"); for(int i=1;i<p;i++) res[i]=0; for(int i=0;i<l;i++) (res[kth[i%(p-1)]]+=t1[i])%=AwD; res[0]=res0; //for(int i=0;i<p;i++) printf("%d ",res[i]);printf("!!/n");}int fac[30005],inv[30005];int C(int n,int m){ return n<m?0:fac[n]*inv[m]%p*inv[n-m]%p;}int L,dp[105][30005],tmp[30005],cur;void init(){ fac[0]=1;for(int i=1;i<p;i++) fac[i]=fac[i-1]*i%p; inv[p-1]=kth[p-1-rk[fac[p-1]]];for(int i=p-1;i;i--) inv[i-1]=inv[i]*i%p; for(int i=0;i<p;i++) dp[0][i]=i==1; for(int i=1;i<L;i++){ for(int j=0;j<p;j++) tmp[j]=0; for(int j=0;j<p;j++) tmp[C(j,n.v[i])]++; multi(dp[i-1],tmp,dp[i]); }}void solve(bigint a,int*ans){ //for(int i=0;i<L;i++,printf("/n")) for(int j=0;j<p;j++) printf("%d ",dp[i][j]); //printf("solving.../n"); for(int i=0;i<p;i++) ans[i]=0; cur=1; for(int i=L;i;i--){ for(int j=0;j<p;j++) tmp[j]=0; for(int j=0;j<a.v[i];j++) tmp[C(j,n.v[i])]++; //for(int j=0;j<p;j++) printf("%d ",tmp[j]);printf("/n"); multi(tmp,dp[i-1],tmp); //for(int j=0;j<p;j++) printf("%d ",tmp[j]);printf("::/n"); for(int j=0;j<p;j++) (ans[j*cur%p]+=tmp[j])%=AwD; (cur*=C(a.v[i],n.v[i]))%=p; } //printf("solved/n");}int ans1[30005],ans2[30005]; int main(){ scanf("%d",&p);read(n);read(l);read(r);r=incr(r); //printf("reading ok/n"); trs(n);trs(l);trs(r); //printf("transforming ok/n"); L=std::max(n.len,std::max(l.len,r.len)); //printf("L=%d/n",L); exp0(n,L);exp0(l,L);exp0(r,L); //for(int i=L;i;i--) printf("%d ",n.v[i]);printf("/n"); //for(int i=L;i;i--) printf("%d ",l.v[i]);printf("/n"); //for(int i=L;i;i--) printf("%d ",r.v[i]);printf("/n"); //printf("0-expanding ok/n"); findR(); //printf("%d/n",R); //for(int i=0;i<p;i++) printf("%d ",kth[i]);printf("/n"); //for(int i=1;i<p;i++) printf("%d ",rk[i]);printf("/n"); //printf("---/n"); init();solve(r,ans1);solve(l,ans2); for(int i=0;i<p;i++) printf("%d/n",(ans1[i]-ans2[i]+AwD)%AwD);}


发表评论 共有条评论
用户名: 密码:
验证码: 匿名发表