Python線性網(wǎng)絡(luò)實(shí)現(xiàn)分類糖尿病病例
1. 加載數(shù)據(jù)集
這次我們搭建一個(gè)小小的多層線性網(wǎng)絡(luò)對(duì)糖尿病的病例進(jìn)行分類
首先先導(dǎo)入需要的庫(kù)文件
先來(lái)看看我們的數(shù)據(jù)集
觀察可以發(fā)現(xiàn),前八列是我們的feature ,根據(jù)這八個(gè)特征可以判斷出病人是否得了糖尿病。所以最后一列是1,0 的一個(gè)二分類問(wèn)題
我們使用numpy 去導(dǎo)入數(shù)據(jù)集,delimiter 是定義分隔符,這里我們用逗號(hào)(,)分割
將前八列的特征放到我們的x_data里面,作為特征輸入,最后一列放到y(tǒng)_data作為label
Tip :這里y_data 里面的 [-1] 中括號(hào)不可以省略,否則y_data會(huì)變成向量的形式
如果不習(xí)慣這種寫法,可以用view改變一下形狀就行
y_data = torch.from_numpy(xy[:,-1]).view(-1,1) #將y_data 的代碼改成這樣就可以了
下面是xy , x_data , y_data 打印出前兩行的結(jié)果
2. 搭建網(wǎng)絡(luò)+優(yōu)化器
搭建網(wǎng)絡(luò)的時(shí)候,要保證兩層網(wǎng)絡(luò)之間的維數(shù)能對(duì)應(yīng)上
首先第一層的時(shí)候,因?yàn)榍鞍肆凶鳛槲覀兊膞_data ,也就是說(shuō)我們輸入的特征是 8 維度的,那么由于 y = x * wT + b ,因?yàn)檩斎霐?shù)據(jù)的x是(n * 8) 的,而我們定義的y維度是(n * 6) ,所以wT的維度應(yīng)該是(8,6)
這里不需要知道啥時(shí)候轉(zhuǎn)置,啥時(shí)候不轉(zhuǎn)置之類的,只要滿足線性的方程y = w*x+b,并且維度一致就行了。因?yàn)椴还苁寝D(zhuǎn)置,或者w和x誰(shuí)在前,只是為了保證滿足矩陣相乘而已
一個(gè)小的技巧就是:只需要看輸入特征是多少,然后保證第一層第一個(gè)參數(shù)對(duì)應(yīng)就行了,然后第一層第二個(gè)參數(shù)是想輸出的維度。其次是第二層的第一個(gè)參數(shù)對(duì)應(yīng)第一層第二個(gè)參數(shù),以此類推....
我們采用的激活函數(shù)是ReLU , 由于是二元分類,最后一個(gè)網(wǎng)絡(luò)的輸出我們采用sigmoid輸出
接下來(lái),搭建實(shí)例化我們的網(wǎng)絡(luò),然后建立優(yōu)化器
這里我們選擇SGD隨機(jī)梯度下降算法,學(xué)習(xí)率設(shè)置為0.01
3. 訓(xùn)練網(wǎng)絡(luò)
訓(xùn)練網(wǎng)絡(luò)的過(guò)程較為簡(jiǎn)單,大概的過(guò)程為
1. 計(jì)算預(yù)測(cè)值
2. 計(jì)算損失函數(shù)
3. 反向傳播,之前要進(jìn)行梯度清零
4. 梯度更新
5. 重復(fù)這個(gè)過(guò)程,epoch 為所有樣本計(jì)算一次的周期,這次讓epoch 迭代1000次
4. 代碼
import torch.nn as nn # 神經(jīng)網(wǎng)絡(luò)庫(kù) import matplotlib.pyplot as plt # 繪圖 import torch # 張量 from torch import optim # 優(yōu)化器庫(kù) import numpy as np # 數(shù)據(jù)處理 xy = np.loadtxt('./diabetes.csv.gz',delimiter=',',dtype=np.float32) # 加載數(shù)據(jù)集 x_data = torch.from_numpy(xy[:,:-1]) # 所有行,除了最后一列的元素 y_data = torch.from_numpy(xy[:,-1]).view(-1,1) # -1也能拿出來(lái)是向量,但是[-1]會(huì)保證拿出來(lái)的是個(gè)矩陣 epoch_list =[] loss_list = [] class Model(nn.Module): def __init__(self): super(Model,self).__init__() self.linear1 = nn.Linear(8,6) self.linear2 = nn.Linear(6,3) self.linear3 = nn.Linear(3,1) self.sigmoid = nn.Sigmoid() self.relu = nn.ReLU() def forward(self,x): x = self.relu(self.linear1(x)) x = self.relu(self.linear2(x)) x = self.sigmoid(self.linear3(x)) return x model = Model() criterion = nn.BCELoss() optimizer = optim.SGD(model.parameters(),lr =0.01) for epoch in range(1000): y_pred = model(x_data) loss = criterion(y_pred,y_data) # 計(jì)算損失 if epoch % 100 ==0: # 每隔100次打印一下 print(epoch,loss.item()) #back propagation optimizer.zero_grad() # 梯度清零 loss.backward() # 反向傳播 optimizer.step() # 梯度更新 epoch_list.append(epoch) loss_list.append(loss.item()) plt.plot(epoch_list,loss_list) plt.show()
輸出結(jié)果為:
到此這篇關(guān)于Python線性網(wǎng)絡(luò)實(shí)現(xiàn)分類糖尿病病例的文章就介紹到這了,更多相關(guān)Python線性網(wǎng)絡(luò)內(nèi)容請(qǐng)搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
相關(guān)文章
tensorflow2.0實(shí)現(xiàn)復(fù)雜神經(jīng)網(wǎng)絡(luò)(多輸入多輸出nn,Resnet)
這篇文章主要介紹了tensorflow2.0實(shí)現(xiàn)復(fù)雜神經(jīng)網(wǎng)絡(luò)(多輸入多輸出nn,Resnet),文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來(lái)一起學(xué)習(xí)學(xué)習(xí)吧2021-03-03Python調(diào)用edge-tts實(shí)現(xiàn)在線文字轉(zhuǎn)語(yǔ)音效果
edge-tts是一個(gè) Python 模塊,允許通過(guò)Python代碼或命令的方式使用 Microsoft Edge 的在線文本轉(zhuǎn)語(yǔ)音服務(wù),這篇文章主要介紹了Python調(diào)用edge-tts實(shí)現(xiàn)在線文字轉(zhuǎn)語(yǔ)音效果,需要的朋友可以參考下2024-03-03Python如何查看并打印matplotlib中所有的colormap(cmap)類型
這篇文章主要介紹了Python如何查看并打印matplotlib中所有的colormap(cmap)類型,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。如有錯(cuò)誤或未考慮完全的地方,望不吝賜教2022-11-11pip已經(jīng)安裝好第三方庫(kù)但pycharm中import時(shí)還是標(biāo)紅的解決方案
這篇文章主要介紹了python中pip已經(jīng)安裝好第三方庫(kù)但pycharm中import時(shí)還是標(biāo)紅的問(wèn)題,本文給大家分享解決方法,對(duì)大家的學(xué)習(xí)或工作具有一定的參考借鑒價(jià)值,需要的朋友可以參考下2020-10-10解決Pycharm無(wú)法import自己安裝的第三方module問(wèn)題
今天小編就為大家分享一篇解決Pycharm無(wú)法import自己安裝的第三方module問(wèn)題,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧2018-05-05詳談pandas中agg函數(shù)和apply函數(shù)的區(qū)別
下面小編就為大家分享一篇詳談pandas中agg函數(shù)和apply函數(shù)的區(qū)別,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧2018-04-04解決Python pandas plot輸出圖形中顯示中文亂碼問(wèn)題
今天小編就為大家分享一篇解決Python pandas plot輸出圖形中顯示中文亂碼問(wèn)題,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧2018-12-12