欧美bbbwbbbw肥妇,免费乱码人妻系列日韩,一级黄片

解決Pytorch dataloader時(shí)報(bào)錯(cuò)每個(gè)tensor維度不一樣的問題

 更新時(shí)間:2021年05月28日 11:35:02   作者:XJTU-Qidong  
這篇文章主要介紹了解決Pytorch dataloader時(shí)報(bào)錯(cuò)每個(gè)tensor維度不一樣的問題,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。如有錯(cuò)誤或未考慮完全的地方,望不吝賜教

使用pytorch的dataloader報(bào)錯(cuò):

RuntimeError: stack expects each tensor to be equal size, but got [2] at entry 0 and [1] at entry 1

1. 問題描述

報(bào)錯(cuò)定位:位于定義dataset的代碼中

def __getitem__(self, index):
 ...
 return y    #此處報(bào)錯(cuò)

報(bào)錯(cuò)內(nèi)容

File "D:\python\lib\site-packages\torch\utils\data\_utils\collate.py", line 55, in default_collate
return torch.stack(batch, 0, out=out)
RuntimeError: stack expects each tensor to be equal size, but got [2] at entry 0 and [1] at entry 1

把前一行的報(bào)錯(cuò)帶上能夠更清楚地明白問題在哪里.

2.問題分析

從報(bào)錯(cuò)可以看到,是在代碼中執(zhí)行torch.stack時(shí)發(fā)生了報(bào)錯(cuò).因此必須要明白在哪里執(zhí)行了stack操作.

通過調(diào)試可以發(fā)現(xiàn),在通過loader加載一個(gè)batch數(shù)據(jù)的時(shí)候,是通過每一次給一個(gè)隨機(jī)的index取出相應(yīng)的向量.那么最終要形成一個(gè)batch的數(shù)據(jù)就必須要進(jìn)行拼接操作,而torch.stack就是進(jìn)行這里所說的拼接.

再來看看具體報(bào)的什么錯(cuò): 說是stack的向量維度不同. 這說明在每次給出一個(gè)隨機(jī)的index,返回的y向量的維度應(yīng)該是相同的,而我們這里是不同的.

這樣解決方法也就明確了:使返回的向量y的維度固定下來.

3.問題出處

為什么我會(huì)出現(xiàn)這樣的一個(gè)問題,是因?yàn)槲业奶卣飨蛄恐写嬖趍ulti-hot特征.而為了節(jié)省空間,我是用一個(gè)列表存儲(chǔ)這個(gè)特征的.示例如下:

feature=[[1,3,5],
  [0,2],
  [1,2,5,8]]

這就導(dǎo)致了我每次返回的向量的維度是不同的.因此可以采用向量補(bǔ)全的方法,把不同長(zhǎng)度的向量補(bǔ)全成等長(zhǎng)的.

 # 把所有向量的長(zhǎng)度都補(bǔ)為6
 multi = np.pad(multi, (0, 6-multi.shape[0]), 'constant', constant_values=(0, -1))

4.總結(jié)

在構(gòu)建dataset重寫的__getitem__方法中要返回相同長(zhǎng)度的tensor.

可以使用向量補(bǔ)全的方法來解決這個(gè)問題.

補(bǔ)充:pytorch學(xué)習(xí)筆記:torch.utils.data下的TensorDataset和DataLoader的使用

一、TensorDataset

對(duì)給定的tensor數(shù)據(jù)(樣本和標(biāo)簽),將它們包裝成dataset。注意,如果是numpy的array,或者Pandas的DataFrame需要先轉(zhuǎn)換成Tensor。

'''
data_tensor (Tensor) - 樣本數(shù)據(jù)
target_tensor (Tensor) - 樣本目標(biāo)(標(biāo)簽)
'''
 dataset=torch.utils.data.TensorDataset(data_tensor, 
                                        target_tensor)

下面舉個(gè)例子:

我們先定義一下樣本數(shù)據(jù)和標(biāo)簽數(shù)據(jù),一共有1000個(gè)樣本

import torch
import numpy as np
num_inputs = 2
num_examples = 1000
true_w = [2, -3.4]
true_b = 4.2
features = torch.tensor(np.random.normal(0, 1, 
                       (num_examples, num_inputs)), 
                       dtype=torch.float)

