欧美bbbwbbbw肥妇,免费乱码人妻系列日韩,一级黄片

GCN?圖神經(jīng)網(wǎng)絡(luò)使用詳解?可視化?Pytorch

 更新時(shí)間:2022年12月17日 10:12:57   作者:LZZ?and?MYY  
這篇文章主要介紹了GCN?圖神經(jīng)網(wǎng)絡(luò)使用詳解?可視化?Pytorch,具有很好的參考價(jià)值,希望對大家有所幫助。如有錯(cuò)誤或未考慮完全的地方,望不吝賜教

手動(dòng)嘗試GCN圖神經(jīng)網(wǎng)絡(luò)

最近,圖上的深度學(xué)習(xí)已經(jīng)成為深度學(xué)習(xí)社區(qū)中最熱門的研究領(lǐng)域之一。 在這里,圖神經(jīng)網(wǎng)絡(luò)(GNN)旨在將經(jīng)典的深度學(xué)習(xí)概念推廣到不規(guī)則的結(jié)構(gòu)化數(shù)據(jù)(與圖像或文本形成對比),并使神經(jīng)網(wǎng)絡(luò)能夠推理出對象及其關(guān)系。

本內(nèi)容介紹一些關(guān)于通過基于PyTorch幾何(PyG)庫的圖神經(jīng)網(wǎng)絡(luò)對圖進(jìn)行深度學(xué)習(xí)的基本概念。

PyTorch geometry是流行的深度學(xué)習(xí)框架PyTorch的擴(kuò)展庫,由各種方法和實(shí)用程序組成,以簡化圖神經(jīng)網(wǎng)絡(luò)的實(shí)現(xiàn)。

在開始之前,先介紹一下配置環(huán)境:

Pytorch: 1.8.0       Cuda: 10.2    Torch-geometric

# 導(dǎo)入使用的模塊包
import torch
import networkx as nx
import matplotlib.pyplot as plt
 
# 定義最后可視化的函數(shù)
def visualize(h, color, epoch=None, loss=None):
    plt.figure(figsize=(7,7))
    plt.xticks([])
    plt.yticks([])
 
    if torch.is_tensor(h):
        h = h.detach().cpu().numpy()
        plt.scatter(h[:, 0], h[:, 1], s=140, c=color, cmap="Set2")
        if epoch is not None and loss is not None:
            plt.xlabel(f'Epoch: {epoch}, Loss: {loss.item():.4f}', fontsize=16)
    else:
        nx.draw_networkx(G, pos=nx.spring_layout(G, seed=42), with_labels=False,
                         node_color=color, cmap="Set2")
    plt.show()

在這里,我們使用一張KarateClub圖來進(jìn)行講解,這張圖描述了一個(gè)由34名空手道俱樂部成員組成的社交網(wǎng)絡(luò),并記錄了俱樂部外成員之間的聯(lián)系。在這里,我們感興趣的是檢測由成員的交互產(chǎn)生的社區(qū)。

KarateClub圖

from torch_geometric.datasets import KarateClub
 
dataset = KarateClub()
print(f'Dataset: {dataset}:')
print('======================')
print(f'Number of graphs: {len(dataset)}') # 1
print(f'Number of features: {dataset.num_features}') # 34
print(f'Number of classes: {dataset.num_classes}') # 4

這里輸出的分別是:

  • (1)圖的數(shù)量、
  • (2)特征的數(shù)量
  • (3)種類

在初始化KarateClub數(shù)據(jù)集之后,我們首先可以檢查它的一些屬性。

例如,我們可以看到這個(gè)數(shù)據(jù)集只持有一個(gè)圖,并且這個(gè)數(shù)據(jù)集中的每個(gè)節(jié)點(diǎn)被分配一個(gè)34維的特征向量(唯一地描述空手道俱樂部的成員)。

此外,圖中正好包含4個(gè)類,它們代表每個(gè)節(jié)點(diǎn)所屬的團(tuán)體。

現(xiàn)在讓我們更詳細(xì)地看一下底層圖

data = dataset[0]  # Get the first graph object.
 
print(data)
print('==============================================================')
 
# Gather some statistics about the graph.
print(f'Number of nodes: {data.num_nodes}')
print(f'Number of edges: {data.num_edges}')
print(f'Average node degree: {data.num_edges / data.num_nodes:.2f}')
print(f'Number of training nodes: {data.train_mask.sum()}')
print(f'Training node label rate: {int(data.train_mask.sum()) / data.num_nodes:.2f}')
print(f'Contains isolated nodes: {data.contains_isolated_nodes()}')
print(f'Contains self-loops: {data.contains_self_loops()}')
print(f'Is undirected: {data.is_undirected()}')
Data(edge_index=[2, 156], train_mask=[34], x=[34, 34], y=[34])
==============================================================
Number of nodes: 34
Number of edges: 156
Average node degree: 4.59
Number of training nodes: 4
Training node label rate: 0.12
Contains isolated nodes: False
Contains self-loops: False
Is undirected: True

PyTorch Geometric 中的每個(gè)圖形都由單個(gè) Data 對象表示,該對象包含描述其圖形表示的所有信息。

