在tensorflow在mnist集上的使用示例(一)中,我们已经使用tensorflow在mnist集上实现了不错的识别的效果。本文主要是进一步教你构建一个框架完善的神经网络程序,包括将构建模型封装成inference()、loss()、training()、evaluation()四部分,添加状态可视化代码等。
原教程TensorFlow Mechanics 101 中文版 TensorFlow运作方式入门
此示例包含两个代码文件。一个是mnist.py,里面主要定义了建立模型所需要的四个函数inference()、loss()、training()、evaluation(),为底层代码,直接使用即可。另一个文件为full_connected_feed.py,这里就是主要内容啦,我根据教程在jupyter notebook里重写了一遍,如下:
import input_dataimport mnistimport tensorflow as tfimport timeimport os.path#建立了一个参数的dictFLAGS={'learning_rate':0.01,'max_steps':2000,'hidden1':128,'hidden2':32,'batch_size':100, 'input_data_dir':'MNIST_data/','log_dir':'logs_fully_connected_feed/','fake_data':False}# Import datadata_sets = input_data.read_data_sets(FLAGS['input_data_dir'], FLAGS['fake_data'])Extracting MNIST_data/train-images-idx3-ubyte.gzExtracting MNIST_data/train-labels-idx1-ubyte.gzExtracting MNIST_data/t10k-images-idx3-ubyte.gzExtracting MNIST_data/t10k-labels-idx1-ubyte.gzdef placeholder_inputs(batch_size): """Generate placeholder variables to rePResent the input tensors. """ images_placeholder = tf.placeholder(tf.float32, shape=(batch_size, mnist.IMAGE_PIXELS)) labels_placeholder = tf.placeholder(tf.int32, shape=(batch_size)) return images_placeholder, labels_placeholderdef fill_feed_dict(data_set, images_pl, labels_pl): """Fills the feed_dict for training the given step. """ images_feed, labels_feed = data_set.next_batch(FLAGS['batch_size'], FLAGS['fake_data']) feed_dict = { images_pl: images_feed, labels_pl: labels_feed, } return feed_dictdef do_eval(sess, eval_correct, images_placeholder, labels_placeholder, data_set): """Runs one evaluation against the full epoch of data.""" # And run one epoch of eval. true_count = 0 # Counts the number of correct predictions. steps_per_epoch = data_set.num_examples // FLAGS['batch_size'] num_examples = steps_per_epoch * FLAGS['batch_size'] for step in range(steps_per_epoch): feed_dict = fill_feed_dict(data_set, images_placeholder, labels_placeholder) true_count += sess.run(eval_correct, feed_dict=feed_dict) precision = float(true_count) / num_examples print(' Num examples: %d Num correct: %d Precision @ 1: %0.04f' % (num_examples, true_count, precision))# Generate placeholders for the images and labels.images_placeholder, labels_placeholder = placeholder_inputs(FLAGS['batch_size'])# Build a Graph that computes predictions from the inference model.logits = mnist.inference(images_placeholder, FLAGS['hidden1'], FLAGS['hidden2'])# Add to the Graph the Ops for loss calculation.loss = mnist.loss(logits, labels_placeholder)# Add to the Graph the Ops that calculate and apply gradients.train_op = mnist.training(loss, FLAGS['learning_rate'])# Add the Op to compare the logits to the labels during evaluation.eval_correct = mnist.evaluation(logits, labels_placeholder)# Build the summary Tensor based on the TF collection of Summaries.summary = tf.summary.merge_all()# Add the variable initializer Op.init = tf.global_variables_initializer()# Create a saver for writing training checkpoints.saver = tf.train.Saver()# Create a session for running Ops on the Graph.sess = tf.Session()# Instantiate a SummaryWriter to output summaries and the Graph.summary_writer = tf.summary.FileWriter(FLAGS['log_dir'], sess.graph)# Run the Op to initialize the variables.sess.run(init)# Start the training loop.for step in range(FLAGS['max_steps']): start_time = time.time()# Fill a feed dictionary with the actual set of images and labels for this particular training step. feed_dict = fill_feed_dict(data_sets.train, images_placeholder, labels_placeholder)# Run one step of the model. _, loss_value = sess.run([train_op, loss], feed_dict=feed_dict) duration = time.time() - start_time# Write the summaries and print an overview fairly often. if step % 100 == 0:# Print status to stdout. print('Step %d: loss = %.2f (%.3f sec)' % (step, loss_value, duration))# Update the events file. summary_str = sess.run(summary, feed_dict=feed_dict) summary_writer.add_summary(summary_str, step) summary_writer.flush()# Save a checkpoint and evaluate the model periodically. if (step + 1) % 1000 == 0 or (step + 1) == FLAGS['max_steps']: checkpoint_file = os.path.join(FLAGS['log_dir'], 'model.ckpt') saver.save(sess, checkpoint_file, global_step=step)# Evaluate against the training set. print('Training Data Eval:') do_eval(sess, eval_correct, images_placeholder, labels_placeholder, data_sets.train)# Evaluate against the validation set. print('Validation Data Eval:') do_eval(sess, eval_correct, images_placeholder, labels_placeholder, data_sets.validation)# Evaluate against the test set. print('Test Data Eval:') do_eval(sess, eval_correct, images_placeholder, labels_placeholder, data_sets.test)Step 0: loss = 2.34 (0.406 sec)Step 100: loss = 2.18 (0.004 sec)Step 200: loss = 1.96 (0.004 sec)Step 300: loss = 1.67 (0.004 sec)Step 400: loss = 1.41 (0.004 sec)Step 500: loss = 0.96 (0.004 sec)Step 600: loss = 0.98 (0.004 sec)Step 700: loss = 0.92 (0.004 sec)Step 800: loss = 0.58 (0.004 sec)Step 900: loss = 0.49 (0.004 sec)Training Data Eval: Num examples: 55000 Num correct: 47612 Precision @ 1: 0.8657Validation Data Eval: Num examples: 5000 Num correct: 4369 Precision @ 1: 0.8738Test Data Eval: Num examples: 10000 Num correct: 8746 Precision @ 1: 0.8746Step 1000: loss = 0.52 (0.015 sec)Step 1100: loss = 0.48 (0.110 sec)Step 1200: loss = 0.61 (0.004 sec)Step 1300: loss = 0.34 (0.004 sec)Step 1400: loss = 0.53 (0.004 sec)Step 1500: loss = 0.30 (0.004 sec)Step 1600: loss = 0.53 (0.004 sec)Step 1700: loss = 0.37 (0.004 sec)Step 1800: loss = 0.48 (0.004 sec)Step 1900: loss = 0.39 (0.004 sec)Training Data Eval: Num examples: 55000 Num correct: 49317 Precision @ 1: 0.8967Validation Data Eval: Num examples: 5000 Num correct: 4510 Precision @ 1: 0.9020Test Data Eval: Num examples: 10000 Num correct: 9030 Precision @ 1: 0.9030新闻热点
疑难解答