欧美bbbwbbbw肥妇,免费乱码人妻系列日韩,一级黄片

Pytorch+PyG實現(xiàn)GIN過程示例詳解

 更新時間:2023年04月21日 09:57:03   作者:實力  
這篇文章主要為大家介紹了Pytorch+PyG實現(xiàn)GIN過程示例詳解,有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進步,早日升職加薪

GIN簡介

GIN(Graph Isomorphism Network)是一類基于圖同構(gòu)的神經(jīng)網(wǎng)絡。在傳統(tǒng)的神經(jīng)網(wǎng)絡中,每個節(jié)點的特征只依賴于其自身特征,但在圖數(shù)據(jù)中,節(jié)點的特征還與其鄰居節(jié)點有關(guān)系。GIN網(wǎng)絡通過定義可重復均值池化運算來學習節(jié)點及其鄰居的特征表示,并使用多層感知器(MLP)作為逐層轉(zhuǎn)換函數(shù)進行特征提取。

實現(xiàn)步驟

數(shù)據(jù)準備

這里我們?nèi)匀贿x用Cora數(shù)據(jù)集作為示例數(shù)據(jù)。由于GIN采用基于點、簡單且無參數(shù)的鄰域聚合方式,因此不需要額外對數(shù)據(jù)做處理,直接使用即可。

import torch.nn.functional as F
from torch_geometric.datasets import Planetoid
from torch_geometric.utils import from_networkx, to_networkx
# 加載Cora數(shù)據(jù)集
dataset = Planetoid(root='./cora', name='Cora')
data = dataset[0]
# 將nx.Graph形式的圖轉(zhuǎn)換成PyG需要的格式
graph = to_networkx(data)
data = from_networkx(graph)
# 獲取節(jié)點數(shù)量和特征向量維度
num_nodes = data.num_nodes
num_features = dataset.num_features
num_classes = dataset.num_classes
# 建立需要訓練的節(jié)點分割數(shù)據(jù)集
data.train_mask = torch.zeros(num_nodes, dtype=torch.bool)
data.val_mask = torch.zeros(num_nodes, dtype=torch.bool)
data.test_mask = torch.zeros(num_nodes, dtype=torch.bool)
data.train_mask[:num_nodes - 1000] = True
data.test_mask[-1000:] = True
data.val_mask[num_nodes - 2000: num_nodes - 1000] = True

實現(xiàn)模型

接下來,我們需要定義GIN模型。

from torch_geometric.nn import global_mean_pool
class GIN(torch.nn.Module):
    def __init__(self, hidden_dim, num_layers):
        super(GIN, self).__init__()
        self.conv1 = GINConv(mlp=nn.Sequential(nn.Linear(num_features, hidden_dim),
                                                nn.ReLU(),
                                                nn.Linear(hidden_dim, hidden_dim)))
        self.convs = nn.ModuleList()
        for _ in range(num_layers - 1):
            self.convs.append(GINConv(mlp=nn.Sequential(nn.Linear(hidden_dim, hidden_dim),
                                                        nn.ReLU(),
                                                        nn.Linear(hidden_dim, hidden_dim))))
        self.classify = nn.Sequential(nn.Linear(hidden_dim, num_classes))
    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        x = F.relu(self.conv1(x, edge_index))
        for conv in self.convs:
            x = F.relu(conv(x, edge_index))
        out = global_mean_pool(x, batch)
        return self.classify(out)

在上述代碼中,我們實現(xiàn)了多層GIN的“可重復均值池化”結(jié)構(gòu),并使用MLP作為轉(zhuǎn)換函數(shù)進行多層特征提取。

模型訓練

定義好模型后,可以開始針對Cora數(shù)據(jù)集進行模型訓練了。訓練模型前先設置好優(yōu)化器和損失函數(shù),并指定訓練周期及其過程中需要記錄輸出信息的參數(shù)。

