Pytorch中torch.utils.checkpoint()及用法詳解
Pytorch中torch.utils.checkpoint()
在PyTorch中,torch.utils.checkpoint
模塊提供了實(shí)現(xiàn)梯度檢查點(diǎn)(也稱(chēng)為checkpointing)的功能。這個(gè)技術(shù)主要用于訓(xùn)練時(shí)內(nèi)存優(yōu)化,它允許我們以計(jì)算時(shí)間為代價(jià),減少訓(xùn)練深度網(wǎng)絡(luò)時(shí)的內(nèi)存占用。
原理
梯度檢查點(diǎn)技術(shù)的基本原理是,在前向傳播的過(guò)程中,并不保存所有的中間激活值。相反,它只保存一部分關(guān)鍵的激活值。在反向傳播時(shí),根據(jù)保留的激活值重新計(jì)算丟棄的中間激活值。因此內(nèi)存的使用量會(huì)下降,但計(jì)算量會(huì)增加,因?yàn)樾枰匦掠?jì)算一些前向傳播的部分。
用法
torch.utils.checkpoint
中主要的函數(shù)是 checkpoint。checkpoint 函數(shù)可以用來(lái)封裝模型的一部分或者一個(gè)復(fù)雜的運(yùn)算,這部分會(huì)使用梯度檢查點(diǎn)。它的一般用法是:
import torch from torch.utils.checkpoint import checkpoint # 定義一個(gè)前向傳播函數(shù) def custom_forward(*inputs): # 定義你的前向傳播邏輯 # 例如: x, y = inputs; result = x + y ... return result # 在訓(xùn)練的前向傳播過(guò)程中使用梯度檢查點(diǎn) model_output = checkpoint(custom_forward, *model_inputs)
在每次調(diào)用 custom_forward 函數(shù)時(shí),它都會(huì)返回正常的前向傳播結(jié)果。不過(guò),checkpoint 函數(shù)會(huì)確保僅保留必須的激活值(即 custom_forward 的輸出)。其他激活值不會(huì)保存在內(nèi)存中,需要在反向傳播時(shí)重新計(jì)算。
下面是一個(gè)具體的示例,演示了如何在一個(gè)簡(jiǎn)單的模型中使用 checkpoint 函數(shù):
import torch import torch.nn as nn from torch.utils.checkpoint import checkpoint class SomeModel(nn.Module): def __init__(self): super(SomeModel, self).__init__() self.conv1 = nn.Conv2d(1, 20, 5) self.conv2 = nn.Conv2d(20, 50, 5) def forward(self, x): # 使用checkpoint來(lái)減少第二層卷積的內(nèi)存使用量 x = self.conv1(x) x = checkpoint(self.conv2, x) return x model = SomeModel() input = torch.randn(1, 1, 28, 28) output = model(input) loss = output.sum() loss.backward()
在上面的例子中,conv2的前向計(jì)算是通過(guò) checkpoint 封裝的,這意味著在 conv1 的輸出和 conv2 的輸出之間的激活值不會(huì)被完全存儲(chǔ)。在反向傳播時(shí),這些丟失的激活值會(huì)通過(guò)再次前向傳遞 conv2 來(lái)重新計(jì)算。
使用梯度檢查點(diǎn)技術(shù)可以在訓(xùn)練大型模型時(shí)減少顯存的占用,但由于在反向傳播時(shí)額外的重新計(jì)算,它會(huì)增加一些計(jì)算成本。
到此這篇關(guān)于Pytorch中torch.utils.checkpoint()及用法詳解的文章就介紹到這了,更多相關(guān)Pytorch torch.utils.checkpoint()內(nèi)容請(qǐng)搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
相關(guān)文章
Python實(shí)現(xiàn)點(diǎn)陣字體讀取與轉(zhuǎn)換的方法
今天小編就為大家分享一篇Python實(shí)現(xiàn)點(diǎn)陣字體讀取與轉(zhuǎn)換的方法,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧2019-01-01Python reshape的用法及多個(gè)二維數(shù)組合并為三維數(shù)組的實(shí)例
今天小編就為大家分享一篇Python reshape的用法及多個(gè)二維數(shù)組合并為三維數(shù)組的實(shí)例,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧2020-02-02python格式化輸出%s與format()的用法對(duì)比
這篇文章主要為大家介紹了python格式化輸出%s與format()的用法對(duì)比,有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進(jìn)步2021-10-10PyTorch詳解經(jīng)典網(wǎng)絡(luò)ResNet實(shí)現(xiàn)流程
ResNet全稱(chēng)residual neural network,主要是解決過(guò)深的網(wǎng)絡(luò)帶來(lái)的梯度彌散,梯度爆炸,網(wǎng)絡(luò)退化(即網(wǎng)絡(luò)層數(shù)越深時(shí),在數(shù)據(jù)集上表現(xiàn)的性能卻越差)的問(wèn)題2022-05-05詳解如何將Python可執(zhí)行文件(.exe)反編譯為Python腳本
將?Python?可執(zhí)行文件(.exe)反編譯為?Python?腳本是一項(xiàng)有趣的技術(shù)挑戰(zhàn),可以幫助我們理解程序的工作原理,下面我們就來(lái)看看具體實(shí)現(xiàn)步驟吧2024-03-03Python自動(dòng)化運(yùn)維_文件內(nèi)容差異對(duì)比分析
下面小編就為大家分享一篇Python自動(dòng)化運(yùn)維_文件內(nèi)容差異對(duì)比分析,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧2017-12-12python爬蟲(chóng)框架talonspider簡(jiǎn)單介紹
本文給大家介紹的是使用python開(kāi)發(fā)的爬蟲(chóng)框架talonspider的簡(jiǎn)單介紹以及使用方法,有需要的小伙伴可以參考下2017-06-06