Pytorch中torch.nn.Softmax的dim參數(shù)用法說(shuō)明
Pytorch中torch.nn.Softmax的dim參數(shù)使用含義
涉及到多維tensor時(shí),對(duì)softmax的參數(shù)dim總是很迷,下面用一個(gè)例子說(shuō)明
import torch.nn as nn m = nn.Softmax(dim=0) n = nn.Softmax(dim=1) k = nn.Softmax(dim=2) input = torch.randn(2, 2, 3) print(input) print(m(input)) print(n(input)) print(k(input))
輸出:
input
tensor([[[ 0.5450, -0.6264, 1.0446],
[ 0.6324, 1.9069, 0.7158]],[[ 1.0092, 0.2421, -0.8928],
[ 0.0344, 0.9723, 0.4328]]])
dim=0
tensor([[[0.3860, 0.2956, 0.8741],
[0.6452, 0.7180, 0.5703]],[[0.6140, 0.7044, 0.1259],
[0.3548, 0.2820, 0.4297]]])
dim=0時(shí),在第0維上sum=1,即:
[0][0][0]+[1][0][0]=0.3860+0.6140=1
[0][0][1]+[1][0][1]=0.2956+0.7044=1
… …
dim=1
tensor([[[0.4782, 0.0736, 0.5815],
[0.5218, 0.9264, 0.4185]],[[0.7261, 0.3251, 0.2099],
[0.2739, 0.6749, 0.7901]]])
dim=1時(shí),在第1維上sum=1,即:
[0][0][0]+[0][1][0]=0.4782+0.5218=1
[0][0][1]+[0][1][1]=0.0736+0.9264=1
… …
dim=2
tensor([[[0.3381, 0.1048, 0.5572],
[0.1766, 0.6315, 0.1919]],[[0.6197, 0.2878, 0.0925],
[0.1983, 0.5065, 0.2953]]])
dim=2時(shí),在第2維上sum=1,即:
[0][0][0]+[0][0][1]+[0][0][2]=0.3381+0.1048+0.5572=1.0001(四舍五入問(wèn)題)
[0][1][0]+[0][1][1]+[0][1][2]=0.1766+0.6315+0.1919=1
… …
用圖表示223的張量如下:

