PyG搭建GCN模型實現(xiàn)節(jié)點分類GCNConv參數(shù)詳解
前言
在上一篇文章PyG搭建GCN前的準備:了解PyG中的數(shù)據(jù)格式中,大致了解了PyG中的數(shù)據(jù)格式,這篇文章主要是簡單搭建GCN來實現(xiàn)節(jié)點分類,主要目的是了解PyG中GCN的參數(shù)情況。
模型搭建
首先導入包:
from torch_geometric.nn import GCNConv
模型參數(shù):
in_channels:輸入通道,比如節(jié)點分類中表示每個節(jié)點的特征數(shù)。
out_channels:輸出通道,最后一層GCNConv的輸出通道為節(jié)點類別數(shù)(節(jié)點分類)。
improved:如果為True表示自環(huán)增加,也就是原始鄰接矩陣加上2I而不是I,默認為False。
cached:如果為True,GCNConv在第一次對鄰接矩陣進行歸一化時會進行緩存,以后將不再重復計算。
add_self_loops:如果為False不再強制添加自環(huán),默認為True。
normalize:默認為True,表示對鄰接矩陣進行歸一化。
bias:默認添加偏置。
于是模型搭建如下:
class GCN(torch.nn.Module): def __init__(self, num_node_features, num_classes): super(GCN, self).__init__() self.conv1 = GCNConv(num_node_features, 16) self.conv2 = GCNConv(16, num_classes) def forward(self, data): x, edge_index = data.x, data.edge_index x = self.conv1(x, edge_index) x = F.relu(x) x = F.dropout(x, training=self.training) x = self.conv2(x, edge_index) x = F.relu(x) x = F.dropout(x, training=self.training) x = F.softmax(x, dim=1) return x
輸出一下模型:
data = Planetoid(root='/data/CiteSeer', name='CiteSeer')model = GCN(data.num_node_features, data.num_classes).to(device)print(model)GCN( (conv1): GCNConv(3703, 16) (conv2): GCNConv(16, 6) )
輸出為:
GCN( (conv1): GCNConv(3703, 16) (conv2): GCNConv(16, 6))GCN( (conv1): GCNConv(3703, 16) (conv2): GCNConv(16, 6) )
1. 前向傳播
查看官方文檔中GCNConv的輸入輸出要求:
可以發(fā)現(xiàn),GCNConv中需要輸入的是節(jié)點特征矩陣x和鄰接關系edge_index,還有一個可選項edge_weight。因此我們首先:
x, edge_index = data.x, data.edge_index x = self.conv1(x, edge_index) x = F.relu(x) x = F.dropout(x, training=self.training)
此時我們不妨輸出一下x及其size:
tensor([[0.0000, 0.1630, 0.0000, ..., 0.0000, 0.0488, 0.0000], [0.0000, 0.2451, 0.1614, ..., 0.0000, 0.0125, 0.0000], [0.1175, 0.0262, 0.2141, ..., 0.2592, 0.0000, 0.0000], ..., [0.0000, 0.0000, 0.0000, ..., 0.0000, 0.1825, 0.0000], [0.0000, 0.1024, 0.0000, ..., 0.0498, 0.0000, 0.0000], [0.0000, 0.3263, 0.0000, ..., 0.0000, 0.0000, 0.0000]], device='cuda:0', grad_fn=<FusedDropoutBackward0>)
torch.Size([3327, 16])
此時的x一共3327行,每一行表示一個節(jié)點經過第一層卷積更新后的狀態(tài)向量。
那么同理,由于:
self.conv2 = GCNConv(16, num_classes)
所以經過第二層卷積后:
x = self.conv2(x, edge_index)x = F.relu(x)x = F.dropout(x, training=self.training)x = self.conv2(x, edge_index) x = F.relu(x) x = F.dropout(x, training=self.training)
此時得到的x的size應該為:
torch.Size([3327, 6])
即每個節(jié)點的維度為6的狀態(tài)向量。
由于我們需要進行6分類,所以最后需要加上一個softmax:
x = F.softmax(x, dim=1)
dim=1表示對每一行進行運算,最終每一行之和加起來為1,也就表示了該節(jié)點為每一類的概率。輸出此時的x:
tensor([[0.1607, 0.1727, 0.1607, 0.1607, 0.1607, 0.1846], [0.1654, 0.1654, 0.1654, 0.1654, 0.1654, 0.1731], [0.1778, 0.1622, 0.1733, 0.1622, 0.1622, 0.1622], ..., [0.1659, 0.1659, 0.1659, 0.1704, 0.1659, 0.1659], [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667], [0.1641, 0.1641, 0.1658, 0.1766, 0.1653, 0.1641]], device='cuda:0', grad_fn=<SoftmaxBackward0>)tensor([[0.1607, 0.1727, 0.1607, 0.1607, 0.1607, 0.1846], [0.1654, 0.1654, 0.1654, 0.1654, 0.1654, 0.1731], [0.1778, 0.1622, 0.1733, 0.1622, 0.1622, 0.1622], ..., [0.1659, 0.1659, 0.1659, 0.1704, 0.1659, 0.1659], [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667], [0.1641, 0.1641, 0.1658, 0.1766, 0.1653, 0.1641]], device='cuda:0', grad_fn=<SoftmaxBackward0>)
2. 反向傳播
在訓練時,我們首先利用前向傳播計算出輸出:
out = model(data)
out即為最終得到的每個節(jié)點的6個概率值,但在實際訓練中,我們只需要計算出訓練集的損失,所以損失函數(shù)這樣寫:
loss = loss_function(out[data.train_mask], data.y[data.train_mask])
然后計算梯度,反向更新!
3. 訓練
訓練的完整代碼:
def train(): optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4) loss_function = torch.nn.CrossEntropyLoss().to(device) model.train() for epoch in range(500): out = model(data) optimizer.zero_grad() loss = loss_function(out[data.train_mask], data.y[data.train_mask]) loss.backward() optimizer.step() print('Epoch {:03d} loss {:.4f}'.format(epoch, loss.item()))def train(): optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4) loss_function = torch.nn.CrossEntropyLoss().to(device) model.train() for epoch in range(500): out = model(data) optimizer.zero_grad() loss = loss_function(out[data.train_mask], data.y[data.train_mask]) loss.backward() optimizer.step() print('Epoch {:03d} loss {:.4f}'.format(epoch, loss.item()))
4. 測試
我們首先需要算出模型對所有節(jié)點的預測值:
model(data)
此時得到的是每個節(jié)點的6個概率值,我們需要在每一行上取其最大值:
model(data).max(dim=1)
輸出一下:
torch.return_types.max( values=tensor([0.9100, 0.9071, 0.9786, ..., 0.4321, 0.4009, 0.8779], device='cuda:0', grad_fn=<MaxBackward0>), indices=tensor([3, 1, 5, ..., 3, 1, 5], device='cuda:0'))
返回的第一項是每一行的最大值,第二項為最大值在這一行中的索引,我們只需要取第二項,那么最終的預測值應該寫為:
_, pred = model(data).max(dim=1)
然后計算預測精度:
correct = int(pred[data.test_mask].eq(data.y[data.test_mask]).sum().item()) acc = correct / int(data.test_mask.sum()) print('GCN Accuracy: {:.4f}'.format(acc))
完整代碼
完整代碼中實現(xiàn)了論文中提到的四種數(shù)據(jù)集,代碼地址:PyG-GCN。
以上就是PyG搭建GCN模型實現(xiàn)節(jié)點分類GCNConv參數(shù)詳解的詳細內容,更多關于PyG搭建GCNConv節(jié)點分類的資料請關注腳本之家其它相關文章!
相關文章
Python之日期與時間處理模塊(date和datetime)
這篇文章主要介紹了Python之日期與時間處理模塊(date和datetime),小編覺得挺不錯的,現(xiàn)在分享給大家,也給大家做個參考。一起跟隨小編過來看看吧2017-02-02一個Python案例帶你掌握xpath數(shù)據(jù)解析方法
xpath解析是最常用且最便捷高效的一種解析方式,通用性強。本文將通過一個Python爬蟲案例帶你詳細了解一下xpath數(shù)據(jù)解析方法,需要的可以參考一下2022-02-02如何在windows下安裝Pycham2020軟件(方法步驟詳解)
這篇文章主要介紹了在windows下安裝Pycham2020軟件方法,本文通過圖文并茂的形式給大家介紹的非常詳細,對大家的學習或工作具有一定的參考借鑒價值,需要的朋友可以參考下2020-05-05python3 破解 geetest(極驗)的滑塊驗證碼功能
這篇文章主要介紹了python3 破解 geetest(極驗)的滑塊驗證碼功能,本文通過實例代碼給大家介紹的非常詳細,對大家的學習或工作具有一定的參考借鑒價值,需要的朋友可以參考下2018-02-02