欧美bbbwbbbw肥妇,免费乱码人妻系列日韩,一级黄片

詳解使用Pytorch Geometric實(shí)現(xiàn)GraphSAGE模型

 更新時(shí)間:2023年04月24日 10:31:39   作者:實(shí)力  
這篇文章主要為大家介紹了詳解使用Pytorch Geometric實(shí)現(xiàn)GraphSAGE模型示例詳解,有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進(jìn)步,早日升職加薪

正文

GraphSAGE是一種用于圖神經(jīng)網(wǎng)絡(luò)中的節(jié)點(diǎn)嵌入學(xué)習(xí)方法。它通過聚合節(jié)點(diǎn)鄰居的信息來(lái)生成節(jié)點(diǎn)的低維表示,使節(jié)點(diǎn)表示能夠更好地應(yīng)用于各種下游任務(wù),如節(jié)點(diǎn)分類、鏈路預(yù)測(cè)等。

圖構(gòu)建

在使用GraphSAGE對(duì)節(jié)點(diǎn)進(jìn)行嵌入學(xué)習(xí)之前,我們需要先將原始數(shù)據(jù)轉(zhuǎn)換為圖結(jié)構(gòu),并將其存儲(chǔ)為Pytorch Tensor格式。例如,我們可以使用networkx庫(kù)來(lái)構(gòu)建一個(gè)簡(jiǎn)單的圖:

import networkx as nx

G = nx.karate_club_graph()

然后,我們可以使用Pytorch Geometric庫(kù)將NetworkX圖轉(zhuǎn)換為Pytorch Tensor格式。首先,我們需要安裝Pytorch Geometric并導(dǎo)入所需的類:

!pip install torch-geometric

from torch_geometric.datasets import Planetoid
from torch_geometric.transforms import NormalizeFeatures
from torch_geometric.utils.convert import from_networkx

接著,我們可以使用from_networkx函數(shù)將NetworkX圖轉(zhuǎn)換為Pytorch Tensor格式:

data = from_networkx(G)

此時(shí),data對(duì)象包含了關(guān)于節(jié)點(diǎn)、邊及其屬性的信息,例如:

data.edge_index: 2x(#edges)的長(zhǎng)整型張量,表示邊的起點(diǎn)和終點(diǎn)

  • data.x: n×dn \times dn×d 的浮點(diǎn)型張量,表示每個(gè)節(jié)點(diǎn)的特征向量(其中nnn是節(jié)點(diǎn)數(shù)量,ddd是特征維度)

注意,此時(shí)的data對(duì)象并未包含鄰居信息。接下來(lái),我們將介紹如何使用Sampler方法采樣節(jié)點(diǎn)鄰居。

Sampler方法

GraphSAGE使用Sampler方法來(lái)聚合鄰居信息。在Pytorch Geometric中,可以使用Various Sampling方法來(lái)實(shí)現(xiàn)Sampler。例如,使用ClusterData方法將圖分成多個(gè)子圖,然后對(duì)每個(gè)子圖進(jìn)行采樣操作。

以下是ClusterData的使用示例:

from torch_geometric.utils import degree, to_undirected
from torch_geometric.transforms import ClusterData

# Convert the graph to an undirected graph, so we can aggregate neighbors in both directions.
G = to_undirected(G)

# Compute the degree of each node.
deg = degree(data.edge_index[0], num_nodes=data.num_nodes)

# Use METIS algorithm to partition the graph into multiple subgraphs.
cluster_data = ClusterData(data, num_parts=2, recursive=False, transform=NormalizeFeatures(),
                           degree=deg)

這里我們將原始圖分成兩個(gè)子圖,并對(duì)每個(gè)子圖進(jìn)行規(guī)范化特征轉(zhuǎn)換。注意,在使用ClusterData方法之前,需要將原始圖轉(zhuǎn)換為無(wú)向圖。

另一個(gè)常用的Sampler方法是在隨機(jī)游動(dòng)時(shí)對(duì)鄰居進(jìn)行采樣,這種方法被稱為隨機(jī)游走采樣(Random Walk Sampling)。以下是隨機(jī)游走采樣的示例代碼:

from torch_geometric.utils import random_walk

# Perform random walk sampling to obtain node neighbor samples.
walk_length = 20  # The length of random walk trail.
num_steps = 4     # The number of nodes to sample from each step.
data.batch = None
data.edge_index = to_undirected(data.edge_index)  # Use undirected edge for random walk.

rw_data = random_walk(data.edge_index, walk_length=walk_length, num_steps=num_steps)

這里我們將使用一個(gè)長(zhǎng)度為20、每個(gè)步驟采樣4個(gè)鄰居的隨機(jī)游走方法。注意,在使用隨機(jī)游走方法進(jìn)行采樣之前,需要使用無(wú)向邊。

GraphSAGE模型定義

GraphSAGE模型包含3個(gè)部分:1)圖卷積層;2)聚合器(Aggregator);3)輸出層。我們將在本節(jié)中介紹如何使用Pytorch實(shí)現(xiàn)這些組件。