多分類問(wèn)題torch.nn.Softmax的使用
為什么談?wù)撨@個(gè)問(wèn)題呢?是因?yàn)槲以诠ぷ鞯倪^(guò)程中遇到了語(yǔ)義分割預(yù)測(cè)輸出特征圖個(gè)數(shù)為16,也就是所謂的16分類問(wèn)題。
因?yàn)槊總€(gè)通道的像素的值的大小代表了像素屬于該通道的類的大小,為了在一張圖上用不同的顏色顯示出來(lái),我不得不學(xué)習(xí)了torch.nn.Softmax的使用。
首先看一個(gè)簡(jiǎn)答的例子,倘若輸出為(3, 4, 4),也就是3張4x4的特征圖。
import torch img = torch.rand((3,4,4)) print(img)
輸出為:
tensor([[[0.0413, 0.8728, 0.8926, 0.0693],
[0.4072, 0.0302, 0.9248, 0.6676],
[0.4699, 0.9197, 0.3333, 0.4809],
[0.3877, 0.7673, 0.6132, 0.5203]],[[0.4940, 0.7996, 0.5513, 0.8016],
[0.1157, 0.8323, 0.9944, 0.2127],
[0.3055, 0.4343, 0.8123, 0.3184],
[0.8246, 0.6731, 0.3229, 0.1730]],[[0.0661, 0.1905, 0.4490, 0.7484],
[0.4013, 0.1468, 0.2145, 0.8838],
[0.0083, 0.5029, 0.0141, 0.8998],
[0.8673, 0.2308, 0.8808, 0.0532]]])
我們可以看到共三張?zhí)卣鲌D,每張?zhí)卣鲌D上對(duì)應(yīng)的值越大,說(shuō)明屬于該特征圖對(duì)應(yīng)類的概率越大。
import torch.nn as nn sogtmax = nn.Softmax(dim=0) img = sogtmax(img) print(img)
輸出為:
tensor([[[0.2780, 0.4107, 0.4251, 0.1979],
[0.3648, 0.2297, 0.3901, 0.3477],
[0.4035, 0.4396, 0.2993, 0.2967],
[0.2402, 0.4008, 0.3273, 0.4285]],[[0.4371, 0.3817, 0.3022, 0.4117],
[0.2726, 0.5122, 0.4182, 0.2206],
[0.3423, 0.2706, 0.4832, 0.2522],
[0.3718, 0.3648, 0.2449, 0.3028]],[[0.2849, 0.2076, 0.2728, 0.3904],
[0.3627, 0.2581, 0.1917, 0.4317],
[0.2543, 0.2898, 0.2175, 0.4511],
[0.3880, 0.2344, 0.4278, 0.2686]]])
可以看到,上面的代碼對(duì)每張?zhí)卣鲌D對(duì)應(yīng)位置的像素值進(jìn)行Softmax函數(shù)處理, 圖中標(biāo)紅位置加和=1,同理,標(biāo)藍(lán)位置加和=1。
我們看到Softmax函數(shù)會(huì)對(duì)原特征圖每個(gè)像素的值在對(duì)應(yīng)維度(這里dim=0,也就是第一維)上進(jìn)行計(jì)算,將其處理到0~1之間,并且大小固定不變。
print(torch.max(img,0))
輸出為:
torch.return_types.max(
values=tensor([[0.4371, 0.4107, 0.4251, 0.4117],
[0.3648, 0.5122, 0.4182, 0.4317],
[0.4035, 0.4396, 0.4832, 0.4511],
[0.3880, 0.4008, 0.4278, 0.4285]]),
indices=tensor([[1, 0, 0, 1],
[0, 1, 1, 2],
[0, 0, 1, 2],
[2, 0, 2, 0]]))
可以看到這里3x4x4變成了1x4x4,而且對(duì)應(yīng)位置上的值為像素對(duì)應(yīng)每個(gè)通道上的最大值,并且indices是對(duì)應(yīng)的分類。
清楚理解了上面的流程,那么我們就容易處理了。
看具體案例,這里輸出output的大小為:16x416x416.
output = torch.tensor(output)
sm = nn.Softmax(dim=0)
output = sm(output)
mask = torch.max(output,0).indices.numpy()
# 因?yàn)橐D(zhuǎn)化為RGB彩色圖,所以增加一維
rgb_img = np.zeros((output.shape[1], output.shape[2], 3))
for i in range(len(mask)):
for j in range(len(mask[0])):
if mask[i][j] == 0:
rgb_img[i][j][0] = 255
rgb_img[i][j][1] = 255
rgb_img[i][j][2] = 255
if mask[i][j] == 1:
rgb_img[i][j][0] = 255
rgb_img[i][j][1] = 180
rgb_img[i][j][2] = 0
if mask[i][j] == 2:
rgb_img[i][j][0] = 255
rgb_img[i][j][1] = 180
rgb_img[i][j][2] = 180
if mask[i][j] == 3:
rgb_img[i][j][0] = 255
rgb_img[i][j][1] = 180
rgb_img[i][j][2] = 255
if mask[i][j] == 4:
rgb_img[i][j][0] = 255
rgb_img[i][j][1] = 255
rgb_img[i][j][2] = 180
if mask[i][j] == 5:
rgb_img[i][j][0] = 255
rgb_img[i][j][1] = 255
rgb_img[i][j][2] = 0
if mask[i][j] == 6:
rgb_img[i][j][0] = 255
rgb_img[i][j][1] = 0
rgb_img[i][j][2] = 180
if mask[i][j] == 7:
rgb_img[i][j][0] = 255
rgb_img[i][j][1] = 0
rgb_img[i][j][2] = 255
if mask[i][j] == 8:
rgb_img[i][j][0] = 255
rgb_img[i][j][1] = 0
rgb_img[i][j][2] = 0
if mask[i][j] == 9:
rgb_img[i][j][0] = 180
rgb_img[i][j][1] = 0
rgb_img[i][j][2] = 0
if mask[i][j] == 10:
rgb_img[i][j][0] = 180
rgb_img[i][j][1] = 255
rgb_img[i][j][2] = 255
if mask[i][j] == 11:
rgb_img[i][j][0] = 180
rgb_img[i][j][1] = 0
rgb_img[i][j][2] = 180
if mask[i][j] == 12:
rgb_img[i][j][0] = 180
rgb_img[i][j][1] = 0
rgb_img[i][j][2] = 255
if mask[i][j] == 13:
rgb_img[i][j][0] = 180
rgb_img[i][j][1] = 255
rgb_img[i][j][2] = 180
if mask[i][j] == 14:
rgb_img[i][j][0] = 0
rgb_img[i][j][1] = 180
rgb_img[i][j][2] = 255
if mask[i][j] == 15:
rgb_img[i][j][0] = 0
rgb_img[i][j][1] = 0
rgb_img[i][j][2] = 0
cv2.imwrite('output.jpg', rgb_img)
最后保存得到的圖為:

