詳解Pytorch+PyG實(shí)現(xiàn)GCN過程示例
一、模型結(jié)構(gòu)
在圖神經(jīng)網(wǎng)絡(luò)的研究中,GCN(Graph Convolutional Networks)是一種比較常見且有效的模型。
在GCN模型中,每個(gè)節(jié)點(diǎn)都包含了該節(jié)點(diǎn)鄰居節(jié)點(diǎn)信息的聚合,這意味著它是一個(gè)全局性模型。一個(gè)典型的GCN模型通常由兩部分組成:一個(gè)基于消息傳遞算法的卷積層以及一個(gè)多層感知器。其中,前者主要完成特征融合,后者負(fù)責(zé)分類任務(wù)。
對(duì)于一個(gè)具有n個(gè)節(jié)點(diǎn)的圖G,其特征矩陣X可以表示為:
步驟如下:
- 構(gòu)建一個(gè)兩層的卷積網(wǎng)絡(luò):第一層是GCN層,后面跟著ReLU激活和一個(gè)隨機(jī)失活層;第二層是輸出分類器。
- 模型在訓(xùn)練期間根據(jù)具體的損失函數(shù)(如交叉熵?fù)p失)進(jìn)行優(yōu)化,并用于預(yù)測(cè)新數(shù)據(jù)。
二、PyTorch實(shí)現(xiàn)
PyTorch使用dgl庫可以方便地構(gòu)建圖,PyG也提供了類似的工具。接下來看一下如何使用PyTorch + PyG實(shí)現(xiàn)一個(gè)簡單的GCN模型,以Cora數(shù)據(jù)集為例。
準(zhǔn)備數(shù)據(jù)
Cora是一個(gè)分類任務(wù)的數(shù)據(jù)集,其中包含2708個(gè)文本節(jié)點(diǎn)名稱,以及每個(gè)節(jié)點(diǎn)的1433維特征(詞匯相關(guān)性)。首先,我們需要在PyG中將其轉(zhuǎn)換為一個(gè)帶有相應(yīng)邊緣信息的圖形對(duì)象。具體而言,使用pyg.data.dataset工具加載Cora數(shù)據(jù)集,然后將其轉(zhuǎn)換為一個(gè)PyG圖。
from torch_geometric.datasets import Planetoid import torch_geometric.transforms as T dataset = Planetoid(root='/path/to/dataset', name='Cora', transform=T.NormalizeFeatures()) data = dataset[0] print(data)
定義GCN模型
在定義PyG的GCN網(wǎng)絡(luò)之前,需要定義Convolutional Layer,這個(gè)層以鄰接矩陣A作為輸入,通過權(quán)重權(quán)值矩陣W來散播消息,并輸出一個(gè)新特征向量。
import torch.nn.functional as F from torch_geometric.nn import GCNConv class Net(torch.nn.Module): def __init__(self): super(Net, self).__init__() self.conv1 = GCNConv(dataset.num_features, 16) self.conv2 = GCNConv(16, dataset.num_classes) def forward(self, x, edge_index): x = self.conv1(x, edge_index) x = F.relu(x) x = F.dropout(x, training=self.training) x = self.conv2(x, edge_index) return F.log_softmax(x, dim=1)
定義訓(xùn)練過程
訓(xùn)練具體流程如下:
- 對(duì)于每個(gè)epoch,進(jìn)行隨機(jī)梯度下降優(yōu)化。我們選擇交叉熵作為損失函數(shù),并使用Adam作為優(yōu)化器。
- 在測(cè)試期間,用驗(yàn)證集對(duì)精確度進(jìn)行評(píng)估。
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = Net().to(device) data.to(device) optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4) def train(): model.train() optimizer.zero_grad() out = model(data.x, data.edge_index) loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask]) loss.backward() optimizer.step() def test(): model.eval() _, pred = model(data.x, data.edge_index).max(dim=1) correct = int(pred[data.test_mask].eq(data.y[data.test_mask]).sum().item()) acc = correct / int(data.test_mask.sum()) return acc for epoch in range(1, 201): train() test_acc = test() print(f'Epoch: {epoch:03d}, Test Acc: {test_acc:.4f}')
以上就是詳解Pytorch+PyG實(shí)現(xiàn)GCN過程示例的詳細(xì)內(nèi)容,更多關(guān)于Pytorch PyG實(shí)現(xiàn)GCN的資料請(qǐng)關(guān)注腳本之家其它相關(guān)文章!
相關(guān)文章
python Protobuf定義消息類型知識(shí)點(diǎn)講解
在本篇文章里小編給大家整理的是一篇關(guān)于python Protobuf定義消息類型知識(shí)點(diǎn)講解,有興趣的朋友們可以學(xué)習(xí)下。2021-03-03pytorch中如何使用DataLoader對(duì)數(shù)據(jù)集進(jìn)行批處理的方法
這篇文章主要介紹了pytorch中如何使用DataLoader對(duì)數(shù)據(jù)集進(jìn)行批處理的方法,文中通過示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧2019-08-08Python利用atexit模塊實(shí)現(xiàn)優(yōu)雅處理程序退出
Python的atexit模塊提供了一種方便的方式來注冊(cè)這些退出時(shí)執(zhí)行的函數(shù),文中的示例代碼講解詳細(xì),感興趣的小伙伴可以跟隨小編一起學(xué)習(xí)一下2024-03-03Python使用DPKT實(shí)現(xiàn)分析數(shù)據(jù)包
dpkt項(xiàng)目是一個(gè)Python模塊,主要用于對(duì)網(wǎng)絡(luò)數(shù)據(jù)包進(jìn)行解析和操作,z這篇文章主要為大家介紹了python如何利用DPKT實(shí)現(xiàn)分析數(shù)據(jù)包,有需要的可以參考下2023-10-10python SSH模塊登錄,遠(yuǎn)程機(jī)執(zhí)行shell命令實(shí)例解析
這篇文章主要介紹了python SSH模塊登錄,遠(yuǎn)程機(jī)執(zhí)行shell命令實(shí)例解析,具有一定借鑒價(jià)值,需要的朋友可以參考下2018-01-01Python實(shí)例方法、類方法、靜態(tài)方法區(qū)別詳解
這篇文章主要介紹了Python實(shí)例方法、類方法、靜態(tài)方法區(qū)別詳解,文中通過示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友可以參考下2020-09-09Python基于隨機(jī)采樣一至性實(shí)現(xiàn)擬合橢圓(優(yōu)化版)
這篇文章主要對(duì)上一版的Python基于隨機(jī)采樣一至性實(shí)現(xiàn)擬合橢圓的優(yōu)化,文中的示例代碼講解詳細(xì),具有一定的借鑒價(jià)值,感興趣的可以了解一下2022-11-11Python如何利用struct進(jìn)行二進(jìn)制文件或數(shù)據(jù)流
這篇文章主要介紹了Python如何利用struct進(jìn)行二進(jìn)制文件或數(shù)據(jù)流問題,具有很好的參考價(jià)值,希望對(duì)大家有所幫助,如有錯(cuò)誤或未考慮完全的地方,望不吝賜教2024-01-01在 Python 中如何使用 Re 模塊的正則表達(dá)式通配符
這篇文章主要介紹了在 Python 中如何使用 Re 模塊的正則表達(dá)式通配符,本文詳細(xì)解釋了如何在 Python 中使用帶有通配符的 re.sub() 來匹配字符串與正則表達(dá)式,需要的朋友可以參考下2023-06-06