学习深度学习一年多了,一个感觉是实验结果的好坏在很大程度上取决于数据;数据对于深度学习算法十分关键,数据集的大小影响着模型的精度和泛化能力,好的数据处理技巧锦上添花,而合适的数据输入输出方法使Tensor
“流动”得更加顺畅更好的发挥机器的性能,为模型的训练节约时间。许多情况下,对于数据的处理花的时间往往比模型的修改花的时间多,因此本文专门针对数据处理(图像类)进行一次梳理归纳,利人利己。
TensorFlow有三种数据读取方式: 1. 预先加载数据 2. 使用python将数据feed
到Tensor
中 3. 从文件读取数据
预先加载数据
第一种方式直接把数据写在代码里进行运算,这种方式在一些简单的演示算法中很常见
import tensorflow as tfa = tf.constant(3.0)b = tf.constant(4.0)c = a + bwith tf.session() as sess: PRint(sess.run(c))feed
第二种方法是利用tf.placeholder
提供一个数据输入的接口,在启动计算图时将数据通过这个接口输入计算图
import tensorfow as tftrain_images = ...train_labels = ...X = tf.placeholder([], dtypes=tf.float32)Y = tf.placeholder([], dtypes=tf.uint8)train_op = ...with tf.Session() as sess: sess.run(train_op, feed_dict={X: train_images, Y: train_labels})从文件读取
第三种方法从文件中读取,涉及到数据的转换和读取两个方面,数据的转换又有各种格式可以选择,这里简单列举几个常用的数据存储与读取方法,最后介绍TensorFlow
标准存储格式TFRecord
的转换和读取方法。
1、 .pkl
.pkl
文件是一种特殊的串行化存储的二进制格式文件,可以存储大部分常见的Python对象,使用起来十分方便
import pickledef data_to_file(image_data, label): with open('somedata.pkl', 'wb') as f: pickle.dump([image_data, label], f)def file_to_data(pkl_file): with open(pkl_file, 'rb') as f: image_data, label = pickle.load(f) 实际应用案例可参考上一篇博客
2、 TFRecord
如上所述,TFRecord是TensorFlow的标准存储格式,尽管这种数据格式的转换方式不是很直观,不是一两行代码就能搞定的,但是在使用时TensorFlow设计了一套高效的API来专门处理这种文件,配合TensorFlow图像处理的API使其在数据处理方面就显得更有优势了。下面的代码简单的展示了怎样将一张图片转换成.tfrecord
文件,以及从文件解析出图片。
"""Converts image data to TFRecords file format with Example protos."""from __future__ import absolute_importfrom __future__ import divisionfrom __future__ import print_functionfrom PIL import Imageimport tensorflow as tf# Input must be type int or long.def _int64_feature(value): return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))# Input must be type bytesdef _bytes_feature(value): return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))def convert_to(data, name): filename = name + '.tfrecords' print('Writing', filename) writer = tf.python_io.TFRecordWriter(filename) # when there is one picture example = tf.train.Example(features=tf.train.Features(feature={ 'label': _int64_feature(data[1]), 'image_raw': _bytes_feature(data[0])})) writer.write(example.SerializeToString()) # # TODO # # when there is many pictures # for index in range(num_examples): # .... writer.close()def main_1(): # Get the data. images = Image.open('image_0006.jpg') images.resize((224,224)) image_raw = images.tobytes() labels = 0 data_sets = [image_raw, labels] # TODO: for large scale image dataset, # a better way is reading while saving # Convert to Examples and write the result to TFRecords. convert_to(data_sets, 'test')def main_2(): # Get the data. img_file = tf.read_file('image_0006.jpg') images = tf.image.decode_jpeg(img_file) images = tf.image.resize_images(images, [224,224]) with tf.Session() as sess: image_raw = sess.run(tf.cast(images, tf.uint8)) image_raw = image_raw.tobytes() labels = 0 data_sets = [image_raw, labels] convert_to(data_sets, 'test')if __name__ == '__main__': main_2()from __future__ import absolute_importfrom __future__ import divisionfrom __future__ import print_functionfrom PIL import Imageimport tensorflow as tfdef read_and_decode(filename): filename_queue = tf.train.string_input_producer([filename]) reader = tf.TFRecordReader() _, serialized_example = reader.read(filename_queue) features = tf.parse_single_example(serialized_example, features={ 'label': tf.FixedLenFeature([], tf.int64), 'img_raw' : tf.FixedLenFeature([], tf.string), }) image = tf.image.decode_jpeg(features['img_raw'], channels=3) image = tf.image.convert_image_dtype(image, dtype=tf.float32) label = tf.cast(features['label'], tf.int32) return image, label 说明: 1、代码中的example
是样本的意思哦 2、本代码只展示了将一张图转化为tfrecord
格式 3、图片的编解码、裁剪、缩放、旋转等操作,TensorFlow都有自己的函数可以代替第三方库的功能,根据习惯自己选择。
References:
tensorflow/g3doc/how_tos/reading_data/index.mdhttp://blog.csdn.net/u012759136/article/details/52232266