pytorch基礎(chǔ)之損失函數(shù)與反向傳播詳解
1 損失函數(shù)
1.1 Loss Function的作用
- 每次訓(xùn)練神經(jīng)網(wǎng)絡(luò)的時(shí)候都會(huì)有一個(gè)目標(biāo),也會(huì)有一個(gè)輸出。目標(biāo)和輸出之間的誤差,就是用Loss Function來(lái)衡量的。所以Loss誤差是越小越好的。
- 此外,我們可以根據(jù)誤差Loss,指導(dǎo)輸出output接近目標(biāo)target。即我們可以以Loss為依據(jù),不斷訓(xùn)練神經(jīng)網(wǎng)絡(luò),優(yōu)化神經(jīng)網(wǎng)絡(luò)中各個(gè)模塊,從而優(yōu)化output 。
Loss Function的作用:
(1)計(jì)算實(shí)際輸出和目標(biāo)之間的差距
(2)為我們更新輸出提供一定的依據(jù),這個(gè)提供依據(jù)的過程也叫反向傳播。
我們可以看下pytorch為我們提供的損失函數(shù):https://pytorch.org/docs/stable/nn.html#loss-functions
1.2 損失函數(shù)簡(jiǎn)單示例
以L1Loss損失函數(shù)為例子,他其實(shí)很簡(jiǎn)單,就是把實(shí)際值與目標(biāo)值,挨個(gè)相減,再求個(gè)均值。就是結(jié)果。(這個(gè)結(jié)果就反映了實(shí)際值的好壞程度,這個(gè)結(jié)果越小,說(shuō)明越靠近目標(biāo)值)
示例代碼
import torch from torch.nn import L1Loss inputs = torch.tensor([1,2,3],dtype=torch.float32) # 實(shí)際值 targets = torch.tensor([1,2,5],dtype=torch.float32) # 目標(biāo)值 loss = L1Loss() result = loss(inputs,targets) print(result)
輸出結(jié)果:tensor(0.6667)
接下來(lái)我們看下兩個(gè)常用的損失函數(shù):均方差和交叉熵誤差
1.3 均方差
均方差:實(shí)際值與目標(biāo)值對(duì)應(yīng)做差,再平方,再求和,再求均值。
那么套用剛才的例子就是:(0+0+2^2)/3=4/3=1.33333…
代碼實(shí)現(xiàn)
import torch from torch.nn import L1Loss, MSELoss inputs = torch.tensor([1,2,3],dtype=torch.float32) # 實(shí)際值 targets = torch.tensor([1,2,5],dtype=torch.float32) # 目標(biāo)值 loss_mse = MSELoss() result = loss_mse(inputs,targets) print(result)
輸出結(jié)果:tensor(1.3333)
1.4 交叉熵誤差:
這個(gè)比較復(fù)雜一點(diǎn),首先我們看官方文檔給出的公式
這里先用代碼實(shí)現(xiàn)一下他的簡(jiǎn)單用法:
import torch from torch.nn import L1Loss, MSELoss, CrossEntropyLoss x = torch.tensor([0.1,0.2,0.3]) # 預(yù)測(cè)出三個(gè)類別的概率值 y = torch.tensor([1]) # 目標(biāo)值 應(yīng)該是這三類中的第二類 也就是下標(biāo)為1(從0開始的) x = torch.reshape(x,(1,3)) # 修改格式 交叉熵函數(shù)的要求格式是 (N,C) N是bitch_size C是類別 # print(x.shape) loss_cross = CrossEntropyLoss() result = loss_cross(x,y) print(result)
輸出結(jié)果:tensor(1.1019)
1.5 如何在神經(jīng)網(wǎng)絡(luò)中用到Loss Function
# -*- coding: utf-8 -*- # 作者:小土堆 # 公眾號(hào):土堆碎念 import torchvision from torch import nn from torch.nn import Sequential, Conv2d, MaxPool2d, Flatten, Linear from torch.utils.data import DataLoader dataset = torchvision.datasets.CIFAR10("../data", train=False, transform=torchvision.transforms.ToTensor(), download=True) dataloader = DataLoader(dataset, batch_size=1) class Tudui(nn.Module): def __init__(self): super(Tudui, self).__init__() self.model1 = Sequential( Conv2d(3, 32, 5, padding=2), MaxPool2d(2), Conv2d(32, 32, 5, padding=2), MaxPool2d(2), Conv2d(32, 64, 5, padding=2), MaxPool2d(2), Flatten(), Linear(1024, 64), Linear(64, 10) ) def forward(self, x): x = self.model1(x) return x loss = nn.CrossEntropyLoss() tudui = Tudui() for data in dataloader: imgs, targets = data outputs = tudui(imgs) result_loss = loss(outputs, targets) print(result_loss)
2 反向傳播
所謂的反向傳播,就是利用我們得到的loss值,來(lái)對(duì)我們神經(jīng)網(wǎng)絡(luò)中的一些參數(shù)做調(diào)整,以達(dá)到loss值降低的目的。(圖片經(jīng)過一層一層網(wǎng)絡(luò)的處理,最終得到結(jié)果,這是正向傳播。最終結(jié)果與期望值運(yùn)算得到loss,用loss反過來(lái)調(diào)整參數(shù),叫做反向傳播。個(gè)人理解,不一定嚴(yán)謹(jǐn)?。?/p>
2.1 backward
這里利用loss來(lái)調(diào)整參數(shù),主要使用的方法是梯度下降法。
這個(gè)方法原理其實(shí)還是有點(diǎn)復(fù)雜的,但是pytorch為我們實(shí)現(xiàn)好了,所以用起來(lái)很簡(jiǎn)單。
調(diào)用損失函數(shù)得到的值的backward函數(shù)即可。
loss = CrossEntropyLoss() # 定義loss函數(shù) # 實(shí)例化這個(gè)網(wǎng)絡(luò) test = Network() for data in dataloader: imgs,targets = data outputs = test(imgs) # 輸入圖片 result_loss = loss(outputs,targets) result_loss.backward() # 反向傳播 print('ok')
打斷點(diǎn)調(diào)試,可以看到,grad屬性被賦予了一些值。如果不用反向傳播,是沒有值的
當(dāng)然,計(jì)算出這個(gè)grad值只是梯度下降法的第一步,算出了梯度,如何下降呢,要靠?jī)?yōu)化器
2.2 optimizer
優(yōu)化器也有好幾種,官網(wǎng)對(duì)優(yōu)化器的介紹:https://pytorch.org/docs/stable/optim.html
不同的優(yōu)化器需要設(shè)置的參數(shù)不同,但是有兩個(gè)是大部分都有的:模型參數(shù)與學(xué)習(xí)速率
我們以SDG優(yōu)化器為例,看下用法:
# 實(shí)例化這個(gè)網(wǎng)絡(luò) test = Network() loss = CrossEntropyLoss() # 定義loss函數(shù) # 構(gòu)造優(yōu)化器 # 這里我們選擇的優(yōu)化器是SGD 傳入兩個(gè)參數(shù) 第一個(gè)是個(gè)模型test的參數(shù) 第二個(gè)是學(xué)習(xí)率 optim = torch.optim.SGD(test.parameters(),lr=0.01) for data in dataloader: imgs,targets = data outputs = test(imgs) # 輸入圖片 result_loss = loss(outputs,targets) # 計(jì)算loss optim.zero_grad() #因?yàn)檫@是在循環(huán)里面 所以每次開始優(yōu)化之前要把梯度置為0 防止上一次的結(jié)果影響這一次 result_loss.backward() # 反向傳播 求得梯度 optim.step() # 對(duì)參數(shù)進(jìn)行調(diào)優(yōu)
這里面我們剛學(xué)得主要是這三行:
清零,反向傳播求梯度,調(diào)優(yōu)
optim.zero_grad() #因?yàn)檫@是在循環(huán)里面 所以每次開始優(yōu)化之前要把梯度置為0 防止上一次的結(jié)果影響這一次 result_loss.backward() # 反向傳播 求得梯度 optim.step() # 對(duì)參數(shù)進(jìn)行調(diào)優(yōu)
我們可以打印一下loss,看下調(diào)優(yōu)后得loss有什么變化。
注意:我們dataloader是把數(shù)據(jù)拿出來(lái)一遍,那么看了一遍之后,經(jīng)過這一遍的調(diào)整,下一遍再看的時(shí)候,loss才有變化。
所以,我們先讓讓他學(xué)習(xí)20輪,然后看一下每一輪的loss是多少
# 實(shí)例化這個(gè)網(wǎng)絡(luò) test = Network() loss = CrossEntropyLoss() # 定義loss函數(shù) # 構(gòu)造優(yōu)化器 # 這里我們選擇的優(yōu)化器是SGD 傳入兩個(gè)參數(shù) 第一個(gè)是個(gè)模型test的參數(shù) 第二個(gè)是學(xué)習(xí)率 optim = torch.optim.SGD(test.parameters(),lr=0.01) for epoch in range(20): running_loss = 0.0 for data in dataloader: imgs,targets = data outputs = test(imgs) # 輸入圖片 result_loss = loss(outputs,targets) # 計(jì)算loss optim.zero_grad() #因?yàn)檫@是在循環(huán)里面 所以每次開始優(yōu)化之前要把梯度置為0 防止上一次的結(jié)果影響這一次 result_loss.backward() # 反向傳播 求得梯度 optim.step() # 對(duì)參數(shù)進(jìn)行調(diào)優(yōu) running_loss = running_loss + result_loss # 記錄下這一輪中每個(gè)loss的值之和 print(running_loss) # 打印每一輪的loss值之和
可以看到,loss之和一次比一次降低了。
總結(jié)
具有很好的參考價(jià)值,希望對(duì)大家有所幫助,如有錯(cuò)誤或未考慮完全的地方,望不吝賜教
相關(guān)文章
零基礎(chǔ)小白多久能學(xué)會(huì)python
在本篇文章里小編給大家分享的是一篇關(guān)于零基礎(chǔ)學(xué)python要多久的相關(guān)文章內(nèi)容,有興趣的朋友們可以跟著學(xué)習(xí)下。2020-06-06Ubuntu 20.04安裝Pycharm2020.2及鎖定到任務(wù)欄的問題(小白級(jí)操作)
這篇文章主要介紹了Ubuntu 20.04安裝Pycharm2020.2及鎖定到任務(wù)欄的問題,本教程給大家講解的很詳細(xì),非常適合小白級(jí)操作,需要的朋友可以參考下2020-10-10基于Django?websocket實(shí)現(xiàn)視頻畫面的實(shí)時(shí)傳輸功能(最新推薦)
Django?Channels?是一個(gè)用于在?Django框架中實(shí)現(xiàn)實(shí)時(shí)、異步通信的擴(kuò)展庫(kù),本文給大家介紹基于Django?websocket實(shí)現(xiàn)視頻畫面的實(shí)時(shí)傳輸案例,本案例是基于B/S架構(gòu)的視頻監(jiān)控畫面的實(shí)時(shí)傳輸,使用django作為服務(wù)端的開發(fā)框架,需要的朋友可以參考下2023-06-06flask操作數(shù)據(jù)庫(kù)相關(guān)配置及實(shí)現(xiàn)示例步驟全解
這篇文章主要介紹了flask操作數(shù)據(jù)庫(kù)相關(guān)配置及實(shí)現(xiàn)示例詳解,有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進(jìn)步,早日升職加薪2024-01-01Python中猜拳游戲與猜篩子游戲的實(shí)現(xiàn)方法
這篇文章主要給大家介紹了關(guān)于Python中猜拳游戲與猜篩子游戲的相關(guān)資料,文中通過示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來(lái)一起學(xué)習(xí)學(xué)習(xí)吧2020-09-09Python PyMySQL操作MySQL數(shù)據(jù)庫(kù)的方法詳解
PyMySQL是一個(gè)用于Python編程語(yǔ)言的純Python MySQL客戶端庫(kù),它遵循Python標(biāo)準(zhǔn)DB API接口,并提供了許多方便的功能,本文就來(lái)和大家簡(jiǎn)單介紹一下吧2023-05-05教你用python實(shí)現(xiàn)一個(gè)無(wú)界面的小型圖書管理系統(tǒng)
今天帶大家學(xué)習(xí)怎么用python實(shí)現(xiàn)一個(gè)無(wú)界面的小型圖書管理系統(tǒng),文中有非常詳細(xì)的圖文解說(shuō)及代碼示例,對(duì)正在學(xué)習(xí)python的小伙伴們有很好地幫助,需要的朋友可以參考下2021-05-05Python matplotlib畫圖與中文設(shè)置操作實(shí)例分析
這篇文章主要介紹了Python matplotlib畫圖與中文設(shè)置操作,結(jié)合實(shí)例形式分析了Python使用matplotlib進(jìn)行圖形繪制及中文設(shè)置相關(guān)操作技巧,需要的朋友可以參考下2019-04-04python切片復(fù)制列表的知識(shí)點(diǎn)詳解
在本篇文章里小編給大家整理的是一篇關(guān)于python切片復(fù)制列表的知識(shí)點(diǎn)相關(guān)內(nèi)容,有興趣的朋友們可以跟著學(xué)習(xí)下。2021-10-10Python實(shí)現(xiàn)將doc轉(zhuǎn)化pdf格式文檔的方法
這篇文章主要介紹了Python實(shí)現(xiàn)將doc轉(zhuǎn)化pdf格式文檔的方法,結(jié)合實(shí)例形式分析了Python實(shí)現(xiàn)doc格式文件讀取及轉(zhuǎn)換pdf格式文件的操作技巧,以及php調(diào)用py文件的具體實(shí)現(xiàn)方法,需要的朋友可以參考下2018-01-01