pytorch 如何使用amp進(jìn)行混合精度訓(xùn)練
簡(jiǎn)介
AMP:Automatic mixed precision,自動(dòng)混合精度,可以在神經(jīng)網(wǎng)絡(luò)推理過(guò)程中,針對(duì)不同的層,采用不同的數(shù)據(jù)精度進(jìn)行計(jì)算,從而實(shí)現(xiàn)節(jié)省顯存和加快速度的目的。
在Pytorch 1.5版本及以前,通過(guò)NVIDIA提供的apex庫(kù)可以實(shí)現(xiàn)amp功能。但是在使用過(guò)程中會(huì)伴隨著一些版本兼容和奇怪的報(bào)錯(cuò)問(wèn)題。
從1.6版本開(kāi)始,Pytorch原生支持自動(dòng)混合精度訓(xùn)練,并已進(jìn)入穩(wěn)定階段,AMP 訓(xùn)練能在 Tensor Core GPU 上實(shí)現(xiàn)更高的性能并節(jié)省多達(dá) 50% 的內(nèi)存。
環(huán)境
Python 3.8
Pytorch 1.7.1
CUDA 11 + cudnn 8
NVIDIA GeFore RTX 3070
ps:后續(xù)使用移動(dòng)端的3070,或者3080結(jié)合我目前訓(xùn)練的分類網(wǎng)絡(luò)來(lái)測(cè)試實(shí)際效果
原理
關(guān)于低精度計(jì)算
當(dāng)前的深度學(xué)習(xí)框架大都采用的都是FP32來(lái)進(jìn)行權(quán)重參數(shù)的存儲(chǔ),比如Python float的類型為雙精度浮點(diǎn)數(shù) FP64,PyTorch Tensor的默認(rèn)類型為單精度浮點(diǎn)數(shù)FP32。
隨著模型越來(lái)越大,加速訓(xùn)練模型的需求就產(chǎn)生了。在深度學(xué)習(xí)模型中使用FP32主要存在幾個(gè)問(wèn)題,第一模型尺寸大,訓(xùn)練的時(shí)候?qū)︼@卡的顯存要求高;第二模型訓(xùn)練速度慢;第三模型推理速度慢。
其解決方案就是使用低精度計(jì)算對(duì)模型進(jìn)行優(yōu)化。
推理過(guò)程中的模型優(yōu)化目前比較成熟的方案就是FP16量化和INT8量化,NVIDIA TensorRT等框架就可以支持,這里不再贅述。訓(xùn)練方面的方案就是混合精度訓(xùn)練,它的基本思想很簡(jiǎn)單: 精度減半(FP32→ FP16) ,訓(xùn)練時(shí)間減半。
與單精度浮點(diǎn)數(shù)float32(32bit,4個(gè)字節(jié))相比,半精度浮點(diǎn)數(shù)float16僅有16bit,2個(gè)字節(jié)組成。
可以很明顯的看到,使用FP16可以解決或者緩解上面FP32的兩個(gè)問(wèn)題:顯存占用更少:通用的模型FP16占用的內(nèi)存只需原來(lái)的一半,訓(xùn)練的時(shí)候可以使用更大的batchsize。
計(jì)算速度更快:有論文指出半精度的計(jì)算吞吐量可以是單精度的 2-8 倍。
從上到下依次為 fp16、fp32 、fp64
當(dāng)前很多NVIDIA GPU搭載了專門為快速FP16矩陣運(yùn)算設(shè)計(jì)的特殊用途Tensor Core,比如Tesla P100,Tesla V100、Tesla A100、GTX 20XX 和RTX 30XX等。
Tensor Core是一種矩陣乘累加的計(jì)算單元,每個(gè)Tensor Core每個(gè)時(shí)鐘執(zhí)行64個(gè)浮點(diǎn)混合精度操作(FP16矩陣相乘和FP32累加),英偉達(dá)宣稱使用Tensor Core進(jìn)行矩陣運(yùn)算可以輕易的提速,同時(shí)降低一半的顯存訪問(wèn)和存儲(chǔ)。
隨著Tensor Core的普及FP16計(jì)算也一步步走向成熟,低精度計(jì)算也是未來(lái)深度學(xué)習(xí)的一個(gè)重要趨勢(shì)。
Tensor Core 的 4x4x4 矩陣乘法與累加
Volta GV100 Tensor Core 流程圖
自動(dòng)混合精度訓(xùn)練
不同于在推理過(guò)程中直接削減權(quán)重精度,在模型訓(xùn)練的過(guò)程中,直接使用半精度進(jìn)行計(jì)算會(huì)導(dǎo)致的兩個(gè)問(wèn)題的處理:舍入誤差(Rounding Error)和溢出錯(cuò)誤(Grad Overflow / Underflow)。
舍入誤差: float16的最大舍入誤差約為 (~2 ^-10 ),比f(wàn)loat32的最大舍入誤差(~2 ^-23) 要大不少。 對(duì)足夠小的浮點(diǎn)數(shù)執(zhí)行的任何操作都會(huì)將該值四舍五入到零,在反向傳播中很多甚至大多數(shù)梯度更新值都非常小,但不為零。 在反向傳播中舍入誤差累積可以把這些數(shù)字變成0或者 nan, 這會(huì)導(dǎo)致不準(zhǔn)確的梯度更新,影響網(wǎng)絡(luò)的收斂。
溢出錯(cuò)誤: 由于float16的有效的動(dòng)態(tài)范圍約為 ( 5.96×10^-8 ~ 6.55×10^4),比單精度的float32(1.4x10^-45 ~ 1.7x10^38)要狹窄很多,精度下降(小數(shù)點(diǎn)后16相比較小數(shù)點(diǎn)后8位要精確的多)會(huì)導(dǎo)致得到的值大于或者小于fp16的有效動(dòng)態(tài)范圍,也就是上溢出或者下溢出。
在深度學(xué)習(xí)中,由于激活函數(shù)的的梯度往往要比權(quán)重梯度小,更易出現(xiàn)下溢出的情況。2018年ICLR論文 Mixed Precision Training 中提到,簡(jiǎn)單的在每個(gè)地方使用FP16會(huì)損失掉梯度更新小于2^-24的值——大約占他們的示例網(wǎng)絡(luò)所有梯度更新的5%。
解決方案就是使用混合精度訓(xùn)練(Mixed Precision)和損失縮放(Loss Scaling):
1、混合精度訓(xùn)練:
混合精度訓(xùn)練是一種通過(guò)在FP16上執(zhí)行盡可能多的操作來(lái)大幅度減少神經(jīng)網(wǎng)絡(luò)訓(xùn)練時(shí)間的技術(shù),在像線性層或是卷積操作上,F(xiàn)P16運(yùn)算較快,但像Reduction運(yùn)算又需要 FP32的動(dòng)態(tài)范圍。通過(guò)混合精度訓(xùn)練的方式,便可以在部分運(yùn)算操作使用FP16,另一部分則使用 FP32,混合精度功能會(huì)嘗試為每個(gè)運(yùn)算使用相匹配的數(shù)據(jù)類型,在內(nèi)存中用FP16做儲(chǔ)存和乘法從而加速計(jì)算,用FP32做累加避免舍入誤差。這樣在權(quán)重更新的時(shí)候就不會(huì)出現(xiàn)舍入誤差導(dǎo)致更新失敗,混合精度訓(xùn)練的策略有效地緩解了舍入誤差的問(wèn)題。
2、損失縮放:
即使用了混合精度訓(xùn)練,還是會(huì)存在無(wú)法收斂的情況,原因是激活梯度的值太小,造成了下溢出。損失縮放是指在執(zhí)行反向傳播之前,將損失函數(shù)的輸出乘以某個(gè)標(biāo)量數(shù)(論文建議從8開(kāi)始)。 乘性增加的損失值產(chǎn)生乘性增加的梯度更新值,提升許多梯度更新值到超過(guò)FP16的安全閾值2^-24。 只要確保在應(yīng)用梯度更新之前撤消縮放,并且不要選擇一個(gè)太大的縮放以至于產(chǎn)生inf權(quán)重更新(上溢出) ,從而導(dǎo)致網(wǎng)絡(luò)向相反的方向發(fā)散。
使用Pytorch AMP
Pytorch原生的amp模式使用起來(lái)相當(dāng)簡(jiǎn)單,只需要從torch.cuda.amp導(dǎo)入GradScaler和 autocast這兩個(gè)函數(shù)即可。torch.cuda.amp的名字意味著這個(gè)功能只能在cuda上使用,事實(shí)上,這個(gè)功能正是NVIDIA的開(kāi)發(fā)人員貢獻(xiàn)到PyTorch項(xiàng)目中的。
Pytorch在amp模式下維護(hù)兩個(gè)權(quán)重矩陣的副本,一個(gè)主副本用 FP32,一個(gè)半精度副本用 FP16。 梯度更新使用FP16矩陣計(jì)算,但更新于 FP32矩陣。 這使得應(yīng)用梯度更新更加安全。
autocast上下文管理器實(shí)現(xiàn)了 FP32到FP16的轉(zhuǎn)換,它會(huì)自動(dòng)判別哪些層可以進(jìn)行FP16哪些層不可以。 GradScaler對(duì)梯度更新計(jì)算(檢查是否溢出)和優(yōu)化器(將丟棄的batches轉(zhuǎn)換為 no-op)進(jìn)行控制,通過(guò)放大loss的值來(lái)防止梯度的溢出。
在訓(xùn)練中的具體使用方法如下所示:
def train(): batch_size = 8 epochs = 10 lr = 1e-3 size = 256 num_class = 35 use_amp = True device = 'cuda' if torch.cuda.is_available() else 'cpu' print('torch version: {}'.format(torch.__version__)) print('amp: {}'.format(use_amp)) print('device: {}'.format(device)) print('epochs: {}'.format(epochs)) print('learn rate: {}'.format(lr)) print('batch size: {}'.format(batch_size)) net = ERFNet(num_classes=num_class).to(device) train_data = CityScapesDataset('D:\\dataset\\cityscapes', 'D:\\dataset\\cityscapes\\trainImages.txt', 'D:\\dataset\\cityscapes\\trainLabels.txt', size, num_class) val_data = CityScapesDataset('D:\\dataset\\cityscapes', 'D:\\dataset\\cityscapes\\valImages.txt', 'D:\\dataset\\cityscapes\\valLabels.txt', size, num_class) train_dataloader = DataLoader(train_data, batch_size=batch_size, shuffle=False, num_workers=8) val_dataloader = DataLoader(val_data, batch_size=batch_size, shuffle=False, num_workers=4) opt = torch.optim.Adam(net.parameters(), lr=lr) criterion = torch.nn.CrossEntropyLoss(ignore_index=255) if use_amp: scaler = torch.cuda.amp.GradScaler() writer = SummaryWriter("summary") train_loss = AverageMeter() val_acc = AverageMeter() val_miou = AverageMeter() for epoch in range(0, epochs): train_loss.reset() val_acc.reset() val_miou.reset() with tqdm(total=train_data.__len__(), unit='img', desc="Epoch {}/{}".format(epoch + 1, epochs)) as pbar: # train net.train() for img, mask in train_dataloader: img = img.to(device) mask = mask.to(device) n = img.size()[0] opt.zero_grad() if use_amp: with torch.cuda.amp.autocast(): output = net(img) loss = criterion(output, mask) scaler.scale(loss).backward() scaler.step(opt) scaler.update() else: output = net(img) loss = criterion(output, mask) loss.backward() opt.step() train_loss.update(loss.item(), n) pbar.set_postfix(**{"loss": train_loss.avg}) pbar.update(img.size()[0]) writer.add_scalar('train_loss', train_loss.avg, epoch) # eval net.eval() for img, mask in val_dataloader: img = img.to(device) mask = mask n = img.size()[0] output = net(img) pred_mask = torch.softmax(output, dim=1) pred_mask = pred_mask.detach().cpu().numpy() pred_mask = np.argmax(pred_mask, axis=1) true_mask = mask.numpy() acc, acc_cls, mean_iu, fwavacc = evaluate(pred_mask, true_mask, num_class) val_acc.update(acc) val_miou.update(mean_iu) writer.add_scalar('val_acc', val_acc.avg, epoch) writer.add_scalar('val_miou', val_miou.avg, epoch) pbar.set_postfix(**{"loss": train_loss.avg, "val_acc": val_acc.avg, "val_miou": val_miou.avg})
實(shí)驗(yàn)
硬件使用NVIDIA Geforce RTX 3070作為測(cè)試卡,這塊卡有184個(gè)Tensor Core,能比較好的支持amp模式。
模型使用ERFNet分割模型作為基準(zhǔn),cityscapes作為測(cè)試數(shù)據(jù),10個(gè)epoch下的測(cè)試效果如下所示:
在模型的訓(xùn)練性能方面,amp模式下的平均訓(xùn)練時(shí)間并沒(méi)有明顯節(jié)省,甚至還略低于正常模式。
顯存的占用大約節(jié)省了25%,對(duì)于需要大量顯存的模型來(lái)說(shuō)這個(gè)提升還是相當(dāng)可觀的。
理論上訓(xùn)練速度應(yīng)該也是有提升的,到Pytorch的GitHub issue里翻了一下,好像30系顯卡會(huì)存在速度提不上來(lái)的問(wèn)題,不太清楚是驅(qū)動(dòng)支持不到位還是軟件適配不到位。
Metrics | time | memory |
---|---|---|
AMP | 66.72s | 2.5G |
NO_AMP | 65.64s | 3.3G |
amp
no_amp
在模型的精度方面,在不進(jìn)行數(shù)據(jù)shuffle的情況下統(tǒng)計(jì)了10個(gè)epoch下兩種模式的train_loss和val_acc,可以看出不管是訓(xùn)練還是推理,amp模式并沒(méi)有帶來(lái)明顯的精度損失。
cmp
以上為個(gè)人經(jīng)驗(yàn),希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
相關(guān)文章
django foreignkey外鍵使用的例子 相當(dāng)于left join
今天小編就為大家分享一篇django foreignkey外鍵使用的例子 相當(dāng)于left join,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧2019-08-08深入了解python的tkinter實(shí)現(xiàn)簡(jiǎn)單登錄
這篇文章主要為大家介紹了python的tkinter實(shí)現(xiàn)簡(jiǎn)單登錄,具有一定的參考價(jià)值,感興趣的小伙伴們可以參考一下,希望能夠給你帶來(lái)幫助2021-12-12Python進(jìn)度條神器tqdm使用實(shí)例詳解
Python進(jìn)度條神器tqdm是一個(gè)快速、可擴(kuò)展的進(jìn)度條工具,可以輕松地為Python腳本添加進(jìn)度條。它可以在循環(huán)中自動(dòng)計(jì)算進(jìn)度,并在終端中顯示進(jìn)度條,讓用戶了解程序的運(yùn)行情況。tqdm還支持多線程和多進(jìn)程,并且可以自定義進(jìn)度條的樣式和顯示方式。2023-06-06Python數(shù)據(jù)結(jié)構(gòu)之圖的應(yīng)用示例
這篇文章主要介紹了Python數(shù)據(jù)結(jié)構(gòu)之圖的應(yīng)用,結(jié)合實(shí)例形式分析了Python數(shù)據(jù)結(jié)構(gòu)中圖的定義與遍歷算法相關(guān)操作技巧,需要的朋友可以參考下2018-05-05多線程python的實(shí)現(xiàn)及多線程有序性
這篇文章主要介紹了多線程python的實(shí)現(xiàn)及多線程有序性,多線程一般用于同時(shí)調(diào)用多個(gè)函數(shù),cpu時(shí)間片輪流分配給多個(gè)任務(wù)2022-06-06python 參數(shù)列表中的self 顯式不等于冗余
Self in the Argument List: Redundant is not Explicit2008-12-12Python檢測(cè)網(wǎng)絡(luò)延遲的代碼
這篇文章主要介紹了Python檢測(cè)網(wǎng)絡(luò)延遲的代碼,小編覺(jué)得挺不錯(cuò)的,現(xiàn)在分享給大家,也給大家做個(gè)參考。一起跟隨小編過(guò)來(lái)看看吧2018-05-05python 讀寫(xiě)中文json的實(shí)例詳解
這篇文章主要介紹了 python 讀寫(xiě)中文json的實(shí)例詳解的相關(guān)資料,希望通過(guò)本文能幫助到大家,讓大家掌握這樣的內(nèi)容,需要的朋友可以參考下2017-10-10