pytorch對可變長度序列的處理方法詳解
主要是用函數(shù)torch.nn.utils.rnn.PackedSequence()和torch.nn.utils.rnn.pack_padded_sequence()以及torch.nn.utils.rnn.pad_packed_sequence()來進(jìn)行的,分別來看看這三個函數(shù)的用法。
1、torch.nn.utils.rnn.PackedSequence()
NOTE: 這個類的實例不能手動創(chuàng)建。它們只能被 pack_padded_sequence() 實例化。
PackedSequence對象包括:
一個data對象:一個torch.Variable(令牌的總數(shù),每個令牌的維度),在這個簡單的例子中有五個令牌序列(用整數(shù)表示):(18,1)
一個batch_sizes對象:每個時間步長的令牌數(shù)列表,在這個例子中為:[6,5,2,4,1]
用pack_padded_sequence函數(shù)來構(gòu)造這個對象非常的簡單:
如何構(gòu)造一個PackedSequence對象(batch_first = True)
PackedSequence對象有一個很不錯的特性,就是我們無需對序列解包(這一步操作非常慢)即可直接在PackedSequence數(shù)據(jù)變量上執(zhí)行許多操作。特別是我們可以對令牌執(zhí)行任何操作(即對令牌的順序/上下文不敏感)。當(dāng)然,我們也可以使用接受PackedSequence作為輸入的任何一個pyTorch模塊(pyTorch 0.2)。
2、torch.nn.utils.rnn.pack_padded_sequence()
這里的pack,理解成壓緊比較好。 將一個 填充過的變長序列 壓緊。(填充時候,會有冗余,所以壓緊一下)
輸入的形狀可以是(T×B×* )。T是最長序列長度,B是batch size,*代表任意維度(可以是0)。如果batch_first=True的話,那么相應(yīng)的 input size 就是 (B×T×*)。
Variable中保存的序列,應(yīng)該按序列長度的長短排序,長的在前,短的在后。即input[:,0]代表的是最長的序列,input[:, B-1]保存的是最短的序列。
NOTE: 只要是維度大于等于2的input都可以作為這個函數(shù)的參數(shù)。你可以用它來打包labels,然后用RNN的輸出和打包后的labels來計算loss。通過PackedSequence對象的.data屬性可以獲取 Variable。
參數(shù)說明:
input (Variable) – 變長序列 被填充后的 batch
lengths (list[int]) – Variable 中 每個序列的長度。
batch_first (bool, optional) – 如果是True,input的形狀應(yīng)該是B*T*size。
返回值:
一個PackedSequence 對象。
3、torch.nn.utils.rnn.pad_packed_sequence()
填充packed_sequence。
上面提到的函數(shù)的功能是將一個填充后的變長序列壓緊。 這個操作和pack_padded_sequence()是相反的。把壓緊的序列再填充回來。
返回的Varaible的值的size是 T×B×*, T 是最長序列的長度,B 是 batch_size,如果 batch_first=True,那么返回值是B×T×*。
Batch中的元素將會以它們長度的逆序排列。
參數(shù)說明:
sequence (PackedSequence) – 將要被填充的 batch
batch_first (bool, optional) – 如果為True,返回的數(shù)據(jù)的格式為 B×T×*。
返回值: 一個tuple,包含被填充后的序列,和batch中序列的長度列表。
例子:
import torch import torch.nn as nn from torch.autograd import Variable from torch.nn import utils as nn_utils batch_size = 2 max_length = 3 hidden_size = 2 n_layers =1 tensor_in = torch.FloatTensor([[1, 2, 3], [1, 0, 0]]).resize_(2,3,1) tensor_in = Variable( tensor_in ) #[batch, seq, feature], [2, 3, 1] seq_lengths = [3,1] # list of integers holding information about the batch size at each sequence step # pack it pack = nn_utils.rnn.pack_padded_sequence(tensor_in, seq_lengths, batch_first=True) # initialize rnn = nn.RNN(1, hidden_size, n_layers, batch_first=True) h0 = Variable(torch.randn(n_layers, batch_size, hidden_size)) #forward out, _ = rnn(pack, h0) # unpack unpacked = nn_utils.rnn.pad_packed_sequence(out) print('111',unpacked)
輸出:
111 (Variable containing: (0 ,.,.) = 0.5406 0.3584 -0.1403 0.0308 (1 ,.,.) = -0.6855 -0.9307 0.0000 0.0000 [torch.FloatTensor of size 2x2x2] , [2, 1])
以上這篇pytorch對可變長度序列的處理方法詳解就是小編分享給大家的全部內(nèi)容了,希望能給大家一個參考,也希望大家多多支持腳本之家。
- 對pytorch網(wǎng)絡(luò)層結(jié)構(gòu)的數(shù)組化詳解
- pytorch 轉(zhuǎn)換矩陣的維數(shù)位置方法
- pytorch 調(diào)整某一維度數(shù)據(jù)順序的方法
- 對PyTorch torch.stack的實例講解
- 使用pytorch進(jìn)行圖像的順序讀取方法
- mac安裝pytorch及系統(tǒng)的numpy更新方法
- 淺談pytorch和Numpy的區(qū)別以及相互轉(zhuǎn)換方法
- pytorch + visdom CNN處理自建圖片數(shù)據(jù)集的方法
- PyTorch CNN實戰(zhàn)之MNIST手寫數(shù)字識別示例
- PyTorch 1.0 正式版已經(jīng)發(fā)布了
相關(guān)文章
pycharm配置Anaconda虛擬環(huán)境全過程
這篇文章主要介紹了pycharm配置Anaconda虛擬環(huán)境全過程,具有很好的參考價值,希望對大家有所幫助,如有錯誤或未考慮完全的地方,望不吝賜教2024-01-01python dataclass 快速創(chuàng)建數(shù)據(jù)類的方法
在Python中,dataclass是一種用于快速創(chuàng)建數(shù)據(jù)類的裝飾器和工具,本文實例代碼中我們定義了一個Person數(shù)據(jù)類,并使用fields()函數(shù)遍歷其字段,打印出每個字段的名稱、類型、默認(rèn)值和元數(shù)據(jù),對python dataclass 數(shù)據(jù)類相關(guān)知識感興趣的朋友一起看看吧2024-03-03Python內(nèi)建類型int源碼學(xué)習(xí)
這篇文章主要為大家介紹了Python內(nèi)建類型int源碼學(xué)習(xí),有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進(jìn)步,早日升職加薪2022-05-05詳解Python如何利用Pandas與NumPy進(jìn)行數(shù)據(jù)清洗
許多數(shù)據(jù)科學(xué)家認(rèn)為獲取和清理數(shù)據(jù)的初始步驟占工作的 80%,花費大量時間來清理數(shù)據(jù)集并將它們歸結(jié)為可以使用的形式。本文將利用 Python 的 Pandas和 NumPy 庫來清理數(shù)據(jù),需要的可以參考一下2022-04-04Python3實現(xiàn)漢語轉(zhuǎn)換為漢語拼音
這篇文章主要為大家詳細(xì)介紹了Python3實現(xiàn)漢語轉(zhuǎn)換為漢語拼音,具有一定的參考價值,感興趣的小伙伴們可以參考一下2019-07-07