欧美bbbwbbbw肥妇,免费乱码人妻系列日韩,一级黄片

詳解Pytorch+PyG實(shí)現(xiàn)GCN過程示例

 更新時(shí)間:2023年04月21日 10:09:34   作者:實(shí)力  
這篇文章主要為大家介紹了Pytorch+PyG實(shí)現(xiàn)GCN過程示例詳解,有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進(jìn)步,早日升職加薪

一、模型結(jié)構(gòu)

在圖神經(jīng)網(wǎng)絡(luò)的研究中,GCN(Graph Convolutional Networks)是一種比較常見且有效的模型。

在GCN模型中,每個(gè)節(jié)點(diǎn)都包含了該節(jié)點(diǎn)鄰居節(jié)點(diǎn)信息的聚合,這意味著它是一個(gè)全局性模型。一個(gè)典型的GCN模型通常由兩部分組成:一個(gè)基于消息傳遞算法的卷積層以及一個(gè)多層感知器。其中,前者主要完成特征融合,后者負(fù)責(zé)分類任務(wù)。

對(duì)于一個(gè)具有n個(gè)節(jié)點(diǎn)的圖G,其特征矩陣X可以表示為:

步驟如下:

  • 構(gòu)建一個(gè)兩層的卷積網(wǎng)絡(luò):第一層是GCN層,后面跟著ReLU激活和一個(gè)隨機(jī)失活層;第二層是輸出分類器。
  • 模型在訓(xùn)練期間根據(jù)具體的損失函數(shù)(如交叉熵?fù)p失)進(jìn)行優(yōu)化,并用于預(yù)測(cè)新數(shù)據(jù)。

二、PyTorch實(shí)現(xiàn)

PyTorch使用dgl庫可以方便地構(gòu)建圖,PyG也提供了類似的工具。接下來看一下如何使用PyTorch + PyG實(shí)現(xiàn)一個(gè)簡單的GCN模型,以Cora數(shù)據(jù)集為例。

準(zhǔn)備數(shù)據(jù)

Cora是一個(gè)分類任務(wù)的數(shù)據(jù)集,其中包含2708個(gè)文本節(jié)點(diǎn)名稱,以及每個(gè)節(jié)點(diǎn)的1433維特征(詞匯相關(guān)性)。首先,我們需要在PyG中將其轉(zhuǎn)換為一個(gè)帶有相應(yīng)邊緣信息的圖形對(duì)象。具體而言,使用pyg.data.dataset工具加載Cora數(shù)據(jù)集,然后將其轉(zhuǎn)換為一個(gè)PyG圖。

from torch_geometric.datasets import Planetoid
import torch_geometric.transforms as T
dataset = Planetoid(root='/path/to/dataset', name='Cora', transform=T.NormalizeFeatures())
data = dataset[0]
print(data)

定義GCN模型

在定義PyG的GCN網(wǎng)絡(luò)之前,需要定義Convolutional Layer,這個(gè)層以鄰接矩陣A作為輸入,通過權(quán)重權(quán)值矩陣W來散播消息,并輸出一個(gè)新特征向量。

import torch.nn.functional as F
from torch_geometric.nn import GCNConv
class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = GCNConv(dataset.num_features, 16)
        self.conv2 = GCNConv(16, dataset.num_classes)
    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index)
        return F.log_softmax(x, dim=1)

定義訓(xùn)練過程

訓(xùn)練具體流程如下:

  • 對(duì)于每個(gè)epoch,進(jìn)行隨機(jī)梯度下降優(yōu)化。我們選擇交叉熵作為損失函數(shù),并使用Adam作為優(yōu)化器。
  • 在測(cè)試期間,用驗(yàn)證集對(duì)精確度進(jìn)行評(píng)估。
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Net().to(device)
data.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
def train():
    model.train()
    optimizer.zero_grad()
    out = model(data.x, data.edge_index)
    loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()
def test():
    model.eval()
    _, pred = model(data.x, data.edge_index).max(dim=1)
    correct = int(pred[data.test_mask].eq(data.y[data.test_mask]).sum().item())
    acc = correct / int(data.test_mask.sum())
    return acc
for epoch in range(1, 201):
    train()
    test_acc = test()
    print(f'Epoch: {epoch:03d}, Test Acc: {test_acc:.4f}')

以上就是詳解Pytorch+PyG實(shí)現(xiàn)GCN過程示例的詳細(xì)內(nèi)容,更多關(guān)于Pytorch PyG實(shí)現(xiàn)GCN的資料請(qǐng)關(guān)注腳本之家其它相關(guān)文章!

相關(guān)文章

最新評(píng)論