Pytorch中關(guān)于RNN輸入和輸出的形狀總結(jié)
Pytorch對(duì)RNN輸入和輸出的形狀總結(jié)
個(gè)人對(duì)于RNN的一些總結(jié)。
RNN的輸入和輸出
RNN的經(jīng)典圖如下所示
各個(gè)參數(shù)的含義
- Xt: t時(shí)刻的輸入,形狀為[batch_size, input_dim]。對(duì)于整個(gè)RNN來說,總的X輸入為[seq_len, batch_size, input_dim],具體如何理解batch_size和seq_len在下面有說明。
- St: t時(shí)刻隱藏層的狀態(tài),也有時(shí)用ht表示,形狀為[batch_size, hidden_size],St=f(U·Xt+W·St-1),通過W和U矩陣的映射,將embedding后的Xt和上一狀態(tài)St-1轉(zhuǎn)為St
- Ot: t時(shí)刻的輸出,Ot=g(V·St),形狀為[batch_size, hidden_size],總的為輸出O為[seq_len, batch_size, hidden_size]
Pytorch中的使用
Pytorch中RNN函數(shù)如下
RNN的主要參數(shù)如下
nn.RNN(input_size, hidden_size, num_layers=1, bias=True)
參數(shù)解釋
input_size
: 輸入特征的維度,一般rnn中輸入的是詞向量,那么就為embedding-dimhidden_size
: 隱藏層神經(jīng)元的個(gè)數(shù),或者也叫輸出的維度num_layers
: 隱藏層的個(gè)數(shù),默認(rèn)為1
output=輸出O, 隱藏狀態(tài)St,其中輸出O=[time_step, batch_size, hidden_size],St為t時(shí)刻的隱藏層狀態(tài)
理解RNN中的batch_size和seq_len
深度學(xué)習(xí)中采用mini-batch的方法進(jìn)行迭代優(yōu)化,在CNN中batch的思想較容易理解,一次輸入batch個(gè)圖片,進(jìn)行迭代。但是RNN中引入了seq_len(time_step), 理解較為困難,下面是我自己的一些理解。
首先假如我有五句話,作為訓(xùn)練的語料。
sentences = ["i like dog", "i love coffee", "i hate milk", "i like music", "i hate you"]
那么在輸入RNN之前要先進(jìn)行embedding,比如one-hot encoding,容易得到這里的embedding-dim為9.
那么輸入的sentences可以表示為如下方式
t=0 | t=1 | t=2 | |
---|---|---|---|
batch1 | i | like | dog |
batch2 | i | love | coffee |
batch3 | i | hate | milk |
batch4 | i | like | music |
batch5 | i | hate | you |
那么在RNN的訓(xùn)練中。
- t=0時(shí), 輸入第一個(gè)batch[i, i, i, i, i]這里用字符表示,其實(shí)應(yīng)該是對(duì)應(yīng)的one-hot編碼。
- t=1時(shí),輸入第二個(gè)batch[like, love, hate, like, hate]
- t=2時(shí),輸入第三個(gè)batch[dog, coffee, milk, music, you]
那么對(duì)應(yīng)的時(shí)間t來說,RNN需要對(duì)先后輸入的batch_size個(gè)字符進(jìn)行前向計(jì)算迭代,得到輸出。
Pytorch雙向RNN隱藏層和輸出層結(jié)果拆分
1 RNN隱藏層和輸出層結(jié)果的形狀
從Pytorch官方文檔可以得到,對(duì)于批量化輸入的RNN來講,其隱藏層的shape為(num_directions*num_layers, batch_size, hidden_size)。
其輸出的shape為(seq_len, batch_size, D*hidden_size)。
2 雙向RNN情況下,隱藏層和輸出層結(jié)果拆分
當(dāng)采用雙向RNN時(shí),其輸出的結(jié)果包含正向和反向兩個(gè)方向輸出的結(jié)果。
2.1 輸出層結(jié)果拆分
其中對(duì)于輸出output來講,從官方文檔我們可以得到,其拆分正向和反向兩個(gè)方向結(jié)果的方法為:
output.shape = (seq_len, batch_size, num_directions*hidden_size)
output.view(seq_len, batch, num_directions, hidden_size)
其中,對(duì)于(num_directions)方向維度,正向和反向的維度值分別為??0???和??1?。
2.2 隱藏層結(jié)果拆分
而對(duì)于隱藏層,包括初始值h_0以及最終輸出h_n,也都包含兩個(gè)方向的隱藏狀態(tài),但是其拆分方式跟輸出層不一樣。
方法如下:
h_0, h_n.shape = (num_directions*num_layers, batch_size, hidden_size)
h_0, h_n.view(num_layers, num_directions, batch_size, hidden_size)
可以從簡(jiǎn)單單層雙向RNN的輸出結(jié)果來驗(yàn)證,此時(shí)RNN的輸出結(jié)果與最后一層的隱藏層結(jié)果是一樣的。
import torch import torch.nn as nn if __name__ == "__main__": # input_size: 3, hidden_size: 5, num_layers: 3 BiRNN_Net = nn.RNN(3, 5, 3, bidirectional=True, batch_first=True) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # batch_size: 1, seq_len: 1, input_size: 3 inputs = torch.zeros(1, 1, 3, device=device) # state: (num_directions*num_layers, batch_size, hidden_size) state = torch.randn(6, 1, 5, device=device) BiRNN_Net.to(device) output, hidden = BiRNN_Net(inputs, state) output_re = output.reshape((1, 1, 2, 5)) hidden_re = hidden.reshape((3, 2, 1, 5)) print(output) print(output_re) print(hidden) print(hidden_re)
輸出結(jié)果可以看出,隱藏層的結(jié)果是優(yōu)先num_layers網(wǎng)絡(luò)層數(shù)這一個(gè)維度來構(gòu)成的。
tensor([[[ 0.3939, -0.9160, ?0.5054, ?0.2949, -0.5225, ?0.0533, ?0.4197, ? ? ? ? ? -0.7200, -0.1262, -0.7975]]], device='cuda:0', ? ? ? ?grad_fn=<CudnnRnnBackward0>) tensor([[[[ 0.3939, -0.9160, ?0.5054, ?0.2949, -0.5225], ? ? ? ? ? [ 0.0533, ?0.4197, -0.7200, -0.1262, -0.7975]]]], device='cuda:0', ? ? ? ?grad_fn=<ReshapeAliasBackward0>) tensor([[[-0.2606, ?0.5410, -0.2663, ?0.6418, -0.2902]], ? ? ? ? [[ 0.1367, ?0.7222, -0.3051, -0.6410, -0.3062]], ? ? ? ? [[ 0.2433, ?0.3287, -0.4809, -0.1782, -0.5582]], ? ? ? ? [[ 0.4824, -0.8529, ?0.7604, ?0.8508, -0.1902]], ? ? ? ? [[ 0.3939, -0.9160, ?0.5054, ?0.2949, -0.5225]], ? ? ? ? [[ 0.0533, ?0.4197, -0.7200, -0.1262, -0.7975]]], device='cuda:0', ? ? ? ?grad_fn=<CudnnRnnBackward0>) tensor([[[[-0.2606, ?0.5410, -0.2663, ?0.6418, -0.2902]], ? ? ? ? ?[[ 0.1367, ?0.7222, -0.3051, -0.6410, -0.3062]]], ? ? ? ? [[[ 0.2433, ?0.3287, -0.4809, -0.1782, -0.5582]], ? ? ? ? ?[[ 0.4824, -0.8529, ?0.7604, ?0.8508, -0.1902]]], ? ? ? ? [[[ 0.3939, -0.9160, ?0.5054, ?0.2949, -0.5225]], ? ? ? ? ?[[ 0.0533, ?0.4197, -0.7200, -0.1262, -0.7975]]]], device='cuda:0', ? ? ? ?grad_fn=<ReshapeAliasBackward0>)
總結(jié)
以上為個(gè)人經(jīng)驗(yàn),希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
相關(guān)文章
Python實(shí)現(xiàn)遠(yuǎn)程調(diào)用MetaSploit的方法
這篇文章主要介紹了Python實(shí)現(xiàn)遠(yuǎn)程調(diào)用MetaSploit的方法,是很有借鑒價(jià)值的一個(gè)技巧,需要的朋友可以參考下2014-08-08淺談Python3中strip()、lstrip()、rstrip()用法詳解
這篇文章主要介紹了淺談Python3中strip()、lstrip()、rstrip()用法詳解,小編覺得挺不錯(cuò)的,現(xiàn)在分享給大家,也給大家做個(gè)參考。一起跟隨小編過來看看吧2019-04-04Python調(diào)用DeepSeek?API的案例詳細(xì)教程
這篇文章主要為大家詳細(xì)介紹了以?Python?為例的調(diào)用?DeepSeek?API?的小白入門級(jí)詳細(xì)教程,文中的示例代碼講解詳細(xì),感興趣的小伙伴可以了解下2025-02-02Python?數(shù)據(jù)篩選功能實(shí)現(xiàn)
這篇文章主要介紹了Python?數(shù)據(jù)篩選,無論是在數(shù)據(jù)分析還是數(shù)據(jù)挖掘的時(shí)候,數(shù)據(jù)篩選總會(huì)涉及到,這里我總結(jié)了一下python中列表,字典,數(shù)據(jù)框中一些常用的數(shù)據(jù)篩選的方法,需要的朋友可以參考下2023-04-04python 五子棋如何獲得鼠標(biāo)點(diǎn)擊坐標(biāo)
這篇文章主要介紹了python 五子棋如何獲得鼠標(biāo)點(diǎn)擊坐標(biāo),文中通過示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友可以參考下2019-11-11python實(shí)現(xiàn)掃描日志關(guān)鍵字的示例
下面小編就為大家分享一篇python實(shí)現(xiàn)掃描日志關(guān)鍵字的示例,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過來看看吧2018-04-04