首先,讓我們定義一個(gè)圖卷積層。圖卷積層的輸入是節(jié)點(diǎn)特征矩陣、鄰接矩陣和聚合器,輸出是新的節(jié)點(diǎn)特征矩陣。以下是圖卷積層的代碼實(shí)現(xiàn):

import torch.nn.functional as F
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn import global_mean_pool

class GraphSageConv(MessagePassing):
    def __init__(self, in_channels, out_channels, aggr='mean'):
        super(GraphSageConv, self).__init__(aggr=aggr)
        self.lin = nn.Linear(in_channels, out_channels)
        
    def forward(self, x, edge_index):
        return self.propagate(edge_index, x=x)
    
    def message(self, x_j):
        return x_j
    
    def update(self, aggr_out, x):
        return F.relu(self.lin(torch.cat([x, aggr_out], dim=1)))

這里我們繼承了MessagePassing類,并在__init__函數(shù)中定義了一個(gè)全連接層,用于將輸入特征矩陣x從 dind_{in}din? 維映射到 doutd_{out}dout? 維。在forward函數(shù)中,我們使用propagate方法來(lái)實(shí)現(xiàn)消息傳遞操作;在message函數(shù)中,我們僅向下游節(jié)點(diǎn)發(fā)送原始特征數(shù)據(jù);在update函數(shù)中,我們首先對(duì)聚合結(jié)果進(jìn)行ReLU非線性變換,然后再通過全連接層進(jìn)行節(jié)點(diǎn)特征的更新。

接下來(lái),讓我們定義一個(gè)聚合器。聚合器的輸入是采樣得到的鄰居特征矩陣,輸出是新的節(jié)點(diǎn)嵌入向量。以下是聚合器的代碼實(shí)現(xiàn):

