欧美bbbwbbbw肥妇,免费乱码人妻系列日韩,一级黄片

淺談tensorflow中dataset.shuffle和dataset.batch dataset.repeat注意點

 更新時間:2020年06月08日 12:01:30   作者:青盞  
這篇文章主要介紹了淺談tensorflow中dataset.shuffle和dataset.batch dataset.repeat注意點,文中通過示例代碼介紹的非常詳細(xì),對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧

batch很好理解,就是batch size。注意在一個epoch中最后一個batch大小可能小于等于batch size

dataset.repeat就是俗稱epoch,但在tf中與dataset.shuffle的使用順序可能會導(dǎo)致個epoch的混合

dataset.shuffle就是說維持一個buffer size 大小的 shuffle buffer,圖中所需的每個樣本從shuffle buffer中獲取,取得一個樣本后,就從源數(shù)據(jù)集中加入一個樣本到shuffle buffer中。

import os
os.environ['CUDA_VISIBLE_DEVICES'] = ""
import numpy as np
import tensorflow as tf
np.random.seed(0)
x = np.random.sample((11,2))
# make a dataset from a numpy array
print(x)
print()
dataset = tf.data.Dataset.from_tensor_slices(x)
dataset = dataset.shuffle(3)
dataset = dataset.batch(4)
dataset = dataset.repeat(2)

# create the iterator
iter = dataset.make_one_shot_iterator()
el = iter.get_next()

with tf.Session() as sess:
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
#源數(shù)據(jù)集
[[ 0.5488135  0.71518937]
 [ 0.60276338 0.54488318]
 [ 0.4236548  0.64589411]
 [ 0.43758721 0.891773 ]
 [ 0.96366276 0.38344152]
 [ 0.79172504 0.52889492]
 [ 0.56804456 0.92559664]
 [ 0.07103606 0.0871293 ]
 [ 0.0202184  0.83261985]
 [ 0.77815675 0.87001215]
 [ 0.97861834 0.79915856]]

# 通過shuffle batch后取得的樣本
[[ 0.4236548  0.64589411]
 [ 0.60276338 0.54488318]
 [ 0.43758721 0.891773 ]
 [ 0.5488135  0.71518937]]
[[ 0.96366276 0.38344152]
 [ 0.56804456 0.92559664]
 [ 0.0202184  0.83261985]
 [ 0.79172504 0.52889492]]
[[ 0.07103606 0.0871293 ]
 [ 0.97861834 0.79915856]
 [ 0.77815675 0.87001215]] #最后一個batch樣本個數(shù)為3
[[ 0.60276338 0.54488318]
 [ 0.5488135  0.71518937]
 [ 0.43758721 0.891773 ]
 [ 0.79172504 0.52889492]]
[[ 0.4236548  0.64589411]
 [ 0.56804456 0.92559664]
 [ 0.0202184  0.83261985]
 [ 0.07103606 0.0871293 ]]
[[ 0.77815675 0.87001215]
 [ 0.96366276 0.38344152]
 [ 0.97861834 0.79915856]] #最后一個batch樣本個數(shù)為3

1、按照shuffle中設(shè)置的buffer size,首先從源數(shù)據(jù)集取得三個樣本:
shuffle buffer:
[ 0.5488135 0.71518937]
[ 0.60276338 0.54488318]
[ 0.4236548 0.64589411]
2、從buffer中取一個樣本到batch中得:
shuffle buffer:
[ 0.5488135 0.71518937]
[ 0.60276338 0.54488318]
batch:
[ 0.4236548 0.64589411]
3、shuffle buffer不足三個樣本,從源數(shù)據(jù)集提取一個樣本:
shuffle buffer:
[ 0.5488135 0.71518937]
[ 0.60276338 0.54488318]
[ 0.43758721 0.891773 ]
4、從buffer中取一個樣本到batch中得:
shuffle buffer:
[ 0.5488135 0.71518937]
[ 0.43758721 0.891773 ]
batch:
[ 0.4236548 0.64589411]
[ 0.60276338 0.54488318]
5、如此反復(fù)。這就意味中如果shuffle 的buffer size=1,數(shù)據(jù)集不打亂。如果shuffle 的buffer size=數(shù)據(jù)集樣本數(shù)量,隨機打亂整個數(shù)據(jù)集

import os
os.environ['CUDA_VISIBLE_DEVICES'] = ""
import numpy as np
import tensorflow as tf
np.random.seed(0)
x = np.random.sample((11,2))
# make a dataset from a numpy array
print(x)
print()
dataset = tf.data.Dataset.from_tensor_slices(x)
dataset = dataset.shuffle(1)
dataset = dataset.batch(4)
dataset = dataset.repeat(2)

# create the iterator
iter = dataset.make_one_shot_iterator()
el = iter.get_next()

with tf.Session() as sess:
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))

