pytorch ImageFolder的覆寫(xiě)實(shí)例
在為數(shù)據(jù)分類訓(xùn)練分類器的時(shí)候,比如貓狗分類時(shí),我們經(jīng)常會(huì)使用pytorch的ImageFolder:
CLASS torchvision.datasets.ImageFolder(root, transform=None, target_transform=None, loader=<function default_loader>, is_valid_file=None)
使用可見(jiàn)pytorch torchvision.ImageFolder的用法介紹
這里想實(shí)現(xiàn)的是如果想要覆寫(xiě)該函數(shù),即能使用它的特性,又可以實(shí)現(xiàn)自己的功能
首先先分析下其源代碼:
IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', 'webp']
class ImageFolder(DatasetFolder):
"""A generic data loader where the images are arranged in this way: ::
root/dog/xxx.png
root/dog/xxy.png
root/dog/xxz.png
root/cat/123.png
root/cat/nsdf3.png
root/cat/asd932_.png
Args:
root (string): Root directory path.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
loader (callable, optional): A function to load an image given its path.
Attributes:
classes (list): List of the class names.
class_to_idx (dict): Dict with items (class_name, class_index).
imgs (list): List of (image path, class_index) tuples
"""
def __init__(self, root, transform=None, target_transform=None,
loader=default_loader):
super(ImageFolder, self).__init__(root, loader, IMG_EXTENSIONS,
transform=transform,
target_transform=target_transform)
self.imgs = self.samples
ImageFolder的代碼很簡(jiǎn)單,主要是繼承了DatasetFolder:
def has_file_allowed_extension(filename, extensions):
"""查看文件是否是支持的可擴(kuò)展類型
Args:
filename (string): 文件路徑
extensions (iterable of strings): 可擴(kuò)展類型列表,即能接受的圖像文件類型
Returns:
bool: True if the filename ends with one of given extensions
"""
filename_lower = filename.lower()
return any(filename_lower.endswith(ext) for ext in extensions) # 返回True或False列表
def make_dataset(dir, class_to_idx, extensions):
"""
返回形如[(圖像路徑, 該圖像對(duì)應(yīng)的類別索引值),(),...]
"""
images = []
dir = os.path.expanduser(dir)
for target in sorted(class_to_idx.keys()):
d = os.path.join(dir, target)
if not os.path.isdir(d):
continue
for root, _, fnames in sorted(os.walk(d)): #層層遍歷文件夾,返回當(dāng)前文件夾路徑,存在的所有文件夾名,存在的所有文件名
for fname in sorted(fnames):
if has_file_allowed_extension(fname, extensions):查看文件是否是支持的可擴(kuò)展類型,是則繼續(xù)
path = os.path.join(root, fname)
item = (path, class_to_idx[target])
images.append(item)
return images
class DatasetFolder(data.Dataset):
"""A generic data loader where the samples are arranged in this way: ::
root/class_x/xxx.ext
root/class_x/xxy.ext
root/class_x/xxz.ext
root/class_y/123.ext
root/class_y/nsdf3.ext
root/class_y/asd932_.ext
Args:
root (string): 根目錄路徑
loader (callable): 根據(jù)給定的路徑來(lái)加載樣本的可調(diào)用函數(shù)
extensions (list[string]): 可擴(kuò)展類型列表,即能接受的圖像文件類型.
transform (callable, optional): 用于樣本的transform函數(shù),然后返回樣本transform后的版本
E.g, ``transforms.RandomCrop`` for images.
target_transform (callable, optional): 用于樣本標(biāo)簽的transform函數(shù)
Attributes:
classes (list): 類別名列表
class_to_idx (dict): 項(xiàng)目(class_name, class_index)字典,如{'cat': 0, 'dog': 1}
samples (list): (sample path, class_index) 元組列表,即(樣本路徑, 類別索引)
targets (list): 在數(shù)據(jù)集中每張圖片的類索引值,為列表
"""
def __init__(self, root, loader, extensions, transform=None, target_transform=None):
classes, class_to_idx = self._find_classes(root) # 得到類名和類索引,如['cat', 'dog']和{'cat': 0, 'dog': 1}
# 返回形如[(圖像路徑, 該圖像對(duì)應(yīng)的類別索引值),(),...],即對(duì)每個(gè)圖像進(jìn)行標(biāo)記
samples = make_dataset(root, class_to_idx, extensions)
if len(samples) == 0:
raise(RuntimeError("Found 0 files in subfolders of: " + root + "\n"
"Supported extensions are: " + ",".join(extensions)))
self.root = root
self.loader = loader
self.extensions = extensions
self.classes = classes
self.class_to_idx = class_to_idx
self.samples = samples
self.targets = [s[1] for s in samples] #所有圖像的類索引值組成的列表
self.transform = transform
self.target_transform = target_transform
def _find_classes(self, dir):
"""
在數(shù)據(jù)集中查找類文件夾。
Args:
dir (string): 根目錄路徑
Returns:
返回元組: (classes, class_to_idx)即(類名, 類索引),其中classes即相應(yīng)的目錄名,如['cat', 'dog'];class_to_idx為形如{類名:類索引}的字典,如{'cat': 0, 'dog': 1}.
Ensures:
保證沒(méi)有類名是另一個(gè)類目錄的子目錄
"""
if sys.version_info >= (3, 5):
# Faster and available in Python 3.5 and above
classes = [d.name for d in os.scandir(dir) if d.is_dir()] #獲得根目錄dir的所有第一層子目錄名
else:
classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))] #效果和上面的一樣,只是版本不同方法不同
classes.sort() #然后對(duì)類名進(jìn)行排序
class_to_idx = {classes[i]: i for i in range(len(classes))} #然后將類名和索引值一一對(duì)應(yīng)的到相應(yīng)字典,如{'cat': 0, 'dog': 1}
return classes, class_to_idx #然后返回類名和類索引
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (sample, target) where target is class_index of the target class.
"""
path, target = self.samples[index]
sample = self.loader(path) # 加載圖片
if self.transform is not None:
sample = self.transform(sample)
if self.target_transform is not None:
target = self.target_transform(target)
return sample, target
def __len__(self):
return len(self.samples)
def __repr__(self):
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
fmt_str += ' Root Location: {}\n'.format(self.root)
tmp = ' Transforms (if any): '
fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
tmp = ' Target Transforms (if any): '
fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
return fmt_str
此時(shí)想要覆寫(xiě)ImageFolder,代碼為:
class CustomImageFolder(ImageFolder): """ 為了得到兩張圖(其中一張是隨機(jī)選取的)的圖像和索引值信息 """ def __init__(self, root, transform=None): super(CustomImageFolder, self).__init__(root, transform) self.indices = range(len(self)) #該文件夾中的長(zhǎng)度 def __getitem__(self, index1): index2 = random.choice(self.indices) #從[0,indices]中隨機(jī)抽取一個(gè)數(shù)字,為了隨機(jī)選取一張圖 path1 = self.imgs[index1][0] #此時(shí)的self.imgs等于self.samples,即內(nèi)容為[(圖像路徑, 該圖像對(duì)應(yīng)的類別索引值),(),...] label1 = self.imgs[index1][1] path2 = self.imgs[index2][0] label2 = self.imgs[index2][1] img1 = self.loader(path1) img2 = self.loader(path2) if self.transform is not None: img1 = self.transform(img1) img2 = self.transform(img2) return img1, img2, label1, label2
以上這篇pytorch ImageFolder的覆寫(xiě)實(shí)例就是小編分享給大家的全部?jī)?nèi)容了,希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
相關(guān)文章
詳解Django項(xiàng)目中模板標(biāo)簽及模板的繼承與引用(網(wǎng)站中快速布置廣告)
這篇文章主要介紹了詳解Django項(xiàng)目中模板標(biāo)簽及模板的繼承與引用【網(wǎng)站中快速布置廣告】,小編覺(jué)得挺不錯(cuò)的,現(xiàn)在分享給大家,也給大家做個(gè)參考。一起跟隨小編過(guò)來(lái)看看吧2019-03-03
Python使用海龜繪圖實(shí)現(xiàn)貪吃蛇游戲
這篇文章主要為大家詳細(xì)介紹了Python使用海龜繪圖實(shí)現(xiàn)貪吃蛇游戲,文中示例代碼介紹的非常詳細(xì),具有一定的參考價(jià)值,感興趣的小伙伴們可以參考一下2021-06-06
Python使用matplotlib繪制圓形代碼實(shí)例
這篇文章主要介紹了Python使用matplotlib繪制圓形代碼實(shí)例,文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友可以參考下2020-05-05
OpenCV-Python使用cv2實(shí)現(xiàn)傅里葉變換
在OpenCV中,我們通過(guò)cv2.dft()來(lái)實(shí)現(xiàn)傅里葉變換,使用cv2.idft()來(lái)實(shí)現(xiàn)逆傅里葉變換。本文就詳細(xì)的介紹一下這兩種用法,感興趣的可以了解一下2021-06-06
python中函數(shù)返回多個(gè)結(jié)果的實(shí)例方法
在本篇文章里小編給大家整理了一篇關(guān)于python中函數(shù)返回多個(gè)結(jié)果的實(shí)例方法,有興趣的朋友們可以學(xué)習(xí)下。2020-12-12
python簡(jiǎn)單爬蟲(chóng)--get方式詳解
本篇文章介紹了python爬蟲(chóng)中g(shù)et和post方法介紹以及cookie作用,對(duì)此有興趣的朋友學(xué)習(xí)下,希望能夠給你帶來(lái)幫助2021-09-09
python下實(shí)現(xiàn)二叉堆以及堆排序的示例
下面小編就為大家?guī)?lái)一篇python下實(shí)現(xiàn)二叉堆以及堆排序的示例。小編覺(jué)得挺不錯(cuò)的,現(xiàn)在就分享給大家,也給大家做個(gè)參考。一起跟隨小編過(guò)來(lái)看看吧2017-09-09

