题目链接
题目大意:给定一个值S和一棵树。在树的每个节点有一个正整数,问有多少条路径的节点总和达到S。路径中节点的深度必须是升序的。假设节点1是根节点,根的深度是0,它的儿子节点的深度为1。路径不必一定从根节点开始。
题解:倍增+二分O(nlog^2n),(Orz hzwer神奇做法)
#include <iostream>#include <cstdio>#include <algorithm>#include <cstring> using namespace std;const int M=100005;int n,s,t;int head[M],w[M],f[M][22],l[M][22];struct edge{ int to,nex;}e[M*2];int read(){ int x=0,f=1;char c=getchar(); while(c>'9'||c<'0') {if(c=='-') f=-1; c=getchar();} while(c>='0'&&c<='9') x=x*10+c-48,c=getchar(); return x*f;}void add(int i,int j){ e[t].to=j; e[t].nex=head[i]; head[i]=t++;}void dfs(int x){ for(int i=head[x];i!=-1;i=e[i].nex){ int v=e[i].to; if(v!=f[x][0]){ f[v][0]=x; l[v][0]=w[x]; dfs(v); } }}void st(){ for(int j=1;j<=20;j++) for(int i=1;i<=n;i++) f[i][j]=f[f[i][j-1]][j-1], l[i][j]=l[i][j-1]+l[f[i][j-1]][j-1];}int get_path(int x,int dep){ int ret=0; for(int i=18;i>=0;i--) { if(dep>=(1<<i)){ dep-=(1<<i); ret+=l[x][i]; x=f[x][i]; } } return ret;}bool solve(int x){ int l=0,r=10000,mid; while(l<=r) { mid=(l+r)>>1; int len=get_path(x,mid)+w[x]; if(len>s) r=mid-1; else if(len<s) l=mid+1; else return 1; } return 0;}void work(){ int tot=0; dfs(1); st(); for(int i=1;i<=n;i++) tot+=solve(i); cout<<tot<<endl;}void init(){ int x,y; memset(head,-1,sizeof(head)); cin>>n>>s; for(int i=1;i<=n;i++) w[i]=read(); for(int i=1;i<n;i++){ x=read();y=read(); add(x,y),add(y,x); }}int main(){ init(); work(); return 0;}#include <iostream>#include <cstdio>#include <cstring>#include <set>#include <algorithm>#define inf 2000000000using namespace std;const int M=100005;int n,s,cnt,ans,t;int w[M],head[M],sum[M];multiset<int> st;struct edge{ int to,nex;}e[M*2];int read(){ int x=0,f=1;char ch=getchar(); while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();} while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();} return x*f;}void add(int i,int j){ e[t].to=j; e[t].nex=head[i]; head[i]=t++;}void dfs(int x,int fa){ if(st.find(sum[x]-s)!=st.end()) ans++; st.insert(sum[x]); for(int i=head[x];i!=-1;i=e[i].nex) { int v=e[i].to; if(v!=fa){ sum[v]=sum[x]+w[v]; dfs(v,x); } } st.erase(sum[x]);}void work(){ dfs(1,0); PRintf("%d/n",ans);}void init(){ int u,v; memset(head,-1,sizeof(head)); cin>>n>>s; st.insert(0); for(int i=1;i<=n;i++) w[i]=read(); for(int i=1;i<n;i++) { u=read(),v=read(); add(u,v),add(v,u); } sum[1]=w[1];}int main(){ init(); work(); return 0;}新闻热点
疑难解答