pytorch中的transforms.ToTensor和transforms.Normalize的實(shí)現(xiàn)
transforms.ToTensor
最近看pytorch時(shí),遇到了對(duì)圖像數(shù)據(jù)的歸一化,如下圖所示:
該怎么理解這串代碼呢?我們一句一句的來(lái)看,先看transforms.ToTensor()
,我們可以先轉(zhuǎn)到官方給的定義,如下圖所示:
大概的意思就是說(shuō),transforms.ToTensor()
可以將PIL和numpy格式的數(shù)據(jù)從[0,255]范圍轉(zhuǎn)換到[0,1] ,具體做法其實(shí)就是將原始數(shù)據(jù)除以255。另外原始數(shù)據(jù)的shape是(H x W x C),通過(guò)transforms.ToTensor()
后shape會(huì)變?yōu)椋– x H x W)。這樣說(shuō)我覺(jué)得大家應(yīng)該也是能理解的,這部分并不難,但想著還是用一些例子來(lái)加深大家的映像??????
先導(dǎo)入一些包
import cv2 import numpy as np import torch from torchvision import transforms
定義一個(gè)數(shù)組模型圖片,注意數(shù)組數(shù)據(jù)類(lèi)型需要時(shí)np.uint8【官方圖示中給出】
data = np.array([ [[1,1,1],[1,1,1],[1,1,1],[1,1,1],[1,1,1]], [[2,2,2],[2,2,2],[2,2,2],[2,2,2],[2,2,2]], [[3,3,3],[3,3,3],[3,3,3],[3,3,3],[3,3,3]], [[4,4,4],[4,4,4],[4,4,4],[4,4,4],[4,4,4]], [[5,5,5],[5,5,5],[5,5,5],[5,5,5],[5,5,5]] ],dtype='uint8')
這是可以看看data的shape,注意現(xiàn)在為(W H C)。
使用transforms.ToTensor()
將data進(jìn)行轉(zhuǎn)換
data = transforms.ToTensor()(data)
這時(shí)候我們來(lái)看看data中的數(shù)據(jù)及shape。
? 很明顯,數(shù)據(jù)現(xiàn)在都映射到了[0, 1]之間,并且data的shape發(fā)生了變換。
**注意:不知道大家是如何理解三維數(shù)組的,這里提供我的一個(gè)方法。**??????
??原始的data的shape為(5,5,3),則其表示有5個(gè)(5 , 3)的二維數(shù)組,即我們把最外層的[]去掉就得到了5個(gè)五行三列的數(shù)據(jù)。
??同樣的,變換后data的shape為(3,5,5),則其表示有3個(gè)(5 , 5)的二維數(shù)組,即我們把最外層的[]去掉就得到了3個(gè)五行五列的數(shù)據(jù)。
transforms.Normalize??
相信通過(guò)前面的敘述大家應(yīng)該對(duì)transforms.ToTensor
有了一定的了解,下面將來(lái)說(shuō)說(shuō)這個(gè)transforms.Normalize
??????同樣的,我們先給出官方的定義,如下圖所示:
可以看到這個(gè)函數(shù)的輸出output[channel] = (input[channel] - mean[channel]) / std[channel]
。這里[channel]的意思是指對(duì)特征圖的每個(gè)通道都進(jìn)行這樣的操作。【mean為均值,std為標(biāo)準(zhǔn)差】接下來(lái)我們看第一張圖片中的代碼,即
這里的第一個(gè)參數(shù)(0.5,0.5,0.5)表示每個(gè)通道的均值都是0.5,第二個(gè)參數(shù)(0.5,0.5,0.5)表示每個(gè)通道的方差都為0.5。【因?yàn)閳D像一般是三個(gè)通道,所以這里的向量都是1x3的??????】有了這兩個(gè)參數(shù)后,當(dāng)我們傳入一個(gè)圖像時(shí),就會(huì)按照上面的公式對(duì)圖像進(jìn)行變換。【注意:這里說(shuō)圖像其實(shí)也不夠準(zhǔn)確,因?yàn)檫@個(gè)函數(shù)傳入的格式不能為PIL Image,我們應(yīng)該先將其轉(zhuǎn)換為T(mén)ensor格式】
說(shuō)了這么多,那么這個(gè)函數(shù)到底有什么用呢?我們通過(guò)前面的ToTensor已經(jīng)將數(shù)據(jù)歸一化到了0-1之間,現(xiàn)在又接上了一個(gè)Normalize函數(shù)有什么用呢?其實(shí)Normalize函數(shù)做的是將數(shù)據(jù)變換到了[-1,1]之間。之前的數(shù)據(jù)為0-1,當(dāng)取0時(shí),output =(0 - 0.5)/ 0.5 = -1
;當(dāng)取1時(shí),output =(1 - 0.5)/ 0.5 = 1
。這樣就把數(shù)據(jù)統(tǒng)一到了[-1,1]之間了??????那么問(wèn)題又來(lái)了,數(shù)據(jù)統(tǒng)一到[-1,1]有什么好處呢?數(shù)據(jù)如果分布在(0,1)之間,可能實(shí)際的bias,就是神經(jīng)網(wǎng)絡(luò)的輸入b會(huì)比較大,而模型初始化時(shí)b=0的,這樣會(huì)導(dǎo)致神經(jīng)網(wǎng)絡(luò)收斂比較慢,經(jīng)過(guò)Normalize后,可以加快模型的收斂速度?!具@句話是再網(wǎng)絡(luò)上找到最多的解釋?zhuān)约阂膊淮_定其正確性】
讀到這里大家是不是以為就完了呢?這里還想和大家嘮上一嘮??????上面的兩個(gè)參數(shù)(0.5,0.5,0.5)是怎么得來(lái)的呢?這是根據(jù)數(shù)據(jù)集中的數(shù)據(jù)計(jì)算出的均值和標(biāo)準(zhǔn)差,所以往往不同的數(shù)據(jù)集這兩個(gè)值是不同的??????這里再舉一個(gè)例子幫助大家理解其計(jì)算過(guò)程。同樣采用上文例子中提到的數(shù)據(jù)。
上文已經(jīng)得到了經(jīng)ToTensor轉(zhuǎn)換后的數(shù)據(jù),現(xiàn)需要求出該數(shù)據(jù)每個(gè)通道的mean和std?!具@一部分建議大家自己運(yùn)行看看每一步的結(jié)果??????】
# 需要對(duì)數(shù)據(jù)進(jìn)行擴(kuò)維,增加batch維度 data = torch.unsqueeze(data,0) #在pytorch中一般都是(batch,C,H,W) nb_samples = 0. #創(chuàng)建3維的空列表 channel_mean = torch.zeros(3) channel_std = torch.zeros(3) N, C, H, W = data.shape[:4] data = data.view(N, C, -1) #將數(shù)據(jù)的H,W合并 #展平后,w,h屬于第2維度,對(duì)他們求平均,sum(0)為將同一緯度的數(shù)據(jù)累加 channel_mean += data.mean(2).sum(0) #展平后,w,h屬于第2維度,對(duì)他們求標(biāo)準(zhǔn)差,sum(0)為將同一緯度的數(shù)據(jù)累加 channel_std += data.std(2).sum(0) #獲取所有batch的數(shù)據(jù),這里為1 nb_samples += N #獲取同一batch的均值和標(biāo)準(zhǔn)差 channel_mean /= nb_samples channel_std /= nb_samples print(channel_mean, channel_std) #結(jié)果為tensor([0.0118, 0.0118, 0.0118]) tensor([0.0057, 0.0057, 0.0057])
將上述得到的mean和std帶入公式,計(jì)算輸出。
for i in range(3): data[i] = (data[i] - channel_mean[i]) / channel_std[i] print(data)
輸出結(jié)果:
? 從結(jié)果可以看出,我們計(jì)算的mean和std并不是0.5,且最后的結(jié)果也沒(méi)有在[-1,1]之間。
最后我們?cè)賮?lái)看一個(gè)有意思的例子,我們得到了最終的結(jié)果,要是我們想要變回去怎么辦,其實(shí)很簡(jiǎn)單啦,就是一個(gè)逆運(yùn)算,即input = std*output + mean
,然后再乘上255就可以得到原始的結(jié)果了。很多人獲取吐槽了,這也叫有趣!????哈哈哈這里我想說(shuō)的是另外的一個(gè)事,如果我們對(duì)一張圖像進(jìn)行了歸一化,這時(shí)候你用歸一化后的數(shù)據(jù)顯示這張圖像的時(shí)候,會(huì)發(fā)現(xiàn)同樣會(huì)是原圖。
參考鏈接1:https://zhuanlan.zhihu.com/p/414242338
參考鏈接2:https://blog.csdn.net/peacefairy/article/details/108020179
到此這篇關(guān)于pytorch中的transforms.ToTensor和transforms.Normalize的實(shí)現(xiàn)的文章就介紹到這了,更多相關(guān)pytorch transforms.ToTensor和transforms.Normalize內(nèi)容請(qǐng)搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
相關(guān)文章
go貨幣計(jì)算時(shí)如何避免浮點(diǎn)數(shù)精度問(wèn)題
在開(kāi)發(fā)的初始階段,我們經(jīng)常會(huì)遇到“浮點(diǎn)數(shù)精度”和“貨幣值表示”的問(wèn)題,那么在golang中如何避免這一方面的問(wèn)題呢,下面就跟隨小編一起來(lái)學(xué)習(xí)一下吧2024-02-02小學(xué)生也能看懂的Golang異常處理recover panic
在其他語(yǔ)言里,宕機(jī)往往以異常的形式存在,底層拋出異常,上層邏輯通過(guò) try/catch 機(jī)制捕獲異常,沒(méi)有被捕獲的嚴(yán)重異常會(huì)導(dǎo)致宕機(jī),go語(yǔ)言追求簡(jiǎn)潔,優(yōu)雅,Go語(yǔ)言不支持傳統(tǒng)的 try…catch…finally 這種異常2021-09-09