欧美bbbwbbbw肥妇,免费乱码人妻系列日韩,一级黄片

pytorch從csv加載自定義數(shù)據(jù)模板的操作

 更新時間:2021年03月06日 09:21:35   作者:追夢小狂魔  
這篇文章主要介紹了pytorch從csv加載自定義數(shù)據(jù)模板的操作,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧

整理了一套模板,全注釋了,這個難點終于克服了

from PIL import Image
import pandas as pd
import numpy as np
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
import os
#放文件的路徑
dir_path= './97/train/'
csv_path='./97/train.csv'
class Mydataset(Dataset):
 #傳遞數(shù)據(jù)路徑,csv路徑 ,數(shù)據(jù)增強方法
 def __init__(self, dir_path,csv, transform=None, target_transform=None):
  super(Mydataset, self).__init__()
  #一個個往列表里面加絕對路徑
  self.path = []
  #讀取csv
  self.data = pd.read_csv(csv)
  #對標簽進行硬編碼,例如0 1 2 3 4,把字母變成這個
  colorMap = {elem: index + 1 for index, elem in enumerate(set(self.data["label"]))}
  self.data['label'] = self.data['label'].map(colorMap)
  #創(chuàng)造空的label準備存放標簽
  self.num = int(self.data.shape[0]) # 一共多少照片
  self.label = np.zeros(self.num, dtype=np.int32)
  #迭代得到數(shù)據(jù)路徑和標簽一一對應
  for index, row in self.data.iterrows():
   self.path.append(os.path.join(dir_path,row['filename']))
   self.label[index] = row['label'] # 將數(shù)據(jù)全部讀取出來
  #訓練數(shù)據(jù)增強
  self.transform = transform
  #驗證數(shù)據(jù)增強在這里沒用
  self.target_transform = target_transform
 #最關鍵的部分,在這里使用前面的方法
 def __getitem__(self, index):
  img =Image.open(self.path[index]).convert('RGB')
  labels = self.label[index]
  #在這里做數(shù)據(jù)增強
  if self.transform is not None:
   img = self.transform(img) # 轉化tensor類型
  return img, labels
 def __len__(self):
  return len(self.data)
