關于torch.scatter與torch_scatter庫的使用整理
最近在做圖結構相關的算法,scatter能把鄰接矩陣里的信息修改,或者把鄰居分組算個sum或者reduce,挺方便的,簡單整理一下。
torch.scatter 與 tensor._scatter
Pytorch自帶的函數,用來將作為 src
的tensor根據 index
的描述填充到 input
中,
形式如下:
ouput = torch.scatter(input, dim, index, src) # 或者是 input.scatter_(dim, index, src)
兩個方法的功能是相同的,而帶下劃線的 _scatter
方法是將原tensor input
直接修改了,不帶的則會返回一個新的tensor output
, input
不變。
其中 dim
決定 index
對應值是沿著哪個維度進行修改。而 src
為數據來源,當其為tensor張量時,shape要和index相同,這樣index中每個元素都能對應 src
中對應位置的信息。
理解 scatter
方法主要是要理解 index
實現(xiàn)的 src
和 input
之間的位置對應關系,舉個例子:
dim = 0 index = torch.tensor( [[0, 2, 2], [2, 1, 0]] )
dim
為0時,遵循的映射原則為: input[index[i][j]][j] = src[i][j]
.
也就是說,將位置 (i, j) 中 dim
對應的位置改為 index[i][j] 的值。
如位置(1,0),index[1][0]為2,則映射后的位置為(2,0),意味著 input
中(2,0)的位置被更改為 src
中(1,0)位置的值。
我個人形象理解是這些值會沿著dim方向滑動,上面例子中src[1][0]位置的值滑到2,成為input中的新值,這樣理解起來更形象一點。
基本理解了上面這個例子,多維情況和不同dim的情況都可以類推了。
需要注意:src和input的dtype需要相同,不然會報
Expected self.dtype to be equal to src.dtype
不一樣就先轉換再使用。
t = torch.arange(6).view(2, 3) t = t.to(torch.float32) print(t) output = torch.scatter(torch.zeros((3, 3)), 0, torch.tensor([[0, 2, 2], [2, 1, 0]]), t) print(torch.zeros((3, 3)).scatter_(0, torch.tensor([[0, 2, 2], [2, 1, 0]]), t))
輸出:
tensor([[0., 1., 2.],
[3., 4., 5.]])
tensor([[0., 0., 5.],
[0., 4., 0.],
[3., 1., 2.]])
torch_scatter庫
這個第三方庫對矩陣的分組處理這個概念做了更進一步的封裝,通過index來指定分組信息,將元素分組后進行對應處理,
最基礎的scatter方法形式如下:
torch_scatter.scatter(src, index, dim, out, dim_size, reduce)
src
: 數據源index
:分組序列dim
:分組遵循的維度out
:輸出的tensor,可以不指定直接讓函數輸出dim_size
:out不指定的時候,將輸出shape變?yōu)樵撝荡笮。籨im_size也不指定,就根據計算結果來reduce
:分組的操作,包括sum,mul,mean,min和max操作
這個方法理解關鍵在 index
的分組方法,
舉個例子:
dim = 1 index = torch.tensor([[0, 1, 1]])
torch_scatter.scatter
對 index
的順序是沒有特定規(guī)定的,相同數字對應的元素即為一組。
比如例子中,維度1上的第0個元素為一組,第1和2元素為另一組。
這樣,按照分組進行reduce定義的計算即可獲得輸出。如:
t = torch.arange(12).view(4, 3) print(t) t_s = torch_scatter.scatter(t, torch.tensor([[0, 1, 1]]), dim=1, reduce='sum') print(t_s)
輸出:
tensor([[0, 1, 2],
[3, 4, 5],
[6, 7, 8]])
tensor([[ 0, 3],
[ 3, 9],
[ 6, 15]])
可以看出,每行的后兩個元素求了和,與index定義相同。
要注意的是,index的 shape[0]
為1時,會自動對dim對應的維度上每一層進行相同的分組處理,如上例所示,index大小為(1, 3),即對src的三行數據都進行了分組處理。
而另一種分組方式,如需要每行分組不同,則需要index的shape和src的shape相同,如下例:
t = torch.arange(12).view(4, 3) print(t) t_s = torch_scatter.scatter(t, torch.tensor([[0, 1, 1], [1, 1, 0], [0, 1, 1], [1, 1, 0]]), dim=1, reduce='sum') print(t_s)
輸出:
tensor([[0, 1, 2],
[3, 4, 5],
[6, 7, 8]])
tensor([[ 0, 3],
[ 5, 7],
[ 6, 15]])
shape不相同時,則會報錯提示:
RuntimeError: The expanded size of the tensor (3) must match the existing size (2) at non-singleton dimension 0 .
同時,該庫還給出了另外兩種方法,分別為 torch_scatter.segment_coo
和 torch_scatter.segment_csr
.
torch_scatter.segment_coo
torch_scatter.segment_coo
和 scatter
的功能差不多,但它只支持index的shape[0]為1的狀態(tài),即每一行都為相同的分組方式。
同時,index中數值為順序排列,以提高計算速度。
torch_scatter.segment_csr
torch_scatter.segment_csr
的index格式不太相同,是一種區(qū)間格式,如[0, 2, 5],表示0,1為一組,2,3,4為一組,即取數值間的左閉右開區(qū)間。
這個方法是計算速度最快的。
官方文檔地址
torch_scatter庫doc
https://pytorch-scatter.readthedocs.io/en/latest/functions/scatter.html
torch.scatter文檔
總結
以上為個人經驗,希望能給大家一個參考,也希望大家多多支持腳本之家。
相關文章
Python+OpenCV圖片去水印的多種方案實現(xiàn)
這篇文章主要為大家總結了Python結合OpenCV的幾種常見的水印去除方式,簡單圖片去水印效果良好,有需要的小伙伴可以跟隨小編一起了解下2025-02-02python tensorflow學習之識別單張圖片的實現(xiàn)的示例
本篇文章主要介紹了python tensorflow學習之識別單張圖片的實現(xiàn)的示例,小編覺得挺不錯的,現(xiàn)在分享給大家,也給大家做個參考。一起跟隨小編過來看看吧2018-02-02python實現(xiàn)AHP算法的方法實例(層次分析法)
這篇文章主要給大家介紹了關于python實現(xiàn)AHP算法(層次分析法)的相關資料,文中通過示例代碼介紹的非常詳細,對大家的學習或者工作具有一定的參考學習價值,需要的朋友們下面隨著小編來一起學習學習吧2020-09-09