欧美bbbwbbbw肥妇,免费乱码人妻系列日韩,一级黄片

PyTorch學(xué)習(xí)筆記之回歸實(shí)戰(zhàn)

 更新時間:2018年05月28日 11:34:46   作者:manong_wxd  
這篇文章主要介紹了PyTorch學(xué)習(xí)筆記之回歸實(shí)戰(zhàn),小編覺得挺不錯的,現(xiàn)在分享給大家,也給大家做個參考。一起跟隨小編過來看看吧

本文主要是用PyTorch來實(shí)現(xiàn)一個簡單的回歸任務(wù)。

編輯器:spyder

1.引入相應(yīng)的包及生成偽數(shù)據(jù)

import torch
import torch.nn.functional as F # 主要實(shí)現(xiàn)激活函數(shù)
import matplotlib.pyplot as plt # 繪圖的工具
from torch.autograd import Variable

# 生成偽數(shù)據(jù)
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim = 1)
y = x.pow(2) + 0.2 * torch.rand(x.size())

# 變?yōu)閂ariable
x, y = Variable(x), Variable(y)

其中torch.linspace是為了生成連續(xù)間斷的數(shù)據(jù),第一個參數(shù)表示起點(diǎn),第二個參數(shù)表示終點(diǎn),第三個參數(shù)表示將這個區(qū)間分成平均幾份,即生成幾個數(shù)據(jù)。因?yàn)閠orch只能處理二維的數(shù)據(jù),所以我們用torch.unsqueeze給偽數(shù)據(jù)添加一個維度,dim表示添加在第幾維。torch.rand返回的是[0,1)之間的均勻分布。

2.繪制數(shù)據(jù)圖像

在上述代碼后面加下面的代碼,然后運(yùn)行可得偽數(shù)據(jù)的圖形化表示:

# 繪制數(shù)據(jù)圖像
plt.scatter(x.data.numpy(), y.data.numpy())
plt.show()

3.建立神經(jīng)網(wǎng)絡(luò)

class Net(torch.nn.Module):
 def __init__(self, n_feature, n_hidden, n_output):
  super(Net, self).__init__()
  self.hidden = torch.nn.Linear(n_feature, n_hidden) # hidden layer
  self.predict = torch.nn.Linear(n_hidden, n_output) # output layer

 def forward(self, x):
  x = F.relu(self.hidden(x))  # activation function for hidden layer
  x = self.predict(x)    # linear output
  return x

net = Net(n_feature=1, n_hidden=10, n_output=1)  # define the network
print(net) # net architecture

一般神經(jīng)網(wǎng)絡(luò)的類都繼承自torch.nn.Module__init__()和forward()兩個函數(shù)是自定義類的主要函數(shù)。在__init__()中都要添加一句super(Net, self).__init__(),這是固定的標(biāo)準(zhǔn)寫法,用于繼承父類的初始化函數(shù)。__init__()中只是對神經(jīng)網(wǎng)絡(luò)的模塊進(jìn)行了聲明,真正的搭建是在forwad()中實(shí)現(xiàn)。自定義類中的成員都通過self指針來進(jìn)行訪問,所以參數(shù)列表中都包含了self。

如果想查看網(wǎng)絡(luò)結(jié)構(gòu),可以用print()函數(shù)直接打印網(wǎng)絡(luò)。本文的網(wǎng)絡(luò)結(jié)構(gòu)輸出如下:

Net (
 (hidden): Linear (1 -> 10)
 (predict): Linear (10 -> 1)
)

4.訓(xùn)練網(wǎng)絡(luò)

# 訓(xùn)練100次
for t in range(100):
 prediction = net(x)  # input x and predict based on x

 loss = loss_func(prediction, y)  # 一定要是輸出在前,標(biāo)簽在后 (1. nn output, 2. target)

 optimizer.zero_grad() # clear gradients for next train
 loss.backward()   # backpropagation, compute gradients
 optimizer.step()  # apply gradients

訓(xùn)練網(wǎng)絡(luò)之前我們需要先定義優(yōu)化器和損失函數(shù)。torch.optim包中包括了各種優(yōu)化器,這里我們選用最常見的SGD作為優(yōu)化器。因?yàn)槲覀円獙W(wǎng)絡(luò)的參數(shù)進(jìn)行優(yōu)化,所以我們要把網(wǎng)絡(luò)的參數(shù)net.parameters()傳入優(yōu)化器中,并設(shè)置學(xué)習(xí)率(一般小于1)。

由于這里是回歸任務(wù),我們選擇torch.nn.MSELoss()作為損失函數(shù)。

由于優(yōu)化器是基于梯度來優(yōu)化參數(shù)的,并且梯度會保存在其中。所以在每次優(yōu)化前要通過optimizer.zero_grad()把梯度置零,然后再后向傳播及更新。

5.可視化訓(xùn)練過程

plt.ion() # something about plotting

for t in range(100):
 ...

 if t % 5 == 0:
  # plot and show learning process
  plt.cla()
  plt.scatter(x.data.numpy(), y.data.numpy())
  plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
  plt.text(0.5, 0, 'Loss=%.4f' % loss.data[0], fontdict={'size': 20, 'color': 'red'})
  plt.pause(0.1)

plt.ioff()
plt.show()

6.運(yùn)行結(jié)果


以上就是本文的全部內(nèi)容,希望對大家的學(xué)習(xí)有所幫助,也希望大家多多支持腳本之家。

相關(guān)文章

最新評論