pytorch使用voc分割數(shù)據(jù)集訓(xùn)練FCN流程講解
語義分割是對圖像中的每一個像素進(jìn)行分類,從而完成圖像分割的過程。分割主要用于醫(yī)學(xué)圖像領(lǐng)域和無人駕駛領(lǐng)域。
和其他算法一樣,圖像分割發(fā)展過程也經(jīng)歷了傳統(tǒng)算法到深度學(xué)習(xí)算法的轉(zhuǎn)變,傳統(tǒng)的分割算法包括閾值分割、分水嶺、邊緣檢測等等,面臨的問題也跟其他傳統(tǒng)圖像處理算法一樣,就是魯棒性不夠,但在一些場景單一不變的場合,傳統(tǒng)圖像處理依舊用的較多。
FCN是2014年的一篇論文,深度學(xué)習(xí)語義分割的開山之作,從思想上奠定了語義分割的基礎(chǔ)。
Fully Convolutional Networks for Semantic Segmentation
Submitted on 14 Nov 2014
https://arxiv.org/abs/1411.4038
一、FCN理論介紹
上圖是原論文中的截圖,從整體架構(gòu)上描繪了FCN的網(wǎng)絡(luò)架構(gòu)。其實就是圖像經(jīng)過一系列卷積運(yùn)算,然后再上采樣成原圖大小,輸出每一個像素的類別概率。
上圖更加細(xì)致的描述了FCN的網(wǎng)絡(luò)。backbone采用VGG16,把VGG的fully-connect層用卷積來表示,即conv6-7(一個大小和feature_map同樣size的卷積核,就相當(dāng)于全連接)??偟膩碚f,網(wǎng)絡(luò)有下列幾個關(guān)鍵點:
1. Fully Convolution: 用于解決像素的預(yù)測問題。通過將基礎(chǔ)網(wǎng)絡(luò)(如VGG16)最后全連接層替換為卷積層,可實現(xiàn)任意大小的圖像輸入,并且輸出圖像大小與輸入相對應(yīng);
2.Transpose Convolution: 上采樣過程,用于恢復(fù)圖片尺寸,方便后續(xù)進(jìn)行逐個像素的預(yù)測;
3. Skip Architecture : 用于融合高底層特征信息。因為卷積是個下采樣操作,而轉(zhuǎn)置卷積雖然恢復(fù)了圖像尺寸,但畢竟不是卷積的逆操作,所以信息肯定有丟失,而skip architecture可以融合千層的細(xì)粒度信息和深層的粗粒度信息,提高分割的精細(xì)程度。
FCN-32s: 沒有跳連接,按照每層轉(zhuǎn)置卷積放大2倍的速度放大,經(jīng)過五層后放大32倍復(fù)原原圖大小。
FCN-16s: 一個skip-connect,(1/32)放大為(1/16)后,再與vgg的(1/16)相加,然后繼續(xù)放大,直到原圖大小。
FCN-8s: 兩個skip-connect,一個是(1/32)放大為(1/16)后,再與vgg的(1/16)相加;另外一個是(1/16)放大為(1/8)之后,再與vgg的(1/8)相加,然后繼續(xù)放大,直到原圖大小。
二、訓(xùn)練過程
pytorch訓(xùn)練深度學(xué)習(xí)模型主要實現(xiàn)三個文件即可,分別為data.py, model.py, train.py。其中data.py里實現(xiàn)數(shù)據(jù)批量處理功能,model.py定義網(wǎng)絡(luò)模型,train.py實現(xiàn)訓(xùn)練步驟。
2.1 voc數(shù)據(jù)集介紹
下載地址:Pascal VOC Dataset Mirror
圖片的名稱在/ImageSets/Segmentation/train.txt ans val.txt里
圖片都在./data/VOC2012/JPEGImages文件夾下面,需要在train.txt讀取的每一行后面加上.jpg
標(biāo)簽都在./data/VOC2012/SegmentationClass文件夾下面,需要在讀取的每一行后面加上.png
voc_seg_data.py
import torch import torch.nn as nn import torchvision.transforms as T from torch.utils.data import DataLoader,Dataset import numpy as np import os from PIL import Image from datetime import datetime class VOC_SEG(Dataset): def __init__(self, root, width, height, train=True, transforms=None): # 圖像統(tǒng)一剪切尺寸(width, height) self.width = width self.height = height # VOC數(shù)據(jù)集中對應(yīng)的標(biāo)簽 self.classes = ['background','aeroplane','bicycle','bird','boat', 'bottle','bus','car','cat','chair','cow','diningtable', 'dog','horse','motorbike','person','potted plant', 'sheep','sofa','train','tv/monitor'] # 各種標(biāo)簽所對應(yīng)的顏色 self.colormap = [[0,0,0],[128,0,0],[0,128,0], [128,128,0], [0,0,128], [128,0,128],[0,128,128],[128,128,128],[64,0,0],[192,0,0], [64,128,0],[192,128,0],[64,0,128],[192,0,128], [64,128,128],[192,128,128],[0,64,0],[128,64,0], [0,192,0],[128,192,0],[0,64,128]] # 輔助變量 self.fnum = 0 if transforms is None: normalize = T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) self.transforms = T.Compose([ T.ToTensor(), normalize ]) # 像素值(RGB)與類別label(0,1,3...)一一對應(yīng) self.cm2lbl = np.zeros(256**3) for i, cm in enumerate(self.colormap): self.cm2lbl[(cm[0]*256+cm[1])*256+cm[2]] = i if train: txt_fname = root+"/ImageSets/Segmentation/train.txt" else: txt_fname = root+"/ImageSets/Segmentation/val.txt" with open(txt_fname, 'r') as f: images = f.read().split() imgs = [os.path.join(root, "JPEGImages", item+".jpg") for item in images] labels = [os.path.join(root, "SegmentationClass", item+".png") for item in images] self.imgs = self._filter(imgs) self.labels = self._filter(labels) if train: print("訓(xùn)練集:加載了 " + str(len(self.imgs)) + " 張圖片和標(biāo)簽" + ",過濾了" + str(self.fnum) + "張圖片") else: print("測試集:加載了 " + str(len(self.imgs)) + " 張圖片和標(biāo)簽" + ",過濾了" + str(self.fnum) + "張圖片") def _crop(self, data, label): """ 切割函數(shù),默認(rèn)都是從圖片的左上角開始切割。切割后的圖片寬是width,高是height data和label都是Image對象 """ box = (0,0,self.width,self.height) data = data.crop(box) label = label.crop(box) return data, label def _image2label(self, im): data = np.array(im, dtype="int32") idx = (data[:,:,0]*256+data[:,:,1])*256+data[:,:,2] return np.array(self.cm2lbl[idx], dtype="int64") def _image_transforms(self, data, label): data, label = self._crop(data,label) data = self.transforms(data) label = self._image2label(label) label = torch.from_numpy(label) return data, label def _filter(self, imgs): img = [] for im in imgs: if (Image.open(im).size[1] >= self.height and Image.open(im).size[0] >= self.width): img.append(im) else: self.fnum = self.fnum+1 return img def __getitem__(self, index: int): img_path = self.imgs[index] label_path = self.labels[index] img = Image.open(img_path) label = Image.open(label_path).convert("RGB") img, label = self._image_transforms(img, label) return img, label def __len__(self) : return len(self.imgs) if __name__=="__main__": root = "./VOCdevkit/VOC2012" height = 224 width = 224 voc_train = VOC_SEG(root, width, height, train=True) voc_test = VOC_SEG(root, width, height, train=False) # train_data = DataLoader(voc_train, batch_size=8, shuffle=True) # valid_data = DataLoader(voc_test, batch_size=8) for data, label in voc_train: print(data.shape) print(label.shape) break
- 我這里為了省事把一些輔助函數(shù),如_crop(), _filter(),還是有變量colormap等都寫到類里面了。實際上脫離出來另外寫一個數(shù)據(jù)預(yù)處理的文件比較好,這樣在訓(xùn)練結(jié)束后,推理測試時可以直接調(diào)用相應(yīng)的處理函數(shù)。
- 數(shù)據(jù)處理的結(jié)果是得到data, label。data是tensor格式的圖像,label也是tensor,且已經(jīng)把像素(RGB)替換為了int類別號。這樣在訓(xùn)練時候,交叉熵函數(shù)直接會實現(xiàn)one-hot處理,就跟訓(xùn)練分類網(wǎng)絡(luò)一樣。
2.2 網(wǎng)絡(luò)定義
fcn8s_net.py
import torch import torch.nn as nn from torch.autograd import Variable import torch.nn.functional as F from torchsummary import summary from torchvision import models class FCN8s(nn.Module): def __init__(self, num_classes=21): super(FCN8s,self).__init__() net = models.vgg16(pretrained=True) # 從預(yù)訓(xùn)練模型加載VGG16網(wǎng)絡(luò)參數(shù) self.premodel = net.features # 只使用Vgg16的五層卷積層(特征提取層)(3,224,224)----->(512,7,7) # self.conv6 = nn.Conv2d(512,512,kernel_size=1,stride=1,padding=0,dilation=1) # self.conv7 = nn.Conv2d(512,512,kernel_size=1,stride=1,padding=0,dilation=1) # (512,7,7) self.relu = nn.ReLU(inplace=True) self.deconv1 = nn.ConvTranspose2d(512,512,kernel_size=3,stride=2,padding=1,dilation=1,output_padding=1) # x2 self.bn1 = nn.BatchNorm2d(512) # (512, 14, 14) self.deconv2 = nn.ConvTranspose2d(512,256,kernel_size=3,stride=2,padding=1,dilation=1,output_padding=1) # x2 self.bn2 = nn.BatchNorm2d(256) # (256, 28, 28) self.deconv3 = nn.ConvTranspose2d(256,128,kernel_size=3,stride=2,padding=1,dilation=1,output_padding=1) # x2 self.bn3 = nn.BatchNorm2d(128) # (128, 56, 56) self.deconv4 = nn.ConvTranspose2d(128,64,kernel_size=3,stride=2,padding=1,dilation=1,output_padding=1) # x2 self.bn4 = nn.BatchNorm2d(64) # (64, 112, 112) self.deconv5 = nn.ConvTranspose2d(64,32,kernel_size=3,stride=2,padding=1,dilation=1,output_padding=1) # x2 self.bn5 = nn.BatchNorm2d(32) # (32, 224, 224) self.classifier = nn.Conv2d(32, num_classes, kernel_size=1) # (num_classes, 224, 224) def forward(self, input): x = input for i in range(len(self.premodel)): x = self.premodel[i](x) if i == 16: x3 = x # maxpooling3的feature map (1/8) if i == 23: x4 = x # maxpooling4的feature map (1/16) if i == 30: x5 = x # maxpooling5的feature map (1/32) # 五層轉(zhuǎn)置卷積,每層size放大2倍,與VGG16剛好相反。兩個skip-connect score = self.relu(self.deconv1(x5)) # out_size = 2*in_size (1/16) score = self.bn1(score + x4) score = self.relu(self.deconv2(score)) # out_size = 2*in_size (1/8) score = self.bn2(score + x3) score = self.bn3(self.relu(self.deconv3(score))) # out_size = 2*in_size (1/4) score = self.bn4(self.relu(self.deconv4(score))) # out_size = 2*in_size (1/2) score = self.bn5(self.relu(self.deconv5(score))) # out_size = 2*in_size (1) score = self.classifier(score) # size不變,使輸出的channel等于類別數(shù) return score if __name__ == "__main__": model = FCN8s() device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = model.to(device) print(model)
FCN的網(wǎng)絡(luò)代碼實現(xiàn)上,在網(wǎng)上查的都有所差異,不過總體都是卷積+轉(zhuǎn)置卷積+跳鏈接的結(jié)構(gòu)。實際上只要實現(xiàn)特征提?。ㄌ崛〕橄筇卣鳎?mdash;—轉(zhuǎn)置卷積(恢復(fù)原圖大?。?mdash;—給每一個像素分類的過程就夠了。
本次實驗采用vgg16的五層卷積層作為特征提取網(wǎng)絡(luò),然后接五個轉(zhuǎn)置卷積(2x)恢復(fù)到原圖大小,然后再接一個卷積層把feature map的通道調(diào)整為類別個數(shù)(21)。最后再softmax分類就行了。
2.3 訓(xùn)練
train.py
import torch import torch.nn as nn from torch.utils.data import DataLoader,Dataset from voc_seg_data import VOC_SEG from fcn_net import FCN8s import os import numpy as np # 計算混淆矩陣 def _fast_hist(label_true, label_pred, n_class): mask = (label_true >= 0) & (label_true < n_class) hist = np.bincount( n_class * label_true[mask].astype(int) + label_pred[mask], minlength=n_class ** 2).reshape(n_class, n_class) return hist # 根據(jù)混淆矩陣計算Acc和mIou def label_accuracy_score(label_trues, label_preds, n_class): """Returns accuracy score evaluation result. - overall accuracy - mean accuracy - mean IU """ hist = np.zeros((n_class, n_class)) for lt, lp in zip(label_trues, label_preds): hist += _fast_hist(lt.flatten(), lp.flatten(), n_class) acc = np.diag(hist).sum() / hist.sum() with np.errstate(divide='ignore', invalid='ignore'): acc_cls = np.diag(hist) / hist.sum(axis=1) acc_cls = np.nanmean(acc_cls) with np.errstate(divide='ignore', invalid='ignore'): iu = np.diag(hist) / ( hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist) ) mean_iu = np.nanmean(iu) freq = hist.sum(axis=1) / hist.sum() return acc, acc_cls, mean_iu def main(): # 1. load dataset root = "./VOCdevkit/VOC2012" batch_size = 32 height = 224 width = 224 voc_train = VOC_SEG(root, width, height, train=True) voc_test = VOC_SEG(root, width, height, train=False) train_dataloader = DataLoader(voc_train,batch_size=batch_size,shuffle=True) val_dataloader = DataLoader(voc_test,batch_size=batch_size,shuffle=True) # 2. load model num_class = 21 model = FCN8s(num_classes=num_class) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = model.to(device) # 3. prepare super parameters criterion = nn.CrossEntropyLoss() optimizer = torch.optim.SGD(model.parameters(), lr=1e-3, momentum=0.7) epoch = 50 # 4. train val_acc_list = [] out_dir = "./checkpoints/" if not os.path.exists(out_dir): os.makedirs(out_dir) for epoch in range(0, epoch): print('\nEpoch: %d' % (epoch + 1)) model.train() sum_loss = 0.0 for batch_idx, (images, labels) in enumerate(train_dataloader): length = len(train_dataloader) images, labels = images.to(device), labels.to(device) optimizer.zero_grad() outputs = model(images) # torch.size([batch_size, num_class, width, height]) loss = criterion(outputs, labels) loss.backward() optimizer.step() sum_loss += loss.item() predicted = torch.argmax(outputs.data, 1) label_pred = predicted.data.cpu().numpy() label_true = labels.data.cpu().numpy() acc, acc_cls, mean_iu = label_accuracy_score(label_true,label_pred,num_class) print('[epoch:%d, iter:%d] Loss: %.03f | Acc: %.3f%% | Acc_cls: %.03f%% |Mean_iu: %.3f' % (epoch + 1, (batch_idx + 1 + epoch * length), sum_loss / (batch_idx + 1), 100. *acc, 100.*acc_cls, mean_iu)) #get the ac with testdataset in each epoch print('Waiting Val...') mean_iu_epoch = 0.0 mean_acc = 0.0 mean_acc_cls = 0.0 with torch.no_grad(): for batch_idx, (images, labels) in enumerate(val_dataloader): model.eval() images, labels = images.to(device), labels.to(device) outputs = model(images) predicted = torch.argmax(outputs.data, 1) label_pred = predicted.data.cpu().numpy() label_true = labels.data.cpu().numpy() acc, acc_cls, mean_iu = label_accuracy_score(label_true,label_pred,num_class) # total += labels.size(0) # iou = torch.sum((predicted == labels.data), (1,2)) / float(width*height) # iou = torch.sum(iou) # correct += iou mean_iu_epoch += mean_iu mean_acc += acc mean_acc_cls += acc_cls print('Acc_epoch: %.3f%% | Acc_cls_epoch: %.03f%% |Mean_iu_epoch: %.3f' % ((100. *mean_acc / len(val_dataloader)), (100.*mean_acc_cls/len(val_dataloader)), mean_iu_epoch/len(val_dataloader)) ) val_acc_list.append(mean_iu_epoch/len(val_dataloader)) torch.save(model.state_dict(), out_dir+"last.pt") if mean_iu_epoch/len(val_dataloader) == max(val_acc_list): torch.save(model.state_dict(), out_dir+"best.pt") print("save epoch {} model".format(epoch)) if __name__ == "__main__": main()
整體訓(xùn)練流程沒問題,讀者可以根據(jù)需要更改其模型評價標(biāo)準(zhǔn)和相關(guān)代碼。在本次訓(xùn)練中,主要使用Acc作為評價指標(biāo),其實就是分類正確的像素個數(shù)除以全部像素個數(shù)。最終訓(xùn)練結(jié)果如下:
0.8
訓(xùn)練集的Acc來到了0.8, 驗證集的Acc來到了0.77。由于有一些函數(shù)是復(fù)制過來的,如_hist等,所以其他指標(biāo)暫時不參考。
到此這篇關(guān)于pytorch使用voc分割數(shù)據(jù)集訓(xùn)練FCN流程講解的文章就介紹到這了,更多相關(guān)pytorch訓(xùn)練FCN內(nèi)容請搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
相關(guān)文章
使用OpenCV獲取圖片連通域數(shù)量,并用不同顏色標(biāo)記函
這篇文章主要介紹了使用OpenCV獲取圖片連通域數(shù)量,并用不同顏色標(biāo)記函,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧2020-06-06python中將兩組數(shù)據(jù)放在一起按照某一固定順序shuffle的實例
今天小編就為大家分享一篇python中將兩組數(shù)據(jù)放在一起按照某一固定順序shuffle的實例,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧2019-07-07Python3使用requests包抓取并保存網(wǎng)頁源碼的方法
這篇文章主要介紹了Python3使用requests包抓取并保存網(wǎng)頁源碼的方法,實例分析了Python3環(huán)境下requests模塊的相關(guān)使用技巧,需要的朋友可以參考下2016-03-03python標(biāo)準(zhǔn)庫OS模塊函數(shù)列表與實例全解
這篇文章主要介紹了python標(biāo)準(zhǔn)庫OS模塊函數(shù)列表與實例全解,需要的朋友可以參考下2020-03-03Python?內(nèi)置logging?使用詳細(xì)介紹
提供日志記錄的接口和眾多處理模塊,供用戶存儲各種格式的日志,幫助調(diào)試程序或者記錄程序運(yùn)行過程中的輸出信息,這篇文章主要介紹了Python?內(nèi)置logging?使用講解,需要的朋友可以參考下2022-07-07Python機(jī)器學(xué)習(xí)算法庫scikit-learn學(xué)習(xí)之決策樹實現(xiàn)方法詳解
這篇文章主要介紹了Python機(jī)器學(xué)習(xí)算法庫scikit-learn學(xué)習(xí)之決策樹實現(xiàn)方法,結(jié)合實例形式分析了決策樹算法的原理及使用sklearn庫實現(xiàn)決策樹的相關(guān)操作技巧,需要的朋友可以參考下2019-07-07