解決pytorch中的kl divergence計(jì)算問(wèn)題
偶然從pytorch討論論壇中看到的一個(gè)問(wèn)題,KL divergence different results from tf,kl divergence 在TensorFlow中和pytorch中計(jì)算結(jié)果不同,平時(shí)沒(méi)有注意到,記錄下
一篇關(guān)于KL散度、JS散度以及交叉熵對(duì)比的文章
kl divergence 介紹
KL散度( Kullback–Leibler divergence),又稱相對(duì)熵,是描述兩個(gè)概率分布 P 和 Q 差異的一種方法。計(jì)算公式:
可以發(fā)現(xiàn),P 和 Q 中元素的個(gè)數(shù)不用相等,只需要兩個(gè)分布中的離散元素一致。
舉個(gè)簡(jiǎn)單例子:
兩個(gè)離散分布分布分別為 P 和 Q
P 的分布為:{1,1,2,2,3}
Q 的分布為:{1,1,1,1,1,2,3,3,3,3}
我們發(fā)現(xiàn),雖然兩個(gè)分布中元素個(gè)數(shù)不相同,P 的元素個(gè)數(shù)為 5,Q 的元素個(gè)數(shù)為 10。但里面的元素都有 “1”,“2”,“3” 這三個(gè)元素。
當(dāng) x = 1時(shí),在 P 分布中,“1” 這個(gè)元素的個(gè)數(shù)為 2,故 P(x = 1) = 2/5 = 0.4,在 Q 分布中,“1” 這個(gè)元素的個(gè)數(shù)為 5,故 Q(x = 1) = 5/10 = 0.5
同理,
當(dāng) x = 2 時(shí),P(x = 2) = 2/5 = 0.4 ,Q(x = 2) = 1/10 = 0.1
當(dāng) x = 3 時(shí),P(x = 3) = 1/5 = 0.2 ,Q(x = 3) = 4/10 = 0.4
把上述概率帶入公式:
至此,就計(jì)算完成了兩個(gè)離散變量分布的KL散度。
pytorch 中的 kl_div 函數(shù)
pytorch中有用于計(jì)算kl散度的函數(shù) kl_div
torch.nn.functional.kl_div(input, target, size_average=None, reduce=None, reduction='mean')
計(jì)算 D (p||q)
1、不用這個(gè)函數(shù)的計(jì)算結(jié)果為:
與手算結(jié)果相同
2、使用函數(shù):
(這是計(jì)算正確的,結(jié)果有差異是因?yàn)閜ytorch這個(gè)函數(shù)中默認(rèn)的是以e為底)
注意:
1、函數(shù)中的 p q 位置相反(也就是想要計(jì)算D(p||q),要寫成kl_div(q.log(),p)的形式),而且q要先取 log
2、reduction 是選擇對(duì)各部分結(jié)果做什么操作,默認(rèn)為取平均數(shù),這里選擇求和
好別扭的用法,不知道為啥官方把它設(shè)計(jì)成這樣
補(bǔ)充:pytorch 的KL divergence的實(shí)現(xiàn)
看代碼吧~
import torch.nn.functional as F # p_logit: [batch, class_num] # q_logit: [batch, class_num] def kl_categorical(p_logit, q_logit): p = F.softmax(p_logit, dim=-1) _kl = torch.sum(p * (F.log_softmax(p_logit, dim=-1) - F.log_softmax(q_logit, dim=-1)), 1) return torch.mean(_kl)
以上為個(gè)人經(jīng)驗(yàn),希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
相關(guān)文章
python中的多線程鎖lock=threading.Lock()使用方式
這篇文章主要介紹了python中的多線程鎖lock=threading.Lock()使用方式,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。如有錯(cuò)誤或未考慮完全的地方,望不吝賜教2022-06-06Python?庫(kù)?PySimpleGUI?制作自動(dòng)化辦公小軟件的方法
Python?在運(yùn)維和辦公自動(dòng)化中扮演著重要的角色,PySimpleGUI?是一款很棒的自動(dòng)化輔助模塊,讓你更輕松的實(shí)現(xiàn)日常任務(wù)的自動(dòng)化,下面通過(guò)本文給大家介紹下Python?庫(kù)?PySimpleGUI?制作自動(dòng)化辦公小軟件的過(guò)程,一起看看吧2021-12-12用python修改excel表某一列內(nèi)容的操作方法
這篇文章主要介紹了用python修改excel表某一列內(nèi)容的操作代碼,在實(shí)現(xiàn)過(guò)程中用到openpyxl這個(gè)庫(kù),要生成隨機(jī)數(shù)就要有random這個(gè)庫(kù),具體代碼跟隨小編一起看看吧2021-06-06python如何實(shí)現(xiàn)Dice系數(shù)
這篇文章主要介紹了python如何實(shí)現(xiàn)Dice系數(shù),具有很好的參考價(jià)值,希望對(duì)大家有所幫助,如有錯(cuò)誤或未考慮完全的地方,望不吝賜教2023-10-10python進(jìn)行數(shù)據(jù)預(yù)處理的4個(gè)重要步驟
在數(shù)據(jù)科學(xué)項(xiàng)目中,數(shù)據(jù)預(yù)處理是最重要的事情之一,本文詳細(xì)給大家介紹python進(jìn)行數(shù)據(jù)預(yù)處理的4個(gè)重要步驟:拆分訓(xùn)練集和測(cè)試集,處理缺失值,處理分類特征和進(jìn)行標(biāo)準(zhǔn)化處理,需要的朋友可以參考下2023-06-06react+django清除瀏覽器緩存的幾種方法小結(jié)
今天小編就為大家分享一篇react+django清除瀏覽器緩存的幾種方法小結(jié),具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧2019-07-07python中自定義異常/raise關(guān)鍵字拋出異常的案例解析
在編程過(guò)程中合理的使用異??梢允沟贸绦蛘5膱?zhí)行,本篇文章給大家介紹python中自定義異常/raise關(guān)鍵字拋出異常案例解析,需要的朋友可以參考下2024-01-01