欧美bbbwbbbw肥妇,免费乱码人妻系列日韩,一级黄片

pytorch通過自己的數(shù)據(jù)集訓(xùn)練Unet網(wǎng)絡(luò)架構(gòu)

 更新時間:2022年12月08日 09:15:34   作者:專業(yè)女神殺手  
Unet是一個最近比較火的網(wǎng)絡(luò)結(jié)構(gòu)。它的理論已經(jīng)有很多大佬在討論了。本文主要從實際操作的層面,講解如何使用pytorch實現(xiàn)unet圖像分割

在圖像分割這個問題上,主要有兩個流派:Encoder-Decoder和Dialated Conv。本文介紹的是編解碼網(wǎng)絡(luò)中最為經(jīng)典的U-Net。隨著骨干網(wǎng)路的進(jìn)化,很多相應(yīng)衍生出來的網(wǎng)絡(luò)大多都是對于Unet進(jìn)行了改進(jìn)但是本質(zhì)上的思路還是沒有太多的變化。比如結(jié)合DenseNet 和Unet的FCDenseNet, Unet++

一、Unet網(wǎng)絡(luò)介紹

論文:https://arxiv.org/abs/1505.04597v1(2015)

UNet的設(shè)計就是應(yīng)用與醫(yī)學(xué)圖像的分割。由于醫(yī)學(xué)影像處理中,數(shù)據(jù)量較少,本文提出的方法有效提升了使用少量數(shù)據(jù)集訓(xùn)練檢測的效果,提出了處理大尺寸圖像的有效方法。

UNet的網(wǎng)絡(luò)架構(gòu)繼承自FCN,并在此基礎(chǔ)上做了些改變。提出了Encoder-Decoder概念,實際上就是FCN那個先卷積再上采樣的思想。

上圖是Unet的網(wǎng)絡(luò)結(jié)構(gòu),從圖中可以看出,

結(jié)構(gòu)左邊為Encoder,即下采樣提取特征的過程。Encoder基本模塊為雙卷積形式,即輸入經(jīng)過兩個

conu 3x3,使用的valid卷積,在代碼實現(xiàn)時我們可以增加padding使用same卷積,來適應(yīng)Skip Architecture。下采樣采用的池化層直接縮小2倍。

結(jié)構(gòu)右邊是Decoder,即上采樣恢復(fù)圖像尺寸并預(yù)測的過程。Decoder一樣采用雙卷積的形式,其中上采樣使用轉(zhuǎn)置卷積實現(xiàn),每次轉(zhuǎn)置卷積放大2倍。

結(jié)構(gòu)中間copy and crop是一個cat操作,即feature map的通道疊加。

二、VOC訓(xùn)練Unet

2.1 Unet代碼實現(xiàn)

根據(jù)上面對于Unet網(wǎng)絡(luò)結(jié)構(gòu)的介紹,可見其結(jié)構(gòu)非常對稱簡單,代碼Unet.py實現(xiàn)如下:

from turtle import forward
import torch.nn as nn
import torch
class DoubleConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )
    def forward(self, x):
        return self.conv(x)
