首页 > 编程 > Python > 正文

pytorch打印网络结构的实例

2019-11-25 11:56:25
字体:
来源:转载
供稿:网友

最简单的方法当然可以直接print(net),但是这样网络比较复杂的时候效果不太好,看着比较乱;以前使用caffe的时候有一个网站可以在线生成网络框图,tensorflow可以用tensor board,keras中可以用model.summary()、或者plot_model()。pytorch没有这样的API,但是可以用代码来完成。

(1)安装环境:graphviz

conda install -n pytorch python-graphviz

或:

sudo apt-get install graphviz

或者从官网下载,按此教程。

(2)生成网络结构的代码:

def make_dot(var, params=None):  """ Produces Graphviz representation of PyTorch autograd graph  Blue nodes are the Variables that require grad, orange are Tensors  saved for backward in torch.autograd.Function  Args:    var: output Variable    params: dict of (name, Variable) to add names to node that      require grad (TODO: make optional)  """  if params is not None:    assert isinstance(params.values()[0], Variable)    param_map = {id(v): k for k, v in params.items()}   node_attr = dict(style='filled',           shape='box',           align='left',           fontsize='12',           ranksep='0.1',           height='0.2')  dot = Digraph(node_attr=node_attr, graph_attr=dict(size="12,12"))  seen = set()   def size_to_str(size):    return '('+(', ').join(['%d' % v for v in size])+')'  def add_nodes(var):    if var not in seen:      if torch.is_tensor(var):        dot.node(str(id(var)), size_to_str(var.size()), fillcolor='orange')      elif hasattr(var, 'variable'):        u = var.variable        name = param_map[id(u)] if params is not None else ''        node_name = '%s/n %s' % (name, size_to_str(u.size()))        dot.node(str(id(var)), node_name, fillcolor='lightblue')      else:        dot.node(str(id(var)), str(type(var).__name__))      seen.add(var)      if hasattr(var, 'next_functions'):        for u in var.next_functions:          if u[0] is not None:            dot.edge(str(id(u[0])), str(id(var)))            add_nodes(u[0])      if hasattr(var, 'saved_tensors'):        for t in var.saved_tensors:          dot.edge(str(id(t)), str(id(var)))          add_nodes(t)  add_nodes(var.grad_fn)  return dot

(3)打印网络结构:

import torch from torch.autograd import Variable import torch.nn as nn from graphviz import Digraph class CNN(nn.module):  def __init__(self):   ******   def forward(self,x):   ******   return out *****************************def make_dot(): #复制上面的代码***************************** if __name__ == '__main__':   net = CNN()   x = Variable(torch.randn(1, 1, 1024,1024))   y = net(x)   g = make_dot(y)   g.view()    params = list(net.parameters())   k = 0   for i in params:     l = 1     print("该层的结构:" + str(list(i.size())))     for j in i.size():       l *= j     print("该层参数和:" + str(l))     k = k + l   print("总参数数量和:" + str(k))

(4)结果展示(例如这是一个resnet block类型的网络):

以上这篇pytorch打印网络结构的实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持武林网。

发表评论 共有条评论
用户名: 密码:
验证码: 匿名发表