我們可以隨時(shí)通過 print(data) 打印數(shù)據(jù)對象,以接收有關(guān)其屬性及其形狀的簡短摘要:

Data(edge_index=[2,?156],?x=[34,?34],?y=[34],?train_mask=[34])

我們可以看到該數(shù)據(jù)對象具有4個(gè)屬性:

(1)edge_index:屬性保存有關(guān)圖連接性的信息,即每個(gè)邊緣的源節(jié)點(diǎn)索引和目標(biāo)節(jié)點(diǎn)索引的元組。 PyG進(jìn)一步將

(2)節(jié)點(diǎn)特征稱為x(為34個(gè)節(jié)點(diǎn)中的每個(gè)節(jié)點(diǎn)分配了一個(gè)34維特征向量),并且將

(3)節(jié)點(diǎn)標(biāo)簽稱為y(每個(gè)節(jié)點(diǎn)被精確地分配為一個(gè)類別)。

(4)還有一個(gè)名為train_mask的附加屬性,它描述了我們已經(jīng)知道其社區(qū)歸屬的節(jié)點(diǎn)。 總共,我們只知道4個(gè)節(jié)點(diǎn)的基本標(biāo)簽(每個(gè)社區(qū)一個(gè)),任務(wù)是推斷其余節(jié)點(diǎn)的社區(qū)分配。數(shù)據(jù)對象還提供一些實(shí)用程序功能來推斷基礎(chǔ)圖的某些基本屬性。 例如,我們可以輕松推斷圖中是否存在孤立的節(jié)點(diǎn)(即,任何節(jié)點(diǎn)都沒有邊),圖是否包含自環(huán)(即(v,v)∈E)或圖是否為 無向的(即,對于每個(gè)邊(v,w)∈E也存在邊(w,v)∈E)。

現(xiàn)在讓我們更詳細(xì)地檢查edge_index的屬性

from IPython.display import Javascript  # Restrict height of output cell.
display(Javascript('''google.colab.output.setIframeHeight(0, true, {maxHeight: 300})'''))
 