from torch_geometric.nn import GINConv, global_add_pool
# 初始化GIN并指定參數(shù)
num_layers = 5
hidden_dim = 1024
model = GIN(hidden_dim=hidden_dim, num_layers=num_layers).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-06)
loss_func = nn.CrossEntropyLoss()
# 開始訓練
for epoch in range(100):
    model.train()
    optimizer.zero_grad()
    pred = model(train_data)
    loss = loss_func(pred[train_mask], train_labels)
    loss.backward()
    optimizer.step()
    # 在各個測試階段檢測一下準確率
    with torch.no_grad():
        model.eval()
        pred = model(test_data)
        test_loss = loss_func(pred[test_mask], test_labels).item()
        pred = pred.argmax(dim=-1, keepdim=True)
        correct = float(pred[test_mask].eq(test_labels.view(-1, 1)[test_mask]).sum().item())
        acc = correct / test_mask.sum().item()
        if epoch % 10 == 0:
            print("Epoch {:03d}, Train Loss {:.4f}, Test Loss {:.4f}, Test Acc {:.4f}".format(
                epoch, loss.item(), test_loss, acc))

以上就是Pytorch+PyG實現(xiàn)GIN過程示例詳解的詳細內(nèi)容,更多關(guān)于Pytorch PyG實現(xiàn)GIN的資料請關(guān)注腳本之家其它相關(guān)文章!

相關(guān)文章

  • python中的decode()與encode()深入理解

    python中的decode()與encode()深入理解

    這篇文章主要介紹了python中的decode()與encode()函數(shù)詳解,本文通過實例代碼給大家講解的非常詳細,對大家的學習或工作具有一定的參考借鑒價值,需要的朋友可以參考下
    2022-12-12
  • Python迭代器模塊itertools使用原理解析

    Python迭代器模塊itertools使用原理解析

    這篇文章主要介紹了Python迭代器模塊itertools使用原理解析,文中通過示例代碼介紹的非常詳細,對大家的學習或者工作具有一定的參考學習價值,需要的朋友可以參考下
    2019-12-12
  • 老生常談python字典用法

    老生常談python字典用法

    python 創(chuàng)建字典可以使用 dict 函數(shù),或者使用花括號,用花括號的方式更為常見。本文給大家介紹python字典用法,感興趣的朋友跟隨小編一起看看吧
    2021-12-12
  • 使用Flask和Django中解決跨域請求問題

    使用Flask和Django中解決跨域請求問題

    這篇文章主要介紹了使用Flask和Django中解決跨域請求問題,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧
    2021-04-04
  • django一對多模型以及如何在前端實現(xiàn)詳解

    django一對多模型以及如何在前端實現(xiàn)詳解

    這篇文章主要介紹了django一對多模型以及如何在前端實現(xiàn)詳解,文中通過示例代碼介紹的非常詳細,對大家的學習或者工作具有一定的參考學習價值,需要的朋友可以參考下
    2019-07-07
  • PyCharm:method may be static問題及解決

    PyCharm:method may be static問題及解決

    這篇文章主要介紹了PyCharm:method may be static問題及解決方案,具有很好的參考價值,希望對大家有所幫助。如有錯誤或未考慮完全的地方,望不吝賜教
    2022-07-07
  • Python兩個內(nèi)置函數(shù) locals 和globals(學習筆記)

    Python兩個內(nèi)置函數(shù) locals 和globals(學習筆記)

    這篇文章主要介紹了Python兩個內(nèi)置函數(shù) locals 和globals(學習筆記),需要的朋友可以參考下
    2016-08-08
  • django框架實現(xiàn)模板中獲取request 的各種信息示例

    django框架實現(xiàn)模板中獲取request 的各種信息示例

    這篇文章主要介紹了django框架實現(xiàn)模板中獲取request 的各種信息,結(jié)合實例形式分析了Django框架模板直接獲取request信息的相關(guān)配置與操作技巧,需要的朋友可以參考下
    2019-07-07
  • 教你用Python寫一個水果忍者小游戲

    教你用Python寫一個水果忍者小游戲

    水果忍者游戲,又稱切水果游戲,玩法簡單,水果忍者游戲在兒童中很受歡迎,下面這篇文章主要給大家介紹了關(guān)于如何利用Python寫一個水果忍者小游戲的相關(guān)資料,需要的朋友可以參考下
    2022-03-03
  • python實現(xiàn)騰訊滑塊驗證碼識別

    python實現(xiàn)騰訊滑塊驗證碼識別

    這篇文章主要介紹了python如何實現(xiàn)騰訊滑塊驗證碼識別,幫助大家更好的理解和學習使用python,感興趣的朋友可以了解下
    2021-04-04

最新評論