欧美bbbwbbbw肥妇,免费乱码人妻系列日韩,一级黄片

Pytorch卷積神經(jīng)網(wǎng)絡(luò)遷移學習的目標及好處

 更新時間:2022年05月12日 14:40:45   作者:淺念念52  
這篇文章主要為大家介紹了Pytorch卷積神經(jīng)網(wǎng)絡(luò)遷移學習的目標實現(xiàn)代碼及好處介紹,有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進步,早日升職加薪

前言

在深度學習訓練的過程中,隨著網(wǎng)絡(luò)層數(shù)的提升,我們訓練的次數(shù),參數(shù)都會提高,訓練時間相應(yīng)就會增加,我們今天來了解遷移學習

一、經(jīng)典的卷積神經(jīng)網(wǎng)絡(luò)

在pytorch官網(wǎng)中,我們可以看到許多經(jīng)典的卷積神經(jīng)網(wǎng)絡(luò)。

附官網(wǎng)鏈接:https://pytorch.org/

這里簡單介紹一下經(jīng)典的卷積神經(jīng)發(fā)展歷程

1.首先可以說是卷積神經(jīng)網(wǎng)絡(luò)的開山之作Alexnet(12年的奪冠之作)這里簡單說一下缺點 卷積核大,步長大,沒有填充層,大刀闊斧的提取特征,容易忽略一些重要的特征

2.第二個就是VGG網(wǎng)絡(luò),它的卷積核大小是3*3,有一個優(yōu)點是經(jīng)過池化層之后,通道數(shù)翻倍,可以更多的保留一些特征,這是VGG的一個特點

在接下來的一段時間中,出現(xiàn)了一個問題,我們都知道,深度學習隨著訓練次數(shù)的不斷增加,效果應(yīng)該是越來越好,但是這里出現(xiàn)了一個問題,研究發(fā)現(xiàn)隨著VGG網(wǎng)絡(luò)的不斷提高,效果卻沒有原來的好,這時候人們就認為,深度學習是不是只能發(fā)展到這里了,這時遇到了一個瓶頸。

3.接下來隨著殘差網(wǎng)絡(luò)(Resnet)的提出,解決了上面這個問題,這個網(wǎng)絡(luò)的優(yōu)點是保留了原有的特征,假如經(jīng)過卷積之后提取的特征還沒有原圖的好,這時候保留原有的特征,就會解決這一問題,下面就是resnet網(wǎng)絡(luò)模型

這是一些訓練對比:

二、遷移學習的目標

首先我們使用遷移學習的目標就是用人家訓練好的權(quán)重參數(shù),偏置參數(shù),來訓練我們的模型。

三、好處

深度學習要訓練的數(shù)據(jù)量是很大的,當我們數(shù)據(jù)量少時,我們訓練的權(quán)重參數(shù)就不會那么的好,所以這時候我們就可以使用別人訓練好的權(quán)重參數(shù),偏置參數(shù)來使用,會使我們的模型準確率得到提高

四、步驟

遷移學習大致可以分為三步

1.加載模型

2.凍結(jié)層數(shù)

3.全連接層

五、代碼

這里使用的是resnet152

import torch
import torchvision as tv
import torch.nn as nn
import torchvision
import torch.nn.functional as F
import torchvision.transforms as transforms
import torch
from torch.utils import data
from torch import optim
from torch.autograd import Variable
model_name='resnet'
featuer_extract=True
train_on_gpu=torch.cuda.is_available()
if not train_on_gpu:
    print("沒有g(shù)pu")
else :
    print("是gpu")
devic=torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')
teature_extract=True
def set_paremeter_requires_grad(model,featuer_extract):
    if featuer_extract:
        for parm in model.parameters():
            parm.requires_grad=False   #不做訓練
def initialize_model(model_name,num_classes,featuer_extract,use_pretrained=True):
    model_ft = None
    input_size = 0
    if model_name=="resnet":
        model_ft=tv.models.resnet152(pretrained=use_pretrained)#下載模型
        set_paremeter_requires_grad(model_ft,featuer_extract) #凍結(jié)層數(shù)
        num_ftrs=model_ft.fc.in_features #改動全連接層
        model_ft.fc=nn.Sequential(nn.Linear(num_ftrs,num_classes),
                                  nn.LogSoftmax(dim=1))
        input_size=224 #輸入維度
    return  model_ft,input_size
model_ft,iput_size=initialize_model(model_name,10,featuer_extract,use_pretrained=True)
model_ft=model_ft.to(devic)
params_to_updata=model_ft.parameters()
if featuer_extract:
    params_to_updata=[]
    for name,param in model_ft.named_parameters():
        if param.requires_grad==True:
            params_to_updata.append(param)
            print("\t",name)
else:
    for name,param in model_ft.parameters():
        if param.requires_grad==True:
            print("\t",name)
