欧美bbbwbbbw肥妇,免费乱码人妻系列日韩,一级黄片

Python基于Pytorch的特征圖提取實(shí)例

 更新時(shí)間:2022年03月29日 10:00:19   作者:淘啊淘啊淘  
在利用深度學(xué)習(xí)進(jìn)行分類時(shí),有時(shí)需要對(duì)中間的特征圖進(jìn)行提取操作,下面這篇文章主要給大家介紹了關(guān)于Python基于Pytorch的特征圖提取的相關(guān)資料,需要的朋友可以參考下

簡(jiǎn)述

為了方便理解卷積神經(jīng)網(wǎng)絡(luò)的運(yùn)行過(guò)程,需要對(duì)卷積神經(jīng)網(wǎng)絡(luò)的運(yùn)行結(jié)果進(jìn)行可視化的展示。

大致可分為如下步驟:

  • 單個(gè)圖片的提取
  • 神經(jīng)網(wǎng)絡(luò)的構(gòu)建
  • 特征圖的提取
  • 可視化展示

單個(gè)圖片的提取

根據(jù)目標(biāo)要求,需要對(duì)單個(gè)圖片進(jìn)行卷積運(yùn)算,但是Pytorch中讀取數(shù)據(jù)主要用到torch.utils.data.DataLoader類,因此我們需要編寫單個(gè)圖片的讀取程序

def get_picture(picture_dir, transform):
    '''
    該算法實(shí)現(xiàn)了讀取圖片,并將其類型轉(zhuǎn)化為Tensor
    '''
    tmp = []
    img = skimage.io.imread(picture_dir)
    tmp.append(img)
    img = skimage.io.imread('./picture/4.jpg')
    tmp.append(img)
    img256 = [skimage.transform.resize(img, (256, 256)) for img in tmp]
    img256 = np.asarray(img256)
    img256 = img256.astype(np.float32)

    return transform(img256[0])

注意: 神經(jīng)網(wǎng)絡(luò)的輸入是四維形式,我們返回的圖片是三維形式,需要使用unsqueeze()插入一個(gè)維度

神經(jīng)網(wǎng)絡(luò)的構(gòu)建

網(wǎng)絡(luò)的基于LeNet構(gòu)建,不過(guò)為了方便展示,將其中的參數(shù)按照2562563進(jìn)行的參數(shù)的修正

網(wǎng)絡(luò)構(gòu)建如下:

class LeNet(nn.Module):
    '''
    該類繼承了torch.nn.Modul類
    構(gòu)建LeNet神經(jīng)網(wǎng)絡(luò)模型
    '''
    def __init__(self):
        super(LeNet, self).__init__()

        # 第一層神經(jīng)網(wǎng)絡(luò),包括卷積層、線性激活函數(shù)、池化層
        self.conv1 = nn.Sequential( 
            nn.Conv2d(3, 32, 5, 1, 2),   # input_size=(3*256*256),padding=2
            nn.ReLU(),                  # input_size=(32*256*256)
            nn.MaxPool2d(kernel_size=2, stride=2),  # output_size=(32*128*128)
        )

        # 第二層神經(jīng)網(wǎng)絡(luò),包括卷積層、線性激活函數(shù)、池化層
        self.conv2 = nn.Sequential(
            nn.Conv2d(32, 64, 5, 1, 2),  # input_size=(32*128*128)
            nn.ReLU(),            # input_size=(64*128*128)
            nn.MaxPool2d(2, 2)    # output_size=(64*64*64)
        )

        # 全連接層(將神經(jīng)網(wǎng)絡(luò)的神經(jīng)元的多維輸出轉(zhuǎn)化為一維)
        self.fc1 = nn.Sequential(
            nn.Linear(64 * 64 * 64, 128),  # 進(jìn)行線性變換
            nn.ReLU()                    # 進(jìn)行ReLu激活
        )

        # 輸出層(將全連接層的一維輸出進(jìn)行處理)
        self.fc2 = nn.Sequential(
            nn.Linear(128, 84),
            nn.ReLU()
        )

        # 將輸出層的數(shù)據(jù)進(jìn)行分類(輸出預(yù)測(cè)值)
        self.fc3 = nn.Linear(84, 62)

    # 定義前向傳播過(guò)程,輸入為x
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        # nn.Linear()的輸入輸出都是維度為一的值,所以要把多維度的tensor展平成一維
        x = x.view(x.size()[0], -1)
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.fc3(x)
        return x

