在PyTorch中實(shí)現(xiàn)可解釋的神經(jīng)網(wǎng)絡(luò)模型的方法詳解
目的
深度學(xué)習(xí)系統(tǒng)缺乏可解釋性對建立人類信任構(gòu)成了重大挑戰(zhàn)。這些模型的復(fù)雜性使人類幾乎不可能理解其決策背后的根本原因。
深度學(xué)習(xí)系統(tǒng)缺乏可解釋性阻礙了人類的信任。
為了解決這個(gè)問題,研究人員一直在積極研究新的解決方案,從而產(chǎn)生了重大創(chuàng)新,例如基于概念的模型。這些模型不僅提高了模型的透明度,而且通過在訓(xùn)練過程中結(jié)合高級人類可解釋的概念(如“顏色”或“形狀”),培養(yǎng)了對系統(tǒng)決策的新信任感。因此,這些模型可以根據(jù)學(xué)習(xí)到的概念為其預(yù)測提供簡單直觀的解釋,從而使人們能夠檢查其決策背后的原因。這還不是全部!它們甚至允許人類與學(xué)習(xí)到的概念進(jìn)行交互,讓我們能夠控制最終的決定。
基于概念的模型允許人類檢查深度學(xué)習(xí)預(yù)測背后的推理,并讓我們重新控制最終決策。
在本文中,我們將深入研究這些技術(shù),并為您提供使用簡單的 PyTorch 接口實(shí)現(xiàn)最先進(jìn)的基于概念的模型的工具。通過實(shí)踐經(jīng)驗(yàn),您將學(xué)習(xí)如何利用這些強(qiáng)大的模型來增強(qiáng)可解釋性并最終校準(zhǔn)人類對您的深度學(xué)習(xí)系統(tǒng)的信任。
概念瓶頸模型
在這個(gè)介紹中,我們將深入探討概念瓶頸模型。這模型在 2020 年國際機(jī)器學(xué)習(xí)會議上發(fā)表的一篇論文中介紹,旨在首先學(xué)習(xí)和預(yù)測一組概念,例如“顏色”或“形狀”,然后利用這些概念來解決下游分類任務(wù):
通過遵循這種方法,我們可以將預(yù)測追溯到提供解釋的概念,例如“輸入對象是一個(gè){apple},因?yàn)樗莧spherical}和{red}。”
概念瓶頸模型首先學(xué)習(xí)一組概念,例如“顏色”或“形狀”,然后利用這些概念來解決下游分類任務(wù)。
實(shí)現(xiàn)
為了說明概念瓶頸模型,我們將重新審視著名的 XOR 問題,但有所不同。我們的輸入將包含兩個(gè)連續(xù)的特征。為了捕捉這些特征的本質(zhì),我們將使用概念編碼器將它們映射為兩個(gè)有意義的概念,表示為“A”和“B”。我們?nèi)蝿?wù)的目標(biāo)是預(yù)測“A”和“B”的異或 (XOR)。通過這個(gè)例子,您將更好地理解概念瓶頸如何在實(shí)踐中應(yīng)用,并見證它們在解決具體問題方面的有效性。
我們可以從導(dǎo)入必要的庫并加載這個(gè)簡單的數(shù)據(jù)集開始:
import torch import torch_explain as te from torch_explain import datasets from sklearn.metrics import accuracy_score from sklearn.model_selection import train_test_split x, c, y = datasets.xor(500) x_train, x_test, c_train, c_test, y_train, y_test = train_test_split(x, c, y, test_size=0.33, random_state=42)
接下來,我們實(shí)例化一個(gè)概念編碼器以將輸入特征映射到概念空間,并實(shí)例化一個(gè)任務(wù)預(yù)測器以將概念映射到任務(wù)預(yù)測:
concept_encoder = torch.nn.Sequential( torch.nn.Linear(x.shape[1], 10), torch.nn.LeakyReLU(), torch.nn.Linear(10, 8), torch.nn.LeakyReLU(), torch.nn.Linear(8, c.shape[1]), torch.nn.Sigmoid(), ) task_predictor = torch.nn.Sequential( torch.nn.Linear(c.shape[1], 8), torch.nn.LeakyReLU(), torch.nn.Linear(8, 1), ) model = torch.nn.Sequential(concept_encoder, task_predictor)
然后我們通過優(yōu)化概念和任務(wù)的交叉熵?fù)p失來訓(xùn)練網(wǎng)絡(luò):
optimizer = torch.optim.AdamW(model.parameters(), lr=0.01) loss_form_c = torch.nn.BCELoss() loss_form_y = torch.nn.BCEWithLogitsLoss() model.train() for epoch in range(2001): optimizer.zero_grad() # generate concept and task predictions c_pred = concept_encoder(x_train) y_pred = task_predictor(c_pred) # update loss concept_loss = loss_form_c(c_pred, c_train) task_loss = loss_form_y(y_pred, y_train) loss = concept_loss + 0.2*task_loss loss.backward() optimizer.step()
訓(xùn)練模型后,我們評估其在測試集上的性能:
c_pred = concept_encoder(x_test) y_pred = task_predictor(c_pred) concept_accuracy = accuracy_score(c_test, c_pred > 0.5) task_accuracy = accuracy_score(y_test, y_pred > 0)
現(xiàn)在,在幾個(gè) epoch 之后,我們可以觀察到概念和任務(wù)在測試集上的準(zhǔn)確性都非常好(~98% 的準(zhǔn)確性)!
由于這種架構(gòu),我們可以通過根據(jù)輸入概念查看任務(wù)預(yù)測器的響應(yīng)來為模型預(yù)測提供解釋,如下所示:
c_different = torch.FloatTensor([0, 1]) print(f"f({c_different}) = {int(task_predictor(c_different).item() > 0)}") c_equal = torch.FloatTensor([1, 1]) print(f"f({c_different}) = {int(task_predictor(c_different).item() > 0)}")
這會產(chǎn)生例如 f([0,1])=1 和 f([1,1])=0 ,如預(yù)期的那樣。這使我們能夠更多地了解模型的行為,并檢查它對于任何相關(guān)概念集的行為是否符合預(yù)期,例如,對于互斥的輸入概念 [0,1] 或 [1,0],它返回的預(yù)測y=1。
概念瓶頸模型通過將預(yù)測追溯到概念來提供直觀的解釋。
淹沒在準(zhǔn)確性與可解釋性的權(quán)衡中
概念瓶頸模型的主要優(yōu)勢之一是它們能夠通過揭示概念預(yù)測模式來為預(yù)測提供解釋,從而使人們能夠評估模型的推理是否符合他們的期望。
然而,標(biāo)準(zhǔn)概念瓶頸模型的主要問題是它們難以解決復(fù)雜問題!更一般地說,他們遇到了可解釋人工智能中眾所周知的一個(gè)眾所周知的問題,稱為準(zhǔn)確性-可解釋性權(quán)衡。實(shí)際上,我們希望模型不僅能實(shí)現(xiàn)高任務(wù)性能,還能提供高質(zhì)量的解釋。不幸的是,在許多情況下,當(dāng)我們追求更高的準(zhǔn)確性時(shí),模型提供的解釋往往會在質(zhì)量和忠實(shí)度上下降,反之亦然。
在視覺上,這種權(quán)衡可以表示如下:
可解釋模型擅長提供高質(zhì)量的解釋,但難以解決具有挑戰(zhàn)性的任務(wù),而黑盒模型以提供脆弱和糟糕的解釋為代價(jià)來實(shí)現(xiàn)高任務(wù)準(zhǔn)確性。
為了在具體設(shè)置中說明這種權(quán)衡,讓我們考慮一個(gè)概念瓶頸模型,該模型應(yīng)用于要求稍高的基準(zhǔn),即“三角學(xué)”數(shù)據(jù)集:
x, c, y = datasets.trigonometry(500) x_train, x_test, c_train, c_test, y_train, y_test = train_test_split(x, c, y, test_size=0.33, random_state=42)
在該數(shù)據(jù)集上訓(xùn)練相同的網(wǎng)絡(luò)架構(gòu)后,我們觀察到任務(wù)準(zhǔn)確性顯著降低,僅達(dá)到 80% 左右。
概念瓶頸模型未能在任務(wù)準(zhǔn)確性和解釋質(zhì)量之間取得平衡。
這就引出了一個(gè)問題:我們是永遠(yuǎn)被迫在準(zhǔn)確性和解釋質(zhì)量之間做出選擇,還是有辦法取得更好的平衡?
以上就是在PyTorch中實(shí)現(xiàn)可解釋的神經(jīng)網(wǎng)絡(luò)模型的方法詳解的詳細(xì)內(nèi)容,更多關(guān)于PyTorch 神經(jīng)網(wǎng)絡(luò)模型的資料請關(guān)注腳本之家其它相關(guān)文章!
- Pytorch神經(jīng)網(wǎng)絡(luò)參數(shù)管理方法詳細(xì)講解
- Pytorch之8層神經(jīng)網(wǎng)絡(luò)實(shí)現(xiàn)Cifar-10圖像分類驗(yàn)證集準(zhǔn)確率94.71%
- GCN?圖神經(jīng)網(wǎng)絡(luò)使用詳解?可視化?Pytorch
- pytorch簡單實(shí)現(xiàn)神經(jīng)網(wǎng)絡(luò)功能
- Pytorch卷積神經(jīng)網(wǎng)絡(luò)遷移學(xué)習(xí)的目標(biāo)及好處
- Pytorch深度學(xué)習(xí)經(jīng)典卷積神經(jīng)網(wǎng)絡(luò)resnet模塊訓(xùn)練
- Pytorch卷積神經(jīng)網(wǎng)絡(luò)resent網(wǎng)絡(luò)實(shí)踐
相關(guān)文章
python中的Json模塊dumps、dump、loads、load函數(shù)用法詳解
這篇文章主要介紹了python中的Json模塊dumps、dump、loads、load函數(shù)用法講解,本文逐一介紹結(jié)合實(shí)例代碼給大家講解的非常詳細(xì),需要的朋友可以參考下2022-11-11解決vscode python print 輸出窗口中文亂碼的問題
今天小編就為大家分享一篇解決vscode python print 輸出窗口中文亂碼的問題,具有很好的參考價(jià)值,希望對大家有所幫助。一起跟隨小編過來看看吧2018-12-12Python中Playwright模塊進(jìn)行自動(dòng)化測試的實(shí)現(xiàn)
playwright是由微軟開發(fā)的Web UI自動(dòng)化測試工具,本文主要介紹了Python中Playwright模塊進(jìn)行自動(dòng)化測試的實(shí)現(xiàn),具有一定的參考價(jià)值,感興趣的可以了解一下2023-12-12在Python下使用Txt2Html實(shí)現(xiàn)網(wǎng)頁過濾代理的教程
這篇文章主要介紹了在Python下使用Txt2Html實(shí)現(xiàn)網(wǎng)頁過濾代理的教程,來自IBM官方開發(fā)者技術(shù)文檔,需要的朋友可以參考下2015-04-04Python Sqlalchemy如何實(shí)現(xiàn)select for update
這篇文章主要介紹了Python Sqlalchemy如何實(shí)現(xiàn)select for update,文中通過示例代碼介紹的非常詳細(xì),對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友可以參考下2020-10-10深入理解python中if?__name__?==?‘__main__‘
很多python的文件中會有語句if?__name=='__main__':,一直不太明白,最近查閱了一下資料,現(xiàn)在明白,本文就來深入理解一下,感興趣的可以了解一下2023-08-08