欧美bbbwbbbw肥妇,免费乱码人妻系列日韩,一级黄片

解讀torch.cuda.amp自動(dòng)混合精度訓(xùn)練之節(jié)省顯存并加快推理速度

 更新時(shí)間:2023年08月03日 16:56:37   作者:Code_demon  
這篇文章主要介紹了torch.cuda.amp自動(dòng)混合精度訓(xùn)練之節(jié)省顯存并加快推理速度問題,具有很好的參考價(jià)值,希望對大家有所幫助,如有錯(cuò)誤或未考慮完全的地方,望不吝賜教

1、什么是amp?

amp:Automatic mixed precision,自動(dòng)混合精度,可以在神經(jīng)網(wǎng)絡(luò)推理過程中,針對不同的層,采用不同的數(shù)據(jù)精度進(jìn)行計(jì)算,從而實(shí)現(xiàn)節(jié)省顯存和加快速度的目的。

自動(dòng)混合精度的關(guān)鍵詞有兩個(gè):自動(dòng)、混合精度。

這是由PyTorch 1.6的torch.cuda.amp模塊帶來的:

from torch.cuda import amp

混合精度預(yù)示著有不止一種精度的Tensor,那在PyTorch的AMP模塊里是幾種呢?

2種:torch.FloatTensor(浮點(diǎn)型 32位)和torch.HalfTensor(半精度浮點(diǎn)型 16位);

自動(dòng)預(yù)示著Tensor的dtype類型會(huì)自動(dòng)變化,也就是框架按需自動(dòng)調(diào)整tensor的dtype(其實(shí)不是完全自動(dòng),有些地方還是需要手工干預(yù));

注意

  • torch.cuda.amp 的名字意味著這個(gè)功能只能在cuda上使用。
  • torch默認(rèn)的tensor精度類型是torch.FloatTensor

2、為什么需要自動(dòng)混合精度(amp)?

也可以這么問:為什么需要自動(dòng)混合精度,也就是torch.FloatTensortorch.HalfTensor的混合,而不全是torch.FloatTensor?或者全是torch.HalfTensor?

原因:

在某些上下文中torch.FloatTensor有優(yōu)勢,在某些上下文中torch.HalfTensor有優(yōu)勢。

torch.HalfTensor

  • torch.HalfTensor的優(yōu)勢就是存儲(chǔ)小、計(jì)算快、更好的利用CUDA設(shè)備的Tensor Core。因此訓(xùn)練的時(shí)候可以減少顯存的占用(可以增加batchsize了),同時(shí)訓(xùn)練速度更快;
  • torch.HalfTensor的劣勢就是:數(shù)值范圍?。ǜ菀譕verflow / Underflow)、舍入誤差(Rounding Error,導(dǎo)致一些微小的梯度信息達(dá)不到16bit精度的最低分辨率,從而丟失)。

可見,當(dāng)有優(yōu)勢的時(shí)候就用torch.HalfTensor,而為了消除torch.HalfTensor的劣勢,我們帶來了兩種解決方案:

  • 梯度scale,這正是上一小節(jié)中提到的torch.cuda.amp.GradScaler,通過放大loss的值來防止梯度消失underflow(這只是BP的時(shí)候傳遞梯度信息使用,真正更新權(quán)重的時(shí)候還是要把放大的梯度再unscale回去)
  • 回落到torch.FloatTensor,這就是混合一詞的由來。那怎么知道什么時(shí)候用torch.FloatTensor,什么時(shí)候用半精度浮點(diǎn)型呢?這是PyTorch框架決定的,AMP上下文中,一些常用的操作中tensor會(huì)被自動(dòng)轉(zhuǎn)化為半精度浮點(diǎn)型的torch.HalfTensor(如:conv1d、conv2d、conv3d、linear、prelu等)

3、如何在PyTorch中使用自動(dòng)混合精度?

答案是 autocast + GradScaler

3.1 autocast

使用torch.cuda.amp模塊中的autocast 類。

from torch.cuda import amp
# 創(chuàng)建model,默認(rèn)是torch.FloatTensor
model = Net().cuda()
optimizer = optim.SGD(model.parameters(), ...)
# 判斷能否使用自動(dòng)混合精度
enable_amp = True if "cuda" in device.type else False
for input, target in data:
    optimizer.zero_grad()
    # 前向過程(model + loss)開啟 autocast
    with amp.autocast(enabled=enable_amp):
        output = model(input)
        loss = loss_fn(output, target)
    # 反向傳播在autocast上下文之外
    loss.backward()
    optimizer.step()

