Pytorch+PyG實(shí)現(xiàn)EdgeCNN過(guò)程示例詳解
1.EdgeCNN簡(jiǎn)介
EdgeCNN是一種用于圖像點(diǎn)云處理的卷積神經(jīng)網(wǎng)絡(luò)(Convolutional Neural Network,CNN)模型。與傳統(tǒng)的CNN僅能處理圖片二維數(shù)據(jù)不同,EdgeCNN可以對(duì)三維點(diǎn)云中每個(gè)點(diǎn)周?chē)木植苦徲蜻M(jìn)行操作,并適用于物體識(shí)別、深度估計(jì)、自動(dòng)駕駛等多項(xiàng)任務(wù)。
2. 實(shí)現(xiàn)步驟
2.1 數(shù)據(jù)準(zhǔn)備
在本實(shí)驗(yàn)中,我們使用了一個(gè)包含4萬(wàn)個(gè)點(diǎn)云的數(shù)據(jù)集ModelNet10,作為示例。與其它標(biāo)準(zhǔn)圖像數(shù)據(jù)集不同的是,這個(gè)數(shù)據(jù)集中圖形的構(gòu)成量非常大,而且各圖之間結(jié)構(gòu)差異很大,因此需要進(jìn)行大量的預(yù)處理工作。
# 導(dǎo)入模型數(shù)據(jù)集 from torch_geometric.datasets import ModelNet # 加載ModelNet數(shù)據(jù)集 dataset = ModelNet(root='./modelnet', name='10') data = dataset[0] # 定義超級(jí)參數(shù) num_points = 1024 batch_size = 32 train_dataset_size = 8000 # 將數(shù)據(jù)集分割成訓(xùn)練、驗(yàn)證及測(cè)試三個(gè)數(shù)據(jù)集 train_dataset = data[0:train_dataset_size] val_dataset = data[train_dataset_size: 9000] test_dataset = data[9000:] # 定義數(shù)據(jù)加載批處理器 train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False) test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
通過(guò)上述代碼,我們先是導(dǎo)入ModelNet數(shù)據(jù)集并將其分割成訓(xùn)練、驗(yàn)證及測(cè)試三個(gè)數(shù)據(jù)集,并創(chuàng)建了數(shù)據(jù)加載批處理器,以便于在訓(xùn)練過(guò)程中對(duì)這些數(shù)據(jù)進(jìn)行有效的處理。
2.2 實(shí)現(xiàn)模型
在定義EdgeCNN模型時(shí),我們需要根據(jù)圖像點(diǎn)云經(jīng)常使用的架構(gòu)定義網(wǎng)絡(luò)結(jié)構(gòu)。同時(shí),在實(shí)現(xiàn)卷積操作時(shí)應(yīng)引入相應(yīng)的鄰域信息,來(lái)使得網(wǎng)絡(luò)能夠?qū)W習(xí)到系統(tǒng)中附近點(diǎn)之間的關(guān)系。
from torch.nn import Sequential as Seq, Linear as Lin, ReLU from torch_geometric.nn import EdgeConv, global_max_pool class EdgeCNN(torch.nn.Module): def __init__(self, dataset): super(EdgeCNN, self).__init__() # 定義基礎(chǔ)參數(shù) self.input_dim = dataset.num_features self.output_dim = dataset.num_classes self.num_points = num_points # 定義模型結(jié)構(gòu) self.conv1 = EdgeConv(Seq(Lin(self.input_dim, 32), ReLU())) self.conv2 = EdgeConv(Seq(Lin(32, 64), ReLU())) self.conv3 = EdgeConv(Seq(Lin(64, 128), ReLU())) self.conv4 = EdgeConv(Seq(Lin(128, 256), ReLU())) self.fc1 = torch.nn.Linear(256, 1024) self.fc2 = torch.nn.Linear(1024, self.output_dim) def forward(self, pos, batch): # 構(gòu)造圖 edge_index = radius_graph(pos, r=0.6, batch=batch, loop=False) # 第一層CNN模型的卷積 + 池化處理 x = F.relu(self.conv1(x=pos, edge_index=edge_index)) x = global_max_pool(x, batch) # 第二層CNN模型的卷積 + 池化處理 edge_index = radius_graph(x, r=0.9, batch=batch, loop=False) x = F.relu(self.conv2(x=x, edge_index=edge_index)) x = global_max_pool(x, batch) # 第三層CNN模型的卷積 + 池化處理 edge_index = radius_graph(x, r=1.2, batch=batch, loop=False) x = F.relu(self.conv3(x=x, edge_index=edge_index)) x = global_max_pool(x, batch) # 第四層CNN模型的卷積 + 池化處理 edge_index = radius_graph(x, r=1.5, batch=batch, loop=False) x = F.relu(self.conv4(x=x, edge_index=edge_index)) # 定義全連接網(wǎng)絡(luò) x = global_max_pool(x, batch) x = F.relu(self.fc1(x)) x = self.fc2(x) return F.log_softmax(x, dim=-1)
在上述代碼中,實(shí)現(xiàn)了基于EdgeCNN的模型的各個(gè)卷積層和全連接層,并使用radius_graph
等函數(shù)將局部區(qū)域問(wèn)題歸約到定義的卷積核檢測(cè)范圍之內(nèi),以便更好地對(duì)點(diǎn)進(jìn)行分析和特征提取。最后結(jié)合全連接層輸出一個(gè)維度為類(lèi)別數(shù)的向量,并通過(guò)softmax函數(shù)來(lái)計(jì)算損失。
2.3 模型訓(xùn)練
在定義好EdgeCNN網(wǎng)絡(luò)結(jié)構(gòu)之后,我們還需要指定合適的優(yōu)化器、損失函數(shù),并控制訓(xùn)練輪數(shù)、批大小與學(xué)習(xí)率等超參數(shù)。同時(shí)也需要記錄大量日志信息,方便后期跟蹤及管理。
# 定義訓(xùn)練計(jì)劃,包括損失函數(shù)、優(yōu)化器及迭代次數(shù)等 train_epochs = 50 learning_rate = 0.01 criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(edge_cnn.parameters(), lr=learning_rate) losses_per_epoch = [] accuracies_per_epoch = [] for epoch in range(train_epochs): running_loss = 0.0 running_corrects = 0.0 count = 0.0 for samples in train_loader: optimizer.zero_grad() pos, batch, label = samples.pos, samples.batch, samples.y.to(torch.long) out = edge_cnn(pos, batch) loss = criterion(out, label) loss.backward() optimizer.step() running_loss += loss.item() / len(train_dataset) running_corrects += torch.sum(torch.argmax(out, dim=1) == label).item() / len(train_dataset) count += 1 losses_per_epoch.append(running_loss) accuracies_per_epoch.append(running_corrects) if (epoch + 1) % 5 == 0: print("Train Epoch {}/{} Loss {:.4f} Accuracy {:.4f}".format( epoch + 1, train_epochs, running_loss, running_corrects))
在訓(xùn)練過(guò)程中,我們遍歷每個(gè)batch,通過(guò)反向傳播算法進(jìn)行優(yōu)化,并更新loss及accuracy輸出。同時(shí),為了方便可視化與記錄,需要將訓(xùn)練過(guò)程中的loss和accuracy輸出到相應(yīng)的容器中,以便后期進(jìn)行分析和處理。
以上就是Pytorch+PyG實(shí)現(xiàn)EdgeCNN過(guò)程示例詳解的詳細(xì)內(nèi)容,更多關(guān)于Pytorch PyG實(shí)現(xiàn)EdgeCNN的資料請(qǐng)關(guān)注腳本之家其它相關(guān)文章!
- 詳解如何使用Pytorch進(jìn)行多卡訓(xùn)練
- pytorch DataLoaderj基本使用方法詳解
- 詳解PyTorch預(yù)定義數(shù)據(jù)集類(lèi)datasets.ImageFolder使用方法
- 詳解Pytorch+PyG實(shí)現(xiàn)GCN過(guò)程示例
- 詳解Pytorch+PyG實(shí)現(xiàn)GAT過(guò)程示例
- Pytorch+PyG實(shí)現(xiàn)GIN過(guò)程示例詳解
- Pytorch+PyG實(shí)現(xiàn)GraphSAGE過(guò)程示例詳解
- PyTorch常用函數(shù)torch.cat()中dim參數(shù)使用說(shuō)明
相關(guān)文章
總結(jié)Pyinstaller的坑及終極解決方法(小結(jié))
這篇文章主要介紹了總結(jié)Pyinstaller的坑及終極解決方法,文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來(lái)一起學(xué)習(xí)學(xué)習(xí)吧2020-09-09python logging.info在終端沒(méi)輸出的解決
這篇文章主要介紹了python logging.info在終端沒(méi)輸出的解決,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧2020-05-05推薦系統(tǒng)MostPopular算法的Python實(shí)現(xiàn)方式
這篇文章主要介紹了推薦系統(tǒng)MostPopular算法的Python實(shí)現(xiàn)方式,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。如有錯(cuò)誤或未考慮完全的地方,望不吝賜教2022-07-07Python cookbook(字符串與文本)在字符串的開(kāi)頭或結(jié)尾處進(jìn)行文本匹配操作
這篇文章主要介紹了Python cookbook(字符串與文本)在字符串的開(kāi)頭或結(jié)尾處進(jìn)行文本匹配操作,涉及Python使用str.startswith()和str.endswith()方法針對(duì)字符串開(kāi)始或結(jié)尾處特定文本匹配操作相關(guān)實(shí)現(xiàn)技巧,需要的朋友可以參考下2018-04-04基于python實(shí)現(xiàn)弱密碼檢測(cè)工具
Python中一個(gè)強(qiáng)大的加密模塊,提供了許多常見(jiàn)的加密算法和工具,本文我們將使用Python編寫(xiě)一個(gè)弱密碼檢測(cè)工具,感興趣的小伙伴可以了解一下2024-01-01Jupyter notebook 遠(yuǎn)程配置及SSL加密教程
這篇文章主要介紹了Jupyter notebook 遠(yuǎn)程配置及SSL加密教程,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧2020-04-04Python爬蟲(chóng)中urllib庫(kù)的進(jìn)階學(xué)習(xí)
本篇文章主要介紹了Python爬蟲(chóng)中urllib庫(kù)的進(jìn)階學(xué)習(xí)內(nèi)容,對(duì)此有興趣的朋友趕緊學(xué)習(xí)分享下。2018-01-01python自動(dòng)化測(cè)試selenium核心技術(shù)三種等待方式詳解
這篇文章主要為大家介紹了python自動(dòng)化測(cè)試selenium的核心技術(shù)三種等待方式示例詳解,有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進(jìn)步早日升職加薪2021-11-11