原理:数据转化成example协议内存块,再将example协议内存块里的数据序列化成字符串,再通过TFRecordWriter把字符串写入到TFRecords文件中。
import os
import tensorflow as tf
from PIL import Image #注意Image,后面会用到
import matplotlib.pyplot as plt
import numpy as np
cwd='/Users/yyzanll/Desktop/pics/dogs/'
classes={'husky','chihuahua'} #人为 设定 2 类
writer= tf.python_io.TFRecordWriter("dog_train.tfrecords") #要生成的文件
for index,name in enumerate(classes):
class_path=cwd+name+'/'
#os.listdir() 方法用于返回指定的文件夹包含的文件或文件夹的名字的列表
for img_name in os.listdir(class_path):
img_path=class_path+img_name #每一个图片的地址
img=Image.open(img_path)
img= img.resize((240,240))
img_raw=img.tobytes()#将图片转化为二进制格式
example = tf.train.Example(features=tf.train.Features(feature={
"label": tf.train.Feature(int64_list=tf.train.Int64List(value=[index])),
'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
})) #example对象对label和image数据进行封装
writer.write(example.SerializeToString()) #序列化为字符串后写到TFRecords文件中
writer.close()
#运行到以上内容截止,生成了dog_train.tfrecords文件夹
#使用tf.train.string_input_producer函数把我们需要的全部文件打包为一个tf内部的queue类型,之后tf开文件就从这个queue中取目录了,
#要注意一点的是这个函数的shuffle参数默认是True
filename_queue = tf.train.string_input_producer(["dog_train.tfrecords"]) #读入流中
reader = tf.TFRecordReader()
#tf.TFRecordReader.read(queue, name=None)
#return:
#A tuple of Tensors (key, value). key: A string scalar Tensor.
#value: A string scalar Tensor. 返回键值对,其中值表示读取的 件
_, serialized_example = reader.read(filename_queue) #返回文件名和文件
features = tf.parse_single_example(serialized_example,
features={
'label': tf.FixedLenFeature([], tf.int64),
'img_raw' : tf.FixedLenFeature([], tf.string),
}) #取出包含image和label的feature对象
image = tf.decode_raw(features['img_raw'], tf.uint8)
image = tf.reshape(image, [240, 240, 3])
label = tf.cast(features['label'], tf.int32)
with tf.Session() as sess: #开始一个会话
init_op = tf.initialize_all_variables()
sess.run(init_op)
coord=tf.train.Coordinator()
threads= tf.train.start_queue_runners(coord=coord)
for i in range(20):
example, l = sess.run([image,label])#在会话中取出image和label
img=Image.fromarray(example, 'RGB')#这里Image是之前提到的
img.save(cwd+str(i)+'_''Label_'+str(l)+'.jpg')#存下图片
print(example, l)
coord.request_stop()
coord.join(threads)
#后半段代码的目的是把dog_train.tfrecords里的数据读出来转成我们想要的图片的样子。
效果图:

TFRecords文件格式在图像识别中有很好的使用,其可以将二进制数据和标签数据(训练的类别标签)数据存储在同一个文件中,它可以在模型进行训练之前通过预处理步骤将图像转换为TFRecords格式,此格式最大的优点实践每幅输入图像和与之关联的标签放在同一个文件中.TFRecords文件是一种二进制文件,其不对数据进行压缩,所以可以被快速加载到内存中.格式不支持随机访问,因此它适合于大量的数据流,但不适用于快速分片或其他非连续存取。
tfrecord生成总结:
1、create_data.py必须在tf环境里,不用在slim下,也不用非在search目录下,data_dir不需要在tf环境里,任何位置都可以。
2、遇到引用不到的问题,使用PYTHONPATH添加环境变量,让里面的文件能够顺利找到.py的路径
export PYTHONPATH=$PYTHONPATH:/Users/yyzanll/anaconda/envs/tensorflow/lib/python3.5/site-packages/tensorflow/models/research/object_detection/utils
但是写入PYTHONPATH只在当前命令行框里管用,要想长久管用,就写入~/.bash_profile里。
3、生成tfrecord的时候如果有create_data.py文件有很多同级文件,最好一起拷贝到一个文件夹下使用,因为可能互相引用
网友评论