欧美bbbwbbbw肥妇,免费乱码人妻系列日韩,一级黄片

利用Pytorch實現(xiàn)獲取特征圖的方法詳解

 更新時間:2022年10月13日 11:41:55   作者:拜陽  
這篇文章主要為大家詳細介紹了如何利用Pytorch實現(xiàn)獲取特征圖,包括提取單個特征圖和提取多個特征圖,文中的示例代碼講解詳細,需要的可以參考一下

簡單加載官方預(yù)訓(xùn)練模型

torchvision.models預(yù)定義了很多公開的模型結(jié)構(gòu)

如果pretrained參數(shù)設(shè)置為False,那么僅僅設(shè)定模型結(jié)構(gòu);如果設(shè)置為True,那么會啟動一個下載流程,下載預(yù)訓(xùn)練參數(shù)

如果只想調(diào)用模型,不想訓(xùn)練,那么設(shè)置model.eval()和model.requires_grad_(False)

想查看模型參數(shù)可以使用modules和named_modules,其中named_modules是一個長度為2的tuple,第一個變量是name,第二個變量是module本身。

# -*- coding: utf-8 -*-
from torch import nn
from torchvision import models

# load model. If pretrained is True, there will be a downloading process
model = models.vgg19(pretrained=True)
model.eval()
model.requires_grad_(False)

# get model component
features = model.features
modules = features.modules()
named_modules = features.named_modules()

# print modules
for module in modules:
    if isinstance(module, nn.Conv2d):
        weight = module.weight
        bias = module.bias
        print(module, weight.shape, bias.shape,
              weight.requires_grad, bias.requires_grad)
    elif isinstance(module, nn.ReLU):
        print(module)

print()
for named_module in named_modules:
    name = named_module[0]
    module = named_module[1]
    if isinstance(module, nn.Conv2d):
        weight = module.weight
        bias = module.bias
        print(name, module, weight.shape, bias.shape,
              weight.requires_grad, bias.requires_grad)
    elif isinstance(module, nn.ReLU):
        print(name, module)

圖片預(yù)處理

使用opencv和pil讀圖都可以使用transforms.ToTensor()把原本[H, W, 3]的數(shù)據(jù)轉(zhuǎn)成[3, H, W]的tensor。但opencv要注意把數(shù)據(jù)改成RGB順序。

vgg系列模型需要做normalization,建議配合torchvision.transforms來實現(xiàn)。

mini-batches of 3-channel RGB images of shape (3 x H x W), where H and W are expected to be at least 224. The images have to be loaded in to a range of [0, 1] and then normalized using mean = [0.485, 0.456, 0.406] and std = [0.229, 0.224, 0.225].

參考:https://pytorch.org/hub/pytorch_vision_vgg/

# -*- coding: utf-8 -*-
from PIL import Image
import cv2
import torch
from torchvision import transforms

# transforms for preprocess
preprocess = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

# load image using cv2
image_cv2 = cv2.imread('lena_std.bmp')
image_cv2 = cv2.cvtColor(image_cv2, cv2.COLOR_BGR2RGB)
image_cv2 = preprocess(image_cv2)

# load image using pil
image_pil = Image.open('lena_std.bmp')
image_pil = preprocess(image_pil)

# check whether image_cv2 and image_pil are same
print(torch.all(image_cv2 == image_pil))
print(image_cv2.shape, image_pil.shape)

提取單個特征圖

如果只提取單層特征圖,可以把模型截斷,以節(jié)省算力和顯存消耗。

下面索引之所以有+1是因為pytorch預(yù)訓(xùn)練模型里面第一個索引的module總是完整模塊結(jié)構(gòu),第二個才開始子模塊。

# -*- coding: utf-8 -*-
from PIL import Image
from torchvision import models
from torchvision import transforms

# load model. If pretrained is True, there will be a downloading process
model = models.vgg19(pretrained=True)
model = model.features[:16 + 1]  # 16 = conv3_4
model.eval()
model.requires_grad_(False)
model.to('cuda')
print(model)

