pytorch訓(xùn)練時的顯存占用遞增的問題解決
遇到的問題:
在pytorch訓(xùn)練過程中突然out of memory。
解決方法:
1. 測試的時候爆顯存有可能是忘記設(shè)置no_grad
加入 with torch.no_grad()
model.eval() with torch.no_grad(): ? ? ? ? for idx, (data, target) in enumerate(data_loader): ? ? ? ? ? ? if args.gpu != -1: ? ? ? ? ? ? ? ? data, target = data.to(args.device), target.to(args.device) ? ? ? ? ? ? log_probs = net_g(data) ? ? ? ? ? ? probs.append(log_probs) ? ? ? ? ? ?? ? ? ? ? ? ? # sum up batch loss ? ? ? ? ? ? test_loss += F.cross_entropy(log_probs, target, reduction='sum').item() ? ? ? ? ? ? # get the index of the max log-probability ? ? ? ? ? ? y_pred = log_probs.data.max(1, keepdim=True)[1] ? ? ? ? ? ? correct += y_pred.eq(target.data.view_as(y_pred)).long().cpu().sum()
2. loss.item()
寫成loss_train = loss_train + loss.item(),不能直接寫loss_train = loss_train + loss
3. 在代碼中添加以下兩行:
torch.backends.cudnn.enabled = True torch.backends.cudnn.benchmark = True
4. del操作后再加上torch.cuda.empty_cache()
單獨使用del、torch.cuda.empty_cache()效果都不明顯,因為empty_cache()不會釋放還被占用的內(nèi)存。
所以這里使用了del讓對應(yīng)數(shù)據(jù)成為“沒標(biāo)簽”的垃圾,之后這些垃圾所占的空間就會被empty_cache()回收。
"""添加了最后兩行,img和segm是圖像和標(biāo)簽輸入,很明顯通過.cuda()已經(jīng)是被存在在顯存里了; ? ?outputs是模型的輸出,模型在顯存里當(dāng)然其輸出也在顯存里;loss是通過在顯存里的segm和 ? ?outputs算出來的,其也在顯存里。這4個對象都是一次性的,使用后應(yīng)及時把其從顯存中清除 ? ?(當(dāng)然如果你顯存夠大也可以忽略)。""" ? def train(model, data_loader, batch_size, optimizer): ? ? model.train() ? ? total_loss = 0 ? ? accumulated_steps = 32 // batch_size ? ? optimizer.zero_grad() ? ? for idx, (img, segm) in enumerate(tqdm(data_loader)): ? ? ? ? img = img.cuda() ? ? ? ? segm = segm.cuda() ? ? ? ? outputs = model(img) ? ? ? ? loss = criterion(outputs, segm) ? ? ? ? (loss/accumulated_steps).backward() ? ? ? ? if (idx + 1 ) % accumulated_steps == 0: ? ? ? ? ? ? optimizer.step()? ? ? ? ? ? ? optimizer.zero_grad() ? ? ? ? total_loss += loss.item() ? ? ? ?? ? ? ? ? # delete caches ? ? ? ? del img, segm, outputs, loss ? ? ? ? torch.cuda.empty_cache()
補充:Pytorch顯存不斷增長問題的解決思路
思路很簡單,就是在代碼的運行階段輸出顯存占用量,觀察在哪一塊存在顯存劇烈增加或者顯存異常變化的情況。
但是在這個過程中要分級確認(rèn)問題點,也即如果存在三個文件main.py、train.py、model.py。
在此種思路下,應(yīng)該先在main.py中確定問題點,然后,從main.py中進(jìn)入到train.py中,再次輸出顯存占用量,確定問題點在哪。
隨后,再從train.py中的問題點,進(jìn)入到model.py中,再次確認(rèn)。
如果還有更深層次的調(diào)用,可以繼續(xù)追溯下去。
例如:
main.py
def train(model,epochs,data): for e in range(epochs): print("1:{}".format(torch.cuda.memory_allocated(0))) train_epoch(model,data) print("2:{}".format(torch.cuda.memory_allocated(0))) eval(model,data) print("3:{}".format(torch.cuda.memory_allocated(0)))
若1與2之間顯存增加極為劇烈,說明問題出在train_epoch中,進(jìn)一步進(jìn)入到train.py中。
train.py
def train_epoch(model,data): model.train() optim=torch.optimizer() for batch_data in data: print("1:{}".format(torch.cuda.memory_allocated(0))) output=model(batch_data) print("2:{}".format(torch.cuda.memory_allocated(0))) loss=loss(output,data.target) print("3:{}".format(torch.cuda.memory_allocated(0))) optim.zero_grad() print("4:{}".format(torch.cuda.memory_allocated(0))) loss.backward() print("5:{}".format(torch.cuda.memory_allocated(0))) utils.func(model) print("6:{}".format(torch.cuda.memory_allocated(0)))
如果在1,2之間,5,6之間同時出現(xiàn)顯存增加異常的情況。此時需要使用控制變量法,例如我們先讓5,6之間的代碼失效,然后運行,觀察是否仍然存在顯存爆炸。如果沒有,說明問題就出在5,6之間下一級的代碼中。進(jìn)入到下一級代碼,進(jìn)行調(diào)試:
utils.py
def func(model): print("1:{}".format(torch.cuda.memory_allocated(0))) a=f1(model) print("2:{}".format(torch.cuda.memory_allocated(0))) b=f2(a) print("3:{}".format(torch.cuda.memory_allocated(0))) c=f3(b) print("4:{}".format(torch.cuda.memory_allocated(0))) d=f4(c) print("5:{}".format(torch.cuda.memory_allocated(0)))
此時我們再展示另一種調(diào)試思路,先注釋第5行之后的代碼,觀察顯存是否存在先訓(xùn)爆炸,如果沒有,則注釋掉第7行之后的,直至確定哪一行的代碼出現(xiàn)導(dǎo)致了顯存爆炸。假設(shè)第9行起作用后,代碼出現(xiàn)顯存爆炸,說明問題出在第九行,顯存爆炸的問題鎖定。
參考鏈接:
http://www.zzvips.com/article/196059.html
https://blog.csdn.net/fish_like_apple/article/details/101448551
到此這篇關(guān)于pytorch訓(xùn)練時的顯存占用遞增的問題解決的文章就介紹到這了,更多相關(guān)pytorch 顯存占用遞增內(nèi)容請搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
相關(guān)文章
詳解使用python crontab設(shè)置linux定時任務(wù)
本篇文章主要介紹了使用python crontab設(shè)置linux定時任務(wù),具有一定的參考價值,有需要的可以了解一下。2016-12-12numpy 中l(wèi)inspace函數(shù)的使用
本文主要介紹了numpy 中l(wèi)inspace函數(shù)的使用,文中通過示例代碼介紹的非常詳細(xì),對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧2023-03-03Pandas對數(shù)值進(jìn)行分箱操作的4種方法總結(jié)
分箱是一種常見的數(shù)據(jù)預(yù)處理技術(shù)有時也被稱為分桶或離散化,他可用于將連續(xù)數(shù)據(jù)的間隔分組到“箱”或“桶”中。本文將使用python?Pandas庫對數(shù)值進(jìn)行分箱的4種方法,感興趣的可以了解一下2022-05-05結(jié)合Python網(wǎng)絡(luò)爬蟲做一個今日新聞小程序
本篇文章介紹了我在開發(fā)過程中遇到的一個問題,以及解決該問題的過程及思路,通讀本篇對大家的學(xué)習(xí)或工作具有一定的價值,需要的朋友可以參考下2021-09-09Python連接mysql數(shù)據(jù)庫及簡單增刪改查操作示例代碼
這篇文章主要介紹了Python連接mysql數(shù)據(jù)庫及簡單增刪改查操作示例代碼,文中通過示例代碼介紹的非常詳細(xì),對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧2020-08-08