欧美bbbwbbbw肥妇,免费乱码人妻系列日韩,一级黄片

詳解Pytorch 使用Pytorch擬合多項式(多項式回歸)

 更新時間:2018年05月24日 08:44:13   作者:ZhichaoDuan  
這篇文章主要介紹了詳解Pytorch 使用Pytorch擬合多項式(多項式回歸),小編覺得挺不錯的,現在分享給大家,也給大家做個參考。一起跟隨小編過來看看吧

使用Pytorch來編寫神經網絡具有很多優(yōu)勢,比起Tensorflow,我認為Pytorch更加簡單,結構更加清晰。

希望通過實戰(zhàn)幾個Pytorch的例子,讓大家熟悉Pytorch的使用方法,包括數據集創(chuàng)建,各種網絡層結構的定義,以及前向傳播與權重更新方式。

比如這里給出

    

很顯然,這里我們只需要假定

這里我們只需要設置一個合適尺寸的全連接網絡,根據不斷迭代,求出最接近的參數即可。

但是這里需要思考一個問題,使用全連接網絡結構是毫無疑問的,但是我們的輸入與輸出格式是什么樣的呢?

只將一個x作為輸入合理嗎?顯然是不合理的,因為每一個神經元其實模擬的是wx+b的計算過程,無法模擬冪運算,所以顯然我們需要將x,x的平方,x的三次方,x的四次方組合成一個向量作為輸入,假設有n個不同的x值,我們就可以將n個組合向量合在一起組成輸入矩陣。

這一步代碼如下:

def make_features(x): 
 x = x.unsqueeze(1) 
 return torch.cat([x ** i for i in range(1,4)] , 1) 

我們需要生成一些隨機數作為網絡輸入:

def get_batch(batch_size=32): 
 random = torch.randn(batch_size) 
 x = make_features(random) 
 '''Compute the actual results''' 
 y = f(x) 
 if torch.cuda.is_available(): 
  return Variable(x).cuda(), Variable(y).cuda() 
 else: 
  return Variable(x), Variable(y) 

其中的f(x)定義如下:

w_target = torch.FloatTensor([0.5,3,2.4]).unsqueeze(1) 
b_target = torch.FloatTensor([0.9]) 
 
def f(x): 
 return x.mm(w_target)+b_target[0] 

接下來定義模型:

class poly_model(nn.Module): 
 def __init__(self): 
  super(poly_model, self).__init__() 
  self.poly = nn.Linear(3,1) 
 
 def forward(self, x): 
  out = self.poly(x) 
  return out 
if torch.cuda.is_available(): 
 model = poly_model().cuda() 
else: 
 model = poly_model() 

接下來我們定義損失函數和優(yōu)化器:

criterion = nn.MSELoss() 
optimizer = optim.SGD(model.parameters(), lr = 1e-3) 

網絡部件定義完后,開始訓練:

epoch = 0 
while True: 
 batch_x,batch_y = get_batch() 
 output = model(batch_x) 
 loss = criterion(output,batch_y) 
 print_loss = loss.data[0] 
 optimizer.zero_grad() 
 loss.backward() 
 optimizer.step() 
 epoch+=1 
 if print_loss < 1e-3: 
  break 

到此我們的所有代碼就敲完了,接下來我們開始詳細了解一下其中的一些代碼。

在make_features()定義中,torch.cat是將計算出的向量拼接成矩陣。unsqueeze是作一個維度上的變化。

get_batch中,torch.randn是產生指定維度的隨機數,如果你的機器支持GPU加速,可以將Variable放在GPU上進行運算,類似語句含義相通。

x.mm是作矩陣乘法。

模型定義是重中之重,其實當你掌握Pytorch之后,你會發(fā)現模型定義是十分簡單的,各種基本的層結構都已經為你封裝好了。所有的層結構和損失函數都來自torch.nn,所有的模型構建都是從這個基類 nn.Module繼承的。模型定義中,__init__與forward是有模板的,大家可以自己體會。

