python量化之搭建Transformer模型用于股票價(jià)格預(yù)測(cè)
前言
下面的這篇文章主要教大家如何搭建一個(gè)基于Transformer的簡(jiǎn)單預(yù)測(cè)模型,并將其用于股票價(jià)格預(yù)測(cè)當(dāng)中。原代碼在文末進(jìn)行獲取。
1、Transformer模型
Transformer 是 Google 的團(tuán)隊(duì)在 2017 年提出的一種 NLP 經(jīng)典模型,現(xiàn)在比較火熱的 Bert 也是基于 Transformer。Transformer 模型使用了 Self-Attention 機(jī)制,不采用 RNN 的順序結(jié)構(gòu),使得模型可以并行化訓(xùn)練,而且能夠擁有全局信息。這篇文章的目的主要是將帶大家通過(guò)Pytorch框架搭建一個(gè)基于Transformer的簡(jiǎn)單股票價(jià)格預(yù)測(cè)模型。
Transformer的基本架構(gòu):
具體地,我們用到了上證指數(shù)的收盤(pán)價(jià)數(shù)據(jù)為例,進(jìn)行預(yù)測(cè)t+1時(shí)刻的收盤(pán)價(jià)。需要注意的是,本文只是通過(guò)這樣一個(gè)簡(jiǎn)單的基本模型,帶大家梳理一下數(shù)據(jù)預(yù)處理,模型構(gòu)建以及模型評(píng)估的流程。模型還有很多可以改進(jìn)的地方,例如選擇更有意義的特征,如何進(jìn)行有效的多步預(yù)測(cè)等。
2、環(huán)境準(zhǔn)備
本地環(huán)境:
Python 3.7
IDE:Pycharm
庫(kù)版本:
numpy 1.18.1
pandas 1.0.3
sklearn 0.22.2
matplotlib 3.2.1
torch 1.10.1
3、代碼實(shí)現(xiàn)
3.1. 導(dǎo)入庫(kù)以及定義超參
首先,需要導(dǎo)入用到庫(kù),以及模型的一些超參數(shù)的設(shè)置。其中,input_window和output_window分別用于設(shè)置輸入數(shù)據(jù)的長(zhǎng)度以及輸出數(shù)據(jù)的長(zhǎng)度。當(dāng)然,這些參數(shù)大家也可以根據(jù)實(shí)際應(yīng)用場(chǎng)景進(jìn)行修改。
import torch import torch.nn as nn import numpy as np import time import math import matplotlib.pyplot as plt from sklearn.preprocessing import MinMaxScaler import pandas as pd torch.manual_seed(0) np.random.seed(0) input_window = 20 output_window = 1 batch_size = 64 device = torch. device("cuda" if torch.cuda.is_available() else "cpu") print(device)
3.2. 模型構(gòu)建
Transformer中很重要的一個(gè)組件是提出了一種新的位置編碼的方式。我們知道,循環(huán)神經(jīng)網(wǎng)絡(luò)本身就是一種順序結(jié)構(gòu),天生就包含了詞在序列中的位置信息。當(dāng)拋棄循環(huán)神經(jīng)網(wǎng)絡(luò)結(jié)構(gòu),完全采用Attention取而代之,這些詞序信息就會(huì)丟失,模型就沒(méi)有辦法知道每個(gè)詞在句子中的相對(duì)和絕對(duì)的位置信息。因此,有必要把詞序信號(hào)加到詞向量上幫助模型學(xué)習(xí)這些信息,位置編碼(PositionalEncoding)就是用來(lái)解決這種問(wèn)題的方法。它的原理是將生成的不同頻率的正弦和余弦數(shù)據(jù)作為位置編碼添加到輸入序列中,從而使得模型可以捕捉輸入變量的相對(duì)位置關(guān)系。
class PositionalEncoding(nn.Module): def __init__(self, d_model, max_len=5000): super(PositionalEncoding, self).__init__() pe = torch.zeros(max_len, d_model) position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) pe = pe.unsqueeze(0).transpose(0, 1) self.register_buffer('pe', pe) def forward(self, x): return x + self.pe[:x.size(0), :]
之后,搭建Transformer
的基本結(jié)構(gòu),在Pytorch中有已經(jīng)實(shí)現(xiàn)的封裝好的Transformer組件,可以很方便地進(jìn)行調(diào)用和修改。其中需要注意的是,文中并沒(méi)有采用原論文中的Encoder-Decoder的架構(gòu),而是將Decoder用了一個(gè)全連接層進(jìn)行代替,用于輸出預(yù)測(cè)值。另外,其中的create_mask將輸入進(jìn)行mask,從而避免引入未來(lái)信息。
class TransAm(nn.Module): def __init__(self, feature_size=250, num_layers=1, dropout=0.1): super(TransAm, self).__init__() self.model_type = 'Transformer' self.src_mask = None self.pos_encoder = PositionalEncoding(feature_size) self.encoder_layer = nn.TransformerEncoderLayer(d_model=feature_size, nhead=10, dropout=dropout) self.transformer_encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=num_layers) self.decoder = nn.Linear(feature_size, 1) self.init_weights() def init_weights(self): initrange = 0.1 self.decoder.bias.data.zero_() self.decoder.weight.data.uniform_(-initrange, initrange) def forward(self, src): if self.src_mask is None or self.src_mask.size(0) != len(src): device = src.device mask = self._generate_square_subsequent_mask(len(src)).to(device) self.src_mask = mask src = self.pos_encoder(src) output = self.transformer_encoder(src, self.src_mask) output = self.decoder(output) return output def _generate_square_subsequent_mask(self, sz): mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1) mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0)) return mask
3.3. 數(shù)據(jù)預(yù)處理
接下來(lái)需要對(duì)數(shù)據(jù)進(jìn)行預(yù)處理,首先定義一個(gè)窗口劃分的函數(shù)。它的作用是將輸入按照延遲output_windw的方式來(lái)劃分?jǐn)?shù)據(jù)以及其標(biāo)簽,文中是進(jìn)行單步預(yù)測(cè),所以假設(shè)輸入是1到20,則其標(biāo)簽就是2到21,以適應(yīng)Transformer
的seq2seq
的形式的輸出。
def create_inout_sequences(input_data, tw): inout_seq = [] L = len(input_data) for i in range(L - tw): train_seq = input_data[i:i + tw] train_label = input_data[i + output_window:i + tw + output_window] inout_seq.append((train_seq, train_label)) return torch.FloatTensor(inout_seq)
之后劃分訓(xùn)練集和測(cè)試集,其中前70%條數(shù)據(jù)用于模型訓(xùn)練,后面的數(shù)據(jù)用于模型測(cè)試。具體地,我們用到了前input_window個(gè)收盤(pán)價(jià)來(lái)預(yù)測(cè)下一時(shí)刻的收盤(pán)價(jià)數(shù)據(jù)。
def get_data(): series = pd.read_csv('./000001_Daily.csv', usecols=['Close']) # series = pd.read_csv('./daily-min-temperatures.csv', usecols=['Temp']) scaler = MinMaxScaler(feature_range=(-1, 1)) series = scaler.fit_transform(series.values.reshape(-1, 1)).reshape(-1) train_samples = int(0.7 * len(series)) train_data = series[:train_samples] test_data = series[train_samples:] train_sequence = create_inout_sequences(train_data, input_window) train_sequence = train_sequence[:-output_window] test_data = create_inout_sequences(test_data, input_window) test_data = test_data[:-output_window] return train_sequence.to(device), test_data.to(device)
接下來(lái)實(shí)現(xiàn)一個(gè)databatch generator
,便于從數(shù)據(jù)中按照batch的形式進(jìn)行讀取數(shù)據(jù)。
def get_batch(source, i, batch_size): seq_len = min(batch_size, len(source) - 1 - i) data = source[i:i + seq_len] input = torch.stack(torch.stack([item[0] for item in data]).chunk(input_window, 1)) target = torch.stack(torch.stack([item[1] for item in data]).chunk(input_window, 1)) return input, target
3.4. 模型訓(xùn)練以及評(píng)估
下面是模型訓(xùn)練的代碼。具體地,就是通過(guò)遍歷訓(xùn)練集,通過(guò)既定的loss,對(duì)參數(shù)進(jìn)行反向傳播,其中用到了梯度裁剪的技巧用于防止梯度爆炸,然后每間隔幾個(gè)間隔打印一下loss。
def train(train_data): model.train() for batch_index, i in enumerate(range(0, len(train_data) - 1, batch_size)): start_time = time.time() total_loss = 0 data, targets = get_batch(train_data, i, batch_size) optimizer.zero_grad() output = model(data) loss = criterion(output, targets) loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 0.7) optimizer.step() total_loss += loss.item() log_interval = int(len(train_data) / batch_size / 5) if batch_index % log_interval == 0 and batch_index > 0: cur_loss = total_loss / log_interval elapsed = time.time() - start_time print('| epoch {:3d} | {:5d}/{:5d} batches | lr {:02.6f} | {:5.2f} ms | loss {:5.5f} | ppl {:8.2f}' .format(epoch, batch_index, len(train_data) // batch_size, scheduler.get_lr()[0], elapsed * 1000 / log_interval, cur_loss, math.exp(cur_loss)))
接下來(lái)是對(duì)模型進(jìn)行評(píng)估的代碼。
def evaluate(eval_model, data_source): eval_model.eval() total_loss = 0 eval_batch_size = 1000 with torch.no_grad(): for i in range(0, len(data_source) - 1, eval_batch_size): data, targets = get_batch(data_source, i, eval_batch_size) output = eval_model(data) total_loss += len(data[0]) * criterion(output, targets).cpu().item() return total_loss / len(data_source)
最后,是模型運(yùn)行過(guò)程的可視化:
def plot_and_loss(eval_model, data_source, epoch): eval_model.eval() total_loss = 0. test_result = torch.Tensor(0) truth = torch.Tensor(0) with torch.no_grad(): for i in range(0, len(data_source) - 1): data, target = get_batch(data_source, i, 1) output = eval_model(data) total_loss += criterion(output, target).item() test_result = torch.cat((test_result, output[-1].view(-1).cpu()), 0) truth = torch.cat((truth, target[-1].view(-1).cpu()), 0) plt.plot(test_result, color="red") plt.plot(truth, color="blue") plt.grid(True, which='both') plt.axhline(y=0, color='k') plt.savefig('graph/transformer-epoch%d.png' % epoch) plt.close() return total_loss / i
3.5. 模型運(yùn)行
最后,對(duì)模型進(jìn)行運(yùn)行。其中用到了mse作為loss,adam作為優(yōu)化器,以及設(shè)定學(xué)習(xí)率的調(diào)度器,最后運(yùn)行200個(gè)epoch,每隔10個(gè)epoch在測(cè)試集上評(píng)估一下模型。
train_data, val_data = get_data() model = TransAm().to(device) criterion = nn.MSELoss() lr = 0.005 optimizer = torch.optim.AdamW(model.parameters(), lr=lr) scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1, gamma=0.95) epochs = 200 for epoch in range(1, epochs + 1): epoch_start_time = time.time() train(train_data) if (epoch % 10 is 0): val_loss = plot_and_loss(model, val_data, epoch) else: val_loss = evaluate(model, val_data) print('-' * 89) print('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.5f} | valid ppl {:8.2f}'.format(epoch, ( time.time() - epoch_start_time), val_loss, math.exp(val_loss))) print('-' * 89) scheduler.step()
下面是運(yùn)行的結(jié)果,可以看到loss明顯降低了:
cuda| epoch 1 | 2/ 10 batches | lr 0.005000 | 7.83 ms | loss 39.99368 | ppl 233902099994043520.00| epoch 1 |
4/ 10 batches | lr 0.005000 | 7.81 ms | loss 7.20889 | ppl 1351.39| epoch 1 | 6/ 10 batches | lr 0.005000 | 11.10 ms | loss 1.68758 | ppl 5.41| epoch 1 |
8/ 10 batches | lr 0.005000 | 9.35 ms | loss 0.00833 | ppl 1.01| epoch 1 | 10/ 10 batches | lr 0.005000 | 7.81 ms | loss 1.18041 | ppl 3.26-----------------------------------------------------------------------------------------| end of epoch 1 | time: 1.96s | valid loss 2.58557 | valid ppl 13.27
...
| end of epoch 198 | time: 0.30s | valid loss 0.00032 | valid ppl 1.00-----------------------------------------------------------------------------------------| epoch 199 |
2/ 10 batches | lr 0.000000 | 15.62 ms | loss 0.00057 | ppl 1.00| epoch 199 | 4/ 10 batches | lr 0.000000 | 15.62 ms | loss 0.00184 | ppl 1.00| epoch 199 |
6/ 10 batches | lr 0.000000 | 15.62 ms | loss 0.00212 | ppl 1.00| epoch 199 | 8/ 10 batches | lr 0.000000 | 7.81 ms | loss 0.00073 | ppl 1.00| epoch 199 | 10/ 10 batches | lr 0.000000 |
7.81 ms | loss 0.00057 | ppl 1.00-----------------------------------------------------------------------------------------| end of epoch 199 | time: 0.30s | valid loss 0.00032 | valid ppl 1.00-----------------------------------------------------------------------------------------| epoch 200 | 2/ 10 batches | lr 0.000000 | 15.62 ms | loss 0.00053 | ppl 1.00| epoch 200 |
4/ 10 batches | lr 0.000000 | 7.81 ms | loss 0.00177 | ppl
1.00| epoch 200 | 6/ 10 batches | lr 0.000000 | 7.81 ms | loss 0.00224 | ppl 1.00| epoch 200 | 8/ 10 batches | lr 0.000000 | 15.62 ms | loss 0.00069 | ppl 1.00| epoch 200 | 10/ 10 batches | lr 0.000000 | 7.81 ms | loss 0.00049 | ppl 1.00-----------------------------------------------------------------------------------------| end of epoch 200 | time: 0.62s | valid loss 0.00032 | valid ppl
1.00-----------------------------------------------------------------------------------------
最后是模型的擬合效果,從實(shí)驗(yàn)結(jié)果中可以看出我們搭建的簡(jiǎn)單的Transformer模型可以實(shí)現(xiàn)相對(duì)不錯(cuò)的數(shù)據(jù)擬合效果。
4、總結(jié)
在這篇文章中,我們介紹了如何基于Pytorch框架搭建一個(gè)基于Transformer的股票預(yù)測(cè)模型,并通過(guò)真實(shí)股票數(shù)據(jù)對(duì)模型進(jìn)行了實(shí)驗(yàn),可以看出Transformer模型對(duì)股價(jià)預(yù)測(cè)具有一定的效果。另外,文中只是做了一個(gè)簡(jiǎn)單的demo,其中仍然有很多可以改進(jìn)的地方,如采用更多有意義的輸入數(shù)據(jù),優(yōu)化其中的一些組件等。除此之外,目前基于Transformer的模型層出不窮,其中也有很多值得我們?nèi)W(xué)習(xí),大家也可以采用更先進(jìn)的Transformer模型進(jìn)行試驗(yàn)。
到此這篇關(guān)于python量化之搭建Transformer模型用于股票價(jià)格預(yù)測(cè)的文章就介紹到這了,更多相關(guān)python搭建Transformer模型內(nèi)容請(qǐng)搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
相關(guān)文章
Python 模擬購(gòu)物車(chē)的實(shí)例講解
下面小編就為大家?guī)?lái)一篇Python 模擬購(gòu)物車(chē)的實(shí)例講解。小編覺(jué)得挺不錯(cuò)的,現(xiàn)在就分享給大家,也給大家做個(gè)參考。一起跟隨小編過(guò)來(lái)看看吧2017-09-09Python虛擬機(jī)棧幀對(duì)象及獲取源碼學(xué)習(xí)
這篇文章主要為大家介紹了Python虛擬機(jī)棧幀對(duì)象及獲取源碼學(xué)習(xí),有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進(jìn)步,早日升職加薪2023-03-03python 爬取京東指定商品評(píng)論并進(jìn)行情感分析
本文主要講述了利用Python網(wǎng)絡(luò)爬蟲(chóng)對(duì)指定京東商城中指定商品下的用戶評(píng)論進(jìn)行爬取,對(duì)數(shù)據(jù)預(yù)處理操作后進(jìn)行文本情感分析,感興趣的朋友可以了解下2021-05-05Keras使用預(yù)訓(xùn)練模型遷移學(xué)習(xí)單通道灰度圖像詳解
這篇文章主要介紹了Keras使用預(yù)訓(xùn)練模型遷移學(xué)習(xí)單通道灰度圖像詳解,具有很好的參考價(jià)值,希望對(duì)大家有所幫助,如有錯(cuò)誤或未考慮完全的地方,望不吝賜教2024-02-02python 根據(jù)列表批量下載網(wǎng)易云音樂(lè)的免費(fèi)音樂(lè)
這篇文章主要介紹了python 根據(jù)列表下載網(wǎng)易云音樂(lè)的免費(fèi)音樂(lè),幫助大家更好的理解和學(xué)習(xí)python,感興趣的朋友可以了解下2020-12-12Python模擬登錄requests.Session應(yīng)用詳解
這篇文章主要介紹了Python模擬登錄requests.Session應(yīng)用詳解,文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來(lái)一起學(xué)習(xí)學(xué)習(xí)吧2020-11-11python數(shù)字圖像處理skimage讀取顯示與保存圖片
這篇文章主要為大家介紹了python數(shù)字圖像處理使用skimage讀取顯示與保存圖片示例詳解,有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進(jìn)步,早日升職加薪2022-06-06通過(guò)Python實(shí)現(xiàn)對(duì)SQL Server 數(shù)據(jù)文件大小的監(jiān)控告警功能
這篇文章主要介紹了通過(guò)Python實(shí)現(xiàn)對(duì)SQL Server 數(shù)據(jù)文件大小的監(jiān)控告警,本文給大家分享問(wèn)題報(bào)錯(cuò)信息及解決方案,需要的朋友可以參考下2021-04-04