欧美bbbwbbbw肥妇,免费乱码人妻系列日韩,一级黄片

Pytorch?PyG實現(xiàn)EdgePool圖分類

 更新時間:2023年04月21日 09:55:27   作者:實力  
這篇文章主要為大家介紹了Pytorch?PyG實現(xiàn)EdgePool圖分類示例詳解,有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進步,早日升職加薪

EdgePool簡介

EdgePool是一種用于圖分類的卷積神經(jīng)網(wǎng)絡(Convolutional Neural Network,CNN)模型。其主要思想是通過 edge pooling 上下采樣優(yōu)化圖像大小,減少空間復雜度,提高分類性能。

實現(xiàn)步驟

 數(shù)據(jù)準備

一般來講,在構建較大規(guī)模數(shù)據(jù)集時,我們都需要對數(shù)據(jù)進行規(guī)范、歸一和清洗處理,以便后續(xù)語義分析或深度學習操作。而在圖像數(shù)據(jù)集中,則需使用特定的框架或工具庫完成。

# 導入MNIST數(shù)據(jù)集
from torch_geometric.datasets import MNISTSuperpixels
# 加載數(shù)據(jù)、劃分訓練集和測試集
dataset = MNISTSuperpixels(root='./mnist', transform=Compose([ToTensor(), NormalizeMeanStd()]))
data = dataset[0]
# 定義超級參數(shù)
num_features = dataset.num_features
num_classes = dataset.num_classes
# 構建訓練集和測試集索引文件
train_mask = torch.zeros(data.num_nodes, dtype=torch.uint8)
train_mask[:60000] = 1
test_mask = torch.zeros(data.num_nodes, dtype=torch.uint8)
test_mask[60000:] = 1
# 創(chuàng)建數(shù)據(jù)加載器
train_loader = DataLoader(data[train_mask], batch_size=32, shuffle=True)
test_loader = DataLoader(data[test_mask], batch_size=32, shuffle=False)

實現(xiàn)模型

在定義EdgePool模型時,我們需要重新考慮網(wǎng)絡結構中的上下采樣操作,以便讓整個網(wǎng)絡擁有更強大的表達能力,從而學習到更復雜的關系。

from torch.nn import Linear
from torch_geometric.nn import EdgePooling
class EdgePool(torch.nn.Module):
    def __init__(self, dataset):
        super(EdgePool, self).__init__()
        # 定義輸入與輸出維度數(shù)
        self.input_dim = dataset.num_features
        self.hidden_dim = 128
        self.output_dim = 10
        # 定義卷積層、歸一化層和pooling層等
        self.conv1 = GCNConv(self.input_dim, self.hidden_dim)
        self.norm1 = BatchNorm1d(self.hidden_dim)
        self.pool1 = EdgePooling(self.hidden_dim)
        self.conv2 = GCNConv(self.hidden_dim, self.hidden_dim)
        self.norm2 = BatchNorm1d(self.hidden_dim)
        self.pool2 = EdgePooling(self.hidden_dim)
        self.conv3 = GCNConv(self.hidden_dim, self.hidden_dim)
        self.norm3 = BatchNorm1d(self.hidden_dim)
        self.pool3 = EdgePooling(self.hidden_dim)
        self.lin = torch.nn.Linear(self.hidden_dim, self.output_dim)
    def forward(self, x, edge_index, batch):
        x = F.relu(self.norm1(self.conv1(x, edge_index)))
        x, edge_index, _, batch, _ = self.pool1(x, edge_index, None, batch)
        x = F.relu(self.norm2(self.conv2(x, edge_index)))
        x, edge_index, _, batch, _ = self.pool2(x, edge_index, None, batch)
        x = F.relu(self.norm3(self.conv3(x, edge_index)))
        x, edge_index, _, batch, _ = self.pool3(x, edge_index, None, batch)
        x = global_mean_pool(x, batch)
        x = self.lin(x)
        return x

在上述代碼中,我們使用了不同的卷積層、池化層和全連接層等神經(jīng)網(wǎng)絡功能塊來構建EdgePool模型。其中,每個 GCNConv 層被保持為128的隱藏尺寸;BatchNorm1d是一種旨在提高收斂速度并增強網(wǎng)絡泛化能力的方法;EdgePooling是一種在 GraphConvolution 上附加的特殊類別,它將給定圖下采樣至其一半的大小,并返回縮小后的圖與兩個跟蹤full-graph-to-pool雙向映射(keep and senders)的 edge index(edgendarcs)。 在這種情況下傳遞 None ,表明 batch 未更改。

模型訓練

在定義好 EdgePool 網(wǎng)絡結構之后,需要指定合適的優(yōu)化器、損失函數(shù),并控制訓練輪數(shù)、批量大小與學習率等超參數(shù)。同時還要記錄大量日志信息,方便后期跟蹤和駕駛員。

