Python實(shí)現(xiàn)隨機(jī)從圖像中獲取多個patch
經(jīng)常有一些圖像任務(wù)需要從一張大圖中截取固定大小的patch來進(jìn)行訓(xùn)練。這里面常常存在下面幾個問題:
- patch的位置盡可能隨機(jī),不然數(shù)據(jù)豐富性可能不夠,容易引起過擬合
- 如果原圖較大,讀圖帶來的IO開銷可能會非常大,影響訓(xùn)練速度,所以最好一次能夠截取多個patch
- 我們經(jīng)常不太希望因?yàn)殡S機(jī)性的存在而使得圖像中某些區(qū)域沒有被覆蓋到,所以還需要注意patch位置的覆蓋程度
基于以上問題,我們可以使用下面的策略從圖像中獲取位置隨機(jī)的多個patch:
- 以固定的stride獲取所有patch的左上角坐標(biāo)
- 對左上角坐標(biāo)進(jìn)行隨機(jī)擾動
- 對patch的左上角坐標(biāo)加上寬和高得到右下角坐標(biāo)
- 檢查patch的坐標(biāo)是否超出圖像邊界,如果超出則將其收進(jìn)來,收的過程應(yīng)保證patch尺寸不變
- 加入ROI(Region Of Interest)功能,也就是說patch不一定非要在整張圖中獲取,而是可以指定ROI區(qū)域
下面是實(shí)現(xiàn)代碼和例子:
注意下面代碼只是獲取了patch的bounding box,并沒有把patch截取出來。
# -*- coding: utf-8 -*- import cv2 import numpy as np def get_random_patch_bboxes(image, bbox_size, stride, jitter, roi_bbox=None): """ Generate random patch bounding boxes for a image around ROI region Parameters ---------- image: image data read by opencv, shape is [H, W, C] bbox_size: size of patch bbox, one digit or a list/tuple containing two digits, defined by (width, height) stride: stride between adjacent bboxes (before jitter), one digit or a list/tuple containing two digits, defined by (x, y) jitter: jitter size for evenly distributed bboxes, one digit or a list/tuple containing two digits, defined by (x, y) roi_bbox: roi region, defined by [xmin, ymin, xmax, ymax], default is whole image region Returns ------- patch_bboxes: randomly distributed patch bounding boxes, n x 4 numpy array. Each bounding box is defined by [xmin, ymin, xmax, ymax] """ height, width = image.shape[:2] bbox_size = _process_geometry_param(bbox_size, min_value=1) stride = _process_geometry_param(stride, min_value=1) jitter = _process_geometry_param(jitter, min_value=0) if bbox_size[0] > width or bbox_size[1] > height: raise ValueError('box_size must be <= image size') if roi_bbox is None: roi_bbox = [0, 0, width, height] # tl is for top-left, br is for bottom-right tl_x, tl_y = _get_top_left_points(roi_bbox, bbox_size, stride, jitter) br_x = tl_x + bbox_size[0] br_y = tl_y + bbox_size[1] # shrink bottom-right points to avoid exceeding image border br_x[br_x > width] = width br_y[br_y > height] = height # shrink top-left points to avoid exceeding image border tl_x = br_x - bbox_size[0] tl_y = br_y - bbox_size[1] tl_x[tl_x < 0] = 0 tl_y[tl_y < 0] = 0 # compute bottom-right points again br_x = tl_x + bbox_size[0] br_y = tl_y + bbox_size[1] patch_bboxes = np.concatenate((tl_x, tl_y, br_x, br_y), axis=1) return patch_bboxes def _process_geometry_param(param, min_value): """ Process and check param, which must be one digit or a list/tuple containing two digits, and its value must be >= min_value Parameters ---------- param: parameter to be processed min_value: min value for param Returns ------- param: param after processing """ if isinstance(param, (int, float)) or \ isinstance(param, np.ndarray) and param.size == 1: param = int(np.round(param)) param = [param, param] else: if len(param) != 2: raise ValueError('param must be one digit or two digits') param = [int(np.round(param[0])), int(np.round(param[1]))] # check data range using min_value if not (param[0] >= min_value and param[1] >= min_value): raise ValueError('param must be >= min_value (%d)' % min_value) return param def _get_top_left_points(roi_bbox, bbox_size, stride, jitter): """ Generate top-left points for bounding boxes Parameters ---------- roi_bbox: roi region, defined by [xmin, ymin, xmax, ymax] bbox_size: size of patch bbox, a list/tuple containing two digits, defined by (width, height) stride: stride between adjacent bboxes (before jitter), a list/tuple containing two digits, defined by (x, y) jitter: jitter size for evenly distributed bboxes, a list/tuple containing two digits, defined by (x, y) Returns ------- tl_x: x coordinates of top-left points, n x 1 numpy array tl_y: y coordinates of top-left points, n x 1 numpy array """ xmin, ymin, xmax, ymax = roi_bbox roi_width = xmax - xmin roi_height = ymax - ymin # get the offset between the first top-left point of patch box and the # top-left point of roi_bbox offset_x = np.arange(0, roi_width, stride[0])[-1] + bbox_size[0] offset_y = np.arange(0, roi_height, stride[1])[-1] + bbox_size[1] offset_x = (offset_x - roi_width) // 2 offset_y = (offset_y - roi_height) // 2 # get the coordinates of all top-left points tl_x = np.arange(xmin, xmax, stride[0]) - offset_x tl_y = np.arange(ymin, ymax, stride[1]) - offset_y tl_x, tl_y = np.meshgrid(tl_x, tl_y) tl_x = np.reshape(tl_x, [-1, 1]) tl_y = np.reshape(tl_y, [-1, 1]) # jitter the coordinates of all top-left points tl_x += np.random.randint(-jitter[0], jitter[0] + 1, size=tl_x.shape) tl_y += np.random.randint(-jitter[1], jitter[1] + 1, size=tl_y.shape) return tl_x, tl_y if __name__ == '__main__': image = cv2.imread('1.bmp') patch_bboxes = get_random_patch_bboxes( image, bbox_size=[64, 96], stride=[128, 128], jitter=[32, 32], roi_bbox=[500, 200, 1500, 800]) colors = [ (255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0), (255, 0, 255), (0, 255, 255)] color_idx = 0 for bbox in patch_bboxes: color_idx = color_idx % 6 pt1 = (bbox[0], bbox[1]) pt2 = (bbox[2], bbox[3]) cv2.rectangle(image, pt1, pt2, color=colors[color_idx], thickness=2) color_idx += 1 cv2.namedWindow('image', 0) cv2.imshow('image', image) cv2.waitKey(0) cv2.destroyAllWindows() cv2.imwrite('image.png', image)
在實(shí)際應(yīng)用中可以進(jìn)一步增加一些簡單的功能:
1.根據(jù)位置增加一些過濾功能。比如說太靠近邊緣的給剔除掉,有些算法可能有比較嚴(yán)重的邊緣效應(yīng),所以此時我們可能不太想要邊緣的數(shù)據(jù)加入訓(xùn)練
2.也可以根據(jù)某些簡單的算法策略進(jìn)行過濾。比如在超分辨率這樣的任務(wù)中,我們可能一般不太關(guān)心面積非常大的平坦區(qū)域,比如純色墻面,大片天空等,此時可以使用方差進(jìn)行過濾
3.設(shè)置最多保留數(shù)目。有時候原圖像的大小可能有很大差異,此時利用上述方法得到的patch數(shù)量也就隨之有很大的差異,然而為了保持訓(xùn)練數(shù)據(jù)的均衡性,我們可以設(shè)置最多保留數(shù)目,為了確保覆蓋程度,一般需要在截取之前對patch進(jìn)行shuffle,或者計算stride
以上就是Python實(shí)現(xiàn)隨機(jī)從圖像中獲取多個patch的詳細(xì)內(nèi)容,更多關(guān)于Python圖像獲取patch的資料請關(guān)注腳本之家其它相關(guān)文章!
相關(guān)文章
Python實(shí)現(xiàn)判斷并移除列表指定位置元素的方法
這篇文章主要介紹了Python實(shí)現(xiàn)判斷并移除列表指定位置元素的方法,涉及Python針對列表的索引范圍判斷及元素刪除等相關(guān)操作技巧,需要的朋友可以參考下2018-04-04Python:Scrapy框架中Item Pipeline組件使用詳解
這篇文章主要介紹了Python:Scrapy框架中Item Pipeline組件使用詳解,具有一定借鑒價值,需要的朋友可以參考下2017-12-12LeetCode189輪轉(zhuǎn)數(shù)組python示例
這篇文章主要為大家介紹了LeetCode189輪轉(zhuǎn)數(shù)組python解法示例,有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進(jìn)步,早日升職加薪2022-08-08Django實(shí)現(xiàn)自定義路由轉(zhuǎn)換器
有時候上面的內(nèi)置的url轉(zhuǎn)換器并不能滿足我們的需求,因此django給我們提供了一個接口可以讓我們自己定義自己的url轉(zhuǎn)換器,那么如何實(shí)現(xiàn),本文就來介紹一下2021-05-05如何修復(fù)使用 Python ORM 工具 SQLAlchemy 時的常見陷阱
SQLAlchemy 是一個 Python ORM 工具包,它提供使用 Python 訪問 SQL 數(shù)據(jù)庫的功能。這篇文章主要介紹了如何修復(fù)使用 Python ORM 工具 SQLAlchemy 時的常見陷阱,需要的朋友可以參考下2019-11-11使用用Pyspark和GraphX實(shí)現(xiàn)解析復(fù)雜網(wǎng)絡(luò)數(shù)據(jù)
GraphX是Spark提供的圖計算API,它提供了一套強(qiáng)大的工具,這篇文章將詳細(xì)為大家介紹如何在Python?/?pyspark環(huán)境中使用graphx進(jìn)行圖計算,感興趣的可以了解下2024-01-01python+OpenCV人臉識別考勤系統(tǒng)實(shí)現(xiàn)的詳細(xì)代碼
作為一個基于人臉識別算法的考勤系統(tǒng)的設(shè)計與實(shí)現(xiàn)教程,以下內(nèi)容將提供詳細(xì)的步驟和代碼示例。本教程將使用 Python 語言和 OpenCV 庫進(jìn)行實(shí)現(xiàn),需要的朋友可以參考下2023-05-05