class Unet(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(Unet, self).__init__()
        # Encoder
        self.conv1 = DoubleConv(in_ch, 64)
        self.pool1 = nn.MaxPool2d(2)
        self.conv2 = DoubleConv(64, 128)
        self.pool2 = nn.MaxPool2d(2)
        self.conv3 = DoubleConv(128, 256)
        self.pool3 = nn.MaxPool2d(2)
        self.conv4 = DoubleConv(256, 512)
        self.pool4 = nn.MaxPool2d(2)
        self.conv5 = DoubleConv(512, 1024)
        # Decoder
        self.up6 = nn.ConvTranspose2d(1024, 512, 2, stride=2)
        self.conv6 = DoubleConv(1024, 512)
        self.up7 = nn.ConvTranspose2d(512, 256, 2, stride=2)
        self.conv7 = DoubleConv(512, 256)
        self.up8 = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.conv8 = DoubleConv(256, 128)
        self.up9 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.conv9 = DoubleConv(128, 64)
        self.output = nn.Conv2d(64, out_ch, 1)
    def forward(self, x):
        conv1 = self.conv1(x)
        pool1 = self.pool1(conv1)
        conv2 = self.conv2(pool1)
        pool2 = self.pool2(conv2)
        conv3 = self.conv3(pool2)
        pool3 = self.pool3(conv3)
        conv4 = self.conv4(pool3)
        pool4 = self.pool4(conv4)
        conv5 = self.conv5(pool4)
        up6 = self.up6(conv5)
        meger6 = torch.cat([up6, conv4], dim=1)
        conv6 = self.conv6(meger6)
        up7 = self.up7(conv6)
        meger7 = torch.cat([up7, conv3], dim=1)
        conv7 = self.conv7(meger7)
        up8 = self.up8(conv7)
        meger8 = torch.cat([up8, conv2], dim=1)
        conv8 = self.conv8(meger8)
        up9 = self.up9(conv8)
        meger9 = torch.cat([up9, conv1], dim=1)
        conv9 = self.conv9(meger9)
        out = self.output(conv9)
        return out
if __name__=="__main__":
    model = Unet(3, 21)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    print(model)

2.2 數(shù)據(jù)集處理

數(shù)據(jù)來源于kaggle,下載地址我忘了。包含2個類別,1個車,還有1個背景類,共有5k+的數(shù)據(jù),按照比例分為訓(xùn)練集和驗證集即可。具體見carnava.py

from PIL import Image
from requests import check_compatibility
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms as T
import numpy as np
import os
import matplotlib.pyplot as plt
class Car(Dataset):
    def __init__(self, root, train=True):
        self.root = root
        self.crop_size = (256, 256)
        self.img_path = os.path.join(root, "train_hq")
        self.label_path = os.path.join(root, "train_masks")
        img_path_list = [os.path.join(self.img_path, im) for im in os.listdir(self.img_path)]
        train_path_list, val_path_list = self._split_data_set(img_path_list)
        if train:
            self.imgs_list = train_path_list
        else:
            self.imgs_list = val_path_list
        normalize = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        self.transforms = T.Compose([
                T.Resize(256),
                T.CenterCrop(256),
                T.ToTensor(),
                normalize
            ])
        self.transforms_val = T.Compose([
            T.Resize(256),
            T.CenterCrop(256)
        ])
        self.color_map = [[0, 0, 0], [255, 255, 255]]
    def __getitem__(self, index: int):
        im_path = self.imgs_list[index]
        image = Image.open(im_path).convert("RGB")
        data = self.transforms(image)
        (filepath, filename) = os.path.split(im_path)
        filename = filename.split('.')[0]
        label = Image.open(self.label_path +"/"+filename+"_mask.gif").convert("RGB")
        label = self.transforms_val(label)
        cm2lb=np.zeros(256**3)
        for i,cm in enumerate(self.color_map):
            cm2lb[(cm[0]*256+cm[1])*256+cm[2]]=i
        image=np.array(label,dtype=np.int64)
        idx=(image[:,:,0]*256+image[:,:,1])*256+image[:,:,2]
        label=np.array(cm2lb[idx],dtype=np.int64)
        label=torch.from_numpy(label).long()
        return data, label
    def label2img(self, label):
        cmap = self.color_map
        cmap = np.array(cmap).astype(np.uint8)
        pred = cmap[label]
        return pred
    def __len__(self):
        return len(self.imgs_list)
    def _split_data_set(self, img_path_list):
        val_path_list = img_path_list[::8]
        train_path_list = []
        for item in img_path_list:
            if item not in val_path_list:
                train_path_list.append(item)
        return train_path_list, val_path_list
if __name__=="__main__":
    root = "../dataset/carvana"
    car_train = Car(root,train=True)
    train_dataloader = DataLoader(car_train, batch_size=8, shuffle=True)
    print(len(car_train))
    print(len(train_dataloader))
    # for data, label in car_train:
    #     print(data.shape)
    #     print(label.shape)
    #     break
    (data, label) = car_train[190]
    label_np = label.data.numpy()
    label_im = car_train.label2img(label_np)
    plt.figure()
    plt.imshow(label_im)
    plt.show()

2.3 訓(xùn)練過程

分割其實就是給每個像素分類而已,所以損失函數(shù)依舊是交叉熵函數(shù),正確率為分類正確的像素點個數(shù)/全部的像素點個數(shù)

import torch
import torch.nn as nn
from torch.utils.data import DataLoader,Dataset
from voc import VOC
from carnava import Car
from unet import Unet
import os
import numpy as np
from torch import optim
import torch.nn as nn
import util
# 計算混淆矩陣
def _fast_hist(label_true, label_pred, n_class):
    mask = (label_true >= 0) & (label_true < n_class)
    hist = np.bincount(
        n_class * label_true[mask].astype(int) +
        label_pred[mask], minlength=n_class ** 2).reshape(n_class, n_class)
    return hist
def label_accuracy_score(label_trues, label_preds, n_class):
    """Returns accuracy score evaluation result.
      - overall accuracy
      - mean accuracy
      - mean IU
    """
    hist = np.zeros((n_class, n_class))
    for lt, lp in zip(label_trues, label_preds):
        hist += _fast_hist(lt.flatten(), lp.flatten(), n_class)
    acc = np.diag(hist).sum() / hist.sum()
    with np.errstate(divide='ignore', invalid='ignore'):
        acc_cls = np.diag(hist) / hist.sum(axis=1)
    acc_cls = np.nanmean(acc_cls)
    with np.errstate(divide='ignore', invalid='ignore'):
        iu = np.diag(hist) / (
            hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist)
        )
    mean_iu = np.nanmean(iu)
    freq = hist.sum(axis=1) / hist.sum()
    return acc, acc_cls, mean_iu
out_path = "./out"
if not os.path.exists(out_path):
    os.makedirs(out_path)
log_path = os.path.join(out_path, "result.txt")
if os.path.exists(log_path):
    os.remove(log_path)