# 定義訓練計劃,包括損失函數(shù)、優(yōu)化器及迭代次數(shù)等
train_epochs = 50
learning_rate = 0.01
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(edge_pool.parameters(), lr=learning_rate)
losses_per_epoch = []
accuracies_per_epoch = []
for epoch in range(train_epochs):
    running_loss = 0.0
    running_corrects = 0.0
    count = 0.0
    for samples in train_loader:
        optimizer.zero_grad()
        x, edge_index, batch = samples.x, samples.edge_index, samples.batch
        out = edge_pool(x, edge_index, batch)
        label = samples.y
        loss = criterion(out, label)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() / len(train_loader.dataset)
        pred = out.argmax(dim=1)
        running_corrects += pred.eq(label).sum().item() / len(train_loader.dataset)
        count += 1
    losses_per_epoch.append(running_loss)
    accuracies_per_epoch.append(running_corrects)
    if (epoch + 1) % 10 == 0:
        print("Train Epoch {}/{} Loss {:.4f} Accuracy {:.4f}".format(
            epoch + 1, train_epochs, running_loss, running_corrects))

在訓練過程中,我們遍歷了每個批次的數(shù)據(jù),并通過反向傳播算法進行優(yōu)化,并更新了 loss 和 accuracy 輸出值。 同時方便可視化與記錄,需要將訓練過程中的 loss 和 accuracy 輸出到相應的容器中,以便后期進行分析和處理。

以上就是Pytorch PyG實現(xiàn)EdgePool圖分類的詳細內容,更多關于Pytorch PyG EdgePool圖分類的資料請關注腳本之家其它相關文章!

相關文章

  • Python面向對象編程之繼承與多態(tài)詳解

    Python面向對象編程之繼承與多態(tài)詳解

    這篇文章主要介紹了Python面向對象編程之繼承與多態(tài),結合實例形式詳細分析了Python面向對象編程中繼承與多態(tài)的概念、使用方法及相關注意事項,需要的朋友可以參考下
    2018-01-01
  • Python?NumPy教程之數(shù)組的基本操作詳解

    Python?NumPy教程之數(shù)組的基本操作詳解

    Numpy?中的數(shù)組是一個元素表(通常是數(shù)字),所有元素類型相同,由正整數(shù)元組索引。本文將通過一些示例詳細講一下NumPy中數(shù)組的一些基本操作,需要的可以參考一下
    2022-08-08
  • Python調用訊飛語音合成API接口來實現(xiàn)文字轉語音

    Python調用訊飛語音合成API接口來實現(xiàn)文字轉語音

    這篇文章主要為大家介紹了Python調用訊飛語音合成API接口來實現(xiàn)文字轉語音方法示例詳解,有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進步,早日升職加薪
    2023-04-04
  • K近鄰法(KNN)相關知識總結以及如何用python實現(xiàn)

    K近鄰法(KNN)相關知識總結以及如何用python實現(xiàn)

    這篇文章主要介紹了K近鄰法(KNN)相關知識總結以及如何用python實現(xiàn),幫助大家更好的利用python實現(xiàn)機器學習,感興趣的朋友可以了解下
    2021-01-01
  • python 圖像增強算法實現(xiàn)詳解

    python 圖像增強算法實現(xiàn)詳解

    這篇文章主要介紹了python 圖像增強算法實現(xiàn)詳解,文中通過示例代碼介紹的非常詳細,對大家的學習或者工作具有一定的參考學習價值,需要的朋友們下面隨著小編來一起學習學習吧
    2021-01-01
  • python?魔法方法之?__?slots?__的實現(xiàn)

    python?魔法方法之?__?slots?__的實現(xiàn)

    本文主要介紹了python?魔法方法之?__?slots?__的實現(xiàn),文中通過示例代碼介紹的非常詳細,對大家的學習或者工作具有一定的參考學習價值,需要的朋友們下面隨著小編來一起學習學習吧
    2023-03-03
  • Pandas中字符串和時間轉換與格式化的實現(xiàn)

    Pandas中字符串和時間轉換與格式化的實現(xiàn)

    本文主要介紹了Pandas中字符串和時間轉換與格式化的實現(xiàn),文中通過示例代碼介紹的非常詳細,對大家的學習或者工作具有一定的參考學習價值,需要的朋友們下面隨著小編來一起學習學習吧
    2023-01-01
  • Python獲取當前路徑實現(xiàn)代碼

    Python獲取當前路徑實現(xiàn)代碼

    這篇文章主要介紹了 Python獲取當前路徑實現(xiàn)代碼的相關資料,需要的朋友可以參考下
    2017-05-05
  • 詳解python中TCP協(xié)議中的粘包問題

    詳解python中TCP協(xié)議中的粘包問題

    這篇文章主要介紹了python中TCP協(xié)議中的粘包問題,文中通過示例代碼介紹的非常詳細,對大家的學習或者工作具有一定的參考學習價值,需要的朋友們下面隨著小編來一起學習學習吧
    2019-03-03
  • python數(shù)據(jù)分析必會的Pandas技巧匯總

    python數(shù)據(jù)分析必會的Pandas技巧匯總

    用Python做數(shù)據(jù)分析光是掌握numpy和matplotlib可不夠,numpy雖然能夠幫我們處理處理數(shù)值型數(shù)據(jù),但很多時候,還有字符串,還有時間序列等,比如:我們通過爬蟲獲取到了存儲在數(shù)據(jù)庫中的數(shù)據(jù),一些Pandas必會的用法,讓你的數(shù)據(jù)分析水平更上一層樓
    2021-08-08

最新評論