注意

  • 當(dāng)進(jìn)入autocast,自動(dòng)將torch.FloatTensor類型轉(zhuǎn)化為torch.HalfTensor,而不需要手動(dòng)設(shè)置model.half()/input.half,框架會(huì)自動(dòng)做,這也是自動(dòng)混合精度中“自動(dòng)”一詞的由來。
  • autocast上下文應(yīng)該只包含網(wǎng)絡(luò)的前向過程(包括loss的計(jì)算),而不要包含反向傳播。

3.2、GradScaler

這里GradScaler就是第二小節(jié)中提到的梯度scaler模塊,需要在訓(xùn)練最開始之前使用amp.GradScaler實(shí)例化一個(gè)GradScaler對象。

from torch.cuda import amp
# 創(chuàng)建model,默認(rèn)是torch.FloatTensor
model = Net().cuda()
optimizer = optim.SGD(model.parameters(), ...)
# 判斷能否使用自動(dòng)混合精度
enable_amp = True if "cuda" in device.type else False
# 在訓(xùn)練最開始之前實(shí)例化一個(gè)GradScaler對象
scaler = amp.GradScaler(enabled=enable_amp)
for epoch in epochs:
    for input, target in data:
        optimizer.zero_grad()
        # 前向過程(model + loss)開啟 autocast
        with amp.autocast(enabled=enable_amp):
            output = model(input)
            loss = loss_fn(output, target)
        # 1、Scales loss.  先將梯度放大 防止梯度消失
        scaler.scale(loss).backward()
        # 2、scaler.step()   再把梯度的值unscale回來.
        # 如果梯度的值不是 infs 或者 NaNs, 那么調(diào)用optimizer.step()來更新權(quán)重,
        # 否則,忽略step調(diào)用,從而保證權(quán)重不更新(不被破壞)
        scaler.step(optimizer)
        # 3、準(zhǔn)備著,看是否要增大scaler
        scaler.update()
        # 正常更新權(quán)重
        optimizer.zero_grad()

scaler的大小在每次迭代中動(dòng)態(tài)的估計(jì),為了盡可能的減少梯度underflow,scaler應(yīng)該更大;但是如果太大的話,半精度浮點(diǎn)型的tensor又容易o(hù)verflow(變成inf或者NaN)。

所以動(dòng)態(tài)估計(jì)的原理就是在不出現(xiàn)inf或者NaN梯度值的情況下盡可能的增大scaler的值——在每次scaler.step(optimizer)中,都會(huì)檢查是否又inf或NaN的梯度出現(xiàn):

  • 如果出現(xiàn)了inf或者NaN,scaler.step(optimizer)會(huì)忽略此次的權(quán)重更新(optimizer.step() ),并且將scaler的大小縮?。ǔ松蟗ackoff_factor);
  • 如果沒有出現(xiàn)inf或者NaN,那么權(quán)重正常更新,并且當(dāng)連續(xù)多次(growth_interval指定)沒有出現(xiàn)inf或者NaN,則scaler.update()會(huì)將scaler的大小增加(乘上growth_factor)。

注意

再強(qiáng)調(diào)一點(diǎn),amp只能在GPU環(huán)境下使用,因?yàn)橐粊韆mp是寫在torch.cuda中的函數(shù),而且amp的中的 amp.GradScaleramp.autocast函數(shù)構(gòu)造是這樣的:

amp.GradScaler

    def __init__(self,
                 init_scale=2.**16,
                 growth_factor=2.0,
                 backoff_factor=0.5,
                 growth_interval=2000,
                 enabled=True):
        if enabled and not torch.cuda.is_available():
            warnings.warn("torch.cuda.amp.GradScaler is enabled, but CUDA is not available.  Disabling.")
            self._enabled = False
        else:
            self._enabled = enabled

amp.autocast

 def __init__(self, enabled=True):
        if enabled and not torch.cuda.is_available():
            warnings.warn("torch.cuda.amp.autocast only affects CUDA ops, but CUDA is not available.  Disabling.")
            self._enabled = False
        else:
            self._enabled = enabled

4、多GPU訓(xùn)練

單卡訓(xùn)練的話上面的代碼已經(jīng)夠了。

要是想多卡跑的話僅僅這樣還不夠,會(huì)發(fā)現(xiàn)在forward里面的每個(gè)結(jié)果都還是float32的,怎么辦?

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
    def forward(self, input_data_c1):
    	with autocast():
    		# code
    	return