[[ 0.5488135  0.71518937]
 [ 0.60276338 0.54488318]
 [ 0.4236548  0.64589411]
 [ 0.43758721 0.891773 ]
 [ 0.96366276 0.38344152]
 [ 0.79172504 0.52889492]
 [ 0.56804456 0.92559664]
 [ 0.07103606 0.0871293 ]
 [ 0.0202184  0.83261985]
 [ 0.77815675 0.87001215]
 [ 0.97861834 0.79915856]]

[[ 0.5488135  0.71518937]
 [ 0.60276338 0.54488318]
 [ 0.4236548  0.64589411]
 [ 0.43758721 0.891773 ]]
[[ 0.96366276 0.38344152]
 [ 0.79172504 0.52889492]
 [ 0.56804456 0.92559664]
 [ 0.07103606 0.0871293 ]]
[[ 0.0202184  0.83261985]
 [ 0.77815675 0.87001215]
 [ 0.97861834 0.79915856]]
[[ 0.5488135  0.71518937]
 [ 0.60276338 0.54488318]
 [ 0.4236548  0.64589411]
 [ 0.43758721 0.891773 ]]
[[ 0.96366276 0.38344152]
 [ 0.79172504 0.52889492]
 [ 0.56804456 0.92559664]
 [ 0.07103606 0.0871293 ]]
[[ 0.0202184  0.83261985]
 [ 0.77815675 0.87001215]
 [ 0.97861834 0.79915856]]

注意如果repeat在shuffle之前使用:

官方說repeat在shuffle之前使用能提高性能,但模糊了數(shù)據(jù)樣本的epoch關(guān)系

import os
os.environ['CUDA_VISIBLE_DEVICES'] = ""
import numpy as np
import tensorflow as tf
np.random.seed(0)
x = np.random.sample((11,2))
# make a dataset from a numpy array
print(x)
print()
dataset = tf.data.Dataset.from_tensor_slices(x)
dataset = dataset.repeat(2)
dataset = dataset.shuffle(11)
dataset = dataset.batch(4)

# create the iterator
iter = dataset.make_one_shot_iterator()
el = iter.get_next()

with tf.Session() as sess:
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))

[[ 0.5488135  0.71518937]
 [ 0.60276338 0.54488318]
 [ 0.4236548  0.64589411]
 [ 0.43758721 0.891773 ]
 [ 0.96366276 0.38344152]
 [ 0.79172504 0.52889492]
 [ 0.56804456 0.92559664]
 [ 0.07103606 0.0871293 ]
 [ 0.0202184  0.83261985]
 [ 0.77815675 0.87001215]
 [ 0.97861834 0.79915856]]

[[ 0.56804456 0.92559664]
 [ 0.5488135  0.71518937]
 [ 0.60276338 0.54488318]
 [ 0.07103606 0.0871293 ]]
[[ 0.96366276 0.38344152]
 [ 0.43758721 0.891773 ]
 [ 0.43758721 0.891773 ]
 [ 0.77815675 0.87001215]]
[[ 0.79172504 0.52889492]  #出現(xiàn)相同樣本出現(xiàn)在同一個batch中
 [ 0.79172504 0.52889492]
 [ 0.60276338 0.54488318]
 [ 0.4236548  0.64589411]]
[[ 0.07103606 0.0871293 ]
 [ 0.4236548  0.64589411]
 [ 0.96366276 0.38344152]
 [ 0.5488135  0.71518937]]
[[ 0.97861834 0.79915856]
 [ 0.0202184  0.83261985]
 [ 0.77815675 0.87001215]
 [ 0.56804456 0.92559664]]
[[ 0.0202184  0.83261985]
 [ 0.97861834 0.79915856]]     #可以看到最后個batch為2,而前面都是4  

使用案例:

def input_fn(filenames, batch_size=32, num_epochs=1, perform_shuffle=False):
  print('Parsing', filenames)
  def decode_libsvm(line):
    #columns = tf.decode_csv(value, record_defaults=CSV_COLUMN_DEFAULTS)
    #features = dict(zip(CSV_COLUMNS, columns))
    #labels = features.pop(LABEL_COLUMN)
    columns = tf.string_split([line], ' ')
    labels = tf.string_to_number(columns.values[0], out_type=tf.float32)
    splits = tf.string_split(columns.values[1:], ':')
    id_vals = tf.reshape(splits.values,splits.dense_shape)
    feat_ids, feat_vals = tf.split(id_vals,num_or_size_splits=2,axis=1)
    feat_ids = tf.string_to_number(feat_ids, out_type=tf.int32)
    feat_vals = tf.string_to_number(feat_vals, out_type=tf.float32)
    #feat_ids = tf.reshape(feat_ids,shape=[-1,FLAGS.field_size])
    #for i in range(splits.dense_shape.eval()[0]):
    #  feat_ids.append(tf.string_to_number(splits.values[2*i], out_type=tf.int32))
    #  feat_vals.append(tf.string_to_number(splits.values[2*i+1]))
    #return tf.reshape(feat_ids,shape=[-1,field_size]), tf.reshape(feat_vals,shape=[-1,field_size]), labels
    return {"feat_ids": feat_ids, "feat_vals": feat_vals}, labels

  # Extract lines from input files using the Dataset API, can pass one filename or filename list
  dataset = tf.data.TextLineDataset(filenames).map(decode_libsvm, num_parallel_calls=10).prefetch(500000)  # multi-thread pre-process then prefetch

  # Randomizes input using a window of 256 elements (read into memory)
  if perform_shuffle:
    dataset = dataset.shuffle(buffer_size=256)

  # epochs from blending together.
  dataset = dataset.repeat(num_epochs)
  dataset = dataset.batch(batch_size) # Batch size to use

  #return dataset.make_one_shot_iterator()
  iterator = dataset.make_one_shot_iterator()
  batch_features, batch_labels = iterator.get_next()
  #return tf.reshape(batch_ids,shape=[-1,field_size]), tf.reshape(batch_vals,shape=[-1,field_size]), batch_labels
  return batch_features, batch_labels

