亚洲乱码中文字幕综合,中国熟女仑乱hd,亚洲精品乱拍国产一区二区三区,一本大道卡一卡二卡三乱码全集资源,又粗又黄又硬又爽的免费视频

詳解使用Pytorch Geometric實現(xiàn)GraphSAGE模型

 更新時間:2023年04月24日 10:31:39   作者:實力  
這篇文章主要為大家介紹了詳解使用Pytorch Geometric實現(xiàn)GraphSAGE模型示例詳解,有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進(jìn)步,早日升職加薪

正文

GraphSAGE是一種用于圖神經(jīng)網(wǎng)絡(luò)中的節(jié)點嵌入學(xué)習(xí)方法。它通過聚合節(jié)點鄰居的信息來生成節(jié)點的低維表示,使節(jié)點表示能夠更好地應(yīng)用于各種下游任務(wù),如節(jié)點分類、鏈路預(yù)測等。

圖構(gòu)建

在使用GraphSAGE對節(jié)點進(jìn)行嵌入學(xué)習(xí)之前,我們需要先將原始數(shù)據(jù)轉(zhuǎn)換為圖結(jié)構(gòu),并將其存儲為Pytorch Tensor格式。例如,我們可以使用networkx庫來構(gòu)建一個簡單的圖:

import networkx as nx

G = nx.karate_club_graph()

然后,我們可以使用Pytorch Geometric庫將NetworkX圖轉(zhuǎn)換為Pytorch Tensor格式。首先,我們需要安裝Pytorch Geometric并導(dǎo)入所需的類:

!pip install torch-geometric

from torch_geometric.datasets import Planetoid
from torch_geometric.transforms import NormalizeFeatures
from torch_geometric.utils.convert import from_networkx

接著,我們可以使用from_networkx函數(shù)將NetworkX圖轉(zhuǎn)換為Pytorch Tensor格式:

data = from_networkx(G)

此時,data對象包含了關(guān)于節(jié)點、邊及其屬性的信息,例如:

data.edge_index: 2x(#edges)的長整型張量,表示邊的起點和終點

  • data.x: n×dn \times dn×d 的浮點型張量,表示每個節(jié)點的特征向量(其中nnn是節(jié)點數(shù)量,ddd是特征維度)

注意,此時的data對象并未包含鄰居信息。接下來,我們將介紹如何使用Sampler方法采樣節(jié)點鄰居。

Sampler方法

GraphSAGE使用Sampler方法來聚合鄰居信息。在Pytorch Geometric中,可以使用Various Sampling方法來實現(xiàn)Sampler。例如,使用ClusterData方法將圖分成多個子圖,然后對每個子圖進(jìn)行采樣操作。

以下是ClusterData的使用示例:

from torch_geometric.utils import degree, to_undirected
from torch_geometric.transforms import ClusterData

# Convert the graph to an undirected graph, so we can aggregate neighbors in both directions.
G = to_undirected(G)

# Compute the degree of each node.
deg = degree(data.edge_index[0], num_nodes=data.num_nodes)

# Use METIS algorithm to partition the graph into multiple subgraphs.
cluster_data = ClusterData(data, num_parts=2, recursive=False, transform=NormalizeFeatures(),
                           degree=deg)

這里我們將原始圖分成兩個子圖,并對每個子圖進(jìn)行規(guī)范化特征轉(zhuǎn)換。注意,在使用ClusterData方法之前,需要將原始圖轉(zhuǎn)換為無向圖。

另一個常用的Sampler方法是在隨機游動時對鄰居進(jìn)行采樣,這種方法被稱為隨機游走采樣(Random Walk Sampling)。以下是隨機游走采樣的示例代碼:

from torch_geometric.utils import random_walk

# Perform random walk sampling to obtain node neighbor samples.
walk_length = 20  # The length of random walk trail.
num_steps = 4     # The number of nodes to sample from each step.
data.batch = None
data.edge_index = to_undirected(data.edge_index)  # Use undirected edge for random walk.

rw_data = random_walk(data.edge_index, walk_length=walk_length, num_steps=num_steps)

這里我們將使用一個長度為20、每個步驟采樣4個鄰居的隨機游走方法。注意,在使用隨機游走方法進(jìn)行采樣之前,需要使用無向邊。

GraphSAGE模型定義

GraphSAGE模型包含3個部分:1)圖卷積層;2)聚合器(Aggregator);3)輸出層。我們將在本節(jié)中介紹如何使用Pytorch實現(xiàn)這些組件。

首先,讓我們定義一個圖卷積層。圖卷積層的輸入是節(jié)點特征矩陣、鄰接矩陣和聚合器,輸出是新的節(jié)點特征矩陣。以下是圖卷積層的代碼實現(xiàn):

