pytorch 使用半精度模型部署的操作
背景
pytorch作為深度學(xué)習(xí)的計(jì)算框架正得到越來(lái)越多的應(yīng)用.
我們除了在模型訓(xùn)練階段應(yīng)用外,最近也把pytorch應(yīng)用在了部署上.
在部署時(shí),為了減少計(jì)算量,可以考慮使用16位浮點(diǎn)模型,而訓(xùn)練時(shí)涉及到梯度計(jì)算,需要使用32位浮點(diǎn),這種精度的不一致經(jīng)過(guò)測(cè)試,模型性能下降有限,可以接受.
但是推斷時(shí)計(jì)算量可以降低一半,同等計(jì)算資源下,并發(fā)度可提升近一倍
具體方法
在pytorch中,一般模型定義都繼承torch.nn.Moudle,torch.nn.Module基類的half()方法會(huì)把所有參數(shù)轉(zhuǎn)為16位浮點(diǎn),所以在模型加載后,調(diào)用一下該方法即可達(dá)到模型切換的目的.接下來(lái)只需要在推斷時(shí)把input的tensor切換為16位浮點(diǎn)即可
另外還有一個(gè)小的trick,在推理過(guò)程中模型輸出的tensor自然會(huì)成為16位浮點(diǎn),如果需要新創(chuàng)建tensor,最好調(diào)用已有tensor的new_zeros,new_full等方法而不是torch.zeros和torch.full,前者可以自動(dòng)繼承已有tensor的類型,這樣就不需要到處增加代碼判斷是使用16位還是32位了,只需要針對(duì)input tensor切換.
補(bǔ)充:pytorch 使用amp.autocast半精度加速訓(xùn)練
準(zhǔn)備工作
pytorch 1.6+
如何使用autocast?
根據(jù)官方提供的方法,
答案就是autocast + GradScaler。
如何在PyTorch中使用自動(dòng)混合精度?
答案:autocast + GradScaler。
1.autocast
正如前文所說(shuō),需要使用torch.cuda.amp模塊中的autocast 類。使用也是非常簡(jiǎn)單的
from torch.cuda.amp import autocast as autocast # 創(chuàng)建model,默認(rèn)是torch.FloatTensor model = Net().cuda() optimizer = optim.SGD(model.parameters(), ...) for input, target in data: optimizer.zero_grad() # 前向過(guò)程(model + loss)開(kāi)啟 autocast with autocast(): output = model(input) loss = loss_fn(output, target) # 反向傳播在autocast上下文之外 loss.backward() optimizer.step()
2.GradScaler
GradScaler就是梯度scaler模塊,需要在訓(xùn)練最開(kāi)始之前實(shí)例化一個(gè)GradScaler對(duì)象。
因此PyTorch中經(jīng)典的AMP使用方式如下:
from torch.cuda.amp import autocast as autocast # 創(chuàng)建model,默認(rèn)是torch.FloatTensor model = Net().cuda() optimizer = optim.SGD(model.parameters(), ...) # 在訓(xùn)練最開(kāi)始之前實(shí)例化一個(gè)GradScaler對(duì)象 scaler = GradScaler() for epoch in epochs: for input, target in data: optimizer.zero_grad() # 前向過(guò)程(model + loss)開(kāi)啟 autocast with autocast(): output = model(input) loss = loss_fn(output, target) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update()
3.nn.DataParallel
單卡訓(xùn)練的話上面的代碼已經(jīng)夠了,親測(cè)在2080ti上能減少至少1/3的顯存,至于速度。。。
要是想多卡跑的話僅僅這樣還不夠,會(huì)發(fā)現(xiàn)在forward里面的每個(gè)結(jié)果都還是float32的,怎么辦?
class Model(nn.Module): def __init__(self): super(Model, self).__init__() def forward(self, input_data_c1): with autocast(): # code return
只要把forward里面的代碼用autocast代碼塊方式運(yùn)行就好啦!
自動(dòng)進(jìn)行autocast的操作
如下操作中tensor會(huì)被自動(dòng)轉(zhuǎn)化為半精度浮點(diǎn)型的torch.HalfTensor:
1、matmul
2、addbmm
3、addmm
4、addmv
5、addr
6、baddbmm
7、bmm
8、chain_matmul
9、conv1d
10、conv2d
11、conv3d
12、conv_transpose1d
13、conv_transpose2d
14、conv_transpose3d
15、linear
16、matmul
17、mm
18、mv
19、prelu
那么只有這些操作才能半精度嗎?不是。其他操作比如rnn也可以進(jìn)行半精度運(yùn)行,但是需要自己手動(dòng),暫時(shí)沒(méi)有提供自動(dòng)的轉(zhuǎn)換。
相關(guān)文章
python實(shí)現(xiàn)斐波那契數(shù)列的方法示例
每個(gè)碼農(nóng)大概都會(huì)用自己擅長(zhǎng)的語(yǔ)言寫(xiě)出一個(gè)斐波那契數(shù)列出來(lái),斐波那契數(shù)列簡(jiǎn)單地說(shuō),起始兩項(xiàng)為0和1,此后的項(xiàng)分別為它的前兩項(xiàng)之后。下面這篇文章就給大家詳細(xì)介紹了python實(shí)現(xiàn)斐波那契數(shù)列的方法,有需要的朋友們可以參考借鑒,下面來(lái)一起看看吧。2017-01-01PyQT中QTableWidget如何根據(jù)單元格內(nèi)容設(shè)置自動(dòng)寬度
這篇文章主要介紹了PyQT中QTableWidget如何根據(jù)單元格內(nèi)容設(shè)置自動(dòng)寬度問(wèn)題,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。如有錯(cuò)誤或未考慮完全的地方,望不吝賜教2023-05-05Python入門及進(jìn)階筆記 Python 內(nèi)置函數(shù)小結(jié)
這篇文章主要介紹了Python的內(nèi)置函數(shù)小結(jié),需要的朋友可以參考下2014-08-08將本地Python項(xiàng)目打包成docker鏡像上傳到服務(wù)器并在docker中運(yùn)行
Docker是一個(gè)開(kāi)源項(xiàng)目,為開(kāi)發(fā)人員和系統(tǒng)管理員提供了一個(gè)開(kāi)放平臺(tái),可以將應(yīng)用程序構(gòu)建、打包為一個(gè)輕量級(jí)容器,并在任何地方運(yùn)行,這篇文章主要給大家介紹了關(guān)于將本地Python項(xiàng)目打包成docker鏡像上傳到服務(wù)器并在docker中運(yùn)行的相關(guān)資料,需要的朋友可以參考下2023-12-12基于python實(shí)現(xiàn)監(jiān)聽(tīng)Rabbitmq系統(tǒng)日志代碼示例
這篇文章主要介紹了基于python實(shí)現(xiàn)監(jiān)聽(tīng)Rabbitmq系統(tǒng)日志代碼示例,文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友可以參考下2020-11-11python判斷鏈表是否有環(huán)的實(shí)例代碼
在本篇文章里小編給大家整理的是關(guān)于python判斷鏈表是否有環(huán)的知識(shí)點(diǎn)及實(shí)例代碼,需要的朋友們參考下。2020-01-01