pytorch中的reshape()、view()、nn.flatten()和flatten()使用
在使用pytorch定義神經(jīng)網(wǎng)絡(luò)結(jié)構(gòu)時,經(jīng)常會看到類似如下的.view() / flatten()用法,這里對其用法做出講解與演示。
torch.reshape用法
reshape()可以由torch.reshape(),也可由torch.Tensor.reshape()調(diào)用,其作用是在不改變tensor元素數(shù)目的情況下改變tensor的shape。
torch.reshape() 需要兩個參數(shù),一個是待被改變的張量tensor,一個是想要改變的形狀。
torch.reshape(input, shape) → Tensor
- input(Tensor)-要重塑的張量
- shape(python的元組:ints)-新形狀`
案例1.
輸入:
import torch a = torch.tensor([[0,1],[2,3]]) x = torch.reshape(a,(-1,)) print (x) b = torch.arange(4.) Y = torch.reshape(a,(2,2)) print(Y)
結(jié)果:
tensor([0, 1, 2, 3])
tensor([[0, 1],
[2, 3]])
torch.view用法
view()的原理很簡單,其實就是把原先tensor中的數(shù)據(jù)進行排列,排成一行,然后根據(jù)所給的view()中的參數(shù)從一行中按順序選擇組成最終的tensor。
view()可以有多個參數(shù),這取決于你想要得到的是幾維的tensor,一般設(shè)置兩個參數(shù),也是神經(jīng)網(wǎng)絡(luò)中常用的(一般在全連接之前),代表二維。
view(h,w),h代表行(想要變?yōu)閹仔校?,當不知道要變?yōu)閹仔?,但知道要變?yōu)閹琢袝r可取-1;w代表的是列(想要變?yōu)閹琢校?,當不知道要變?yōu)閹琢?,但知道要變?yōu)閹仔袝r可取-1。
一、普通用法(手動調(diào)整)
view()相當于reshape、resize,重新調(diào)整Tensor的形狀。
案例2.
輸入
import torch a1 = torch.arange(0,16) print(a1)
輸出
tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15])
輸入
a2 = a1.view(8, 2) a3 = a1.view(2, 8) a4 = a1.view(4, 4) print(a2) print(a3) print(a4)
輸出
tensor([[ 0, 1],
[ 2, 3],
[ 4, 5],
[ 6, 7],
[ 8, 9],
[10, 11],
[12, 13],
[14, 15]])
tensor([[ 0, 1, 2, 3, 4, 5, 6, 7],
[ 8, 9, 10, 11, 12, 13, 14, 15]])
tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11],
[12, 13, 14, 15]])
二、特殊用法:參數(shù)-1(自動調(diào)整size)
view中一個參數(shù)定為-1,代表自動調(diào)整這個維度上的元素個數(shù),以保證元素的總數(shù)不變。
輸入
import torch a1 = torch.arange(0,16) print(a1)
輸出
tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15])
輸入
a2 = a1.view(-1, 16) a3 = a1.view(-1, 8) a4 = a1.view(-1, 4) a5 = a1.view(-1, 2) a6 = a1.view(4*4, -1) a7 = a1.view(1*4, -1) a8 = a1.view(2*4, -1) print(a2) print(a3) print(a4) print(a5) print(a6) print(a7) print(a8)
輸出
tensor([[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]])
tensor([[ 0, 1, 2, 3, 4, 5, 6, 7],
[ 8, 9, 10, 11, 12, 13, 14, 15]])
tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11],
[12, 13, 14, 15]])
tensor([[ 0, 1],
[ 2, 3],
[ 4, 5],
[ 6, 7],
[ 8, 9],
[10, 11],
[12, 13],
[14, 15]])
tensor([[ 0],
[ 1],
[ 2],
[ 3],
[ 4],
[ 5],
[ 6],
[ 7],
[ 8],
[ 9],
[10],
[11],
[12],
[13],
[14],
[15]])
tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11],
[12, 13, 14, 15]])
tensor([[ 0, 1],
[ 2, 3],
[ 4, 5],
[ 6, 7],
[ 8, 9],
[10, 11],
[12, 13],
[14, 15]])
torch.nn.Flatten(start_dim=1,end_dim=-1)
start_dim與end_dim分別表示開始的維度和終止的維度,默認值為1和-1,其中1表示第一維度,-1表示最后的維度。結(jié)合起來看意思就是從第一維度到最后一個維度全部給展平為張量。(注意:數(shù)據(jù)的維度是從0開始的,也就是存在第0維度,第一維度并不是真正意義上的第一個)。
因為其被用在神經(jīng)網(wǎng)絡(luò)中,輸入為一批數(shù)據(jù),第 0 維為batch(輸入數(shù)據(jù)的個數(shù)),通常要把一個數(shù)據(jù)拉成一維,而不是將一批數(shù)據(jù)拉為一維。所以torch.nn.Flatten()默認從第一維開始平坦化。
使用nn.Flatten(),使用默認參數(shù)
官方給出的示例:
input = torch.randn(32, 1, 5, 5) # With default parameters m = nn.Flatten() output = m(input) output.size() #torch.Size([32, 25]) # With non-default parameters m = nn.Flatten(0, 2) output = m(input) output.size() #torch.Size([160, 5])
#開頭的代碼是注釋
整段代碼的意思是:給定一個維度為(32,1,5,5)的隨機數(shù)據(jù)。
1.先使用一次nn.Flatten(),使用默認參數(shù):
m = nn.Flatten()
也就是說從第一維度展平到最后一個維度,數(shù)據(jù)的維度是從0開始的,第一維度實際上是數(shù)據(jù)的第二位置代表的維度,也就是樣例中的1。
因此進行展平后的結(jié)果也就是[32,155]→[32,25]
2.接著再使用一次指定參數(shù)的nn.Flatten(),即
m = nn.Flatten(0,2)
也就是說從第0維度展平到第2維度,0~2,對應(yīng)的也就是前三個維度。
因此結(jié)果就是[3215,5]→[160,25]
torch.flatten
torch.flatten()函數(shù)經(jīng)常用于寫分類神經(jīng)網(wǎng)絡(luò)的時候,經(jīng)過最后一個卷積層之后,一般會再接一個自適應(yīng)的池化層,輸出一個BCHW的向量。
這時候就需要用到torch.flatten()函數(shù)將這個向量拉平成一個Bx的向量(其中,x = CHW),然后送入到FC層中。
語句結(jié)構(gòu)
torch.flatten(input, start_dim=0, end_dim=-1)
input: 一個 tensor,即要被“攤平”的 tensor。
- start_dim: “攤平”的起始維度。
- end_dim: “攤平”的結(jié)束維度。
作用與 torch.nn.flatten 類似,都是用于展平 tensor 的,只是 torch.flatten 是 function 而不是類,其默認開始維度為第 0 維。
例1:
import torch data_pool = torch.randn(2,2,3,3) # 模擬經(jīng)過最后一個池化層或自適應(yīng)池化層之后的輸出,Batchsize*c*h*w print(data_pool) y=torch.flatten(data_pool,1) print(y)
輸出結(jié)果:
結(jié)果是一個B*x的向量。
總結(jié)
以上為個人經(jīng)驗,希望能給大家一個參考,也希望大家多多支持腳本之家。
相關(guān)文章
Python實現(xiàn)定時自動關(guān)閉的tkinter窗口方法
今天小編就為大家分享一篇Python實現(xiàn)定時自動關(guān)閉的tkinter窗口方法,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧2019-02-02python網(wǎng)頁請求urllib2模塊簡單封裝代碼
這篇文章主要分享一個python網(wǎng)頁請求模塊urllib2模塊的簡單封裝代碼,有需要的朋友參考下2014-02-02圖解Python中淺拷貝copy()和深拷貝deepcopy()的區(qū)別
這篇文章主要介紹了Python中淺拷貝copy()和深拷貝deepcopy()的區(qū)別,淺拷貝和深拷貝想必大家在學(xué)習中遇到很多次,這也是面試中常常被問到的問題,本文就帶你詳細了解一下2023-05-05python常用函數(shù)random()函數(shù)詳解
這篇文章主要介紹了python常用函數(shù)random()函數(shù),本文通過實例代碼給大家介紹的非常詳細,對大家的學(xué)習或工作具有一定的參考借鑒價值,需要的朋友可以參考下2023-02-02Python深度學(xué)習pytorch神經(jīng)網(wǎng)絡(luò)塊的網(wǎng)絡(luò)之VGG
雖然AlexNet證明深層神經(jīng)網(wǎng)絡(luò)卓有成效,但它沒有提供一個通用的模板來指導(dǎo)后續(xù)的研究人員設(shè)計新的網(wǎng)絡(luò)。下面,我們將介紹一些常用于設(shè)計深層神經(jīng)網(wǎng)絡(luò)的啟發(fā)式概念2021-10-10Python內(nèi)存管理器如何實現(xiàn)池化技術(shù)
Python中的內(nèi)存管理是從三個方面來進行的,一對象的引用計數(shù)機制,二垃圾回收機制,三內(nèi)存池機制,下面這篇文章主要給大家介紹了關(guān)于Python內(nèi)存管理器如何實現(xiàn)池化技術(shù)的相關(guān)資料,需要的朋友可以參考下2022-05-05Python使用eval函數(shù)執(zhí)行動態(tài)標表達式過程詳解
這篇文章主要介紹了Python使用eval函數(shù)執(zhí)行動態(tài)標表達式過程詳解,文中通過示例代碼介紹的非常詳細,對大家的學(xué)習或者工作具有一定的參考學(xué)習價值,需要的朋友可以參考下2020-10-10