欧美bbbwbbbw肥妇,免费乱码人妻系列日韩,一级黄片

詳解Pytorch+PyG實現(xiàn)GAT過程示例

 更新時間:2023年04月21日 10:01:31   作者:實力  
這篇文章主要為大家介紹了Pytorch+PyG實現(xiàn)GAT過程示例詳解,有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進步,早日升職加薪

導(dǎo)入庫和數(shù)據(jù)

GAT(圖注意力網(wǎng)絡(luò))是常見的圖神經(jīng)網(wǎng)絡(luò)結(jié)構(gòu)之一,它使用注意力機制來對節(jié)點進行特征加權(quán),并考慮其鄰居節(jié)點的交互。

首先,我們需要導(dǎo)入PyTorch和PyG庫,然后準備好我們的數(shù)據(jù)。例如,我們可以使用以下方式生成一個簡單的隨機數(shù)據(jù)集:

from torch_geometric.datasets import Planetoid
dataset = Planetoid(root='/tmp/Cora', name='Cora')
train_loader = DataLoader(dataset[0], batch_size=128, shuffle=True)
test_loader = DataLoader(dataset[0], batch_size=128, shuffle=False)

其中, Planetoid 是PyG提供的圖形數(shù)據(jù)集之一。這里我們選擇了 Cora 數(shù)據(jù)集并存儲到 /tmp/Cora 文件夾中。然后我們將該數(shù)據(jù)集分成訓(xùn)練集和測試集,設(shè)置相應(yīng)的加載器。

定義模型結(jié)構(gòu)

接下來,我們需要定義GAT模型的結(jié)構(gòu)。通過PyTorch和PyG,我們可以自己定義完整的GAT模型或者利用現(xiàn)有的庫函數(shù)快速構(gòu)建模型。在這里,我們將使用 torch_geometric.nn.GATConv 函數(shù)逐層堆疊多個圖注意力層來實現(xiàn)GAT模型。以下是GAT模型定義的示例代碼:

import torch.nn.functional as F
from torch_geometric.nn import GATConv
class Net(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Net, self).__init__()
        self.num_layers = 2
        self.conv1 = GATConv(in_channels=in_channels, out_channels=16, heads=8, dropout=0.6)
        self.conv2 = GATConv(in_channels=16*8, out_channels=out_channels, heads=1, concat=False, dropout=0.6)
    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = F.dropout(x, p=0.6, training=self.training)
        x = F.elu(self.conv1(x, edge_index))
        x = F.dropout(x, p=0.6, training=self.training)
        x = self.conv2(x, edge_index)
        return F.log_softmax(x, dim=1)

上述代碼中,我們定義了一個 Net 類用于構(gòu)建GAT網(wǎng)絡(luò),接收輸入通道數(shù)和輸出通道數(shù)作為參數(shù)。例如,我們可以按照以下方式創(chuàng)建一個將 CORD 參量作為輸入特征向量大小、64 個隱藏節(jié)點(每個注意力頭)。并將數(shù)字類別作為輸出大小的GAT模型:

model = Net(in_channels=dataset.num_features, out_channels=dataset.num_classes)

其中 num_featuresnum_classes 是PyG數(shù)據(jù)集中包含的屬性。

定義訓(xùn)練函數(shù)

然后,我們需要定義訓(xùn)練函數(shù)來訓(xùn)練我們的GAT神經(jīng)網(wǎng)絡(luò)。在這里,我們將使用交叉熵損失和Adam優(yōu)化器進行訓(xùn)練,并在每一個epoch結(jié)束時計算準確率并打印出來。以下是訓(xùn)練函數(shù)的示例代碼:

import torch.optim as optim
from tqdm import tqdm
def train(model, loader, optimizer, loss_fn):
    model.train()
    correct = 0
    total_loss = 0
    for data in tqdm(loader, desc='Training'):
        optimizer.zero_grad()
        out = model(data)
        pred = out.argmax(dim=1)
        loss = loss_fn(out[data.train_mask], data.y[data.train_mask])
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * data.num_graphs
        correct += pred[data.train_mask].eq(data.y[data.train_mask]).sum().item()
    return total_loss / len(loader.dataset), correct / len(data.train_mask)

在上述代碼中,我們遍歷加載器中的每個數(shù)據(jù)批次,并對模型進行培訓(xùn)。對于每個圖數(shù)據(jù)批次,我們計算網(wǎng)絡(luò)輸出、預(yù)測和損失,然后通過反向傳播來更新權(quán)重。最后,我們將總損失和正確率記錄下來并返回。

定義測試函數(shù)

接下來,我們還需要定義測試函數(shù)來測試我們的GAT神經(jīng)網(wǎng)絡(luò)性能表現(xiàn)。我們將利用與訓(xùn)練函數(shù)相同的輸出參數(shù)進行測試,并打印出最終的測試準確率。以下是測試函數(shù)的示例代碼:

def test(model, loader, loss_fn):
    model.eval()
    correct = 0
    total_loss = 0
    with torch.no_grad():
        for data in tqdm(loader, desc='Testing'):
            out = model(data)
            pred = out.argmax(dim=1)
            loss = loss_fn(out[data.test_mask], data.y[data.test_mask])
            total_loss += loss.item() * data.num_graphs
            correct += pred[data.test_mask].eq(data.y[data.test_mask]).sum().item()
    return total_loss / len(loader.dataset), correct / len(data.test_mask)

在上述代碼中,我們對測試數(shù)據(jù)集中的所有數(shù)據(jù)進行了循環(huán),并計算網(wǎng)絡(luò)的輸出和預(yù)測。我們記錄下總損失和正確分類的數(shù)據(jù)量,并返回損失和準確率之間的比率。