model_path = os.path.join(out_path, "best_model.pth")
root = "../dataset/carvana"
epochs = 5
numclasses = 2
train_data = Car(root, train=True)
train_dataloader = DataLoader(train_data, batch_size=16, shuffle=True)
val_data = Car(root, train=False)
val_dataloader = DataLoader(val_data, batch_size=16, shuffle=True)
net = Unet(3, numclasses)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
net = net.to(device)
optimizer = optim.SGD(net.parameters(), lr=0.01, weight_decay=1e-4)
criterion = nn.CrossEntropyLoss()
def train_model():
    best_score = 0.0
    for e in range(epochs):
        net.train()
        train_loss = 0.0
        label_true = torch.LongTensor()
        label_pred = torch.LongTensor()
        for batch_id, (data, label) in enumerate(train_dataloader):
            data, label = data.to(device), label.to(device)
            output = net(data)
            loss = criterion(output, label)
            pred = output.argmax(dim=1).squeeze().data.cpu()
            real = label.data.cpu()
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_loss+=loss.cpu().item()
            label_true = torch.cat((label_true,real),dim=0)
            label_pred = torch.cat((label_pred,pred),dim=0)
        train_loss /= len(train_dataloader)
        acc, acc_cls, mean_iu = label_accuracy_score(label_true.numpy(),label_pred.numpy(),numclasses)
        print("\n epoch:{}, train_loss:{:.4f}, acc:{:.4f}, acc_cls:{:.4f}, mean_iu:{:.4f}".format(
            e+1, train_loss, acc, acc_cls, mean_iu))
        with open(log_path, 'a') as f:
            f.write('\n epoch:{}, train_loss:{:.4f}, acc:{:.4f}, acc_cls:{:.4f}, mean_iu:{:.4f}'.format(
                e+1,train_loss,acc, acc_cls, mean_iu))
        net.eval()
        val_loss = 0.0
        val_label_true = torch.LongTensor()
        val_label_pred = torch.LongTensor()
        with torch.no_grad():
            for batch_id, (data, label) in enumerate(val_dataloader):
                data, label = data.to(device), label.to(device)
                output = net(data)
                loss = criterion(output, label)
                pred = output.argmax(dim=1).squeeze().data.cpu()
                real = label.data.cpu()
                val_loss += loss.cpu().item()
                val_label_true = torch.cat((val_label_true, real), dim=0)
                val_label_pred = torch.cat((val_label_pred, pred), dim=0)
            val_loss/=len(val_dataloader)
            val_acc, val_acc_cls, val_mean_iu = label_accuracy_score(val_label_true.numpy(),
                                                                    val_label_pred.numpy(),numclasses)
        print('\n epoch:{}, val_loss:{:.4f}, acc:{:.4f}, acc_cls:{:.4f}, mean_iu:{:.4f}'.format(e+1, val_loss, val_acc, val_acc_cls, val_mean_iu))
        with open(log_path, 'a') as f:
            f.write('\n epoch:{}, val_loss:{:.4f}, acc:{:.4f}, acc_cls:{:.4f}, mean_iu:{:.4f}'.format(
            e+1,val_loss,val_acc, val_acc_cls, val_mean_iu))
        score = (val_acc_cls+val_mean_iu)/2
        if score > best_score:
            best_score = score
            torch.save(net.state_dict(), model_path)
def evaluate():
    import util
    import random
    import matplotlib.pyplot as plt
    net.load_state_dict(torch.load(model_path))
    index = random.randint(0, len(val_data)-1)
    val_image, val_label = val_data[index]
    out = net(val_image.unsqueeze(0).to(device))
    pred = out.argmax(dim=1).squeeze().data.cpu().numpy()
    label = val_label.data.numpy()
    img_pred = val_data.label2img(pred)
    img_label = val_data.label2img(label)
    temp = val_image.numpy()
    temp = (temp-np.min(temp)) / (np.max(temp)-np.min(temp))*255
    fig, ax = plt.subplots(1,3)
    ax[0].imshow(temp.transpose(1,2,0).astype("uint8"))
    ax[1].imshow(img_label)
    ax[2].imshow(img_pred)
    plt.show()
if __name__=="__main__":
    # train_model()
    evaluate()

最終訓(xùn)練結(jié)果是:

由于數(shù)據(jù)比較簡單,訓(xùn)練到epoch為5時,mIOU就已經(jīng)達(dá)到0.97了。

最后測試一下效果:

從左到右分別是:原圖、真實label、預(yù)測label

備注:

其實最開始使用voc數(shù)據(jù)集訓(xùn)練的,但效果極差,也沒發(fā)現(xiàn)哪里有問題。換個數(shù)據(jù)集效果就好了,可能有兩個原因:

1. voc數(shù)據(jù)我在處理數(shù)據(jù)時出錯了,沒檢查出來

2. 這個數(shù)據(jù)集比較簡單,容易學(xué)習(xí),所以效果差不多。

到此這篇關(guān)于pytorch通過自己的數(shù)據(jù)集訓(xùn)練Unet網(wǎng)絡(luò)架構(gòu)的文章就介紹到這了,更多相關(guān)pytorch Unet內(nèi)容請搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!

相關(guān)文章

最新評論