首页 > 编程 > Python > 正文

python使用tensorflow保存、加载和使用模型的方法

2020-02-22 23:04:45
字体:
来源:转载
供稿:网友

使用Tensorflow进行深度学习训练的时候,需要对训练好的网络模型和各种参数进行保存,以便在此基础上继续训练或者使用。介绍这方面的博客有很多,我发现写的最好的是这一篇官方英文介绍:

http://cv-tricks.com/tensorflow-tutorial/save-restore-tensorflow-models-quick-complete-tutorial/

我对这篇文章进行了整理和汇总。

首先是模型的保存。直接上代码:

#!/usr/bin/env python #-*- coding:utf-8 -*- ############################ #File Name: tut1_save.py #Author: Wang  #Mail: wang19920419@hotmail.com #Created Time:2017-08-30 11:04:25 ############################  import tensorflow as tf  # prepare to feed input, i.e. feed_dict and placeholders w1 = tf.Variable(tf.random_normal(shape = [2]), name = 'w1') # name is very important in restoration w2 = tf.Variable(tf.random_normal(shape = [2]), name = 'w2') b1 = tf.Variable(2.0, name = 'bias1') feed_dict = {w1:[10,3], w2:[5,5]}  # define a test operation that will be restored w3 = tf.add(w1, w2) # without name, w3 will not be stored w4 = tf.multiply(w3, b1, name = "op_to_restore")  #saver = tf.train.Saver() saver = tf.train.Saver(max_to_keep = 4, keep_checkpoint_every_n_hours = 1) sess = tf.Session() sess.run(tf.global_variables_initializer()) print sess.run(w4, feed_dict) #saver.save(sess, 'my_test_model', global_step = 100) saver.save(sess, 'my_test_model') #saver.save(sess, 'my_test_model', global_step = 100, write_meta_graph = False) 

需要说明的有以下几点:

1. 创建saver的时候可以指明要存储的tensor,如果不指明,就会全部存下来。在这里也可以指明最大存储数量和checkpoint的记录时间。具体细节看英文博客。

2. saver.save()函数里面可以设定global_step和write_meta_graph,meta存储的是网络结构,只在开始运行程序的时候存储一次即可,后续可以通过设置write_meta_graph = False加以限制。

3. 这个程序执行结束后,会在程序目录下生成四个文件,分别是.meta(存储网络结构)、.data和.index(存储训练好的参数)、checkpoint(记录最新的模型)。

下面是如何加载已经保存的网络模型。这里有两种方法,第一种是saver.restore(sess, 'aaaa.ckpt'),这种方法的本质是读取全部参数,并加载到已经定义好的网络结构上,因此相当于给网络的weights和biases赋值并执行tf.global_variables_initializer()。这种方法的缺点是使用前必须重写网络结构,而且网络结构要和保存的参数完全对上。第二种就比较高端了,直接把网络结构加载进来(.meta),上代码:

#!/usr/bin/env python #-*- coding:utf-8 -*- ############################ #File Name: tut2_import.py #Author: Wang  #Mail: wang19920419@hotmail.com #Created Time:2017-08-30 14:16:38 ############################  import tensorflow as tf sess = tf.Session() new_saver = tf.train.import_meta_graph('my_test_model.meta') new_saver.restore(sess, tf.train.latest_checkpoint('./')) print sess.run('w1:0')             
发表评论 共有条评论
用户名: 密码:
验证码: 匿名发表