欧美bbbwbbbw肥妇,免费乱码人妻系列日韩,一级黄片

PyG搭建GCN模型實(shí)現(xiàn)節(jié)點(diǎn)分類GCNConv參數(shù)詳解

 更新時(shí)間:2022年05月10日 15:23:46   作者:Cyril_KI  
這篇文章主要為大家介紹了PyG搭建GCN模型實(shí)現(xiàn)節(jié)點(diǎn)分類GCNConv參數(shù)詳解,有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進(jìn)步,早日升職加薪

前言

在上一篇文章PyG搭建GCN前的準(zhǔn)備:了解PyG中的數(shù)據(jù)格式中,大致了解了PyG中的數(shù)據(jù)格式,這篇文章主要是簡單搭建GCN來實(shí)現(xiàn)節(jié)點(diǎn)分類,主要目的是了解PyG中GCN的參數(shù)情況。

模型搭建

首先導(dǎo)入包:

from torch_geometric.nn import GCNConv

模型參數(shù):

in_channels:輸入通道,比如節(jié)點(diǎn)分類中表示每個(gè)節(jié)點(diǎn)的特征數(shù)。

out_channels:輸出通道,最后一層GCNConv的輸出通道為節(jié)點(diǎn)類別數(shù)(節(jié)點(diǎn)分類)。

improved:如果為True表示自環(huán)增加,也就是原始鄰接矩陣加上2I而不是I,默認(rèn)為False。

cached:如果為True,GCNConv在第一次對鄰接矩陣進(jìn)行歸一化時(shí)會進(jìn)行緩存,以后將不再重復(fù)計(jì)算。

add_self_loops:如果為False不再強(qiáng)制添加自環(huán),默認(rèn)為True。

normalize:默認(rèn)為True,表示對鄰接矩陣進(jìn)行歸一化。

bias:默認(rèn)添加偏置。

于是模型搭建如下:

class GCN(torch.nn.Module):
    def __init__(self, num_node_features, num_classes):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(num_node_features, 16)
        self.conv2 = GCNConv(16, num_classes)
    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        x = F.softmax(x, dim=1)
        return x

輸出一下模型:

data = Planetoid(root='/data/CiteSeer', name='CiteSeer')model = GCN(data.num_node_features, data.num_classes).to(device)print(model)GCN(
  (conv1): GCNConv(3703, 16)
  (conv2): GCNConv(16, 6)
)

輸出為:

GCN( (conv1): GCNConv(3703, 16) (conv2): GCNConv(16, 6))GCN(
  (conv1): GCNConv(3703, 16)
  (conv2): GCNConv(16, 6)
)

1. 前向傳播

查看官方文檔中GCNConv的輸入輸出要求:

可以發(fā)現(xiàn),GCNConv中需要輸入的是節(jié)點(diǎn)特征矩陣x和鄰接關(guān)系edge_index,還有一個(gè)可選項(xiàng)edge_weight。因此我們首先:

x, edge_index = data.x, data.edge_index
x = self.conv1(x, edge_index)
x = F.relu(x)
x = F.dropout(x, training=self.training)

此時(shí)我們不妨輸出一下x及其size:

tensor([[0.0000, 0.1630, 0.0000,  ..., 0.0000, 0.0488, 0.0000],
        [0.0000, 0.2451, 0.1614,  ..., 0.0000, 0.0125, 0.0000],
        [0.1175, 0.0262, 0.2141,  ..., 0.2592, 0.0000, 0.0000],
        ...,
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.1825, 0.0000],
        [0.0000, 0.1024, 0.0000,  ..., 0.0498, 0.0000, 0.0000],
        [0.0000, 0.3263, 0.0000,  ..., 0.0000, 0.0000, 0.0000]],
       device='cuda:0', grad_fn=<FusedDropoutBackward0>)
torch.Size([3327, 16])

此時(shí)的x一共3327行,每一行表示一個(gè)節(jié)點(diǎn)經(jīng)過第一層卷積更新后的狀態(tài)向量。

那么同理,由于:

self.conv2 = GCNConv(16, num_classes)

所以經(jīng)過第二層卷積后:

x = self.conv2(x, edge_index)x = F.relu(x)x = F.dropout(x, training=self.training)x = self.conv2(x, edge_index)
x = F.relu(x)
x = F.dropout(x, training=self.training)

