欧美bbbwbbbw肥妇,免费乱码人妻系列日韩,一级黄片

PyTorch 遷移學(xué)習(xí)實(shí)踐(幾分鐘即可訓(xùn)練好自己的模型)

 更新時(shí)間:2021年03月26日 14:22:18   作者:YXHPY  
這篇文章主要介紹了PyTorch 遷移學(xué)習(xí)實(shí)踐(幾分鐘即可訓(xùn)練好自己的模型),文中通過示例代碼介紹的非常詳細(xì),對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧

前言

如果你認(rèn)為深度學(xué)習(xí)非常的吃GPU,或者說非常的耗時(shí)間,訓(xùn)練一個(gè)模型要非常久,但是你如果了解了遷移學(xué)習(xí)那你的模型可能只需要幾分鐘,而且準(zhǔn)確率不比你自己訓(xùn)練的模型準(zhǔn)確率低,本節(jié)我們將會(huì)介紹兩種方法來實(shí)現(xiàn)遷移學(xué)習(xí)

遷移學(xué)習(xí)方法介紹

  • 微調(diào)網(wǎng)絡(luò)的方法實(shí)現(xiàn)遷移學(xué)習(xí),更改最后一層全連接,并且微調(diào)訓(xùn)練網(wǎng)絡(luò)
  • 將模型看成特征提取器,如果一個(gè)模型的預(yù)訓(xùn)練模型非常的好,那完全就把前面的層看成特征提取器,凍結(jié)所有層并且更改最后一層,只訓(xùn)練最后一層,這樣我們只訓(xùn)練了最后一層,訓(xùn)練會(huì)非常的快速

在這里插入圖片描述 

遷移基本步驟

  •  數(shù)據(jù)的準(zhǔn)備
  • 選擇數(shù)據(jù)增廣的方式
  • 選擇合適的模型
  • 更換最后一層全連接
  • 凍結(jié)層,開始訓(xùn)練
  • 選擇預(yù)測結(jié)果最好的模型保存

需要導(dǎo)入的包

import zipfile # 解壓文件
import torchvision
from torchvision import datasets, transforms, models
import torch
from torch.utils.data import DataLoader, Dataset
import os
import cv2
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import copy

數(shù)據(jù)準(zhǔn)備

本次實(shí)驗(yàn)的數(shù)據(jù)到這里下載
首先按照上一章節(jié)講的數(shù)據(jù)讀取方法來準(zhǔn)備數(shù)據(jù)

# 解壓數(shù)據(jù)到指定文件
def unzip(filename, dst_dir):
  z = zipfile.ZipFile(filename)
  z.extractall(dst_dir)
unzip('./data/hymenoptera_data.zip', './data/')
# 實(shí)現(xiàn)自己的Dataset方法,主要實(shí)現(xiàn)兩個(gè)方法__len__和__getitem__
class MyDataset(Dataset):
  def __init__(self, dirname, transform=None):
    super(MyDataset, self).__init__()
    self.classes = os.listdir(dirname)
    self.images = []
    self.transform = transform
    for i, classes in enumerate(self.classes):
      classes_path = os.path.join(dirname, classes)
      for image_name in os.listdir(classes_path):
        self.images.append((os.path.join(classes_path, image_name), i))
  def __len__(self):
    return len(self.images)
  def __getitem__(self, idx):
    image_name, classes = self.images[idx]
    image = Image.open(image_name)
    if self.transform:
      image = self.transform(image)
    return image, classes
  def get_claesses(self):
    return self.classes
