首页 > 学院 > 开发设计 > 正文

caffe分类实现

2019-11-06 08:47:10
字体:
来源:转载
供稿:网友

本文介绍采用NVIDIA digits训练的模型对图片数据集进行预测分类,使用caffe训练的模型同样有效,在此主要介绍使用digits训练的模型。

NVIDIA digits 官网caffe源码

一、环境配置

1、digits环境安装 具体不介绍了,官方有:digits安装 2、caffe环境(可选) 因为digits已将caffe封装,可直接安装NVIDIA的digits,当然有caffe环境的可不看 可按照官方教程安装编译caffe(官方安装说明)

二、python实现

# -*- coding:utf-8 -*-import numpy as npimport sys,os,caffeimport jsonimport shutil#统计分类后的图片数glassesNum = 0no_glassesNum = 0def model_classify(image_path):# 如果环境变量里没有配置,可将注释去掉# sys.path.insert(0, caffe_root + 'python')# os.chdir(caffe_root) net = caffe.Net(net_file,caffe_model,caffe.TEST) transformer = caffe.io.Transformer({'data': net.blobs['data'].data.shape}) transformer.set_transpose('data', (2,0,1)) transformer.set_mean('data', np.load(mean_file).mean(1).mean(1)) transformer.set_raw_scale('data', 255) transformer.set_channel_swap('data', (2,1,0)) # if using RGB instead of BGR img = caffe.io.load_image(image_path) net.blobs['data'].data[...] = transformer.PReprocess('data',img) out = net.forward() labels = np.loadtxt(labels_file, str, delimiter='/t') top1 = net.blobs['softmax'].data[0].flatten() top_k = top1.argsort()[-1:-6:-1] value = round(top1[top_k[0]]*100,2) key = str(labels[top_k[0]]) value1 = round(top1[top_k[1]]*100,2) key1 = str(labels[top_k[1]]) keytemp =work_root+key keytemp1 =work_root+key1 unknow =work_root+'unknow' data = {} data[key] = value data[key1] = value1 jsonstr = json.dumps(data) if not os.path.exists(keytemp): os.makedirs(keytemp) if not os.path.exists(keytemp1): os.makedirs(keytemp1) if not os.path.exists(unknow): os.makedirs(unknow) image_path_target='/'+key+str(value)+'_'+image_path.split('/')[-1] image_path_target1='/'+key1+str(value1)+'_'+image_path.split('/')[-1] if value >95.00: shutil.copy(image_path,keytemp+image_path_target) global no_glassesNum no_glassesNum +=1 print "================="+str(no_glassesNum+glassesNum)+'/'+filesum,jsonstr+"========================================" if value1 > 95.00: shutil.copy(image_path,keytemp1+image_path_target1) global glassesNum glassesNum +=1 print "================="+str(no_glassesNum+glassesNum)+'/'+filesum,jsonstr+"============================================" else: shutil.copy(image_path,unknow+image_path_target)# 转换bp格式图像均值文件为npy格式def BpToNpy(): #sys.path.insert(0, caffe_root + 'python') MEAN_PROTO_PATH = 'mean.binaryproto' # 待转换的pb格式图像均值文件路径 MEAN_NPY_PATH = 'mean.npy' # 转换后的numpy格式图像均值文件路径 blob = caffe.proto.caffe_pb2.BlobProto() # 创建protobuf blob data = open(MEAN_PROTO_PATH, 'rb' ).read() # 读入mean.binaryproto文件内容 blob.ParseFromString(data) # 解析文件内容到blob array = np.array(caffe.io.blobproto_to_array(blob))# 将blob中的均值转换成numpy格式,array的shape (mean_number,channel, hight, width) mean_npy = array[0] # 一个array中可以有多组均值存在,故需要通过下标选择其中一组均值 np.save(MEAN_NPY_PATH ,mean_npy)# 获取模型文件def getFileName(path): global net_file,caffe_model,labels_file,mean_file f_list = os.listdir(path) # print f_list for filename in f_list: # os.path.splitext():分离文件名与扩展名 if os.path.splitext(filename)[1] == '.prototxt': net_file = work_root+filename if os.path.splitext(filename)[1] == '.caffemodel': caffe_model = work_root+filename if os.path.splitext(filename)[1] == '.txt': labels_file = work_root+filename if os.path.splitext(filename)[1] == '.npy': mean_file = work_root+filenameif __name__ == '__main__': work_root = os.getcwd()+'/' f_list = os.listdir(work_root) for filename in f_list: if not filename.endswith('npy'): BpToNpy() getFileName(work_root) #如果环境中没有设置caffe工作路径,设置路径 #caffe_root = '/dataTwo/caffe-ssd' image_path = raw_input("Input your image path: ") #判断路径是否存在 if os.path.exists(image_path): #判断是否为目录路径 if os.path.isdir(image_path): filesum =str(len(sum([i[2] for i in os.walk(image_path)],[]))) filenames = os.listdir(image_path) for fn in filenames: fullfilename = os.path.join(image_path,fn) model_classify(fullfilename) else: #判断是否为列表文件 if os.path.exists(image_path): imglist = open(image_path) line = imglist.readline() while line: print line, line = imglist.readline().replace('/n','').replace('/r/n','') model_classify(line) imglist.close() model_classify(image_path) else: iamgefile = urllib.urlopen(image_path) status=iamgefile.code #判路路径是否为网络路径 if(status==200): image_data = iamgefile.read() #获取图片名 image_name = os.path.basename(image_path) #创建新的图片地址 new_imagepath = filepath+"/"+image_name #保存图片 with open(new_imagepath, 'wb') as code: code.write(image_data) model_classify(new_imagepath) else: print "not found the folder!"

这里就不过多介绍了,可参考我的GitHub:wulivicte/caffe_classify


发表评论 共有条评论
用户名: 密码:
验证码: 匿名发表