labels = true_w[0] * features[:, 0] + \
         true_w[1] * features[:, 1] + true_b

labels += torch.tensor(np.random.normal(0, 0.01, 
                       size=labels.size()), 
                       dtype=torch.float)

print(features.shape)
print(labels.shape)

'''
輸出:torch.Size([1000, 2])
     torch.Size([1000])
'''

然后我們使用TensorDataset來生成數(shù)據(jù)集

import torch.utils.data as Data
# 將訓(xùn)練數(shù)據(jù)的特征和標(biāo)簽組合
dataset = Data.TensorDataset(features, labels)

二、DataLoader

數(shù)據(jù)加載器,組合數(shù)據(jù)集和采樣器,并在數(shù)據(jù)集上提供單進(jìn)程或多進(jìn)程迭代器。它可以對(duì)我們上面所說的數(shù)據(jù)集Dataset作進(jìn)一步的設(shè)置。

dataset (Dataset) – 加載數(shù)據(jù)的數(shù)據(jù)集。

batch_size (int, optional) – 每個(gè)batch加載多少個(gè)樣本(默認(rèn): 1)。

shuffle (bool, optional) – 設(shè)置為True時(shí)會(huì)在每個(gè)epoch重新打亂數(shù)據(jù)(默認(rèn): False).

sampler (Sampler, optional) – 定義從數(shù)據(jù)集中提取樣本的策略。如果指定,則shuffle必須設(shè)置成False。

num_workers (int, optional) – 用多少個(gè)子進(jìn)程加載數(shù)據(jù)。0表示數(shù)據(jù)將在主進(jìn)程中加載(默認(rèn): 0)

pin_memory:內(nèi)存寄存,默認(rèn)為False。在數(shù)據(jù)返回前,是否將數(shù)據(jù)復(fù)制到CUDA內(nèi)存中。

drop_last (bool, optional) – 如果數(shù)據(jù)集大小不能被batch size整除,則設(shè)置為True后可刪除最后一個(gè)不完整的batch。如果設(shè)為False并且數(shù)據(jù)集的大小不能被batch size整除,則最后一個(gè)batch將更小。(默認(rèn): False)

timeout:是用來設(shè)置數(shù)據(jù)讀取的超時(shí)時(shí)間的,如果超過這個(gè)時(shí)間還沒讀取到數(shù)據(jù)的話就會(huì)報(bào)錯(cuò)。 所以,數(shù)值必須大于等于0。

data_iter=torch.utils.data.DataLoader(dataset, batch_size=1, 
                            shuffle=False, sampler=None, 
                            batch_sampler=None, num_workers=0, 
                            collate_fn=None, pin_memory=False, 
                            drop_last=False, timeout=0, 
                            worker_init_fn=None, 
                            multiprocessing_context=None)

上面對(duì)一些重要常用的參數(shù)做了說明,其中有一個(gè)參數(shù)是sampler,下面我們對(duì)它有哪些具體取值再做一下說明。只列出幾個(gè)常用的取值:

torch.utils.data.sampler.SequentialSampler(dataset)

樣本元素按順序采樣,始終以相同的順序。

torch.utils.data.sampler.RandomSampler(dataset)

樣本元素隨機(jī)采樣,沒有替換。

torch.utils.data.sampler.SubsetRandomSampler(indices)

樣本元素從指定的索引列表中隨機(jī)抽取,沒有替換。

下面就來看一個(gè)例子,該例子使用的dataset就是上面所生成的dataset

data_iter=Data.DataLoader(dataset, 
                          batch_size=10, 
                          shuffle=False,
sampler=torch.utils.data.sampler.RandomSampler(dataset))

for X, y in data_iter:
    print(X,"\n", y)
    break

'''
輸出:
tensor([[-1.6338,  0.8451],
        [ 0.7245, -0.7387],
        [ 0.4672,  0.2623],
        [-1.9082,  0.0980],
        [-0.3881,  0.5138],
        [-0.6983, -0.4712],
        [ 0.1400,  0.7489],
        [-0.7761, -0.4596],
        [-2.2700, -0.2532],
        [-1.2641, -2.8089]]) 

tensor([-1.9451,  8.1587,  4.2374,  0.0519,  1.6843,  4.3970,  
        1.9311,  4.1999,0.5253, 11.2277])
'''