# load and preprocess image
preprocess = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
    transforms.Resize(size=(224, 224))
])
image = Image.open('lena_std.bmp')
image = preprocess(image)
inputs = image.unsqueeze(0)  # add batch dimension
inputs = inputs.cuda()

# forward
output = model(inputs)
print(output.shape)

提取多個特征圖

第一種方式:逐層運行model,如果碰到了需要保存的feature map就存下來。

第二種方式:使用register_forward_hook,使用這種方式需要用一個類把feature map以成員變量的形式緩存下來。

兩種方式的運行效率差不多

第一種方式簡單直觀,但是只能處理類似VGG這種沒有跨層連接的網(wǎng)絡(luò);第二種方式更加通用。

# -*- coding: utf-8 -*-
from PIL import Image
import torch
from torchvision import models
from torchvision import transforms

# load model. If pretrained is True, there will be a downloading process
model = models.vgg19(pretrained=True)
model = model.features[:16 + 1]  # 16 = conv3_4
model.eval()
model.requires_grad_(False)
model.to('cuda')

# check module name
for named_module in model.named_modules():
    name = named_module[0]
    module = named_module[1]
    print('-------- %s --------' % name)
    print(module)
    print()

# load and preprocess image
preprocess = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
    transforms.Resize(size=(224, 224))
])
image = Image.open('lena_std.bmp')
image = preprocess(image)
inputs = image.unsqueeze(0)  # add batch dimension
inputs = inputs.cuda()

# forward - 1
layers = [2, 7, 8, 9, 16]
layers = sorted(set(layers))
feature_maps = {}
feature = inputs
for i in range(max(layers) + 1):
    feature = model[i](feature)
    if i in layers:
        feature_maps[i] = feature
for key in feature_maps:
    print(key, feature_maps.get(key).shape)


# forward - 2
class FeatureHook:
    def __init__(self, module):
        self.inputs = None
        self.output = None
        self.hook = module.register_forward_hook(self.get_features)

    def get_features(self, module, inputs, output):
        self.inputs = inputs
        self.output = output


layer_names = ['2', '7', '8', '9', '16']
hook_modules = []
for named_module in model.named_modules():
    name = named_module[0]
    module = named_module[1]
    if name in layer_names:
        hook_modules.append(module)

hooks = [FeatureHook(module) for module in hook_modules]
output = model(inputs)
features = [hook.output for hook in hooks]
for feature in features:
    print(feature.shape)

# check correctness
for i, layer in enumerate(layers):
    feature1 = feature_maps.get(layer)
    feature2 = features[i]
    print(torch.all(feature1 == feature2))

使用第二種方式(register_forward_hook),resnet特征圖也可以順利拿到。

而由于resnet的model已經(jīng)不可以用model[i]的形式索引,所以無法使用第一種方式。

# -*- coding: utf-8 -*-
from PIL import Image
from torchvision import models
from torchvision import transforms

# load model. If pretrained is True, there will be a downloading process
model = models.resnet18(pretrained=True)
model.eval()
model.requires_grad_(False)
model.to('cuda')

# check module name
for named_module in model.named_modules():
    name = named_module[0]
    module = named_module[1]
    print('-------- %s --------' % name)
    print(module)
    print()

# load and preprocess image
preprocess = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
    transforms.Resize(size=(224, 224))
])
image = Image.open('lena_std.bmp')
image = preprocess(image)
inputs = image.unsqueeze(0)  # add batch dimension
inputs = inputs.cuda()


class FeatureHook:
    def __init__(self, module):
        self.inputs = None
        self.output = None
        self.hook = module.register_forward_hook(self.get_features)

    def get_features(self, module, inputs, output):
        self.inputs = inputs
        self.output = output


layer_names = [
    'conv1',
    'layer1.0.relu',
    'layer2.0.conv1'
]

hook_modules = []
for named_module in model.named_modules():
    name = named_module[0]
    module = named_module[1]
    if name in layer_names:
        hook_modules.append(module)

hooks = [FeatureHook(module) for module in hook_modules]
output = model(inputs)
features = [hook.output for hook in hooks]
for feature in features:
    print(feature.shape)

