There are three main methods of getting data into a TensorFlow PRogram:
Feeding: Python code provides the data when running each step.Reading from files: an input pipeline reads the data from files at the beginning of a TensorFlow graph.Preloaded data: a constant or variable in the TensorFlow graph holds all the data (for small data sets).TensorFlow's feed mechanism lets you inject data into any Tensor in acomputation graph. A python computation can thus feed data directly into thegraph.
Supply feed data through the feed_dict
argument to a run() or eval() callthat initiates computation.
with tf.session(): input = tf.placeholder(tf.float32) classifier = ... print(classifier.eval(feed_dict={input: my_python_preprocessing_fn()}))While you can replace any Tensor with feed data, including variables andconstants, the best practice is to use aplaceholder
op node. Aplaceholder
exists solely to serve as the target of feeds. It is notinitialized and contains no data. A placeholder generates an error ifit is executed without a feed, so you won't forget to feed it.
An example using placeholder
and feeding to train on MNIST data can be foundintensorflow/examples/tutorials/mnist/fully_connected_feed.py
,and is described in the MNIST tutorial.
x = tf.placeholder(tf.float32, shape=(1024, 1024))y = tf.matmul(x, x)with tf.Session() as sess: print(sess.run(y)) # ERROR: will fail because x was not fed. rand_array = np.random.rand(1024, 1024) print(sess.run(y, feed_dict={x: rand_array})) # Will succeed.
tf.placeholder(dtype, shape=None, name=None)
Args:
dtype
: The type of elements in the tensor to be fed.shape
: The shape of the tensor to be fed (optional). If the shape is not specified, you can feed a tensor of any shape.name
: A name for the Operation (optional).Returns:
A Tensor
that may be used as a handle for feeding a value, but not evaluated directly.
for step in xrange(num_batch): batch_imgs, batch_labels = self.nextBatch(train_imgs, train_labels, step, BATCH_SIZE) # Fit training using batch data # print("IMG_PL = ", self.img_pl.get_shape()) _, single_cost = sess.run([optimizer, cost], feed_dict={self.img_pl: batch_imgs, self.label_pl: batch_labels, self.keep_prob: dropout})
新闻热点
疑难解答