此時(shí)得到的x的size應(yīng)該為:

torch.Size([3327, 6])

即每個(gè)節(jié)點(diǎn)的維度為6的狀態(tài)向量。

由于我們需要進(jìn)行6分類,所以最后需要加上一個(gè)softmax:

x = F.softmax(x, dim=1)

dim=1表示對每一行進(jìn)行運(yùn)算,最終每一行之和加起來為1,也就表示了該節(jié)點(diǎn)為每一類的概率。輸出此時(shí)的x:

tensor([[0.1607, 0.1727, 0.1607, 0.1607, 0.1607, 0.1846], [0.1654, 0.1654, 0.1654, 0.1654, 0.1654, 0.1731], [0.1778, 0.1622, 0.1733, 0.1622, 0.1622, 0.1622], ..., [0.1659, 0.1659, 0.1659, 0.1704, 0.1659, 0.1659], [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667], [0.1641, 0.1641, 0.1658, 0.1766, 0.1653, 0.1641]], device='cuda:0', grad_fn=<SoftmaxBackward0>)tensor([[0.1607, 0.1727, 0.1607, 0.1607, 0.1607, 0.1846],
        [0.1654, 0.1654, 0.1654, 0.1654, 0.1654, 0.1731],
        [0.1778, 0.1622, 0.1733, 0.1622, 0.1622, 0.1622],
        ...,
        [0.1659, 0.1659, 0.1659, 0.1704, 0.1659, 0.1659],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1641, 0.1641, 0.1658, 0.1766, 0.1653, 0.1641]], device='cuda:0',
       grad_fn=<SoftmaxBackward0>)

2. 反向傳播

在訓(xùn)練時(shí),我們首先利用前向傳播計(jì)算出輸出:

out = model(data)

out即為最終得到的每個(gè)節(jié)點(diǎn)的6個(gè)概率值,但在實(shí)際訓(xùn)練中,我們只需要計(jì)算出訓(xùn)練集的損失,所以損失函數(shù)這樣寫:

loss = loss_function(out[data.train_mask], data.y[data.train_mask])

然后計(jì)算梯度,反向更新!

3. 訓(xùn)練

訓(xùn)練的完整代碼:

def train(): optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4) loss_function = torch.nn.CrossEntropyLoss().to(device) model.train() for epoch in range(500): out = model(data) optimizer.zero_grad() loss = loss_function(out[data.train_mask], data.y[data.train_mask]) loss.backward() optimizer.step() print('Epoch {:03d} loss {:.4f}'.format(epoch, loss.item()))def train():
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
    loss_function = torch.nn.CrossEntropyLoss().to(device)
    model.train()
    for epoch in range(500):
        out = model(data)
        optimizer.zero_grad()
        loss = loss_function(out[data.train_mask], data.y[data.train_mask])
        loss.backward()
        optimizer.step()
        print('Epoch {:03d} loss {:.4f}'.format(epoch, loss.item()))

4. 測試

我們首先需要算出模型對所有節(jié)點(diǎn)的預(yù)測值:

model(data)

此時(shí)得到的是每個(gè)節(jié)點(diǎn)的6個(gè)概率值,我們需要在每一行上取其最大值:

model(data).max(dim=1)

輸出一下:

torch.return_types.max(
values=tensor([0.9100, 0.9071, 0.9786,  ..., 0.4321, 0.4009, 0.8779], device='cuda:0',
       grad_fn=<MaxBackward0>),
indices=tensor([3, 1, 5,  ..., 3, 1, 5], device='cuda:0'))

返回的第一項(xiàng)是每一行的最大值,第二項(xiàng)為最大值在這一行中的索引,我們只需要取第二項(xiàng),那么最終的預(yù)測值應(yīng)該寫為:

_, pred = model(data).max(dim=1)

然后計(jì)算預(yù)測精度:

correct = int(pred[data.test_mask].eq(data.y[data.test_mask]).sum().item())
acc = correct / int(data.test_mask.sum())
print('GCN Accuracy: {:.4f}'.format(acc))

完整代碼

完整代碼中實(shí)現(xiàn)了論文中提到的四種數(shù)據(jù)集,代碼地址:PyG-GCN。

以上就是PyG搭建GCN模型實(shí)現(xiàn)節(jié)點(diǎn)分類GCNConv參數(shù)詳解的詳細(xì)內(nèi)容,更多關(guān)于PyG搭建GCNConv節(jié)點(diǎn)分類的資料請關(guān)注腳本之家其它相關(guān)文章!

