pytorch中節(jié)約顯卡內(nèi)存的方法和技巧
pytorch中一些節(jié)約顯卡內(nèi)存的方法和技巧:
1,控制批(batch)的大?。号看笮∈怯绊慓PU內(nèi)存使用最直接的因素之一。較小的批量大小會(huì)使用更少的GPU內(nèi)存,但可能會(huì)降低模型的收斂速度和穩(wěn)定性。
2,使用梯度累計(jì):梯度累積是在每個(gè)訓(xùn)練步驟中計(jì)算梯度,但不立即更新模型參數(shù),而是將多個(gè)步驟的梯度累積起來(lái),然后一次性更新模型參數(shù)。這樣可以在不增加計(jì)算復(fù)雜性的情況下減少內(nèi)存使用。
3,優(yōu)化模型:控制模型層數(shù),以及每層的神經(jīng)元數(shù)量。
4,使用混合精度:混合精度訓(xùn)練是指同時(shí)使用32位浮點(diǎn)數(shù)(float32)和16位浮點(diǎn)數(shù)(float16)進(jìn)行訓(xùn)練。對(duì)于一些不需要非常高精度的模型,使用float16可以大大減少GPU內(nèi)存的使用。但需要注意的是,使用float16可能會(huì)導(dǎo)致數(shù)值不穩(wěn)定的問(wèn)題,因此需要使用一些技巧如梯度剪裁來(lái)避免這個(gè)問(wèn)題。PyTorch 1.6 版本后引入了自動(dòng)混合精度模塊(AMP)可以自動(dòng)實(shí)現(xiàn)這一功能。
5,刪除不再使用的變量:在訓(xùn)練過(guò)程中不再需要的變量可以停止更新,例如使用torch.no_grad()來(lái)停止計(jì)算梯度。
6,使用數(shù)據(jù)并行:如果有多個(gè)GPU,可以用torch.nn.DataParallel在多個(gè)GPU上并行運(yùn)行你的模型。
7,清理不再使用的緩存:在某些情況下,GPU內(nèi)存不會(huì)被自動(dòng)釋放。你可以手動(dòng)調(diào)用torch.cuda.empty_cache()來(lái)清理不再需要的緩存。
8,凍結(jié)部分網(wǎng)絡(luò)層
9,使用梯度檢查點(diǎn):梯度檢查點(diǎn)是一種保存中間計(jì)算結(jié)果的技術(shù),以便在反向傳播時(shí)重復(fù)使用它們,而不是每次都重新計(jì)算它們。這可以顯著減少GPU內(nèi)存的使用,特別是在深度很大的網(wǎng)絡(luò)中。檢查點(diǎn)的工作原理是用時(shí)間換空間。檢查點(diǎn)不保存整個(gè)計(jì)算圖的所有中間結(jié)果以進(jìn)行反向傳播的計(jì)算,而是在反向傳播的過(guò)程中重新計(jì)算中間結(jié)果。
拓展方法:
以下給大家提供一些節(jié)省PyTorch顯存占用的小技巧,雖然提升不大,但或許能幫你達(dá)到可以勉強(qiáng)運(yùn)行的及格線(xiàn)。
一、大幅減少顯存占用方法
想大幅減少顯存占用,必定要從最占用顯存的方面進(jìn)行縮減,即 模型 和 數(shù)據(jù)。
1. 模型
在模型上主要是將Backbone改用輕量化網(wǎng)絡(luò)或者減少網(wǎng)絡(luò)層數(shù)等方法,可以很大程度上減少模型參數(shù)量,從而減少顯存占用。
二、小幅減少顯存占用方法
有時(shí)候我們可能不想更改模型,而又恰好差一點(diǎn)點(diǎn)顯存或者想盡量多塞幾個(gè)BatchSize,有一些小技巧可以擠出一點(diǎn)點(diǎn)顯存。
1. 使用inplace
PyTorch中的一些函數(shù),例如 ReLU、LeakyReLU 等,均有 inplace
參數(shù),可以對(duì)傳入Tensor進(jìn)行就地修改,減少多余顯存的占用。
2. 加載、存儲(chǔ)等能用CPU就絕不用GPU
GPU存儲(chǔ)空間寶貴,我們可以選擇使用CPU做一些可行的分擔(dān),雖然數(shù)據(jù)傳輸會(huì)浪費(fèi)一些時(shí)間,但是以時(shí)間換空間,可以視情況而定,在模型加載中,如 torch.load_state_dict 時(shí),先加載再使用 model.cuda(),尤其是在 resume 斷點(diǎn)續(xù)訓(xùn)時(shí),可能會(huì)報(bào)顯存不足的錯(cuò)誤。數(shù)據(jù)加載也是,在送入模型前在送入GPU。其余中間的數(shù)據(jù)處理也可以依循這個(gè)原則。
3. 低精度計(jì)算
可以使用 float16 半精度混合計(jì)算,也可以有效減少顯存占用,但是要注意一些溢出情況,如 mean 和 sum等。
4. torch.no_grad
對(duì)于 eval 等不需要 bp 及 backward 的時(shí)候,可已使用with torch.no_grad
,這個(gè)和model.eval()
有一些差異,可以減少一部分顯存占用。
5. 及時(shí)清理不用的變量
對(duì)于一些使用完成后的變量,及時(shí)del
掉,例如 backward 完的 Loss,緩存torch.cuda.empty_cache()
等。
6. 分段計(jì)算
騷操作,我們可以將模型或者數(shù)據(jù)分段計(jì)算。
模型分段,利用
checkpoint
將模型分段計(jì)算
# 首先設(shè)置輸入的input=>requires_grad=True # 如果不設(shè)置可能會(huì)導(dǎo)致得到的gradient為0 input = torch.rand(1, 10, requires_grad=True) layers = [nn.Linear(10, 10) for _ in range(1000)] # 定義要計(jì)算的層函數(shù),可以看到我們定義了兩個(gè) # 一個(gè)計(jì)算前500個(gè)層,另一個(gè)計(jì)算后500個(gè)層 def run_first_half(*args): x = args[0] for layer in layers[:500]: x = layer(x) return x def run_second_half(*args): x = args[0] for layer in layers[500:-1]: x = layer(x) return x # 引入checkpoint from torch.utils.checkpoint import checkpoint x = checkpoint(run_first_half, input) x = checkpoint(run_second_half, x) # 最后一層單獨(dú)執(zhí)行 x = layers[-1](x) x.sum.backward()
數(shù)據(jù)分段,例如原來(lái)需要64個(gè)batch的數(shù)據(jù)forward一次后backward一次,現(xiàn)在改為32個(gè)batch的數(shù)據(jù)forward兩次后backward一次。
總結(jié)
以上是我總結(jié)的一些PyTorch節(jié)省顯存的一些小技巧,希望可以幫助到大家,如果有其它好方法,也歡迎和我討論。
到此這篇關(guān)于pytorch中節(jié)約顯卡內(nèi)存的方法和技巧的文章就介紹到這了,更多相關(guān)pytorch節(jié)約顯卡內(nèi)存內(nèi)容請(qǐng)搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
相關(guān)文章
Python字典一個(gè)key對(duì)應(yīng)多個(gè)value幾種實(shí)現(xiàn)方式
python中字典的健和值是一一對(duì)應(yīng)的,如果對(duì)字典進(jìn)行添加操作時(shí)如果健的名字相同,則當(dāng)前健對(duì)應(yīng)的值就會(huì)被覆蓋,有時(shí)候我們想要一個(gè)健對(duì)應(yīng)多個(gè)值的場(chǎng)景,這篇文章主要給大家介紹了關(guān)于Python字典一個(gè)key對(duì)應(yīng)多個(gè)value幾種實(shí)現(xiàn)方式的相關(guān)資料,需要的朋友可以參考下2023-10-10十行Python代碼實(shí)現(xiàn)文字識(shí)別功能
這篇文章主要和大家分享如何調(diào)用百度的接口實(shí)現(xiàn)圖片的文字識(shí)別。整體是用Python實(shí)現(xiàn),所需要使用的第三方庫(kù)包括aip、PIL、keyboard、pyinstaller,需要的可以參考一下2022-05-05Python中aiohttp模塊的簡(jiǎn)單運(yùn)用方式
這篇文章主要介紹了Python中aiohttp模塊的簡(jiǎn)單運(yùn)用方式,具有很好的參考價(jià)值,希望對(duì)大家有所幫助,如有錯(cuò)誤或未考慮完全的地方,望不吝賜教2024-06-06Python中集合類(lèi)型(set)學(xué)習(xí)小結(jié)
這篇文章主要介紹了Python中集合類(lèi)型(set)學(xué)習(xí)小結(jié),本文講解了set的初始化、運(yùn)算操作、基本方法等內(nèi)容,需要的朋友可以參考下2015-01-01用 Python 寫(xiě)的文檔批量翻譯工具效果竟然超出想象
這篇文章主要介紹了用 Python 寫(xiě)的文檔批量翻譯工具,效果竟然超越付費(fèi)軟件,這個(gè)非常適合python辦公自動(dòng)化腳本,非常不錯(cuò),實(shí)現(xiàn)方法也很簡(jiǎn)單,需要的朋友可以參考下2021-05-05windows下python模擬鼠標(biāo)點(diǎn)擊和鍵盤(pán)輸示例
這篇文章主要介紹了windows下python模擬鼠標(biāo)點(diǎn)擊和鍵盤(pán)輸示例,需要的朋友可以參考下2014-02-02python實(shí)現(xiàn)嵌套列表平鋪的兩種方法
今天小編就為大家分享一篇python實(shí)現(xiàn)嵌套列表平鋪的兩種方法,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧2018-11-11Matlab常見(jiàn)最優(yōu)化方法的原理和深度分析
這篇文章主要介紹了Matlab常見(jiàn)最優(yōu)化方法的原理和深度分析,matlab只是個(gè)軟件,用來(lái)完成機(jī)械的計(jì)算,而如何安排這些計(jì)算,需要用戶(hù)掌握最基本的數(shù)學(xué)概念,需要的朋友可以參考下2023-07-07