PyTorch模型容器與AlexNet構建示例詳解
模型容器與AlexNet構建
文章和代碼已經(jīng)歸檔至【Github倉庫:https://github.com/timerring/dive-into-AI 】
除了上述的模塊之外,還有一個重要的概念是模型容器 (Containers),常用的容器有 3 個,這些容器都是繼承自nn.Module
。
- nn.Sequetial:按照順序包裝多個網(wǎng)絡層
- nn.ModuleList:像 python 的 list 一樣包裝多個網(wǎng)絡層,可以迭代
- nn.ModuleDict:像 python 的 dict 一樣包裝多個網(wǎng)絡層,通過 (key, value) 的方式為每個網(wǎng)絡層指定名稱。
nn.Sequetial
深度學習中,特征提取和分類器這兩步被融合到了一個神經(jīng)網(wǎng)絡中。在卷積神經(jīng)網(wǎng)絡中,前面的卷積層以及池化層可以認為是特征提取部分,而后面的全連接層可以認為是分類器部分。比如 LeNet 就可以分為特征提取和分類器兩部分,這 2 部分都可以分別使用 nn.Seuqtial
來包裝。
代碼如下:
class LeNetSequetial(nn.Module): def __init__(self, classes): super(LeNet2, self).__init__() self.features = nn.Sequential( nn.Conv2d(3, 6, 5), nn.ReLU(), nn.MaxPool2d(2, 2), nn.Conv2d(6, 16, 5), nn.ReLU(), nn.MaxPool2d(2, 2) ) self.classifier = nn.Sequential( nn.Linear(16*5*5, 120), nn.ReLU(), nn.Linear(120, 84), nn.ReLU(), nn.Linear(84, classes) ) def forward(self, x): x = self.features(x) x = x.view(x.size()[0], -1) x = self.classifier(x) return x
在初始化時,nn.Sequetial
會調(diào)用__init__()
方法,將每一個子 module 添加到 自身的_modules
屬性中。這里可以看到,我們傳入的參數(shù)可以是一個 list,或者一個 OrderDict。如果是一個 OrderDict,那么則使用 OrderDict 里的 key,否則使用數(shù)字作為 key。
def __init__(self, *args): super(Sequential, self).__init__() if len(args) == 1 and isinstance(args[0], OrderedDict): for key, module in args[0].items(): self.add_module(key, module) else: for idx, module in enumerate(args): self.add_module(str(idx), module)
網(wǎng)絡初始化完成后有兩個子 module
:features
和classifier
。
而features
中的子 module 如下,每個網(wǎng)絡層以序號作為 key:
在進行前向傳播時,會進入 LeNet 的forward()
函數(shù),首先調(diào)用第一個Sequetial
容器:self.features
,由于self.features
也是一個 module,因此會調(diào)用__call__()
函數(shù),里面調(diào)用
result = self.forward(*input, **kwargs)
,進入nn.Seuqetial
的forward()
函數(shù),在這里依次調(diào)用所有的 module。上一個module的輸出是下一個module的輸入。
def forward(self, input): for module in self: input = module(input) return input
在上面可以看到在nn.Sequetial
中,里面的每個子網(wǎng)絡層 module 是使用序號來索引的,即使用數(shù)字來作為key。
一旦網(wǎng)絡層增多,難以查找特定的網(wǎng)絡層,這種情況可以使用 OrderDict (有序字典)??梢耘c上面的代碼對比一下
class LeNetSequentialOrderDict(nn.Module): def __init__(self, classes): super(LeNetSequentialOrderDict, self).__init__() self.features = nn.Sequential(OrderedDict({ 'conv1': nn.Conv2d(3, 6, 5), 'relu1': nn.ReLU(inplace=True), 'pool1': nn.MaxPool2d(kernel_size=2, stride=2), 'conv2': nn.Conv2d(6, 16, 5), 'relu2': nn.ReLU(inplace=True), 'pool2': nn.MaxPool2d(kernel_size=2, stride=2), })) self.classifier = nn.Sequential(OrderedDict({ 'fc1': nn.Linear(16*5*5, 120), 'relu3': nn.ReLU(), 'fc2': nn.Linear(120, 84), 'relu4': nn.ReLU(inplace=True), 'fc3': nn.Linear(84, classes), })) ... ... ...
總結(jié)
nn.Sequetial
是nn.Module
的容器,用于按順序包裝一組網(wǎng)絡層,有以下兩個特性。
- 順序性:各網(wǎng)絡層之間嚴格按照順序構建,我們在構建網(wǎng)絡時,一定要注意前后網(wǎng)絡層之間輸入和輸出數(shù)據(jù)之間的形狀是否匹配
- 自帶
forward()
函數(shù):在nn.Sequetial
的forward()
函數(shù)里通過 for 循環(huán)依次讀取每個網(wǎng)絡層,執(zhí)行前向傳播運算。這使得我們我們構建的模型更加簡潔
nn.ModuleList
nn.ModuleList
是nn.Module
的容器,用于包裝一組網(wǎng)絡層,以迭代的方式調(diào)用網(wǎng)絡層,主要有以下 3 個方法:
- append():在 ModuleList 后面添加網(wǎng)絡層
- extend():拼接兩個 ModuleList
- insert():在 ModuleList 的指定位置中插入網(wǎng)絡層
下面的代碼通過列表生成式來循環(huán)迭代創(chuàng)建 20 個全連接層,非常方便,只是在 forward()
函數(shù)中需要手動調(diào)用每個網(wǎng)絡層。
class ModuleList(nn.Module): def __init__(self): super(ModuleList, self).__init__() self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(20)]) def forward(self, x): for i, linear in enumerate(self.linears): x = linear(x) return x net = ModuleList() print(net) fake_data = torch.ones((10, 10)) output = net(fake_data) print(output)
nn.ModuleDict
nn.ModuleDict
是nn.Module
的容器,用于包裝一組網(wǎng)絡層,以索引的方式調(diào)用網(wǎng)絡層,主要有以下 5 個方法:
- clear():清空 ModuleDict
- items():返回可迭代的鍵值對 (key, value)
- keys():返回字典的所有 key
- values():返回字典的所有 value
- pop():返回一對鍵值,并從字典中刪除
下面的模型創(chuàng)建了兩個ModuleDict
:self.choices
和self.activations
,在前向傳播時通過傳入對應的 key 來執(zhí)行對應的網(wǎng)絡層。
class ModuleDict(nn.Module): def __init__(self): super(ModuleDict, self).__init__() self.choices = nn.ModuleDict({ 'conv': nn.Conv2d(10, 10, 3), 'pool': nn.MaxPool2d(3) }) self.activations = nn.ModuleDict({ 'relu': nn.ReLU(), 'prelu': nn.PReLU() }) def forward(self, x, choice, act): x = self.choices[choice](x) x = self.activations[act](x) return x net = ModuleDict() fake_img = torch.randn((4, 10, 32, 32)) output = net(fake_img, 'conv', 'relu') # output = net(fake_img, 'conv', 'prelu') print(output)
容器總結(jié)
- nn.Sequetial:順序性,各網(wǎng)絡層之間嚴格按照順序執(zhí)行,常用于 block 構建,在前向傳播時的代碼調(diào)用變得簡潔
- nn.ModuleList:迭代行,常用于大量重復網(wǎng)絡構建,通過 for 循環(huán)實現(xiàn)重復構建
- nn.ModuleDict:索引性,常用于可選擇的網(wǎng)絡層
AlexNet實現(xiàn)
AlexNet 特點如下:
- 采用 ReLU 替換飽和激活函數(shù),減輕梯度消失
- 采用 LRN (Local Response Normalization) 對數(shù)據(jù)進行局部歸一化,減輕梯度消失
- 采用 Dropout 提高網(wǎng)絡的魯棒性,增加泛化能力
- 使用 Data Augmentation,包括 TenCrop 和一些色彩修改
AlexNet 的網(wǎng)絡結(jié)構可以分為兩部分:features 和 classifier。
可以在計算機視覺庫torchvision.models
中找到 AlexNet 的代碼,通過看可知使用了nn.Sequential
來封裝網(wǎng)絡層。
class AlexNet(nn.Module): def __init__(self, num_classes=1000): super(AlexNet, self).__init__() self.features = nn.Sequential( nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2), nn.ReLU(inplace=True), nn.MaxPool2d(kernel_size=3, stride=2), nn.Conv2d(64, 192, kernel_size=5, padding=2), nn.ReLU(inplace=True), nn.MaxPool2d(kernel_size=3, stride=2), nn.Conv2d(192, 384, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.Conv2d(384, 256, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.Conv2d(256, 256, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.MaxPool2d(kernel_size=3, stride=2), ) self.avgpool = nn.AdaptiveAvgPool2d((6, 6)) self.classifier = nn.Sequential( nn.Dropout(), nn.Linear(256 * 6 * 6, 4096), nn.ReLU(inplace=True), nn.Dropout(), nn.Linear(4096, 4096), nn.ReLU(inplace=True), nn.Linear(4096, num_classes), ) def forward(self, x): x = self.features(x) x = self.avgpool(x) x = torch.flatten(x, 1) x = self.classifier(x) return x
以上就是PyTorch模型容器與AlexNet構建示例詳解的詳細內(nèi)容,更多關于PyTorch AlexNet構建的資料請關注腳本之家其它相關文章!
相關文章
在Django中管理Users和Permissions以及Groups的方法
這篇文章主要介紹了在Django中管理Users和Permissions以及Groups的方法,Django是最具人氣的Python web開發(fā)框架,需要的朋友可以參考下2015-07-07Python使用shutil操作文件、subprocess運行子程序
這篇文章介紹了Python使用shutil操作文件、subprocess運行子程序的方法,文中通過示例代碼介紹的非常詳細。對大家的學習或工作具有一定的參考借鑒價值,需要的朋友可以參考下2022-05-05