以上為個(gè)人經(jīng)驗(yàn),希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。

相關(guān)文章

  • Python中的模塊導(dǎo)入和讀取鍵盤輸入的方法

    Python中的模塊導(dǎo)入和讀取鍵盤輸入的方法

    這篇文章主要介紹了Python中的模塊導(dǎo)入和讀取鍵盤輸入的方法,相關(guān)import語句和input函數(shù)的使用是Python入門學(xué)習(xí)中的基礎(chǔ)知識(shí), 需要的朋友可以參考下
    2015-10-10
  • python文件處理--文件讀寫詳解

    python文件處理--文件讀寫詳解

    這篇文章主要介紹了Python 處理文件的幾種方式,文中通過示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧
    2021-08-08
  • Python 如何調(diào)試程序崩潰錯(cuò)誤

    Python 如何調(diào)試程序崩潰錯(cuò)誤

    這篇文章主要介紹了Python 如何調(diào)試程序崩潰錯(cuò)誤,文中講解非常細(xì)致,代碼幫助大家更好的理解和學(xué)習(xí),感興趣的朋友可以了解下
    2020-08-08
  • python 字典的打印實(shí)現(xiàn)

    python 字典的打印實(shí)現(xiàn)

    這篇文章主要介紹了python 字典的打印實(shí)現(xiàn),文中通過示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧
    2019-09-09
  • 分享Python獲取本機(jī)IP地址的幾種方法

    分享Python獲取本機(jī)IP地址的幾種方法

    這篇文章主要介紹了分享Python獲取本機(jī)IP地址的幾種方法,分享了使用專用網(wǎng)站、使用自帶socket庫、使用第三方netifaces庫等方式們需要的小伙伴可以參考一下
    2022-03-03
  • 回歸預(yù)測(cè)分析python數(shù)據(jù)化運(yùn)營(yíng)線性回歸總結(jié)

    回歸預(yù)測(cè)分析python數(shù)據(jù)化運(yùn)營(yíng)線性回歸總結(jié)

    本文主要介紹了python數(shù)據(jù)化運(yùn)營(yíng)中的線性回歸一般應(yīng)用場(chǎng)景,常用方法,回歸實(shí)現(xiàn),回歸評(píng)估指標(biāo),效果可視化等,并采用了回歸預(yù)測(cè)分析的數(shù)據(jù)預(yù)測(cè)方法
    2021-08-08
  • Tensorflow分類器項(xiàng)目自定義數(shù)據(jù)讀入的實(shí)現(xiàn)

    Tensorflow分類器項(xiàng)目自定義數(shù)據(jù)讀入的實(shí)現(xiàn)

    這篇文章主要介紹了Tensorflow分類器項(xiàng)目自定義數(shù)據(jù)讀入的實(shí)現(xiàn),小編覺得挺不錯(cuò)的,現(xiàn)在分享給大家,也給大家做個(gè)參考。一起跟隨小編過來看看吧
    2019-02-02
  • Python使用內(nèi)存緩存實(shí)例分享

    Python使用內(nèi)存緩存實(shí)例分享

    Python中的內(nèi)存緩存是一種將計(jì)算結(jié)果存儲(chǔ)在內(nèi)存中,以便在后續(xù)調(diào)用時(shí)快速獲取結(jié)果的技術(shù)。通過使用裝飾器和字典等數(shù)據(jù)結(jié)構(gòu),可以輕松實(shí)現(xiàn)內(nèi)存緩存功能,提高程序的執(zhí)行效率。
    2023-09-09
  • Python SMTP配置參數(shù)并發(fā)送郵件

    Python SMTP配置參數(shù)并發(fā)送郵件

    這篇文章主要介紹了Python SMTP配置參數(shù)并發(fā)送郵件,文中通過示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友可以參考下
    2020-06-06
  • 基于Python實(shí)現(xiàn)批量縮放圖片(視頻)尺寸

    基于Python實(shí)現(xiàn)批量縮放圖片(視頻)尺寸

    這篇文章主要為大家詳細(xì)介紹了如何通過Python語言實(shí)現(xiàn)批量縮放圖片(視頻)尺寸的功能,文中的示例代碼簡(jiǎn)潔易懂,感興趣的小伙伴可以跟隨小編一起了解一下
    2023-03-03

最新評(píng)論