题解:trie+可持久化trie
今天的互测题,感觉是道不错的字符串题,不过非官方数据较水,可以用hash过掉。
首先对于读入的字符串排序,然后按照顺序加入trie,对于trie中的每个节点都维护一下排序后的那一段区间可以到达这个点(就是那一段的前缀可以匹配到这里)。
然后按照排好的顺序将每个串反过来,建立可持久化trie。
对于每个询问,我们用前缀在trie上匹配,得到一段合法的区间[l,r],然后在可持久化trie上查询[l,r]区间中可以匹配上的后缀的数量即可。
#include<iostream>#include<cstring>#include<cstdio>#include<algorithm>#include<cmath>#define N 2000003using namespace std;int n,cnt,m,sz,tot,x[N],b[N],c[N]; int a[N],ch[N][27],ch1[N][27],size[N],root[N],ls[N],rs[N],l[N],r[N];char s[N],s1[N],s2[N];int cmp(int x,int y){ int len=min(r[x]-l[x]+1,r[y]-l[y]+1); for (int i=0;i<len;i++) if (a[l[x]+i]!=a[l[y]+i]) return a[l[x]+i]<a[l[y]+i]; return r[x]-l[x]<r[y]-l[y];}void insert(int l,int r,int j){ int now=0; for (int i=l;i<=r;i++) { int x=a[i]; if (!ch[now][x]) ch[now][x]=++sz; now=ch[now][x]; ls[now]=min(ls[now],j); rs[now]=max(rs[now],j); }}void buildtree(int i,int l,int r){ int PRe=root[i-1]; root[i]=++tot; int now=root[i]; for (int i=r;i>=l;i--) { int x=a[i]; size[now]=size[pre]+1; ch1[now][x]=++tot; for (int j=1;j<=26;j++) if (j!=x) ch1[now][j]=ch1[pre][j]; now=ch1[now][x]; pre=ch1[pre][x]; } size[now]=size[pre]+1;}int get_pos(){ int len=strlen(s1+1); int now=0; for (int i=1;i<=len;i++) { int x=b[i]; now=ch[now][x]; if (!now) return -1; } return now;}int find(int i,int j){ int len=strlen(s2+1); bool pd=false; for (int k=len;k>=1;k--) { int x=c[k]; if (!pd&&size[ch1[i][x]]==size[ch1[j][x]]||!size[ch1[j][x]]) return 0; //cout<<size[ch1[i][x]]<<" "<<size[ch1[j][x]]<<endl; if (!pd) i=ch1[i][x]; j=ch1[j][x]; if(!j) return 0; if (!i) pd=true; //if (pd) cout<<"!"<<endl; } if (pd) return size[j]; return size[j]-size[i];}int main(){ freopen("xiba.in","r",stdin); freopen("xiba.out","w",stdout); scanf("%d",&n); for (int i=1;i<=n;i++) { scanf("%s",s+1); int len=strlen(s+1); for (int j=1;j<=len;j++) a[++cnt]=s[j]-'a'+1; l[i]=r[i-1]+1; r[i]=cnt; } for (int i=1;i<=n;i++) x[i]=i; sort(x+1,x+n+1,cmp); memset(ls,127/3,sizeof(ls)); for (int i=1;i<=n;i++) { int t=x[i]; insert(l[t],r[t],i); } for (int i=1;i<=n;i++) buildtree(i,l[x[i]],r[x[i]]); scanf("%d",&m); int lastans=0; for (int i=1;i<=m;i++) { scanf("%s%s",s1+1,s2+1); int len=strlen(s1+1); int len1=strlen(s2+1); for (int j=1;j<=len;j++) b[j]=(s1[j]-'a'+1+lastans-1)%26+1; for (int j=1;j<=len1;j++) c[j]=(s2[j]-'a'+1+lastans-1)%26+1; int t=get_pos(); if (t==-1) { lastans=0; printf("0/n"); continue; } printf("%d/n",(lastans=find(root[ls[t]-1],root[rs[t]]))); }}
新闻热点
疑难解答