# 分布實(shí)現(xiàn)訓(xùn)練和預(yù)測的transform
train_transform = transforms.Compose([
  transforms.Grayscale(3),
  transforms.RandomResizedCrop(224), #隨機(jī)裁剪一個(gè)area然后再resize
  transforms.RandomHorizontalFlip(), #隨機(jī)水平翻轉(zhuǎn)
  transforms.Resize(size=(256, 256)),
  transforms.ToTensor(),
  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
val_transform = transforms.Compose([
  transforms.Grayscale(3),
  transforms.Resize(size=(256, 256)),
  transforms.CenterCrop(224),
  transforms.ToTensor(),
  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# 分別實(shí)現(xiàn)loader
train_dataset = MyDataset('./data/hymenoptera_data/train/', train_transform)
train_loader = DataLoader(train_dataset, shuffle=True, batch_size=32)
val_dataset = MyDataset('./data/hymenoptera_data/val/', val_transform)
val_loader = DataLoader(val_dataset, shuffle=True, batch_size=32)

選擇預(yù)訓(xùn)練的模型

這里我們選擇了resnet18在ImageNet 1000類上進(jìn)行了預(yù)訓(xùn)練的

model = models.resnet18(pretrained=True) # 使用預(yù)訓(xùn)練

使用model.buffers查看網(wǎng)絡(luò)基本結(jié)構(gòu)

<bound method Module.buffers of ResNet(
 (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
 (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
 (relu): ReLU(inplace=True)
 (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
 (layer1): Sequential(
  (0): BasicBlock(
   (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
   (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
   (relu): ReLU(inplace=True)
   (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
   (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (1): BasicBlock(
   (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
   (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
   (relu): ReLU(inplace=True)
   (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
   (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
 )
 (layer2): Sequential(
  (0): BasicBlock(
   (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
   (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
   (relu): ReLU(inplace=True)
   (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
   (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
   (downsample): Sequential(
    (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
    (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
   )
  )
  (1): BasicBlock(
   (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
   (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
   (relu): ReLU(inplace=True)
   (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
   (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
 )
 (layer3): Sequential(
  (0): BasicBlock(
   (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
   (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
   (relu): ReLU(inplace=True)
   (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
   (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
   (downsample): Sequential(
    (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
    (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
   )
  )
  (1): BasicBlock(
   (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
   (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
   (relu): ReLU(inplace=True)
   (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
   (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
 )
 (layer4): Sequential(
  (0): BasicBlock(
   (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
   (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
   (relu): ReLU(inplace=True)
   (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
   (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
   (downsample): Sequential(
    (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
    (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
   )
  )
  (1): BasicBlock(
   (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
   (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
   (relu): ReLU(inplace=True)
   (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
   (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
 )
 (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
 (fc): Linear(in_features=512, out_features=1000, bias=True)
)>

我們現(xiàn)在需要做的就是將最后一層進(jìn)行替換

only_train_fc = True
if only_train_fc:
  for param in model.parameters():
    param.requires_grad_(False)
fc_in_features = model.fc.in_features
model.fc = torch.nn.Linear(fc_in_features, 2, bias=True)

注釋:only_train_fc如果我們設(shè)置為True那么就只訓(xùn)練最后的fc層
現(xiàn)在觀察一下可導(dǎo)的參數(shù)有那些(在只訓(xùn)練最后一層的情況下)

for i in model.parameters():
  if i.requires_grad:
    print(i)
Parameter containing:
tensor([[ 0.0342, -0.0336, 0.0279, ..., -0.0428, 0.0421, 0.0366],
    [-0.0162, 0.0286, -0.0379, ..., -0.0203, -0.0016, -0.0440]],
    requires_grad=True)
Parameter containing:
tensor([-0.0120, -0.0086], requires_grad=True)

注釋:由于最后一層使用了bias因此我們會(huì)多加兩個(gè)參數(shù)

訓(xùn)練主體的實(shí)現(xiàn)

epochs = 50
loss_fn = torch.nn.CrossEntropyLoss()
opt = torch.optim.SGD(lr=0.01, params=model.parameters())
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# device = torch.device('cpu')
model.to(device)
opt_step = torch.optim.lr_scheduler.StepLR(opt, step_size=20, gamma=0.1)
max_acc = 0
epoch_acc = []
epoch_loss = []
for epoch in range(epochs):
  for type_id, loader in enumerate([train_loader, val_loader]):
    mean_loss = []
    mean_acc = []
    for images, labels in loader:
      if type_id == 0:
        # opt_step.step()
        model.train()
      else:
        model.eval()
      images = images.to(device)
      labels = labels.to(device).long()
      opt.zero_grad()
      with torch.set_grad_enabled(type_id==0):
        outputs = model(images)
        _, pre_labels = torch.max(outputs, 1)
        loss = loss_fn(outputs, labels)
      if type_id == 0:
        loss.backward()
        opt.step()
      acc = torch.sum(pre_labels==labels) / torch.tensor(labels.shape[0], dtype=torch.float32)    
      mean_loss.append(loss.cpu().detach().numpy())
      mean_acc.append(acc.cpu().detach().numpy())
    if type_id == 1:
      epoch_acc.append(np.mean(mean_acc))
      epoch_loss.append(np.mean(mean_loss))
      if max_acc < np.mean(mean_acc):
        max_acc = np.mean(mean_acc)
    print(type_id, np.mean(mean_loss),np.mean(mean_acc))
print(max_acc)

在使用cpu訓(xùn)練的情況,也能快速得到較好的結(jié)果,這里訓(xùn)練了50次,其實(shí)很快的就已經(jīng)得到了很好的結(jié)果了

在這里插入圖片描述

總結(jié)

本節(jié)我們使用了預(yù)訓(xùn)練模型,發(fā)現(xiàn)大概10個(gè)epoch就可以很快的得到較好的結(jié)果了,即使在使用cpu情況下訓(xùn)練,這也是遷移學(xué)習(xí)為什么這么受歡迎的原因之一了,如果讀者有興趣可以自己試一試在不凍結(jié)層的情況下,使用方法一能否得到更好的結(jié)果

到此這篇關(guān)于PyTorch 遷移學(xué)習(xí)實(shí)踐(幾分鐘即可訓(xùn)練好自己的模型)的文章就介紹到這了,更多相關(guān)PyTorch 遷移內(nèi)容請搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!

相關(guān)文章

  • python中decimal模塊的用法

    python中decimal模塊的用法

    本文主要介紹了python中decimal模塊的用法,文中通過示例代碼介紹的非常詳細(xì),對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧
    2023-02-02
  • Python函數(shù)命名空間,作用域LEGB及Global詳析

    Python函數(shù)命名空間,作用域LEGB及Global詳析

    這篇文章主要介紹了Python函數(shù)命名空間,作用域LEGB及Global詳析,文章圍繞主題展開詳細(xì)的內(nèi)容介紹,具有一定的參考價(jià)值,需要的朋友可以參考一下
    2022-09-09
  • 淺析Python的Django框架中的Memcached

    淺析Python的Django框架中的Memcached

    這篇文章主要介紹了淺析Python的Django框架中的緩存機(jī)制,其中著重講到了Memcached,需要的朋友可以參考下
    2015-07-07
  • Jupyter Lab無法打開終端窗口的解決方法

    Jupyter Lab無法打開終端窗口的解決方法

    本文主要介紹了Jupyter Lab無法打開終端窗口的解決方法,文中通過示例代碼介紹的非常詳細(xì),對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧
    2023-02-02
  • python實(shí)現(xiàn)Windows電腦定時(shí)關(guān)機(jī)

    python實(shí)現(xiàn)Windows電腦定時(shí)關(guān)機(jī)

    這篇文章主要為大家詳細(xì)介紹了python實(shí)現(xiàn)Windows電腦定時(shí)關(guān)機(jī)功能,具有一定的參考價(jià)值,感興趣的小伙伴們可以參考一下
    2018-06-06
  • Python計(jì)算公交發(fā)車時(shí)間的完整代碼

    Python計(jì)算公交發(fā)車時(shí)間的完整代碼

    這篇文章主要介紹了Python計(jì)算公交發(fā)車時(shí)間的完整代碼,代碼簡單易懂,非常不錯(cuò),具有一定的參考借鑒價(jià)值,需要的朋友可以參考下
    2020-02-02
  • 安裝PyTorch的詳細(xì)過程記錄

    安裝PyTorch的詳細(xì)過程記錄

    PyTorch是一個(gè)基于Python的科學(xué)計(jì)算框架,用于進(jìn)行深度學(xué)習(xí)相關(guān)研究,下面這篇文章主要給大家介紹了關(guān)于安裝PyTorch的詳細(xì)過程,文中通過圖文介紹的非常詳細(xì),需要的朋友可以參考下
    2022-03-03
  • Python實(shí)現(xiàn)Linux下守護(hù)進(jìn)程的編寫方法

    Python實(shí)現(xiàn)Linux下守護(hù)進(jìn)程的編寫方法

    這篇文章主要介紹了Python實(shí)現(xiàn)Linux下守護(hù)進(jìn)程的編寫方法,比較實(shí)用的一個(gè)技巧,需要的朋友可以參考下
    2014-08-08
  • Django中ORM表的創(chuàng)建和增刪改查方法示例

    Django中ORM表的創(chuàng)建和增刪改查方法示例

    這篇文章主要給大家介紹了關(guān)于Django中ORM表的創(chuàng)建和增刪改查等基本操作的方法,還給大家分享了django orm常用查詢篩選的相關(guān)內(nèi)容,分享出來供大家參考學(xué)習(xí),需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧。
    2017-11-11
  • 使用Python Fast API發(fā)布API服務(wù)的過程詳解

    使用Python Fast API發(fā)布API服務(wù)的過程詳解

    這篇文章主要介紹了使用Python Fast API發(fā)布API服務(wù),本文給大家介紹的非常詳細(xì),對大家的學(xué)習(xí)或工作具有一定的參考借鑒價(jià)值,需要的朋友可以參考下
    2023-04-04

最新評論