pytorch通過(guò)自己的數(shù)據(jù)集訓(xùn)練Unet網(wǎng)絡(luò)架構(gòu)
在圖像分割這個(gè)問(wèn)題上,主要有兩個(gè)流派:Encoder-Decoder和Dialated Conv。本文介紹的是編解碼網(wǎng)絡(luò)中最為經(jīng)典的U-Net。隨著骨干網(wǎng)路的進(jìn)化,很多相應(yīng)衍生出來(lái)的網(wǎng)絡(luò)大多都是對(duì)于Unet進(jìn)行了改進(jìn)但是本質(zhì)上的思路還是沒(méi)有太多的變化。比如結(jié)合DenseNet 和Unet的FCDenseNet, Unet++
一、Unet網(wǎng)絡(luò)介紹
論文:https://arxiv.org/abs/1505.04597v1(2015)
UNet的設(shè)計(jì)就是應(yīng)用與醫(yī)學(xué)圖像的分割。由于醫(yī)學(xué)影像處理中,數(shù)據(jù)量較少,本文提出的方法有效提升了使用少量數(shù)據(jù)集訓(xùn)練檢測(cè)的效果,提出了處理大尺寸圖像的有效方法。
UNet的網(wǎng)絡(luò)架構(gòu)繼承自FCN,并在此基礎(chǔ)上做了些改變。提出了Encoder-Decoder概念,實(shí)際上就是FCN那個(gè)先卷積再上采樣的思想。
上圖是Unet的網(wǎng)絡(luò)結(jié)構(gòu),從圖中可以看出,
結(jié)構(gòu)左邊為Encoder,即下采樣提取特征的過(guò)程。Encoder基本模塊為雙卷積形式,即輸入經(jīng)過(guò)兩個(gè)
conu 3x3,使用的valid卷積,在代碼實(shí)現(xiàn)時(shí)我們可以增加padding使用same卷積,來(lái)適應(yīng)Skip Architecture。下采樣采用的池化層直接縮小2倍。
結(jié)構(gòu)右邊是Decoder,即上采樣恢復(fù)圖像尺寸并預(yù)測(cè)的過(guò)程。Decoder一樣采用雙卷積的形式,其中上采樣使用轉(zhuǎn)置卷積實(shí)現(xiàn),每次轉(zhuǎn)置卷積放大2倍。
結(jié)構(gòu)中間copy and crop是一個(gè)cat操作,即feature map的通道疊加。
二、VOC訓(xùn)練Unet
2.1 Unet代碼實(shí)現(xiàn)
根據(jù)上面對(duì)于Unet網(wǎng)絡(luò)結(jié)構(gòu)的介紹,可見(jiàn)其結(jié)構(gòu)非常對(duì)稱簡(jiǎn)單,代碼Unet.py實(shí)現(xiàn)如下:
from turtle import forward import torch.nn as nn import torch class DoubleConv(nn.Module): def __init__(self, in_ch, out_ch): super(DoubleConv, self).__init__() self.conv = nn.Sequential( nn.Conv2d(in_ch, out_ch, 3, padding=1), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True), nn.Conv2d(out_ch, out_ch, 3, padding=1), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True) ) def forward(self, x): return self.conv(x) class Unet(nn.Module): def __init__(self, in_ch, out_ch): super(Unet, self).__init__() # Encoder self.conv1 = DoubleConv(in_ch, 64) self.pool1 = nn.MaxPool2d(2) self.conv2 = DoubleConv(64, 128) self.pool2 = nn.MaxPool2d(2) self.conv3 = DoubleConv(128, 256) self.pool3 = nn.MaxPool2d(2) self.conv4 = DoubleConv(256, 512) self.pool4 = nn.MaxPool2d(2) self.conv5 = DoubleConv(512, 1024) # Decoder self.up6 = nn.ConvTranspose2d(1024, 512, 2, stride=2) self.conv6 = DoubleConv(1024, 512) self.up7 = nn.ConvTranspose2d(512, 256, 2, stride=2) self.conv7 = DoubleConv(512, 256) self.up8 = nn.ConvTranspose2d(256, 128, 2, stride=2) self.conv8 = DoubleConv(256, 128) self.up9 = nn.ConvTranspose2d(128, 64, 2, stride=2) self.conv9 = DoubleConv(128, 64) self.output = nn.Conv2d(64, out_ch, 1) def forward(self, x): conv1 = self.conv1(x) pool1 = self.pool1(conv1) conv2 = self.conv2(pool1) pool2 = self.pool2(conv2) conv3 = self.conv3(pool2) pool3 = self.pool3(conv3) conv4 = self.conv4(pool3) pool4 = self.pool4(conv4) conv5 = self.conv5(pool4) up6 = self.up6(conv5) meger6 = torch.cat([up6, conv4], dim=1) conv6 = self.conv6(meger6) up7 = self.up7(conv6) meger7 = torch.cat([up7, conv3], dim=1) conv7 = self.conv7(meger7) up8 = self.up8(conv7) meger8 = torch.cat([up8, conv2], dim=1) conv8 = self.conv8(meger8) up9 = self.up9(conv8) meger9 = torch.cat([up9, conv1], dim=1) conv9 = self.conv9(meger9) out = self.output(conv9) return out if __name__=="__main__": model = Unet(3, 21) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = model.to(device) print(model)
2.2 數(shù)據(jù)集處理
數(shù)據(jù)來(lái)源于kaggle,下載地址我忘了。包含2個(gè)類別,1個(gè)車(chē),還有1個(gè)背景類,共有5k+的數(shù)據(jù),按照比例分為訓(xùn)練集和驗(yàn)證集即可。具體見(jiàn)carnava.py
from PIL import Image from requests import check_compatibility import torch import torch.nn as nn from torch.utils.data import DataLoader, Dataset from torchvision import transforms as T import numpy as np import os import matplotlib.pyplot as plt class Car(Dataset): def __init__(self, root, train=True): self.root = root self.crop_size = (256, 256) self.img_path = os.path.join(root, "train_hq") self.label_path = os.path.join(root, "train_masks") img_path_list = [os.path.join(self.img_path, im) for im in os.listdir(self.img_path)] train_path_list, val_path_list = self._split_data_set(img_path_list) if train: self.imgs_list = train_path_list else: self.imgs_list = val_path_list normalize = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) self.transforms = T.Compose([ T.Resize(256), T.CenterCrop(256), T.ToTensor(), normalize ]) self.transforms_val = T.Compose([ T.Resize(256), T.CenterCrop(256) ]) self.color_map = [[0, 0, 0], [255, 255, 255]] def __getitem__(self, index: int): im_path = self.imgs_list[index] image = Image.open(im_path).convert("RGB") data = self.transforms(image) (filepath, filename) = os.path.split(im_path) filename = filename.split('.')[0] label = Image.open(self.label_path +"/"+filename+"_mask.gif").convert("RGB") label = self.transforms_val(label) cm2lb=np.zeros(256**3) for i,cm in enumerate(self.color_map): cm2lb[(cm[0]*256+cm[1])*256+cm[2]]=i image=np.array(label,dtype=np.int64) idx=(image[:,:,0]*256+image[:,:,1])*256+image[:,:,2] label=np.array(cm2lb[idx],dtype=np.int64) label=torch.from_numpy(label).long() return data, label def label2img(self, label): cmap = self.color_map cmap = np.array(cmap).astype(np.uint8) pred = cmap[label] return pred def __len__(self): return len(self.imgs_list) def _split_data_set(self, img_path_list): val_path_list = img_path_list[::8] train_path_list = [] for item in img_path_list: if item not in val_path_list: train_path_list.append(item) return train_path_list, val_path_list if __name__=="__main__": root = "../dataset/carvana" car_train = Car(root,train=True) train_dataloader = DataLoader(car_train, batch_size=8, shuffle=True) print(len(car_train)) print(len(train_dataloader)) # for data, label in car_train: # print(data.shape) # print(label.shape) # break (data, label) = car_train[190] label_np = label.data.numpy() label_im = car_train.label2img(label_np) plt.figure() plt.imshow(label_im) plt.show()
2.3 訓(xùn)練過(guò)程
分割其實(shí)就是給每個(gè)像素分類而已,所以損失函數(shù)依舊是交叉熵函數(shù),正確率為分類正確的像素點(diǎn)個(gè)數(shù)/全部的像素點(diǎn)個(gè)數(shù)
import torch import torch.nn as nn from torch.utils.data import DataLoader,Dataset from voc import VOC from carnava import Car from unet import Unet import os import numpy as np from torch import optim import torch.nn as nn import util # 計(jì)算混淆矩陣 def _fast_hist(label_true, label_pred, n_class): mask = (label_true >= 0) & (label_true < n_class) hist = np.bincount( n_class * label_true[mask].astype(int) + label_pred[mask], minlength=n_class ** 2).reshape(n_class, n_class) return hist def label_accuracy_score(label_trues, label_preds, n_class): """Returns accuracy score evaluation result. - overall accuracy - mean accuracy - mean IU """ hist = np.zeros((n_class, n_class)) for lt, lp in zip(label_trues, label_preds): hist += _fast_hist(lt.flatten(), lp.flatten(), n_class) acc = np.diag(hist).sum() / hist.sum() with np.errstate(divide='ignore', invalid='ignore'): acc_cls = np.diag(hist) / hist.sum(axis=1) acc_cls = np.nanmean(acc_cls) with np.errstate(divide='ignore', invalid='ignore'): iu = np.diag(hist) / ( hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist) ) mean_iu = np.nanmean(iu) freq = hist.sum(axis=1) / hist.sum() return acc, acc_cls, mean_iu out_path = "./out" if not os.path.exists(out_path): os.makedirs(out_path) log_path = os.path.join(out_path, "result.txt") if os.path.exists(log_path): os.remove(log_path) model_path = os.path.join(out_path, "best_model.pth") root = "../dataset/carvana" epochs = 5 numclasses = 2 train_data = Car(root, train=True) train_dataloader = DataLoader(train_data, batch_size=16, shuffle=True) val_data = Car(root, train=False) val_dataloader = DataLoader(val_data, batch_size=16, shuffle=True) net = Unet(3, numclasses) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") net = net.to(device) optimizer = optim.SGD(net.parameters(), lr=0.01, weight_decay=1e-4) criterion = nn.CrossEntropyLoss() def train_model(): best_score = 0.0 for e in range(epochs): net.train() train_loss = 0.0 label_true = torch.LongTensor() label_pred = torch.LongTensor() for batch_id, (data, label) in enumerate(train_dataloader): data, label = data.to(device), label.to(device) output = net(data) loss = criterion(output, label) pred = output.argmax(dim=1).squeeze().data.cpu() real = label.data.cpu() optimizer.zero_grad() loss.backward() optimizer.step() train_loss+=loss.cpu().item() label_true = torch.cat((label_true,real),dim=0) label_pred = torch.cat((label_pred,pred),dim=0) train_loss /= len(train_dataloader) acc, acc_cls, mean_iu = label_accuracy_score(label_true.numpy(),label_pred.numpy(),numclasses) print("\n epoch:{}, train_loss:{:.4f}, acc:{:.4f}, acc_cls:{:.4f}, mean_iu:{:.4f}".format( e+1, train_loss, acc, acc_cls, mean_iu)) with open(log_path, 'a') as f: f.write('\n epoch:{}, train_loss:{:.4f}, acc:{:.4f}, acc_cls:{:.4f}, mean_iu:{:.4f}'.format( e+1,train_loss,acc, acc_cls, mean_iu)) net.eval() val_loss = 0.0 val_label_true = torch.LongTensor() val_label_pred = torch.LongTensor() with torch.no_grad(): for batch_id, (data, label) in enumerate(val_dataloader): data, label = data.to(device), label.to(device) output = net(data) loss = criterion(output, label) pred = output.argmax(dim=1).squeeze().data.cpu() real = label.data.cpu() val_loss += loss.cpu().item() val_label_true = torch.cat((val_label_true, real), dim=0) val_label_pred = torch.cat((val_label_pred, pred), dim=0) val_loss/=len(val_dataloader) val_acc, val_acc_cls, val_mean_iu = label_accuracy_score(val_label_true.numpy(), val_label_pred.numpy(),numclasses) print('\n epoch:{}, val_loss:{:.4f}, acc:{:.4f}, acc_cls:{:.4f}, mean_iu:{:.4f}'.format(e+1, val_loss, val_acc, val_acc_cls, val_mean_iu)) with open(log_path, 'a') as f: f.write('\n epoch:{}, val_loss:{:.4f}, acc:{:.4f}, acc_cls:{:.4f}, mean_iu:{:.4f}'.format( e+1,val_loss,val_acc, val_acc_cls, val_mean_iu)) score = (val_acc_cls+val_mean_iu)/2 if score > best_score: best_score = score torch.save(net.state_dict(), model_path) def evaluate(): import util import random import matplotlib.pyplot as plt net.load_state_dict(torch.load(model_path)) index = random.randint(0, len(val_data)-1) val_image, val_label = val_data[index] out = net(val_image.unsqueeze(0).to(device)) pred = out.argmax(dim=1).squeeze().data.cpu().numpy() label = val_label.data.numpy() img_pred = val_data.label2img(pred) img_label = val_data.label2img(label) temp = val_image.numpy() temp = (temp-np.min(temp)) / (np.max(temp)-np.min(temp))*255 fig, ax = plt.subplots(1,3) ax[0].imshow(temp.transpose(1,2,0).astype("uint8")) ax[1].imshow(img_label) ax[2].imshow(img_pred) plt.show() if __name__=="__main__": # train_model() evaluate()
最終訓(xùn)練結(jié)果是:
由于數(shù)據(jù)比較簡(jiǎn)單,訓(xùn)練到epoch為5時(shí),mIOU就已經(jīng)達(dá)到0.97了。
最后測(cè)試一下效果:
從左到右分別是:原圖、真實(shí)label、預(yù)測(cè)label
備注:
其實(shí)最開(kāi)始使用voc數(shù)據(jù)集訓(xùn)練的,但效果極差,也沒(méi)發(fā)現(xiàn)哪里有問(wèn)題。換個(gè)數(shù)據(jù)集效果就好了,可能有兩個(gè)原因:
1. voc數(shù)據(jù)我在處理數(shù)據(jù)時(shí)出錯(cuò)了,沒(méi)檢查出來(lái)
2. 這個(gè)數(shù)據(jù)集比較簡(jiǎn)單,容易學(xué)習(xí),所以效果差不多。
到此這篇關(guān)于pytorch通過(guò)自己的數(shù)據(jù)集訓(xùn)練Unet網(wǎng)絡(luò)架構(gòu)的文章就介紹到這了,更多相關(guān)pytorch Unet內(nèi)容請(qǐng)搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
相關(guān)文章
Python實(shí)現(xiàn)批量讀取圖片并存入mongodb數(shù)據(jù)庫(kù)的方法示例
這篇文章主要介紹了Python實(shí)現(xiàn)批量讀取圖片并存入mongodb數(shù)據(jù)庫(kù)的方法,涉及Python文件讀取及數(shù)據(jù)庫(kù)寫(xiě)入相關(guān)操作技巧,需要的朋友可以參考下2018-04-04Golang與python線程詳解及簡(jiǎn)單實(shí)例
這篇文章主要介紹了Golang與python線程詳解及簡(jiǎn)單實(shí)例的相關(guān)資料,需要的朋友可以參考下2017-04-04OpenCV讀取與寫(xiě)入圖片的實(shí)現(xiàn)
這篇文章主要介紹了OpenCV讀取與寫(xiě)入圖片的實(shí)現(xiàn),文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來(lái)一起學(xué)習(xí)學(xué)習(xí)吧2020-10-10Django在視圖中使用表單并和數(shù)據(jù)庫(kù)進(jìn)行數(shù)據(jù)交互的實(shí)現(xiàn)
本文主要介紹了Django在視圖中使用表單并和數(shù)據(jù)庫(kù)進(jìn)行數(shù)據(jù)交互,文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來(lái)一起學(xué)習(xí)學(xué)習(xí)吧2022-07-07pytorch繪制并顯示loss曲線和acc曲線,LeNet5識(shí)別圖像準(zhǔn)確率
今天小編就為大家分享一篇pytorch繪制并顯示loss曲線和acc曲線,LeNet5識(shí)別圖像準(zhǔn)確率,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧2020-01-01CentOS 7 安裝python3.7.1的方法及注意事項(xiàng)
這篇文章主要介紹了CentOS 7 安裝python3.7.1的方法,文中給大家提到了注意事項(xiàng),需要的朋友可以參考下2018-11-11