PyTorch中多對(duì)象分割項(xiàng)目的實(shí)現(xiàn)
對(duì)象分割任務(wù)的目標(biāo)是找到圖像中目標(biāo)對(duì)象的邊界。實(shí)際應(yīng)用例如自動(dòng)駕駛汽車和醫(yī)學(xué)成像分析。這里將使用PyTorch開發(fā)一個(gè)深度學(xué)習(xí)模型來完成多對(duì)象分割任務(wù)。多對(duì)象分割的主要目標(biāo)是自動(dòng)勾勒出圖像中多個(gè)目標(biāo)對(duì)象的邊界。
對(duì)象的邊界通常由與圖像大小相同的分割掩碼定義,在分割掩碼中屬于目標(biāo)對(duì)象的所有像素基于預(yù)定義的標(biāo)記被標(biāo)記為相同。
創(chuàng)建數(shù)據(jù)集
from torchvision.datasets import VOCSegmentation
from PIL import Image
from torchvision.transforms.functional import to_tensor, to_pil_image
class myVOCSegmentation(VOCSegmentation):
def __getitem__(self, index):
img = Image.open(self.images[index]).convert('RGB')
target = Image.open(self.masks[index])
if self.transforms is not None:
augmented= self.transforms(image=np.array(img), mask=np.array(target))
img = augmented['image']
target = augmented['mask']
target[target>20]=0
img= to_tensor(img)
target= torch.from_numpy(target).type(torch.long)
return img, target
from albumentations import (
HorizontalFlip,
Compose,
Resize,
Normalize)
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
h,w=520,520
transform_train = Compose([ Resize(h,w),
HorizontalFlip(p=0.5),
Normalize(mean=mean,std=std)])
transform_val = Compose( [ Resize(h,w),
Normalize(mean=mean,std=std)])
path2data="./data/"
train_ds=myVOCSegmentation(path2data,
year='2012',
image_set='train',
download=False,
transforms=transform_train)
print(len(train_ds))
val_ds=myVOCSegmentation(path2data,
year='2012',
image_set='val',
download=False,
transforms=transform_val)
print(len(val_ds))
import torch
import numpy as np
from skimage.segmentation import mark_boundaries
import matplotlib.pylab as plt
%matplotlib inline
np.random.seed(0)
num_classes=21
COLORS = np.random.randint(0, 2, size=(num_classes+1, 3),dtype="uint8")
def show_img_target(img, target):
if torch.is_tensor(img):
img=to_pil_image(img)
target=target.numpy()
for ll in range(num_classes):
mask=(target==ll)
img=mark_boundaries(np.array(img) ,
mask,
outline_color=COLORS[ll],
color=COLORS[ll])
plt.imshow(img)
def re_normalize (x, mean = mean, std= std):
x_r= x.clone()
for c, (mean_c, std_c) in enumerate(zip(mean, std)):
x_r [c] *= std_c
x_r [c] += mean_c
return x_r展示訓(xùn)練數(shù)據(jù)集示例圖像
img, mask = train_ds[10]
print(img.shape, img.type(),torch.max(img))
print(mask.shape, mask.type(),torch.max(mask))
plt.figure(figsize=(20,20))
img_r= re_normalize(img)
plt.subplot(1, 3, 1)
plt.imshow(to_pil_image(img_r))
plt.subplot(1, 3, 2)
plt.imshow(mask)
plt.subplot(1, 3, 3)
show_img_target(img_r, mask)


展示驗(yàn)證數(shù)據(jù)集示例圖像
img, mask = val_ds[10] print(img.shape, img.type(),torch.max(img)) print(mask.shape, mask.type(),torch.max(mask)) plt.figure(figsize=(20,20)) img_r= re_normalize(img) plt.subplot(1, 3, 1) plt.imshow(to_pil_image(img_r)) plt.subplot(1, 3, 2) plt.imshow(mask) plt.subplot(1, 3, 3) show_img_target(img_r, mask)


