Tensorflow中TFRecord生成與讀取的實(shí)現(xiàn)
一、為什么使用TFRecord?
正常情況下我們訓(xùn)練文件夾經(jīng)常會(huì)生成 train, test 或者val文件夾,這些文件夾內(nèi)部往往會(huì)存著成千上萬(wàn)的圖片或文本等文件,這些文件被散列存著,這樣不僅占用磁盤(pán)空間,并且再被一個(gè)個(gè)讀取的時(shí)候會(huì)非常慢,繁瑣。占用大量?jī)?nèi)存空間(有的大型數(shù)據(jù)不足以一次性加載)。此時(shí)我們TFRecord格式的文件存儲(chǔ)形式會(huì)很合理的幫我們存儲(chǔ)數(shù)據(jù)。TFRecord內(nèi)部使用了“Protocol Buffer”二進(jìn)制數(shù)據(jù)編碼方案,它只占用一個(gè)內(nèi)存塊,只需要一次性加載一個(gè)二進(jìn)制文件的方式即可,簡(jiǎn)單,快速,尤其對(duì)大型訓(xùn)練數(shù)據(jù)很友好。而且當(dāng)我們的訓(xùn)練數(shù)據(jù)量比較大的時(shí)候,可以將數(shù)據(jù)分成多個(gè)TFRecord文件,來(lái)提高處理效率。
二、 生成TFRecord簡(jiǎn)單實(shí)現(xiàn)方式
我們可以分成兩個(gè)部分來(lái)介紹如何生成TFRecord,分別是TFRecord生成器以及樣本Example模塊。
- TFRecord生成器
writer = tf.python_io.TFRecordWriter(record_path) writer.write(tf_example.SerializeToString()) writer.close()
這里面writer
就是我們TFrecord生成器。接著我們就可以通過(guò)writer.write(tf_example.SerializeToString())
來(lái)生成我們所要的tfrecord文件了。這里需要注意的是我們TFRecord生成器在寫(xiě)完文件后需要關(guān)閉writer.close()
。這里tf_example.SerializeToString()
是將Example中的map壓縮為二進(jìn)制文件,更好的節(jié)省空間。那么tf_example是如何生成的呢?那就是下面所要介紹的樣本Example模塊了。
- Example模塊
首先們來(lái)看一下Example協(xié)議塊是什么樣子的。
message Example { Features features = 1; }; message Features { map<string, Feature> feature = 1; }; message Feature { oneof kind { BytesList bytes_list = 1; FloatList float_list = 2; Int64List int64_list = 3; } };
我們可以看出上面的tf_example可以寫(xiě)入的數(shù)據(jù)形式有三種,分別是BytesList, FloatList以及Int64List的類(lèi)型。那我們?nèi)绾螌?xiě)一個(gè)tf_example呢?下面有一個(gè)簡(jiǎn)單的例子。
def int64_feature(value): return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) def bytes_feature(value): return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) tf_example = tf.train.Example( features=tf.train.Features(feature={ 'image/encoded': bytes_feature(encoded_jpg), 'image/format': bytes_feature('jpg'.encode()), 'image/class/label': int64_feature(label), 'image/height': int64_feature(height), 'image/width': int64_feature(width)}))
下面我們來(lái)好好從外部往內(nèi)部分解來(lái)解釋一下上面的內(nèi)容。
(1)tf.train.Example(features = None)
這里的features是tf.train.Features類(lèi)型的特征實(shí)例。
(2)tf.train.Features(feature = None)
這里的feature是以字典的形式存在,*key:要保存數(shù)據(jù)的名字 value:要保存的數(shù)據(jù),但是格式必須符合tf.train.Feature實(shí)例要求。
三、 生成TFRecord文件完整代碼實(shí)例
首先我們需要提供數(shù)據(jù)集
圖片文件夾
通過(guò)圖片文件夾我們可以知道這里面總共有七種分類(lèi)圖片,類(lèi)別的名稱(chēng)就是每個(gè)文件夾名稱(chēng),每個(gè)類(lèi)別文件夾存儲(chǔ)各自的對(duì)應(yīng)類(lèi)別的很多圖片。下面我們通過(guò)一下代碼(generate_annotation_json.py
和generate_tfrecord.py
)生成train.record。
- generate_annotation_json.py
# -*- coding: utf-8 -*- # @Time : 2018/11/22 22:12 # @Author : MaochengHu # @Email : wojiaohumaocheng@gmail.com # @File : generate_annotation_json.py # @Software: PyCharm import os import json def get_annotation_dict(input_folder_path, word2number_dict): label_dict = {} father_file_list = os.listdir(input_folder_path) for father_file in father_file_list: full_father_file = os.path.join(input_folder_path, father_file) son_file_list = os.listdir(full_father_file) for image_name in son_file_list: label_dict[os.path.join(full_father_file, image_name)] = word2number_dict[father_file] return label_dict def save_json(label_dict, json_path): with open(json_path, 'w') as json_path: json.dump(label_dict, json_path) print("label json file has been generated successfully!")
- generate_tfrecord.py
# -*- coding: utf-8 -*- # @Time : 2018/11/23 0:09 # @Author : MaochengHu # @Email : wojiaohumaocheng@gmail.com # @File : generate_tfrecord.py # @Software: PyCharm import os import tensorflow as tf import io from PIL import Image from generate_annotation_json import get_annotation_dict flags = tf.app.flags flags.DEFINE_string('images_dir', '/data2/raycloud/jingxiong_datasets/six_classes/images', 'Path to image(directory)') flags.DEFINE_string('annotation_path', '/data1/humaoc_file/classify/data/annotations/annotations.json', 'Path to annotation') flags.DEFINE_string('record_path', '/data1/humaoc_file/classify/data/train_tfrecord/train.record', 'Path to TFRecord') FLAGS = flags.FLAGS def int64_feature(value): return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) def bytes_feature(value): return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) def process_image_channels(image): process_flag = False # process the 4 channels .png if image.mode == 'RGBA': r, g, b, a = image.split() image = Image.merge("RGB", (r,g,b)) process_flag = True # process the channel image elif image.mode != 'RGB': image = image.convert("RGB") process_flag = True return image, process_flag def process_image_reshape(image, resize): width, height = image.size if resize is not None: if width > height: width = int(width * resize / height) height = resize else: width = resize height = int(height * resize / width) image = image.resize((width, height), Image.ANTIALIAS) return image def create_tf_example(image_path, label, resize=None): with tf.gfile.GFile(image_path, 'rb') as fid: encode_jpg = fid.read() encode_jpg_io = io.BytesIO(encode_jpg) image = Image.open(encode_jpg_io) # process png pic with four channels image, process_flag = process_image_channels(image) # reshape image image = process_image_reshape(image, resize) if process_flag == True or resize is not None: bytes_io = io.BytesIO() image.save(bytes_io, format='JPEG') encoded_jpg = bytes_io.getvalue() width, height = image.size tf_example = tf.train.Example( features=tf.train.Features( feature={ 'image/encoded': bytes_feature(encode_jpg), 'image/format': bytes_feature(b'jpg'), 'image/class/label': int64_feature(label), 'image/height': int64_feature(height), 'image/width': int64_feature(width) } )) return tf_example def generate_tfrecord(annotation_dict, record_path, resize=None): num_tf_example = 0 writer = tf.python_io.TFRecordWriter(record_path) for image_path, label in annotation_dict.items(): if not tf.gfile.GFile(image_path): print("{} does not exist".format(image_path)) tf_example = create_tf_example(image_path, label, resize) writer.write(tf_example.SerializeToString()) num_tf_example += 1 if num_tf_example % 100 == 0: print("Create %d TF_Example" % num_tf_example) writer.close() print("{} tf_examples has been created successfully, which are saved in {}".format(num_tf_example, record_path)) def main(_): word2number_dict = { "combinations": 0, "details": 1, "sizes": 2, "tags": 3, "models": 4, "tileds": 5, "hangs": 6 } images_dir = FLAGS.images_dir #annotation_path = FLAGS.annotation_path record_path = FLAGS.record_path annotation_dict = get_annotation_dict(images_dir, word2number_dict) generate_tfrecord(annotation_dict, record_path) if __name__ == '__main__': tf.app.run()
* 這里需要說(shuō)明的是generate_annotation_json.py是為了得到圖片標(biāo)注的label_dict。通過(guò)這個(gè)代碼塊可以獲得我們需要的圖片標(biāo)注字典,key是圖片具體地址, value是圖片的類(lèi)別,具體實(shí)例如下:
{ "/images/hangs/862e67a8-5bd9-41f1-8c6d-876a3cb270df.JPG": 6, "/images/tags/adc264af-a76b-4477-9573-ac6c435decab.JPG": 3, "/images/tags/fd231f5a-b42c-43ba-9e9d-4abfbaf38853.JPG": 3, "/images/hangs/2e47d877-1954-40d6-bfa2-1b8e3952ebf9.jpg": 6, "/images/tileds/a07beddc-4b39-4865-8ee2-017e6c257e92.png": 5, "/images/models/642015c8-f29d-4930-b1a9-564f858c40e5.png": 4 }
- 如何運(yùn)行代碼
(1)首先我們的文件夾構(gòu)成形式是如下結(jié)構(gòu),其中images_root
是圖片根文件夾,combinations, details, sizes, tags, models, tileds, hangs
分別存放不同類(lèi)別的圖片文件夾。
-<images_root> -<combinations> -圖片.jpg -<details> -圖片.jpg -<sizes> -圖片.jpg -<tags> -圖片.jpg -<models> -圖片.jpg -<tileds> -圖片.jpg -<hangs> -圖片.jpg
(2)建立文件夾TFRecord
,并將generate_tfrecord.py
和generate_annotation_json.py
這兩個(gè)python文件放入文件夾內(nèi),需要注意的是我們需要將 generate_tfrecord.py
文件中字典word2number_dict換成自己的字典(即key是放不同類(lèi)別的圖片文件夾名稱(chēng),value是對(duì)應(yīng)的分類(lèi)number)
word2number_dict = { "combinations": 0, "details": 1, "sizes": 2, "tags": 3, "models": 4, "tileds": 5, "hangs": 6 }
(3)直接執(zhí)行代碼 python3/python2 ./TFRecord/generate_tfrecord.py --image_dir="images_root地址" --record_path="你想要保存record地址(.record文件全路徑)"
即可。如下是一個(gè)實(shí)例:
python3 generate_tfrecord.py --image_dir /images/ --record_path /classify/data/train_tfrecord/train.record
TFRecord讀取
上面我們介紹了如何生成TFRecord,現(xiàn)在我們嘗試如何通過(guò)使用隊(duì)列讀取讀取我們的TFRecord。
讀取TFRecord可以通過(guò)tensorflow兩個(gè)個(gè)重要的函數(shù)實(shí)現(xiàn),分別是tf.train.string_input_producer
和 tf.TFRecordReader
的tf.parse_single_example
解析器。如下圖
AnimatedFileQueues.gif
四、 讀取TFRecord的簡(jiǎn)單實(shí)現(xiàn)方式
解析TFRecord有兩種解析方式一種是利用tf.parse_single_example
, 另一種是通過(guò)tf.contrib.slim
(* 推薦使用)。
第一種方式(tf.parse_single_example)解析步驟如下:
(1).第一步,我們將train.record
文件讀入到隊(duì)列中,如下所示:filename_queue = tf.train.string_input_producer([tfrecords_filename])
(2) 第二步,我們需要通過(guò)TFRecord將生成的隊(duì)列讀入
reader = tf.TFRecordReader() _, serialized_example = reader.read(filename_queue) #返回文件名和文件
(3)第三步, 通過(guò)解析器tf.parse_single_example
將我們的example解析出來(lái)。
第二種方式(tf.contrib.slim)解析步驟如下:
(1) 第一步, 我們要設(shè)置decoder = slim.tfexample_decoder.TFExampleDecoder(keys_to_features, items_to_handlers)
, 其中key_to_features
這個(gè)字典需要和TFrecord文件中定義的字典項(xiàng)匹配,items_to_handlers
中的關(guān)鍵字可以是任意值,但是它的handler的初始化參數(shù)必須要來(lái)自于keys_to_features中的關(guān)鍵字。
(2) 第二步, 我們要設(shè)定dataset = slim.dataset.Dataset(params)
, 其中params包括:
a. data_source
: 為tfrecord文件地址
b. reader
: 一般設(shè)置為tf.TFRecordReader閱讀器
c. decoder
: 為第一步設(shè)置的decoder
d. num_samples
: 樣本數(shù)量
e. items_to_description
: 對(duì)樣本及標(biāo)簽的描述
f. num_classes
: 分類(lèi)的數(shù)量
(3) 第三步, 我們?cè)O(shè)置provider = slim.dataset_data_provider.DatasetDataProvider(params)
, 其中params包括 :
a. dataset
: 第二步驟我們生成的數(shù)據(jù)集
b. num_reader
: 并行閱讀器數(shù)量
c. shuffle
: 是否打亂
d. num_epochs
:每個(gè)數(shù)據(jù)源被讀取的次數(shù),如果設(shè)為None數(shù)據(jù)將會(huì)被無(wú)限循環(huán)的讀取
e. common_queue_capacity
:讀取數(shù)據(jù)隊(duì)列的容量,默認(rèn)為256
f. scope
:范圍
g. common_queue_min
:讀取數(shù)據(jù)隊(duì)列的最小容量。
(4) 第四步, 我們可以通過(guò)provider.get
得到我們需要的數(shù)據(jù)了。
3. 對(duì)不同圖片大小的TFRecord讀取并resize成相同大小reshape_same_size
函數(shù)來(lái)對(duì)圖片進(jìn)行resize,這樣我們可以對(duì)我們的圖片進(jìn)行batch操作了,因?yàn)橛械纳窠?jīng)網(wǎng)絡(luò)訓(xùn)練需要一個(gè)batch一個(gè)batch操作,不同大小的圖片在組成一個(gè)batch的時(shí)候會(huì)報(bào)錯(cuò),因此我們我通過(guò)后期處理可以更好的對(duì)圖片進(jìn)行batch操作。
或者直接通過(guò)resized_image = tf.squeeze(tf.image.resize_bilinear([image], size=[FLAG.resize_height, FLAG.resize_width]))
即可。
五、tf.contrib.slim模塊讀取TFrecord文件完整代碼實(shí)例
# -*- coding: utf-8 -*- # @Time : 2018/12/1 11:06 # @Author : MaochengHu # @Email : wojiaohumaocheng@gmail.com # @File : read_tfrecord.py # @Software: PyCharm import os import tensorflow as tf flags = tf.app.flags flags.DEFINE_string('tfrecord_path', '/data1/humaoc_file/classify/data/train_tfrecord/train.record', 'path to tfrecord file') flags.DEFINE_integer('resize_height', 800, 'resize height of image') flags.DEFINE_integer('resize_width', 800, 'resize width of image') FLAG = flags.FLAGS slim = tf.contrib.slim def print_data(image, resized_image, label, height, width): with tf.Session() as sess: init_op = tf.global_variables_initializer() sess.run(init_op) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(coord=coord) for i in range(10): print("______________________image({})___________________".format(i)) print_image, print_resized_image, print_label, print_height, print_width = sess.run([image, resized_image, label, height, width]) print("resized_image shape is: ", print_resized_image.shape) print("image shape is: ", print_image.shape) print("image label is: ", print_label) print("image height is: ", print_height) print("image width is: ", print_width) coord.request_stop() coord.join(threads) def reshape_same_size(image, output_height, output_width): """Resize images by fixed sides. Args: image: A 3-D image `Tensor`. output_height: The height of the image after preprocessing. output_width: The width of the image after preprocessing. Returns: resized_image: A 3-D tensor containing the resized image. """ output_height = tf.convert_to_tensor(output_height, dtype=tf.int32) output_width = tf.convert_to_tensor(output_width, dtype=tf.int32) image = tf.expand_dims(image, 0) resized_image = tf.image.resize_nearest_neighbor( image, [output_height, output_width], align_corners=False) resized_image = tf.squeeze(resized_image) return resized_image def read_tfrecord(tfrecord_path, num_samples=14635, num_classes=7, resize_height=800, resize_width=800): keys_to_features = { 'image/encoded': tf.FixedLenFeature([], default_value='', dtype=tf.string,), 'image/format': tf.FixedLenFeature([], default_value='jpeg', dtype=tf.string), 'image/class/label': tf.FixedLenFeature([], tf.int64, default_value=0), 'image/height': tf.FixedLenFeature([], tf.int64, default_value=0), 'image/width': tf.FixedLenFeature([], tf.int64, default_value=0) } items_to_handlers = { 'image': slim.tfexample_decoder.Image(image_key='image/encoded', format_key='image/format', channels=3), 'label': slim.tfexample_decoder.Tensor('image/class/label', shape=[]), 'height': slim.tfexample_decoder.Tensor('image/height', shape=[]), 'width': slim.tfexample_decoder.Tensor('image/width', shape=[]) } decoder = slim.tfexample_decoder.TFExampleDecoder(keys_to_features, items_to_handlers) labels_to_names = None items_to_descriptions = { 'image': 'An image with shape image_shape.', 'label': 'A single integer between 0 and 9.'} dataset = slim.dataset.Dataset( data_sources=tfrecord_path, reader=tf.TFRecordReader, decoder=decoder, num_samples=num_samples, items_to_descriptions=None, num_classes=num_classes, ) provider = slim.dataset_data_provider.DatasetDataProvider(dataset=dataset, num_readers=3, shuffle=True, common_queue_capacity=256, common_queue_min=128, seed=None) image, label, height, width = provider.get(['image', 'label', 'height', 'width']) resized_image = tf.squeeze(tf.image.resize_bilinear([image], size=[resize_height, resize_width])) return resized_image, label, image, height, width def main(): resized_image, label, image, height, width = read_tfrecord(tfrecord_path=FLAG.tfrecord_path, resize_height=FLAG.resize_height, resize_width=FLAG.resize_width) #resized_image = reshape_same_size(image, FLAG.resize_height, FLAG.resize_width) #resized_image = tf.squeeze(tf.image.resize_bilinear([image], size=[FLAG.resize_height, FLAG.resize_width])) print_data(image, resized_image, label, height, width) if __name__ == '__main__': main()
代碼運(yùn)行方式
python3 read_tfrecord.py --tfrecord_path /data1/humaoc_file/classify/data/train_tfrecord/train.record --resize_height 800 --resize_width 800
最終我們可以看到我們讀取文件的部分內(nèi)容:
______________________image(0)___________________ resized_image shape is: (800, 800, 3) image shape is: (2000, 1333, 3) image label is: 5 image height is: 2000 image width is: 1333 ______________________image(1)___________________ resized_image shape is: (800, 800, 3) image shape is: (667, 1000, 3) image label is: 0 image height is: 667 image width is: 1000 ______________________image(2)___________________ resized_image shape is: (800, 800, 3) image shape is: (667, 1000, 3) image label is: 3 image height is: 667 image width is: 1000 ______________________image(3)___________________ resized_image shape is: (800, 800, 3) image shape is: (800, 800, 3) image label is: 5 image height is: 800 image width is: 800 ______________________image(4)___________________ resized_image shape is: (800, 800, 3) image shape is: (1424, 750, 3) image label is: 0 image height is: 1424 image width is: 750 ______________________image(5)___________________ resized_image shape is: (800, 800, 3) image shape is: (1196, 1000, 3) image label is: 6 image height is: 1196 image width is: 1000 ______________________image(6)___________________ resized_image shape is: (800, 800, 3) image shape is: (667, 1000, 3) image label is: 5 image height is: 667 image width is: 1000
參考:
[1] TensorFlow 自定義生成 .record 文件
[2] TensorFlow基礎(chǔ)5:TFRecords文件的存儲(chǔ)與讀取講解及代碼實(shí)現(xiàn)
[3] Slim讀取TFrecord文件
[4] Tensorflow針對(duì)不定尺寸的圖片讀寫(xiě)tfrecord文件總結(jié)
到此這篇關(guān)于Tensorflow中TFRecord生成與讀取的實(shí)現(xiàn)的文章就介紹到這了,更多相關(guān)Tensorflow TFRecord生成與讀取內(nèi)容請(qǐng)搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
相關(guān)文章
Python實(shí)現(xiàn)批量更換指定目錄下文件擴(kuò)展名的方法
這篇文章主要介紹了Python實(shí)現(xiàn)批量更換指定目錄下文件擴(kuò)展名的方法,結(jié)合完整實(shí)例分析了Python批量修改文件擴(kuò)展名的技巧,并對(duì)比分析了shell命令及scandir的兼容性代碼,需要的朋友可以參考下2016-09-09將數(shù)據(jù)集制作成VOC數(shù)據(jù)集格式的實(shí)例
今天小編就為大家分享一篇將數(shù)據(jù)集制作成VOC數(shù)據(jù)集格式的實(shí)例,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧2020-02-02Python微服務(wù)開(kāi)發(fā)之使用FastAPI構(gòu)建高效API
微服務(wù)架構(gòu)在現(xiàn)代軟件開(kāi)發(fā)中日益普及,它將復(fù)雜的應(yīng)用程序拆分成多個(gè)可獨(dú)立部署的小型服務(wù)。本文將介紹如何使用 Python 的 FastAPI 庫(kù)快速構(gòu)建和部署微服務(wù),感興趣的可以了解一下2023-05-05Blueprint實(shí)現(xiàn)路由分組及Flask中session的使用詳解
這篇文章主要為大家介紹了Blueprint實(shí)現(xiàn)路由分組及Flask中session的使用詳解,有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進(jìn)步,早日升職加薪2023-11-11python自定義時(shí)鐘類(lèi)、定時(shí)任務(wù)類(lèi)
這篇文章主要為大家詳細(xì)介紹了Python自定義時(shí)鐘類(lèi)、定時(shí)任務(wù)類(lèi),文中示例代碼介紹的非常詳細(xì),具有一定的參考價(jià)值,感興趣的小伙伴們可以參考一下2019-07-07python numpy中multiply與*及matul 的區(qū)別說(shuō)明
這篇文章主要介紹了python numpy中multiply與*及matul 的區(qū)別說(shuō)明,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。如有錯(cuò)誤或未考慮完全的地方,望不吝賜教2021-05-05python+excel接口自動(dòng)化獲取token并作為請(qǐng)求參數(shù)進(jìn)行傳參操作
這篇文章主要介紹了python+excel接口自動(dòng)化獲取token并作為請(qǐng)求參數(shù)進(jìn)行傳參操作,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧2020-11-11