GCN?圖神經(jīng)網(wǎng)絡(luò)使用詳解?可視化?Pytorch
手動(dòng)嘗試GCN圖神經(jīng)網(wǎng)絡(luò)
最近,圖上的深度學(xué)習(xí)已經(jīng)成為深度學(xué)習(xí)社區(qū)中最熱門的研究領(lǐng)域之一。 在這里,圖神經(jīng)網(wǎng)絡(luò)(GNN)旨在將經(jīng)典的深度學(xué)習(xí)概念推廣到不規(guī)則的結(jié)構(gòu)化數(shù)據(jù)(與圖像或文本形成對比),并使神經(jīng)網(wǎng)絡(luò)能夠推理出對象及其關(guān)系。
本內(nèi)容介紹一些關(guān)于通過基于PyTorch幾何(PyG)庫的圖神經(jīng)網(wǎng)絡(luò)對圖進(jìn)行深度學(xué)習(xí)的基本概念。
PyTorch geometry是流行的深度學(xué)習(xí)框架PyTorch的擴(kuò)展庫,由各種方法和實(shí)用程序組成,以簡化圖神經(jīng)網(wǎng)絡(luò)的實(shí)現(xiàn)。
在開始之前,先介紹一下配置環(huán)境:
Pytorch: 1.8.0 Cuda: 10.2 Torch-geometric
# 導(dǎo)入使用的模塊包 import torch import networkx as nx import matplotlib.pyplot as plt # 定義最后可視化的函數(shù) def visualize(h, color, epoch=None, loss=None): plt.figure(figsize=(7,7)) plt.xticks([]) plt.yticks([]) if torch.is_tensor(h): h = h.detach().cpu().numpy() plt.scatter(h[:, 0], h[:, 1], s=140, c=color, cmap="Set2") if epoch is not None and loss is not None: plt.xlabel(f'Epoch: {epoch}, Loss: {loss.item():.4f}', fontsize=16) else: nx.draw_networkx(G, pos=nx.spring_layout(G, seed=42), with_labels=False, node_color=color, cmap="Set2") plt.show()
在這里,我們使用一張KarateClub圖來進(jìn)行講解,這張圖描述了一個(gè)由34名空手道俱樂部成員組成的社交網(wǎng)絡(luò),并記錄了俱樂部外成員之間的聯(lián)系。在這里,我們感興趣的是檢測由成員的交互產(chǎn)生的社區(qū)。
KarateClub圖
from torch_geometric.datasets import KarateClub dataset = KarateClub() print(f'Dataset: {dataset}:') print('======================') print(f'Number of graphs: {len(dataset)}') # 1 print(f'Number of features: {dataset.num_features}') # 34 print(f'Number of classes: {dataset.num_classes}') # 4
這里輸出的分別是:
- (1)圖的數(shù)量、
- (2)特征的數(shù)量
- (3)種類
在初始化KarateClub數(shù)據(jù)集之后,我們首先可以檢查它的一些屬性。
例如,我們可以看到這個(gè)數(shù)據(jù)集只持有一個(gè)圖,并且這個(gè)數(shù)據(jù)集中的每個(gè)節(jié)點(diǎn)被分配一個(gè)34維的特征向量(唯一地描述空手道俱樂部的成員)。
此外,圖中正好包含4個(gè)類,它們代表每個(gè)節(jié)點(diǎn)所屬的團(tuán)體。
現(xiàn)在讓我們更詳細(xì)地看一下底層圖
data = dataset[0] # Get the first graph object. print(data) print('==============================================================') # Gather some statistics about the graph. print(f'Number of nodes: {data.num_nodes}') print(f'Number of edges: {data.num_edges}') print(f'Average node degree: {data.num_edges / data.num_nodes:.2f}') print(f'Number of training nodes: {data.train_mask.sum()}') print(f'Training node label rate: {int(data.train_mask.sum()) / data.num_nodes:.2f}') print(f'Contains isolated nodes: {data.contains_isolated_nodes()}') print(f'Contains self-loops: {data.contains_self_loops()}') print(f'Is undirected: {data.is_undirected()}')
Data(edge_index=[2, 156], train_mask=[34], x=[34, 34], y=[34]) ============================================================== Number of nodes: 34 Number of edges: 156 Average node degree: 4.59 Number of training nodes: 4 Training node label rate: 0.12 Contains isolated nodes: False Contains self-loops: False Is undirected: True
PyTorch Geometric 中的每個(gè)圖形都由單個(gè) Data 對象表示,該對象包含描述其圖形表示的所有信息。
我們可以隨時(shí)通過 print(data) 打印數(shù)據(jù)對象,以接收有關(guān)其屬性及其形狀的簡短摘要:
Data(edge_index=[2,?156],?x=[34,?34],?y=[34],?train_mask=[34])
我們可以看到該數(shù)據(jù)對象具有4個(gè)屬性:
(1)edge_index:屬性保存有關(guān)圖連接性的信息,即每個(gè)邊緣的源節(jié)點(diǎn)索引和目標(biāo)節(jié)點(diǎn)索引的元組。 PyG進(jìn)一步將
(2)節(jié)點(diǎn)特征稱為x(為34個(gè)節(jié)點(diǎn)中的每個(gè)節(jié)點(diǎn)分配了一個(gè)34維特征向量),并且將
(3)節(jié)點(diǎn)標(biāo)簽稱為y(每個(gè)節(jié)點(diǎn)被精確地分配為一個(gè)類別)。
(4)還有一個(gè)名為train_mask的附加屬性,它描述了我們已經(jīng)知道其社區(qū)歸屬的節(jié)點(diǎn)。 總共,我們只知道4個(gè)節(jié)點(diǎn)的基本標(biāo)簽(每個(gè)社區(qū)一個(gè)),任務(wù)是推斷其余節(jié)點(diǎn)的社區(qū)分配。數(shù)據(jù)對象還提供一些實(shí)用程序功能來推斷基礎(chǔ)圖的某些基本屬性。 例如,我們可以輕松推斷圖中是否存在孤立的節(jié)點(diǎn)(即,任何節(jié)點(diǎn)都沒有邊),圖是否包含自環(huán)(即(v,v)∈E)或圖是否為 無向的(即,對于每個(gè)邊(v,w)∈E也存在邊(w,v)∈E)。
現(xiàn)在讓我們更詳細(xì)地檢查edge_index的屬性
from IPython.display import Javascript # Restrict height of output cell. display(Javascript('''google.colab.output.setIframeHeight(0, true, {maxHeight: 300})''')) edge_index = data.edge_index print(edge_index.t())
tensor([[ 0, 1], [ 0, 2], [ 0, 3], [ 0, 4], [ 0, 5], [ 0, 6], [ 0, 7], [ 0, 8], ........
這個(gè)edge_index描述了34個(gè)人的相關(guān)性。通過輸出edge_index,我們可以進(jìn)一步了解PyG內(nèi)部是如何表示圖連通性的。
我們可以看到,對于每條邊,edge_index 包含兩個(gè)節(jié)點(diǎn)索引的元組,其中第一個(gè)值描述源節(jié)點(diǎn)的節(jié)點(diǎn)索引,第二個(gè)值描述邊的目標(biāo)節(jié)點(diǎn)的節(jié)點(diǎn)索引。
這種表示被稱為COO格式(坐標(biāo)格式),通常用于表示稀疏矩陣。
PyG使用稀疏矩陣代替以密集表示形式的鄰接矩陣A∈{0,1} | V |×| V | ,這是指僅保留A中的條目不為零的坐標(biāo)/值。
我們可以通過將圖轉(zhuǎn)換為networkx庫格式來進(jìn)一步可視化,這種格式除了圖形操作功能之外,還實(shí)現(xiàn)了用于可視化的強(qiáng)大工具
from torch_geometric.utils import to_networkx G = to_networkx(data, to_undirected=True) visualize(G, color=data.y)
數(shù)據(jù)庫可視化
灰色、黃色、綠色、藍(lán)色代表四類不同的俱樂部,其中每一個(gè)圓圈代表一個(gè)人,一共有34個(gè)人,每個(gè)人之間的關(guān)系就如edge_index所描述的那樣。
現(xiàn)在,我們要通過在torch.nn.Module類繼承中定義我們的網(wǎng)絡(luò)架構(gòu)來創(chuàng)建我們的第一個(gè)圖神經(jīng)網(wǎng)絡(luò)
import torch from torch.nn import Linear from torch_geometric.nn import GCNConv class GCN(torch.nn.Module): def __init__(self): super(GCN, self).__init__() torch.manual_seed(12345) self.conv1 = GCNConv(dataset.num_features, 4) self.conv2 = GCNConv(4, 4) self.conv3 = GCNConv(4, 2) self.classifier = Linear(2, dataset.num_classes) def forward(self, x, edge_index): h = self.conv1(x, edge_index) h = h.tanh() h = self.conv2(h, edge_index) h = h.tanh() h = self.conv3(h, edge_index) h = h.tanh() # Final GNN embedding space. # Apply a final (linear) classifier. out = self.classifier(h) return out, h model = GCN() print(model)
GCN( (conv1): GCNConv(34, 4) (conv2): GCNConv(4, 4) (conv3): GCNConv(4, 2) (classifier): Linear(in_features=2, out_features=4, bias=True) )
在這里,我們首先在 __init__ 中初始化我們所有的構(gòu)建塊,并定義我們forward網(wǎng)絡(luò)的計(jì)算流程。 我們首先定義并堆疊三個(gè)圖卷積層,這對應(yīng)于聚合每個(gè)節(jié)點(diǎn)周圍的 3 個(gè)鄰域信息(所有節(jié)點(diǎn)最多 3個(gè))。 此外,GCNConv 層將節(jié)點(diǎn)特征維數(shù)減少到 2 ,即 34→4→4→2 。 每個(gè) GCNConv 層都通過 tanh 非線性增強(qiáng)。(可以換成RELU試一試)
之后,我們應(yīng)用單個(gè)線性變換 (torch.nn.Linear) 作為分類器將我們的節(jié)點(diǎn)映射到 4 個(gè)類/社區(qū)中的 1 個(gè)。
我們返回最終分類器的輸出以及GNN生成的最終節(jié)點(diǎn)嵌入。 我們繼續(xù)通過 GCN() 初始化我們的最終模型,打印我們的模型會(huì)生成所有使用的子模塊的摘要。
嵌入 Karate Club Network
讓我們看看GNN產(chǎn)生的節(jié)點(diǎn)嵌入。這里,我們將初始節(jié)點(diǎn)特征x和圖連通性信息edge_index傳遞給模型,并可視化其二維嵌入。
model = GCN() _, h = model(data.x, data.edge_index) print(f'Embedding shape: {list(h.shape)}') visualize(h, color=data.y)
值得注意的是,即使在訓(xùn)練我們的模型的權(quán)重之前,該模型也會(huì)產(chǎn)生一個(gè)與圖中的社區(qū)結(jié)構(gòu)非常相似的節(jié)點(diǎn)嵌入。
相同顏色(社區(qū))的節(jié)點(diǎn)在嵌入空間中已經(jīng)緊密地聚在一起,盡管我們的模型的權(quán)值是完全隨機(jī)初始化的,而且到目前為止我們還沒有進(jìn)行任何訓(xùn)練!由此得出結(jié)論,gnn引入了很強(qiáng)的歸納偏置,導(dǎo)致輸入圖中彼此接近的節(jié)點(diǎn)產(chǎn)生類似的嵌入。
訓(xùn)練 Karate Club Network
但我們能做得更好嗎? 讓我們看一個(gè)示例,說明如何根據(jù)圖中 4 個(gè)節(jié)點(diǎn)的社區(qū)分配知識(shí)(每個(gè)社區(qū)一個(gè))來訓(xùn)練我們的網(wǎng)絡(luò)參數(shù):
由于我們模型中的所有內(nèi)容都是可微分和參數(shù)化的,我們可以添加一些標(biāo)簽、訓(xùn)練模型并觀察嵌入的反應(yīng)。 在這里,我們使用半監(jiān)督或轉(zhuǎn)導(dǎo)學(xué)習(xí)程序:我們只是針對每個(gè)類的一個(gè)節(jié)點(diǎn)進(jìn)行訓(xùn)練,但允許使用完整的輸入圖數(shù)據(jù)。
這個(gè)模型訓(xùn)練與任何其他PyTorch模型非常相似。除了定義我們的網(wǎng)絡(luò)架構(gòu)之外,我們還定義了一個(gè)損失標(biāo)準(zhǔn)(這里是CrossEntropyLoss),并初始化了一個(gè)隨機(jī)梯度優(yōu)化器(這里是Adam)。之后,我們執(zhí)行多輪優(yōu)化,每輪由前向和后向傳遞來計(jì)算我們的模型參數(shù)w.r.t.對前向傳遞的損失的梯度。
import time from IPython.display import Javascript # Restrict height of output cell. display(Javascript('''google.colab.output.setIframeHeight(0, true, {maxHeight: 430})''')) model = GCN() criterion = torch.nn.CrossEntropyLoss() # Define loss criterion. optimizer = torch.optim.Adam(model.parameters(), lr=0.01) # Define optimizer. def train(data): optimizer.zero_grad() # Clear gradients. out, h = model(data.x, data.edge_index) # Perform a single forward pass. loss = criterion(out[data.train_mask], data.y[data.train_mask]) # Compute the loss solely based on the training nodes. loss.backward() # Derive gradients. optimizer.step() # Update parameters based on gradients. return loss, h for epoch in range(401): loss, h = train(data) if epoch % 10 == 0: visualize(h, color=data.y, epoch=epoch, loss=loss) time.sleep(0.3)
可以看到,訓(xùn)練400輪后,它的聚類是比較明顯的。正如可以看到的,我們的3層GCN模型管理線性分隔社區(qū)和正確分類大多數(shù)節(jié)點(diǎn)。
此外,我們只用了幾行代碼就完成了這一切,這要感謝PyTorch geometry庫,它幫助我們完成了數(shù)據(jù)處理和GNN實(shí)現(xiàn)。
總結(jié)
以上為個(gè)人經(jīng)驗(yàn),希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
相關(guān)文章
python實(shí)現(xiàn)b站直播自動(dòng)發(fā)送彈幕功能
這篇文章主要介紹了python如何實(shí)現(xiàn)b站直播自動(dòng)發(fā)送彈幕,幫助大家更好的理解和學(xué)習(xí)使用python,感興趣的朋友可以了解下2021-02-02python抓取網(wǎng)頁內(nèi)容并進(jìn)行語音播報(bào)的方法
今天小編就為大家分享一篇python抓取網(wǎng)頁內(nèi)容并進(jìn)行語音播報(bào)的方法,具有很好的參考價(jià)值,希望對大家有所幫助。一起跟隨小編過來看看吧2018-12-12Python中實(shí)現(xiàn)變量賦值傳遞時(shí)的引用和拷貝方法
下面小編就為大家分享一篇Python中實(shí)現(xiàn)變量賦值傳遞時(shí)的引用和拷貝方法,具有很好的參考價(jià)值,希望對大家有所幫助。一起跟隨小編過來看看吧2018-04-04Python實(shí)現(xiàn)按目錄層級輸出文件名并保存為excel
當(dāng)我們發(fā)現(xiàn)電腦的內(nèi)存很滿,或平時(shí)工作中文件夾管理不清晰,導(dǎo)致里面的文件數(shù)據(jù)很雜亂,查找很不方便,一個(gè)一個(gè)文件夾去看去找然后刪除又很浪費(fèi)時(shí)間。本文將介紹如何利用Python實(shí)現(xiàn)按目錄層級輸出文件名并保存為excel,需要的可以參考一下2022-02-02Django更新models數(shù)據(jù)庫結(jié)構(gòu)步驟
這篇文章主要介紹了Django更新models數(shù)據(jù)庫結(jié)構(gòu)的操作步驟,具有很好的參考價(jià)值,希望對大家有所幫助。一起跟隨小編過來看看吧2020-04-04Python?pandas數(shù)據(jù)合并merge函數(shù)用法詳解
這篇文章主要給大家介紹了關(guān)于Python?pandas數(shù)據(jù)合并merge函數(shù)用法的相關(guān)資料,數(shù)據(jù)分析中經(jīng)常會(huì)遇到數(shù)據(jù)合并的基本問題,文中通過示例代碼介紹的非常詳細(xì),需要的朋友可以參考下2023-07-07pyqt6實(shí)現(xiàn)關(guān)閉窗口前彈出確認(rèn)框的示例代碼
本文主要介紹了pyqt6實(shí)現(xiàn)關(guān)閉窗口前彈出確認(rèn)框的示例代碼,文中通過示例代碼介紹的非常詳細(xì),對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧2024-02-02