Pytorch中torch.repeat_interleave()函數(shù)使用及說明
torch.repeat_interleave()函數(shù)解析
1.函數(shù)說明
官網(wǎng):torch.repeat_interleave(),函數(shù)說明如下圖所示:
2. 函數(shù)原型
torch.repeat_interleave(input, repeats, dim=None) → Tensor
3. 函數(shù)功能
沿著指定的維度重復張量的元素
4. 輸入?yún)?shù)
1)input (類型:torch.Tensor):輸入張量
2)repeats(類型:int或torch.Tensor):每個元素的重復次數(shù)
3)dim(類型:int)需要重復的維度。默認情況下dim=None,表示將把給定的輸入張量展平(flatten)為向量,然后將每個元素重復repeats次,并返回重復后的張量。
5. 注意
1) 如果不指定dim,則默認將輸入張量扁平化(維數(shù)是1,因此這時repeats必須是一個數(shù),不能是數(shù)組),并且返回一個扁平化的輸出數(shù)組。
2) 返回的數(shù)組與輸入數(shù)組維數(shù)相同,并且除了給定的維度dim,其他維度大小與輸入數(shù)組相應維度大小相同
3) repeats:如果傳入數(shù)組,則必須是tensor格式。并且只能是一維數(shù)組,數(shù)組長度與輸入數(shù)組input的dim維度大小相同
6. 代碼例子
6.1 輸入一維張量,不指定dim,重復次數(shù)為2次,表示將把給定的輸入張量展平(flatten)為向量,然后將每個元素重復2次,并返回重復后的張量。
a = torch.randn(5) a,torch.repeat_interleave(a,2)
輸出結果如下所示:
(tensor([ 0.4030, -1.1536, -2.4513, 1.1454, -0.8818]),
tensor([ 0.4030, 0.4030, -1.1536, -1.1536, -2.4513, -2.4513, 1.1454, 1.1454,
-0.8818, -0.8818]))
6.2 輸入二維張量,不指定dim,重復次數(shù)為2次,表示將把給定的輸入張量展平(flatten)為向量,然后將每個元素重復2次,并返回重復后的張量。
a = torch.randn(3,2) a,a.repeat_interleave(2)
輸出結果如下:
(tensor([[-1.03, -0.32],
[ 0.43, 0.78],
[ 0.91, -0.11]]),
tensor([-1.03, -1.03, -0.32, -0.32, 0.43, 0.43, 0.78, 0.78, 0.91, 0.91,
-0.11, -0.11]))
6.3 輸入二維張量,指定dim=0,重復次數(shù)為3次,表示把輸入張量每行元素重復3次
a = torch.randn(3,2) a,torch.repeat_interleave(a,3,dim=0)
輸出結果如下:
(tensor([[ 0.14, 1.47],
[-1.52, -0.62],
[-0.24, -0.27]]),
tensor([[ 0.14, 1.47],
[ 0.14, 1.47],
[ 0.14, 1.47],
[-1.52, -0.62],
[-1.52, -0.62],
[-1.52, -0.62],
[-0.24, -0.27],
[-0.24, -0.27],
[-0.24, -0.27]]))
6.4 輸入二維張量,指定dim=1,重復次數(shù)為3次,表示把輸入張量每列元素重復3次
a = torch.randn(3,2) a,torch.repeat_interleave(a,3,dim=1)
輸出結果如下:
(tensor([[-0.81, 0.56],
[-2.41, -0.56],
[ 0.38, -0.90]]),
tensor([[-0.81, -0.81, -0.81, 0.56, 0.56, 0.56],
[-2.41, -2.41, -2.41, -0.56, -0.56, -0.56],
[ 0.38, 0.38, 0.38, -0.90, -0.90, -0.90]]))
6.5 輸入二維張量,指定dim=0,重復次數(shù)為一個張量列表[n1,n2,n3],表示在(dim=0)對應行上面重復n1,n2,n3遍,張量列表的長度必須與dim=0的維度的長度一樣,否則會報錯
a = torch.randn(3,2) a,torch.repeat_interleave(a,torch.tensor([2,3,4]),dim=0)#表示第一行重復2遍,第二行重復3遍,第三行重復4遍
輸出結果如下:
(tensor([[-0.79, 0.54],
[-0.47, -0.25],
[-0.13, 1.03]]),
tensor([[-0.79, 0.54],
[-0.79, 0.54],
[-0.47, -0.25],
[-0.47, -0.25],
[-0.47, -0.25],
[-0.13, 1.03],
[-0.13, 1.03],
[-0.13, 1.03],
[-0.13, 1.03]]))
7. 與torch.repeat()函數(shù)區(qū)別
兩個函數(shù)方法最大的區(qū)別就是repeat_interleave是一個元素一個元素地重復,而repeat是一組元素一組元素地重復.
總結
以上為個人經(jīng)驗,希望能給大家一個參考,也希望大家多多支持腳本之家。
相關文章
numpy數(shù)組的重塑和轉(zhuǎn)置實現(xiàn)
本文主要介紹了numpy數(shù)組的重塑和轉(zhuǎn)置實現(xiàn),文中通過示例代碼介紹的非常詳細,對大家的學習或者工作具有一定的參考學習價值,需要的朋友們下面隨著小編來一起學習學習吧2023-03-03Selenium獲取登錄Cookies并添加Cookies自動登錄的方法
這篇文章主要介紹了Selenium獲取登錄Cookies并添加Cookies自動登錄的方法,文中通過示例代碼介紹的非常詳細,對大家的學習或者工作具有一定的參考學習價值,需要的朋友們下面隨著小編來一起學習學習吧2020-12-12Python Flask基礎到登錄功能的實現(xiàn)代碼
這篇文章主要介紹了Python Flask基礎到登錄功能的實現(xiàn)代碼,本文給大家介紹的非常詳細,對大家的學習或工作具有一定的參考借鑒價值,需要的朋友可以參考下2021-05-05