nn.Linear是做一個線性的運算,參數的含義代表了輸入層與輸出層的結構,即3*1;在訓練階段,有幾行是Pytorch不同于別的框架的,首先loss是一個Variable,通過loss.data可以取出一個Tensor,再通過data[0]可以得到一個int或者float類型的值,我們才可以進行基本運算或者顯示。每次計算梯度之前,都需要將梯度歸零,否則梯度會疊加。個人覺得別的語句還是比較好懂的,如果有疑問可以在下方評論。

下面是我們的擬合結果

其實效果肯定會很好,因為只是一個非常簡單的全連接網絡,希望大家通過這個小例子可以學到Pytorch的一些基本操作。往后我們會繼續(xù)更新,完整代碼請戳,https://github.com/ZhichaoDuan/PytorchCourse

以上就是本文的全部內容,希望對大家的學習有所幫助,也希望大家多多支持腳本之家。

相關文章

  • Python中表示字符串的三種方法

    Python中表示字符串的三種方法

    這篇文章主要介紹了Python中表示字符串的三種方法的相關資料,需要的朋友可以參考下
    2017-09-09
  • Python pathlib模塊使用方法及實例解析

    Python pathlib模塊使用方法及實例解析

    這篇文章主要介紹了Python pathlib模塊使用方法及實例解析,文中通過示例代碼介紹的非常詳細,對大家的學習或者工作具有一定的參考學習價值,需要的朋友可以參考下
    2020-10-10
  • python的描述符(descriptor)、裝飾器(property)造成的一個無限遞歸問題分享

    python的描述符(descriptor)、裝飾器(property)造成的一個無限遞歸問題分享

    這篇文章主要介紹了python的描述符(descriptor)、裝飾器(property)造成的一個無限遞歸問題分享,一個不太會遇到的問題,需要的朋友可以參考下
    2014-07-07
  • Python基于百度API識別并提取圖片中文字

    Python基于百度API識別并提取圖片中文字

    本文主要實現了利用百度 AI 開發(fā)平臺的 OCR 文字識別 API 識別并提取圖片中的文字。具有一定的參考價值,感興趣的小伙伴們可以參考一下
    2021-06-06
  • python爬蟲 正則表達式解析

    python爬蟲 正則表達式解析

    這篇文章主要介紹了python爬蟲 正則表達式解析,文中通過示例代碼介紹的非常詳細,對大家的學習或者工作具有一定的參考學習價值,需要的朋友可以參考下
    2019-09-09
  • python實現自動打卡小程序

    python實現自動打卡小程序

    這篇文章主要為大家詳細介紹了python實現自動打卡小程序,文中示例代碼介紹的非常詳細,具有一定的參考價值,感興趣的小伙伴們可以參考一下
    2021-03-03
  • python經典100題之皮球掉落的幾種解法

    python經典100題之皮球掉落的幾種解法

    這篇文章主要給大家介紹了關于python經典100題之皮球掉落的幾種解法,這個問題相信不少人都可以從網絡上找到相對應的答案本文提供了3種解法,需要的朋友可以參考下
    2023-11-11
  • 淺析python中的set類型

    淺析python中的set類型

    這篇文章主要介紹了python中的set類型,本文給大家介紹的非常詳細,對大家的學習或工作具有一定的參考借鑒價值,需要的朋友可以參考下
    2022-06-06
  • 詳解Pytorch+PyG實現GAT過程示例

    詳解Pytorch+PyG實現GAT過程示例

    這篇文章主要為大家介紹了Pytorch+PyG實現GAT過程示例詳解,有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進步,早日升職加薪
    2023-04-04
  • python使用cookie庫操保存cookie詳解

    python使用cookie庫操保存cookie詳解

    Python中Cookie模塊(python3中為http.cookies)提供了一個類似字典的特殊對象SimpleCookie,其中存儲并管理著稱為Morsel的cookie值集合,這里介紹了python操作cookie的使用方法
    2014-03-03

最新評論