PyTorch中的拷貝與就地操作詳解
前言
PyTroch中我們經(jīng)常使用到Numpy進(jìn)行數(shù)據(jù)的處理,然后再轉(zhuǎn)為T(mén)ensor,但是關(guān)系到數(shù)據(jù)的更改時(shí)我們要注意方法是否是共享地址,這關(guān)系到整個(gè)網(wǎng)絡(luò)的更新。本篇就In-palce操作,拷貝操作中的注意點(diǎn)進(jìn)行總結(jié)。
In-place操作
pytorch中原地操作的后綴為_(kāi),如.add_()或.scatter_(),就地操作是直接更改給定Tensor的內(nèi)容而不進(jìn)行復(fù)制的操作,即不會(huì)為變量分配新的內(nèi)存。Python操作類(lèi)似+=或*=也是就地操作。(我加了我自己~)
為什么in-place操作可以在處理高維數(shù)據(jù)時(shí)可以幫助減少內(nèi)存使用呢,下面使用一個(gè)例子進(jìn)行說(shuō)明,定義以下簡(jiǎn)單函數(shù)來(lái)測(cè)量PyTorch的異位ReLU(out-of-place)和就地ReLU(in-place)分配的內(nèi)存:
import torch # import main library import torch.nn as nn # import modules like nn.ReLU() import torch.nn.functional as F # import torch functions like F.relu() and F.relu_() def get_memory_allocated(device, inplace = False): ''' Function measures allocated memory before and after the ReLU function call. INPUT: - device: gpu device to run the operation - inplace: True - to run ReLU in-place, False - for normal ReLU call ''' # Create a large tensor t = torch.randn(10000, 10000, device=device) # Measure allocated memory torch.cuda.synchronize() start_max_memory = torch.cuda.max_memory_allocated() / 1024**2 start_memory = torch.cuda.memory_allocated() / 1024**2 # Call in-place or normal ReLU if inplace: F.relu_(t) else: output = F.relu(t) # Measure allocated memory after the call torch.cuda.synchronize() end_max_memory = torch.cuda.max_memory_allocated() / 1024**2 end_memory = torch.cuda.memory_allocated() / 1024**2 # Return amount of memory allocated for ReLU call return end_memory - start_memory, end_max_memory - start_max_memory # setup the device device = torch.device('cuda:0' if torch.cuda.is_available() else "cpu") #開(kāi)始測(cè)試 # Call the function to measure the allocated memory for the out-of-place ReLU memory_allocated, max_memory_allocated = get_memory_allocated(device, inplace = False) print('Allocated memory: {}'.format(memory_allocated)) print('Allocated max memory: {}'.format(max_memory_allocated)) ''' Allocated memory: 382.0 Allocated max memory: 382.0 ''' #Then call the in-place ReLU as follows: memory_allocated_inplace, max_memory_allocated_inplace = get_memory_allocated(device, inplace = True) print('Allocated memory: {}'.format(memory_allocated_inplace)) print('Allocated max memory: {}'.format(max_memory_allocated_inplace)) ''' Allocated memory: 0.0 Allocated max memory: 0.0 '''
看起來(lái),使用就地操作可以幫助我們節(jié)省一些GPU內(nèi)存。但是,在使用就地操作時(shí)應(yīng)該格外謹(jǐn)慎。
就地操作的主要缺點(diǎn)主要原因有2點(diǎn),官方文檔:
1.可能會(huì)覆蓋計(jì)算梯度所需的值,這意味著破壞了模型的訓(xùn)練過(guò)程。
2.每個(gè)就地操作實(shí)際上都需要實(shí)現(xiàn)來(lái)重寫(xiě)計(jì)算圖。異地操作Out-of-place分配新對(duì)象并保留對(duì)舊圖的引用,而就地操作則需要更改表示此操作的函數(shù)的所有輸入的創(chuàng)建者。
在Autograd中支持就地操作很困難,并且在大多數(shù)情況下不鼓勵(lì)使用。Autograd積極的緩沖區(qū)釋放和重用使其非常高效,就地操作實(shí)際上降低內(nèi)存使用量的情況很少。除非在沉重的內(nèi)存壓力下運(yùn)行,否則可能永遠(yuǎn)不需要使用它們。
總結(jié):Autograd很香了,就地操作要慎用。
拷貝方法
淺拷貝方法: 共享 data 的內(nèi)存地址,數(shù)據(jù)會(huì)同步變化
* a.numpy() # Tensor—>Numpy array
* view() #改變tensor的形狀,但共享數(shù)據(jù)內(nèi)存,不要直接使用id進(jìn)行判斷
* y = x[:] # 索引
* torch.from_numpy() # Numpy array—>Tensor
* torch.detach() # 新的tensor會(huì)脫離計(jì)算圖,不會(huì)牽扯梯度計(jì)算。
* model:forward()
還有很多選擇函數(shù)也是數(shù)據(jù)共享內(nèi)存,如index_select() masked_select() gather()。
以及后文提到的就地操作in-place。
深拷貝方法:
* torch.clone() # 新的tensor會(huì)保留在計(jì)算圖中,參與梯度計(jì)算
下面進(jìn)行驗(yàn)證,首先驗(yàn)證淺拷貝:
import torch as t import numpy as np a = np.ones(4) b = t.from_numpy(a) # Numpy->Tensor print(a) print(b) '''輸出: [1. 1. 1. 1.] tensor([1., 1., 1., 1.], dtype=torch.float64) ''' b.add_(1)# add_會(huì)修改b自身 print(a) print(b) '''輸出: [2. 2. 2. 2.] tensor([2., 2., 2., 2.], dtype=torch.float64) b進(jìn)行add操作后, a,b同步發(fā)生了變化 '''
Tensor和numpy對(duì)象共享內(nèi)存(淺拷貝操作),所以他們之間的轉(zhuǎn)換很快,且會(huì)同步變化。
造torch中y = x + y這樣的運(yùn)算是會(huì)新開(kāi)內(nèi)存的,然后將y指向新內(nèi)存。為了進(jìn)行驗(yàn)證,我們可以使用Python自帶的id函數(shù):如果兩個(gè)實(shí)例的ID一致,那么它們所對(duì)應(yīng)的內(nèi)存地址相同;但需要注意是在torch中還有些特殊,數(shù)據(jù)共享時(shí)直接打印tensor的id仍然會(huì)出現(xiàn)不同。
x = torch.tensor([1, 2]) y = torch.tensor([3, 4]) id_0 = id(y) y = y + x print(id(y) == id_0) # False
這時(shí)使用索引操作不會(huì)開(kāi)辟新的內(nèi)存,而想指定結(jié)果到原來(lái)的y的內(nèi)存,我們可以使用索引來(lái)進(jìn)行替換操作。比如把x + y的結(jié)果通過(guò)[:]寫(xiě)進(jìn)y對(duì)應(yīng)的內(nèi)存中。
x = torch.tensor([1, 2]) y = torch.tensor([3, 4]) id_0 = id(y) y[:] = y + x print(id(y) == id_0) # True
另外,以下兩種方式也可以索引到相同的內(nèi)存:
- torch.add(x, y, out=y)
- y += x, y.add_(x)
x = torch.tensor([1, 2]) y = torch.tensor([3, 4]) id_0 = id(y) torch.add(x, y, out=y) # y += x, y.add_(x) print(id(y) == id_0) # True
clone() 與 detach() 對(duì)比
Torch 為了提高速度,向量或是矩陣的賦值是指向同一內(nèi)存的,這不同于 Matlab。如果需要保存舊的tensor即需要開(kāi)辟新的存儲(chǔ)地址而不是引用,可以用 clone() 進(jìn)行深拷貝,
首先我們來(lái)打印出來(lái)clone()操作后的數(shù)據(jù)類(lèi)型定義變化:
(1). 簡(jiǎn)單打印類(lèi)型
import torch a = torch.tensor(1.0, requires_grad=True) b = a.clone() c = a.detach() a.data *= 3 b += 1 print(a) # tensor(3., requires_grad=True) print(b) print(c) ''' 輸出結(jié)果: tensor(3., requires_grad=True) tensor(2., grad_fn=<AddBackward0>) tensor(3.) # detach()后的值隨著a的變化出現(xiàn)變化 '''
grad_fn=<CloneBackward>,表示clone后的返回值是個(gè)中間變量,因此支持梯度的回溯。clone操作在一定程度上可以視為是一個(gè)identity-mapping函數(shù)。
detach()操作后的tensor與原始tensor共享數(shù)據(jù)內(nèi)存,當(dāng)原始tensor在計(jì)算圖中數(shù)值發(fā)生反向傳播等更新之后,detach()的tensor值也發(fā)生了改變。
注意: 在pytorch中我們不要直接使用id是否相等來(lái)判斷tensor是否共享內(nèi)存,這只是充分條件,因?yàn)橐苍S底層共享數(shù)據(jù)內(nèi)存,但是仍然是新的tensor,比如detach(),如果我們直接打印id會(huì)出現(xiàn)以下情況。
import torch as t a = t.tensor([1.0,2.0], requires_grad=True) b = a.detach() #c[:] = a.detach() print(id(a)) print(id(b)) #140568935450520 140570337203616
顯然直接打印出來(lái)的id不等,我們可以通過(guò)簡(jiǎn)單的賦值后觀察數(shù)據(jù)變化進(jìn)行判斷。
(2). clone()的梯度回傳
detach()函數(shù)可以返回一個(gè)完全相同的tensor,與舊的tensor共享內(nèi)存,脫離計(jì)算圖,不會(huì)牽扯梯度計(jì)算。
而clone充當(dāng)中間變量,會(huì)將梯度傳給源張量進(jìn)行疊加,但是本身不保存其grad,即值為None
import torch a = torch.tensor(1.0, requires_grad=True) a_ = a.clone() y = a**2 z = a ** 2+a_ * 3 y.backward() print(a.grad) # 2 z.backward() print(a_.grad) # None. 中間variable,無(wú)grad print(a.grad) ''' 輸出: tensor(2.) None tensor(7.) # 2*2+3=7 '''
使用torch.clone()獲得的新tensor和原來(lái)的數(shù)據(jù)不再共享內(nèi)存,但仍保留在計(jì)算圖中,clone操作在不共享數(shù)據(jù)內(nèi)存的同時(shí)支持梯度梯度傳遞與疊加,所以常用在神經(jīng)網(wǎng)絡(luò)中某個(gè)單元需要重復(fù)使用的場(chǎng)景下。
通常如果原tensor的requires_grad=True,則:
- clone()操作后的tensor requires_grad=True
- detach()操作后的tensor requires_grad=False。
import torch torch.manual_seed(0) x= torch.tensor([1., 2.], requires_grad=True) clone_x = x.clone() detach_x = x.detach() clone_detach_x = x.clone().detach() f = torch.nn.Linear(2, 1) y = f(x) y.backward() print(x.grad) print(clone_x.requires_grad) print(clone_x.grad) print(detach_x.requires_grad) print(clone_detach_x.requires_grad) ''' 輸出結(jié)果如下: tensor([-0.0053, 0.3793]) True None False False '''
另一個(gè)比較特殊的是當(dāng)源張量的 require_grad=False,clone后的張量 require_grad=True,此時(shí)不存在張量回傳現(xiàn)象,可以得到clone后的張量求導(dǎo)。
如下:
import torch a = torch.tensor(1.0) a_ = a.clone() a_.requires_grad_() #require_grad=True y = a_ ** 2 y.backward() print(a.grad) # None print(a_.grad) ''' 輸出: None tensor(2.) '''
總結(jié):
torch.detach() —新的tensor會(huì)脫離計(jì)算圖,不會(huì)牽扯梯度計(jì)算
torch.clone() — 新的tensor充當(dāng)中間變量,會(huì)保留在計(jì)算圖中,參與梯度計(jì)算(回傳疊加),但是一般不會(huì)保留自身梯度。
原地操作(in-place, such as resize_ / resize_as_ / set_ / transpose_) 在上面兩者中執(zhí)行都會(huì)引發(fā)錯(cuò)誤或者警告。
引用官方文檔的話:如果你使用了in-place operation而沒(méi)有報(bào)錯(cuò)的話,那么你可以確定你的梯度計(jì)算是正確的。另外盡量避免in-place的使用。
到此這篇關(guān)于PyTorch中拷貝與就地操作的文章就介紹到這了,更多相關(guān)PyTorch拷貝與就地操作內(nèi)容請(qǐng)搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
相關(guān)文章
python DataFrame獲取行數(shù)、列數(shù)、索引及第幾行第幾列的值方法
下面小編就為大家分享一篇python DataFrame獲取行數(shù)、列數(shù)、索引及第幾行第幾列的值方法,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧2018-04-04Python將圖片轉(zhuǎn)換為字符畫(huà)的方法
這篇文章主要為大家詳細(xì)介紹了Python將圖片轉(zhuǎn)換為字符畫(huà)的方法,文中示例代碼介紹的非常詳細(xì),具有一定的參考價(jià)值,感興趣的小伙伴們可以參考一下2018-03-03Numpy數(shù)組轉(zhuǎn)置的實(shí)現(xiàn)
本文主要介紹了Numpy數(shù)組轉(zhuǎn)置的實(shí)現(xiàn),文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來(lái)一起學(xué)習(xí)學(xué)習(xí)吧2023-02-02python讀取查看npz/npy文件數(shù)據(jù)以及數(shù)據(jù)完全顯示方法實(shí)例
前兩天從在GitHub下載了一個(gè)代碼,其中的數(shù)據(jù)集是.npz結(jié)尾的文件,之前沒(méi)有見(jiàn)過(guò)不知道如何處理,下面這篇文章主要給大家介紹了關(guān)于python讀取查看npz/npy文件數(shù)據(jù)以及數(shù)據(jù)完全顯示方法的相關(guān)資料,需要的朋友可以參考下2022-04-04Python計(jì)算庫(kù)numpy進(jìn)行方差/標(biāo)準(zhǔn)方差/樣本標(biāo)準(zhǔn)方差/協(xié)方差的計(jì)算
今天小編就為大家分享一篇關(guān)于Python計(jì)算庫(kù)numpy進(jìn)行方差/標(biāo)準(zhǔn)方差/樣本標(biāo)準(zhǔn)方差/協(xié)方差的計(jì)算,小編覺(jué)得內(nèi)容挺不錯(cuò)的,現(xiàn)在分享給大家,具有很好的參考價(jià)值,需要的朋友一起跟隨小編來(lái)看看吧2018-12-12PyQt5+serial模塊實(shí)現(xiàn)一個(gè)串口小工具
這篇文章主要為大家詳細(xì)介紹了如何利用PyQt5和serial模塊實(shí)現(xiàn)一個(gè)簡(jiǎn)單的串口小工具,文中的示例代碼講解詳細(xì),感興趣的小伙伴可以跟隨小編一起學(xué)習(xí)一下2023-01-01Python re.split方法分割字符串的實(shí)現(xiàn)示例
本文主要介紹了Python re.split方法分割字符串的實(shí)現(xiàn)示例,文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來(lái)一起學(xué)習(xí)學(xué)習(xí)吧2022-08-08Django ORM多對(duì)多查詢(xún)方法(自定義第三張表&ManyToManyField)
今天小編就為大家分享一篇Django ORM多對(duì)多查詢(xún)方法(自定義第三張表&ManyToManyField),具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧2019-08-08python實(shí)現(xiàn)冒泡排序算法的兩種方法
本篇文章主要介紹了python實(shí)現(xiàn)冒泡排序的兩種方法,小編覺(jué)得挺不錯(cuò)的,現(xiàn)在分享給大家,也給大家做個(gè)參考。一起跟隨小編過(guò)來(lái)看看吧2018-03-03