特征圖的提取

直接上代碼:

class FeatureExtractor(nn.Module):
    def __init__(self, submodule, extracted_layers):
        super(FeatureExtractor, self).__init__()
        self.submodule = submodule
        self.extracted_layers = extracted_layers
 
    def forward(self, x):
        outputs = []
        for name, module in self.submodule._modules.items():
        # 目前不展示全連接層
            if "fc" in name: 
                x = x.view(x.size(0), -1)
            print(module)
            x = module(x)
            print(name)
            if name in self.extracted_layers:
                outputs.append(x)
        return outputs

可視化展示

可視化展示使用matplotlib

代碼如下:

    # 特征輸出可視化
    for i in range(32):
        ax = plt.subplot(6, 6, i + 1)
        ax.set_title('Feature {}'.format(i))
        ax.axis('off')
        plt.imshow(x[0].data.numpy()[0,i,:,:],cmap='jet')
    plt.plot()

完整代碼

在此貼上完整代碼

import os
import torch
import torchvision as tv
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim
import argparse
import skimage.data
import skimage.io
import skimage.transform
import numpy as np
import matplotlib.pyplot as plt

# 定義是否使用GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load training and testing datasets.
pic_dir = './picture/3.jpg'

# 定義數(shù)據(jù)預(yù)處理方式(將輸入的類似numpy中arrary形式的數(shù)據(jù)轉(zhuǎn)化為pytorch中的張量(tensor))
transform = transforms.ToTensor()


def get_picture(picture_dir, transform):
    '''
    該算法實(shí)現(xiàn)了讀取圖片,并將其類型轉(zhuǎn)化為Tensor
    '''
    img = skimage.io.imread(picture_dir)
    img256 = skimage.transform.resize(img, (256, 256))
    img256 = np.asarray(img256)
    img256 = img256.astype(np.float32)

    return transform(img256)


def get_picture_rgb(picture_dir):
    '''
    該函數(shù)實(shí)現(xiàn)了顯示圖片的RGB三通道顏色
    '''
    img = skimage.io.imread(picture_dir)
    img256 = skimage.transform.resize(img, (256, 256))
    skimage.io.imsave('./picture/4.jpg',img256)

    # 取單一通道值顯示
    # for i in range(3):
    #     img = img256[:,:,i]
    #     ax = plt.subplot(1, 3, i + 1)
    #     ax.set_title('Feature {}'.format(i))
    #     ax.axis('off')
    #     plt.imshow(img)

    # r = img256.copy()
    # r[:,:,0:2]=0
    # ax = plt.subplot(1, 4, 1)
    # ax.set_title('B Channel')
    # # ax.axis('off')
    # plt.imshow(r)

    # g = img256.copy()
    # g[:,:,0]=0
    # g[:,:,2]=0
    # ax = plt.subplot(1, 4, 2)
    # ax.set_title('G Channel')
    # # ax.axis('off')
    # plt.imshow(g)

    # b = img256.copy()
    # b[:,:,1:3]=0
    # ax = plt.subplot(1, 4, 3)
    # ax.set_title('R Channel')
    # # ax.axis('off')
    # plt.imshow(b)

    # img = img256.copy()
    # ax = plt.subplot(1, 4, 4)
    # ax.set_title('image')
    # # ax.axis('off')
    # plt.imshow(img)

    img = img256.copy()
    ax = plt.subplot()
    ax.set_title('image')
    # ax.axis('off')
    plt.imshow(img)

    plt.show()


