pyTorch深入學(xué)習(xí)梯度和Linear Regression實(shí)現(xiàn)
梯度
PyTorch的數(shù)據(jù)結(jié)構(gòu)是tensor,它有個屬性叫做requires_grad,設(shè)置為True以后,就開始track在其上的所有操作,前向計(jì)算完成后,可以通過backward來進(jìn)行梯度回傳。
評估模型的時候我們并不需要梯度回傳,使用with torch.no_grad() 將不需要梯度的代碼段包裹起來。每個Tensor都有一個.grad_fn屬性,該屬性即創(chuàng)建該Tensor的Function,直接用構(gòu)造的tensor返回None,否則是生成該tensor的操作。
tensor(data, *, dtype=None, device=None, requires_grad=False, pin_memory=False) -> Tensor #require_grad默認(rèn)是false,下面我們將顯式的開啟 x = torch.tensor([1,2,3],requires_grad=True,dtype=torch.float)
注意只有數(shù)據(jù)類型是浮點(diǎn)型和complex類型才能require梯度,所以這里顯示指定dtype為torch.float32
x = torch.tensor([1,2,3],requires_grad=True,dtype=torch.float32) > tensor([1.,2.,3.],grad_fn=None) y = x + 2 > tensor([3.,4.,5.],grad_fn=<AddBackward0>) z = y * y * 3 > tensor([3.,4.,5.],grad_fn=<MulBackward0>)
像x這種直接創(chuàng)建的,沒有g(shù)rad_fn,被稱為葉子結(jié)點(diǎn)。grad_fn記錄了一個個基本操作用來進(jìn)行梯度計(jì)算的。
關(guān)于梯度回傳計(jì)算看下面一個例子
x = torch.ones((2,2),requires_grad=True) > tensor([[1.,1.], > [1.,1.]],requires_grad=True) y = x + 2 z = y * y * 3 out = z.mean() #out是一個標(biāo)量,無需指定求偏導(dǎo)的變量 out.backward() x.grad > tensor([[4.500,4.500], > [4.500,4.500]]) #每次計(jì)算梯度前,需要將梯度清零,否則會累加 x.grad.data.zero_()
值得注意的是只有葉子節(jié)點(diǎn)的梯度在回傳時才會被計(jì)算,也就是說,上面的例子中拿不到y(tǒng)和z的grad。
來看一個中斷求導(dǎo)的例子
x = torch.tensor(1.,requires_grad=True) y1 = x ** 2 with torch.no_grad() y2 = x ** 3 y3 = y1 + y2 y3.backward() print(x.grad) > 2
本來梯度應(yīng)該為5的,但是由于y2被with torch.no_grad()包裹,在梯度計(jì)算的時候不會被追蹤。
如果我們想要修改某個tensor的數(shù)值但是又不想被autograd記錄,那么需要使用對x.data進(jìn)行操作就行這也是一個張量。
線性回歸(linear regression)
利用線性回歸來預(yù)測一棟房屋的價格,價格取決于很多feature,這里簡化問題,假設(shè)價格只取決于兩個因素,面積(平方米)和房齡(年)
x1代表面積,x2代表房齡,售出價格為y
模擬數(shù)據(jù)集
假設(shè)我們的樣本數(shù)量為1000個,每個數(shù)據(jù)包括兩個features,則數(shù)據(jù)為1000 * 2的2-d張量,用正太分布來隨機(jī)取值。
labels是房屋的價格,長度為1000的一維張量。
真實(shí)w和b提前把值定好,然后再取一個干擾量 δ \delta δ(也用高斯分布取值,用來模擬真實(shí)數(shù)據(jù)集中的偏差)
num_features = 2#兩個特征 num_examples = 1000 #樣本個數(shù) w = torch.normal(0,1,(num_features,1)) b = torch.tensor(4.2) samples = torch.normal(0,1,(num_examples,num_features)) labels = samples.matmul(w) + b noise = torch.normal(0,.01,labels.shape) labels += noise
加載數(shù)據(jù)集
import random def data_iter(samples,labels,batch_size): num_samples = samples.shape[0] #獲得batch軸的長度 indices = [i for i in range(num_samples)] random.shuffle(indices)#將索引數(shù)組原地打亂 for i in range(0,num_samples,batch_size): j = torch.LongTensor(indices[i:min(i+batch_size,num_samples)]) yield samples.index_select(0,j),labels(0,j)
torch.index_select(dim,index)
dim表示tensor的軸,index是一個tensor,里面包含的是索引。
定義loss_function
def loss_function(predict,labels): loss = (predict,labels)** 2 / 2 return loss.mean()
定義優(yōu)化器
def loss_function(predict,labels): loss = (predict,labels)** 2 / 2 return loss.mean()
開始訓(xùn)練
w = torch.normal(0.,1.,(num_features,1),requires_grad=True) b = torch.zero(0.,dtype=torch.float32,requires_grad=True) batch_size = 100 for epoch in range(10): for data, label in data_iter(samples,labels,batch_size): predict = data.matmul(w) + b loss = loss_function(predict,label) loss.backward() optimizer([w,b],0.05) w.grad.data.zero_() b.grad.data.zero_()
以上就是pyTorch深入學(xué)習(xí)梯度和Linear Regression實(shí)現(xiàn)的詳細(xì)內(nèi)容,更多關(guān)于pyTorch實(shí)現(xiàn)梯度和Linear Regression的資料請關(guān)注腳本之家其它相關(guān)文章!
相關(guān)文章
python實(shí)現(xiàn)員工管理系統(tǒng)
這篇文章主要介紹了python實(shí)現(xiàn)員工管理系統(tǒng),具有一定的參考價值,感興趣的小伙伴們可以參考一下2018-01-01解決使用Spyder IDE時matplotlib繪圖的顯示問題
這篇文章主要介紹了解決使用Spyder IDE時matplotlib繪圖的顯示問題,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧2021-04-04pytorch實(shí)現(xiàn)用CNN和LSTM對文本進(jìn)行分類方式
今天小編就為大家分享一篇pytorch實(shí)現(xiàn)用CNN和LSTM對文本進(jìn)行分類方式,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧2020-01-01python使用pygame創(chuàng)建精靈Sprite
這篇文章主要介紹了使用Pygame創(chuàng)建精靈Sprite,文中通過示例代碼介紹的非常詳細(xì),對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧2021-04-04Python基礎(chǔ)之hashlib模塊subprocess模塊logging模塊
這篇文章主要為大家介紹了Python基礎(chǔ)之hashlib模塊subprocess模塊logging模塊示例詳解,有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進(jìn)步,早日升職加薪2022-11-11python爬蟲使用requests發(fā)送post請求示例詳解
這篇文章主要介紹了python爬蟲使用requests發(fā)送post請求示例詳解,文中通過示例代碼介紹的非常詳細(xì),對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧2020-08-08如何安裝多版本python python2和python3共存以及pip共存
這篇文章主要為大家詳細(xì)介紹了python多版本的安裝方法,解決python2和python3共存以及pip共存問題,具有一定的參考價值,感興趣的小伙伴們可以參考一下2018-09-09