創(chuàng)建數(shù)據(jù)加載器
通過torch.utils.data針對(duì)訓(xùn)練和驗(yàn)證集分別創(chuàng)建Dataloader,打印示例觀察效果
from torch.utils.data import DataLoader
train_dl = DataLoader(train_ds, batch_size=4, shuffle=True)
val_dl = DataLoader(val_ds, batch_size=8, shuffle=False)
for img_b, mask_b in train_dl:
print(img_b.shape,img_b.dtype)
print(mask_b.shape, mask_b.dtype)
break
for img_b, mask_b in val_dl:
print(img_b.shape,img_b.dtype)
print(mask_b.shape, mask_b.dtype)
break

創(chuàng)建模型
創(chuàng)建并打印deeplab_resnet模型結(jié)構(gòu),使用預(yù)訓(xùn)練權(quán)重
from torchvision.models.segmentation import deeplabv3_resnet101
import torch
model=deeplabv3_resnet101(pretrained=True, num_classes=21)
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model=model.to(device)
print(model)
部署模型
在驗(yàn)證數(shù)據(jù)集的數(shù)據(jù)批次上部署模型觀察效果
from torch import nn
model.eval()
with torch.no_grad():
for xb, yb in val_dl:
yb_pred = model(xb.to(device))
yb_pred = yb_pred["out"].cpu()
print(yb_pred.shape)
yb_pred = torch.argmax(yb_pred,axis=1)
break
print(yb_pred.shape)
plt.figure(figsize=(20,20))
n=2
img, mask= xb[n], yb_pred[n]
img_r= re_normalize(img)
plt.subplot(1, 3, 1)
plt.imshow(to_pil_image(img_r))
plt.subplot(1, 3, 2)
plt.imshow(mask)
plt.subplot(1, 3, 3)
show_img_target(img_r, mask)可見勾勒對(duì)象方面效果很好