edge_index = data.edge_index
print(edge_index.t())
tensor([[ 0,  1],
        [ 0,  2],
        [ 0,  3],
        [ 0,  4],
        [ 0,  5],
        [ 0,  6],
        [ 0,  7],
        [ 0,  8],
         ........

這個(gè)edge_index描述了34個(gè)人的相關(guān)性。通過輸出edge_index,我們可以進(jìn)一步了解PyG內(nèi)部是如何表示圖連通性的。

我們可以看到,對于每條邊,edge_index 包含兩個(gè)節(jié)點(diǎn)索引的元組,其中第一個(gè)值描述源節(jié)點(diǎn)的節(jié)點(diǎn)索引,第二個(gè)值描述邊的目標(biāo)節(jié)點(diǎn)的節(jié)點(diǎn)索引。

這種表示被稱為COO格式(坐標(biāo)格式),通常用于表示稀疏矩陣。

PyG使用稀疏矩陣代替以密集表示形式的鄰接矩陣A∈{0,1} | V |×| V | ,這是指僅保留A中的條目不為零的坐標(biāo)/值。

我們可以通過將圖轉(zhuǎn)換為networkx庫格式來進(jìn)一步可視化,這種格式除了圖形操作功能之外,還實(shí)現(xiàn)了用于可視化的強(qiáng)大工具

from torch_geometric.utils import to_networkx
 
G = to_networkx(data, to_undirected=True)
visualize(G, color=data.y)

數(shù)據(jù)庫可視化

灰色、黃色、綠色、藍(lán)色代表四類不同的俱樂部,其中每一個(gè)圓圈代表一個(gè)人,一共有34個(gè)人,每個(gè)人之間的關(guān)系就如edge_index所描述的那樣。

現(xiàn)在,我們要通過在torch.nn.Module類繼承中定義我們的網(wǎng)絡(luò)架構(gòu)來創(chuàng)建我們的第一個(gè)圖神經(jīng)網(wǎng)絡(luò)

import torch
from torch.nn import Linear
from torch_geometric.nn import GCNConv
 
 
class GCN(torch.nn.Module):
    def __init__(self):
        super(GCN, self).__init__()
        torch.manual_seed(12345)
        self.conv1 = GCNConv(dataset.num_features, 4)
        self.conv2 = GCNConv(4, 4)
        self.conv3 = GCNConv(4, 2)
        self.classifier = Linear(2, dataset.num_classes)
 
    def forward(self, x, edge_index):
        h = self.conv1(x, edge_index)
        h = h.tanh()
        h = self.conv2(h, edge_index)
        h = h.tanh()
        h = self.conv3(h, edge_index)
        h = h.tanh()  # Final GNN embedding space.
        
        # Apply a final (linear) classifier.
        out = self.classifier(h)
 
        return out, h
 
model = GCN()
print(model)
GCN(
  (conv1): GCNConv(34, 4)
  (conv2): GCNConv(4, 4)
  (conv3): GCNConv(4, 2)
  (classifier): Linear(in_features=2, out_features=4, bias=True)
)

在這里,我們首先在 __init__ 中初始化我們所有的構(gòu)建塊,并定義我們forward網(wǎng)絡(luò)的計(jì)算流程。 我們首先定義并堆疊三個(gè)圖卷積層,這對應(yīng)于聚合每個(gè)節(jié)點(diǎn)周圍的 3 個(gè)鄰域信息(所有節(jié)點(diǎn)最多 3個(gè))。 此外,GCNConv 層將節(jié)點(diǎn)特征維數(shù)減少到 2 ,即 34→4→4→2 。 每個(gè) GCNConv 層都通過 tanh 非線性增強(qiáng)。(可以換成RELU試一試)

之后,我們應(yīng)用單個(gè)線性變換 (torch.nn.Linear) 作為分類器將我們的節(jié)點(diǎn)映射到 4 個(gè)類/社區(qū)中的 1 個(gè)。

我們返回最終分類器的輸出以及GNN生成的最終節(jié)點(diǎn)嵌入。 我們繼續(xù)通過 GCN() 初始化我們的最終模型,打印我們的模型會(huì)生成所有使用的子模塊的摘要。

嵌入 Karate Club Network

讓我們看看GNN產(chǎn)生的節(jié)點(diǎn)嵌入。這里,我們將初始節(jié)點(diǎn)特征x和圖連通性信息edge_index傳遞給模型,并可視化其二維嵌入。

model = GCN()
 
_, h = model(data.x, data.edge_index)
print(f'Embedding shape: {list(h.shape)}')
 
visualize(h, color=data.y)

值得注意的是,即使在訓(xùn)練我們的模型的權(quán)重之前,該模型也會(huì)產(chǎn)生一個(gè)與圖中的社區(qū)結(jié)構(gòu)非常相似的節(jié)點(diǎn)嵌入。

相同顏色(社區(qū))的節(jié)點(diǎn)在嵌入空間中已經(jīng)緊密地聚在一起,盡管我們的模型的權(quán)值是完全隨機(jī)初始化的,而且到目前為止我們還沒有進(jìn)行任何訓(xùn)練!由此得出結(jié)論,gnn引入了很強(qiáng)的歸納偏置,導(dǎo)致輸入圖中彼此接近的節(jié)點(diǎn)產(chǎn)生類似的嵌入。

訓(xùn)練 Karate Club Network

但我們能做得更好嗎? 讓我們看一個(gè)示例,說明如何根據(jù)圖中 4 個(gè)節(jié)點(diǎn)的社區(qū)分配知識(shí)(每個(gè)社區(qū)一個(gè))來訓(xùn)練我們的網(wǎng)絡(luò)參數(shù):

由于我們模型中的所有內(nèi)容都是可微分和參數(shù)化的,我們可以添加一些標(biāo)簽、訓(xùn)練模型并觀察嵌入的反應(yīng)。 在這里,我們使用半監(jiān)督或轉(zhuǎn)導(dǎo)學(xué)習(xí)程序:我們只是針對每個(gè)類的一個(gè)節(jié)點(diǎn)進(jìn)行訓(xùn)練,但允許使用完整的輸入圖數(shù)據(jù)。

這個(gè)模型訓(xùn)練與任何其他PyTorch模型非常相似。除了定義我們的網(wǎng)絡(luò)架構(gòu)之外,我們還定義了一個(gè)損失標(biāo)準(zhǔn)(這里是CrossEntropyLoss),并初始化了一個(gè)隨機(jī)梯度優(yōu)化器(這里是Adam)。之后,我們執(zhí)行多輪優(yōu)化,每輪由前向和后向傳遞來計(jì)算我們的模型參數(shù)w.r.t.對前向傳遞的損失的梯度。

import time
from IPython.display import Javascript  # Restrict height of output cell.
display(Javascript('''google.colab.output.setIframeHeight(0, true, {maxHeight: 430})'''))
 
model = GCN()
criterion = torch.nn.CrossEntropyLoss()  # Define loss criterion.
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)  # Define optimizer.
 
def train(data):
    optimizer.zero_grad()  # Clear gradients.
    out, h = model(data.x, data.edge_index)  # Perform a single forward pass.
    loss = criterion(out[data.train_mask], data.y[data.train_mask])  # Compute the loss solely based on the training nodes.
    loss.backward()  # Derive gradients.
    optimizer.step()  # Update parameters based on gradients.
    return loss, h
 
for epoch in range(401):
    loss, h = train(data)
    if epoch % 10 == 0:
        visualize(h, color=data.y, epoch=epoch, loss=loss)
        time.sleep(0.3)

可以看到,訓(xùn)練400輪后,它的聚類是比較明顯的。正如可以看到的,我們的3層GCN模型管理線性分隔社區(qū)和正確分類大多數(shù)節(jié)點(diǎn)。

此外,我們只用了幾行代碼就完成了這一切,這要感謝PyTorch geometry庫,它幫助我們完成了數(shù)據(jù)處理和GNN實(shí)現(xiàn)。

總結(jié)

以上為個(gè)人經(jīng)驗(yàn),希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。

相關(guān)文章

最新評論