import torch.nn.functional as F
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn import global_mean_pool

class GraphSageConv(MessagePassing):
    def __init__(self, in_channels, out_channels, aggr='mean'):
        super(GraphSageConv, self).__init__(aggr=aggr)
        self.lin = nn.Linear(in_channels, out_channels)
        
    def forward(self, x, edge_index):
        return self.propagate(edge_index, x=x)
    
    def message(self, x_j):
        return x_j
    
    def update(self, aggr_out, x):
        return F.relu(self.lin(torch.cat([x, aggr_out], dim=1)))

這里我們繼承了MessagePassing類,并在__init__函數(shù)中定義了一個全連接層,用于將輸入特征矩陣x從 dind_{in}din? 維映射到 doutd_{out}dout? 維。在forward函數(shù)中,我們使用propagate方法來實現(xiàn)消息傳遞操作;在message函數(shù)中,我們僅向下游節(jié)點發(fā)送原始特征數(shù)據(jù);在update函數(shù)中,我們首先對聚合結(jié)果進(jìn)行ReLU非線性變換,然后再通過全連接層進(jìn)行節(jié)點特征的更新。

接下來,讓我們定義一個聚合器。聚合器的輸入是采樣得到的鄰居特征矩陣,輸出是新的節(jié)點嵌入向量。以下是聚合器的代碼實現(xiàn):

