利用Pytorch實現(xiàn)ResNet網(wǎng)絡(luò)構(gòu)建及模型訓(xùn)練
構(gòu)建網(wǎng)絡(luò)
ResNet由一系列堆疊的殘差塊組成,其主要作用是通過無限制地增加網(wǎng)絡(luò)深度,從而使其更加強大。在建立ResNet模型之前,讓我們先定義4個層,每個層由多個殘差塊組成。這些層的目的是降低空間尺寸,同時增加通道數(shù)量。
以ResNet50為例,我們可以使用以下代碼來定義ResNet網(wǎng)絡(luò):
class ResNet(nn.Module): def __init__(self, num_classes=1000): super().__init__() self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) self.bn1 = nn.BatchNorm2d(64) self.relu = nn.ReLU(inplace (續(xù)) 即模型需要在輸入層加入一些 normalization 和激活層。 ```python import torch.nn.init as init class Flatten(nn.Module): def __init__(self): super().__init__() def forward(self, x): return x.view(x.size(0), -1) class ResNet(nn.Module): def __init__(self, num_classes=1000): super().__init__() self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) self.bn1 = nn.BatchNorm2d(64) self.relu = nn.ReLU(inplace=True) self.layer1 = nn.Sequential( ResidualBlock(64, 256, stride=1), *[ResidualBlock(256, 256) for _ in range(1, 3)] ) self.layer2 = nn.Sequential( ResidualBlock(256, 512, stride=2), *[ResidualBlock(512, 512) for _ in range(1, 4)] ) self.layer3 = nn.Sequential( ResidualBlock(512, 1024, stride=2), *[ResidualBlock(1024, 1024) for _ in range(1, 6)] ) self.layer4 = nn.Sequential( ResidualBlock(1024, 2048, stride=2), *[ResidualBlock(2048, 2048) for _ in range(1, 3)] ) self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) self.flatten = Flatten() self.fc = nn.Linear(2048, num_classes) for m in self.modules(): if isinstance(m, nn.Conv2d): init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): init.constant_(m.weight, 1) init.constant_(m.bias, 0) def forward(self, x): x = self.conv1(x) x = self.bn1(x) x = self.relu(x) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) x = self.avgpool(x) x = self.flatten(x) x = self.fc(x) return x
改進點如下:
- 我們使用
nn.Sequential
組件,將多個殘差塊組合成一個功能塊(layer)。這樣可以方便地修改網(wǎng)絡(luò)深度,并將其與其他層分離九更容易上手,例如遷移學(xué)習(xí)中重新訓(xùn)練頂部分類器時。 - 我們在ResNet的輸出層添加了標(biāo)準(zhǔn)化和激活函數(shù)。它們有助于提高模型的收斂速度并改善性能。
- 對于
nn.Conv2d
和批標(biāo)準(zhǔn)化層等神經(jīng)網(wǎng)絡(luò)組件,我們使用了PyTorch中的內(nèi)置初始化函數(shù)。它們會自動為我們設(shè)置好每層的參數(shù)。 - 我們還添加了一個
Flatten
層,將4維輸出展平為2維張量,以便通過接下來的全連接層進行分類。
訓(xùn)練模型
我們現(xiàn)在已經(jīng)實現(xiàn)了ResNet50模型,接下來我們將解釋如何訓(xùn)練和測試該模型。
首先我們需要定義損失函數(shù)和優(yōu)化器。在這里,我們使用交叉熵損失函數(shù),以及Adam優(yōu)化器。
import torch.optim as optim device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = ResNet(num_classes=1000).to(device) criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(model.parameters(), lr=0.001)
在使用PyTorch進行訓(xùn)練時,我們通常會創(chuàng)建一個循環(huán),為每個批次的輸入數(shù)據(jù)計算損失并對模型參數(shù)進行更新。以下是該循環(huán)的代碼:
def train(model, optimizer, criterion, train_loader, device): model.train() train_loss = 0 correct = 0 total = 0 for batch_idx, (inputs, targets) in enumerate(train_loader): inputs, targets = inputs.to(device), targets.to(device) optimizer.zero_grad() outputs = model(inputs) loss = criterion(outputs, targets) loss.backward() optimizer.step() train_loss += loss.item() _, predicted = outputs.max(1) total += targets.size(0) correct += predicted.eq(targets).sum().item() acc = 100 * correct / total avg_loss = train_loss / len(train_loader) return acc, avg_loss
在上面的訓(xùn)練循環(huán)中,我們首先通過model.train()
代表進入訓(xùn)練模式。然后使用optimizer.zero_grad()
清除
以上就是利用Pytorch實現(xiàn)ResNet網(wǎng)絡(luò)構(gòu)建及模型訓(xùn)練的詳細內(nèi)容,更多關(guān)于Pytorch ResNet構(gòu)建網(wǎng)絡(luò)模型訓(xùn)練的資料請關(guān)注腳本之家其它相關(guān)文章!
相關(guān)文章
淺談flask截獲所有訪問及before/after_request修飾器
這篇文章主要介紹了淺談flask截獲所有訪問及before/after_request修飾器,具有一定借鑒價值,需要的朋友可以參考下2018-01-01詳解Python是如何實現(xiàn)issubclass的
這篇文章主要介紹了詳解Python是如何實現(xiàn)issubclass的,文中通過示例代碼介紹的非常詳細,對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧2019-07-07Jupyter notebook之如何快速打開ipynb文件
這篇文章主要介紹了Jupyter notebook之如何快速打開ipynb文件問題,具有很好的參考價值,希望對大家有所幫助,如有錯誤或未考慮完全的地方,望不吝賜教2023-09-09Python基礎(chǔ)教程,Python入門教程(超詳細)
Python由荷蘭數(shù)學(xué)和計算機科學(xué)研究學(xué)會 于1990 年代初設(shè)計,作為一門叫做ABC語言的替代品。Python語法和動態(tài)類型,以及解釋型語言的本質(zhì),使它成為多數(shù)平臺上寫腳本和快速開發(fā)應(yīng)用的編程語言2021-06-06深度理解Python中Class類、Object類、Type元類
本文主要介紹了深度理解Python中Class類、Object類、Type元類,文中通過示例代碼介紹的非常詳細,對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧2023-06-06