PyTorch搭建多項式回歸模型(三)
PyTorch基礎(chǔ)入門三:PyTorch搭建多項式回歸模型
1)理論簡介
對于一般的線性回歸模型,由于該函數(shù)擬合出來的是一條直線,所以精度欠佳,我們可以考慮多項式回歸來擬合更多的模型。所謂多項式回歸,其本質(zhì)也是線性回歸。也就是說,我們采取的方法是,提高每個屬性的次數(shù)來增加維度數(shù)。比如,請看下面這樣的例子:
如果我們想要擬合方程:
對于輸入變量和輸出值
,我們只需要增加其平方項、三次方項系數(shù)即可。所以,我們可以設(shè)置如下參數(shù)方程:
可以看到,上述方程與線性回歸方程并沒有本質(zhì)區(qū)別。所以我們可以采用線性回歸的方式來進(jìn)行多項式的擬合。下面請看代碼部分。
2)代碼實現(xiàn)
當(dāng)然最先要做的就是導(dǎo)包了,下面需要說明的只有一個:itertools中的count,這個是用來記數(shù)用的,其可以記數(shù)到無窮,第一個參數(shù)是記數(shù)的起始值,第二個參數(shù)是步長。其內(nèi)部實現(xiàn)相當(dāng)于如下代碼:
def count(firstval=0, step=1): x = firstval while 1: yield x x += step
下面是導(dǎo)包部分代碼,這里定義了一個常量POLY_DEGREE = 3用來指定多項式最高次數(shù)。
from itertools import count import torch import torch.autograd import torch.nn.functional as F POLY_DEGREE = 3
然后我們需要將數(shù)據(jù)處理成矩陣的形式:
在PyTorch里面使用torch.cat()函數(shù)來實現(xiàn)Tensor的拼接:
def make_features(x): """Builds features i.e. a matrix with columns [x, x^2, x^3, x^4].""" x = x.unsqueeze(1) return torch.cat([x ** i for i in range(1, POLY_DEGREE+1)], 1)
對于輸入的個數(shù)據(jù),我們將其擴(kuò)展成上面矩陣所示的樣子。
然后定義出我們需要擬合的多項式,可以隨機抽取一個多項式來作為我們的目標(biāo)多項式。當(dāng)然,系數(shù)和偏置
確定了,多項式也就確定了:
W_target = torch.randn(POLY_DEGREE, 1) b_target = torch.randn(1) def f(x): """Approximated function.""" return x.mm(W_target) + b_target.item()
這里的權(quán)重已經(jīng)定義好了,x.mm(W_target)表示做矩陣乘法,就是每次輸入一個
得到一個
的真實函數(shù)。
在訓(xùn)練的時候我們需要采樣一些點,可以隨機生成一批數(shù)據(jù)來得到訓(xùn)練集。下面的函數(shù)可以讓我們每次取batch_size這么多個數(shù)據(jù),然后將其轉(zhuǎn)化為矩陣形式,再把這個值通過函數(shù)之后的結(jié)果也返回作為真實的輸出值:
def get_batch(batch_size=32): """Builds a batch i.e. (x, f(x)) pair.""" random = torch.randn(batch_size) x = make_features(random) y = f(x) return x, y
接下來我們需要定義模型,這里采用一種簡寫的方式定義模型,torch.nn.Linear()表示定義一個線性模型,這里定義了是輸入值和目標(biāo)參數(shù)的行數(shù)一致(和POLY_DEGREE一致,本次實驗中為3),輸出值為1的模型。
# Define model fc = torch.nn.Linear(W_target.size(0), 1)
下面開始訓(xùn)練模型,訓(xùn)練的過程讓其不斷優(yōu)化,直到隨機取出的batch_size個點中計算出來的均方誤差小于0.001為止。
for batch_idx in count(1): # Get data batch_x, batch_y = get_batch() # Reset gradients fc.zero_grad() # Forward pass output = F.smooth_l1_loss(fc(batch_x), batch_y) loss = output.item() # Backward pass output.backward() # Apply gradients for param in fc.parameters(): param.data.add_(-0.1 * param.grad.data) # Stop criterion if loss < 1e-3: break
這樣就已經(jīng)訓(xùn)練出了我們的多項式回歸模型,為了方便觀察,定義了如下打印函數(shù)來打印出我們擬合的多項式表達(dá)式:
def poly_desc(W, b): """Creates a string description of a polynomial.""" result = 'y = ' for i, w in enumerate(W): result += '{:+.2f} x^{} '.format(w, len(W) - i) result += '{:+.2f}'.format(b[0]) return result print('Loss: {:.6f} after {} batches'.format(loss, batch_idx)) print('==> Learned function:\t' + poly_desc(fc.weight.view(-1), fc.bias)) print('==> Actual function:\t' + poly_desc(W_target.view(-1), b_target))
程序運行結(jié)果如下圖所示:
可以看出,真實的多項式表達(dá)式和我們擬合的多項式十分接近?,F(xiàn)實世界中很多問題都不是簡單的線性回歸,涉及到很多復(fù)雜的非線性模型。但是我們可以在其特征量上進(jìn)行研究,改變或者增加其特征,從而將非線性問題轉(zhuǎn)化為線性問題來解決,這種處理問題的思路是我們從多項式回歸的算法中應(yīng)該汲取到的。
以上就是本文的全部內(nèi)容,希望對大家的學(xué)習(xí)有所幫助,也希望大家多多支持腳本之家。
相關(guān)文章
Python?Pandas刪除替換并提取其中的缺失值NaN(dropna,fillna,isnull)
這篇文章主要給大家介紹了關(guān)于Python?Pandas刪除替換并提取其中的缺失值NaN(dropna,fillna,isnull)的相關(guān)資料,文中通過實例代碼介紹的非常詳細(xì),對大家學(xué)習(xí)或者使用Pandas具有一定的參考學(xué)習(xí)價值,需要的朋友可以參考下2022-01-01python繪制規(guī)則網(wǎng)絡(luò)圖形實例
今天小編大家分享一篇python繪制規(guī)則網(wǎng)絡(luò)圖形實例,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧2019-12-12Python進(jìn)行數(shù)據(jù)可視化Plotly與Dash的應(yīng)用小結(jié)
數(shù)據(jù)可視化是數(shù)據(jù)分析中至關(guān)重要的一環(huán),它能夠幫助我們更直觀地理解數(shù)據(jù)并發(fā)現(xiàn)隱藏的模式和趨勢,本文主要介紹了Python進(jìn)行數(shù)據(jù)可視化Plotly與Dash的應(yīng)用小結(jié),具有一定的參考價值,感興趣的可以了解一下2024-04-04Android模擬器無法啟動,報錯:Cannot set up guest memory ‘a(chǎn)ndroid_arm’ I
這篇文章主要介紹了Android模擬器無法啟動,報錯:Cannot set up guest memory ‘a(chǎn)ndroid_arm’ Invalid argument的解決方法,通過模擬器ram設(shè)置的調(diào)整予以解決,需要的朋友可以參考下2016-07-07python基礎(chǔ)學(xué)習(xí)之遞歸函數(shù)知識總結(jié)
在函數(shù)中調(diào)用函數(shù)自身,我們把這樣的函數(shù)叫做遞歸函數(shù), 遞歸函數(shù)就是循環(huán)的調(diào)用,類似于俄羅斯套娃,本文給各位小伙伴詳細(xì)介紹了python遞歸函數(shù),需要的朋友可以參考下2021-05-05