Pytorch DataLoader 變長數(shù)據(jù)處理方式
關(guān)于Pytorch中怎么自定義Dataset數(shù)據(jù)集類、怎樣使用DataLoader迭代加載數(shù)據(jù),這篇官方文檔已經(jīng)說得很清楚了,這里就不在贅述。
現(xiàn)在的問題:有的時候,特別對于NLP任務來說,輸入的數(shù)據(jù)可能不是定長的,比如多個句子的長度一般不會一致,這時候使用DataLoader加載數(shù)據(jù)時,不定長的句子會被胡亂切分,這肯定是不行的。
解決方法是重寫DataLoader的collate_fn,具體方法如下:
# 假如每一個樣本為: sample = { # 一個句子中各個詞的id 'token_list' : [5, 2, 4, 1, 9, 8], # 結(jié)果y 'label' : 5, } # 重寫collate_fn函數(shù),其輸入為一個batch的sample數(shù)據(jù) def collate_fn(batch): # 因為token_list是一個變長的數(shù)據(jù),所以需要用一個list來裝這個batch的token_list token_lists = [item['token_list'] for item in batch] # 每個label是一個int,我們把這個batch中的label也全取出來,重新組裝 labels = [item['label'] for item in batch] # 把labels轉(zhuǎn)換成Tensor labels = torch.Tensor(labels) return { 'token_list': token_lists, 'label': labels, } # 在使用DataLoader加載數(shù)據(jù)時,注意collate_fn參數(shù)傳入的是重寫的函數(shù) DataLoader(trainset, batch_size=4, shuffle=True, num_workers=4, collate_fn=collate_fn)
使用以上方法,可以保證DataLoader能Load出一個batch的數(shù)據(jù),load出來的東西就是重寫的collate_fn函數(shù)最后return出來的字典。
以上這篇Pytorch DataLoader 變長數(shù)據(jù)處理方式就是小編分享給大家的全部內(nèi)容了,希望能給大家一個參考,也希望大家多多支持腳本之家。
相關(guān)文章
快速上手基于Anaconda搭建Django環(huán)境的教程
Django具有完整的封裝,開發(fā)者可以高效率的開發(fā)項目,Django將大部分的功能進行了封裝,開發(fā)者只需要調(diào)用即可,接下來通過本文給大家介紹基于Anaconda搭建Django環(huán)境的教程,需要的朋友可以參考下2021-10-10