PyG搭建GCN需要準備的數據格式
前言
有關GCN的原理可以參考:GCN圖卷積神經網絡原理
一開始是打算手寫一下GCN,畢竟原理也不是很難,但想了想還是直接調包吧。在使用各種深度學習框架時我們首先需要知道的是框架內的數據集結構,因此這篇文章主要講講PyG中的數據結構。
1. PyG數據集
原始論文中使用的數據集:
本篇文章使用Citeseer網絡。Citeseer網絡是一個引文網絡,節(jié)點為論文,一共3327篇論文。論文一共分為六類:Agents、AI(人工智能)、DB(數據庫)、IR(信息檢索)、ML(機器語言)和HCI。如果兩篇論文間存在引用關系,那么它們之間就存在鏈接關系。
使用PyG加載數據集:
data = Planetoid(root='/data/CiteSeer', name='CiteSeer') print(len(data))
輸出:
1
CiteSeer中只有一個網絡,然后我們輸出一下這個網絡:
data = data[0] print(data) print(data.is_directed())
輸出:
Data(x=[3327, 3703], edge_index=[2, 9104], y=[3327], train_mask=[3327], val_mask=[3327], test_mask=[3327]) False
x=[3327, 3703]。表示一共有3327個節(jié)點,然后節(jié)點的特征維度為3703,這里實際上是去除停用詞和在文檔中出現頻率小于10次的詞,整理得到3703個唯一詞。
edge_index=[2, 9104],表示一共9104條edge。數據一共兩行,每一行都表示節(jié)點編號。
輸出一下data.y:
tensor([3, 1, 5, ..., 3, 1, 5])tensor([3, 1, 5, ..., 3, 1, 5])
data.y表示節(jié)點的標簽編號,比如3表示該篇論文屬于第3類。
輸出data.train_mask:
tensor([ True, True, True, ..., False, False, False])
data.train_mask的長度和y的長度一致,如果某個位置為True就表示該樣本為訓練樣本。val_mask和test_mask類似,分別表示驗證集和訓練集。
比如我們輸出:
print(data.y[data.test_mask])
結果為:
tensor([4, 5, 4, 4, 4, 1, 4, 2, 3, 3, 3, 3, 2, 3, 3, 4, 2, 0, 1, 2, 0, 3, 3, 4, 2, 4, 0, 4, 3, 3, 3, 5, 4, 5, 4, 5, 1, 1, 3, 3, 3, 3, 3, 1, 2, 3, 3, 3, 1, 2, 2, 3, 3, 1, 5, 5, 5, 3, 2, 3, 3, 3, 3, 3, 3, 3, 5, 1, 3, 1, 1, 4, 1, 3, 3, 1, 3, 3, 2, 4, 3, 3, 3, 1, 2, 2, 2, 3, 5, 2, 1, 3, 2, 2, 2, 4, 3, 3, 4, 0, 3, 1, 2, 2, 2, 2, 3, 2, 2, 2, 1, 1, 5, 2, 2, 1, 2, 4, 3, 1, 1, 3, 2, 3, 4, 3, 3, 4, 4, 3, 2, 2, 1, 3, 4, 4, 4, 4, 4, 4, 5, 0, 3, 1, 1, 3, 1, 3, 1, 3, 4, 4, 3, 2, 3, 5, 3, 3, 3, 4, 2, 2, 2, 5, 3, 1, 0, 3, 2, 5, 2, 3, 2, 4, 2, 2, 2, 0, 5, 1, 3, 4, 4, 4, 1, 1, 5, 1, 2, 0, 1, 0, 2, 2, 3, 3, 3, 3, 5, 4, 4, 3, 1, 1, 2, 1, 2, 2, 2, 2, 5, 0, 1, 2, 2, 4, 0, 4, 1, 1, 2, 3, 1, 1, 2, 3, 3, 5, 2, 5, 5, 3, 1, 0, 5, 5, 5, 5, 3, 3, 3, 0, 4, 5, 3, 4, 5, 4, 5, 2, 0, 5, 5, 5, 1, 1, 3, 1, 2, 2, 2, 3, 2, 4, 5, 3, 3, 1, 3, 1, 2, 2, 1, 3, 1, 3, 1, 2, 1, 2, 1, 2, 2, 2, 2, 5, 4, 4, 5, 0, 3, 4, 5, 4, 4, 4, 4, 4, 0, 0, 1, 4, 1, 1, 5, 0, 2, 2, 3, 3, 2, 2, 0, 0, 3, 2, 4, 1, 1, 0, 0, 1, 2, 2, 2, 2, 2, 0, 4, 0, 1, 4, 1, 1, 2, 2, 3, 3, 1, 3, 2, 4, 4, 0, 0, 3, 4, 4, 2, 2, 2, 5, 5, 2, 5, 5, 5, 5, 4, 0, 2, 2, 0, 2, 4, 5, 4, 0, 3, 3, 5, 3, 3, 4, 2, 1, 5, 5, 0, 1, 3, 3, 3, 5, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 4, 2, 2, 0, 2, 2, 2, 2, 4, 3, 3, 5, 5, 4, 5, 2, 4, 4, 4, 5, 5, 4, 2, 2, 3, 3, 4, 4, 3, 1, 3, 2, 0, 5, 5, 5, 3, 4, 1, 4, 0, 5, 5, 0, 3, 0, 2, 3, 5, 3, 4, 2, 2, 3, 5, 1, 5, 3, 4, 5, 5, 2, 2, 4, 3, 3, 3, 3, 2, 2, 2, 2, 2, 3, 0, 0, 5, 1, 2, 3, 3, 1, 3, 2, 4, 3, 1, 3, 3, 3, 3, 3, 1, 0, 5, 4, 4, 1, 1, 3, 4, 4, 4, 4, 5, 4, 2, 2, 2, 2, 2, 2, 2, 3, 2, 2, 2, 1, 4, 0, 1, 4, 4, 4, 1, 2, 1, 5, 5, 2, 4, 4, 2, 2, 3, 1, 1, 0, 0, 2, 1, 0, 1, 5, 1, 2, 2, 3, 2, 0, 0, 3, 3, 3, 2, 2, 2, 1, 1, 1, 3, 3, 3, 5, 3, 5, 2, 3, 2, 3, 1, 5, 2, 2, 3, 3, 3, 1, 1, 1, 3, 3, 3, 3, 4, 4, 1, 4, 4, 1, 3, 3, 1, 0, 3, 5, 4, 4, 2, 4, 1, 0, 3, 1, 4, 1, 4, 4, 0, 5, 3, 2, 2, 2, 5, 5, 0, 4, 4, 1, 2, 2, 3, 3, 3, 5, 5, 5, 1, 5, 1, 4, 3, 1, 5, 5, 4, 4, 2, 3, 1, 0, 0, 5, 3, 1, 2, 1, 4, 1, 4, 1, 2, 2, 5, 1, 2, 1, 4, 5, 5, 1, 4, 5, 5, 1, 1, 5, 5, 3, 1, 0, 0, 1, 0, 0, 2, 0, 4, 3, 4, 3, 3, 1, 2, 3, 5, 3, 5, 5, 5, 5, 5, 3, 4, 4, 5, 4, 2, 2, 5, 1, 4, 4, 4, 3, 1, 5, 3, 1, 3, 4, 2, 2, 4, 2, 1, 5, 2, 2, 5, 5, 3, 3, 4, 1, 1, 2, 5, 3, 4, 4, 4, 5, 5, 1, 5, 5, 1, 5, 5, 1, 1, 1, 4, 2, 3, 5, 4, 1, 1, 4, 5, 2, 3, 1, 2, 1, 4, 1, 4, 1, 1, 1, 0, 0, 1, 5, 0, 2, 1, 1, 5, 1, 1, 3, 2, 3, 3, 1, 1, 2, 3, 2, 3, 5, 5, 5, 5, 5, 5, 5, 5, 5, 3, 3, 5, 2, 2, 3, 4, 4, 4, 4, 0, 3, 0, 3, 4, 1, 1, 3, 3, 0, 4, 5, 0, 0, 0, 2, 1, 3, 4, 5, 2, 1, 1, 3, 3, 4, 4, 4, 2, 2, 1, 5, 4, 0, 5, 5, 4, 3, 4, 5, 0, 3, 0, 3, 4, 4, 3, 3, 3, 3, 3, 3, 3, 5, 2, 0, 0, 1, 0, 0, 0, 3, 1, 5, 3, 2, 3, 5, 3, 3, 3, 1, 5, 5, 5, 5, 1, 2, 1, 4, 5, 4, 3, 3, 5, 5, 1, 4, 2, 5, 4, 1, 4, 4, 4, 4, 5, 5, 4, 3, 4, 3, 5, 3, 3, 1, 1, 0, 4, 4, 3, 1, 1, 1, 1, 3, 3, 3, 4, 3, 1, 4, 1, 1, 3, 5, 5, 5, 4, 4, 1, 3, 1, 4, 3, 3, 3, 1, 2, 2, 5, 3, 2, 5, 1, 3, 3, 5, 5, 4, 0, 3, 5, 5, 5, 1, 2, 2, 4, 1, 4, 5, 5, 5, 4, 5, 2, 1, 5, 4, 4, 0, 3, 5, 4, 1, 3, 3, 5, 4, 2, 1, 0, 1, 3, 2, 4, 3, 2, 4, 4, 1, 1, 0, 3, 3, 3, 1, 5])
可以發(fā)現,我們輸出的是測試集的內容。
那么很顯然,如果我們最終得到了預測值,我們就可以通過以下代碼來計算分類的正確數:
correct = int(pred[data.test_mask].eq(data.y[data.test_mask]).sum().item())
模型輸出的pred實際上包含了所有節(jié)點的預測值,而我們只需要取測試集中的內容,即:
pred[data.test_mask]
然后再與data.y[data.test_mask]進行比較,最后計算二者對應位置相等的個數即可。
2. 構造數據集
如果我們需要的數據集在PyG中沒有,我們就需要自己手動構造數據集。
例如對于一個無向圖,我們知道了其節(jié)點特征矩陣x:
x = torch.tensor([[-1, 1], [0, 1], [1, 3]], dtype=torch.float)
一共3個節(jié)點,每個節(jié)點具有兩個特征。
然后我們知道了節(jié)點間的鄰接關系:
edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]], dtype=torch.long)
一共4條邊,第一條邊為0->1,第2條邊為1->0。
然后我們就可以構造數據集:
data = Data(x=x, edge_index=edge_index)
有關GCN的實現放在下一篇文章!
以上就是PyG搭建GCN需要準備的數據格式的詳細內容,更多關于PyG搭建GCN數據格式的資料請關注腳本之家其它相關文章!
相關文章
numpy中實現ndarray數組返回符合特定條件的索引方法
下面小編就為大家分享一篇numpy中實現ndarray數組返回符合特定條件的索引方法,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧2018-04-04