詳解使用Pytorch Geometric實現(xiàn)GraphSAGE模型
正文
GraphSAGE是一種用于圖神經(jīng)網(wǎng)絡(luò)中的節(jié)點嵌入學(xué)習(xí)方法。它通過聚合節(jié)點鄰居的信息來生成節(jié)點的低維表示,使節(jié)點表示能夠更好地應(yīng)用于各種下游任務(wù),如節(jié)點分類、鏈路預(yù)測等。
圖構(gòu)建
在使用GraphSAGE對節(jié)點進(jìn)行嵌入學(xué)習(xí)之前,我們需要先將原始數(shù)據(jù)轉(zhuǎn)換為圖結(jié)構(gòu),并將其存儲為Pytorch Tensor格式。例如,我們可以使用networkx庫來構(gòu)建一個簡單的圖:
import networkx as nx G = nx.karate_club_graph()
然后,我們可以使用Pytorch Geometric庫將NetworkX圖轉(zhuǎn)換為Pytorch Tensor格式。首先,我們需要安裝Pytorch Geometric并導(dǎo)入所需的類:
!pip install torch-geometric from torch_geometric.datasets import Planetoid from torch_geometric.transforms import NormalizeFeatures from torch_geometric.utils.convert import from_networkx
接著,我們可以使用from_networkx
函數(shù)將NetworkX圖轉(zhuǎn)換為Pytorch Tensor格式:
data = from_networkx(G)
此時,data
對象包含了關(guān)于節(jié)點、邊及其屬性的信息,例如:
data.edge_index: 2x(#edges)的長整型張量,表示邊的起點和終點
data.x
: n×dn \times dn×d 的浮點型張量,表示每個節(jié)點的特征向量(其中nnn是節(jié)點數(shù)量,ddd是特征維度)
注意,此時的data
對象并未包含鄰居信息。接下來,我們將介紹如何使用Sampler方法采樣節(jié)點鄰居。
Sampler方法
GraphSAGE使用Sampler方法來聚合鄰居信息。在Pytorch Geometric中,可以使用Various Sampling方法來實現(xiàn)Sampler。例如,使用ClusterData方法將圖分成多個子圖,然后對每個子圖進(jìn)行采樣操作。
以下是ClusterData
的使用示例:
from torch_geometric.utils import degree, to_undirected from torch_geometric.transforms import ClusterData # Convert the graph to an undirected graph, so we can aggregate neighbors in both directions. G = to_undirected(G) # Compute the degree of each node. deg = degree(data.edge_index[0], num_nodes=data.num_nodes) # Use METIS algorithm to partition the graph into multiple subgraphs. cluster_data = ClusterData(data, num_parts=2, recursive=False, transform=NormalizeFeatures(), degree=deg)
這里我們將原始圖分成兩個子圖,并對每個子圖進(jìn)行規(guī)范化特征轉(zhuǎn)換。注意,在使用ClusterData方法之前,需要將原始圖轉(zhuǎn)換為無向圖。
另一個常用的Sampler方法是在隨機游動時對鄰居進(jìn)行采樣,這種方法被稱為隨機游走采樣(Random Walk Sampling)。以下是隨機游走采樣的示例代碼:
from torch_geometric.utils import random_walk # Perform random walk sampling to obtain node neighbor samples. walk_length = 20 # The length of random walk trail. num_steps = 4 # The number of nodes to sample from each step. data.batch = None data.edge_index = to_undirected(data.edge_index) # Use undirected edge for random walk. rw_data = random_walk(data.edge_index, walk_length=walk_length, num_steps=num_steps)
這里我們將使用一個長度為20、每個步驟采樣4個鄰居的隨機游走方法。注意,在使用隨機游走方法進(jìn)行采樣之前,需要使用無向邊。
GraphSAGE模型定義
GraphSAGE模型包含3個部分:1)圖卷積層;2)聚合器(Aggregator);3)輸出層。我們將在本節(jié)中介紹如何使用Pytorch實現(xiàn)這些組件。
首先,讓我們定義一個圖卷積層。圖卷積層的輸入是節(jié)點特征矩陣、鄰接矩陣和聚合器,輸出是新的節(jié)點特征矩陣。以下是圖卷積層的代碼實現(xiàn):
import torch.nn.functional as F from torch_geometric.nn.conv import MessagePassing from torch_geometric.nn import global_mean_pool class GraphSageConv(MessagePassing): def __init__(self, in_channels, out_channels, aggr='mean'): super(GraphSageConv, self).__init__(aggr=aggr) self.lin = nn.Linear(in_channels, out_channels) def forward(self, x, edge_index): return self.propagate(edge_index, x=x) def message(self, x_j): return x_j def update(self, aggr_out, x): return F.relu(self.lin(torch.cat([x, aggr_out], dim=1)))
這里我們繼承了MessagePassing
類,并在__init__
函數(shù)中定義了一個全連接層,用于將輸入特征矩陣x
從 dind_{in}din? 維映射到 doutd_{out}dout? 維。在forward
函數(shù)中,我們使用propagate
方法來實現(xiàn)消息傳遞操作;在message
函數(shù)中,我們僅向下游節(jié)點發(fā)送原始特征數(shù)據(jù);在update
函數(shù)中,我們首先對聚合結(jié)果進(jìn)行ReLU非線性變換,然后再通過全連接層進(jìn)行節(jié)點特征的更新。
接下來,讓我們定義一個聚合器。聚合器的輸入是采樣得到的鄰居特征矩陣,輸出是新的節(jié)點嵌入向量。以下是聚合器的代碼實現(xiàn):
class MeanAggregator(nn.Module): def __init__(self, input_dim, output_dim): super(MeanAggregator, self).__init__() self.input_dim = input_dim self.output_dim = output_dim self.lin = nn.Linear(input_dim, output_dim) def forward(self, neigh_mean): out = F.relu(self.lin(neigh_mean)) return out
這里我們定義了一個簡單的均值聚合器,其將鄰居特征矩陣中每列的均值作為節(jié)點嵌入向量,并使用全連接層進(jìn)行維度變換。
最后,讓我們定義整個GraphSage模型。GraphSage模型包含2個圖卷積層和1個輸出層。以下是模型的代碼實現(xiàn):
class GraphSAGE(nn.Module): def __init__(self, in_channels, hidden_channels, out_channels, num_layers=2): super(GraphSAGE, self).__init__() self.conv1 = GraphSageConv(in_channels, hidden_channels) self.aggreg1 = MeanAggregator(hidden_channels, hidden_channels) self.conv2 = GraphSageConv(hidden_channels, out_channels) def forward(self, x, edge_index): x = self.conv1(x, edge_index) x = global_mean_pool(x, edge_index) # Compute global mean over nodes. x = self.aggreg1(x) x = self.conv2(x, edge_index) return x
這里我們定義了一個包含2層GraphSAGE Conv層的神經(jīng)網(wǎng)絡(luò)。在最后一層GraphSAGE Conv層之后,我們使用global_mean_pool
函數(shù)來計算節(jié)點嵌入的全局平均值。注意,在本示例中,我們僅保留了一個輸出節(jié)點,因此輸出矩陣的大小為1。如果需要輸出多個節(jié)點,則需要設(shè)置global_mean_pool
函數(shù)中的參數(shù)。
模型訓(xùn)練與測試
在定義好模型后,我們可以使用Pytorch進(jìn)行模型訓(xùn)練和測試。首先,讓我們定義一個損失函數(shù)和優(yōu)化器:
criterion = nn.CrossEntropyLoss() optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
這里我們使用交叉熵作為損失函數(shù),并使用Adam優(yōu)化器來更新模型參數(shù)。
接著,我們可以開始訓(xùn)練模型。以下是訓(xùn)練過程的代碼實現(xiàn):
num_epochs = 100 for epoch in range(num_epochs): model.train() optimizer.zero_grad() out = model(data.x, data.edge_index) loss = criterion(out[data.train_mask], data.y[data.train_mask]) loss.backward() optimizer.step() print('Epoch {:03d}, Loss: {:.4f}'.format(epoch, loss.item()))
這里我們遍歷所有數(shù)據(jù)樣本,計算預(yù)測結(jié)果和真實標(biāo)簽之間的交叉熵?fù)p失,并使用反向傳播來更新權(quán)重。我們在每個epoch結(jié)束后打印出當(dāng)前損失值。
最后,我們可以對模型進(jìn)行測試。以下是測試過程的代碼實現(xiàn):
model.eval() with torch.no_grad(): pred = model(data.x, data.edge_index) pred = pred.argmax(dim=1) acc = (pred[data.test_mask] == data.y[data.test_mask]).sum().item() / data.test_mask.sum().item() print('Test accuracy: {:.4f}'.format(acc))
這里我們使用測試集來計算模型的準(zhǔn)確率。注意,在執(zhí)行model.eval()
后,我們需要使用torch.no_grad()
包裝代碼塊,以禁止梯度計算。
總結(jié)
介紹了如何使用Pytorch Geometric實現(xiàn)GraphSAGE模型,包括構(gòu)建圖、定義Sampler方法、定義模型、訓(xùn)練和測試模型等步驟。GraphSAGE模型是一種常用的節(jié)點嵌入學(xué)習(xí)方法,可以應(yīng)用于各種下游任務(wù)中。
以上就是詳解使用Pytorch Geometric實現(xiàn)GraphSAGE模型的詳細(xì)內(nèi)容,更多關(guān)于Pytorch Geometric GraphSAGE的資料請關(guān)注腳本之家其它相關(guān)文章!
- PyTorch模型轉(zhuǎn)換為ONNX格式實現(xiàn)過程詳解
- 利用Pytorch實現(xiàn)ResNet網(wǎng)絡(luò)構(gòu)建及模型訓(xùn)練
- 詳解利用Pytorch實現(xiàn)ResNet網(wǎng)絡(luò)之評估訓(xùn)練模型
- pytorch模型的保存加載與續(xù)訓(xùn)練詳解
- AMP?Tensor?Cores節(jié)省內(nèi)存PyTorch模型詳解
- 詳解?PyTorch?Lightning模型部署到生產(chǎn)服務(wù)中
- Pytorch模型定義與深度學(xué)習(xí)自查手冊
- 一文詳解如何實現(xiàn)PyTorch模型編譯
相關(guān)文章
解決Python3錯誤:SyntaxError: unexpected EOF while
這篇文章主要介紹了解決Python3錯誤:SyntaxError: unexpected EOF while parsin問題,具有很好的參考價值,希望對大家有所幫助。如有錯誤或未考慮完全的地方,望不吝賜教2022-07-07python中的數(shù)據(jù)結(jié)構(gòu)比較
這篇文章主要介紹了python中的數(shù)據(jù)結(jié)構(gòu)比較,本文給大家介紹的非常詳細(xì),具有一定的參考借鑒價值,需要的朋友可以參考下2019-05-05python xmind 包使用詳解(其中解決導(dǎo)出的xmind文件 xmind8可以打開 xmind2020及之后版本打
xmind8 可以打開xmind2020 報錯,如何解決這個問題呢?下面小編給大家?guī)砹藀ython xmind 包使用(其中解決導(dǎo)出的xmind文件 xmind8可以打開 xmind2020及之后版本打開報錯問題),感興趣的朋友一起看看吧2021-10-10一文教會你用Python獲取網(wǎng)頁指定內(nèi)容
Python用做數(shù)據(jù)處理還是相當(dāng)不錯的,如果你想要做爬蟲,Python是很好的選擇,它有很多已經(jīng)寫好的類包,只要調(diào)用即可完成很多復(fù)雜的功能,下面這篇文章主要給大家介紹了關(guān)于Python獲取網(wǎng)頁指定內(nèi)容的相關(guān)資料,需要的朋友可以參考下2022-03-03詳解python中[-1]、[:-1]、[::-1]、[n::-1]使用方法
這篇文章主要介紹了詳解python中[-1]、[:-1]、[::-1]、[n::-1]使用方法,文中通過示例代碼介紹的非常詳細(xì),對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧2021-04-04Keras load_model 導(dǎo)入錯誤的解決方式
這篇文章主要介紹了Keras load_model 導(dǎo)入錯誤的解決方式,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧2020-06-06