pytorch-RNN進(jìn)行回歸曲線預(yù)測(cè)方式
任務(wù)
通過(guò)輸入的sin曲線與預(yù)測(cè)出對(duì)應(yīng)的cos曲線
#初始加載包 和定義參數(shù) import torch from torch import nn import numpy as np import matplotlib.pyplot as plt torch.manual_seed(1) #為了可復(fù)現(xiàn) #超參數(shù)設(shè)定 TIME_SETP=10 INPUT_SIZE=1 LR=0.02 DOWNLoad_MNIST=True
定義RNN網(wǎng)絡(luò)結(jié)構(gòu)
from torch.autograd import Variable
class RNN(nn.Module):
def __init__(self):
#在這個(gè)函數(shù)中,兩步走,先init,再逐步定義層結(jié)構(gòu)
super(RNN,self).__init__()
self.rnn=nn.RNN( #定義32隱層的rnn結(jié)構(gòu)
input_size=1,
hidden_size=32, #隱層有32個(gè)記憶體
num_layers=1, #隱層層數(shù)是1
batch_first=True
)
self.out=nn.Linear(32,1) #32個(gè)記憶體對(duì)應(yīng)一個(gè)輸出
def forward(self,x,h_state):
#前向過(guò)程,獲取 rnn網(wǎng)絡(luò)輸出r_put(注意這里r_out并不是最后輸出,最后要經(jīng)過(guò)全連接層) 和 記憶體情況h_state
r_out,h_state=self.rnn(x,h_state)
outs=[]#獲取所有時(shí)間點(diǎn)下得到的預(yù)測(cè)值
for time_step in range(r_out.size(1)): #將記憶rnn層的輸出傳到全連接層來(lái)得到最終輸出。 這樣每個(gè)輸入對(duì)應(yīng)一個(gè)輸出,所以會(huì)有長(zhǎng)度為10的輸出
outs.append(self.out(r_out[:,time_step,:]))
return torch.stack(outs,dim=1),h_state #將10個(gè)數(shù) 通過(guò)stack方式壓縮在一起
rnn=RNN()
print('RNN的網(wǎng)絡(luò)體系結(jié)構(gòu)為:',rnn)

創(chuàng)建數(shù)據(jù)集及網(wǎng)絡(luò)訓(xùn)練
以sin曲線為特征,以cos曲線為標(biāo)簽進(jìn)行網(wǎng)絡(luò)的訓(xùn)練
#定義優(yōu)化器和 損失函數(shù)
optimizer=torch.optim.Adam(rnn.parameters(),lr=LR)
loss_fun=nn.MSELoss()
h_state=None #記錄的隱藏層狀態(tài),記住這就是記憶體,初始時(shí)候?yàn)榭?,之后每次后面的都?huì)使用到前面的記憶,自動(dòng)生成全0的
#這樣加入記憶信息后,每次都會(huì)在之前的記憶矩陣基礎(chǔ)上再進(jìn)行新的訓(xùn)練,初始是全0的形式。
#啟動(dòng)訓(xùn)練,這里假定訓(xùn)練的批次為100次
plt.ion() #可以設(shè)定持續(xù)不斷的繪圖,但是在這里看還是間斷的,這是jupyter的問(wèn)題
for step in range(100):
#我們以一個(gè)π為一個(gè)時(shí)間步 定義數(shù)據(jù),
start,end=step*np.pi,(step+1)*np.pi
steps=np.linspace(start,end,10,dtype=np.float32) #注意這里的10并不是間隔為10,而是將數(shù)按范圍分成10等分了
x_np=np.sin(steps)
y_np=np.cos(steps)
#將numpy類型轉(zhuǎn)成torch類型 *****當(dāng)需要 求梯度時(shí),一個(gè) op 的兩個(gè)輸入都必須是要 Variable,輸入的一定要variable包下
x=Variable(torch.from_numpy(x_np[np.newaxis,:,np.newaxis]))#增加兩個(gè)維度,是三維的數(shù)據(jù)。
y=Variable(torch.from_numpy(y_np[np.newaxis,:,np.newaxis]))
#將每個(gè)時(shí)間步上的10個(gè)值 輸入到rnn獲得結(jié)果 這里rnn會(huì)自動(dòng)執(zhí)行forward前向過(guò)程. 這里輸入時(shí)10個(gè),輸出也是10個(gè),傳遞的是一個(gè)長(zhǎng)度為32的記憶體
predition,h_state=rnn(x,h_state)
#更新新的中間狀態(tài)
h_state=Variable(h_state.data) #擦,這點(diǎn)一定要從新包裝
loss=loss_fun(predition,y)
#print('loss:',loss)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# plotting 畫圖,這里先平展了 flatten,這樣就是得到一個(gè)數(shù)組,更加直接
plt.plot(steps, y_np.flatten(), 'r-')
plt.plot(steps, predition.data.numpy().flatten(), 'b-')
#plt.draw();
plt.pause(0.05)
plt.ioff() #關(guān)閉交互模式
plt.show()

以上這篇pytorch-RNN進(jìn)行回歸曲線預(yù)測(cè)方式就是小編分享給大家的全部?jī)?nèi)容了,希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
相關(guān)文章
使用python讀取csv文件快速插入數(shù)據(jù)庫(kù)的實(shí)例
今天小編就為大家分享一篇使用python讀取csv文件快速插入數(shù)據(jù)庫(kù)的實(shí)例,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧2018-06-06
解決Alexnet訓(xùn)練模型在每個(gè)epoch中準(zhǔn)確率和loss都會(huì)一升一降問(wèn)題
這篇文章主要介紹了解決Alexnet訓(xùn)練模型在每個(gè)epoch中準(zhǔn)確率和loss都會(huì)一升一降問(wèn)題,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧2020-06-06
Python中url標(biāo)簽使用知識(shí)點(diǎn)總結(jié)
這篇文章主要介紹了Python中url標(biāo)簽使用知識(shí)點(diǎn)以及相關(guān)實(shí)例代碼,需要的朋友們參考下。2020-01-01
Pycharm?cannot?set?up?a?python?SDK問(wèn)題的原因及解決方法
這篇文章主要給大家介紹了關(guān)于Pycharm?cannot?set?up?a?python?SDK問(wèn)題的原因及解決方法,這個(gè)問(wèn)題已經(jīng)不是第一次出現(xiàn)了,所以干脆總結(jié)下,需要的朋友可以參考下2022-06-06
Python統(tǒng)計(jì)字符內(nèi)容的占比的實(shí)現(xiàn)
本文介紹了如何使用Python統(tǒng)計(jì)字符占比,包括字符串中字母、數(shù)字、空格等字符的占比,對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來(lái)一起學(xué)習(xí)學(xué)習(xí)吧2023-08-08
Pytorch使用技巧之Dataloader中的collate_fn參數(shù)詳析
collate_fn 參數(shù)的目的主要是為了隨心所欲的轉(zhuǎn)變數(shù)據(jù)的類型,這個(gè)數(shù)據(jù)是用DataLoader加載的,比如img,target,下面這篇文章主要給大家介紹了關(guān)于Pytorch使用技巧之Dataloader中的collate_fn參數(shù)的相關(guān)資料,需要的朋友可以參考下2022-03-03

