pytorch之添加BN的實(shí)現(xiàn)
pytorch之添加BN層
批標(biāo)準(zhǔn)化
模型訓(xùn)練并不容易,特別是一些非常復(fù)雜的模型,并不能非常好的訓(xùn)練得到收斂的結(jié)果,所以對(duì)數(shù)據(jù)增加一些預(yù)處理,同時(shí)使用批標(biāo)準(zhǔn)化能夠得到非常好的收斂結(jié)果,這也是卷積網(wǎng)絡(luò)能夠訓(xùn)練到非常深的層的一個(gè)重要原因。
數(shù)據(jù)預(yù)處理
目前數(shù)據(jù)預(yù)處理最常見(jiàn)的方法就是中心化和標(biāo)準(zhǔn)化,中心化相當(dāng)于修正數(shù)據(jù)的中心位置,實(shí)現(xiàn)方法非常簡(jiǎn)單,就是在每個(gè)特征維度上減去對(duì)應(yīng)的均值,最后得到 0 均值的特征。標(biāo)準(zhǔn)化也非常簡(jiǎn)單,在數(shù)據(jù)變成 0 均值之后,為了使得不同的特征維度有著相同的規(guī)模,可以除以標(biāo)準(zhǔn)差近似為一個(gè)標(biāo)準(zhǔn)正態(tài)分布,也可以依據(jù)最大值和最小值將其轉(zhuǎn)化為 -1 ~ 1之間,這兩種方法非常的常見(jiàn),如果你還記得,前面我們?cè)谏窠?jīng)網(wǎng)絡(luò)的部分就已經(jīng)使用了這個(gè)方法實(shí)現(xiàn)了數(shù)據(jù)標(biāo)準(zhǔn)化,至于另外一些方法,比如 PCA 或者 白噪聲已經(jīng)用得非常少了。
Batch Normalization
前面在數(shù)據(jù)預(yù)處理的時(shí)候,盡量輸入特征不相關(guān)且滿足一個(gè)標(biāo)準(zhǔn)的正態(tài)分布,
這樣模型的表現(xiàn)一般也較好。但是對(duì)于很深的網(wǎng)路結(jié)構(gòu),網(wǎng)路的非線性層會(huì)使得輸出的結(jié)果變得相關(guān),且不再滿足一個(gè)標(biāo)準(zhǔn)的 N(0, 1) 的分布,甚至輸出的中心已經(jīng)發(fā)生了偏移,這對(duì)于模型的訓(xùn)練,特別是深層的模型訓(xùn)練非常的困難。
所以在 2015 年一篇論文提出了這個(gè)方法,批標(biāo)準(zhǔn)化,簡(jiǎn)而言之,就是對(duì)于每一層網(wǎng)絡(luò)的輸出,對(duì)其做一個(gè)歸一化,使其服從標(biāo)準(zhǔn)的正態(tài)分布,這樣后一層網(wǎng)絡(luò)的輸入也是一個(gè)標(biāo)準(zhǔn)的正態(tài)分布,所以能夠比較好的進(jìn)行訓(xùn)練,加快收斂速度。batch normalization 的實(shí)現(xiàn)非常簡(jiǎn)單,對(duì)于給定的一個(gè) batch 的數(shù)據(jù)算法的公式如下
第一行和第二行是計(jì)算出一個(gè) batch 中數(shù)據(jù)的均值和方差,接著使用第三個(gè)公式對(duì) batch 中的每個(gè)數(shù)據(jù)點(diǎn)做標(biāo)準(zhǔn)化,ϵ是為了計(jì)算穩(wěn)定引入的一個(gè)小的常數(shù),通常取 ,最后利用權(quán)重修正得到最后的輸出結(jié)果,非常的簡(jiǎn)單,
實(shí)現(xiàn)一下簡(jiǎn)單的一維的情況,也就是神經(jīng)網(wǎng)絡(luò)中的情況
import sys sys.path.append('..') import torch def simple_batch_norm_1d(x, gamma, beta): eps = 1e-5 x_mean = torch.mean(x, dim=0, keepdim=True) # 保留維度進(jìn)行 broadcast x_var = torch.mean((x - x_mean) ** 2, dim=0, keepdim=True) x_hat = (x - x_mean) / torch.sqrt(x_var + eps) return gamma.view_as(x_mean) * x_hat + beta.view_as(x_mean) x = torch.arange(15).view(5, 3) gamma = torch.ones(x.shape[1]) beta = torch.zeros(x.shape[1]) print('before bn: ') print(x) y = simple_batch_norm_1d(x, gamma, beta) print('after bn: ') print(y)
可以看到這里一共是 5 個(gè)數(shù)據(jù)點(diǎn),三個(gè)特征,每一列表示一個(gè)特征的不同數(shù)據(jù)點(diǎn),使用批標(biāo)準(zhǔn)化之后,每一列都變成了標(biāo)準(zhǔn)的正態(tài)分布這個(gè)時(shí)候會(huì)出現(xiàn)一個(gè)問(wèn)題,就是測(cè)試的時(shí)候該使用批標(biāo)準(zhǔn)化嗎?答案是肯定的,因?yàn)橛?xùn)練的時(shí)候使用了,而測(cè)試的時(shí)候不使用肯定會(huì)導(dǎo)致結(jié)果出現(xiàn)偏差,但是測(cè)試的時(shí)候如果只有一個(gè)數(shù)據(jù)集,那么均值不就是這個(gè)值,方差為 0 嗎?這顯然是隨機(jī)的,所以測(cè)試的時(shí)候不能用測(cè)試的數(shù)據(jù)集去算均值和方差,而是用訓(xùn)練的時(shí)候算出的移動(dòng)平均均值和方差去代替
實(shí)現(xiàn)以下能夠區(qū)分訓(xùn)練狀態(tài)和測(cè)試狀態(tài)的批標(biāo)準(zhǔn)化方法
def batch_norm_1d(x, gamma, beta, is_training, moving_mean, moving_var, moving_momentum=0.1): eps = 1e-5 x_mean = torch.mean(x, dim=0, keepdim=True) # 保留維度進(jìn)行 broadcast x_var = torch.mean((x - x_mean) ** 2, dim=0, keepdim=True) if is_training: x_hat = (x - x_mean) / torch.sqrt(x_var + eps) moving_mean[:] = moving_momentum * moving_mean + (1. - moving_momentum) * x_mean moving_var[:] = moving_momentum * moving_var + (1. - moving_momentum) * x_var else: x_hat = (x - moving_mean) / torch.sqrt(moving_var + eps) return gamma.view_as(x_mean) * x_hat + beta.view_as(x_mean)
下面使用深度神經(jīng)網(wǎng)絡(luò)分類 mnist 數(shù)據(jù)集的例子來(lái)試驗(yàn)一下批標(biāo)準(zhǔn)化是否有用
import numpy as np from torchvision.datasets import mnist # 導(dǎo)入 pytorch 內(nèi)置的 mnist 數(shù)據(jù) from torch.utils.data import DataLoader from torch import nn from torch.autograd import Variable
使用內(nèi)置函數(shù)下載 mnist 數(shù)據(jù)集
train_set = mnist.MNIST('./data', train=True) test_set = mnist.MNIST('./data', train=False) def data_tf(x): x = np.array(x, dtype='float32') / 255 x = (x - 0.5) / 0.5 # 數(shù)據(jù)預(yù)處理,標(biāo)準(zhǔn)化 x = x.reshape((-1,)) # 拉平 x = torch.from_numpy(x) return x train_set = mnist.MNIST('./data', train=True, transform=data_tf, download=True) # 重新載入數(shù)據(jù)集,申明定義的數(shù)據(jù)變換 test_set = mnist.MNIST('./data', train=False, transform=data_tf, download=True) train_data = DataLoader(train_set, batch_size=64, shuffle=True) test_data = DataLoader(test_set, batch_size=128, shuffle=False) class multi_network(nn.Module): def __init__(self): super(multi_network, self).__init__() self.layer1 = nn.Linear(784, 100) self.relu = nn.ReLU(True) self.layer2 = nn.Linear(100, 10) self.gamma = nn.Parameter(torch.randn(100)) self.beta = nn.Parameter(torch.randn(100)) self.moving_mean = Variable(torch.zeros(100)) self.moving_var = Variable(torch.zeros(100)) def forward(self, x, is_train=True): x = self.layer1(x) x = batch_norm_1d(x, self.gamma, self.beta, is_train, self.moving_mean, self.moving_var) x = self.relu(x) x = self.layer2(x) return x net = multi_network() # 定義 loss 函數(shù) criterion = nn.CrossEntropyLoss() optimizer = torch.optim.SGD(net.parameters(), 1e-1) # 使用隨機(jī)梯度下降,學(xué)習(xí)率 0.1 from datetime import datetime import torch import torch.nn.functional as F from torch import nn from torch.autograd import Variable def get_acc(output, label): total = output.shape[0] _, pred_label = output.max(1) num_correct = (pred_label == label).sum().item() return num_correct / total #定義訓(xùn)練函數(shù) def train(net, train_data, valid_data, num_epochs, optimizer, criterion): if torch.cuda.is_available(): net = net.cuda() prev_time = datetime.now() for epoch in range(num_epochs): train_loss = 0 train_acc = 0 net = net.train() for im, label in train_data: if torch.cuda.is_available(): im = Variable(im.cuda()) # (bs, 3, h, w) label = Variable(label.cuda()) # (bs, h, w) else: im = Variable(im) label = Variable(label) # forward output = net(im) loss = criterion(output, label) # backward optimizer.zero_grad() loss.backward() optimizer.step() train_loss += loss.item() train_acc += get_acc(output, label) cur_time = datetime.now() h, remainder = divmod((cur_time - prev_time).seconds, 3600) m, s = divmod(remainder, 60) time_str = "Time %02d:%02d:%02d" % (h, m, s) if valid_data is not None: valid_loss = 0 valid_acc = 0 net = net.eval() for im, label in valid_data: if torch.cuda.is_available(): im = Variable(im.cuda(), volatile=True) label = Variable(label.cuda(), volatile=True) else: im = Variable(im, volatile=True) label = Variable(label, volatile=True) output = net(im) loss = criterion(output, label) valid_loss += loss.item() valid_acc += get_acc(output, label) epoch_str = ( "Epoch %d. Train Loss: %f, Train Acc: %f, Valid Loss: %f, Valid Acc: %f, " % (epoch, train_loss / len(train_data), train_acc / len(train_data), valid_loss / len(valid_data), valid_acc / len(valid_data))) else: epoch_str = ("Epoch %d. Train Loss: %f, Train Acc: %f, " % (epoch, train_loss / len(train_data), train_acc / len(train_data))) prev_time = cur_time print(epoch_str + time_str) train(net, train_data, test_data, 10, optimizer, criterion)
#這里的 γ和 都作為參數(shù)進(jìn)行訓(xùn)練,初始化為隨機(jī)的高斯分布,
#moving_mean 和 moving_var 都初始化為 0,并不是更新的參數(shù),訓(xùn)練完 10 次之后,我們可以看看移動(dòng)平均和移動(dòng)方差被修改為了多少
#打出 moving_mean 的前 10 項(xiàng)
print(net.moving_mean[:10]) no_bn_net = nn.Sequential( nn.Linear(784, 100), nn.ReLU(True), nn.Linear(100, 10) ) optimizer = torch.optim.SGD(no_bn_net.parameters(), 1e-1) # 使用隨機(jī)梯度下降,學(xué)習(xí)率 0.1 train(no_bn_net, train_data, test_data, 10, optimizer, criterion)
可以看到雖然最后的結(jié)果兩種情況一樣,但是如果我們看前幾次的情況,可以看到使用批標(biāo)準(zhǔn)化的情況能夠更快的收斂,因?yàn)檫@只是一個(gè)小網(wǎng)絡(luò),所以用不用批標(biāo)準(zhǔn)化都能夠收斂,但是對(duì)于更加深的網(wǎng)絡(luò),使用批標(biāo)準(zhǔn)化在訓(xùn)練的時(shí)候能夠很快地收斂從上面可以看到,我們自己實(shí)現(xiàn)了 2 維情況的批標(biāo)準(zhǔn)化,對(duì)應(yīng)于卷積的 4 維情況的標(biāo)準(zhǔn)化是類似的,只需要沿著通道的維度進(jìn)行均值和方差的計(jì)算,但是我們自己實(shí)現(xiàn)批標(biāo)準(zhǔn)化是很累的,pytorch 當(dāng)然也為我們內(nèi)置了批標(biāo)準(zhǔn)化的函數(shù),一維和二維分別是 torch.nn.BatchNorm1d() 和 torch.nn.BatchNorm2d(),不同于我們的實(shí)現(xiàn),pytorch 不僅將 和 β作為訓(xùn)練的參數(shù),也將 moving_mean 和 moving_var 也作為參數(shù)進(jìn)行訓(xùn)練
下面在卷積網(wǎng)絡(luò)下試用一下批標(biāo)準(zhǔn)化看看效果
def data_tf(x): x = np.array(x, dtype='float32') / 255 x = (x - 0.5) / 0.5 # 數(shù)據(jù)預(yù)處理,標(biāo)準(zhǔn)化 x = torch.from_numpy(x) x = x.unsqueeze(0) return x train_set = mnist.MNIST('./data', train=True, transform=data_tf, download=True) # 重新載入數(shù)據(jù)集,申明定義的數(shù)據(jù)變換 test_set = mnist.MNIST('./data', train=False, transform=data_tf, download=True) train_data = DataLoader(train_set, batch_size=64, shuffle=True) test_data = DataLoader(test_set, batch_size=128, shuffle=False)
使用批標(biāo)準(zhǔn)化
class conv_bn_net(nn.Module): def __init__(self): super(conv_bn_net, self).__init__() self.stage1 = nn.Sequential( nn.Conv2d(1, 6, 3, padding=1), nn.BatchNorm2d(6), nn.ReLU(True), nn.MaxPool2d(2, 2), nn.Conv2d(6, 16, 5), nn.BatchNorm2d(16), nn.ReLU(True), nn.MaxPool2d(2, 2) ) self.classfy = nn.Linear(400, 10) def forward(self, x): x = self.stage1(x) x = x.view(x.shape[0], -1) x = self.classfy(x) return x net = conv_bn_net() optimizer = torch.optim.SGD(net.parameters(), 1e-1) # 使用隨機(jī)梯度下降,學(xué)習(xí)率 0.1 train(net, train_data, test_data, 5, optimizer, criterion)
不使用批標(biāo)準(zhǔn)化
class conv_no_bn_net(nn.Module): def __init__(self): super(conv_no_bn_net, self).__init__() self.stage1 = nn.Sequential( nn.Conv2d(1, 6, 3, padding=1), nn.ReLU(True), nn.MaxPool2d(2, 2), nn.Conv2d(6, 16, 5), nn.ReLU(True), nn.MaxPool2d(2, 2) ) self.classfy = nn.Linear(400, 10) def forward(self, x): x = self.stage1(x) x = x.view(x.shape[0], -1) x = self.classfy(x) return x net = conv_no_bn_net() optimizer = torch.optim.SGD(net.parameters(), 1e-1) # 使用隨機(jī)梯度下降,學(xué)習(xí)率 0.1 train(net, train_data, test_data, 5, optimizer, criterion)
以上這篇pytorch之添加BN的實(shí)現(xiàn)就是小編分享給大家的全部?jī)?nèi)容了,希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
相關(guān)文章
numpy求矩陣的特征值與特征向量(np.linalg.eig函數(shù)用法)
這篇文章主要介紹了numpy求矩陣的特征值與特征向量(np.linalg.eig函數(shù)用法),具有很好的參考價(jià)值,希望對(duì)大家有所幫助。如有錯(cuò)誤或未考慮完全的地方,望不吝賜教2023-02-02Python實(shí)現(xiàn)爬取騰訊招聘網(wǎng)崗位信息
這篇文章主要介紹了如何用python爬取騰訊招聘網(wǎng)崗位信息保存到表格,并做成簡(jiǎn)單可視化。文中的示例代碼對(duì)學(xué)習(xí)Python有一定的幫助,感興趣的可以了解一下2022-01-01Python基于callable函數(shù)檢測(cè)對(duì)象是否可被調(diào)用
這篇文章主要介紹了Python基于callable函數(shù)檢測(cè)對(duì)象是否可被調(diào)用,文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友可以參考下2020-10-10django多對(duì)多表的創(chuàng)建,級(jí)聯(lián)刪除及手動(dòng)創(chuàng)建第三張表
這篇文章主要介紹了django多對(duì)多表的創(chuàng)建,級(jí)聯(lián)刪除及手動(dòng)創(chuàng)建第三張表,文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友可以參考下2019-07-07Python?FastApi結(jié)合異步執(zhí)行方式
這篇文章主要介紹了Python?FastApi結(jié)合異步執(zhí)行方式,具有很好的參考價(jià)值,希望對(duì)大家有所幫助,如有錯(cuò)誤或未考慮完全的地方,望不吝賜教2024-06-06