PyTorch中多對(duì)象分割項(xiàng)目的實(shí)現(xiàn)
對(duì)象分割任務(wù)的目標(biāo)是找到圖像中目標(biāo)對(duì)象的邊界。實(shí)際應(yīng)用例如自動(dòng)駕駛汽車和醫(yī)學(xué)成像分析。這里將使用PyTorch開(kāi)發(fā)一個(gè)深度學(xué)習(xí)模型來(lái)完成多對(duì)象分割任務(wù)。多對(duì)象分割的主要目標(biāo)是自動(dòng)勾勒出圖像中多個(gè)目標(biāo)對(duì)象的邊界。
對(duì)象的邊界通常由與圖像大小相同的分割掩碼定義,在分割掩碼中屬于目標(biāo)對(duì)象的所有像素基于預(yù)定義的標(biāo)記被標(biāo)記為相同。
創(chuàng)建數(shù)據(jù)集
from torchvision.datasets import VOCSegmentation from PIL import Image from torchvision.transforms.functional import to_tensor, to_pil_image class myVOCSegmentation(VOCSegmentation): def __getitem__(self, index): img = Image.open(self.images[index]).convert('RGB') target = Image.open(self.masks[index]) if self.transforms is not None: augmented= self.transforms(image=np.array(img), mask=np.array(target)) img = augmented['image'] target = augmented['mask'] target[target>20]=0 img= to_tensor(img) target= torch.from_numpy(target).type(torch.long) return img, target from albumentations import ( HorizontalFlip, Compose, Resize, Normalize) mean = [0.485, 0.456, 0.406] std = [0.229, 0.224, 0.225] h,w=520,520 transform_train = Compose([ Resize(h,w), HorizontalFlip(p=0.5), Normalize(mean=mean,std=std)]) transform_val = Compose( [ Resize(h,w), Normalize(mean=mean,std=std)]) path2data="./data/" train_ds=myVOCSegmentation(path2data, year='2012', image_set='train', download=False, transforms=transform_train) print(len(train_ds)) val_ds=myVOCSegmentation(path2data, year='2012', image_set='val', download=False, transforms=transform_val) print(len(val_ds))
import torch import numpy as np from skimage.segmentation import mark_boundaries import matplotlib.pylab as plt %matplotlib inline np.random.seed(0) num_classes=21 COLORS = np.random.randint(0, 2, size=(num_classes+1, 3),dtype="uint8") def show_img_target(img, target): if torch.is_tensor(img): img=to_pil_image(img) target=target.numpy() for ll in range(num_classes): mask=(target==ll) img=mark_boundaries(np.array(img) , mask, outline_color=COLORS[ll], color=COLORS[ll]) plt.imshow(img) def re_normalize (x, mean = mean, std= std): x_r= x.clone() for c, (mean_c, std_c) in enumerate(zip(mean, std)): x_r [c] *= std_c x_r [c] += mean_c return x_r
展示訓(xùn)練數(shù)據(jù)集示例圖像
img, mask = train_ds[10] print(img.shape, img.type(),torch.max(img)) print(mask.shape, mask.type(),torch.max(mask)) plt.figure(figsize=(20,20)) img_r= re_normalize(img) plt.subplot(1, 3, 1) plt.imshow(to_pil_image(img_r)) plt.subplot(1, 3, 2) plt.imshow(mask) plt.subplot(1, 3, 3) show_img_target(img_r, mask)
展示驗(yàn)證數(shù)據(jù)集示例圖像
img, mask = val_ds[10] print(img.shape, img.type(),torch.max(img)) print(mask.shape, mask.type(),torch.max(mask)) plt.figure(figsize=(20,20)) img_r= re_normalize(img) plt.subplot(1, 3, 1) plt.imshow(to_pil_image(img_r)) plt.subplot(1, 3, 2) plt.imshow(mask) plt.subplot(1, 3, 3) show_img_target(img_r, mask)
創(chuàng)建數(shù)據(jù)加載器
通過(guò)torch.utils.data針對(duì)訓(xùn)練和驗(yàn)證集分別創(chuàng)建Dataloader,打印示例觀察效果
from torch.utils.data import DataLoader train_dl = DataLoader(train_ds, batch_size=4, shuffle=True) val_dl = DataLoader(val_ds, batch_size=8, shuffle=False) for img_b, mask_b in train_dl: print(img_b.shape,img_b.dtype) print(mask_b.shape, mask_b.dtype) break for img_b, mask_b in val_dl: print(img_b.shape,img_b.dtype) print(mask_b.shape, mask_b.dtype) break
創(chuàng)建模型
創(chuàng)建并打印deeplab_resnet模型結(jié)構(gòu),使用預(yù)訓(xùn)練權(quán)重
from torchvision.models.segmentation import deeplabv3_resnet101 import torch model=deeplabv3_resnet101(pretrained=True, num_classes=21) device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') model=model.to(device) print(model)
部署模型
在驗(yàn)證數(shù)據(jù)集的數(shù)據(jù)批次上部署模型觀察效果
from torch import nn model.eval() with torch.no_grad(): for xb, yb in val_dl: yb_pred = model(xb.to(device)) yb_pred = yb_pred["out"].cpu() print(yb_pred.shape) yb_pred = torch.argmax(yb_pred,axis=1) break print(yb_pred.shape) plt.figure(figsize=(20,20)) n=2 img, mask= xb[n], yb_pred[n] img_r= re_normalize(img) plt.subplot(1, 3, 1) plt.imshow(to_pil_image(img_r)) plt.subplot(1, 3, 2) plt.imshow(mask) plt.subplot(1, 3, 3) show_img_target(img_r, mask)
可見(jiàn)勾勒對(duì)象方面效果很好
定義損失函數(shù)和優(yōu)化器
from torch import nn criterion = nn.CrossEntropyLoss(reduction="sum")
from torch import optim opt = optim.Adam(model.parameters(), lr=1e-6) def loss_batch(loss_func, output, target, opt=None): loss = loss_func(output, target) if opt is not None: opt.zero_grad() loss.backward() opt.step() return loss.item(), None from torch.optim.lr_scheduler import ReduceLROnPlateau lr_scheduler = ReduceLROnPlateau(opt, mode='min',factor=0.5, patience=20,verbose=1) def get_lr(opt): for param_group in opt.param_groups: return param_group['lr'] current_lr=get_lr(opt) print('current lr={}'.format(current_lr))
訓(xùn)練和驗(yàn)證模型
def loss_epoch(model,loss_func,dataset_dl,sanity_check=False,opt=None): running_loss=0.0 len_data=len(dataset_dl.dataset) for xb, yb in dataset_dl: xb=xb.to(device) yb=yb.to(device) output=model(xb)["out"] loss_b, _ = loss_batch(loss_func, output, yb, opt) running_loss += loss_b if sanity_check is True: break loss=running_loss/float(len_data) return loss, None import copy def train_val(model, params): num_epochs=params["num_epochs"] loss_func=params["loss_func"] opt=params["optimizer"] train_dl=params["train_dl"] val_dl=params["val_dl"] sanity_check=params["sanity_check"] lr_scheduler=params["lr_scheduler"] path2weights=params["path2weights"] loss_history={ "train": [], "val": []} metric_history={ "train": [], "val": []} best_model_wts = copy.deepcopy(model.state_dict()) best_loss=float('inf') for epoch in range(num_epochs): current_lr=get_lr(opt) print('Epoch {}/{}, current lr={}'.format(epoch, num_epochs - 1, current_lr)) model.train() train_loss, train_metric=loss_epoch(model,loss_func,train_dl,sanity_check,opt) loss_history["train"].append(train_loss) metric_history["train"].append(train_metric) model.eval() with torch.no_grad(): val_loss, val_metric=loss_epoch(model,loss_func,val_dl,sanity_check) loss_history["val"].append(val_loss) metric_history["val"].append(val_metric) if val_loss < best_loss: best_loss = val_loss best_model_wts = copy.deepcopy(model.state_dict()) torch.save(model.state_dict(), path2weights) print("Copied best model weights!") lr_scheduler.step(val_loss) if current_lr != get_lr(opt): print("Loading best model weights!") model.load_state_dict(best_model_wts) print("train loss: %.6f" %(train_loss)) print("val loss: %.6f" %(val_loss)) print("-"*10) model.load_state_dict(best_model_wts) return model, loss_history, metric_history
import os opt = optim.Adam(model.parameters(), lr=1e-6) lr_scheduler = ReduceLROnPlateau(opt, mode='min',factor=0.5, patience=20,verbose=1) path2models= "./models/" if not os.path.exists(path2models): os.mkdir(path2models) params_train={ "num_epochs": 10, "optimizer": opt, "loss_func": criterion, "train_dl": train_dl, "val_dl": val_dl, "sanity_check": True, "lr_scheduler": lr_scheduler, "path2weights": path2models+"sanity_weights.pt", } model, loss_hist, _ = train_val(model, params_train)
繪制了訓(xùn)練和驗(yàn)證損失曲線
num_epochs=params_train["num_epochs"] plt.title("Train-Val Loss") plt.plot(range(1,num_epochs+1),loss_hist["train"],label="train") plt.plot(range(1,num_epochs+1),loss_hist["val"],label="val") plt.ylabel("Loss") plt.xlabel("Training Epochs") plt.legend() plt.show()
到此這篇關(guān)于PyTorch中多對(duì)象分割項(xiàng)目的實(shí)現(xiàn)的文章就介紹到這了,更多相關(guān)PyTorch 多對(duì)象分割項(xiàng)目?jī)?nèi)容請(qǐng)搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
相關(guān)文章
pycharm中cv2的package安裝失敗問(wèn)題及解決
這篇文章主要介紹了pycharm中cv2的package安裝失敗問(wèn)題及解決方案,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。如有錯(cuò)誤或未考慮完全的地方,望不吝賜教2023-05-05Python實(shí)例一個(gè)類背后發(fā)生了什么
Python實(shí)例一個(gè)類背后發(fā)生了什么,本文為大家一一列出,感興趣的朋友可以參考一下2016-02-02python numpy矩陣信息說(shuō)明,shape,size,dtype
這篇文章主要介紹了python numpy矩陣信息說(shuō)明,shape,size,dtype,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧2020-05-05詳解python 字符串和日期之間轉(zhuǎn)換 StringAndDate
這篇文章主要介紹了python 字符串和日期之間轉(zhuǎn)換 StringAndDate簡(jiǎn)單實(shí)例的相關(guān)資料,需要的朋友可以參考下2017-05-05教你用Python寫(xiě)一個(gè)植物大戰(zhàn)僵尸小游戲
這篇文章主要介紹了教你用Python寫(xiě)一個(gè)植物大戰(zhàn)僵尸小游戲,文中有非常詳細(xì)的代碼示例,對(duì)正在學(xué)習(xí)python的小伙伴們有非常好的幫助,需要的朋友可以參考下2021-04-04Python第三方庫(kù)xlrd/xlwt的安裝與讀寫(xiě)Excel表格
最近開(kāi)始學(xué)習(xí)python,想做做簡(jiǎn)單的自動(dòng)化測(cè)試,需要讀寫(xiě)excel,于是就接觸到了Python的第三方庫(kù)xlrd和xlwt,下面這篇文章就給大家主要介紹了Python中第三方庫(kù)xlrd/xlwt的安裝與讀寫(xiě)Excel表格的方法,需要的朋友可以參考借鑒。2017-01-01matplotlib之pyplot模塊坐標(biāo)軸標(biāo)簽設(shè)置使用(xlabel()、ylabel())
這篇文章主要介紹了matplotlib之pyplot模塊坐標(biāo)軸標(biāo)簽設(shè)置使用(xlabel()、ylabel()),文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來(lái)一起學(xué)習(xí)學(xué)習(xí)吧2021-02-02深入淺析Python科學(xué)計(jì)算庫(kù)Scipy及安裝步驟
這篇文章主要介紹了Python科學(xué)計(jì)算庫(kù)—Scipy的相關(guān)知識(shí),非常不錯(cuò),具有一定的參考借鑒價(jià)值,需要的朋友可以參考下2019-10-10