pytorch加載自己的數(shù)據(jù)集源碼分享
一、標準的數(shù)據(jù)集流程梳理
分為幾個步驟
數(shù)據(jù)準備以及加載數(shù)據(jù)庫–>數(shù)據(jù)加載器的調(diào)用或者設(shè)計–>批量調(diào)用進行訓(xùn)練或者其他作用
數(shù)據(jù)來源
直接讀取了x和y的數(shù)據(jù)變量,對比后面的就從把對應(yīng)的路徑寫進了文本文件中,通過加載器進行讀取
x = torch.linspace(1, 10, 10) # 訓(xùn)練數(shù)據(jù) linspace返回一個一維的張量,(最小值,最大值,多少個數(shù)) print(x) y = torch.linspace(10, 1, 10) # 標簽 print(y)
將數(shù)據(jù)加載進數(shù)據(jù)庫
輸出的結(jié)果是<torch.utils.data.dataset.TensorDataset object at 0x00000145BD93F1C0>
,需要使用加載器進行加載,才能迭代遍歷
import torch.utils.data as Data torch_dataset = Data.TensorDataset(x, y) # 對給定的 tensor 數(shù)據(jù),將他們包裝成 dataset #輸出的結(jié)果是<torch.utils.data.dataset.TensorDataset object at 0x00000145BD93F1C0>,需要使用加載器進行加載,才能迭代遍歷 print(torch_dataset)
所以要想看里面的內(nèi)容,就需要用迭代進行操作或者查看。
BATCH_SIZE=5 loader = Data.DataLoader(#使用支持的默認的數(shù)據(jù)集加載的方式 # 從數(shù)據(jù)庫中每次抽出batch size個樣本 dataset=torch_dataset, # torch TensorDataset format 加載數(shù)據(jù)集 batch_size=BATCH_SIZE, # mini batch size 5 shuffle=False, # 要不要打亂數(shù)據(jù) (打亂比較好) num_workers=2, # 多線程來讀數(shù)據(jù) ) def show_batch(): for epoch in range(3): for step, (batch_x, batch_y) in enumerate(loader): #加載數(shù)據(jù)集的時候起的作用很奇怪 # training print("steop:{}, batch_x:{}, batch_y:{}".format(step, batch_x, batch_y)) print("*"*100) if __name__ == '__main__': show_batch()
二、實現(xiàn)加載自己的數(shù)據(jù)集
實現(xiàn)自己的數(shù)據(jù)集就需要完成對dataset類的重載。這個類的重載完成幾個函數(shù)的作用
- 初始化數(shù)據(jù)集中的數(shù)據(jù)以及標簽
__init__()
- 返回數(shù)據(jù)和對應(yīng)標簽
__getitem__
- 返回數(shù)據(jù)集的大小
__len__
基本的數(shù)據(jù)集的方法就是完成以上步驟,但是可以想想數(shù)據(jù)集通常是一些圖片和標簽組成,而這些數(shù)據(jù)集以及標簽是保存在計算機上,具有相對應(yīng)的位置,那么直接訪問對應(yīng)的位置因為是在文件夾下需要進行遍歷等一系列操作,而且這就顯得和dataset類沒有解耦,因為有時候在這些位置的操作可能會有一些特殊操作,所以如果能夠?qū)⑵湮恢帽4嬖谖谋疚募锌赡芫蜁奖愫芏?,所以就采取保存文本文件的方式?/p>
# 自定義數(shù)據(jù)集類 class MyDataset(torch.utils.data.Dataset): def __init__(self, *args): super().__init__() # 初始化數(shù)據(jù)集包含的數(shù)據(jù)和標簽 pass def __getitem__(self, index): # 根據(jù)索引index從文件中讀取一個數(shù)據(jù) # 對數(shù)據(jù)預(yù)處理 # 返回數(shù)據(jù)和對應(yīng)標簽 pass def __len__(self): # 返回數(shù)據(jù)集的大小 return len()
1. 保存在txt文件中(生成訓(xùn)練集和測試集,其實這里的訓(xùn)練集以及測試集也都是用文本文件的形式保存下來的)
所以這里新建一個數(shù)據(jù)庫就是新建了兩個文本文件,然后加載器通過文本文件就將圖片以及l(fā)abel加載進去了。而標準的數(shù)據(jù)集操作是使用了自帶的數(shù)據(jù)集接口,在加載的時候也不用再去實現(xiàn)相關(guān)的__getitem__方法
- 數(shù)組定義
- 將絕對路徑加載進數(shù)組中
- 數(shù)組定義
- 將絕對路徑加載進數(shù)組中
- 通過os.walk操作
- os.walk可以獲得根路徑、文件夾以及文件,并會一直進行迭代遍歷下去,直至只有文件才會結(jié)束
- 將數(shù)組的內(nèi)容打亂順序
- 分別將絕對路徑對應(yīng)的數(shù)組內(nèi)容寫進文本文件里,那么這里的文本文件就是保存的數(shù)據(jù)庫,其實數(shù)據(jù)就是一個保存相關(guān)信息或者其內(nèi)容的文件,而標準也是將將其數(shù)據(jù)保存在了一個地方,然后對應(yīng)到標準接口就可以加載了(Data.TensorDataset以及Data.DataLoader)
以下代碼用于生成對應(yīng)的train.txt val.txt
''' 生成訓(xùn)練集和測試集,保存在txt文件中 ''' import os import random train_ratio = 0.6 test_ratio = 1-train_ratio rootdata = r"dataset" #數(shù)組定義 train_list, test_list = [],[] data_list = [] class_flag = -1 # 將絕對路徑加載進數(shù)組中 for a,b,c in os.walk(rootdata):#os.walk可以獲得根路徑、文件夾以及文件,并會一直進行迭代遍歷下去,直至只有文件才會結(jié)束 print(a) for i in range(len(c)): data_list.append(os.path.join(a,c[i])) for i in range(0,int(len(c)*train_ratio)): train_data = os.path.join(a, c[i])+'\t'+str(class_flag)+'\n' #class_flag表示分類的類別 train_list.append(train_data) for i in range(int(len(c) * train_ratio),len(c)): test_data = os.path.join(a, c[i]) + '\t' + str(class_flag)+'\n' test_list.append(test_data) class_flag += 1 print(train_list) # 將數(shù)組的內(nèi)容打亂順序 random.shuffle(train_list) random.shuffle(test_list) #分別將絕對路徑對應(yīng)的數(shù)組內(nèi)容寫進文本文件里 with open('train.txt','w',encoding='UTF-8') as f: for train_img in train_list: f.write(str(train_img)) with open('test.txt','w',encoding='UTF-8') as f: for test_img in test_list: f.write(test_img)
2. 在繼承dataset類LoadData的三個函數(shù)里調(diào)用train.txt以及test.txt實現(xiàn)相關(guān)功能
初始化數(shù)據(jù)集中的數(shù)據(jù)以及標簽、相關(guān)變量__init__()
def __init__(self, txt_path, train_flag=True): #初始化圖片對應(yīng)的變量imgs_info以及一些相關(guān)變量 self.imgs_info = self.get_images(txt_path) #imgs_info保存了圖片以及標簽 self.train_flag = train_flag self.train_tf = transforms.Compose([#對訓(xùn)練集的圖片進行預(yù)處理 transforms.Resize(224), transforms.RandomHorizontalFlip(), transforms.RandomVerticalFlip(), transforms.ToTensor(), transform_BZ ]) self.val_tf = transforms.Compose([#對測試集的圖片進行預(yù)處理 transforms.Resize(224), transforms.ToTensor(), transform_BZ ])
返回數(shù)據(jù)和對應(yīng)標簽__getitem__
def __getitem__(self, index): img_path, label = self.imgs_info[index] #打開圖片,并將RGBA轉(zhuǎn)換為RGB,這里是通過PIL庫打開圖片的 img = Image.open(img_path) img = img.convert('RGB') img = self.padding_black(img) #將圖片添加上黑邊的 if self.train_flag: #選擇是訓(xùn)練集還是測試集 img = self.train_tf(img) else: img = self.val_tf(img) label = int(label) return img, label
返回數(shù)據(jù)集的大小__len__
def __len__(self): return len(self.imgs_info)
由于前面已經(jīng)對集成dataset的類進行了實現(xiàn)三種方法,那么就可以在加載器中進行加載,將加載后的數(shù)據(jù)傳入到train函數(shù)或者test函數(shù)都可以
train_dataloader = DataLoader(dataset=train_data, num_workers=4, pin_memory=True, batch_size=batch_size, shuffle=True)
:使用加載器加載數(shù)據(jù)train(train_dataloader, model, loss_fn, optimizer) test(test_dataloader, model)
:將數(shù)據(jù)傳入train或者test中進行訓(xùn)練或者測試- 注意:LoadData是繼承了dataset的類
if __name__=='__main__': batch_size = 16 # # 給訓(xùn)練集和測試集分別創(chuàng)建一個數(shù)據(jù)集加載器 train_data = LoadData("train.txt", True) valid_data = LoadData("test.txt", False) train_dataloader = DataLoader(dataset=train_data, num_workers=4, pin_memory=True, batch_size=batch_size, shuffle=True) test_dataloader = DataLoader(dataset=valid_data, num_workers=4, pin_memory=True, batch_size=batch_size) for X, y in test_dataloader: print("Shape of X [N, C, H, W]: ", X.shape) print("Shape of y: ", y.shape, y.dtype) break
三、源碼
鏈接: https://pan.baidu.com/s/19Oo87gbcm9e8zvYGkBi95A 提取碼: 2tss
到此這篇關(guān)于pytorch加載自己的數(shù)據(jù)集源碼分享的文章就介紹到這了,更多相關(guān)pytorch加載自己的數(shù)據(jù)集內(nèi)容請搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
相關(guān)文章
一款Python工具制作的動態(tài)條形圖(強烈推薦!)
有時為了方便看數(shù)據(jù)的變化情況,需要畫一個動態(tài)圖來看整體的變化情況,下面這篇文章主要給大家介紹了一款Python工具制作的動態(tài)條形圖的相關(guān)資料,文中通過實例代碼介紹的非常詳細,需要的朋友可以參考下2023-02-02python lambda表達式(匿名函數(shù))寫法解析
這篇文章主要介紹了python lambda表達式(匿名函數(shù))寫法解析,文中通過示例代碼介紹的非常詳細,對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價值,需要的朋友可以參考下2019-09-09Linux CentOS Python開發(fā)環(huán)境搭建教程
這篇文章主要介紹了Linux CentOS Python開發(fā)環(huán)境搭建方法,非常不錯,具有一定的參考借鑒價值,需要的朋友可以參考下2018-11-11Python strip lstrip rstrip使用方法
Python中的strip用于去除字符串的首位字符,同理,lstrip用于去除左邊的字符,rstrip用于去除右邊的字符。這三個函數(shù)都可傳入一個參數(shù),指定要去除的首尾字符。2008-09-09使用Pycharm(Python工具)新建項目及創(chuàng)建Python文件的教程
這篇文章主要介紹了使用Pycharm(Python工具)新建項目及創(chuàng)建Python文件的教程,本文通過圖文并茂的形式給大家介紹的非常詳細,對大家的學(xué)習(xí)或工作具有一定的參考借鑒價值,需要的朋友可以參考下2020-04-04