首页 > 编程 > Python > 正文

详解tensorflow训练自己的数据集实现CNN图像分类

2020-02-22 23:09:48
字体:
来源:转载
供稿:网友

利用卷积神经网络训练图像数据分为以下几个步骤

1.读取图片文件
2.产生用于训练的批次
3.定义训练的模型(包括初始化参数,卷积、池化层等参数、网络)
4.训练

1 读取图片文件

def get_files(filename): class_train = [] label_train = [] for train_class in os.listdir(filename):  for pic in os.listdir(filename+train_class):   class_train.append(filename+train_class+'/'+pic)   label_train.append(train_class) temp = np.array([class_train,label_train]) temp = temp.transpose() #shuffle the samples np.random.shuffle(temp) #after transpose, images is in dimension 0 and label in dimension 1 image_list = list(temp[:,0]) label_list = list(temp[:,1]) label_list = [int(i) for i in label_list] #print(label_list) return image_list,label_list

这里文件名作为标签,即类别(其数据类型要确定,后面要转为tensor类型数据)。

然后将image和label转为list格式数据,因为后边用到的的一些tensorflow函数接收的是list格式数据。

2 产生用于训练的批次

def get_batches(image,label,resize_w,resize_h,batch_size,capacity): #convert the list of images and labels to tensor image = tf.cast(image,tf.string) label = tf.cast(label,tf.int64) queue = tf.train.slice_input_producer([image,label]) label = queue[1] image_c = tf.read_file(queue[0]) image = tf.image.decode_jpeg(image_c,channels = 3) #resize image = tf.image.resize_image_with_crop_or_pad(image,resize_w,resize_h) #(x - mean) / adjusted_stddev image = tf.image.per_image_standardization(image)  image_batch,label_batch = tf.train.batch([image,label],            batch_size = batch_size,            num_threads = 64,            capacity = capacity) images_batch = tf.cast(image_batch,tf.float32) labels_batch = tf.reshape(label_batch,[batch_size]) return images_batch,labels_batch

首先使用tf.cast转化为tensorflow数据格式,使用tf.train.slice_input_producer实现一个输入的队列。

label不需要处理,image存储的是路径,需要读取为图片,接下来的几步就是读取路径转为图片,用于训练。

CNN对图像大小是敏感的,第10行图片resize处理为大小一致,12行将其标准化,即减去所有图片的均值,方便训练。

接下来使用tf.train.batch函数产生训练的批次。

最后将产生的批次做数据类型的转换和shape的处理即可产生用于训练的批次。

3 定义训练的模型

(1)训练参数的定义及初始化

def init_weights(shape): return tf.Variable(tf.random_normal(shape,stddev = 0.01))#init weightsweights = { "w1":init_weights([3,3,3,16]), "w2":init_weights([3,3,16,128]), "w3":init_weights([3,3,128,256]), "w4":init_weights([4096,4096]), "wo":init_weights([4096,2]) }#init biasesbiases = { "b1":init_weights([16]), "b2":init_weights([128]), "b3":init_weights([256]), "b4":init_weights([4096]), "bo":init_weights([2]) }            
发表评论 共有条评论
用户名: 密码:
验证码: 匿名发表