定義損失函數(shù)和優(yōu)化器
from torch import nn criterion = nn.CrossEntropyLoss(reduction="sum")
from torch import optim
opt = optim.Adam(model.parameters(), lr=1e-6)
def loss_batch(loss_func, output, target, opt=None):
loss = loss_func(output, target)
if opt is not None:
opt.zero_grad()
loss.backward()
opt.step()
return loss.item(), None
from torch.optim.lr_scheduler import ReduceLROnPlateau
lr_scheduler = ReduceLROnPlateau(opt, mode='min',factor=0.5, patience=20,verbose=1)
def get_lr(opt):
for param_group in opt.param_groups:
return param_group['lr']
current_lr=get_lr(opt)
print('current lr={}'.format(current_lr))
訓(xùn)練和驗(yàn)證模型
def loss_epoch(model,loss_func,dataset_dl,sanity_check=False,opt=None):
running_loss=0.0
len_data=len(dataset_dl.dataset)
for xb, yb in dataset_dl:
xb=xb.to(device)
yb=yb.to(device)
output=model(xb)["out"]
loss_b, _ = loss_batch(loss_func, output, yb, opt)
running_loss += loss_b
if sanity_check is True:
break
loss=running_loss/float(len_data)
return loss, None
import copy
def train_val(model, params):
num_epochs=params["num_epochs"]
loss_func=params["loss_func"]
opt=params["optimizer"]
train_dl=params["train_dl"]
val_dl=params["val_dl"]
sanity_check=params["sanity_check"]
lr_scheduler=params["lr_scheduler"]
path2weights=params["path2weights"]
loss_history={
"train": [],
"val": []}
metric_history={
"train": [],
"val": []}
best_model_wts = copy.deepcopy(model.state_dict())
best_loss=float('inf')
for epoch in range(num_epochs):
current_lr=get_lr(opt)
print('Epoch {}/{}, current lr={}'.format(epoch, num_epochs - 1, current_lr))
model.train()
train_loss, train_metric=loss_epoch(model,loss_func,train_dl,sanity_check,opt)
loss_history["train"].append(train_loss)
metric_history["train"].append(train_metric)
model.eval()
with torch.no_grad():
val_loss, val_metric=loss_epoch(model,loss_func,val_dl,sanity_check)
loss_history["val"].append(val_loss)
metric_history["val"].append(val_metric)
if val_loss < best_loss:
best_loss = val_loss
best_model_wts = copy.deepcopy(model.state_dict())
torch.save(model.state_dict(), path2weights)
print("Copied best model weights!")
lr_scheduler.step(val_loss)
if current_lr != get_lr(opt):
print("Loading best model weights!")
model.load_state_dict(best_model_wts)
print("train loss: %.6f" %(train_loss))
print("val loss: %.6f" %(val_loss))
print("-"*10)
model.load_state_dict(best_model_wts)
return model, loss_history, metric_history
import os
opt = optim.Adam(model.parameters(), lr=1e-6)
lr_scheduler = ReduceLROnPlateau(opt, mode='min',factor=0.5, patience=20,verbose=1)
path2models= "./models/"
if not os.path.exists(path2models):
os.mkdir(path2models)
params_train={
"num_epochs": 10,
"optimizer": opt,
"loss_func": criterion,
"train_dl": train_dl,
"val_dl": val_dl,
"sanity_check": True,
"lr_scheduler": lr_scheduler,
"path2weights": path2models+"sanity_weights.pt",
}
model, loss_hist, _ = train_val(model, params_train)
繪制了訓(xùn)練和驗(yàn)證損失曲線
num_epochs=params_train["num_epochs"]
plt.title("Train-Val Loss")
plt.plot(range(1,num_epochs+1),loss_hist["train"],label="train")
plt.plot(range(1,num_epochs+1),loss_hist["val"],label="val")
plt.ylabel("Loss")
plt.xlabel("Training Epochs")
plt.legend()
plt.show()
到此這篇關(guān)于PyTorch中多對(duì)象分割項(xiàng)目的實(shí)現(xiàn)的文章就介紹到這了,更多相關(guān)PyTorch 多對(duì)象分割項(xiàng)目?jī)?nèi)容請(qǐng)搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
相關(guān)文章
Python實(shí)例一個(gè)類背后發(fā)生了什么
Python實(shí)例一個(gè)類背后發(fā)生了什么,本文為大家一一列出,感興趣的朋友可以參考一下2016-02-02
python numpy矩陣信息說明,shape,size,dtype
這篇文章主要介紹了python numpy矩陣信息說明,shape,size,dtype,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過來看看吧2020-05-05
詳解python 字符串和日期之間轉(zhuǎn)換 StringAndDate
這篇文章主要介紹了python 字符串和日期之間轉(zhuǎn)換 StringAndDate簡(jiǎn)單實(shí)例的相關(guān)資料,需要的朋友可以參考下2017-05-05
教你用Python寫一個(gè)植物大戰(zhàn)僵尸小游戲
這篇文章主要介紹了教你用Python寫一個(gè)植物大戰(zhàn)僵尸小游戲,文中有非常詳細(xì)的代碼示例,對(duì)正在學(xué)習(xí)python的小伙伴們有非常好的幫助,需要的朋友可以參考下2021-04-04
Python第三方庫xlrd/xlwt的安裝與讀寫Excel表格
最近開始學(xué)習(xí)python,想做做簡(jiǎn)單的自動(dòng)化測(cè)試,需要讀寫excel,于是就接觸到了Python的第三方庫xlrd和xlwt,下面這篇文章就給大家主要介紹了Python中第三方庫xlrd/xlwt的安裝與讀寫Excel表格的方法,需要的朋友可以參考借鑒。2017-01-01
matplotlib之pyplot模塊坐標(biāo)軸標(biāo)簽設(shè)置使用(xlabel()、ylabel())
這篇文章主要介紹了matplotlib之pyplot模塊坐標(biāo)軸標(biāo)簽設(shè)置使用(xlabel()、ylabel()),文中通過示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧2021-02-02
深入淺析Python科學(xué)計(jì)算庫Scipy及安裝步驟
這篇文章主要介紹了Python科學(xué)計(jì)算庫—Scipy的相關(guān)知識(shí),非常不錯(cuò),具有一定的參考借鑒價(jià)值,需要的朋友可以參考下2019-10-10

