亚洲乱码中文字幕综合,中国熟女仑乱hd,亚洲精品乱拍国产一区二区三区,一本大道卡一卡二卡三乱码全集资源,又粗又黄又硬又爽的免费视频

關于torch.scatter與torch_scatter庫的使用整理

 更新時間:2023年09月11日 14:36:18   作者:回爐重造P  
這篇文章主要介紹了關于torch.scatter與torch_scatter庫的使用整理,具有很好的參考價值,希望對大家有所幫助,如有錯誤或未考慮完全的地方,望不吝賜教

最近在做圖結構相關的算法,scatter能把鄰接矩陣里的信息修改,或者把鄰居分組算個sum或者reduce,挺方便的,簡單整理一下。

torch.scatter 與 tensor._scatter

Pytorch自帶的函數,用來將作為 src 的tensor根據 index 的描述填充到 input 中,

形式如下:

ouput = torch.scatter(input, dim, index, src)
# 或者是
input.scatter_(dim, index, src)

兩個方法的功能是相同的,而帶下劃線的 _scatter 方法是將原tensor input 直接修改了,不帶的則會返回一個新的tensor output , input 不變。

其中 dim 決定 index 對應值是沿著哪個維度進行修改。而 src 為數據來源,當其為tensor張量時,shape要和index相同,這樣index中每個元素都能對應 src 中對應位置的信息。

理解 scatter 方法主要是要理解 index 實現(xiàn)的 src input 之間的位置對應關系,舉個例子:

dim = 0
index = torch.tensor(
	[[0, 2, 2], 
	[2, 1, 0]]
)

dim 為0時,遵循的映射原則為: input[index[i][j]][j] = src[i][j] .

也就是說,將位置 (i, j) 中 dim 對應的位置改為 index[i][j] 的值。

如位置(1,0),index[1][0]為2,則映射后的位置為(2,0),意味著 input 中(2,0)的位置被更改為 src 中(1,0)位置的值。

我個人形象理解是這些值會沿著dim方向滑動,上面例子中src[1][0]位置的值滑到2,成為input中的新值,這樣理解起來更形象一點。

基本理解了上面這個例子,多維情況和不同dim的情況都可以類推了。

需要注意:src和input的dtype需要相同,不然會報

Expected self.dtype to be equal to src.dtype

不一樣就先轉換再使用。

t = torch.arange(6).view(2, 3)
t = t.to(torch.float32)
print(t)
output = torch.scatter(torch.zeros((3, 3)), 0, torch.tensor([[0, 2, 2], [2, 1, 0]]), t)
print(torch.zeros((3, 3)).scatter_(0, torch.tensor([[0, 2, 2], [2, 1, 0]]), t))

輸出:

tensor([[0., 1., 2.],
        [3., 4., 5.]])
tensor([[0., 0., 5.],
        [0., 4., 0.],
        [3., 1., 2.]])

torch_scatter庫

這個第三方庫對矩陣的分組處理這個概念做了更進一步的封裝,通過index來指定分組信息,將元素分組后進行對應處理,

最基礎的scatter方法形式如下:

torch_scatter.scatter(src, index, dim, out, dim_size, reduce)
  • src : 數據源
  • index :分組序列
  • dim :分組遵循的維度
  • out :輸出的tensor,可以不指定直接讓函數輸出
  • dim_size :out不指定的時候,將輸出shape變?yōu)樵撝荡笮。籨im_size也不指定,就根據計算結果來
  • reduce :分組的操作,包括sum,mul,mean,min和max操作

這個方法理解關鍵在 index 的分組方法,

舉個例子:

dim = 1
index = torch.tensor([[0, 1, 1]])

torch_scatter.scatter index 的順序是沒有特定規(guī)定的,相同數字對應的元素即為一組。

比如例子中,維度1上的第0個元素為一組,第1和2元素為另一組。

這樣,按照分組進行reduce定義的計算即可獲得輸出。如:

t = torch.arange(12).view(4, 3)
print(t)
t_s = torch_scatter.scatter(t, torch.tensor([[0, 1, 1]]), dim=1, reduce='sum')
print(t_s)

輸出:

tensor([[0, 1, 2],
        [3, 4, 5],
        [6, 7, 8]])
tensor([[ 0,  3],
        [ 3,  9],
        [ 6, 15]])

可以看出,每行的后兩個元素求了和,與index定義相同。

要注意的是,index的 shape[0] 為1時,會自動對dim對應的維度上每一層進行相同的分組處理,如上例所示,index大小為(1, 3),即對src的三行數據都進行了分組處理。

而另一種分組方式,如需要每行分組不同,則需要index的shape和src的shape相同,如下例:

t = torch.arange(12).view(4, 3)
print(t)
t_s = torch_scatter.scatter(t, torch.tensor([[0, 1, 1], [1, 1, 0], [0, 1, 1], [1, 1, 0]]), dim=1, reduce='sum')
print(t_s)

輸出:

