tensorflow入門:TFRecordDataset變長(zhǎng)數(shù)據(jù)的batch讀取詳解
在上一篇文章tensorflow入門:tfrecord 和tf.data.TFRecordDataset的使用里,講到了使用如何使用tf.data.TFRecordDatase來(lái)對(duì)tfrecord文件進(jìn)行batch讀取,即使用dataset的batch方法進(jìn)行;但如果每條數(shù)據(jù)的長(zhǎng)度不一樣(常見(jiàn)于語(yǔ)音、視頻、NLP等領(lǐng)域),則不能直接用batch方法獲取數(shù)據(jù),這時(shí)則有兩個(gè)解決辦法:
1.在把數(shù)據(jù)寫(xiě)入tfrecord時(shí),先把數(shù)據(jù)pad到統(tǒng)一的長(zhǎng)度再寫(xiě)入tfrecord;這個(gè)方法的問(wèn)題在于:若是有大量數(shù)據(jù)的長(zhǎng)度都遠(yuǎn)遠(yuǎn)小于最大長(zhǎng)度,則會(huì)造成存儲(chǔ)空間的大量浪費(fèi)。
2.使用dataset中的padded_batch方法來(lái)進(jìn)行,參數(shù)padded_shapes #指明每條記錄中各成員要pad成的形狀,成員若是scalar,則用[],若是list,則用[mx_length],若是array,則用[d1,...,dn],假如各成員的順序是scalar數(shù)據(jù)、list數(shù)據(jù)、array數(shù)據(jù),則padded_shapes=([], [mx_length], [d1,...,dn]);該方法的函數(shù)說(shuō)明如下:
padded_batch( batch_size, padded_shapes, padding_values=None #默認(rèn)使用各類型數(shù)據(jù)的默認(rèn)值,一般使用時(shí)可忽略該項(xiàng) )
使用mnist數(shù)據(jù)來(lái)舉例說(shuō)明,首先在把mnist寫(xiě)入tfrecord之前,把mnist數(shù)據(jù)進(jìn)行更改,以使得每個(gè)mnist圖像的大小不等,如下:
import tensorflow as tf
from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets
mnist = read_data_sets("MNIST_data/", one_hot=True)
def get_tfrecords_example(feature, label):
tfrecords_features = {}
feat_shape = feature.shape
tfrecords_features['feature'] = tf.train.Feature(float_list=tf.train.FloatList(value=feature))
tfrecords_features['shape'] = tf.train.Feature(int64_list=tf.train.Int64List(value=list(feat_shape)))
tfrecords_features['label'] = tf.train.Feature(float_list=tf.train.FloatList(value=label))
return tf.train.Example(features=tf.train.Features(feature=tfrecords_features))
def make_tfrecord(data, outf_nm='mnist-train'):
feats, labels = data
outf_nm += '.tfrecord'
tfrecord_wrt = tf.python_io.TFRecordWriter(outf_nm)
ndatas = len(labels)
print(feats[0].dtype, feats[0].shape, ndatas)
assert len(labels[0]) > 1
for inx in range(ndatas):
ed = random.randint(0,3) #隨機(jī)丟掉幾個(gè)數(shù)據(jù)點(diǎn),以使長(zhǎng)度不等
exmp = get_tfrecords_example(feats[inx][:-ed], labels[inx])
exmp_serial = exmp.SerializeToString()
tfrecord_wrt.write(exmp_serial)
tfrecord_wrt.close()
import random
nDatas = len(mnist.train.labels)
inx_lst = range(nDatas)
random.shuffle(inx_lst)
random.shuffle(inx_lst)
ntrains = int(0.85*nDatas)
# make training set
data = ([mnist.train.images[i] for i in inx_lst[:ntrains]], \
[mnist.train.labels[i] for i in inx_lst[:ntrains]])
make_tfrecord(data, outf_nm='mnist-train')
# make validation set
data = ([mnist.train.images[i] for i in inx_lst[ntrains:]], \
[mnist.train.labels[i] for i in inx_lst[ntrains:]])
make_tfrecord(data, outf_nm='mnist-val')
# make test set
data = (mnist.test.images, mnist.test.labels)
make_tfrecord(data, outf_nm='mnist-test')
用dataset加載批量數(shù)據(jù),在解析數(shù)據(jù)時(shí)用到tf.VarLenFeature(tf.datatype),而非tf.FixedLenFeature([], tf.datatype)},且要配合tf.sparse_tensor_to_dense函數(shù)使用,如下:
import tensorflow as tf
train_f, val_f, test_f = ['mnist-%s.tfrecord'%i for i in ['train', 'val', 'test']]
def parse_exmp(serial_exmp):
feats = tf.parse_single_example(serial_exmp, features={'feature':tf.VarLenFeature(tf.float32),\
'label':tf.FixedLenFeature([10],tf.float32), 'shape':tf.FixedLenFeature([], tf.int64)})
image = tf.sparse_tensor_to_dense(feats['feature']) #使用VarLenFeature讀入的是一個(gè)sparse_tensor,用該函數(shù)進(jìn)行轉(zhuǎn)換
label = tf.reshape(feats['label'],[2,5]) #把label變成[2,5],以說(shuō)明array數(shù)據(jù)如何padding
shape = tf.cast(feats['shape'], tf.int32)
return image, label, shape
def get_dataset(fname):
dataset = tf.data.TFRecordDataset(fname)
return dataset.map(parse_exmp) # use padded_batch method if padding needed
epochs = 16
batch_size = 50
padded_shapes = ([784],[3,5],[]) #把image pad至784,把label pad至[3,5],shape是一個(gè)scalar,不輸入數(shù)字
# training dataset
dataset_train = get_dataset(train_f)
dataset_train = dataset_train.repeat(epochs).shuffle(1000).padded_batch(batch_size, padded_shapes=padded_shapes)
以上這篇tensorflow入門:TFRecordDataset變長(zhǎng)數(shù)據(jù)的batch讀取詳解就是小編分享給大家的全部?jī)?nèi)容了,希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
相關(guān)文章
Windows下安裝python2和python3多版本教程
這篇文章主要介紹下Windows(我用的Win10)環(huán)境下的python2.x 和 python3.x 的安裝,以及python2.x 與 python3.x 共存時(shí)的配置問(wèn)題。2017-03-03
Python Pytorch深度學(xué)習(xí)之圖像分類器
今天小編就為大家分享一篇關(guān)于Pytorch圖像分類器的文章,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧2021-10-10
Python使用sftp實(shí)現(xiàn)傳文件夾和文件
這篇文章主要為大家詳細(xì)介紹了Python使用sftp實(shí)現(xiàn)傳文件夾和文件,文中示例代碼介紹的非常詳細(xì),具有一定的參考價(jià)值,感興趣的小伙伴們可以參考一下2021-04-04
Python中subprocess模塊用法實(shí)例詳解
這篇文章主要介紹了Python中subprocess模塊用法,實(shí)例分析了subprocess模塊的相關(guān)使用技巧,需要的朋友可以參考下2015-05-05
關(guān)于Python字典(Dictionary)操作詳解
這篇文章主要介紹了關(guān)于Python字典(Dictionary)操作詳解,Python字典是另一種可變?nèi)萜髂P停铱纱鎯?chǔ)任意類型對(duì)象,如字符串、數(shù)字、元組等其他容器模型,需要的朋友可以參考下2023-04-04
python利用openpyxl拆分多個(gè)工作表的工作簿的方法
這篇文章主要介紹了python利用openpyxl拆分多個(gè)工作表的工作簿的方法,文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來(lái)一起學(xué)習(xí)學(xué)習(xí)吧2019-09-09
詳解Open Folder as PyCharm Project怎么添加的方法
這篇文章主要介紹了詳解Open Folder as PyCharm Project怎么添加的方法,文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來(lái)一起學(xué)習(xí)學(xué)習(xí)吧2020-12-12
Python字典中的鍵映射多個(gè)值的方法(列表或者集合)
今天小編就為大家分享一篇Python字典中的鍵映射多個(gè)值的方法(列表或者集合),具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧2018-10-10
python opencv實(shí)現(xiàn)切變換 不裁減圖片
這篇文章主要為大家詳細(xì)介紹了python opencv實(shí)現(xiàn)切變換,不裁減圖片,具有一定的參考價(jià)值,感興趣的小伙伴們可以參考一下2018-07-07

