tensorflow使用tf.data.Dataset 處理大型數(shù)據(jù)集問(wèn)題
最近深度學(xué)習(xí)用到的數(shù)據(jù)集比較大,如果一次性將數(shù)據(jù)集讀入內(nèi)存,那服務(wù)器是頂不住的,所以需要分批進(jìn)行讀取,這里就用到了tf.data.Dataset構(gòu)建數(shù)據(jù)集:
概括一下,tf.data.Dataset主要有幾個(gè)部分最重要:
- 構(gòu)建生成器函數(shù)
- 使用tf.data.Dataset的from_generator函數(shù),通過(guò)指定數(shù)據(jù)類型,數(shù)據(jù)的shape等參數(shù),構(gòu)建一個(gè)Dataset
- 指定batch_size
- 使用make_one_shot_iterator()函數(shù),構(gòu)建一個(gè)iterator
- 使用上面構(gòu)建的迭代器開始get_next() 。(必須要有這個(gè)get_next(),迭代器才會(huì)工作)
一.構(gòu)建生成器
生成器的要點(diǎn)是要在while True中加入yield,yield的功能有點(diǎn)類似return,有yield才能起到迭代的作用。
我的數(shù)據(jù)是一個(gè)[6047, 6000, 1]的文本數(shù)據(jù),我每次迭代返回的shape為[1,6000,1],要注意的是返回的shape要和構(gòu)建Dataset時(shí)的shape一致,下面會(huì)說(shuō)到。
代碼如下:
def gen():?? ??? ??? ??? ? ?? ??? ?train=pd.read_csv('/home/chenqiren/PycharmProjects/code/test/formal/small_sample/train2.csv', header=None) ? ? ? ? train.fillna(0, inplace = True) ? ? ? ? label_encoder = LabelEncoder().fit(train[6000]) ? ? ? ? label = label_encoder.transform(train[6000]) ? ? ? ? ? train = train.drop([6000], axis=1)? ? ? ? ? scaler = StandardScaler().fit(train.values) ? #train.values中的值是csv文件中的那些值, ? ? 這步標(biāo)準(zhǔn)化可以保留 ? ? ? ? scaled_train = scaler.transform(train.values) ? ? ? ? #print(scaled_train) ? ? ? ? #拆分訓(xùn)練集和測(cè)試集-------------- ? ? ? ? sss=StratifiedShuffleSplit(test_size=0.1, random_state=23) ? ? ? ? for train_index, valid_index in sss.split(scaled_train, label): ? #需要的是數(shù)組,train.values得到的是數(shù)組 ? ? ? ? ? ? X_train, X_valid=scaled_train[train_index], scaled_train[valid_index] ?#https://www.cnblogs.com/Allen-rg/p/9453949.html ? ? ? ? ? ? y_train, y_valid=label[train_index], label[valid_index] ? ? ? ? X_train_r=np.zeros((len(X_train), 6000, 1)) ? #先構(gòu)建一個(gè)框架出來(lái),下面再賦值 ? ? ? ? X_train_r[:,: ,0]=X_train[:,0:6000] ? ?? ? ?? ? ? ? ? X_valid_r=np.zeros((len(X_valid), 6000, 1)) ? ? ? ? X_valid_r[:,: ,0]=X_valid[:,0:6000] ? ?? ? ? ? ? y_train=np_utils.to_categorical(y_train, 3) ? ? ? ? y_valid=np_utils.to_categorical(y_valid, 3) ? ? ? ?? ? ? ? ? leng=len(X_train_r) ? ? ? ? index=0 ? ? ? ? while True: ? ? ? ? ? ? x_train_batch=X_train_r[index, :, 0:1] ? ? ? ? ? ? y_train_batch=y_train[index, :] ? ? ? ? ? ? yield (x_train_batch, y_train_batch) ? ? ? ? ? ? index=index+1 ? ? ? ? ? ? if index>leng: ? ? ? ? ? ? ? ? break
代碼中while True上面的部分是標(biāo)準(zhǔn)化數(shù)據(jù)的代碼,可以不用看,只需要看 while True中的代碼即可。
x_train_batch, y_train_batch都只是一行的數(shù)據(jù),這里是一行一行數(shù)據(jù)迭代。
二.使用tf.data.Dataset包裝生成器
data=tf.data.Dataset.from_generator(gen_1, (tf.float32, tf.float32), (tf.TensorShape([6000,1]), tf.TensorShape([3]))) data=data.batch(128) iterator=data.make_one_shot_iterator()
這里的tf.TensorShape([6000,1]) 和 tf.TensorShape([3])中的shape要和上面生成器yield返回的數(shù)據(jù)的shape一致。
data=data.batch(128)
是設(shè)置batchsize,這里設(shè)為128,在運(yùn)行時(shí),因?yàn)槲覀儁ield的是一行的數(shù)據(jù)[1, 6000, 1],所以將會(huì)循環(huán)yield夠128次,得到[128, 6000, 1],即一個(gè)batch,才會(huì)開始訓(xùn)練。iterator=data.make_one_shot_iterator()
是構(gòu)建迭代器,one_shot迭代器人如其名,意思就是數(shù)據(jù)輸出一次后就丟棄了。
三.獲取生成器返回的數(shù)據(jù)
x, y=iterator.get_next() x_batch, y_batch=sess.run([x,y])
注意要有g(shù)et_next(),迭代器才能開始工作。
第二行是run第一行代碼。獲取訓(xùn)練數(shù)據(jù)和訓(xùn)練標(biāo)簽。
這里做個(gè)關(guān)于yield的小筆記:
上一次迭代,yield返回了值,然后get_next()開啟了下一次迭代,此時(shí),程序是從yield處開始運(yùn)行的,也就是說(shuō),如果yield后面還有程序,那就會(huì)運(yùn)行yield后面的程序。一直運(yùn)行的是while True中的程序,沒有運(yùn)行while True外面的程序。
下面是我寫的總的代碼??梢圆挥每?。
import os import keras import numpy as np import pandas as pd from sklearn.preprocessing import LabelEncoder from sklearn.preprocessing import StandardScaler from sklearn.model_selection import StratifiedShuffleSplit from sklearn.model_selection import train_test_split from keras.models import Sequential, Model from keras.layers import Dense, Activation, Flatten, Conv1D, Dropout, MaxPooling1D, GlobalAveragePooling1D from keras.layers import GlobalAveragePooling2D,BatchNormalization, UpSampling1D, RepeatVector,Reshape from keras.layers.core import Lambda from keras.optimizers import SGD, Adam, Adadelta from keras.utils import np_utils from keras.applications.inception_resnet_v2 import InceptionResNetV2 from keras.backend import conv3d,reshape, shape, categorical_crossentropy, mean, square from keras.applications.vgg16 import VGG16 from keras.layers import Input,LSTM from keras import regularizers from keras.utils import multi_gpu_model import tensorflow as tf import keras.backend.tensorflow_backend as KTF os.environ["CUDA_VISIBLE_DEVICES"]="2" config = tf.ConfigProto() config.gpu_options.allow_growth = True session = tf.Session(config=config) keep_prob = tf.placeholder("float") # 設(shè)置session KTF.set_session(session ) #-----生成訓(xùn)練數(shù)據(jù)----------------------------------------------- def gen_1(): train=pd.read_csv('/home/chenqiren/PycharmProjects/code/test/formal/small_sample/train2.csv', header=None) train.fillna(0, inplace = True) label_encoder = LabelEncoder().fit(train[6000]) label = label_encoder.transform(train[6000]) train = train.drop([6000], axis=1) scaler = StandardScaler().fit(train.values) #train.values中的值是csv文件中的那些值, 這步標(biāo)準(zhǔn)化可以保留 scaled_train = scaler.transform(train.values) #print(scaled_train) #拆分訓(xùn)練集和測(cè)試集-------------- sss=StratifiedShuffleSplit(test_size=0.1, random_state=23) for train_index, valid_index in sss.split(scaled_train, label): #需要的是數(shù)組,train.values得到的是數(shù)組 X_train, X_valid=scaled_train[train_index], scaled_train[valid_index] #https://www.cnblogs.com/Allen-rg/p/9453949.html y_train, y_valid=label[train_index], label[valid_index] X_train_r=np.zeros((len(X_train), 6000, 1)) #先構(gòu)建一個(gè)框架出來(lái),下面再賦值 #開始賦值 #https://stackoverflow.com/questions/43290202/python-typeerror-unhashable-type-slice-for-encoding-categorical-data X_train_r[:,: ,0]=X_train[:,0:6000] X_valid_r=np.zeros((len(X_valid), 6000, 1)) X_valid_r[:,: ,0]=X_valid[:,0:6000] y_train=np_utils.to_categorical(y_train, 3) y_valid=np_utils.to_categorical(y_valid, 3) leng=len(X_train_r) index=0 while True: x_train_batch=X_train_r[index, :, 0:1] y_train_batch=y_train[index, :] yield (x_train_batch, y_train_batch) index=index+1 if index>leng: break #----生成測(cè)試數(shù)據(jù)-------------------------------------- def gen_2(): train=pd.read_csv('/home/chenqiren/PycharmProjects/code/test/formal/small_sample/train2.csv', header=None) train.fillna(0, inplace = True) label_encoder = LabelEncoder().fit(train[6000]) label = label_encoder.transform(train[6000]) train = train.drop([6000], axis=1) scaler = StandardScaler().fit(train.values) #train.values中的值是csv文件中的那些值, 這步標(biāo)準(zhǔn)化可以保留 scaled_train = scaler.transform(train.values) #print(scaled_train) #拆分訓(xùn)練集和測(cè)試集-------------- sss=StratifiedShuffleSplit(test_size=0.1, random_state=23) for train_index, valid_index in sss.split(scaled_train, label): #需要的是數(shù)組,train.values得到的是數(shù)組 X_train, X_valid=scaled_train[train_index], scaled_train[valid_index] #https://www.cnblogs.com/Allen-rg/p/9453949.html y_train, y_valid=label[train_index], label[valid_index] X_train_r=np.zeros((len(X_train), 6000, 1)) #先構(gòu)建一個(gè)框架出來(lái),下面再賦值 #開始賦值 #https://stackoverflow.com/questions/43290202/python-typeerror-unhashable-type-slice-for-encoding-categorical-data X_train_r[:,: ,0]=X_train[:,0:6000] X_valid_r=np.zeros((len(X_valid), 6000, 1)) X_valid_r[:,: ,0]=X_valid[:,0:6000] y_train=np_utils.to_categorical(y_train, 3) y_valid=np_utils.to_categorical(y_valid, 3) leng=len(X_valid_r) index=0 while True: x_test_batch=X_valid_r[index, :, 0:1] y_test_batch=y_valid[index, :] yield (x_test_batch, y_test_batch) index=index+1 if index>leng: break #--------------------------------------------------------------------- def custom_mean_squared_error(y_true, y_pred): return mean(square(y_pred - y_true)) def custom_categorical_crossentropy(y_true, y_pred): return categorical_crossentropy(y_true, y_pred) def loss_func(y_loss, x_loss): return categorical_crossentropy + 0.05 * mean_squared_error #建立模型 with tf.device('/cpu:0'): inputs1=tf.placeholder(tf.float32, shape=(None,6000,1)) x1=LSTM(128, return_sequences=True)(inputs1) encoded=LSTM(64 ,return_sequences=True)(x1) print('encoded shape:',shape(encoded)) #decode x1=LSTM(128, return_sequences=True)(encoded) decoded=LSTM(1, return_sequences=True,name='decode')(x1) #classify labels=tf.placeholder(tf.float32, shape=(None,3)) x2=Conv1D(20,kernel_size=50, strides=2, activation='relu' )(encoded) #步數(shù)論文中未提及,第一層 x2=MaxPooling1D(pool_size=2, strides=1)(x2) x2=Conv1D(20,kernel_size=50, strides=2, activation='relu')(x2) #第二層 x2=MaxPooling1D(pool_size=2, strides=1)(x2) x2=Dropout(0.25)(x2) x2=Conv1D(24,kernel_size=30, strides=2, activation='relu')(x2) #第三層 x2=MaxPooling1D(pool_size=2, strides=1)(x2) x2=Dropout(0.25)(x2) x2=Conv1D(24,kernel_size=30, strides=2, activation='relu')(x2) #第四層 x2=MaxPooling1D(pool_size=2, strides=1)(x2) x2=Dropout(0.25)(x2) x2=Conv1D(24,kernel_size=10, strides=2, activation='relu')(x2) #第五層 x2=MaxPooling1D(pool_size=2, strides=1)(x2) x2=Dropout(0.25)(x2) x2=Dense(192)(x2) #第一個(gè)全連接層 x2=Dense(192)(x2) #第二個(gè)全連接層 x2=Flatten()(x2) x2=Dense(3,activation='softmax', name='classify')(x2) def get_accuracy(x2, labels): current = tf.cast(tf.equal(tf.argmax(x2, 1), tf.argmax(labels, 1)), 'float') accuracy = tf.reduce_mean(current) return accuracy #實(shí)例化獲取準(zhǔn)確率函數(shù) getAccuracy = get_accuracy(x2, labels) #定義損失函數(shù) all_loss=tf.reduce_mean(categorical_crossentropy(x2 , labels) + tf.convert_to_tensor(0.5)*square(decoded-inputs1)) train_option=tf.train.AdamOptimizer(0.01).minimize(all_loss) #----------------------------------------- #生成訓(xùn)練數(shù)據(jù) data=tf.data.Dataset.from_generator(gen_1, (tf.float32, tf.float32), (tf.TensorShape([6000,1]), tf.TensorShape([3]))) data=data.batch(128) iterator=data.make_one_shot_iterator() #生成測(cè)試數(shù)據(jù) data2=tf.data.Dataset.from_generator(gen_2, (tf.float32, tf.float32), (tf.TensorShape([6000,1]), tf.TensorShape([3]))) data2=data2.batch(128) iterator2=data2.make_one_shot_iterator() #----------------------------------------- with tf.Session() as sess: init=tf.global_variables_initializer() sess.run(init) i=-1 for k in range(20): #----------------------------------------- x, y=iterator.get_next() x_batch, y_batch=sess.run([x,y]) print('batch shape:',x_batch.shape, y_batch.shape) #----------------------------------------- if k%2==0: print('第',k,'輪') x3=sess.run(x2, feed_dict={inputs1:x_batch, labels:y_batch }) dc=sess.run(decoded, feed_dict={inputs1:x_batch}) accuracy=sess.run(getAccuracy, feed_dict={x2:x3, labels:y_batch, keep_prob: 1.0}) loss=sess.run(all_loss, feed_dict={x2:x3, labels:y_batch, inputs1:x_batch, decoded:dc}) print("step(s): %d ----- accuracy: %g -----loss: %g" % (i, accuracy, loss)) sess.run(train_option, feed_dict={inputs1:x_batch, labels:y_batch, keep_prob: 0.5}) x, y=iterator2.get_next() x_test_batch, y_test_batch=sess.run([x,y]) print('batch shape:',x_test_batch.shape, y_test_batch.shape) x_test=sess.run(x2, feed_dict={inputs1:x_test_batch, labels:y_test_batch }) print ("test accuracy %f"%getAccuracy.eval(feed_dict={x2:x_test, labels:y_test_batch, keep_prob: 1.0}))
總結(jié)
以上為個(gè)人經(jīng)驗(yàn),希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
相關(guān)文章
對(duì)django 2.x版本中models.ForeignKey()外鍵說(shuō)明介紹
這篇文章主要介紹了對(duì)django 2.x版本中models.ForeignKey()外鍵說(shuō)明介紹,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧2020-03-03cv2.getStructuringElement()函數(shù)及開、閉、腐蝕、膨脹原理講解
getStructuringElement()函數(shù)可用于構(gòu)造一個(gè)特定大小和形狀的結(jié)構(gòu)元素,用于圖像形態(tài)學(xué)處理,這篇文章主要介紹了cv2.getStructuringElement()函數(shù)及開、閉、腐蝕、膨脹原理講解的相關(guān)資料,需要的朋友可以參考下2022-12-12python機(jī)器學(xué)習(xí)XGBoost梯度提升決策樹的高效且可擴(kuò)展實(shí)現(xiàn)
這篇文章主要為大家介紹了python機(jī)器學(xué)習(xí)XGBoost梯度提升決策樹的高效且可擴(kuò)展實(shí)現(xiàn),有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進(jìn)步,早日升職加薪2024-01-01Python實(shí)現(xiàn)為PDF去除水印的示例代碼
這篇文章主要介紹了如何利用Python實(shí)現(xiàn)PDF去除水印功能,文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來(lái)一起學(xué)習(xí)學(xué)習(xí)吧2022-04-04python多線程實(shí)現(xiàn)同時(shí)執(zhí)行兩個(gè)while循環(huán)的操作
這篇文章主要介紹了python多線程實(shí)現(xiàn)同時(shí)執(zhí)行兩個(gè)while循環(huán)的操作,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧2020-05-05Python實(shí)現(xiàn)iOS自動(dòng)化打包詳解步驟
這篇文章主要介紹了Python實(shí)現(xiàn)iOS自動(dòng)化打包詳解步驟,文中通過(guò)示例代碼以及圖文介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來(lái)一起學(xué)習(xí)學(xué)習(xí)吧2018-10-10