到此這篇關(guān)于淺談tensorflow中dataset.shuffle和dataset.batch dataset.repeat注意點的文章就介紹到這了,更多相關(guān)tensorflow中dataset.shuffle和dataset.batch dataset.repeat內(nèi)容請搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家! 

相關(guān)文章

  • pandas dataframe的合并實現(xiàn)(append, merge, concat)

    pandas dataframe的合并實現(xiàn)(append, merge, concat)

    這篇文章主要介紹了pandas dataframe的合并實現(xiàn),文中通過示例代碼介紹的非常詳細(xì),對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧
    2019-06-06
  • Python樹的鏡像的實現(xiàn)示例

    Python樹的鏡像的實現(xiàn)示例

    樹的鏡像是指將樹的每個節(jié)點的左右子樹交換,得到一棵新的樹,本文主要介紹了Python樹的鏡像的實現(xiàn)示例,具有一定的參考價值,感興趣的可以了解一下
    2023-11-11
  • 自定義python日志文件系統(tǒng)實例

    自定義python日志文件系統(tǒng)實例

    這篇文章主要介紹了自定義python日志文件系統(tǒng)方式,具有很好的參考價值,希望對大家有所幫助,如有錯誤或未考慮完全的地方,望不吝賜教
    2023-08-08
  • Python操作MongoDB數(shù)據(jù)庫的方法示例

    Python操作MongoDB數(shù)據(jù)庫的方法示例

    這篇文章主要介紹了Python操作MongoDB數(shù)據(jù)庫的方法,結(jié)合實例形式分析了Python命令行模式下操作MongoDB數(shù)據(jù)庫實現(xiàn)連接、查找、刪除、排序等相關(guān)操作技巧,需要的朋友可以參考下
    2018-01-01
  • python實現(xiàn)單鏈表中刪除倒數(shù)第K個節(jié)點的方法

    python實現(xiàn)單鏈表中刪除倒數(shù)第K個節(jié)點的方法

    這篇文章主要為大家詳細(xì)介紹了python實現(xiàn)單鏈表中刪除倒數(shù)第K個節(jié)點的方法,具有一定的參考價值,感興趣的小伙伴們可以參考一下
    2018-09-09
  • Python實現(xiàn)加載及解析properties配置文件的方法

    Python實現(xiàn)加載及解析properties配置文件的方法

    這篇文章主要介紹了Python實現(xiàn)加載及解析properties配置文件的方法,結(jié)合實例形式分析了Python針對properties配置文件的加載、讀取及解析相關(guān)操作技巧,需要的朋友可以參考下
    2018-03-03
  • 使用PIL(Python-Imaging)反轉(zhuǎn)圖像的顏色方法

    使用PIL(Python-Imaging)反轉(zhuǎn)圖像的顏色方法

    今天小編就為大家分享一篇使用PIL(Python-Imaging)反轉(zhuǎn)圖像的顏色方法,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧
    2019-01-01
  • 用Python制作在地圖上模擬瘟疫擴散的Gif圖

    用Python制作在地圖上模擬瘟疫擴散的Gif圖

    這篇文章主要介紹了如何用Python制作在地圖上模擬瘟疫擴散的Gif圖,其中用到了歐拉公式等數(shù)學(xué)知識、需要一定的算法基礎(chǔ),需要的朋友可以參考下
    2015-03-03
  • pandas多層索引的創(chuàng)建和取值以及排序的實現(xiàn)

    pandas多層索引的創(chuàng)建和取值以及排序的實現(xiàn)

    這篇文章主要介紹了pandas多層索引的創(chuàng)建和取值以及排序的實現(xiàn),文中通過示例代碼介紹的非常詳細(xì),對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧
    2021-03-03
  • Python+OpenCV 實現(xiàn)圖片無損旋轉(zhuǎn)90°且無黑邊

    Python+OpenCV 實現(xiàn)圖片無損旋轉(zhuǎn)90°且無黑邊

    今天小編就為大家分享一篇Python+OpenCV 實現(xiàn)圖片無損旋轉(zhuǎn)90°且無黑邊,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧
    2019-12-12

最新評論