Pytorch 使用 nii數(shù)據(jù)做輸入數(shù)據(jù)的操作
使用pix2pix-gan做醫(yī)學圖像合成的時候,如果把nii數(shù)據(jù)轉(zhuǎn)成png格式會損失很多信息,以為png格式圖像的灰度值有256階,因此直接使用nii的醫(yī)學圖像做輸入會更好一點。
但是Pythorch中的Dataloader是不能直接讀取nii圖像的,因此加一個CreateNiiDataset的類。
先來了解一下pytorch中讀取數(shù)據(jù)的主要途徑——Dataset類。在自己構(gòu)建數(shù)據(jù)層時都要基于這個類,類似于C++中的虛基類。
自己構(gòu)建的數(shù)據(jù)層包含三個部分
class Dataset(object): """An abstract class representing a Dataset. All other datasets should subclass it. All subclasses should override ``__len__``, that provides the size of the dataset, and ``__getitem__``, supporting integer indexing in range from 0 to len(self) exclusive. """ def __getitem__(self, index): raise NotImplementedError def __len__(self): raise NotImplementedError def __add__(self, other): return ConcatDataset([self, other])
根據(jù)自己的需要編寫CreateNiiDataset子類:
因為我是基于https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix
做pix2pix-gan的實驗,數(shù)據(jù)包含兩個部分mr 和 ct,不需要標簽,因此上面的 def getitem(self, index):中不需要index這個參數(shù)了,類似地,根據(jù)需要,加入自己的參數(shù),去掉不需要的參數(shù)。
class CreateNiiDataset(Dataset): def __init__(self, opt, transform = None, target_transform = None): self.path1 = opt.dataroot # parameter passing self.A = 'MR' self.B = 'CT' lines = os.listdir(os.path.join(self.path1, self.A)) lines.sort() imgs = [] for line in lines: imgs.append(line) self.imgs = imgs self.transform = transform self.target_transform = target_transform def crop(self, image, crop_size): shp = image.shape scl = [int((shp[0] - crop_size[0]) / 2), int((shp[1] - crop_size[1]) / 2)] image_crop = image[scl[0]:scl[0] + crop_size[0], scl[1]:scl[1] + crop_size[1]] return image_crop def __getitem__(self, item): file = self.imgs[item] img1 = sitk.ReadImage(os.path.join(self.path1, self.A, file)) img2 = sitk.ReadImage(os.path.join(self.path1, self.B, file)) data1 = sitk.GetArrayFromImage(img1) data2 = sitk.GetArrayFromImage(img2) if data1.shape[0] != 256: data1 = self.crop(data1, [256, 256]) data2 = self.crop(data2, [256, 256]) if self.transform is not None: data1 = self.transform(data1) data2 = self.transform(data2) if np.min(data1)<0: data1 = (data1 - np.min(data1))/(np.max(data1)-np.min(data1)) if np.min(data2)<0: #data2 = data2 - np.min(data2) data2 = (data2 - np.min(data2))/(np.max(data2)-np.min(data2)) data = {} data1 = data1[np.newaxis, np.newaxis, :, :] data1_tensor = torch.from_numpy(np.concatenate([data1,data1,data1], 1)) data1_tensor = data1_tensor.type(torch.FloatTensor) data['A'] = data1_tensor # should be a tensor in Float Tensor Type data2 = data2[np.newaxis, np.newaxis, :, :] data2_tensor = torch.from_numpy(np.concatenate([data2,data2,data2], 1)) data2_tensor = data2_tensor.type(torch.FloatTensor) data['B'] = data2_tensor # should be a tensor in Float Tensor Type data['A_paths'] = [os.path.join(self.path1, self.A, file)] # should be a list, with path inside data['B_paths'] = [os.path.join(self.path1, self.B, file)] return data def load_data(self): return self def __len__(self): return len(self.imgs)
注意:最后輸出的data是一個字典,里面有四個keys=[‘A',‘B',‘A_paths',‘B_paths'], 一定要注意數(shù)據(jù)要轉(zhuǎn)成FloatTensor。
其次是data[‘A_paths'] 接收的值是一個list,一定要加[ ] 擴起來,要不然測試存圖的時候會有問題,找這個問題找了好久才發(fā)現(xiàn)。
然后直接在train.py的主函數(shù)里面把數(shù)據(jù)加載那行改掉就好了
data_loader = CreateNiiDataset(opt)
dataset = data_loader.load_data()
Over!
補充知識:nii格式圖像存為npy格式
我就廢話不多說了,大家還是直接看代碼吧!
import nibabel as nib import os import numpy as np img_path = '/home/lei/train/img/' seg_path = '/home/lei/train/seg/' saveimg_path = '/home/lei/train/npy_img/' saveseg_path = '/home/lei/train/npy_seg/' img_names = os.listdir(img_path) seg_names = os.listdir(seg_path) for img_name in img_names: print(img_name) img = nib.load(img_path + img_name).get_data() #載入 img = np.array(img) np.save(saveimg_path + str(img_name).split('.')[0] + '.npy', img) #保存 for seg_name in seg_names: print(seg_name) seg = nib.load(seg_path + seg_name).get_data() seg = np.array(seg) np.save(saveseg_path + str(seg_name).split('.')[0] + '.npy
以上這篇Pytorch 使用 nii數(shù)據(jù)做輸入數(shù)據(jù)的操作就是小編分享給大家的全部內(nèi)容了,希望能給大家一個參考,也希望大家多多支持腳本之家。
相關(guān)文章
python3的url編碼和解碼,自定義gbk、utf-8的例子
今天小編就為大家分享一篇python3的url編碼和解碼,自定義gbk、utf-8的例子,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧2019-08-08詳解使用pymysql在python中對mysql的增刪改查操作(綜合)
本篇文章主要介紹了使用pymysql在python中對mysql的增刪改查操作,通過pymysql向數(shù)據(jù)庫進行查刪增改,具有一定的參考價值,有興趣的可以了解一下。2017-01-01python語言線程標準庫threading.local解讀總結(jié)
在本篇文章里我們給各位整理了一篇關(guān)于python threading.local源碼解讀的相關(guān)文章知識點,有需要的朋友們可以學習下。2019-11-11Python爬蟲實戰(zhàn)之使用Scrapy爬取豆瓣圖片
在用Python的urllib和BeautifulSoup寫過了很多爬蟲之后,本人決定嘗試著名的Python爬蟲框架——Scrapy.本次分享將詳細講述如何利用Scrapy來下載豆瓣名人圖片,需要的朋友可以參考下2021-06-06python3實現(xiàn)字符串的全排列的方法(無重復(fù)字符)
這篇文章主要介紹了python3實現(xiàn)字符串的全排列的方法(無重復(fù)字符),小編覺得挺不錯的,現(xiàn)在分享給大家,也給大家做個參考。一起跟隨小編過來看看吧2018-07-07利用Tensorboard繪制網(wǎng)絡(luò)識別準確率和loss曲線實例
今天小編就為大家分享一篇利用Tensorboard繪制網(wǎng)絡(luò)識別準確率和loss曲線實例,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧2020-02-02