詳解使用Pytorch Geometric實(shí)現(xiàn)GraphSAGE模型
正文
GraphSAGE是一種用于圖神經(jīng)網(wǎng)絡(luò)中的節(jié)點(diǎn)嵌入學(xué)習(xí)方法。它通過聚合節(jié)點(diǎn)鄰居的信息來(lái)生成節(jié)點(diǎn)的低維表示,使節(jié)點(diǎn)表示能夠更好地應(yīng)用于各種下游任務(wù),如節(jié)點(diǎn)分類、鏈路預(yù)測(cè)等。
圖構(gòu)建
在使用GraphSAGE對(duì)節(jié)點(diǎn)進(jìn)行嵌入學(xué)習(xí)之前,我們需要先將原始數(shù)據(jù)轉(zhuǎn)換為圖結(jié)構(gòu),并將其存儲(chǔ)為Pytorch Tensor格式。例如,我們可以使用networkx庫(kù)來(lái)構(gòu)建一個(gè)簡(jiǎn)單的圖:
import networkx as nx G = nx.karate_club_graph()
然后,我們可以使用Pytorch Geometric庫(kù)將NetworkX圖轉(zhuǎn)換為Pytorch Tensor格式。首先,我們需要安裝Pytorch Geometric并導(dǎo)入所需的類:
!pip install torch-geometric from torch_geometric.datasets import Planetoid from torch_geometric.transforms import NormalizeFeatures from torch_geometric.utils.convert import from_networkx
接著,我們可以使用from_networkx
函數(shù)將NetworkX圖轉(zhuǎn)換為Pytorch Tensor格式:
data = from_networkx(G)
此時(shí),data
對(duì)象包含了關(guān)于節(jié)點(diǎn)、邊及其屬性的信息,例如:
data.edge_index: 2x(#edges)的長(zhǎng)整型張量,表示邊的起點(diǎn)和終點(diǎn)
data.x
: n×dn \times dn×d 的浮點(diǎn)型張量,表示每個(gè)節(jié)點(diǎn)的特征向量(其中nnn是節(jié)點(diǎn)數(shù)量,ddd是特征維度)
注意,此時(shí)的data
對(duì)象并未包含鄰居信息。接下來(lái),我們將介紹如何使用Sampler方法采樣節(jié)點(diǎn)鄰居。
Sampler方法
GraphSAGE使用Sampler方法來(lái)聚合鄰居信息。在Pytorch Geometric中,可以使用Various Sampling方法來(lái)實(shí)現(xiàn)Sampler。例如,使用ClusterData方法將圖分成多個(gè)子圖,然后對(duì)每個(gè)子圖進(jìn)行采樣操作。
以下是ClusterData
的使用示例:
from torch_geometric.utils import degree, to_undirected from torch_geometric.transforms import ClusterData # Convert the graph to an undirected graph, so we can aggregate neighbors in both directions. G = to_undirected(G) # Compute the degree of each node. deg = degree(data.edge_index[0], num_nodes=data.num_nodes) # Use METIS algorithm to partition the graph into multiple subgraphs. cluster_data = ClusterData(data, num_parts=2, recursive=False, transform=NormalizeFeatures(), degree=deg)
這里我們將原始圖分成兩個(gè)子圖,并對(duì)每個(gè)子圖進(jìn)行規(guī)范化特征轉(zhuǎn)換。注意,在使用ClusterData方法之前,需要將原始圖轉(zhuǎn)換為無(wú)向圖。
另一個(gè)常用的Sampler方法是在隨機(jī)游動(dòng)時(shí)對(duì)鄰居進(jìn)行采樣,這種方法被稱為隨機(jī)游走采樣(Random Walk Sampling)。以下是隨機(jī)游走采樣的示例代碼:
from torch_geometric.utils import random_walk # Perform random walk sampling to obtain node neighbor samples. walk_length = 20 # The length of random walk trail. num_steps = 4 # The number of nodes to sample from each step. data.batch = None data.edge_index = to_undirected(data.edge_index) # Use undirected edge for random walk. rw_data = random_walk(data.edge_index, walk_length=walk_length, num_steps=num_steps)
這里我們將使用一個(gè)長(zhǎng)度為20、每個(gè)步驟采樣4個(gè)鄰居的隨機(jī)游走方法。注意,在使用隨機(jī)游走方法進(jìn)行采樣之前,需要使用無(wú)向邊。
GraphSAGE模型定義
GraphSAGE模型包含3個(gè)部分:1)圖卷積層;2)聚合器(Aggregator);3)輸出層。我們將在本節(jié)中介紹如何使用Pytorch實(shí)現(xiàn)這些組件。
首先,讓我們定義一個(gè)圖卷積層。圖卷積層的輸入是節(jié)點(diǎn)特征矩陣、鄰接矩陣和聚合器,輸出是新的節(jié)點(diǎn)特征矩陣。以下是圖卷積層的代碼實(shí)現(xiàn):
import torch.nn.functional as F from torch_geometric.nn.conv import MessagePassing from torch_geometric.nn import global_mean_pool class GraphSageConv(MessagePassing): def __init__(self, in_channels, out_channels, aggr='mean'): super(GraphSageConv, self).__init__(aggr=aggr) self.lin = nn.Linear(in_channels, out_channels) def forward(self, x, edge_index): return self.propagate(edge_index, x=x) def message(self, x_j): return x_j def update(self, aggr_out, x): return F.relu(self.lin(torch.cat([x, aggr_out], dim=1)))
這里我們繼承了MessagePassing
類,并在__init__
函數(shù)中定義了一個(gè)全連接層,用于將輸入特征矩陣x
從 dind_{in}din? 維映射到 doutd_{out}dout? 維。在forward
函數(shù)中,我們使用propagate
方法來(lái)實(shí)現(xiàn)消息傳遞操作;在message
函數(shù)中,我們僅向下游節(jié)點(diǎn)發(fā)送原始特征數(shù)據(jù);在update
函數(shù)中,我們首先對(duì)聚合結(jié)果進(jìn)行ReLU非線性變換,然后再通過全連接層進(jìn)行節(jié)點(diǎn)特征的更新。
接下來(lái),讓我們定義一個(gè)聚合器。聚合器的輸入是采樣得到的鄰居特征矩陣,輸出是新的節(jié)點(diǎn)嵌入向量。以下是聚合器的代碼實(shí)現(xiàn):
class MeanAggregator(nn.Module): def __init__(self, input_dim, output_dim): super(MeanAggregator, self).__init__() self.input_dim = input_dim self.output_dim = output_dim self.lin = nn.Linear(input_dim, output_dim) def forward(self, neigh_mean): out = F.relu(self.lin(neigh_mean)) return out
這里我們定義了一個(gè)簡(jiǎn)單的均值聚合器,其將鄰居特征矩陣中每列的均值作為節(jié)點(diǎn)嵌入向量,并使用全連接層進(jìn)行維度變換。
最后,讓我們定義整個(gè)GraphSage模型。GraphSage模型包含2個(gè)圖卷積層和1個(gè)輸出層。以下是模型的代碼實(shí)現(xiàn):
class GraphSAGE(nn.Module): def __init__(self, in_channels, hidden_channels, out_channels, num_layers=2): super(GraphSAGE, self).__init__() self.conv1 = GraphSageConv(in_channels, hidden_channels) self.aggreg1 = MeanAggregator(hidden_channels, hidden_channels) self.conv2 = GraphSageConv(hidden_channels, out_channels) def forward(self, x, edge_index): x = self.conv1(x, edge_index) x = global_mean_pool(x, edge_index) # Compute global mean over nodes. x = self.aggreg1(x) x = self.conv2(x, edge_index) return x
這里我們定義了一個(gè)包含2層GraphSAGE Conv層的神經(jīng)網(wǎng)絡(luò)。在最后一層GraphSAGE Conv層之后,我們使用global_mean_pool
函數(shù)來(lái)計(jì)算節(jié)點(diǎn)嵌入的全局平均值。注意,在本示例中,我們僅保留了一個(gè)輸出節(jié)點(diǎn),因此輸出矩陣的大小為1。如果需要輸出多個(gè)節(jié)點(diǎn),則需要設(shè)置global_mean_pool
函數(shù)中的參數(shù)。
模型訓(xùn)練與測(cè)試
在定義好模型后,我們可以使用Pytorch進(jìn)行模型訓(xùn)練和測(cè)試。首先,讓我們定義一個(gè)損失函數(shù)和優(yōu)化器:
criterion = nn.CrossEntropyLoss() optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
這里我們使用交叉熵作為損失函數(shù),并使用Adam優(yōu)化器來(lái)更新模型參數(shù)。
接著,我們可以開始訓(xùn)練模型。以下是訓(xùn)練過程的代碼實(shí)現(xiàn):
num_epochs = 100 for epoch in range(num_epochs): model.train() optimizer.zero_grad() out = model(data.x, data.edge_index) loss = criterion(out[data.train_mask], data.y[data.train_mask]) loss.backward() optimizer.step() print('Epoch {:03d}, Loss: {:.4f}'.format(epoch, loss.item()))
這里我們遍歷所有數(shù)據(jù)樣本,計(jì)算預(yù)測(cè)結(jié)果和真實(shí)標(biāo)簽之間的交叉熵?fù)p失,并使用反向傳播來(lái)更新權(quán)重。我們?cè)诿總€(gè)epoch結(jié)束后打印出當(dāng)前損失值。
最后,我們可以對(duì)模型進(jìn)行測(cè)試。以下是測(cè)試過程的代碼實(shí)現(xiàn):
model.eval() with torch.no_grad(): pred = model(data.x, data.edge_index) pred = pred.argmax(dim=1) acc = (pred[data.test_mask] == data.y[data.test_mask]).sum().item() / data.test_mask.sum().item() print('Test accuracy: {:.4f}'.format(acc))
這里我們使用測(cè)試集來(lái)計(jì)算模型的準(zhǔn)確率。注意,在執(zhí)行model.eval()
后,我們需要使用torch.no_grad()
包裝代碼塊,以禁止梯度計(jì)算。
總結(jié)
介紹了如何使用Pytorch Geometric實(shí)現(xiàn)GraphSAGE模型,包括構(gòu)建圖、定義Sampler方法、定義模型、訓(xùn)練和測(cè)試模型等步驟。GraphSAGE模型是一種常用的節(jié)點(diǎn)嵌入學(xué)習(xí)方法,可以應(yīng)用于各種下游任務(wù)中。
以上就是詳解使用Pytorch Geometric實(shí)現(xiàn)GraphSAGE模型的詳細(xì)內(nèi)容,更多關(guān)于Pytorch Geometric GraphSAGE的資料請(qǐng)關(guān)注腳本之家其它相關(guān)文章!
- PyTorch模型轉(zhuǎn)換為ONNX格式實(shí)現(xiàn)過程詳解
- 利用Pytorch實(shí)現(xiàn)ResNet網(wǎng)絡(luò)構(gòu)建及模型訓(xùn)練
- 詳解利用Pytorch實(shí)現(xiàn)ResNet網(wǎng)絡(luò)之評(píng)估訓(xùn)練模型
- pytorch模型的保存加載與續(xù)訓(xùn)練詳解
- AMP?Tensor?Cores節(jié)省內(nèi)存PyTorch模型詳解
- 詳解?PyTorch?Lightning模型部署到生產(chǎn)服務(wù)中
- Pytorch模型定義與深度學(xué)習(xí)自查手冊(cè)
- 一文詳解如何實(shí)現(xiàn)PyTorch模型編譯
相關(guān)文章
解決Python3錯(cuò)誤:SyntaxError: unexpected EOF while
這篇文章主要介紹了解決Python3錯(cuò)誤:SyntaxError: unexpected EOF while parsin問題,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。如有錯(cuò)誤或未考慮完全的地方,望不吝賜教2022-07-07Python實(shí)現(xiàn)一個(gè)自助取數(shù)查詢工具
在數(shù)據(jù)生產(chǎn)應(yīng)用部門,取數(shù)分析是一個(gè)很常見的需求,實(shí)際上業(yè)務(wù)人員需求時(shí)刻變化,最高效的方式是讓業(yè)務(wù)部門自己來(lái)取,減少不必要的重復(fù)勞動(dòng),本文介紹如何用Python實(shí)現(xiàn)一個(gè)自助取數(shù)查詢工具2021-06-06python中的數(shù)據(jù)結(jié)構(gòu)比較
這篇文章主要介紹了python中的數(shù)據(jù)結(jié)構(gòu)比較,本文給大家介紹的非常詳細(xì),具有一定的參考借鑒價(jià)值,需要的朋友可以參考下2019-05-05python xmind 包使用詳解(其中解決導(dǎo)出的xmind文件 xmind8可以打開 xmind2020及之后版本打
xmind8 可以打開xmind2020 報(bào)錯(cuò),如何解決這個(gè)問題呢?下面小編給大家?guī)?lái)了python xmind 包使用(其中解決導(dǎo)出的xmind文件 xmind8可以打開 xmind2020及之后版本打開報(bào)錯(cuò)問題),感興趣的朋友一起看看吧2021-10-10python實(shí)現(xiàn)詩(shī)歌游戲(類繼承)
這篇文章主要為大家詳細(xì)介紹了python實(shí)現(xiàn)詩(shī)歌游戲,根據(jù)上句猜下句、猜作者、猜朝代、猜詩(shī)名,具有一定的參考價(jià)值,感興趣的小伙伴們可以參考一下2019-02-02一文教會(huì)你用Python獲取網(wǎng)頁(yè)指定內(nèi)容
Python用做數(shù)據(jù)處理還是相當(dāng)不錯(cuò)的,如果你想要做爬蟲,Python是很好的選擇,它有很多已經(jīng)寫好的類包,只要調(diào)用即可完成很多復(fù)雜的功能,下面這篇文章主要給大家介紹了關(guān)于Python獲取網(wǎng)頁(yè)指定內(nèi)容的相關(guān)資料,需要的朋友可以參考下2022-03-03詳解python中[-1]、[:-1]、[::-1]、[n::-1]使用方法
這篇文章主要介紹了詳解python中[-1]、[:-1]、[::-1]、[n::-1]使用方法,文中通過示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來(lái)一起學(xué)習(xí)學(xué)習(xí)吧2021-04-04Keras load_model 導(dǎo)入錯(cuò)誤的解決方式
這篇文章主要介紹了Keras load_model 導(dǎo)入錯(cuò)誤的解決方式,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過來(lái)看看吧2020-06-06