問題來了,resnet這種類型的網(wǎng)絡(luò)結(jié)構(gòu)怎么截斷?

使用如下命令就可以,print查看需要截斷到哪里,然后用nn.Sequential重組即可。

需注意重組后網(wǎng)絡(luò)的module_name會發(fā)生變化。

print(list(model.children())
model = torch.nn.Sequential(*list(model.children())[:6])

以上就是利用Pytorch實現(xiàn)獲取特征圖的方法詳解的詳細內(nèi)容,更多關(guān)于Pytorch獲取特征圖的資料請關(guān)注腳本之家其它相關(guān)文章!

相關(guān)文章

  • python--shutil移動文件到另一個路徑的操作

    python--shutil移動文件到另一個路徑的操作

    這篇文章主要介紹了python--shutil移動文件到另一個路徑的操作,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧
    2020-07-07
  • python字符串基礎(chǔ)操作詳解

    python字符串基礎(chǔ)操作詳解

    這篇文章主要為大家詳細介紹了python字符串基礎(chǔ)操作,,文中示例代碼介紹的非常詳細,具有一定的參考價值,感興趣的小伙伴們可以參考一下,希望能夠給你帶來幫助
    2022-01-01
  • python樹的同構(gòu)學(xué)習(xí)筆記

    python樹的同構(gòu)學(xué)習(xí)筆記

    在本篇文章里小編給大家整理的是一篇關(guān)于python樹的同構(gòu)學(xué)習(xí)筆記以及相關(guān)實例代碼內(nèi)容,有需要的朋友們學(xué)習(xí)下。
    2019-09-09
  • Python?時間操作datetime詳情

    Python?時間操作datetime詳情

    這篇文章主要介紹了?Python?時間操作datetime,datetime?模塊提供處理時間和日期的多種類,簡單方便,下面文章將詳細介紹其內(nèi)容,需要的朋友可以參考一下
    2021-11-11
  • 3分鐘看懂Python后端必須知道的Django的信號機制

    3分鐘看懂Python后端必須知道的Django的信號機制

    這篇文章主要介紹了3分鐘看懂Python后端必須知道的Django的信號機制,文中通過示例代碼介紹的非常詳細,對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧
    2020-07-07
  • python3中sort和sorted使用與區(qū)別

    python3中sort和sorted使用與區(qū)別

    python3中sort()和sorted()都可以用來排序,本文主要介紹了python3中sort和sorted使用與區(qū)別,具有一定的參考價值,感興趣的可以了解一下
    2024-02-02
  • python35種繪圖函數(shù)詳細總結(jié)

    python35種繪圖函數(shù)詳細總結(jié)

    Python有許多用于繪圖的函數(shù)和庫,比如Matplotlib,Plotly,Bokeh,Seaborn等,這只是一些常用的繪圖函數(shù)和庫,Python還有其他繪圖工具,如Pandas、ggplot等,選擇適合你需求的庫,可以根據(jù)你的數(shù)據(jù)類型、圖形需求和個人偏好來決定,本文給大家總結(jié)了python35種繪圖函數(shù)
    2023-08-08
  • 讓python json encode datetime類型

    讓python json encode datetime類型

    python2.6+ 自帶的json模塊,不支持datetime的json encode,每次都需要手動轉(zhuǎn)為字符串,很累人,我們可以自己封裝一個簡單的方法處理此問題。
    2010-12-12
  • python3.9不支持pillow包解決辦法

    python3.9不支持pillow包解決辦法

    本文主要介紹了python3.9不支持pillow包解決辦法,文中通過示例代碼介紹的非常詳細,對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧
    2024-06-06
  • Python多進程multiprocessing、進程池用法實例分析

    Python多進程multiprocessing、進程池用法實例分析

    這篇文章主要介紹了Python多進程multiprocessing、進程池用法,結(jié)合實例形式分析了Python多進程multiprocessing、進程池相關(guān)概念、原理、用法及操作注意事項,需要的朋友可以參考下
    2020-03-03

最新評論