pytorch?tensor按廣播賦值scatter_函數(shù)的用法
pytorch tensor按廣播賦值scatter函數(shù)
普通廣播
>>> import torch >>> a = torch.tensor([[1,2,3],[4,5,6]]) # 和a shape相同,但是用0填充 >>> b = torch.full_like(a,0) >>> c = torch.tensor([[0,0,1],[1,0,1]]) # 賦值索引 >>> c[:,0] tensor([0, 1]) # 賦值語句:使用廣播機制進行賦值 >>> b[range(n),c[:,0]] = 1 >>> b tensor([[1, 0, 0], ? ? ? ? [0, 1, 0]])
為什么會出現(xiàn)這樣的結果?
賦值語句的意思是:
- 1.range(n)表示對b的所有行進行賦值操作
- 2.c[:,0]] 表示執(zhí)行賦值操作的b的列索引,[0, 1] 表示第一行對索引為0的列進行操作(賦值為1);第二行對索引為1的列進行操作(賦值為1)
- 3.最右邊的1表示對應索引位置所賦的值
scatter函數(shù)
import torch label = torch.zeros(3, 6) #首先生成一個全零的多維數(shù)組 print("label:",label) a = torch.ones(3,5) b = [[0,1,2],[0,1,3],[1,2,3]] #這里需要解釋的是,b的行數(shù)要小于等于label的行數(shù),列數(shù)要小于等于a的列數(shù) print(a) label.scatter_(1,torch.LongTensor(b),a)? #參數(shù)解釋:‘1':需要賦值的維度,是label的維度;‘torch.LongTensor(b)':需要賦值的索引;‘a':要賦的值 print("new_label: ",label) label:? tensor([[0., 0., 0., 0., 0., 0.], ? ? ? ? [0., 0., 0., 0., 0., 0.], ? ? ? ? [0., 0., 0., 0., 0., 0.]]) tensor([[1., 1., 1., 1., 1.], ? ? ? ? [1., 1., 1., 1., 1.], ? ? ? ? [1., 1., 1., 1., 1.]]) new_label: ? tensor([[1., 1., 1., 0., 0., 0.], ? ? ? ? [1., 1., 0., 1., 0., 0.], ? ? ? ? [0., 1., 1., 1., 0., 0.]])
舉例
>>> b = torch.full_like(a,0) >>> b tensor([[0, 0, 0], ? ? ? ? [0, 0, 0]]) >>> c = torch.tensor([[0,0],[1,0]]) >>> c tensor([[0, 0], ? ? ? ? [1, 0]]) # 1表示對b的列進行賦值,以c的每一行的值作為b的列索引,一行一行地進行賦值 # c第一行 [0,0] 表示分別將b的 第一行 第0列、第0列 元素賦值為1 (重復操作了) # c第二行 [1,0] 表示 將b的 第1列、第0列 元素賦值為1 (逆序了) # 上面的這兩個賦值操作其實有重復的、逆序的 >>> b.scatter_(1,torch.LongTensor(c),1) >>> b tensor([[1, 0, 0], ? ? ? ? [1, 1, 0]])
scatter()和scatter_()的作用和區(qū)別
scatter和scatter_函數(shù)原型如下
Tensor.scatter_(dim, index, src, reduce=None)->Tensor scatter(input, dim, index, src)->Tensor
函數(shù)作用是將src中的數(shù)據按照dim中指定的維度和index中的索引寫入self中。
dim(int)
- 操作的維度index(LongTensor)
- 填充依據的索引,src(Tensor of float)
- 操作的src數(shù)據reduce(str, optional)
- reduce選擇運算方式,有’add’和’mutiply’方式, 默認為替換 dim(int)
在scatter中self指返回的tensor,scatter_中self指輸入的tensor自身。
對于一個三維張量,self更新結果如下
self[index[i][j][k]][j][k] = src[i][j][k] ?# if dim == 0 self[i][index[i][j][k]][k] = src[i][j][k] ?# if dim == 1 self[i][j][index[i][j][k]] = src[i][j][k] ?# if dim == 2
使用示例
>>> src = torch.arange(1, 11).reshape((2, 5)) >>> src tensor([[ 1, ?2, ?3, ?4, ?5], ? ? ? ? [ 6, ?7, ?8, ?9, 10]]) >>> index = torch.tensor([[0, 1, 2, 0]]) >>> torch.zeros(3, 5, dtype=src.dtype).scatter_(0, index, src) tensor([[1, 0, 0, 4, 0], ? ? ? ? [0, 2, 0, 0, 0], ? ? ? ? [0, 0, 3, 0, 0]])
dim=0, 說明按照行賦值,index[0][1]=1, 代表更改input中的第1行,src[0][1]=2,因此更改input中[1][1]中的元素為2
>>> index = torch.tensor([[0, 1, 2], [0, 1, 4]]) >>> torch.zeros(3, 5, dtype=src.dtype).scatter_(1, index, src) tensor([[1, 2, 3, 0, 0], ? ? ? ? [6, 7, 0, 0, 8], ? ? ? ? [0, 0, 0, 0, 0]])
dim,說明按照列賦值,index[0][1]=1, 代表更改input中的第1列,src[0][1]=2, 更改input中[0][1]元素為2
>>> torch.full((2, 4), 2.).scatter_(1, torch.tensor([[2], [3]]), ... ? ? ? ? ? ?1.23, reduce='multiply') tensor([[2.0000, 2.0000, 2.4600, 2.0000], ? ? ? ? [2.0000, 2.0000, 2.0000, 2.4600]]) >>> torch.full((2, 4), 2.).scatter_(1, torch.tensor([[2], [3]]), ... ? ? ? ? ? ?1.23, reduce='add') tensor([[2.0000, 2.0000, 3.2300, 2.0000], ? ? ? ? [2.0000, 2.0000, 2.0000, 3.2300]])
scatter的應用, one-hot編碼
def one_hot(x, n_class, dtype=torch.float32): ? ? # X shape: (batch), output shape: (batch, n_class) ? ? x=x.long() ? ? res=torch.zeros(x.shape[0], n_class, dtype=dtype, device=x.device) # shape為[batch, n_class]全零向量 ? ? res.scatter_(1, x.view(-1,1), 1)? ? ? # scatter_(input, dim, index, src)將src中數(shù)據根據index的索引按照dim的方向填進input中 ? ? return res x=torch.tensor([5,7,0]) one_hot(x, 10) tensor([[0., 0., 0., 0., 0., 1., 0., 0., 0., 0.], ? ? ? ? [0., 0., 0., 0., 0., 0., 0., 1., 0., 0.], ? ? ? ? [0., 0., 0., 0., 0., 0., 0., 0., 0., 1.]])
總結
以上為個人經驗,希望能給大家一個參考,也希望大家多多支持腳本之家。
相關文章
Python 中數(shù)組和數(shù)字相乘時的注意事項說明
這篇文章主要介紹了Python 中數(shù)組和數(shù)字相乘時的注意事項說明,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧2021-05-05關于np.meshgrid函數(shù)中的indexing參數(shù)問題
Meshgrid函數(shù)在二維與三維空間中用于生成坐標網格,便于進行圖像處理和空間數(shù)據分析,二維情況下,默認使用笛卡爾坐標系,而三維meshgrid則涉及不同的坐標軸取法,在三維情況下,可能會出現(xiàn)坐標軸排列序混亂2024-09-09