pytorch如何定義新的自動(dòng)求導(dǎo)函數(shù)
pytorch定義新的自動(dòng)求導(dǎo)函數(shù)
在pytorch中想自定義求導(dǎo)函數(shù),通過(guò)實(shí)現(xiàn)torch.autograd.Function并重寫(xiě)forward和backward函數(shù),來(lái)定義自己的自動(dòng)求導(dǎo)運(yùn)算。參考官網(wǎng)上的demo:傳送門(mén)
直接上代碼,定義一個(gè)ReLu來(lái)實(shí)現(xiàn)自動(dòng)求導(dǎo)
import torch
class MyRelu(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
# 我們使用ctx上下文對(duì)象來(lái)緩存,以便在反向傳播中使用,ctx存儲(chǔ)時(shí)候只能存tensor
# 在正向傳播中,我們接收一個(gè)上下文對(duì)象ctx和一個(gè)包含輸入的張量input;
# 我們必須返回一個(gè)包含輸出的張量,
# input.clamp(min = 0)表示講輸入中所有值范圍規(guī)定到0到正無(wú)窮,如input=[-1,-2,3]則被轉(zhuǎn)換成input=[0,0,3]
ctx.save_for_backward(input)
# 返回幾個(gè)值,backward接受參數(shù)則包含ctx和這幾個(gè)值
return input.clamp(min = 0)
@staticmethod
def backward(ctx, grad_output):
# 把ctx中存儲(chǔ)的input張量讀取出來(lái)
input, = ctx.saved_tensors
# grad_output存放反向傳播過(guò)程中的梯度
grad_input = grad_output.clone()
# 這兒就是ReLu的規(guī)則,表示原始數(shù)據(jù)小于0,則relu為0,因此對(duì)應(yīng)索引的梯度都置為0
grad_input[input < 0] = 0
return grad_input進(jìn)行輸入數(shù)據(jù)并測(cè)試
dtype = torch.float
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 使用torch的generator定義隨機(jī)數(shù),注意產(chǎn)生的是cpu隨機(jī)數(shù)還是gpu隨機(jī)數(shù)
generator=torch.Generator(device).manual_seed(42)
# N是Batch, H is hidden dimension,
# D_in is input dimension;D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10
x = torch.randn(N, D_in, device=device, dtype=dtype,generator=generator)
y = torch.randn(N, D_out, device=device, dtype=dtype, generator=generator)
w1 = torch.randn(D_in, H, device=device, dtype=dtype, requires_grad=True, generator=generator)
w2 = torch.randn(H, D_out, device=device, dtype=dtype, requires_grad=True, generator=generator)
learning_rate = 1e-6
for t in range(500):
relu = MyRelu.apply
# 使用函數(shù)傳入?yún)?shù)運(yùn)算
y_pred = relu(x.mm(w1)).mm(w2)
# 計(jì)算損失
loss = (y_pred - y).pow(2).sum()
if t % 100 == 99:
print(t, loss.item())
# 傳播
loss.backward()
with torch.no_grad():
w1 -= learning_rate * w1.grad
w2 -= learning_rate * w2.grad
w1.grad.zero_()
w2.grad.zero_()pytorch自動(dòng)求導(dǎo)與邏輯回歸
自動(dòng)求導(dǎo)

retain_graph設(shè)為T(mén)rue,可以進(jìn)行兩次反向傳播


邏輯回歸


import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np
torch.manual_seed(10)
#========生成數(shù)據(jù)=============
sample_nums = 100
mean_value = 1.7
bias = 1
n_data = torch.ones(sample_nums,2)
x0 = torch.normal(mean_value*n_data,1)+bias#類(lèi)別0數(shù)據(jù)
y0 = torch.zeros(sample_nums)#類(lèi)別0標(biāo)簽
x1 = torch.normal(-mean_value*n_data,1)+bias#類(lèi)別1數(shù)據(jù)
y1 = torch.ones(sample_nums)#類(lèi)別1標(biāo)簽
train_x = torch.cat((x0,x1),0)
train_y = torch.cat((y0,y1),0)
#==========選擇模型===========
class LR(nn.Module):
def __init__(self):
super(LR,self).__init__()
self.features = nn.Linear(2,1)
self.sigmoid = nn.Sigmoid()
def forward(self,x):
x = self.features(x)
x = self.sigmoid(x)
return x
lr_net = LR()#實(shí)例化邏輯回歸模型
#==============選擇損失函數(shù)===============
loss_fn = nn.BCELoss()
#==============選擇優(yōu)化器=================
lr = 0.01
optimizer = torch.optim.SGD(lr_net.parameters(),lr = lr,momentum=0.9)
#===============模型訓(xùn)練==================
for iteration in range(1000):
#前向傳播
y_pred = lr_net(train_x)#模型的輸出
#計(jì)算loss
loss = loss_fn(y_pred.squeeze(),train_y)
#反向傳播
loss.backward()
#更新參數(shù)
optimizer.step()
#繪圖
if iteration % 20 == 0:
mask = y_pred.ge(0.5).float().squeeze() #以0.5分類(lèi)
correct = (mask==train_y).sum()#正確預(yù)測(cè)樣本數(shù)
acc = correct.item()/train_y.size(0)#分類(lèi)準(zhǔn)確率
plt.scatter(x0.data.numpy()[:,0],x0.data.numpy()[:,1],c='r',label='class0')
plt.scatter(x1.data.numpy()[:,0],x1.data.numpy()[:,1],c='b',label='class1')
w0,w1 = lr_net.features.weight[0]
w0,w1 = float(w0.item()),float(w1.item())
plot_b = float(lr_net.features.bias[0].item())
plot_x = np.arange(-6,6,0.1)
plot_y = (-w0*plot_x-plot_b)/w1
plt.xlim(-5,7)
plt.ylim(-7,7)
plt.plot(plot_x,plot_y)
plt.text(-5,5,'Loss=%.4f'%loss.data.numpy(),fontdict={'size':20,'color':'red'})
plt.title('Iteration:{}\nw0:{:.2f} w1:{:.2f} b{:.2f} accuracy:{:2%}'.format(iteration,w0,w1,plot_b,acc))
plt.legend()
plt.show()
plt.pause(0.5)
if acc > 0.99:
break總結(jié)
以上為個(gè)人經(jīng)驗(yàn),希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
- 使用pytorch進(jìn)行張量計(jì)算、自動(dòng)求導(dǎo)和神經(jīng)網(wǎng)絡(luò)構(gòu)建功能
- 在?pytorch?中實(shí)現(xiàn)計(jì)算圖和自動(dòng)求導(dǎo)
- Pytorch自動(dòng)求導(dǎo)函數(shù)詳解流程以及與TensorFlow搭建網(wǎng)絡(luò)的對(duì)比
- 淺談Pytorch中的自動(dòng)求導(dǎo)函數(shù)backward()所需參數(shù)的含義
- pytorch中的自定義反向傳播,求導(dǎo)實(shí)例
- 關(guān)于PyTorch 自動(dòng)求導(dǎo)機(jī)制詳解
- Pytorch反向求導(dǎo)更新網(wǎng)絡(luò)參數(shù)的方法
- 關(guān)于pytorch求導(dǎo)總結(jié)(torch.autograd)
相關(guān)文章
使用Python Fast API發(fā)布API服務(wù)的過(guò)程詳解
這篇文章主要介紹了使用Python Fast API發(fā)布API服務(wù),本文給大家介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或工作具有一定的參考借鑒價(jià)值,需要的朋友可以參考下2023-04-04
Python基礎(chǔ)之?dāng)?shù)據(jù)結(jié)構(gòu)詳解
這篇文章主要介紹了Python基礎(chǔ)之?dāng)?shù)據(jù)結(jié)構(gòu)詳解,文中有非常詳細(xì)的代碼示例,對(duì)正在學(xué)習(xí)python基礎(chǔ)的小伙伴們有非常好的幫助,需要的朋友可以參考下2021-04-04
PyQt5中QTimer定時(shí)器的實(shí)例代碼
如果需要在程序中周期性地進(jìn)行某項(xiàng)操作,比如檢測(cè)某種設(shè)備的狀態(tài),就會(huì)用到定時(shí)器,本文主要介紹了PyQt5中QTimer定時(shí)器的實(shí)例代碼,感興趣的可以了解一下2021-06-06
Pytorch根據(jù)layers的name凍結(jié)訓(xùn)練方式
今天小編就為大家分享一篇Pytorch根據(jù)layers的name凍結(jié)訓(xùn)練方式,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧2020-01-01
使用python制作一個(gè)為hex文件增加版本號(hào)的腳本實(shí)例
今天小編就為大家分享一篇使用python制作一個(gè)為hex文件增加版本號(hào)的腳本實(shí)例,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧2019-06-06
python中的json數(shù)據(jù)和pyecharts模塊入門(mén)示例教程
JSON是一種輕量級(jí)的數(shù)據(jù)交互格式??梢园凑?JSON指定的格式去組織和封裝數(shù)據(jù),這篇文章主要介紹了python中的json數(shù)據(jù)和pyecharts模塊入門(mén),需要的朋友可以參考下2022-12-12
Python抓取框架Scrapy爬蟲(chóng)入門(mén):頁(yè)面提取
Scrapy吸引人的地方在于它是一個(gè)框架,任何人都可以根據(jù)需求方便的修改,下面這篇文章主要給大家介紹了關(guān)于Python抓取框架Scrapy爬蟲(chóng)入門(mén)之頁(yè)面提取的相關(guān)資料,文中通過(guò)示例代碼介紹的非常詳細(xì),需要的朋友可以參考下。2017-12-12

