python人工智能使用RepVgg實(shí)現(xiàn)圖像分類示例詳解
摘要
RepVgg通過(guò)結(jié)構(gòu)重參數(shù)化讓VGG再次偉大。 所謂“VGG式”指的是:
- 沒(méi)有任何分支結(jié)構(gòu)。即通常所說(shuō)的plain或feed-forward架構(gòu)。
- 僅使用3x3卷積。
- 僅使用ReLU作為激活函數(shù)。
RepVGG的更深版本達(dá)到了84.16%正確率!反超若干transformer!
RepVgg是如何到的呢?簡(jiǎn)單地說(shuō)就是:
- 首先, 訓(xùn)練一個(gè)多分支模型
- 然后,將多分支模型等價(jià)轉(zhuǎn)換為單路模型
- 最在,在部署的時(shí)候,部署轉(zhuǎn)換后單路模型
我這篇文章主要講解如何使用RepVgg完成圖像分類任務(wù),接下來(lái)我們一起完成項(xiàng)目的實(shí)戰(zhàn)。

通過(guò)這篇文章能讓你學(xué)到:
- 如何使用數(shù)據(jù)增強(qiáng),包括transforms的增強(qiáng)、CutOut、MixUp、CutMix等增強(qiáng)手段?
- 如何實(shí)現(xiàn)RepVGG模型實(shí)現(xiàn)訓(xùn)練?
- 如何將多分支模型等價(jià)轉(zhuǎn)換為單路模型?
- 如何使用pytorch自帶混合精度?
- 如何使用梯度裁剪防止梯度爆炸?
- 如何使用DP多顯卡訓(xùn)練?
- 如何繪制loss和acc曲線?
- 如何生成val的測(cè)評(píng)報(bào)告?
- 如何編寫(xiě)測(cè)試腳本測(cè)試測(cè)試集?
- 如何使用余弦退火策略調(diào)整學(xué)習(xí)率?
- 如何使用AverageMeter類統(tǒng)計(jì)ACC和loss等自定義變量?
- 如何理解和統(tǒng)計(jì)ACC1和ACC5?
- 如何使用EMA?
安裝包
安裝timm
使用pip就行,命令:
pip install timm
數(shù)據(jù)增強(qiáng)Cutout和Mixup
為了提高成績(jī)我在代碼中加入Cutout和Mixup這兩種增強(qiáng)方式。實(shí)現(xiàn)這兩種增強(qiáng)需要安裝torchtoolbox。安裝命令:
pip install torchtoolbox
Cutout實(shí)現(xiàn),在transforms中。
from torchtoolbox.transform import Cutout
# 數(shù)據(jù)預(yù)處理
transform = transforms.Compose([
transforms.Resize((224, 224)),
Cutout(),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])
需要導(dǎo)入包:from timm.data.mixup import Mixup,
定義Mixup,和SoftTargetCrossEntropy
mixup_fn = Mixup(
mixup_alpha=0.8, cutmix_alpha=1.0, cutmix_minmax=None,
prob=0.1, switch_prob=0.5, mode='batch',
label_smoothing=0.1, num_classes=12)
criterion_train = SoftTargetCrossEntropy()
參數(shù)詳解:
mixup_alpha (float): mixup alpha 值,如果 > 0,則 mixup 處于活動(dòng)狀態(tài)。
cutmix_alpha (float):cutmix alpha 值,如果 > 0,cutmix 處于活動(dòng)狀態(tài)。
cutmix_minmax (List[float]):cutmix 最小/最大圖像比率,cutmix 處于活動(dòng)狀態(tài),如果不是 None,則使用這個(gè) vs alpha。
如果設(shè)置了 cutmix_minmax 則cutmix_alpha 默認(rèn)為1.0
prob (float): 每批次或元素應(yīng)用 mixup 或 cutmix 的概率。
switch_prob (float): 當(dāng)兩者都處于活動(dòng)狀態(tài)時(shí)切換cutmix 和mixup 的概率 。
mode (str): 如何應(yīng)用 mixup/cutmix 參數(shù)(每個(gè)'batch','pair'(元素對(duì)),'elem'(元素)。
correct_lam (bool): 當(dāng) cutmix bbox 被圖像邊框剪裁時(shí)應(yīng)用。 lambda 校正
label_smoothing (float):將標(biāo)簽平滑應(yīng)用于混合目標(biāo)張量。
num_classes (int): 目標(biāo)的類數(shù)。
EMA
EMA(Exponential Moving Average)是指數(shù)移動(dòng)平均值。在深度學(xué)習(xí)中的做法是保存歷史的一份參數(shù),在一定訓(xùn)練階段后,拿歷史的參數(shù)給目前學(xué)習(xí)的參數(shù)做一次平滑。具體實(shí)現(xiàn)如下:
class EMA():
def __init__(self, model, decay):
self.model = model
self.decay = decay
self.shadow = {}
self.backup = {}
def register(self):
for name, param in self.model.named_parameters():
if param.requires_grad:
self.shadow[name] = param.data.clone()
def update(self):
for name, param in self.model.named_parameters():
if param.requires_grad:
assert name in self.shadow
new_average = (1.0 - self.decay) * param.data + self.decay * self.shadow[name]
self.shadow[name] = new_average.clone()
def apply_shadow(self):
for name, param in self.model.named_parameters():
if param.requires_grad:
assert name in self.shadow
self.backup[name] = param.data
param.data = self.shadow[name]
def restore(self):
for name, param in self.model.named_parameters():
if param.requires_grad:
assert name in self.backup
param.data = self.backup[name]
self.backup = {}
加入到模型中。
# 初始化
ema = EMA(model, 0.999)
ema.register()
# 訓(xùn)練過(guò)程中,更新完參數(shù)后,同步update shadow weights
def train():
optimizer.step()
ema.update()
# eval前,apply shadow weights;eval之后,恢復(fù)原來(lái)模型的參數(shù)
def evaluate():
ema.apply_shadow()
# evaluate
ema.restore()
這個(gè)ema最好放在微調(diào)的時(shí)候使用,否則驗(yàn)證集不上分,或者上分很慢。
項(xiàng)目結(jié)構(gòu)
RepVgg_demo ├─data1 │ ├─Black-grass │ ├─Charlock │ ├─Cleavers │ ├─Common Chickweed │ ├─Common wheat │ ├─Fat Hen │ ├─Loose Silky-bent │ ├─Maize │ ├─Scentless Mayweed │ ├─Shepherds Purse │ ├─Small-flowered Cranesbill │ └─Sugar beet ├─models │ ├─__init__.py │ ├─repvgg.py │ └─se_block.py ├─mean_std.py ├─makedata.py ├─ema.py ├─train.py └─test.py
mean_std.py:計(jì)算mean和std的值。 makedata.py:生成數(shù)據(jù)集。 ema.py:EMA腳本 models文件夾下的repvgg.py和se_block.py:來(lái)自官方的pytorch版本的代碼。 - repvgg.py:網(wǎng)絡(luò)文件。 - se_block.py:SE注意力機(jī)制。
為了能在DP方式中使用混合精度,還需要在模型的forward函數(shù)前增加@autocast()。

計(jì)算mean和std
為了使模型更加快速的收斂,我們需要計(jì)算出mean和std的值,新建mean_std.py,插入代碼:
from torchvision.datasets import ImageFolder
import torch
from torchvision import transforms
def get_mean_and_std(train_data):
train_loader = torch.utils.data.DataLoader(
train_data, batch_size=1, shuffle=False, num_workers=0,
pin_memory=True)
mean = torch.zeros(3)
std = torch.zeros(3)
for X, _ in train_loader:
for d in range(3):
mean[d] += X[:, d, :, :].mean()
std[d] += X[:, d, :, :].std()
mean.div_(len(train_data))
std.div_(len(train_data))
return list(mean.numpy()), list(std.numpy())
if __name__ == '__main__':
train_dataset = ImageFolder(root=r'data1', transform=transforms.ToTensor())
print(get_mean_and_std(train_dataset))
數(shù)據(jù)集結(jié)構(gòu):

運(yùn)行結(jié)果:
([0.3281186, 0.28937867, 0.20702125], [0.09407319, 0.09732835, 0.106712654])
把這個(gè)結(jié)果記錄下來(lái),后面要用!
生成數(shù)據(jù)集
我們整理還的圖像分類的數(shù)據(jù)集結(jié)構(gòu)是這樣的
data ├─Black-grass ├─Charlock ├─Cleavers ├─Common Chickweed ├─Common wheat ├─Fat Hen ├─Loose Silky-bent ├─Maize ├─Scentless Mayweed ├─Shepherds Purse ├─Small-flowered Cranesbill └─Sugar beet
pytorch和keras默認(rèn)加載方式是ImageNet數(shù)據(jù)集格式,格式是
├─data │ ├─val │ │ ├─Black-grass │ │ ├─Charlock │ │ ├─Cleavers │ │ ├─Common Chickweed │ │ ├─Common wheat │ │ ├─Fat Hen │ │ ├─Loose Silky-bent │ │ ├─Maize │ │ ├─Scentless Mayweed │ │ ├─Shepherds Purse │ │ ├─Small-flowered Cranesbill │ │ └─Sugar beet │ └─train │ ├─Black-grass │ ├─Charlock │ ├─Cleavers │ ├─Common Chickweed │ ├─Common wheat │ ├─Fat Hen │ ├─Loose Silky-bent │ ├─Maize │ ├─Scentless Mayweed │ ├─Shepherds Purse │ ├─Small-flowered Cranesbill │ └─Sugar beet
新增格式轉(zhuǎn)化腳本makedata.py,插入代碼:
import glob
import os
import shutil
image_list=glob.glob('data1/*/*.png')
print(image_list)
file_dir='data'
if os.path.exists(file_dir):
print('true')
#os.rmdir(file_dir)
shutil.rmtree(file_dir)#刪除再建立
os.makedirs(file_dir)
else:
os.makedirs(file_dir)
from sklearn.model_selection import train_test_split
trainval_files, val_files = train_test_split(image_list, test_size=0.3, random_state=42)
train_dir='train'
val_dir='val'
train_root=os.path.join(file_dir,train_dir)
val_root=os.path.join(file_dir,val_dir)
for file in trainval_files:
file_class=file.replace("\\","/").split('/')[-2]
file_name=file.replace("\\","/").split('/')[-1]
file_class=os.path.join(train_root,file_class)
if not os.path.isdir(file_class):
os.makedirs(file_class)
shutil.copy(file, file_class + '/' + file_name)
for file in val_files:
file_class=file.replace("\\","/").split('/')[-2]
file_name=file.replace("\\","/").split('/')[-1]
file_class=os.path.join(val_root,file_class)
if not os.path.isdir(file_class):
os.makedirs(file_class)
shutil.copy(file, file_class + '/' + file_name)
完成上面的內(nèi)容就可以開(kāi)啟訓(xùn)練和測(cè)試了。
以上就是python人工智能使用RepVgg實(shí)現(xiàn)圖像分類示例詳解的詳細(xì)內(nèi)容,更多關(guān)于python人工智能RepVgg圖像分類的資料請(qǐng)關(guān)注腳本之家其它相關(guān)文章!
- Python使用?TCP協(xié)議實(shí)現(xiàn)智能聊天機(jī)器人功能
- 16行Python代碼實(shí)現(xiàn)微信聊天機(jī)器人并自動(dòng)智能回復(fù)功能
- python實(shí)現(xiàn)AI聊天機(jī)器人詳解流程
- python機(jī)器學(xué)習(xí)創(chuàng)建基于規(guī)則聊天機(jī)器人過(guò)程示例詳解
- python人工智能算法之線性回歸實(shí)例
- python人工智能算法之決策樹(shù)流程示例詳解
- python人工智能算法之人工神經(jīng)網(wǎng)絡(luò)
- python人工智能自定義求導(dǎo)tf_diffs詳解
- Python人工智能構(gòu)建簡(jiǎn)單聊天機(jī)器人示例詳解
相關(guān)文章
Python使用Pillow實(shí)現(xiàn)圖像基本變化
這篇文章主要為大家詳細(xì)介紹了Python如何使用Pillow實(shí)現(xiàn)圖像的基本變化處理,文中的示例代碼講解詳細(xì),具有一定的學(xué)習(xí)價(jià)值,需要的可以了解一下2022-10-10
帶你學(xué)習(xí)Python如何實(shí)現(xiàn)回歸樹(shù)模型
這篇文章主要介紹了Python如何實(shí)現(xiàn)回歸樹(shù)模型,文中講解非常細(xì)致,幫助大家更好的理解和學(xué)習(xí),感興趣的朋友可以了解下2020-07-07
轉(zhuǎn)換科學(xué)計(jì)數(shù)法的數(shù)值字符串為decimal類型的方法
今天小編就為大家分享一篇轉(zhuǎn)換科學(xué)計(jì)數(shù)法的數(shù)值字符串為decimal類型的方法,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧2018-07-07
Python爬蟲(chóng):url中帶字典列表參數(shù)的編碼轉(zhuǎn)換方法
今天小編就為大家分享一篇Python爬蟲(chóng):url中帶字典列表參數(shù)的編碼轉(zhuǎn)換方法,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧2019-08-08
python實(shí)現(xiàn)模擬數(shù)字的魔術(shù)游戲
這篇文章介紹了python實(shí)現(xiàn)模擬數(shù)字的魔術(shù)游戲,小編覺(jué)得挺不錯(cuò)的,現(xiàn)在分享給大家,也給大家做個(gè)參考。一起跟隨小編過(guò)來(lái)看看吧2021-12-12
python利用 keyboard 庫(kù)記錄鍵盤(pán)事件
這篇文章主要介紹了python利用 keyboard 庫(kù)記錄鍵盤(pán)事件,幫助大家更好的利用python進(jìn)行辦公,感興趣的朋友可以了解下2020-10-10
YOLOv5在圖片上顯示統(tǒng)計(jì)出單一檢測(cè)目標(biāo)的個(gè)數(shù)實(shí)例代碼
各位讀者首先要認(rèn)識(shí)到的問(wèn)題是,在YOLOv5中完成錨框計(jì)數(shù)是一件非常簡(jiǎn)單的工作,下面這篇文章主要給大家介紹了關(guān)于YOLOv5如何在圖片上顯示統(tǒng)計(jì)出單一檢測(cè)目標(biāo)的個(gè)數(shù)的相關(guān)資料,需要的朋友可以參考下2023-03-03
Python通用循環(huán)的構(gòu)造方法實(shí)例分析
這篇文章主要介紹了Python通用循環(huán)的構(gòu)造方法,結(jié)合實(shí)例形式分析了Python常見(jiàn)的交互循環(huán)、哨兵循環(huán)、文件循環(huán)、死循環(huán)等實(shí)現(xiàn)與處理技巧,需要的朋友可以參考下2018-12-12