訓(xùn)練模型并評估訓(xùn)練結(jié)果

最后,我們可以使用前面定義過的函數(shù)來定義主函數(shù),從而完成GAT神經(jīng)網(wǎng)絡(luò)的訓(xùn)練和測試。以下是主函數(shù)的示例代碼:

if __name__ == '__main__':
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = Net(in_channels=dataset.num_features, out_channels=dataset.num_classes).to(device)
    train_loader = DataLoader(dataset[0], batch_size=128, shuffle=True)
    test_loader = DataLoader(dataset[0], batch_size=128, shuffle=False)
    optimizer = optim.Adam(model.parameters(), lr=0.01)
    loss_fn = nn.CrossEntropyLoss()
    for epoch in range(1, 201):
        train_loss, train_acc = train(model, train_loader, optimizer, loss_fn)
        test_loss, test_acc = test(model, test_loader, loss_fn)
        print(f'Epoch {epoch:03d}, Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, '
              f'Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.4f}')

通過上述代碼,我們就可以完成GAT神經(jīng)網(wǎng)絡(luò)的訓(xùn)練和測試。我們使用 DataLoader 函數(shù)進行數(shù)據(jù)加載,設(shè)置學(xué)習(xí)率、損失函數(shù)、訓(xùn)練輪數(shù)等超參數(shù)。最后,我們可以在屏幕上看到每個時代的準確率和損失值,并通過它們評估模型的訓(xùn)練表現(xiàn)。

以上就是詳解Pytorch+PyG實現(xiàn)GAT過程示例的詳細內(nèi)容,更多關(guān)于Pytorch PyG實現(xiàn)GAT的資料請關(guān)注腳本之家其它相關(guān)文章!

相關(guān)文章

  • PySide2出現(xiàn)“ImportError: DLL load failed: 找不到指定的模塊”的問題及解決方法

    PySide2出現(xiàn)“ImportError: DLL load failed: 找不到指定的模塊”的問題及解決方法

    這篇文章主要介紹了PySide2出現(xiàn)“ImportError: DLL load failed: 找不到指定的模塊”的問題及解決方法,本文通過實例代碼給大家介紹的非常詳細,需要的朋友可以參考下
    2020-06-06
  • Python OpenCV圖像模糊處理介紹

    Python OpenCV圖像模糊處理介紹

    大家好,本篇文章主要講的是Python OpenCV圖像模糊處理介紹,感興趣的同學(xué)趕快來看一看吧,對你有幫助的話記得收藏一下
    2022-01-01
  • python處理multipart/form-data的請求方法

    python處理multipart/form-data的請求方法

    今天小編就為大家分享一篇python處理multipart/form-data的請求方法,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧
    2018-12-12
  • Python列表的切片實例講解

    Python列表的切片實例講解

    在本篇文章里小編給大家分享了關(guān)于Python列表的切片的知識點實例,需要的朋友們可以參考下。
    2019-08-08
  • 基于opencv對高空拍攝視頻消抖處理方法

    基于opencv對高空拍攝視頻消抖處理方法

    這篇文章主要介紹了基于opencv對高空拍攝視頻消抖處理,首先對視頻進行抽第一幀與最后一幀,為什么抽取兩幀?這樣做的主要目的是,我們在做幀對齊時,使用幀中靜態(tài)物的關(guān)鍵點做對齊,需要的朋友可以參考下
    2022-10-10
  • tensorflow 獲取checkpoint中的變量列表實例

    tensorflow 獲取checkpoint中的變量列表實例

    今天小編就為大家分享一篇tensorflow 獲取checkpoint中的變量列表實例,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧
    2020-02-02
  • python 生成器和迭代器的原理解析

    python 生成器和迭代器的原理解析

    這篇文章主要介紹了python 生成器和迭代器的原理解析,文中通過示例代碼介紹的非常詳細,對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價值,需要的朋友可以參考下
    2019-10-10
  • pytorch如何自定義數(shù)據(jù)集

    pytorch如何自定義數(shù)據(jù)集

    這篇文章主要介紹了pytorch自定義數(shù)據(jù)集,在識別手寫數(shù)字的例子中,數(shù)據(jù)集是直接下載的,但如果我們自己收集了一些數(shù)據(jù),存在電腦文件夾里,我們該如何把這些數(shù)據(jù)變?yōu)榭梢栽赑yTorch框架下進行神經(jīng)網(wǎng)絡(luò)訓(xùn)練的數(shù)據(jù)集呢,即如何自定義數(shù)據(jù)集呢,需要的朋友可以參考下
    2024-01-01
  • python實現(xiàn)skywalking的trace模塊過濾和報警(實例代碼)

    python實現(xiàn)skywalking的trace模塊過濾和報警(實例代碼)

    Skywalking可以對鏈路追蹤到數(shù)據(jù)進行告警規(guī)則配置,例如響應(yīng)時間、響應(yīng)百分比等。發(fā)送警告通過調(diào)用webhook接口完成。webhook接口用戶可以自定義。本文給大家介紹python實現(xiàn)skywalking的trace模塊過濾和報警,感興趣的朋友跟隨小編一起看看吧
    2021-12-12
  • python基礎(chǔ)之文件的備份以及定位

    python基礎(chǔ)之文件的備份以及定位

    這篇文章主要介紹了python文件的備份以及定位,實例分析了Python中返回一個返回值與多個返回值的方法,需要的朋友可以參考下
    2021-10-10

最新評論