只要把model中的forward里面的代碼用autocast代碼塊方式運(yùn)行就好了。

總結(jié)

以上為個(gè)人經(jīng)驗(yàn),希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。

相關(guān)文章

  • python通過移動(dòng)端訪問查看電腦界面

    python通過移動(dòng)端訪問查看電腦界面

    這篇文章主要介紹了python通過移動(dòng)端訪問查看電腦界面,文中通過示例代碼介紹的非常詳細(xì),對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友可以參考下
    2020-01-01
  • python 實(shí)現(xiàn)添加標(biāo)簽&打標(biāo)簽的操作

    python 實(shí)現(xiàn)添加標(biāo)簽&打標(biāo)簽的操作

    這篇文章主要介紹了python 實(shí)現(xiàn)添加標(biāo)簽&打標(biāo)簽的操作,具有很好的參考價(jià)值,希望對大家有所幫助。如有錯(cuò)誤或未考慮完全的地方,望不吝賜教
    2021-05-05
  • pyqt5 實(shí)現(xiàn)在別的窗口彈出進(jìn)度條

    pyqt5 實(shí)現(xiàn)在別的窗口彈出進(jìn)度條

    今天小編就為大家分享一篇pyqt5 實(shí)現(xiàn)在別的窗口彈出進(jìn)度條,具有很好的參考價(jià)值,希望對大家有所幫助。一起跟隨小編過來看看吧
    2019-06-06
  • python繪制已知點(diǎn)的坐標(biāo)的直線實(shí)例

    python繪制已知點(diǎn)的坐標(biāo)的直線實(shí)例

    今天小編就為大家分享一篇python繪制已知點(diǎn)的坐標(biāo)的直線實(shí)例,具有很好的參考價(jià)值,希望對大家有所幫助。一起跟隨小編過來看看吧
    2019-07-07
  • python實(shí)現(xiàn)拓?fù)渑判虻姆椒ú襟E

    python實(shí)現(xiàn)拓?fù)渑判虻姆椒ú襟E

    拓?fù)渑判蚴菍τ邢驘o環(huán)圖進(jìn)行排序的一種算法,本文主要介紹了python實(shí)現(xiàn)拓?fù)渑判虻姆椒ú襟E,具有一定的參考價(jià)值,感興趣的可以了解一下
    2024-03-03
  • Python可變和不可變、類的私有屬性實(shí)例分析

    Python可變和不可變、類的私有屬性實(shí)例分析

    這篇文章主要介紹了Python可變和不可變、類的私有屬性,結(jié)合實(shí)例形式分析了Python值可變與不可變的情況及內(nèi)存地址變化,類的私有屬性定義、訪問相關(guān)操作技巧,需要的朋友可以參考下
    2019-05-05
  • 在?Python?中使用變量創(chuàng)建文件名的方法

    在?Python?中使用變量創(chuàng)建文件名的方法

    這篇文章主要介紹了在?Python?中使用變量創(chuàng)建文件名,格式化的字符串文字使我們能夠通過在字符串前面加上 f 來在字符串中包含表達(dá)式和變量,本文給大家詳細(xì)講解,需要的朋友可以參考下
    2023-03-03
  • pygame實(shí)現(xiàn)貪吃蛇游戲(上)

    pygame實(shí)現(xiàn)貪吃蛇游戲(上)

    這篇文章主要為大家詳細(xì)介紹了pygame實(shí)現(xiàn)貪吃蛇游戲,文中示例代碼介紹的非常詳細(xì),具有一定的參考價(jià)值,感興趣的小伙伴們可以參考一下
    2019-10-10
  • 正確的使用Python臨時(shí)文件

    正確的使用Python臨時(shí)文件

    這篇文章主要介紹了正確的使用Python臨時(shí)文件,幫助大家更好的理解和學(xué)習(xí)使用python,感興趣的朋友可以了解下
    2021-03-03
  • python3使用sqlite3構(gòu)建本地持久化緩存的過程

    python3使用sqlite3構(gòu)建本地持久化緩存的過程

    日常python開發(fā)中會(huì)遇到數(shù)據(jù)持久化的問題,今天記錄下如何使用sqlite3進(jìn)行數(shù)據(jù)持久化,并提供示例代碼及數(shù)據(jù)查看工具,需要的朋友可以參考下
    2023-11-11

最新評論