opt=optim.Adam(params_to_updata,lr=0.01)
loss=nn.NLLLoss()
if __name__ == '__main__':
    transform = transforms.Compose([
        # 圖像增強
        transforms.Resize(1024),#裁剪
        transforms.RandomHorizontalFlip(),#隨機水平翻轉(zhuǎn)
        transforms.RandomCrop(224),#隨機裁剪
        transforms.ColorJitter(brightness=0.5, contrast=0.5, hue=0.5), #亮度
        # 轉(zhuǎn)變?yōu)閠ensor 正則化
        transforms.ToTensor(), #轉(zhuǎn)換格式
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))  # 歸一化處理
    ])
    trainset = tv.datasets.CIFAR10(
        root=r'E:\桌面\資料\cv3\數(shù)據(jù)集\cifar-10-batches-py',
        train=True,
        download=True,
        transform=transform
    )
    trainloader = data.DataLoader(
        trainset,
        batch_size=8,
        drop_last=True,
        shuffle=True,  # 亂序
        num_workers=4,
    )
    testset = tv.datasets.CIFAR10(
        root=r'E:\桌面\資料\cv3\數(shù)據(jù)集\cifar-10-batches-py',
        train=False,
        download=True,
        transform=transform
    )
    testloader = data.DataLoader(
        testset,
        batch_size=8,
        drop_last=True,
        shuffle=False,
        num_workers=4
    )
    for epoch in range(3):
        running_loss=0
        for index,data in enumerate(trainloader,0):
            inputs, labels = data
            inputs = inputs.to(devic)
            labels = labels.to(devic)
            inputs, labels = Variable(inputs), Variable(labels)
            opt.zero_grad()
            h=model_ft(inputs)
            loss1=loss(h,labels)
            loss1.backward()
            opt.step()
            h+=loss1.item()
            if index%10==9:
                avg_loss=loss1/10.
                running_loss=0
                print('avg_loss',avg_loss)
            if index%100==99 :
                correct=0
                total=0
                for data in testloader:
                    images,labels=data
                    outputs=model_ft(Variable(images.cuda()))
                    _,predicted=torch.max(outputs.cpu(),1)
                    total+=labels.size(0)
                    bool_tensor=(predicted==labels)
                    correct+=bool_tensor.sum()
                print('1000張測試集中的準確率為%d   %%'%(100*correct/total))

以上就是Pytorch卷積神經(jīng)網(wǎng)絡(luò)遷移學習的目標及好處的詳細內(nèi)容,更多關(guān)于Pytorch卷積神經(jīng)網(wǎng)絡(luò)遷移的資料請關(guān)注腳本之家其它相關(guān)文章!

相關(guān)文章

  • 2021年的Python 時間軸和即將推出的功能詳解

    2021年的Python 時間軸和即將推出的功能詳解

    這篇文章主要介紹了2021年的Python 時間軸和即將推出的功能,本文通過實例代碼給大家介紹的非常詳細,對大家的學習或工作具有一定的參考借鑒價值,需要的朋友可以參考下
    2020-07-07
  • pygame畫點線方法詳解

    pygame畫點線方法詳解

    這篇文章主要介紹了pygame畫點線的方法,文中通過示例代碼介紹的非常詳細,對大家的學習或者工作具有一定的參考學習價值,需要的朋友們下面隨著小編來一起學習吧
    2022-11-11
  • 利用python實現(xiàn)JSON文檔與Python對象互相轉(zhuǎn)換

    利用python實現(xiàn)JSON文檔與Python對象互相轉(zhuǎn)換

    這篇文章主要介紹了利用python實現(xiàn)JSON文檔與Python對象互相轉(zhuǎn)換,通過對將一個JSON文檔映射為Python對象問題的展開介紹主題內(nèi)容,需要的朋友可以參考一下
    2022-06-06
  • python 爬蟲 批量獲取代理ip的實例代碼

    python 爬蟲 批量獲取代理ip的實例代碼

    今天小編就為大家分享一篇python 爬蟲 批量獲取代理ip的實例代碼,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧
    2018-05-05
  • python根據(jù)出生日期獲得年齡的方法

    python根據(jù)出生日期獲得年齡的方法

    這篇文章主要介紹了python根據(jù)出生日期獲得年齡的方法,涉及Python操作日期的技巧,具有一定參考借鑒價值,需要的朋友可以參考下
    2015-03-03
  • 解析Python的縮進規(guī)則的使用

    解析Python的縮進規(guī)則的使用

    這篇文章主要介紹了解析Python的縮進規(guī)則的使用,小編覺得挺不錯的,現(xiàn)在分享給大家,也給大家做個參考。一起跟隨小編過來看看吧
    2019-01-01
  • Pycharm中SQL語句提示SQL Dialect is Not Configured的解決

    Pycharm中SQL語句提示SQL Dialect is Not Config

    這篇文章主要介紹了Pycharm中SQL語句提示SQL Dialect is Not Configured的解決方案,具有很好的參考價值,希望對大家有所幫助。
    2022-07-07
  • 淺析Django接口版本控制

    淺析Django接口版本控制

    一個項目在升級迭代的時候,不會立馬拋棄舊的版本,甚至會出現(xiàn)多個版本共存同時維護的情況,因此需要版本控制
    2021-06-06
  • python?列表常用方法超詳細梳理總結(jié)

    python?列表常用方法超詳細梳理總結(jié)

    這篇文章主要為大家介紹了Python中列表的幾個常用方法總結(jié),文中的示例代碼講解詳細,對我們學習Python列表有一定幫助,需要的可以參考一下
    2022-03-03
  • python數(shù)據(jù)可視化使用pyfinance分析證券收益示例詳解

    python數(shù)據(jù)可視化使用pyfinance分析證券收益示例詳解

    這篇文章主要為大家介紹了python數(shù)據(jù)可視化使用pyfinance分析證券收益的示例詳解及pyfinance中returns模塊的應(yīng)用,有需要的朋友可以借鑒參考下,希望能夠有所幫助
    2021-11-11

最新評論