Pytorch+PyG實現(xiàn)GraphConv過程示例詳解
GraphConv簡介
GraphConv是一種使用圖形數(shù)據(jù)的卷積神經(jīng)網(wǎng)絡(Convolutional Neural Network, CNN)模型。與傳統(tǒng)的CNN僅能處理圖片二維數(shù)據(jù)不同,GraphConv可以對任意結構的圖進行卷積操作,并適用于基于圖的多項任務。
實現(xiàn)步驟
數(shù)據(jù)準備
在本實驗中,我們使用了一個包含4萬個圖像的數(shù)據(jù)集CIFAR-10,作為示例。與其它標準圖像數(shù)據(jù)集不同的是,在這個數(shù)據(jù)集中圖形的構成量非常大,而且各圖之間結構差異很大,因此需要進行大量的預處理工作。
# 導入cifar-10數(shù)據(jù)集 from torch_geometric.datasets import Planetoid # 加載數(shù)據(jù)、劃分訓練集和測試集 dataset = Planetoid(root='./cifar10', name='Cora') data = dataset[0] # 定義超級參數(shù) num_features = dataset.num_features num_classes = dataset.num_classes # 構建訓練集和測試集索引文件 train_mask = torch.zeros(data.num_nodes, dtype=torch.uint8) train_mask[:800] = 1 test_mask = torch.zeros(data.num_nodes, dtype=torch.uint8) test_mask[800:] = 1 # 創(chuàng)建數(shù)據(jù)加載器 train_loader = DataLoader(data[train_mask], batch_size=32, shuffle=True) test_loader = DataLoader(data[test_mask], batch_size=32, shuffle=False)
通過上述代碼,我們先是導入CIFAR-10數(shù)據(jù)集并將其分割為訓練及測試兩個數(shù)據(jù)集,并創(chuàng)建了相應的數(shù)據(jù)加載器以便于對數(shù)據(jù)進行有效處理。
實現(xiàn)模型
在定義GraphConv模型時,我們需要根據(jù)圖像經(jīng)常使用的架構定義網(wǎng)絡結構。同時,在實現(xiàn)卷積操作時應引入鄰接矩陣(adjacency matrix)和特征矩陣(feature matrix)作為輸入,來使得網(wǎng)絡能夠學習到節(jié)點之間的關系和提取重要特征。
from torch.nn import Linear, ModuleList, ReLU from torch_geometric.nn import GCNConv class GraphConv(torch.nn.Module): def __init__(self, dataset): super(GraphConv, self).__init__() # 定義基礎參數(shù) self.input_dim = dataset.num_features self.output_dim = dataset.num_classes # 定義GCN網(wǎng)絡結構 self.convs = ModuleList() self.convs.append(GCNConv(self.input_dim, 16)) self.convs.append(GCNConv(16, 32)) self.convs.append(GCNConv(32, self.output_dim)) def forward(self, x, edge_index): for conv in self.convs: x = conv(x, edge_index) x = F.relu(x) return F.log_softmax(x, dim=1)
在上述代碼中,我們實現(xiàn)了基于GraphConv的模型的各個卷積層,并使用GCNConv
將鄰接矩陣和特征矩陣作為輸入進行特征提取。最后結合全連接層輸出一個維度為類別數(shù)的向量,并通過softmax函數(shù)來計算損失。
模型訓練
在定義好GraphConv網(wǎng)絡結構之后,我們還需要指定合適的優(yōu)化器、損失函數(shù),并控制訓練輪數(shù)、批大小與學習率等超參數(shù)。同時也需要記錄大量日志信息,方便后期跟蹤及管理。
# 定義訓練計劃,包括損失函數(shù)、優(yōu)化器及迭代次數(shù)等 train_epochs = 200 learning_rate = 0.01 criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(graph_conv.parameters(), lr=learning_rate) losses_per_epoch = [] accuracies_per_epoch = [] for epoch in range(train_epochs): running_loss = 0.0 running_corrects = 0.0 count = 0.0 for samples in train_loader: optimizer.zero_grad() x, edge_index = samples.x, samples.edge_index out = graph_conv(x, edge_index) label = samples.y loss = criterion(out, label) loss.backward() optimizer.step() running_loss += loss.item() / len(train_loader.dataset) pred = out.argmax(dim=1) running_corrects += pred.eq(label).sum().item() / len(train_loader.dataset) count += 1 losses_per_epoch.append(running_loss) accuracies_per_epoch.append(running_corrects) if (epoch + 1) % 20 == 0: print("Train Epoch {}/{} Loss {:.4f} Accuracy {:.4f}".format( epoch + 1, train_epochs, running_loss, running_corrects))
在訓練過程中,我們遍歷每個batch,通過反向傳播算法進行優(yōu)化,并更新loss及accuracy輸出。同時,為了方便可視化與記錄,需要將訓練過程中的loss和accuracy輸出到相應的容器中,以便后期進行分析和處理。
以上就是Pytorch+PyG實現(xiàn)GraphConv過程示例詳解的詳細內容,更多關于Pytorch PyG實現(xiàn)GraphConv的資料請關注腳本之家其它相關文章!
相關文章
Pycharm的Available Packages為空的解決方法
這篇文章主要介紹了Pycharm的Available Packages為空的解決方法,文中通過示例代碼介紹的非常詳細,對大家的學習或者工作具有一定的參考學習價值,需要的朋友們下面隨著小編來一起學習學習吧2020-09-09使用Python搭建服務器公網(wǎng)展示本地電腦文件的操作過程
這篇文章主要介紹了使用Python搭建服務器公網(wǎng)展示本地電腦文件,今天我們就嘗試用python,建立一個簡單的http服務器,用來展示本地電腦上指定的目錄和文件,需要的朋友可以參考下2023-08-08python pprint模塊中print()和pprint()兩者的區(qū)別
這篇文章主要介紹了python pprint模塊中print()和pprint()兩者的區(qū)別,文中通過示例代碼介紹的非常詳細,對大家的學習或者工作具有一定的參考學習價值,需要的朋友們下面隨著小編來一起學習學習吧2020-02-02