Tensorflow 多線程與多進(jìn)程數(shù)據(jù)加載實(shí)例
在項(xiàng)目中遇到需要處理超級(jí)大量的數(shù)據(jù)集,無(wú)法載入內(nèi)存的問(wèn)題就不用說(shuō)了,單線程分批讀取和處理(雖然這個(gè)處理也只是特別簡(jiǎn)單的首尾相連的操作)也會(huì)使瓶頸出現(xiàn)在CPU性能上,所以研究了一下多線程和多進(jìn)程的數(shù)據(jù)讀取和預(yù)處理,都是通過(guò)調(diào)用dataset api實(shí)現(xiàn)
1. 多線程數(shù)據(jù)讀取
第一種方法是可以直接從csv里讀取數(shù)據(jù),但返回值是tensor,需要在sess里run一下才能返回真實(shí)值,無(wú)法實(shí)現(xiàn)真正的并行處理,但如果直接用csv文件或其他什么文件存了特征值,可以直接讀取后進(jìn)行訓(xùn)練,可使用這種方法.
import tensorflow as tf #這里是返回的數(shù)據(jù)類(lèi)型,具體內(nèi)容無(wú)所謂,類(lèi)型對(duì)應(yīng)就好了,比如我這個(gè),就是一個(gè)四維的向量,前三維是字符串類(lèi)型 最后一維是int類(lèi)型 record_defaults = [[""], [""], [""], [0]] def decode_csv(line): parsed_line = tf.decode_csv(line, record_defaults) label = parsed_line[-1] # label del parsed_line[-1] # delete the last element from the list features = tf.stack(parsed_line) # Stack features so that you can later vectorize forward prop., etc. #label = tf.stack(label) #NOT needed. Only if more than 1 column makes the label... batch_to_return = features, label return batch_to_return filenames = tf.placeholder(tf.string, shape=[None]) dataset5 = tf.data.Dataset.from_tensor_slices(filenames) #在這里設(shè)置線程數(shù)目 dataset5 = dataset5.flat_map(lambda filename: tf.data.TextLineDataset(filename).skip(1).map(decode_csv,num_parallel_calls=15)) dataset5 = dataset5.shuffle(buffer_size=1000) dataset5 = dataset5.batch(32) #batch_size iterator5 = dataset5.make_initializable_iterator() next_element5 = iterator5.get_next() #這里是需要加載的文件名 training_filenames = ["train.csv"] validation_filenames = ["vali.csv"] with tf.Session() as sess: for _ in range(2): #通過(guò)文件名初始化迭代器 sess.run(iterator5.initializer, feed_dict={filenames: training_filenames}) while True: try: #這里獲得真實(shí)值 features, labels = sess.run(next_element5) # Train... # print("(train) features: ") # print(features) # print("(train) labels: ") # print(labels) except tf.errors.OutOfRangeError: print("Out of range error triggered (looped through training set 1 time)") break # Validate (cost, accuracy) on train set print("\nDone with the first iterator\n") sess.run(iterator5.initializer, feed_dict={filenames: validation_filenames}) while True: try: features, labels = sess.run(next_element5) # Validate (cost, accuracy) on dev set # print("(dev) features: ") # print(features) # print("(dev) labels: ") # print(labels) except tf.errors.OutOfRangeError: print("Out of range error triggered (looped through dev set 1 time only)") break
第二種方法,基于生成器,可以進(jìn)行預(yù)處理操作了,sess里run出來(lái)的結(jié)果可以直接進(jìn)行輸入訓(xùn)練,但需要自己寫(xiě)一個(gè)生成器,我使用的測(cè)試代碼如下:
import tensorflow as tf import random import threading import numpy as np from data import load_image,load_wave class SequenceData(): def __init__(self, path, batch_size=32): self.path = path self.batch_size = batch_size f = open(path) self.datas = f.readlines() self.L = len(self.datas) self.index = random.sample(range(self.L), self.L) def __len__(self): return self.L - self.batch_size def __getitem__(self, idx): batch_indexs = self.index[idx:(idx+self.batch_size)] batch_datas = [self.datas[k] for k in batch_indexs] img1s,img2s,audios,labels = self.data_generation(batch_datas) return img1s,img2s,audios,labels def gen(self): for i in range(100000): t = self.__getitem__(i) yield t def data_generation(self, batch_datas): #預(yù)處理操作,數(shù)據(jù)在參數(shù)里 return img1s,img2s,audios,labels #這里的type要和實(shí)際返回的數(shù)據(jù)類(lèi)型對(duì)應(yīng),如果在自己的處理代碼里已經(jīng)考慮的batchszie,那這里的batch設(shè)為1即可 dataset = tf.data.Dataset().batch(1).from_generator(SequenceData('train.csv').gen, output_types= (tf.float32,tf.float32,tf.float32,tf.int64)) dataset = dataset.map(lambda x,y,z,w : (x,y,z,w), num_parallel_calls=32).prefetch(buffer_size=1000) X, y,z,w = dataset.make_one_shot_iterator().get_next() with tf.Session() as sess: for _ in range(100000): a,b,c,d = sess.run([X,y,z,w]) print(a.shape)
不過(guò)python的多線程并不是真正的多線程,雖然看起來(lái)我是啟動(dòng)了32線程,但運(yùn)行時(shí)的CPU占用如下所示:
還剩這么多核心空著,然后就是第三個(gè)版本了,使用了queue來(lái)緩存數(shù)據(jù),訓(xùn)練需要數(shù)據(jù)時(shí)直接從queue中進(jìn)行讀取,是一個(gè)到多進(jìn)程的過(guò)度版本(vscode沒(méi)法debug多進(jìn)程,坑啊,還以為代碼寫(xiě)錯(cuò)了,在vscode里多進(jìn)程直接就沒(méi)法運(yùn)行),在初始化時(shí)啟動(dòng)多個(gè)線程進(jìn)行數(shù)據(jù)的預(yù)處理:
import tensorflow as tf import random import threading import numpy as np from data import load_image,load_wave from queue import Queue class SequenceData(): def __init__(self, path, batch_size=32): self.path = path self.batch_size = batch_size f = open(path) self.datas = f.readlines() self.L = len(self.datas) self.index = random.sample(range(self.L), self.L) self.queue = Queue(maxsize=20) for i in range(32): threading.Thread(target=self.f).start() def __len__(self): return self.L - self.batch_size def __getitem__(self, idx): batch_indexs = self.index[idx:(idx+self.batch_size)] batch_datas = [self.datas[k] for k in batch_indexs] img1s,img2s,audios,labels = self.data_generation(batch_datas) return img1s,img2s,audios,labels def f(self): for i in range(int(self.__len__()/self.batch_size)): t = self.__getitem__(i) self.queue.put(t) def gen(self): while 1: yield self.queue.get() def data_generation(self, batch_datas): #數(shù)據(jù)預(yù)處理操作 return img1s,img2s,audios,labels #這里的type要和實(shí)際返回的數(shù)據(jù)類(lèi)型對(duì)應(yīng),如果在自己的處理代碼里已經(jīng)考慮的batchszie,那這里的batch設(shè)為1即可 dataset = tf.data.Dataset().batch(1).from_generator(SequenceData('train.csv').gen, output_types= (tf.float32,tf.float32,tf.float32,tf.int64)) dataset = dataset.map(lambda x,y,z,w : (x,y,z,w), num_parallel_calls=1).prefetch(buffer_size=1000) X, y,z,w = dataset.make_one_shot_iterator().get_next() with tf.Session() as sess: for _ in range(100000): a,b,c,d = sess.run([X,y,z,w]) print(a.shape)
2. 多進(jìn)程數(shù)據(jù)讀取
這里的代碼和多線程的第三個(gè)版本非常類(lèi)似,修改為啟動(dòng)進(jìn)程和進(jìn)程類(lèi)里的Queue即可,但千萬(wàn)不要在vscode里直接debug!在vscode里直接f5運(yùn)行進(jìn)程并不能啟動(dòng).
from __future__ import unicode_literals from functools import reduce import tensorflow as tf import numpy as np import warnings import argparse import skimage.io import skimage.transform import skimage import scipy.io.wavfile from multiprocessing import Process,Queue class SequenceData(): def __init__(self, path, batch_size=32): self.path = path self.batch_size = batch_size f = open(path) self.datas = f.readlines() self.L = len(self.datas) self.index = random.sample(range(self.L), self.L) self.queue = Queue(maxsize=30) self.Process_num=32 for i in range(self.Process_num): print(i,'start') ii = int(self.__len__()/self.Process_num) t = Process(target=self.f,args=(i*ii,(i+1)*ii)) t.start() def __len__(self): return self.L - self.batch_size def __getitem__(self, idx): batch_indexs = self.index[idx:(idx+self.batch_size)] batch_datas = [self.datas[k] for k in batch_indexs] img1s,img2s,audios,labels = self.data_generation(batch_datas) return img1s,img2s,audios,labels def f(self,i_l,i_h): for i in range(i_l,i_h): t = self.__getitem__(i) self.queue.put(t) def gen(self): while 1: t = self.queue.get() yield t[0],t[1],t[2],t[3] def data_generation(self, batch_datas): #數(shù)據(jù)預(yù)處理操作 return img1s,img2s,audios,labels epochs = 2 data_g = SequenceData('train_1.csv',batch_size=48) dataset = tf.data.Dataset().batch(1).from_generator(data_g.gen, output_types= (tf.float32,tf.float32,tf.float32,tf.float32)) X, y,z,w = dataset.make_one_shot_iterator().get_next() with tf.Session() as sess: tf.global_variables_initializer().run() for i in range(epochs): for j in range(int(len(data_g)/(data_g.batch_size))): face1,face2,voice, labels = sess.run([X,y,z,w]) print(face1.shape)
然后,最后實(shí)現(xiàn)的效果
以上這篇Tensorflow 多線程與多進(jìn)程數(shù)據(jù)加載實(shí)例就是小編分享給大家的全部?jī)?nèi)容了,希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
相關(guān)文章
Python對(duì)XML文件實(shí)現(xiàn)增刪改查操作
這篇文章主要為大家詳細(xì)介紹了Python對(duì)XML文件進(jìn)行實(shí)現(xiàn)增刪改查操作的方法,文中的示例代碼講解詳細(xì),具有一定的借鑒價(jià)值,感興趣的可以了解一下2022-11-11在前女友婚禮上,用Python破解了現(xiàn)場(chǎng)的WIFI還把名稱(chēng)改成了
今日重點(diǎn):① python暴力拿下WiFi密碼,②python拿下路由器管理頁(yè)面,文中有非常詳細(xì)的代碼示例,干貨滿滿,,需要的朋友可以參考下2021-05-05Python?pandera數(shù)據(jù)驗(yàn)證和清洗的庫(kù)
為了確保數(shù)據(jù)的質(zhì)量,Python Pandera 庫(kù)應(yīng)運(yùn)而生。本文將深入介紹 Python Pandera,這是一個(gè)用于數(shù)據(jù)驗(yàn)證和清洗的庫(kù),并提供豐富的示例代碼,幫助大家充分利用它來(lái)提高數(shù)據(jù)質(zhì)量2024-01-0110行Python代碼實(shí)現(xiàn)Web自動(dòng)化管控的示例代碼
這篇文章主要介紹了10行Python代碼實(shí)現(xiàn)Web自動(dòng)化管控的示例代碼,文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來(lái)一起學(xué)習(xí)學(xué)習(xí)吧2020-08-08Python 詳解通過(guò)Scrapy框架實(shí)現(xiàn)爬取CSDN全站熱榜標(biāo)題熱詞流程
Scrapy是用純Python實(shí)現(xiàn)一個(gè)為了爬取網(wǎng)站數(shù)據(jù)、提取結(jié)構(gòu)性數(shù)據(jù)而編寫(xiě)的應(yīng)用框架,用途非常廣泛,框架的力量,用戶只需要定制開(kāi)發(fā)幾個(gè)模塊就可以輕松的實(shí)現(xiàn)一個(gè)爬蟲(chóng),用來(lái)抓取網(wǎng)頁(yè)內(nèi)容以及各種圖片,非常之方便2021-11-11Python實(shí)現(xiàn)遍歷大量表格文件并篩選出數(shù)據(jù)缺失率低的文件
這篇文章主要為大家詳細(xì)介紹了如何利用Python實(shí)現(xiàn)遍歷大量表格文件并篩選出表格內(nèi)數(shù)據(jù)缺失率低的文件的功能,感興趣的小伙伴可以跟隨小編一起學(xué)習(xí)一下2023-05-05Python視頻爬蟲(chóng)實(shí)現(xiàn)下載頭條視頻功能示例
這篇文章主要介紹了Python視頻爬蟲(chóng)實(shí)現(xiàn)下載頭條視頻功能,涉及Python正則匹配、網(wǎng)絡(luò)傳輸及文件讀寫(xiě)等相關(guān)操作技巧,需要的朋友可以參考下2018-05-05