使用Pytorch實現Swish激活函數的示例詳解
前言
激活函數是人工神經網絡的基本組成部分。他們將非線性引入模型,使其能夠學習數據中的復雜關系。Swish 激活函數就是此類激活函數之一,因其獨特的屬性和相對于廣泛使用的整流線性單元 (ReLU) 激活的潛在優(yōu)勢而受到關注。在本文中,我們將深入研究 Swish 激活函數,提供數學公式,探索其相對于 ReLU 的優(yōu)勢,并使用 PyTorch 演示其實現。
Swish 激活功能
Swish 激活函數由 Google 研究人員于 2017 年推出,其數學定義如下:
Swish(x) = x * sigmoid(x)
Where:
- x:激活函數的輸入值。
- sigmoid(x):sigmoid 函數,將任何實數值映射到范圍 [0, 1]。隨著 x 的增加,它從 0 平滑過渡到 1。
Swish 激活將線性分量(輸入 x)與非線性分量(sigmoid函數)相結合,產生平滑且可微的激活函數。
在哪里使用 Swish 激活?
Swish 可用于各種神經網絡架構,包括前饋神經網絡、卷積神經網絡 (CNN) 和循環(huán)神經網絡 (RNN)。它的優(yōu)勢在深度網絡中變得尤為明顯,它可以幫助緩解梯度消失問題。
Swish 激活函數相對于 ReLU 的優(yōu)點
現在,我們來探討一下 Swish 激活函數與流行的 ReLU 激活函數相比的優(yōu)勢。
平滑度和可微分性
由于 sigmoid 分量的存在,Swish 是一個平滑且可微的函數。此屬性使其非常適合基于梯度的優(yōu)化技術,例如隨機梯度下降 (SGD) 和反向傳播。相比之下,ReLU 在零處不可微(ReLU 的導數在 x=0 時未定義),這可能會帶來優(yōu)化挑戰(zhàn)。
改進深度網絡的學習
在深度神經網絡中,與 ReLU 相比,Swish 可以實現更好的學習和收斂。Swish 的平滑性有助于梯度在網絡中更平滑地流動,減少訓練期間梯度消失的可能性。這在非常深的網絡中尤其有用。
類似的計算成本
Swish 激活的計算效率很高,類似于 ReLU。這兩個函數都涉及基本的算術運算,不會顯著增加訓練或推理過程中的計算負擔。
使用 PyTorch 實現
現在,我們來看看如何使用 PyTorch 實現 Swish 激活函數。我們將創(chuàng)建一個自定義 Swish 模塊并將其集成到一個簡單的神經網絡中。
讓我們從導入必要的庫開始。
import torch import torch.nn as nn import torch.optim as optim import torch.nn.functional as F from torchvision import datasets, transforms from torch.utils.data import DataLoader
一旦我們完成了庫的導入,我們就可以定義自定義激活——Swish。
以下代碼定義了一個繼承 PyTorch 基類的類。類內部有一個forward方法。該方法定義模塊如何處理輸入數據。它將輸入張量作為參數,并在應用 Swish 激活后返回輸出張量。
# Swish功能 class Swish(nn.Module): def forward(self, x): return x * torch.sigmoid(x)
定義 Swish 類后,我們繼續(xù)定義神經網絡模型。
在下面的代碼片段中,我們使用 PyTorch 定義了一個專為圖像分類任務設計的神經網絡模型。
輸入層有28×28像素。
隱藏層
- 第一個隱藏層由 256 個神經元組成。它采用扁平輸入并應用線性變換來產生輸出。
- 第二個隱藏層由 128 個神經元組成,從前一層獲取 256 維輸出并產生 128 維輸出。
- Swish 激活函數應用于兩個隱藏層,以向網絡引入非線性。
- 輸出層由 10 個神經元組成,用于執(zhí)行 10 個類別的分類。
# 定義神經網絡模型 class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.fc1 = nn.Linear(28 * 28, 256) self.fc2 = nn.Linear(256, 128) self.fc3 = nn.Linear(128, 10) self.swish = Swish() def forward(self, x): x = x.view(-1, 28 * 28) x = self.fc1(x) x = self.swish(x) x = self.fc2(x) x = self.swish(x) x = self.fc3(x) return x
為了設置用于訓練的神經網絡,我們創(chuàng)建模型的實例,定義損失函數、優(yōu)化器和數據轉換。
# 創(chuàng)建模型的實例 model = Net() # 定義損失函數和優(yōu)化器 criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(model.parameters(), lr=0.01) # 定義數據轉換 transform = transforms.Compose([ transforms.ToTensor(), ])
完成此步驟后,我們可以繼續(xù)在數據集上訓練和評估模型。讓我們使用以下代碼加載 MNIST 數據并創(chuàng)建用于訓練的數據加載器。
# 加載MNIST數據集 train_dataset = datasets.MNIST('', train=True, download=True, transform=transform) test_dataset = datasets.MNIST('', train=False, download=True, transform=transform) # 創(chuàng)建數據加載器 train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True) test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
有了這些數據加載器,我們就可以繼續(xù)訓練循環(huán)來迭代批量的訓練和測試數據。
在下面的代碼中,我們執(zhí)行了神經網絡的訓練循環(huán)。該循環(huán)將重復 5 個時期,在此期間更新模型的權重,以最大限度地減少損失并提高其在訓練數據上的性能。
# 訓練循環(huán) num_epochs = 5 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) for epoch in range(num_epochs): model.train() total_loss = 0.0 for batch_idx, (data, target) in enumerate(train_loader): data, target = data.to(device), target.to(device) optimizer.zero_grad() output = model(data) loss = criterion(output, target) loss.backward() optimizer.step() total_loss += loss.item() print(f"Epoch {epoch+1}/{num_epochs}, Loss: {total_loss / len(train_loader)}")
輸出:
Epoch 1/5, Loss: 1.6938323568503062
Epoch 2/5, Loss: 0.4569567457397779
Epoch 3/5, Loss: 0.3522500048557917
Epoch 4/5, Loss: 0.31695075702369213
Epoch 5/5, Loss: 0.2961081813474496
最后一步是模型評估步驟。
# 評估循環(huán) model.eval() correct = 0 total = 0 with torch.no_grad(): for data, target in test_loader: data, target = data.to(device), target.to(device) outputs = model(data) _, predicted = torch.max(outputs.data, 1) total += target.size(0) correct += (predicted == target).sum().item() print(f"Accuracy on test set: {100 * correct / total}%")
輸出:
Accuracy on test set: 92.02%
結論
Swish 激活函數為 ReLU 等傳統(tǒng)激活函數提供了一種有前景的替代方案。它的平滑性、可微性和改善深度網絡學習的潛力使其成為現代神經網絡架構的寶貴工具。通過在 PyTorch 中實施 Swish,您可以利用其優(yōu)勢并探索其在各種機器學習任務中的有效性。
以上就是使用Pytorch實現Swish激活函數的示例詳解的詳細內容,更多關于Pytorch Swish激活函數的資料請關注腳本之家其它相關文章!