Pytorch-Geometric中的Message?Passing使用及說(shuō)明
Pytorch-Geometric中Message Passing使用
圖中的卷積計(jì)算通常被稱(chēng)為鄰域聚合或者消息傳遞 (neighborhood aggregation or message passing).
定義
為節(jié)點(diǎn)i在第(k−1)層的特征,ej,i表示節(jié)點(diǎn)j到節(jié)點(diǎn)i的邊特征,在GNN中消息傳遞可以表示為

其中 □ 表示具有置換不變性并且可微的函數(shù),例如 sum, mean, max 等, γ 和 ? 表示可微函數(shù)。
在 PyTorch Gemetric 中,所有卷積算子都是由 MessagePassing 類(lèi)派生而來(lái),理解 MessagePasing 有助于我們理解 PyG 中消息傳遞的計(jì)算方式和編寫(xiě)自定義的卷積。
在自定義卷積中,用戶(hù)只需定義消息傳遞函數(shù) ? message(), 節(jié)點(diǎn)更新函數(shù) γ update() 以及聚合方式 aggr='add', aggr='mean' 或則 aggr=max.
具體函數(shù)說(shuō)明如下
MessagePassing(aggr='add', flow='source_to_target', node_dim=-2)定義聚合計(jì)算的方式 ('add', 'mean'ormax) 以及消息的傳遞方向 (source_to_targetortarget_to_source). 在 PyG 中,中心節(jié)點(diǎn)為目標(biāo) target,鄰域節(jié)點(diǎn)為源 source.node_dim為消息聚合的維度MessagePassing.propagate(edge_index, size=None, **kwargs):該函數(shù)接受邊信息edge_index和其他額外的數(shù)據(jù)來(lái)執(zhí)行消息傳遞并更新節(jié)點(diǎn)嵌入MessagePassing.message(...):該函數(shù)的作用是計(jì)算節(jié)點(diǎn)消息,就是公式中的函數(shù) ? \phi ? . 如果flow='source_to_target',那么消息將由鄰域節(jié)點(diǎn) j j j 傳向中心節(jié)點(diǎn) i i i ;如果flow='target_to_source',消息則由中心節(jié)點(diǎn) i i i 傳向鄰域節(jié)點(diǎn) j j j . 傳入?yún)?shù)的節(jié)點(diǎn)類(lèi)型可以通過(guò)變量名后綴來(lái)確定,例如中心節(jié)點(diǎn)嵌入變量一般以_i為結(jié)尾,鄰域節(jié)點(diǎn)嵌入變量以x_j為結(jié)尾MessagePassing.update(arr_out, ...):該函數(shù)為節(jié)點(diǎn)嵌入的更新函數(shù) γ \gamma γ , 輸入?yún)?shù)為聚合函數(shù)MessagePassing.aggregate計(jì)算的結(jié)果
為了更好的理解 PyG 中 MessagePassing 的計(jì)算過(guò)程,我們來(lái)分析一下源代碼。
class MessagePassing(torch.nn.Module):
special_args: Set[str] = {
'edge_index', 'adj_t', 'edge_index_i', 'edge_index_j', 'size',
'size_i', 'size_j', 'ptr', 'index', 'dim_size'
}
def __init__(self, aggr: Optional[str] = "add",
flow: str = "source_to_target", node_dim: int = -2):
super(MessagePassing, self).__init__()
self.aggr = aggr
assert self.aggr in ['add', 'mean', 'max', None]
self.flow = flow
assert self.flow in ['source_to_target', 'target_to_source']
self.node_dim = node_dim
self.inspector = Inspector(self)
self.inspector.inspect(self.message)
self.inspector.inspect(self.aggregate, pop_first=True)
self.inspector.inspect(self.message_and_aggregate, pop_first=True)
self.inspector.inspect(self.update, pop_first=True)
self.__user_args__ = self.inspector.keys(
['message', 'aggregate', 'update']).difference(self.special_args)
self.__fused_user_args__ = self.inspector.keys(
['message_and_aggregate', 'update']).difference(self.special_args)
# Support for "fused" message passing.
self.fuse = self.inspector.implements('message_and_aggregate')
# Support for GNNExplainer.
self.__explain__ = False
self.__edge_mask__ = None
在初始化函數(shù)中,MessagePassing 定義了一個(gè) Inspector . Inspector 的中文意思是檢查員的意思,這個(gè)類(lèi)的作用就是檢查各個(gè)函數(shù)的輸入?yún)?shù),并保存到 Inspector的參數(shù)列表字典中 Inspector.params中。
如果 message的輸入?yún)?shù)為 x_i, x_j,那么Inspector.params['message']={'x_i': Parameter, 'x_j': Parameter} (注:這里僅作示意,實(shí)際 Inspector.params['message'] 類(lèi)型為 OrderedDict). Inspector.implements 檢查函數(shù)是否實(shí)現(xiàn).
MessagePasing 中最核心的是 propgate 函數(shù),假設(shè)鄰接矩陣 edge_index 的類(lèi)型為 Torch.LongTensor,消息由 edge_index[0] 傳向 edge_index[1] ,代碼實(shí)現(xiàn)如下
def propagate(self, edge_index: Adj, size: Size = None, **kwargs):
# 為了簡(jiǎn)化問(wèn)題,這里不討論 edge_index 為 SparseTensor 的情況,感興趣的可閱讀 PyG 原始代碼
size = self.__check_input__(edge_index, size)
coll_dict = self.__collect__(self.__user_args__, edge_index, size,
kwargs)
msg_kwargs = self.inspector.distribute('message', coll_dict)
out = self.message(**msg_kwargs)
aggr_kwargs = self.inspector.distribute('aggregate', coll_dict)
out = self.aggregate(out, **aggr_kwargs)
update_kwargs = self.inspector.distribute('update', coll_dict)
return self.update(out, **update_kwargs)
在這段代碼中,首先是檢查節(jié)點(diǎn)數(shù)量和用戶(hù)自定義的輸入變量,然后依次執(zhí)行 message, aggregate 和 update 函數(shù)。
如果是自定義圖卷積,一般會(huì)重寫(xiě) message 和 update,這一點(diǎn)隨后再以 GCN 為例解釋?zhuān)@里首先來(lái)看一下 aggregate 的實(shí)現(xiàn)
def aggregate(self, inputs: Tensor, index: Tensor,
ptr: Optional[Tensor] = None,
dim_size: Optional[int] = None) -> Tensor:
if ptr is not None:
ptr = expand_left(ptr, dim=self.node_dim, dims=inputs.dim())
return segment_csr(inputs, ptr, reduce=self.aggr)
else:
return scatter(inputs, index, dim=self.node_dim, dim_size=dim_size,
reduce=self.aggr)
ptr 變量是針對(duì)鄰接矩陣 edge_index 為 SparseTensor的情況,此處暫且不論
inputs為 message計(jì)算得到的消息, index 就是待更新節(jié)點(diǎn)的索引,實(shí)際上就是 edge_index_i. 聚合計(jì)算通過(guò) scatter 函數(shù)實(shí)現(xiàn)。scatter 具體實(shí)現(xiàn)參考鏈接
下面以 GCN 為例,我們來(lái)看一下 MessagePassing 的計(jì)算過(guò)程。
GCN 的計(jì)算公式如下