以上為個(gè)人經(jīng)驗(yàn),希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
相關(guān)文章
在Flask使用TensorFlow的幾個(gè)常見(jiàn)錯(cuò)誤及解決
這篇文章主要介紹了在Flask使用TensorFlow的幾個(gè)常見(jiàn)錯(cuò)誤及解決,具有很好的參考價(jià)值,希望對(duì)大家有所幫助,如有錯(cuò)誤或未考慮完全的地方,望不吝賜教2024-01-01
Python2和Python3中@abstractmethod使用方法
這篇文章主要介紹了Python2和Python3中@abstractmethod使用方法,文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友可以參考下2020-02-02
Python中的TfidfVectorizer參數(shù)使用解析
這篇文章主要介紹了Python中的TfidfVectorizer參數(shù)使用解析,具有很好的參考價(jià)值,希望對(duì)大家有所幫助,如有錯(cuò)誤或未考慮完全的地方,望不吝賜教2023-11-11
nginx搭建基于python的web環(huán)境的實(shí)現(xiàn)步驟
這篇文章主要介紹了nginx搭建基于python的web環(huán)境的實(shí)現(xiàn)步驟,文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來(lái)一起學(xué)習(xí)學(xué)習(xí)吧2020-01-01
Python實(shí)現(xiàn)隨機(jī)生成迷宮并自動(dòng)尋路
最近在學(xué)習(xí)Python,正好今天在學(xué)習(xí)隨機(jī)數(shù),本文實(shí)現(xiàn)了Python實(shí)現(xiàn)隨機(jī)生成迷宮并自動(dòng)尋路,感興趣的可以了解一下2021-06-06
教你如何在Pycharm中導(dǎo)入requests模塊
這篇文章主要介紹了教你如何在Pycharm中導(dǎo)入requests模塊,本文給大家介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或工作具有一定的參考借鑒價(jià)值,需要的朋友可以參考下2021-09-09
Qt通過(guò)QGraphicsview實(shí)現(xiàn)簡(jiǎn)單縮放及還原效果
本文主要介紹通過(guò)QGraphicsview實(shí)現(xiàn)簡(jiǎn)單的縮放以及縮放后還原原始大小,通過(guò)scale可以對(duì)view進(jìn)行放大或縮小,具體內(nèi)容詳情跟隨小編一起看看吧2021-09-09
Python中定時(shí)器用法詳解之Timer定時(shí)器和schedule庫(kù)
目前所在的項(xiàng)目組需要經(jīng)常執(zhí)行一些定時(shí)任務(wù),于是選擇使用 Python 的定時(shí)器,下面這篇文章主要給大家介紹了關(guān)于Python中定時(shí)器用法詳解之Timer定時(shí)器和schedule庫(kù)的相關(guān)資料,需要的朋友可以參考下2024-02-02

