PyTorch加載數(shù)據(jù)集梯度下降優(yōu)化
一、實(shí)現(xiàn)過(guò)程
1、準(zhǔn)備數(shù)據(jù)
與PyTorch實(shí)現(xiàn)多維度特征輸入的邏輯回歸的方法不同的是:本文使用DataLoader
方法,并繼承DataSet抽象類,可實(shí)現(xiàn)對(duì)數(shù)據(jù)集進(jìn)行mini_batch
梯度下降優(yōu)化。
代碼如下:
import torch import numpy as np from torch.utils.data import Dataset,DataLoader class DiabetesDataSet(Dataset): ? ? def __init__(self, filepath): ? ? ? ? xy = np.loadtxt(filepath,delimiter=',',dtype=np.float32) ? ? ? ? self.len = xy.shape[0] ? ? ? ? self.x_data = torch.from_numpy(xy[:,:-1]) ? ? ? ? self.y_data = torch.from_numpy(xy[:,[-1]]) ? ? ? ?? ? ? def __getitem__(self, index): ? ? ? ? return self.x_data[index],self.y_data[index] ? ?? ? ? def __len__(self): ? ? ? ? return self.len dataset = DiabetesDataSet('G:/datasets/diabetes/diabetes.csv') train_loader = DataLoader(dataset=dataset,batch_size=32,shuffle=True,num_workers=0)
2、設(shè)計(jì)模型
class Model(torch.nn.Module): ? ? def __init__(self): ? ? ? ? super(Model,self).__init__() ? ? ? ? self.linear1 = torch.nn.Linear(8,6) ? ? ? ? self.linear2 = torch.nn.Linear(6,4) ? ? ? ? self.linear3 = torch.nn.Linear(4,1) ? ? ? ? self.activate = torch.nn.Sigmoid() ? ?? ? ? def forward(self, x): ? ? ? ? x = self.activate(self.linear1(x)) ? ? ? ? x = self.activate(self.linear2(x)) ? ? ? ? x = self.activate(self.linear3(x)) ? ? ? ? return x model = Model()
3、構(gòu)造損失函數(shù)和優(yōu)化器
criterion = torch.nn.BCELoss(reduction='mean') optimizer = torch.optim.SGD(model.parameters(),lr=0.1)
4、訓(xùn)練過(guò)程
每次拿出mini_batch個(gè)樣本進(jìn)行訓(xùn)練,代碼如下:
epoch_list = [] loss_list = [] for epoch in range(100): ? ? count = 0 ? ? loss1 = 0 ? ? for i, data in enumerate(train_loader,0): ? ? ? ? # 1.Prepare data ? ? ? ? inputs, labels = data ? ? ? ? # 2.Forward ? ? ? ? y_pred = model(inputs) ? ? ? ? loss = criterion(y_pred,labels) ? ? ? ? print(epoch,i,loss.item()) ? ? ? ? count += 1 ? ? ? ? loss1 += loss.item() ? ? ? ? # 3.Backward ? ? ? ? optimizer.zero_grad() ? ? ? ? loss.backward() ? ? ? ? # 4.Update ? ? ? ? optimizer.step() ? ? ? ?? ? ? epoch_list.append(epoch) ? ? loss_list.append(loss1/count)
5、結(jié)果展示
plt.plot(epoch_list,loss_list,'b') plt.xlabel('epoch') plt.ylabel('loss') plt.grid() plt.show()
二、參考文獻(xiàn)
到此這篇關(guān)于PyTorch加載數(shù)據(jù)集梯度下降優(yōu)化的文章就介紹到這了,更多相關(guān)PyTorch加載數(shù)據(jù)集內(nèi)容請(qǐng)搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
相關(guān)文章
pytorch 實(shí)現(xiàn)凍結(jié)部分參數(shù)訓(xùn)練另一部分
這篇文章主要介紹了pytorch 實(shí)現(xiàn)凍結(jié)部分參數(shù)訓(xùn)練另一部分,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧2021-03-03python爬蟲之利用selenium模塊自動(dòng)登錄CSDN
這篇文章主要介紹了python爬蟲之利用selenium模塊自動(dòng)登錄CSDN,文中有非常詳細(xì)的代碼示例,對(duì)正在學(xué)習(xí)python的小伙伴們有很好地幫助,需要的朋友可以參考下2021-04-04Python字符串本身作為bytes進(jìn)行解碼的問(wèn)題
這篇文章主要介紹了解決Python字符串本身作為bytes進(jìn)行解碼的問(wèn)題,文末給大家補(bǔ)充介紹了,Python字符串如何轉(zhuǎn)為bytes對(duì)象?Python字符串和bytes類型怎么互轉(zhuǎn),需要的朋友可以參考下2022-11-11python使用yield壓平嵌套字典的超簡(jiǎn)單方法
這篇文章主要給大家介紹了關(guān)于python使用yield壓平嵌套字典的超簡(jiǎn)單方法,文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者使用python具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面來(lái)一起學(xué)習(xí)學(xué)習(xí)吧2019-11-11Python實(shí)現(xiàn)識(shí)別圖像中人物的示例代碼
這篇文章主要介紹了通過(guò)face_recognition提供的demo代碼,簡(jiǎn)單調(diào)整了一下,從而實(shí)現(xiàn)識(shí)別圖像中人物的功能,感興趣的可以跟隨小編一起試試2022-01-01Python3.4實(shí)現(xiàn)遠(yuǎn)程控制電腦開關(guān)機(jī)
這篇文章主要為大家詳細(xì)介紹了Python3.4實(shí)現(xiàn)遠(yuǎn)程控制電腦開關(guān)機(jī)的方法,具有一定的參考價(jià)值,感興趣的小伙伴們可以參考一下2018-02-02