基于pytorch的lstm參數(shù)使用詳解
lstm(*input, **kwargs)
將多層長(zhǎng)短時(shí)記憶(LSTM)神經(jīng)網(wǎng)絡(luò)應(yīng)用于輸入序列。
參數(shù):
input_size:輸入'x'中預(yù)期特性的數(shù)量
hidden_size:隱藏狀態(tài)'h'中的特性數(shù)量
num_layers:循環(huán)層的數(shù)量。例如,設(shè)置' ' num_layers=2 ' '意味著將兩個(gè)LSTM堆疊在一起,形成一個(gè)'堆疊的LSTM ',第二個(gè)LSTM接收第一個(gè)LSTM的輸出并計(jì)算最終結(jié)果。默認(rèn)值:1
bias:如果' False',則該層不使用偏置權(quán)重' b_ih '和' b_hh '。默認(rèn)值:'True'
batch_first:如果' 'True ' ',則輸入和輸出張量作為(batch, seq, feature)提供。默認(rèn)值: 'False'
dropout:如果非零,則在除最后一層外的每個(gè)LSTM層的輸出上引入一個(gè)“dropout”層,相當(dāng)于:attr:'dropout'。默認(rèn)值:0
bidirectional:如果‘True',則成為雙向LSTM。默認(rèn)值:'False'
輸入:input,(h_0, c_0)
**input**of shape (seq_len, batch, input_size):包含輸入序列特征的張量。輸入也可以是一個(gè)壓縮的可變長(zhǎng)度序列。
see:func:'torch.nn.utils.rnn.pack_padded_sequence' 或:func:'torch.nn.utils.rnn.pack_sequence' 的細(xì)節(jié)。
**h_0** of shape (num_layers * num_directions, batch, hidden_size):張量包含批處理中每個(gè)元素的初始隱藏狀態(tài)。
如果RNN是雙向的,num_directions應(yīng)該是2,否則應(yīng)該是1。
**c_0** of shape (num_layers * num_directions, batch, hidden_size):張量包含批處理中每個(gè)元素的初始單元格狀態(tài)。
如果沒有提供' (h_0, c_0) ',則**h_0**和**c_0**都默認(rèn)為零。
輸出:output,(h_n, c_n)
**output**of shape (seq_len, batch, num_directions * hidden_size) :包含LSTM最后一層輸出特征' (h_t) '張量,
對(duì)于每個(gè)t. If a:class: 'torch.nn.utils.rnn.PackedSequence' 已經(jīng)給出,輸出也將是一個(gè)打包序列。
對(duì)于未打包的情況,可以使用'output.view(seq_len, batch, num_directions, hidden_size)',正向和反向分別為方向' 0 '和' 1 '。
同樣,在包裝的情況下,方向可以分開。
**h_n** of shape (num_layers * num_directions, batch, hidden_size):包含' t = seq_len '隱藏狀態(tài)的張量。
與*output*類似, the layers可以使用以下命令分隔
h_n.view(num_layers, num_directions, batch, hidden_size) 對(duì)于'c_n'相似
**c_n** (num_layers * num_directions, batch, hidden_size):張量包含' t = seq_len '的單元狀態(tài)
所有的權(quán)重和偏差都初始化自: where:
include:: cudnn_persistent_rnn.rst
import torch import torch.nn as nn # 雙向rnn例子 # rnn = nn.RNN(10, 20, 2) # input = torch.randn(5, 3, 10) # h0 = torch.randn(2, 3, 20) # output, hn = rnn(input, h0) # print(output.shape,hn.shape) # torch.Size([5, 3, 20]) torch.Size([2, 3, 20]) # 雙向lstm例子 rnn = nn.LSTM(10, 20, 2) #(input_size,hidden_size,num_layers) input = torch.randn(5, 3, 10) #(seq_len, batch, input_size) h0 = torch.randn(2, 3, 20) #(num_layers * num_directions, batch, hidden_size) c0 = torch.randn(2, 3, 20) #(num_layers * num_directions, batch, hidden_size) # output:(seq_len, batch, num_directions * hidden_size) # hn,cn(num_layers * num_directions, batch, hidden_size) output, (hn, cn) = rnn(input, (h0, c0)) print(output.shape,hn.shape,cn.shape) >>>torch.Size([5, 3, 20]) torch.Size([2, 3, 20]) torch.Size([2, 3, 20])
以上這篇基于pytorch的lstm參數(shù)使用詳解就是小編分享給大家的全部?jī)?nèi)容了,希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
相關(guān)文章
深入講解Python函數(shù)中參數(shù)的使用及默認(rèn)參數(shù)的陷阱
這篇文章主要介紹了Python函數(shù)中參數(shù)的使用及默認(rèn)參數(shù)的陷阱,文中將函數(shù)的參數(shù)分為必選參數(shù)、默認(rèn)參數(shù)、可變參數(shù)和關(guān)鍵字參數(shù)來講,要的朋友可以參考下2016-03-03Python3自定義http/https請(qǐng)求攔截mitmproxy腳本實(shí)例
這篇文章主要介紹了Python3自定義http/https請(qǐng)求攔截mitmproxy腳本實(shí)例,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過來看看吧2020-05-05python將類似json的數(shù)據(jù)存儲(chǔ)到MySQL中的實(shí)例
今天小編就為大家分享一篇python將類似json的數(shù)據(jù)存儲(chǔ)到MySQL中的實(shí)例,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過來看看吧2019-07-07Python中rapidjson參數(shù)校驗(yàn)實(shí)現(xiàn)
通常需要對(duì)前端傳遞過來的參數(shù)進(jìn)行校驗(yàn),校驗(yàn)的方式有多種,本文主要介紹了Python中rapidjson參數(shù)校驗(yàn)實(shí)現(xiàn),具有一定的參考價(jià)值,感興趣的小伙伴們可以參考一下2021-07-07基于Python實(shí)現(xiàn)實(shí)時(shí)監(jiān)控CPU使用率
這篇文章主要為大家介紹了一款手寫編程代碼的小腳本,能夠輕松在界面上展示:利用Python實(shí)時(shí)監(jiān)控CPU使用率,隨時(shí)展現(xiàn)。也無需下載管理軟件,感興趣的可以了解一下2022-04-04Python程序中使用SQLAlchemy時(shí)出現(xiàn)亂碼的解決方案
這篇文章主要介紹了Python程序中使用SQLAlchemy時(shí)出現(xiàn)亂碼的解決方案,SQLAlchemy是Python常用的操作MySQL數(shù)據(jù)庫的工具,需要的朋友可以參考下2015-04-04