Pytorch+PyG實現(xiàn)GraphSAGE過程示例詳解
GraphSAGE簡介
GraphSAGE(Graph Sampling and Aggregation)是一種常見的圖神經(jīng)網(wǎng)絡(luò)模型,主要用于結(jié)點級別的表征學習。該模型基于采樣和聚合策略,將一個結(jié)點及其鄰居節(jié)點信息融合在一起,得到其表征表示,并通過多輪迭代更新來提高表征的精度。
實現(xiàn)步驟
數(shù)據(jù)準備
在本次實現(xiàn)中,我們?nèi)匀皇褂肅ora數(shù)據(jù)集作為示例進行測試,由于GraphSage主要聚焦于單一節(jié)點特征的更新,因此這里不需要對數(shù)據(jù)集做特別處理,只需要將數(shù)據(jù)轉(zhuǎn)化成PyG格式即可。
import torch.nn.functional as F from torch_geometric.datasets import Planetoid from torch_geometric.utils import from_networkx, to_networkx # 加載cora數(shù)據(jù)集 dataset = Planetoid(root='./cora', name='Cora') data = dataset[0] # 將nx.Graph形式的圖轉(zhuǎn)換成PyG需要的格式 graph = to_networkx(data) data = from_networkx(graph) # 獲取節(jié)點數(shù)量和特征向量維度 num_nodes = data.num_nodes num_features = dataset.num_features num_classes = dataset.num_classes # 建立需要訓練的節(jié)點分割數(shù)據(jù)集 data.train_mask = torch.zeros(num_nodes, dtype=torch.bool) data.val_mask = torch.zeros(num_nodes, dtype=torch.bool) data.test_mask = torch.zeros(num_nodes, dtype=torch.bool) data.train_mask[:num_nodes - 1000] = True data.test_mask[-1000:] = True data.val_mask[num_nodes - 2000: num_nodes - 1000] = True
實現(xiàn)模型
接下來,我們需要定義GraphSAGE模型。與傳統(tǒng)的GCN中只需要一層卷積操作不同,GraphSAGE包含兩層卷積和采樣(也稱“聚合”)操作。
from torch.nn import Sequential as Seq, Linear as Lin, ReLU from torch_geometric.nn import SAGEConv class GraphSAGE(torch.nn.Module): def __init__(self, hidden_channels, num_layers): super(GraphSAGE, self).__init__() self.convs = nn.ModuleList() for i in range(num_layers): in_channels = hidden_channels if i != 0 else num_features out_channels = num_classes if i == num_layers - 1 else hidden_channels self.convs.append(SAGEConv(in_channels, out_channels)) def forward(self, x, edge_index): for _, conv in enumerate(self.convs[:-1]): x = F.relu(conv(x, edge_index)) # 最后一層不用激活函數(shù) x = self.convs[-1](x, edge_index) return F.log_softmax(x, dim=-1)
在上述代碼中,我們實現(xiàn)了多層GraphSAGE卷積和相應(yīng)的聚合函數(shù),并使用ReLU和softmax函數(shù)來進行特征提取和分類分數(shù)的輸出。
模型訓練
定義好模型之后,就可以開始針對Cora數(shù)據(jù)集進行模型訓練。首先還是需要先指定優(yōu)化器和損失函數(shù),并設(shè)定一些參數(shù)用于記錄訓練過程中的信息,如Epochs、Batch size、學習率等。
# 初始化GraphSage并指定參數(shù) num_layers = 2 hidden_channels = 256 model = GraphSAGE(hidden_channels, num_layers).to(device) optimizer = torch.optim.Adam(model.parameters(), lr=0.01) loss_func = nn.CrossEntropyLoss() # 訓練過程 for epoch in range(500): model.train() optimizer.zero_grad() out = model(data.x.to(device), data.edge_index.to(device)) loss = loss_func(out[data.train_mask], data.y.to(device)[data.train_mask]) loss.backward() optimizer.step() # 在各個測試階段檢測一下準確率 if epoch % 10 == 0: with torch.no_grad(): _, pred = model(data.x.to(device), data.edge_index.to(device)).max(dim=1) correct = float(pred[data.test_mask].eq(data.y.to(device)[data.test_mask]).sum().item()) acc = correct / data.test_mask.sum().item() print("Epoch {:03d}, Train Loss {:.4f}, Test Acc {:.4f}".format( epoch, loss.item(), acc))
在上述代碼中,我們使用有標記的訓練數(shù)據(jù)擬合GraphSAGE模型,在各個驗證階段測試準確率,并通過梯度下降法優(yōu)化損失函數(shù)。
以上就是Pytorch+PyG實現(xiàn)GraphSAGE過程示例詳解的詳細內(nèi)容,更多關(guān)于Pytorch PyG實現(xiàn)GraphSAGE的資料請關(guān)注腳本之家其它相關(guān)文章!
相關(guān)文章
初次部署django+gunicorn+nginx的方法步驟
這篇文章主要介紹了初次部署django+gunicorn+nginx的方法步驟,文中通過示例代碼介紹的非常詳細,對大家的學習或者工作具有一定的參考學習價值,需要的朋友們下面隨著小編來一起學習學習吧2019-09-09python四個坐標點對圖片區(qū)域最小外接矩形進行裁剪
在圖像裁剪操作中,opencv和pillow兩個庫都具有相應(yīng)的函數(shù),如果想要對目標的最小外接矩形進行裁剪該如何操作呢?本文就來詳細的介紹一下2021-06-06Python爬蟲實戰(zhàn):分析《戰(zhàn)狼2》豆瓣影評
這篇文章主要介紹了Python爬蟲實戰(zhàn):《戰(zhàn)狼2》豆瓣影評分析,小編在這里使用的是python版本3.5,需要的朋友可以參考下2018-03-03Python使用Beautiful Soup實現(xiàn)解析網(wǎng)頁
在這篇文章中,我們將介紹如何使用 Python 編寫一個簡單的網(wǎng)絡(luò)爬蟲,以獲取并解析網(wǎng)頁內(nèi)容。我們將使用 Beautiful Soup 庫,它是一個非常強大的庫,用于解析和操作 HTML 和 XML 文檔。讓我們開始吧2023-05-05Python 編碼處理-str與Unicode的區(qū)別
本文主要介紹Python 編碼處理的問題,這里整理了相關(guān)資料,并詳細說明如何處理編碼問題,有需要的小伙伴可以參考下2016-09-09