實(shí)際計(jì)算工程可以分為下面幾步
- 1.在鄰接矩陣中增加自循環(huán),即把鄰接矩陣的對(duì)角線(xiàn)上的元素設(shè)為1
- 2.對(duì)節(jié)點(diǎn)特征矩陣做線(xiàn)性變換
- 3.計(jì)算節(jié)點(diǎn)的歸一化系數(shù),也就是節(jié)點(diǎn)度乘積的開(kāi)方
- 4.對(duì)節(jié)點(diǎn)特征做歸一化處理
- 5.聚合(求和)節(jié)點(diǎn)特征得到新的節(jié)點(diǎn)嵌入
代碼如下
import torch
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree
class GCNConv(MessagePassing):
def __init__(self, in_channels, out_channels):
super(GCNConv, self).__init__(aggr='add') # "Add" aggregation (Step 5).
self.lin = torch.nn.Linear(in_channels, out_channels)
def forward(self, x, edge_index):
# x has shape [N, in_channels]
# edge_index has shape [2, E]
# Step 1: Add self-loops to the adjacency matrix.
edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
# Step 2: Linearly transform node feature matrix.
x = self.lin(x)
# Step 3: Compute normalization.
row, col = edge_index
deg = degree(col, x.size(0), dtype=x.dtype)
deg_inv_sqrt = deg.pow(-0.5)
deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]
# Step 4-5: Start propagating messages.
return self.propagate(edge_index, x=x, norm=norm)
def message(self, x_j, norm):
# x_j has shape [E, out_channels]
# Step 4: Normalize node features.
return norm.view(-1, 1) * x_j
在 forward 函數(shù)中,首先是給節(jié)點(diǎn)邊增加自循環(huán)。設(shè)輸入變量如下
edge_index = torch.tensor([[0, 0, 2], [1, 2, 3]], dtype=torch.long) x = torch.rand((4, 3)) conv = GCNConv(3, 8)
注意到默認(rèn)消息傳遞方向?yàn)?source_to_target,此時(shí)edge_index[0]=x_j 為 source, edge_index[1]=x_i 為 target.
在 GCN 中,第一步是增加節(jié)點(diǎn)的自循環(huán),add_self_loops 計(jì)算前后變化如下
# before add_self_loops
# edge_index=
tensor([[0, 0, 2],
[1, 2, 3]])
# after add_self_loops
# edge_index=
tensor([[0, 0, 2, 0, 1, 2, 3],
[1, 2, 3, 0, 1, 2, 3]])
# norm=
tensor([0.7071, 0.7071, 0.5000, 1.0000, 0.5000, 0.5000, 0.5000]
此處的 propagate 的輸出參數(shù)由 edge_index, x, norm , edge_index 是 propagete 必須輸入的參數(shù),x, norm 為用戶(hù)自定義參數(shù)。
在 __collect__ 會(huì)根據(jù)變量名稱(chēng)來(lái)收集 message 需要的輸入?yún)?shù)。
在 GCN 中,norm 保持不變,x 將被映射到 x_j ,并且經(jīng)過(guò) __lift__ 函數(shù),其值也會(huì)發(fā)生變化。__lift__ 函數(shù)如下
def __lift__(self, src, edge_index, dim):
if isinstance(edge_index, Tensor):
index = edge_index[dim]
return src.index_select(self.node_dim, index)
在本例中,輸入的特征 shape=[4, 8],經(jīng)過(guò) __lift__ 后,節(jié)點(diǎn)特征 shape=[7, 8] . 經(jīng)過(guò) message 計(jì)算后,就可以執(zhí)行 aggregate 和 update 了。
總結(jié)
以上為個(gè)人經(jīng)驗(yàn),希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
相關(guān)文章
Python用戶(hù)自定義異常的實(shí)現(xiàn)
這篇文章主要介紹了Python用戶(hù)自定義異常的實(shí)現(xiàn),文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來(lái)一起學(xué)習(xí)學(xué)習(xí)吧2020-12-12
python實(shí)現(xiàn)比較類(lèi)的兩個(gè)instance(對(duì)象)是否相等的方法分析
這篇文章主要介紹了python實(shí)現(xiàn)比較類(lèi)的兩個(gè)instance(對(duì)象)是否相等的方法,結(jié)合實(shí)例形式分析了Python判斷類(lèi)的實(shí)例是否相等的判斷操作實(shí)現(xiàn)技巧,需要的朋友可以參考下2019-06-06
你們要的Python繪畫(huà)3D太陽(yáng)系詳細(xì)代碼
這篇文章主要給大家介紹了關(guān)于如何利用Python 繪畫(huà)3D太陽(yáng)系,文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來(lái)一起學(xué)習(xí)學(xué)習(xí)吧2021-10-10
Python中用字符串調(diào)用函數(shù)或方法示例代碼
字符串作為python中常用的數(shù)據(jù)類(lèi)型,掌握字符串的常用方法十分必要。下面這篇文章主要給大家介紹了關(guān)于Python中通過(guò)字符串調(diào)用函數(shù)或方法的相關(guān)資料,需要的朋友可以參考借鑒,下面來(lái)一起看看吧。2017-08-08
python目錄操作之python遍歷文件夾后將結(jié)果存儲(chǔ)為xml
需求是獲取服務(wù)器某個(gè)目錄下的某些類(lèi)型的文件,考慮到服務(wù)器即有Linux、又有Windows,所以寫(xiě)了一個(gè)Python小程序來(lái)完成這項(xiàng)工作,大家參考使用吧2014-01-01
Python?OpenCV實(shí)現(xiàn)3種濾鏡效果實(shí)例
opencv是一個(gè)很強(qiáng)大的庫(kù),支持多個(gè)編程語(yǔ)言,下面這篇文章主要給大家介紹了關(guān)于Python?OpenCV實(shí)現(xiàn)3種濾鏡效果的相關(guān)資料,文中通過(guò)示例代碼介紹的非常詳細(xì),需要的朋友可以參考下2022-04-04
python神經(jīng)網(wǎng)絡(luò)ResNet50模型的復(fù)現(xiàn)詳解
這篇文章主要為大家介紹了python神經(jīng)網(wǎng)絡(luò)ResNet50模型的復(fù)現(xiàn)詳解,有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進(jìn)步,早日升職加薪2022-05-05