#數(shù)據(jù)增強的具體內容
transform = transforms.Compose(
 [transforms.ToTensor(),
  transforms.Resize(150),
  transforms.CenterCrop(150),
  transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)
#加載數(shù)據(jù)
train_data = Mydataset(dir_path=dir_path,csv=csv_path, transform=transform)
trainloader = DataLoader(train_data, batch_size=16, shuffle=True, num_workers=0)
#迭代訓練
for i_batch,batch_data in enumerate(trainloader):
 image,label=batch_data

補充:pytorch—定義自己的數(shù)據(jù)集及加載訓練

筆記:pytorch Conv2d 的寬高公式理解,pytorch 使用自己的數(shù)據(jù)集并且加載訓練

一、pypi 鏡像使用幫助

pypi 鏡像每 5 分鐘同步一次。

臨時使用

pip install -i https://pypi.tuna.tsinghua.edu.cn/simple some-package

注意,simple 不能少, 是 https 而不是 http

設為默認

修改 ~/.config/pip/pip.conf (Linux), %APPDATA%\pip\pip.ini (Windows 10)$HOME/Library/Application Support/pip/pip.conf (macOS) (沒有就創(chuàng)建一個), 修改 index-urltuna,例如

[global]
index-url = https://pypi.tuna.tsinghua.edu.cn/simple

pip 和 pip3 并存時,只需修改 ~/.pip/pip.conf。

二、pytorch Conv2d 的寬高公式理解

三、pytorch 使用自己的數(shù)據(jù)集并且加載訓練

import os
import sys
import numpy as np
import cv2
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
import time
import random
import csv
from PIL import Image
def createImgIndex(dataPath, ratio):
 '''
 讀取目錄下面的圖片制作包含圖片信息、圖片label的train.txt和val.txt
 dataPath: 圖片目錄路徑
 ratio: val占比
 return:label列表
 '''
 fileList = os.listdir(dataPath)
 random.shuffle(fileList)
 classList = [] # label列表
 # val 數(shù)據(jù)集制作
 with open('data/val_section1015.csv', 'w') as f:
  writer = csv.writer(f)
  for i in range(int(len(fileList)*ratio)):
   row = []
   if '.jpg' in fileList[i]:
    fileInfo = fileList[i].split('_')
    sectionName = fileInfo[0] + '_' + fileInfo[1] # 切面名+標準與否
    row.append(os.path.join(dataPath, fileList[i])) # 圖片路徑
    if sectionName not in classList:
     classList.append(sectionName)
    row.append(classList.index(sectionName))
    writer.writerow(row)
  f.close()
 # train 數(shù)據(jù)集制作
 with open('data/train_section1015.csv', 'w') as f:
  writer = csv.writer(f)
  for i in range(int(len(fileList) * ratio)+1, len(fileList)):
   row = []
   if '.jpg' in fileList[i]:
    fileInfo = fileList[i].split('_')
    sectionName = fileInfo[0] + '_' + fileInfo[1] # 切面名+標準與否
    row.append(os.path.join(dataPath, fileList[i])) # 圖片路徑
    if sectionName not in classList:
     classList.append(sectionName)
    row.append(classList.index(sectionName))
    writer.writerow(row)
  f.close()
 print(classList, len(classList))
 return classList
def default_loader(path):
 '''定義讀取文件的格式'''
 return Image.open(path).resize((128, 128),Image.ANTIALIAS).convert('RGB')
class MyDataset(Dataset):
 '''Dataset類是讀入數(shù)據(jù)集數(shù)據(jù)并且對讀入的數(shù)據(jù)進行索引'''
 def __init__(self, txt, transform=None, target_transform=None, loader=default_loader):
  super(MyDataset, self).__init__() #對繼承自父類的屬性進行初始化
  fh = open(txt, 'r') #按照傳入的路徑和txt文本參數(shù),以只讀的方式打開這個文本
  reader = csv.reader(fh)
  imgs = []
  for row in reader:
   imgs.append((row[0], int(row[1]))) # (圖片信息,lable)
  self.imgs = imgs
  self.transform = transform
  self.target_transform = target_transform
  self.loader = loader
 
 def __getitem__(self, index):
  '''用于按照索引讀取每個元素的具體內容'''
  # fn是圖片path #fn和label分別獲得imgs[index]也即是剛才每行中row[0]和row[1]的信息
  fn, label = self.imgs[index]
  img = self.loader(fn)
  if self.transform is not None:
   img = self.transform(img) #數(shù)據(jù)標簽轉換為Tensor
  return img, label
 
 def __len__(self):
  '''返回數(shù)據(jù)集的長度'''
  return len(self.imgs)
class Model(nn.Module):
 def __init__(self, classNum=31):
  super(Model, self).__init__()
  # torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
  # torch.nn.MaxPool2d(kernel_size, stride, padding)
  # input 維度 [3, 128, 128]
  self.cnn = nn.Sequential(
   nn.Conv2d(3, 64, 3, 1, 1), # [64, 128, 128]
   nn.BatchNorm2d(64),
   nn.ReLU(),
   nn.MaxPool2d(2, 2, 0), # [64, 64, 64]
   nn.Conv2d(64, 128, 3, 1, 1), # [128, 64, 64]
   nn.BatchNorm2d(128),
   nn.ReLU(),
   nn.MaxPool2d(2, 2, 0), # [128, 32, 32]
   nn.Conv2d(128, 256, 3, 1, 1), # [256, 32, 32]
   nn.BatchNorm2d(256),
   nn.ReLU(),
   nn.MaxPool2d(2, 2, 0), # [256, 16, 16]
   nn.Conv2d(256, 512, 3, 1, 1), # [512, 16, 16]
   nn.BatchNorm2d(512),
   nn.ReLU(),
   nn.MaxPool2d(2, 2, 0), # [512, 8, 8]
   nn.Conv2d(512, 512, 3, 1, 1), # [512, 8, 8]
   nn.BatchNorm2d(512),
   nn.ReLU(),
   nn.MaxPool2d(2, 2, 0), # [512, 4, 4]
  )
  self.fc = nn.Sequential(
   nn.Linear(512 * 4 * 4, 1024),
   nn.ReLU(),
   nn.Linear(1024, 512),
   nn.ReLU(),
   nn.Linear(512, classNum)
  )
 def forward(self, x):
  out = self.cnn(x)
  out = out.view(out.size()[0], -1)
  return self.fc(out)
def train(train_set, train_loader, val_set, val_loader):
 model = Model()
 loss = nn.CrossEntropyLoss() # 因為是分類任務,所以loss function使用 CrossEntropyLoss
 optimizer = torch.optim.Adam(model.parameters(), lr=0.001) # optimizer 使用 Adam
 num_epoch = 10
 # 開始訓練
 for epoch in range(num_epoch):
  epoch_start_time = time.time()
  train_acc = 0.0
  train_loss = 0.0
  val_acc = 0.0
  val_loss = 0.0
  model.train() # train model會開放Dropout和BN
  for i, data in enumerate(train_loader):
   optimizer.zero_grad() # 用 optimizer 將 model 參數(shù)的 gradient 歸零
   train_pred = model(data[0]) # 利用 model 的 forward 函數(shù)返回預測結果
   batch_loss = loss(train_pred, data[1]) # 計算 loss
   batch_loss.backward() # tensor(item, grad_fn=<NllLossBackward>)
   optimizer.step() # 以 optimizer 用 gradient 更新參數(shù)
   train_acc += np.sum(np.argmax(train_pred.data.numpy(), axis=1) == data[1].numpy())
   train_loss += batch_loss.item()
  model.eval()
  with torch.no_grad(): # 不跟蹤梯度
   for i, data in enumerate(val_loader):
    # data = [imgData, labelList]
    val_pred = model(data[0])
    batch_loss = loss(val_pred, data[1])
    val_acc += np.sum(np.argmax(val_pred.data.numpy(), axis=1) == data[1].numpy())
    val_loss += batch_loss.item()
   # 打印結果
   print('[%03d/%03d] %2.2f sec(s) Train Acc: %3.6f Loss: %3.6f | Val Acc: %3.6f loss: %3.6f' % \
     (epoch + 1, num_epoch, time.time() - epoch_start_time, \
     train_acc / train_set.__len__(), train_loss / train_set.__len__(), val_acc / val_set.__len__(),
     val_loss / val_set.__len__()))
if __name__ == '__main__':
 dirPath = '/data/Matt/QC_images/test0916' # 圖片文件目錄
 createImgIndex(dirPath, 0.2)    # 創(chuàng)建train.txt, val.txt
 root = os.getcwd() + '/data/'
 train_data = MyDataset(txt=root+'train_section1015.csv', transform=transforms.ToTensor())
 val_data = MyDataset(txt=root+'val_section1015.csv', transform=transforms.ToTensor())
 train_loader = DataLoader(dataset=train_data, batch_size=6, shuffle=True, num_workers = 4)
 val_loader = DataLoader(dataset=val_data, batch_size=6, shuffle=False, num_workers = 4)
 # 開始訓練模型
 train(train_data, train_loader, val_data, val_loader)

以上為個人經驗,希望能給大家一個參考,也希望大家多多支持腳本之家。如有錯誤或未考慮完全的地方,望不吝賜教。

相關文章

  • 利用python+ffmpeg合并B站視頻及格式轉換的實例代碼

    利用python+ffmpeg合并B站視頻及格式轉換的實例代碼

    這篇文章主要介紹了利用python+ffmpeg合并B站視頻及格式轉換的實例代碼,本文通過實例代碼給大家介紹的非常詳細,對大家的學習或工作具有一定的參考借鑒價值,需要的朋友可以參考下
    2020-11-11
  • 對pandas中apply函數(shù)的用法詳解

    對pandas中apply函數(shù)的用法詳解

    下面小編就為大家分享一篇對pandas中apply函數(shù)的用法詳解,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧
    2018-04-04
  • Python中__str__()的妙用

    Python中__str__()的妙用

    本文主要介紹了Python中__str__()的妙用,文中通過示例代碼介紹的非常詳細,對大家的學習或者工作具有一定的參考學習價值,需要的朋友們下面隨著小編來一起學習學習吧
    2023-01-01
  • Python+Pygame實現(xiàn)之走四棋兒游戲的實現(xiàn)

    Python+Pygame實現(xiàn)之走四棋兒游戲的實現(xiàn)

    大家以前應該都聽說過一個游戲:叫做走四棋兒。直接在家里的水泥地上用燒完的炭火灰畫出幾條線,擺上幾顆石頭子即可。當時的火爆程度可謂是達到了一個新的高度。本文將利用Pygame實現(xiàn)這一游戲,需要的可以參考一下
    2022-07-07
  • 新手入門學習python Numpy基礎操作

    新手入門學習python Numpy基礎操作

    這篇文章主要介紹了新手入門學習python Numpy基礎操作,文中通過示例代碼介紹的非常詳細,對大家的學習或者工作具有一定的參考學習價值,需要的朋友可以參考下
    2020-03-03
  • python3如何使用Requests測試帶簽名的接口

    python3如何使用Requests測試帶簽名的接口

    這篇文章主要介紹了python3如何使用Requests測試帶簽名的接口,具有很好的參考價值,希望對大家有所幫助。如有錯誤或未考慮完全的地方,望不吝賜教
    2022-02-02
  • Python小實例混合使用turtle和tkinter讓小海龜互動起來

    Python小實例混合使用turtle和tkinter讓小海龜互動起來

    Tkinter模塊("Tk 接口")是Python的標準Tk GUI工具包的接口.Tk和Tkinter可以在大多數(shù)的Unix平臺下使用,同樣可以應用在Windows和Macintosh系統(tǒng)里.Tk8.0的后續(xù)版本可以實現(xiàn)本地窗口風格,并良好地運行在絕大多數(shù)平臺中
    2021-10-10
  • python實現(xiàn)梯度下降和邏輯回歸

    python實現(xiàn)梯度下降和邏輯回歸

    這篇文章主要為大家詳細介紹了python實現(xiàn)梯度下降和邏輯回歸,文中示例代碼介紹的非常詳細,具有一定的參考價值,感興趣的小伙伴們可以參考一下
    2020-03-03
  • Python基于property實現(xiàn)類的特性操作示例

    Python基于property實現(xiàn)類的特性操作示例

    這篇文章主要介紹了Python基于property實現(xiàn)類的特性,結合實例形式分析了使用property實現(xiàn)類的特性相關操作技巧與注意事項,需要的朋友可以參考下
    2018-06-06
  • Django框架請求生命周期實現(xiàn)原理

    Django框架請求生命周期實現(xiàn)原理

    這篇文章主要介紹了Django框架請求生命周期實現(xiàn)原理,文中通過示例代碼介紹的非常詳細,對大家的學習或者工作具有一定的參考學習價值,需要的朋友可以參考下
    2020-11-11

最新評論