pytorch通過自己的數(shù)據(jù)集訓(xùn)練Unet網(wǎng)絡(luò)架構(gòu)
在圖像分割這個問題上,主要有兩個流派:Encoder-Decoder和Dialated Conv。本文介紹的是編解碼網(wǎng)絡(luò)中最為經(jīng)典的U-Net。隨著骨干網(wǎng)路的進(jìn)化,很多相應(yīng)衍生出來的網(wǎng)絡(luò)大多都是對于Unet進(jìn)行了改進(jìn)但是本質(zhì)上的思路還是沒有太多的變化。比如結(jié)合DenseNet 和Unet的FCDenseNet, Unet++
一、Unet網(wǎng)絡(luò)介紹
論文:https://arxiv.org/abs/1505.04597v1(2015)
UNet的設(shè)計就是應(yīng)用與醫(yī)學(xué)圖像的分割。由于醫(yī)學(xué)影像處理中,數(shù)據(jù)量較少,本文提出的方法有效提升了使用少量數(shù)據(jù)集訓(xùn)練檢測的效果,提出了處理大尺寸圖像的有效方法。
UNet的網(wǎng)絡(luò)架構(gòu)繼承自FCN,并在此基礎(chǔ)上做了些改變。提出了Encoder-Decoder概念,實際上就是FCN那個先卷積再上采樣的思想。
上圖是Unet的網(wǎng)絡(luò)結(jié)構(gòu),從圖中可以看出,
結(jié)構(gòu)左邊為Encoder,即下采樣提取特征的過程。Encoder基本模塊為雙卷積形式,即輸入經(jīng)過兩個
conu 3x3,使用的valid卷積,在代碼實現(xiàn)時我們可以增加padding使用same卷積,來適應(yīng)Skip Architecture。下采樣采用的池化層直接縮小2倍。
結(jié)構(gòu)右邊是Decoder,即上采樣恢復(fù)圖像尺寸并預(yù)測的過程。Decoder一樣采用雙卷積的形式,其中上采樣使用轉(zhuǎn)置卷積實現(xiàn),每次轉(zhuǎn)置卷積放大2倍。
結(jié)構(gòu)中間copy and crop是一個cat操作,即feature map的通道疊加。
二、VOC訓(xùn)練Unet
2.1 Unet代碼實現(xiàn)
根據(jù)上面對于Unet網(wǎng)絡(luò)結(jié)構(gòu)的介紹,可見其結(jié)構(gòu)非常對稱簡單,代碼Unet.py實現(xiàn)如下:
from turtle import forward import torch.nn as nn import torch class DoubleConv(nn.Module): def __init__(self, in_ch, out_ch): super(DoubleConv, self).__init__() self.conv = nn.Sequential( nn.Conv2d(in_ch, out_ch, 3, padding=1), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True), nn.Conv2d(out_ch, out_ch, 3, padding=1), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True) ) def forward(self, x): return self.conv(x) class Unet(nn.Module): def __init__(self, in_ch, out_ch): super(Unet, self).__init__() # Encoder self.conv1 = DoubleConv(in_ch, 64) self.pool1 = nn.MaxPool2d(2) self.conv2 = DoubleConv(64, 128) self.pool2 = nn.MaxPool2d(2) self.conv3 = DoubleConv(128, 256) self.pool3 = nn.MaxPool2d(2) self.conv4 = DoubleConv(256, 512) self.pool4 = nn.MaxPool2d(2) self.conv5 = DoubleConv(512, 1024) # Decoder self.up6 = nn.ConvTranspose2d(1024, 512, 2, stride=2) self.conv6 = DoubleConv(1024, 512) self.up7 = nn.ConvTranspose2d(512, 256, 2, stride=2) self.conv7 = DoubleConv(512, 256) self.up8 = nn.ConvTranspose2d(256, 128, 2, stride=2) self.conv8 = DoubleConv(256, 128) self.up9 = nn.ConvTranspose2d(128, 64, 2, stride=2) self.conv9 = DoubleConv(128, 64) self.output = nn.Conv2d(64, out_ch, 1) def forward(self, x): conv1 = self.conv1(x) pool1 = self.pool1(conv1) conv2 = self.conv2(pool1) pool2 = self.pool2(conv2) conv3 = self.conv3(pool2) pool3 = self.pool3(conv3) conv4 = self.conv4(pool3) pool4 = self.pool4(conv4) conv5 = self.conv5(pool4) up6 = self.up6(conv5) meger6 = torch.cat([up6, conv4], dim=1) conv6 = self.conv6(meger6) up7 = self.up7(conv6) meger7 = torch.cat([up7, conv3], dim=1) conv7 = self.conv7(meger7) up8 = self.up8(conv7) meger8 = torch.cat([up8, conv2], dim=1) conv8 = self.conv8(meger8) up9 = self.up9(conv8) meger9 = torch.cat([up9, conv1], dim=1) conv9 = self.conv9(meger9) out = self.output(conv9) return out if __name__=="__main__": model = Unet(3, 21) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = model.to(device) print(model)
2.2 數(shù)據(jù)集處理
數(shù)據(jù)來源于kaggle,下載地址我忘了。包含2個類別,1個車,還有1個背景類,共有5k+的數(shù)據(jù),按照比例分為訓(xùn)練集和驗證集即可。具體見carnava.py
from PIL import Image from requests import check_compatibility import torch import torch.nn as nn from torch.utils.data import DataLoader, Dataset from torchvision import transforms as T import numpy as np import os import matplotlib.pyplot as plt class Car(Dataset): def __init__(self, root, train=True): self.root = root self.crop_size = (256, 256) self.img_path = os.path.join(root, "train_hq") self.label_path = os.path.join(root, "train_masks") img_path_list = [os.path.join(self.img_path, im) for im in os.listdir(self.img_path)] train_path_list, val_path_list = self._split_data_set(img_path_list) if train: self.imgs_list = train_path_list else: self.imgs_list = val_path_list normalize = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) self.transforms = T.Compose([ T.Resize(256), T.CenterCrop(256), T.ToTensor(), normalize ]) self.transforms_val = T.Compose([ T.Resize(256), T.CenterCrop(256) ]) self.color_map = [[0, 0, 0], [255, 255, 255]] def __getitem__(self, index: int): im_path = self.imgs_list[index] image = Image.open(im_path).convert("RGB") data = self.transforms(image) (filepath, filename) = os.path.split(im_path) filename = filename.split('.')[0] label = Image.open(self.label_path +"/"+filename+"_mask.gif").convert("RGB") label = self.transforms_val(label) cm2lb=np.zeros(256**3) for i,cm in enumerate(self.color_map): cm2lb[(cm[0]*256+cm[1])*256+cm[2]]=i image=np.array(label,dtype=np.int64) idx=(image[:,:,0]*256+image[:,:,1])*256+image[:,:,2] label=np.array(cm2lb[idx],dtype=np.int64) label=torch.from_numpy(label).long() return data, label def label2img(self, label): cmap = self.color_map cmap = np.array(cmap).astype(np.uint8) pred = cmap[label] return pred def __len__(self): return len(self.imgs_list) def _split_data_set(self, img_path_list): val_path_list = img_path_list[::8] train_path_list = [] for item in img_path_list: if item not in val_path_list: train_path_list.append(item) return train_path_list, val_path_list if __name__=="__main__": root = "../dataset/carvana" car_train = Car(root,train=True) train_dataloader = DataLoader(car_train, batch_size=8, shuffle=True) print(len(car_train)) print(len(train_dataloader)) # for data, label in car_train: # print(data.shape) # print(label.shape) # break (data, label) = car_train[190] label_np = label.data.numpy() label_im = car_train.label2img(label_np) plt.figure() plt.imshow(label_im) plt.show()
2.3 訓(xùn)練過程
分割其實就是給每個像素分類而已,所以損失函數(shù)依舊是交叉熵函數(shù),正確率為分類正確的像素點個數(shù)/全部的像素點個數(shù)
import torch import torch.nn as nn from torch.utils.data import DataLoader,Dataset from voc import VOC from carnava import Car from unet import Unet import os import numpy as np from torch import optim import torch.nn as nn import util # 計算混淆矩陣 def _fast_hist(label_true, label_pred, n_class): mask = (label_true >= 0) & (label_true < n_class) hist = np.bincount( n_class * label_true[mask].astype(int) + label_pred[mask], minlength=n_class ** 2).reshape(n_class, n_class) return hist def label_accuracy_score(label_trues, label_preds, n_class): """Returns accuracy score evaluation result. - overall accuracy - mean accuracy - mean IU """ hist = np.zeros((n_class, n_class)) for lt, lp in zip(label_trues, label_preds): hist += _fast_hist(lt.flatten(), lp.flatten(), n_class) acc = np.diag(hist).sum() / hist.sum() with np.errstate(divide='ignore', invalid='ignore'): acc_cls = np.diag(hist) / hist.sum(axis=1) acc_cls = np.nanmean(acc_cls) with np.errstate(divide='ignore', invalid='ignore'): iu = np.diag(hist) / ( hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist) ) mean_iu = np.nanmean(iu) freq = hist.sum(axis=1) / hist.sum() return acc, acc_cls, mean_iu out_path = "./out" if not os.path.exists(out_path): os.makedirs(out_path) log_path = os.path.join(out_path, "result.txt") if os.path.exists(log_path): os.remove(log_path) model_path = os.path.join(out_path, "best_model.pth") root = "../dataset/carvana" epochs = 5 numclasses = 2 train_data = Car(root, train=True) train_dataloader = DataLoader(train_data, batch_size=16, shuffle=True) val_data = Car(root, train=False) val_dataloader = DataLoader(val_data, batch_size=16, shuffle=True) net = Unet(3, numclasses) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") net = net.to(device) optimizer = optim.SGD(net.parameters(), lr=0.01, weight_decay=1e-4) criterion = nn.CrossEntropyLoss() def train_model(): best_score = 0.0 for e in range(epochs): net.train() train_loss = 0.0 label_true = torch.LongTensor() label_pred = torch.LongTensor() for batch_id, (data, label) in enumerate(train_dataloader): data, label = data.to(device), label.to(device) output = net(data) loss = criterion(output, label) pred = output.argmax(dim=1).squeeze().data.cpu() real = label.data.cpu() optimizer.zero_grad() loss.backward() optimizer.step() train_loss+=loss.cpu().item() label_true = torch.cat((label_true,real),dim=0) label_pred = torch.cat((label_pred,pred),dim=0) train_loss /= len(train_dataloader) acc, acc_cls, mean_iu = label_accuracy_score(label_true.numpy(),label_pred.numpy(),numclasses) print("\n epoch:{}, train_loss:{:.4f}, acc:{:.4f}, acc_cls:{:.4f}, mean_iu:{:.4f}".format( e+1, train_loss, acc, acc_cls, mean_iu)) with open(log_path, 'a') as f: f.write('\n epoch:{}, train_loss:{:.4f}, acc:{:.4f}, acc_cls:{:.4f}, mean_iu:{:.4f}'.format( e+1,train_loss,acc, acc_cls, mean_iu)) net.eval() val_loss = 0.0 val_label_true = torch.LongTensor() val_label_pred = torch.LongTensor() with torch.no_grad(): for batch_id, (data, label) in enumerate(val_dataloader): data, label = data.to(device), label.to(device) output = net(data) loss = criterion(output, label) pred = output.argmax(dim=1).squeeze().data.cpu() real = label.data.cpu() val_loss += loss.cpu().item() val_label_true = torch.cat((val_label_true, real), dim=0) val_label_pred = torch.cat((val_label_pred, pred), dim=0) val_loss/=len(val_dataloader) val_acc, val_acc_cls, val_mean_iu = label_accuracy_score(val_label_true.numpy(), val_label_pred.numpy(),numclasses) print('\n epoch:{}, val_loss:{:.4f}, acc:{:.4f}, acc_cls:{:.4f}, mean_iu:{:.4f}'.format(e+1, val_loss, val_acc, val_acc_cls, val_mean_iu)) with open(log_path, 'a') as f: f.write('\n epoch:{}, val_loss:{:.4f}, acc:{:.4f}, acc_cls:{:.4f}, mean_iu:{:.4f}'.format( e+1,val_loss,val_acc, val_acc_cls, val_mean_iu)) score = (val_acc_cls+val_mean_iu)/2 if score > best_score: best_score = score torch.save(net.state_dict(), model_path) def evaluate(): import util import random import matplotlib.pyplot as plt net.load_state_dict(torch.load(model_path)) index = random.randint(0, len(val_data)-1) val_image, val_label = val_data[index] out = net(val_image.unsqueeze(0).to(device)) pred = out.argmax(dim=1).squeeze().data.cpu().numpy() label = val_label.data.numpy() img_pred = val_data.label2img(pred) img_label = val_data.label2img(label) temp = val_image.numpy() temp = (temp-np.min(temp)) / (np.max(temp)-np.min(temp))*255 fig, ax = plt.subplots(1,3) ax[0].imshow(temp.transpose(1,2,0).astype("uint8")) ax[1].imshow(img_label) ax[2].imshow(img_pred) plt.show() if __name__=="__main__": # train_model() evaluate()
最終訓(xùn)練結(jié)果是:
由于數(shù)據(jù)比較簡單,訓(xùn)練到epoch為5時,mIOU就已經(jīng)達(dá)到0.97了。
最后測試一下效果:
從左到右分別是:原圖、真實label、預(yù)測label
備注:
其實最開始使用voc數(shù)據(jù)集訓(xùn)練的,但效果極差,也沒發(fā)現(xiàn)哪里有問題。換個數(shù)據(jù)集效果就好了,可能有兩個原因:
1. voc數(shù)據(jù)我在處理數(shù)據(jù)時出錯了,沒檢查出來
2. 這個數(shù)據(jù)集比較簡單,容易學(xué)習(xí),所以效果差不多。
到此這篇關(guān)于pytorch通過自己的數(shù)據(jù)集訓(xùn)練Unet網(wǎng)絡(luò)架構(gòu)的文章就介紹到這了,更多相關(guān)pytorch Unet內(nèi)容請搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
相關(guān)文章
Python實現(xiàn)批量讀取圖片并存入mongodb數(shù)據(jù)庫的方法示例
這篇文章主要介紹了Python實現(xiàn)批量讀取圖片并存入mongodb數(shù)據(jù)庫的方法,涉及Python文件讀取及數(shù)據(jù)庫寫入相關(guān)操作技巧,需要的朋友可以參考下2018-04-04Django在視圖中使用表單并和數(shù)據(jù)庫進(jìn)行數(shù)據(jù)交互的實現(xiàn)
本文主要介紹了Django在視圖中使用表單并和數(shù)據(jù)庫進(jìn)行數(shù)據(jù)交互,文中通過示例代碼介紹的非常詳細(xì),對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧2022-07-07pytorch繪制并顯示loss曲線和acc曲線,LeNet5識別圖像準(zhǔn)確率
今天小編就為大家分享一篇pytorch繪制并顯示loss曲線和acc曲線,LeNet5識別圖像準(zhǔn)確率,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧2020-01-01CentOS 7 安裝python3.7.1的方法及注意事項
這篇文章主要介紹了CentOS 7 安裝python3.7.1的方法,文中給大家提到了注意事項,需要的朋友可以參考下2018-11-11