tensor([[0, 1, 2],
        [3, 4, 5],
        [6, 7, 8]])
tensor([[ 0,  3],
        [ 5,  7],
        [ 6, 15]])

shape不相同時,則會報錯提示:

RuntimeError: The expanded size of the tensor (3) must match the existing size (2) at non-singleton dimension 0 .

同時,該庫還給出了另外兩種方法,分別為 torch_scatter.segment_coo torch_scatter.segment_csr .

torch_scatter.segment_coo

torch_scatter.segment_coo scatter 的功能差不多,但它只支持index的shape[0]為1的狀態(tài),即每一行都為相同的分組方式。

同時,index中數值為順序排列,以提高計算速度。

torch_scatter.segment_csr

torch_scatter.segment_csr 的index格式不太相同,是一種區(qū)間格式,如[0, 2, 5],表示0,1為一組,2,3,4為一組,即取數值間的左閉右開區(qū)間。

這個方法是計算速度最快的。

官方文檔地址

torch_scatter庫doc

https://pytorch-scatter.readthedocs.io/en/latest/functions/scatter.html

torch.scatter文檔

https://pytorch-cn.readthedocs.io/zh/latest/package_references/Tensor/#scatter_input-dim-index-src-tensor

總結

以上為個人經驗,希望能給大家一個參考,也希望大家多多支持腳本之家。

相關文章

  • python網絡爬蟲精解之XPath的使用說明

    python網絡爬蟲精解之XPath的使用說明

    XPath 是一門在 XML 文檔中查找信息的語言。XPath 可用來在 XML 文檔中對元素和屬性進行遍歷。XPath 是 W3C XSLT 標準的主要元素,并且 XQuery 和 XPointer 都構建于 XPath 表達之上
    2021-09-09
  • Python高效計算庫Joblib的入門教程

    Python高效計算庫Joblib的入門教程

    Joblib庫是一個用于在Python中進行高效計算的開源庫,提供內存映射和并行計算工具,本文就來介紹一下Joblib庫的使用,具有一定的參考價值,感興趣的可以了解一下
    2025-01-01
  • Python Tkinter GUI編程入門介紹

    Python Tkinter GUI編程入門介紹

    這篇文章主要介紹了Python Tkinter GUI編程入門介紹,本文講解了Tkinter介紹、Tkinter的使用、Tkinter的幾何管理器等內容,并給出了一個完整示例,需要的朋友可以參考下
    2015-03-03
  • django使用xlwt導出excel文件實例代碼

    django使用xlwt導出excel文件實例代碼

    這篇文章主要介紹了django使用xlwt導出excel文件實例代碼,分享了相關代碼示例,小編覺得還是挺不錯的,具有一定借鑒價值,需要的朋友可以參考下
    2018-02-02
  • python數據結構之二叉樹的統(tǒng)計與轉換實例

    python數據結構之二叉樹的統(tǒng)計與轉換實例

    這篇文章主要介紹了python數據結構之二叉樹的統(tǒng)計與轉換實例,例如統(tǒng)計二叉樹的葉子、分支節(jié)點,以及二叉樹的左右兩樹互換等,需要的朋友可以參考下
    2014-04-04
  • Python+OpenCV圖片去水印的多種方案實現(xiàn)

    Python+OpenCV圖片去水印的多種方案實現(xiàn)

    這篇文章主要為大家總結了Python結合OpenCV的幾種常見的水印去除方式,簡單圖片去水印效果良好,有需要的小伙伴可以跟隨小編一起了解下
    2025-02-02
  • Python階乘求和的代碼詳解

    Python階乘求和的代碼詳解

    在本篇文章里小編給大家整理的是關于Python階乘求和的代碼實例,有需要的朋友們可以跟著學習下。
    2020-02-02
  • python tensorflow學習之識別單張圖片的實現(xiàn)的示例

    python tensorflow學習之識別單張圖片的實現(xiàn)的示例

    本篇文章主要介紹了python tensorflow學習之識別單張圖片的實現(xiàn)的示例,小編覺得挺不錯的,現(xiàn)在分享給大家,也給大家做個參考。一起跟隨小編過來看看吧
    2018-02-02
  • python實現(xiàn)AHP算法的方法實例(層次分析法)

    python實現(xiàn)AHP算法的方法實例(層次分析法)

    這篇文章主要給大家介紹了關于python實現(xiàn)AHP算法(層次分析法)的相關資料,文中通過示例代碼介紹的非常詳細,對大家的學習或者工作具有一定的參考學習價值,需要的朋友們下面隨著小編來一起學習學習吧
    2020-09-09
  • pandas數據清洗(缺失值和重復值的處理)

    pandas數據清洗(缺失值和重復值的處理)

    這篇文章主要介紹了pandas數據清洗(缺失值和重復值的處理),pandas對大數據有很多便捷的清洗用法,尤其針對缺失值和重復值,詳細介紹感興趣的小伙伴可以參考下面文章內容
    2022-08-08

最新評論