class MeanAggregator(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(MeanAggregator, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.lin = nn.Linear(input_dim, output_dim)
        
    def forward(self, neigh_mean):
        out = F.relu(self.lin(neigh_mean))
        return out

這里我們定義了一個(gè)簡(jiǎn)單的均值聚合器,其將鄰居特征矩陣中每列的均值作為節(jié)點(diǎn)嵌入向量,并使用全連接層進(jìn)行維度變換。

最后,讓我們定義整個(gè)GraphSage模型。GraphSage模型包含2個(gè)圖卷積層和1個(gè)輸出層。以下是模型的代碼實(shí)現(xiàn):

class GraphSAGE(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers=2):
        super(GraphSAGE, self).__init__()
        self.conv1 = GraphSageConv(in_channels, hidden_channels)
        self.aggreg1 = MeanAggregator(hidden_channels, hidden_channels)
        self.conv2 = GraphSageConv(hidden_channels, out_channels)
        
    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = global_mean_pool(x, edge_index)  # Compute global mean over nodes.
        x = self.aggreg1(x)
        x = self.conv2(x, edge_index)
        return x

這里我們定義了一個(gè)包含2層GraphSAGE Conv層的神經(jīng)網(wǎng)絡(luò)。在最后一層GraphSAGE Conv層之后,我們使用global_mean_pool函數(shù)來(lái)計(jì)算節(jié)點(diǎn)嵌入的全局平均值。注意,在本示例中,我們僅保留了一個(gè)輸出節(jié)點(diǎn),因此輸出矩陣的大小為1。如果需要輸出多個(gè)節(jié)點(diǎn),則需要設(shè)置global_mean_pool函數(shù)中的參數(shù)。

模型訓(xùn)練與測(cè)試

在定義好模型后,我們可以使用Pytorch進(jìn)行模型訓(xùn)練和測(cè)試。首先,讓我們定義一個(gè)損失函數(shù)和優(yōu)化器:

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

這里我們使用交叉熵作為損失函數(shù),并使用Adam優(yōu)化器來(lái)更新模型參數(shù)。

接著,我們可以開始訓(xùn)練模型。以下是訓(xùn)練過程的代碼實(shí)現(xiàn):

num_epochs = 100

for epoch in range(num_epochs):
    model.train()
    
    optimizer.zero_grad()
    out = model(data.x, data.edge_index)
    loss = criterion(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()
    
    print('Epoch {:03d}, Loss: {:.4f}'.format(epoch, loss.item()))

這里我們遍歷所有數(shù)據(jù)樣本,計(jì)算預(yù)測(cè)結(jié)果和真實(shí)標(biāo)簽之間的交叉熵?fù)p失,并使用反向傳播來(lái)更新權(quán)重。我們?cè)诿總€(gè)epoch結(jié)束后打印出當(dāng)前損失值。

最后,我們可以對(duì)模型進(jìn)行測(cè)試。以下是測(cè)試過程的代碼實(shí)現(xiàn):

model.eval()

with torch.no_grad():
    pred = model(data.x, data.edge_index)
    pred = pred.argmax(dim=1)

acc = (pred[data.test_mask] == data.y[data.test_mask]).sum().item() / data.test_mask.sum().item()
print('Test accuracy: {:.4f}'.format(acc))

這里我們使用測(cè)試集來(lái)計(jì)算模型的準(zhǔn)確率。注意,在執(zhí)行model.eval()后,我們需要使用torch.no_grad()包裝代碼塊,以禁止梯度計(jì)算。

總結(jié)

介紹了如何使用Pytorch Geometric實(shí)現(xiàn)GraphSAGE模型,包括構(gòu)建圖、定義Sampler方法、定義模型、訓(xùn)練和測(cè)試模型等步驟。GraphSAGE模型是一種常用的節(jié)點(diǎn)嵌入學(xué)習(xí)方法,可以應(yīng)用于各種下游任務(wù)中。

以上就是詳解使用Pytorch Geometric實(shí)現(xiàn)GraphSAGE模型的詳細(xì)內(nèi)容,更多關(guān)于Pytorch Geometric GraphSAGE的資料請(qǐng)關(guān)注腳本之家其它相關(guān)文章!

相關(guān)文章

  • 解決Python3錯(cuò)誤:SyntaxError: unexpected EOF while parsin

    解決Python3錯(cuò)誤:SyntaxError: unexpected EOF while

    這篇文章主要介紹了解決Python3錯(cuò)誤:SyntaxError: unexpected EOF while parsin問題,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。如有錯(cuò)誤或未考慮完全的地方,望不吝賜教
    2022-07-07
  • Python實(shí)現(xiàn)一個(gè)自助取數(shù)查詢工具

    Python實(shí)現(xiàn)一個(gè)自助取數(shù)查詢工具

    在數(shù)據(jù)生產(chǎn)應(yīng)用部門,取數(shù)分析是一個(gè)很常見的需求,實(shí)際上業(yè)務(wù)人員需求時(shí)刻變化,最高效的方式是讓業(yè)務(wù)部門自己來(lái)取,減少不必要的重復(fù)勞動(dòng),本文介紹如何用Python實(shí)現(xiàn)一個(gè)自助取數(shù)查詢工具
    2021-06-06
  • 如何使用Python 打印各種三角形

    如何使用Python 打印各種三角形

    這篇文章主要介紹了如何使用Python 打印各種三角形,文中通過示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友可以參考下
    2019-06-06
  • python中的數(shù)據(jù)結(jié)構(gòu)比較

    python中的數(shù)據(jù)結(jié)構(gòu)比較

    這篇文章主要介紹了python中的數(shù)據(jù)結(jié)構(gòu)比較,本文給大家介紹的非常詳細(xì),具有一定的參考借鑒價(jià)值,需要的朋友可以參考下
    2019-05-05
  • python xmind 包使用詳解(其中解決導(dǎo)出的xmind文件 xmind8可以打開 xmind2020及之后版本打開報(bào)錯(cuò)問題)

    python xmind 包使用詳解(其中解決導(dǎo)出的xmind文件 xmind8可以打開 xmind2020及之后版本打

    xmind8 可以打開xmind2020 報(bào)錯(cuò),如何解決這個(gè)問題呢?下面小編給大家?guī)?lái)了python xmind 包使用(其中解決導(dǎo)出的xmind文件 xmind8可以打開 xmind2020及之后版本打開報(bào)錯(cuò)問題),感興趣的朋友一起看看吧
    2021-10-10
  • python實(shí)現(xiàn)詩(shī)歌游戲(類繼承)

    python實(shí)現(xiàn)詩(shī)歌游戲(類繼承)

    這篇文章主要為大家詳細(xì)介紹了python實(shí)現(xiàn)詩(shī)歌游戲,根據(jù)上句猜下句、猜作者、猜朝代、猜詩(shī)名,具有一定的參考價(jià)值,感興趣的小伙伴們可以參考一下
    2019-02-02
  • Python變量賦值的秘密分享

    Python變量賦值的秘密分享

    在Python中,我們令一個(gè)變量等于另外一個(gè)變量時(shí),并不是把值傳遞給它,而是直接把指向的地址更改了,我們通過一個(gè)小例子來(lái)看看這個(gè)有趣的過程,需要的朋友可以參考下
    2018-04-04
  • 一文教會(huì)你用Python獲取網(wǎng)頁(yè)指定內(nèi)容

    一文教會(huì)你用Python獲取網(wǎng)頁(yè)指定內(nèi)容

    Python用做數(shù)據(jù)處理還是相當(dāng)不錯(cuò)的,如果你想要做爬蟲,Python是很好的選擇,它有很多已經(jīng)寫好的類包,只要調(diào)用即可完成很多復(fù)雜的功能,下面這篇文章主要給大家介紹了關(guān)于Python獲取網(wǎng)頁(yè)指定內(nèi)容的相關(guān)資料,需要的朋友可以參考下
    2022-03-03
  • 詳解python中[-1]、[:-1]、[::-1]、[n::-1]使用方法

    詳解python中[-1]、[:-1]、[::-1]、[n::-1]使用方法

    這篇文章主要介紹了詳解python中[-1]、[:-1]、[::-1]、[n::-1]使用方法,文中通過示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來(lái)一起學(xué)習(xí)學(xué)習(xí)吧
    2021-04-04
  • Keras load_model 導(dǎo)入錯(cuò)誤的解決方式

    Keras load_model 導(dǎo)入錯(cuò)誤的解決方式

    這篇文章主要介紹了Keras load_model 導(dǎo)入錯(cuò)誤的解決方式,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過來(lái)看看吧
    2020-06-06

最新評(píng)論