PyTorch實(shí)現(xiàn)多維度特征輸入邏輯回歸
一、實(shí)現(xiàn)過程
1、準(zhǔn)備數(shù)據(jù)
本文數(shù)據(jù)采取文獻(xiàn)[1]給出的數(shù)據(jù)集,該數(shù)據(jù)集前8列為特征,最后1列為標(biāo)簽(0/1)。本模型使用pandas處理該數(shù)據(jù)集,需要注意的是,原始數(shù)據(jù)集沒有特征名稱,需要自己在第一行添加上去,否則,pandas會(huì)把第一行的數(shù)據(jù)當(dāng)成特征名稱處理,從而影響最后的分類效果。
代碼如下:
# 1、準(zhǔn)備數(shù)據(jù) import torch import pandas as pd import numpy as np xy = pd.read_csv('G:/datasets/diabetes/diabetes.csv',dtype=np.float32)?? ?# 文件路徑 x_data = torch.from_numpy(xy.values[:,:-1]) y_data = torch.from_numpy(xy.values[:,[-1]])
2、設(shè)計(jì)模型
本文采取文獻(xiàn)[1]的思路,激活函數(shù)使用ReLU,最后一層使用Sigmoid
函數(shù),
代碼如下:
class Model(torch.nn.Module): ? ? def __init__(self): ? ? ? ? super(Model,self).__init__() ? ? ? ? self.linear1 = torch.nn.Linear(8,6) ? ? ? ? self.linear2 = torch.nn.Linear(6,4) ? ? ? ? self.linear3 = torch.nn.Linear(4,1) ? ? ? ? self.activate = torch.nn.ReLU() ? ?? ? ? def forward(self, x): ? ? ? ? x = self.activate(self.linear1(x)) ? ? ? ? x = self.activate(self.linear2(x)) ? ? ? ? x = torch.sigmoid(self.linear3(x)) ? ? ? ? return x model = Model()
將模型和數(shù)據(jù)加載到GPU上,代碼如下:
### 將模型和訓(xùn)練數(shù)據(jù)加載到GPU上 # 模型加載到GPU上 device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") model.to(device) # 數(shù)據(jù)加載到GPU上 x = x_data.to(device) y = y_data.to(device)
3、構(gòu)造損失函數(shù)和優(yōu)化器 criterion = torch.nn.BCELoss(reduction='mean') optimizer = torch.optim.SGD(model.parameters(),lr=0.1)
4、訓(xùn)練過程
epoch_list = [] loss_list = [] epochs = 10000 for epoch in range(epochs): ? ? # Forward ? ? y_pred = model(x) ? ? loss = criterion(y_pred, y) ? ? print(epoch, loss) ? ? epoch_list.append(epoch) ? ? loss_list.append(loss.data.item()) ? ? # Backward ? ? optimizer.zero_grad() ? ? loss.backward() ? ? # Update ? ? optimizer.step()
5、結(jié)果展示
查看各個(gè)層的權(quán)重和偏置:
model.linear1.weight,model.linear1.bias model.linear2.weight,model.linear2.bias model.linear3.weight,model.linear3.bias
損失值隨迭代次數(shù)的變化曲線:
# 繪圖展示 plt.plot(epoch_list,loss_list,'b') plt.xlabel('epoch') plt.ylabel('loss') plt.grid() plt.show()
最終的損失和準(zhǔn)確率:
# 準(zhǔn)確率 y_pred_label = torch.where(y_pred.data.cpu() >= 0.5,torch.tensor([1.0]),torch.tensor([0.0])) acc = torch.eq(y_pred_label, y_data).sum().item()/y_data.size(0) print("loss = ",loss.item(), "acc = ",acc) loss = ?0.4232381284236908 acc = ?0.7931488801054019
二、參考文獻(xiàn)
- [1] https://www.bilibili.com/video/BV1Y7411d7Ys?p=7
- [2] https://blog.csdn.net/bit452/article/details/109682078
到此這篇關(guān)于PyTorch實(shí)現(xiàn)多維度特征輸入邏輯回歸的文章就介紹到這了,更多相關(guān)PyTorch邏輯回歸內(nèi)容請(qǐng)搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
相關(guān)文章
python使用socket連接遠(yuǎn)程服務(wù)器的方法
這篇文章主要介紹了python使用socket連接遠(yuǎn)程服務(wù)器的方法,涉及Python中socket通信的基本技巧,具有一定參考借鑒價(jià)值,需要的朋友可以參考下2015-04-04Python?PyQt拖動(dòng)控件對(duì)齊到網(wǎng)格的方法步驟
pyqt是一個(gè)用于創(chuàng)建GUI應(yīng)用程序的跨平臺(tái)工具包,它將python與qt庫融為一體,下面這篇文章主要給大家介紹了關(guān)于Python?PyQt拖動(dòng)控件對(duì)齊到網(wǎng)格的方法步驟,需要的朋友可以參考下2022-12-12利用selenium 3.7和python3添加cookie模擬登陸的實(shí)現(xiàn)
這篇文章主要給大家介紹了關(guān)于利用selenium 3.7和python3添加cookie模擬登陸的相關(guān)資料,文中通過示例代碼介紹的非常詳細(xì),對(duì)大家學(xué)習(xí)或者使用python具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧。2017-11-11python使用json序列化datetime類型實(shí)例解析
這篇文章主要介紹了python使用json序列化datetime類型實(shí)例解析,分享了相關(guān)代碼示例,小編覺得還是挺不錯(cuò)的,具有一定借鑒價(jià)值,需要的朋友可以參考下2018-02-02python使用tcp實(shí)現(xiàn)局域網(wǎng)內(nèi)文件傳輸
這篇文章主要介紹了python使用tcp實(shí)現(xiàn)局域網(wǎng)內(nèi)文件傳輸,文件包括文本,圖片,視頻等,文中示例代碼介紹的非常詳細(xì),具有一定的參考價(jià)值,感興趣的小伙伴們可以參考一下2018-07-07Python-OpenCV實(shí)現(xiàn)圖像缺陷檢測(cè)的實(shí)例
本文將結(jié)合實(shí)例代碼,在Jupyter Notebook上使用Python+opencv實(shí)現(xiàn)如下圖像缺陷檢測(cè)。需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧2021-06-06在Pandas中DataFrame數(shù)據(jù)合并,連接(concat,merge,join)的實(shí)例
今天小編就為大家分享一篇在Pandas中DataFrame數(shù)據(jù)合并,連接(concat,merge,join)的實(shí)例,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過來看看吧2019-01-01python中的[1:]、[::-1]、X[:,m:n]和X[1,:]的使用
本文主要介紹了python中的[1:]、[::-1]、X[:,m:n]和X[1,:]的使用,文中通過示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧2022-08-08