class MeanAggregator(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(MeanAggregator, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.lin = nn.Linear(input_dim, output_dim)
        
    def forward(self, neigh_mean):
        out = F.relu(self.lin(neigh_mean))
        return out

這里我們定義了一個簡單的均值聚合器,其將鄰居特征矩陣中每列的均值作為節(jié)點嵌入向量,并使用全連接層進(jìn)行維度變換。

最后,讓我們定義整個GraphSage模型。GraphSage模型包含2個圖卷積層和1個輸出層。以下是模型的代碼實現(xiàn):

class GraphSAGE(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers=2):
        super(GraphSAGE, self).__init__()
        self.conv1 = GraphSageConv(in_channels, hidden_channels)
        self.aggreg1 = MeanAggregator(hidden_channels, hidden_channels)
        self.conv2 = GraphSageConv(hidden_channels, out_channels)
        
    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = global_mean_pool(x, edge_index)  # Compute global mean over nodes.
        x = self.aggreg1(x)
        x = self.conv2(x, edge_index)
        return x

這里我們定義了一個包含2層GraphSAGE Conv層的神經(jīng)網(wǎng)絡(luò)。在最后一層GraphSAGE Conv層之后,我們使用global_mean_pool函數(shù)來計算節(jié)點嵌入的全局平均值。注意,在本示例中,我們僅保留了一個輸出節(jié)點,因此輸出矩陣的大小為1。如果需要輸出多個節(jié)點,則需要設(shè)置global_mean_pool函數(shù)中的參數(shù)。

模型訓(xùn)練與測試

在定義好模型后,我們可以使用Pytorch進(jìn)行模型訓(xùn)練和測試。首先,讓我們定義一個損失函數(shù)和優(yōu)化器:

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

這里我們使用交叉熵作為損失函數(shù),并使用Adam優(yōu)化器來更新模型參數(shù)。

接著,我們可以開始訓(xùn)練模型。以下是訓(xùn)練過程的代碼實現(xiàn):

num_epochs = 100

for epoch in range(num_epochs):
    model.train()
    
    optimizer.zero_grad()
    out = model(data.x, data.edge_index)
    loss = criterion(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()
    
    print('Epoch {:03d}, Loss: {:.4f}'.format(epoch, loss.item()))

這里我們遍歷所有數(shù)據(jù)樣本,計算預(yù)測結(jié)果和真實標(biāo)簽之間的交叉熵?fù)p失,并使用反向傳播來更新權(quán)重。我們在每個epoch結(jié)束后打印出當(dāng)前損失值。

最后,我們可以對模型進(jìn)行測試。以下是測試過程的代碼實現(xiàn):

model.eval()

with torch.no_grad():
    pred = model(data.x, data.edge_index)
    pred = pred.argmax(dim=1)

acc = (pred[data.test_mask] == data.y[data.test_mask]).sum().item() / data.test_mask.sum().item()
print('Test accuracy: {:.4f}'.format(acc))

這里我們使用測試集來計算模型的準(zhǔn)確率。注意,在執(zhí)行model.eval()后,我們需要使用torch.no_grad()包裝代碼塊,以禁止梯度計算。

總結(jié)

介紹了如何使用Pytorch Geometric實現(xiàn)GraphSAGE模型,包括構(gòu)建圖、定義Sampler方法、定義模型、訓(xùn)練和測試模型等步驟。GraphSAGE模型是一種常用的節(jié)點嵌入學(xué)習(xí)方法,可以應(yīng)用于各種下游任務(wù)中。

以上就是詳解使用Pytorch Geometric實現(xiàn)GraphSAGE模型的詳細(xì)內(nèi)容,更多關(guān)于Pytorch Geometric GraphSAGE的資料請關(guān)注腳本之家其它相關(guān)文章!

相關(guān)文章

  • 解決Python3錯誤:SyntaxError: unexpected EOF while parsin

    解決Python3錯誤:SyntaxError: unexpected EOF while

    這篇文章主要介紹了解決Python3錯誤:SyntaxError: unexpected EOF while parsin問題,具有很好的參考價值,希望對大家有所幫助。如有錯誤或未考慮完全的地方,望不吝賜教
    2022-07-07
  • Python實現(xiàn)一個自助取數(shù)查詢工具

    Python實現(xiàn)一個自助取數(shù)查詢工具

    在數(shù)據(jù)生產(chǎn)應(yīng)用部門,取數(shù)分析是一個很常見的需求,實際上業(yè)務(wù)人員需求時刻變化,最高效的方式是讓業(yè)務(wù)部門自己來取,減少不必要的重復(fù)勞動,本文介紹如何用Python實現(xiàn)一個自助取數(shù)查詢工具
    2021-06-06
  • 如何使用Python 打印各種三角形

    如何使用Python 打印各種三角形

    這篇文章主要介紹了如何使用Python 打印各種三角形,文中通過示例代碼介紹的非常詳細(xì),對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價值,需要的朋友可以參考下
    2019-06-06
  • python中的數(shù)據(jù)結(jié)構(gòu)比較

    python中的數(shù)據(jù)結(jié)構(gòu)比較

    這篇文章主要介紹了python中的數(shù)據(jù)結(jié)構(gòu)比較,本文給大家介紹的非常詳細(xì),具有一定的參考借鑒價值,需要的朋友可以參考下
    2019-05-05
  • python xmind 包使用詳解(其中解決導(dǎo)出的xmind文件 xmind8可以打開 xmind2020及之后版本打開報錯問題)

    python xmind 包使用詳解(其中解決導(dǎo)出的xmind文件 xmind8可以打開 xmind2020及之后版本打

    xmind8 可以打開xmind2020 報錯,如何解決這個問題呢?下面小編給大家?guī)砹藀ython xmind 包使用(其中解決導(dǎo)出的xmind文件 xmind8可以打開 xmind2020及之后版本打開報錯問題),感興趣的朋友一起看看吧
    2021-10-10
  • python實現(xiàn)詩歌游戲(類繼承)

    python實現(xiàn)詩歌游戲(類繼承)

    這篇文章主要為大家詳細(xì)介紹了python實現(xiàn)詩歌游戲,根據(jù)上句猜下句、猜作者、猜朝代、猜詩名,具有一定的參考價值,感興趣的小伙伴們可以參考一下
    2019-02-02
  • Python變量賦值的秘密分享

    Python變量賦值的秘密分享

    在Python中,我們令一個變量等于另外一個變量時,并不是把值傳遞給它,而是直接把指向的地址更改了,我們通過一個小例子來看看這個有趣的過程,需要的朋友可以參考下
    2018-04-04
  • 一文教會你用Python獲取網(wǎng)頁指定內(nèi)容

    一文教會你用Python獲取網(wǎng)頁指定內(nèi)容

    Python用做數(shù)據(jù)處理還是相當(dāng)不錯的,如果你想要做爬蟲,Python是很好的選擇,它有很多已經(jīng)寫好的類包,只要調(diào)用即可完成很多復(fù)雜的功能,下面這篇文章主要給大家介紹了關(guān)于Python獲取網(wǎng)頁指定內(nèi)容的相關(guān)資料,需要的朋友可以參考下
    2022-03-03
  • 詳解python中[-1]、[:-1]、[::-1]、[n::-1]使用方法

    詳解python中[-1]、[:-1]、[::-1]、[n::-1]使用方法

    這篇文章主要介紹了詳解python中[-1]、[:-1]、[::-1]、[n::-1]使用方法,文中通過示例代碼介紹的非常詳細(xì),對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧
    2021-04-04
  • Keras load_model 導(dǎo)入錯誤的解決方式

    Keras load_model 導(dǎo)入錯誤的解決方式

    這篇文章主要介紹了Keras load_model 導(dǎo)入錯誤的解決方式,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧
    2020-06-06

最新評論