class LeNet(nn.Module):
    '''
    該類繼承了torch.nn.Modul類
    構(gòu)建LeNet神經(jīng)網(wǎng)絡(luò)模型
    '''
    def __init__(self):
        super(LeNet, self).__init__()

        # 第一層神經(jīng)網(wǎng)絡(luò),包括卷積層、線性激活函數(shù)、池化層
        self.conv1 = nn.Sequential( 
            nn.Conv2d(3, 32, 5, 1, 2),   # input_size=(3*256*256),padding=2
            nn.ReLU(),                  # input_size=(32*256*256)
            nn.MaxPool2d(kernel_size=2, stride=2),  # output_size=(32*128*128)
        )

        # 第二層神經(jīng)網(wǎng)絡(luò),包括卷積層、線性激活函數(shù)、池化層
        self.conv2 = nn.Sequential(
            nn.Conv2d(32, 64, 5, 1, 2),  # input_size=(32*128*128)
            nn.ReLU(),            # input_size=(64*128*128)
            nn.MaxPool2d(2, 2)    # output_size=(64*64*64)
        )

        # 全連接層(將神經(jīng)網(wǎng)絡(luò)的神經(jīng)元的多維輸出轉(zhuǎn)化為一維)
        self.fc1 = nn.Sequential(
            nn.Linear(64 * 64 * 64, 128),  # 進(jìn)行線性變換
            nn.ReLU()                    # 進(jìn)行ReLu激活
        )

        # 輸出層(將全連接層的一維輸出進(jìn)行處理)
        self.fc2 = nn.Sequential(
            nn.Linear(128, 84),
            nn.ReLU()
        )

        # 將輸出層的數(shù)據(jù)進(jìn)行分類(輸出預(yù)測(cè)值)
        self.fc3 = nn.Linear(84, 62)

    # 定義前向傳播過(guò)程,輸入為x
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        # nn.Linear()的輸入輸出都是維度為一的值,所以要把多維度的tensor展平成一維
        x = x.view(x.size()[0], -1)
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.fc3(x)
        return x

# 中間特征提取
class FeatureExtractor(nn.Module):
    def __init__(self, submodule, extracted_layers):
        super(FeatureExtractor, self).__init__()
        self.submodule = submodule
        self.extracted_layers = extracted_layers
 
    def forward(self, x):
        outputs = []
        print(self.submodule._modules.items())
        for name, module in self.submodule._modules.items():
            if "fc" in name: 
                print(name)
                x = x.view(x.size(0), -1)
            print(module)
            x = module(x)
            print(name)
            if name in self.extracted_layers:
                outputs.append(x)
        return outputs


def get_feature():
    # 輸入數(shù)據(jù)
    img = get_picture(pic_dir, transform)
    # 插入維度
    img = img.unsqueeze(0)

    img = img.to(device)

    # 特征輸出
    net = LeNet().to(device)
    # net.load_state_dict(torch.load('./model/net_050.pth'))
    exact_list = ["conv1","conv2"]
    myexactor = FeatureExtractor(net, exact_list)
    x = myexactor(img)

    # 特征輸出可視化
    for i in range(32):
        ax = plt.subplot(6, 6, i + 1)
        ax.set_title('Feature {}'.format(i))
        ax.axis('off')
        plt.imshow(x[0].data.numpy()[0,i,:,:],cmap='jet')

    plt.show()

# 訓(xùn)練
if __name__ == "__main__":
    get_picture_rgb(pic_dir)
    # get_feature()
    

總結(jié)

到此這篇關(guān)于Python基于Pytorch的特征圖提取的文章就介紹到這了,更多相關(guān)Pytorch特征圖提取內(nèi)容請(qǐng)搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!

