pytorch-RNN進行回歸曲線預測方式
更新時間:2020年01月14日 10:14:24 作者:馬飛飛
今天小編就為大家分享一篇pytorch-RNN進行回歸曲線預測方式,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧
任務
通過輸入的sin曲線與預測出對應的cos曲線
#初始加載包 和定義參數(shù) import torch from torch import nn import numpy as np import matplotlib.pyplot as plt torch.manual_seed(1) #為了可復現(xiàn) #超參數(shù)設定 TIME_SETP=10 INPUT_SIZE=1 LR=0.02 DOWNLoad_MNIST=True
定義RNN網絡結構
from torch.autograd import Variable class RNN(nn.Module): def __init__(self): #在這個函數(shù)中,兩步走,先init,再逐步定義層結構 super(RNN,self).__init__() self.rnn=nn.RNN( #定義32隱層的rnn結構 input_size=1, hidden_size=32, #隱層有32個記憶體 num_layers=1, #隱層層數(shù)是1 batch_first=True ) self.out=nn.Linear(32,1) #32個記憶體對應一個輸出 def forward(self,x,h_state): #前向過程,獲取 rnn網絡輸出r_put(注意這里r_out并不是最后輸出,最后要經過全連接層) 和 記憶體情況h_state r_out,h_state=self.rnn(x,h_state) outs=[]#獲取所有時間點下得到的預測值 for time_step in range(r_out.size(1)): #將記憶rnn層的輸出傳到全連接層來得到最終輸出。 這樣每個輸入對應一個輸出,所以會有長度為10的輸出 outs.append(self.out(r_out[:,time_step,:])) return torch.stack(outs,dim=1),h_state #將10個數(shù) 通過stack方式壓縮在一起 rnn=RNN() print('RNN的網絡體系結構為:',rnn)
創(chuàng)建數(shù)據(jù)集及網絡訓練
以sin曲線為特征,以cos曲線為標簽進行網絡的訓練
#定義優(yōu)化器和 損失函數(shù) optimizer=torch.optim.Adam(rnn.parameters(),lr=LR) loss_fun=nn.MSELoss() h_state=None #記錄的隱藏層狀態(tài),記住這就是記憶體,初始時候為空,之后每次后面的都會使用到前面的記憶,自動生成全0的 #這樣加入記憶信息后,每次都會在之前的記憶矩陣基礎上再進行新的訓練,初始是全0的形式。 #啟動訓練,這里假定訓練的批次為100次 plt.ion() #可以設定持續(xù)不斷的繪圖,但是在這里看還是間斷的,這是jupyter的問題 for step in range(100): #我們以一個π為一個時間步 定義數(shù)據(jù), start,end=step*np.pi,(step+1)*np.pi steps=np.linspace(start,end,10,dtype=np.float32) #注意這里的10并不是間隔為10,而是將數(shù)按范圍分成10等分了 x_np=np.sin(steps) y_np=np.cos(steps) #將numpy類型轉成torch類型 *****當需要 求梯度時,一個 op 的兩個輸入都必須是要 Variable,輸入的一定要variable包下 x=Variable(torch.from_numpy(x_np[np.newaxis,:,np.newaxis]))#增加兩個維度,是三維的數(shù)據(jù)。 y=Variable(torch.from_numpy(y_np[np.newaxis,:,np.newaxis])) #將每個時間步上的10個值 輸入到rnn獲得結果 這里rnn會自動執(zhí)行forward前向過程. 這里輸入時10個,輸出也是10個,傳遞的是一個長度為32的記憶體 predition,h_state=rnn(x,h_state) #更新新的中間狀態(tài) h_state=Variable(h_state.data) #擦,這點一定要從新包裝 loss=loss_fun(predition,y) #print('loss:',loss) optimizer.zero_grad() loss.backward() optimizer.step() # plotting 畫圖,這里先平展了 flatten,這樣就是得到一個數(shù)組,更加直接 plt.plot(steps, y_np.flatten(), 'r-') plt.plot(steps, predition.data.numpy().flatten(), 'b-') #plt.draw(); plt.pause(0.05) plt.ioff() #關閉交互模式 plt.show()
以上這篇pytorch-RNN進行回歸曲線預測方式就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支持腳本之家。
相關文章
使用python讀取csv文件快速插入數(shù)據(jù)庫的實例
今天小編就為大家分享一篇使用python讀取csv文件快速插入數(shù)據(jù)庫的實例,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧2018-06-06解決Alexnet訓練模型在每個epoch中準確率和loss都會一升一降問題
這篇文章主要介紹了解決Alexnet訓練模型在每個epoch中準確率和loss都會一升一降問題,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧2020-06-06Pycharm?cannot?set?up?a?python?SDK問題的原因及解決方法
這篇文章主要給大家介紹了關于Pycharm?cannot?set?up?a?python?SDK問題的原因及解決方法,這個問題已經不是第一次出現(xiàn)了,所以干脆總結下,需要的朋友可以參考下2022-06-06Python統(tǒng)計字符內容的占比的實現(xiàn)
本文介紹了如何使用Python統(tǒng)計字符占比,包括字符串中字母、數(shù)字、空格等字符的占比,對大家的學習或者工作具有一定的參考學習價值,需要的朋友們下面隨著小編來一起學習學習吧2023-08-08Pytorch使用技巧之Dataloader中的collate_fn參數(shù)詳析
collate_fn 參數(shù)的目的主要是為了隨心所欲的轉變數(shù)據(jù)的類型,這個數(shù)據(jù)是用DataLoader加載的,比如img,target,下面這篇文章主要給大家介紹了關于Pytorch使用技巧之Dataloader中的collate_fn參數(shù)的相關資料,需要的朋友可以參考下2022-03-03