相關(guān)文章

  • 關(guān)于matlab圖像濾波詳解(二維傅里葉濾波)

    關(guān)于matlab圖像濾波詳解(二維傅里葉濾波)

    這篇文章主要介紹了關(guān)于matlab圖像濾波詳解(二維傅里葉濾波),具有很好的參考價(jià)值,希望對大家有所幫助。如有錯(cuò)誤或未考慮完全的地方,望不吝賜教
    2023-02-02
  • 用Python獲取亞馬遜商品信息

    用Python獲取亞馬遜商品信息

    大家好,本篇文章主要講的是用Python獲取亞馬遜商品信息,感興趣的同學(xué)趕快來看一看吧,對你有幫助的話記得收藏一下,方便下次瀏覽
    2022-01-01
  • Python之日期與時(shí)間處理模塊(date和datetime)

    Python之日期與時(shí)間處理模塊(date和datetime)

    這篇文章主要介紹了Python之日期與時(shí)間處理模塊(date和datetime),小編覺得挺不錯(cuò)的,現(xiàn)在分享給大家,也給大家做個(gè)參考。一起跟隨小編過來看看吧
    2017-02-02
  • Python pathlib模塊實(shí)例詳解

    Python pathlib模塊實(shí)例詳解

    本文給大家介紹了Python的pathlib 模塊,為 Python 工程師對該模塊的使用提供了支撐,讓大家了解如何使用 pathlib 模塊讀寫文件、操縱文件路徑和基礎(chǔ)文件系統(tǒng),統(tǒng)計(jì)目錄下的文件類型以及查找匹配目錄下某一類型文件等,需要的朋友參考下吧
    2023-05-05
  • Python設(shè)計(jì)模式行為型責(zé)任鏈模式

    Python設(shè)計(jì)模式行為型責(zé)任鏈模式

    這篇文章主要介紹了Python設(shè)計(jì)模式行為型責(zé)任鏈模式,責(zé)任鏈模式將能處理請求的對象連成一條鏈,并沿著這條鏈傳遞該請求,直到有一個(gè)對象處理請求為止,避免請求的發(fā)送者和接收者之間的耦合關(guān)系,下圍繞改內(nèi)容介紹具有一點(diǎn)的參考價(jià)值,需要的朋友可以參考下
    2022-02-02
  • 一個(gè)Python案例帶你掌握xpath數(shù)據(jù)解析方法

    一個(gè)Python案例帶你掌握xpath數(shù)據(jù)解析方法

    xpath解析是最常用且最便捷高效的一種解析方式,通用性強(qiáng)。本文將通過一個(gè)Python爬蟲案例帶你詳細(xì)了解一下xpath數(shù)據(jù)解析方法,需要的可以參考一下
    2022-02-02
  • 如何在windows下安裝Pycham2020軟件(方法步驟詳解)

    如何在windows下安裝Pycham2020軟件(方法步驟詳解)

    這篇文章主要介紹了在windows下安裝Pycham2020軟件方法,本文通過圖文并茂的形式給大家介紹的非常詳細(xì),對大家的學(xué)習(xí)或工作具有一定的參考借鑒價(jià)值,需要的朋友可以參考下
    2020-05-05
  • Python常見文件操作的函數(shù)示例代碼

    Python常見文件操作的函數(shù)示例代碼

    Python常見文件操作的函數(shù)示例代碼,學(xué)習(xí)python的朋友可以參考下。
    2011-11-11
  • python3 破解 geetest(極驗(yàn))的滑塊驗(yàn)證碼功能

    python3 破解 geetest(極驗(yàn))的滑塊驗(yàn)證碼功能

    這篇文章主要介紹了python3 破解 geetest(極驗(yàn))的滑塊驗(yàn)證碼功能,本文通過實(shí)例代碼給大家介紹的非常詳細(xì),對大家的學(xué)習(xí)或工作具有一定的參考借鑒價(jià)值,需要的朋友可以參考下
    2018-02-02
  • Python中int()函數(shù)的用法淺析

    Python中int()函數(shù)的用法淺析

    這篇文章主要介紹了Python中int()函數(shù)的用法淺析的相關(guān)資料,需要的朋友可以參考下
    2017-10-10

最新評論