相關(guān)文章

  • python實(shí)現(xiàn)Virginia無(wú)密鑰解密

    python實(shí)現(xiàn)Virginia無(wú)密鑰解密

    這篇文章主要為大家詳細(xì)介紹了python實(shí)現(xiàn)Virginia無(wú)密鑰解密,具有一定的參考價(jià)值,感興趣的小伙伴們可以參考一下
    2019-03-03
  • python 利用openpyxl讀取Excel表格中指定的行或列教程

    python 利用openpyxl讀取Excel表格中指定的行或列教程

    這篇文章主要介紹了python 利用openpyxl讀取Excel表格中指定的行或列教程,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧
    2021-02-02
  • pytorch創(chuàng)建tensor函數(shù)詳情

    pytorch創(chuàng)建tensor函數(shù)詳情

    這篇文章主要介紹了pytorch創(chuàng)建tensor函數(shù)詳情,文章圍繞tensor函數(shù)的相關(guān)自來(lái)哦展開詳細(xì)內(nèi)容的介紹,需要的小伙伴可以參考一下,希望對(duì)你有所幫助
    2022-03-03
  • python pytest進(jìn)階之conftest.py詳解

    python pytest進(jìn)階之conftest.py詳解

    這篇文章主要介紹了python pytest進(jìn)階之conftest.py詳解,如果我們?cè)诰帉憸y(cè)試用的時(shí)候,每一個(gè)測(cè)試文件里面的用例都需要先登錄后才能完成后面的操作,那么們?cè)撊绾螌?shí)現(xiàn)呢?這就需要我們掌握conftest.py文件的使用了,需要的朋友可以參考下
    2019-06-06
  • 用Python制作簡(jiǎn)單的樸素基數(shù)估計(jì)器的教程

    用Python制作簡(jiǎn)單的樸素基數(shù)估計(jì)器的教程

    這篇文章主要介紹了用Python制作簡(jiǎn)單的樸素基數(shù)估計(jì)器的教程,同時(shí)介紹了如何去改進(jìn)精度來(lái)進(jìn)行算法優(yōu)化,需要的朋友可以參考下
    2015-04-04
  • 基于Python的圖像閾值化分割(迭代法)

    基于Python的圖像閾值化分割(迭代法)

    這篇文章主要介紹了基于Python的圖像閾值化分割(迭代法),文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來(lái)一起學(xué)習(xí)學(xué)習(xí)吧
    2020-11-11
  • Python讀寫Excel表格的方法

    Python讀寫Excel表格的方法

    這篇文章主要為大家詳細(xì)介紹了Python讀寫Excel表格的方法,文中示例代碼介紹的非常詳細(xì),具有一定的參考價(jià)值,感興趣的小伙伴們可以參考一下
    2021-03-03
  • Python并發(fā)編程隊(duì)列與多線程最快發(fā)送http請(qǐng)求方式

    Python并發(fā)編程隊(duì)列與多線程最快發(fā)送http請(qǐng)求方式

    假如有一個(gè)文件,里面有10萬(wàn)個(gè)url,需要對(duì)每個(gè)url發(fā)送http請(qǐng)求,并打印請(qǐng)求結(jié)果的狀態(tài)碼,如何編寫代碼盡可能快的完成這些任務(wù)呢
    2021-09-09
  • 解決python中使用plot畫圖,圖不顯示的問(wèn)題

    解決python中使用plot畫圖,圖不顯示的問(wèn)題

    今天小編就為大家分享一篇解決python中使用plot畫圖,圖不顯示的問(wèn)題,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧
    2018-07-07
  • python中l(wèi)ist常用操作實(shí)例詳解

    python中l(wèi)ist常用操作實(shí)例詳解

    這篇文章主要介紹了python中l(wèi)ist常用操作,以實(shí)例形式較為詳細(xì)的分析了列表list中常用的建立、添加、刪除、搜索、過(guò)濾等操作技巧,需要的朋友可以參考下
    2015-06-06

最新評(píng)論