Pytorch從0實(shí)現(xiàn)Transformer的實(shí)踐
摘要
With the continuous development of time series prediction, Transformer-like models have gradually replaced traditional models in the fields of CV and NLP by virtue of their powerful advantages. Among them, the Informer is far superior to the traditional RNN model in long-term prediction, and the Swin Transformer is significantly stronger than the traditional CNN model in image recognition. A deep grasp of Transformer has become an inevitable requirement in the field of artificial intelligence. This article will use the Pytorch framework to implement the position encoding, multi-head attention mechanism, self-mask, causal mask and other functions in Transformer, and build a Transformer network from 0.
隨著時(shí)序預(yù)測(cè)的不斷發(fā)展,Transformer類模型憑借強(qiáng)大的優(yōu)勢(shì),在CV、NLP領(lǐng)域逐漸取代傳統(tǒng)模型。其中Informer在長(zhǎng)時(shí)序預(yù)測(cè)上遠(yuǎn)超傳統(tǒng)的RNN模型,Swin Transformer在圖像識(shí)別上明顯強(qiáng)于傳統(tǒng)的CNN模型。深層次掌握Transformer已經(jīng)成為從事人工智能領(lǐng)域的必然要求。本文將用Pytorch框架,實(shí)現(xiàn)Transformer中的位置編碼、多頭注意力機(jī)制、自掩碼、因果掩碼等功能,從0搭建一個(gè)Transformer網(wǎng)絡(luò)。
一、構(gòu)造數(shù)據(jù)
1.1 句子長(zhǎng)度
# 關(guān)于word embedding,以序列建模為例 # 輸入句子有兩個(gè),第一個(gè)長(zhǎng)度為2,第二個(gè)長(zhǎng)度為4 src_len = torch.tensor([2, 4]).to(torch.int32) # 目標(biāo)句子有兩個(gè)。第一個(gè)長(zhǎng)度為4, 第二個(gè)長(zhǎng)度為3 tgt_len = torch.tensor([4, 3]).to(torch.int32) print(src_len) print(tgt_len)
輸入句子(src_len)有兩個(gè),第一個(gè)長(zhǎng)度為2,第二個(gè)長(zhǎng)度為4
目標(biāo)句子(tgt_len)有兩個(gè)。第一個(gè)長(zhǎng)度為4, 第二個(gè)長(zhǎng)度為3
1.2 生成句子
用隨機(jī)數(shù)生成句子,用0填充空白位置,保持所有句子長(zhǎng)度一致
src_seq = torch.cat([torch.unsqueeze(F.pad(torch.randint(1, max_num_src_words, (L, )), (0, max(src_len)-L)), 0) for L in src_len]) tgt_seq = torch.cat([torch.unsqueeze(F.pad(torch.randint(1, max_num_tgt_words, (L, )), (0, max(tgt_len)-L)), 0) for L in tgt_len]) print(src_seq) print(tgt_seq)
src_seq為輸入的兩個(gè)句子,tgt_seq為輸出的兩個(gè)句子。
為什么句子是數(shù)字?在做中英文翻譯時(shí),每個(gè)中文或英文對(duì)應(yīng)的也是一個(gè)數(shù)字,只有這樣才便于處理。
1.3 生成字典
在該字典中,總共有8個(gè)字(行),每個(gè)字對(duì)應(yīng)8維向量(做了簡(jiǎn)化了的)。注意在實(shí)際應(yīng)用中,應(yīng)當(dāng)有幾十萬(wàn)個(gè)字,每個(gè)字可能有512個(gè)維度。
# 構(gòu)造word embedding src_embedding_table = nn.Embedding(9, model_dim) tgt_embedding_table = nn.Embedding(9, model_dim) # 輸入單詞的字典 print(src_embedding_table) # 目標(biāo)單詞的字典 print(tgt_embedding_table)
字典中,需要留一個(gè)維度給class token,故是9行。
1.4 得到向量化的句子
通過(guò)字典取出1.2
中得到的句子
# 得到向量化的句子 src_embedding = src_embedding_table(src_seq) tgt_embedding = tgt_embedding_table(tgt_seq) print(src_embedding) print(tgt_embedding)
該階段總程序
import torch # 句子長(zhǎng)度 src_len = torch.tensor([2, 4]).to(torch.int32) tgt_len = torch.tensor([4, 3]).to(torch.int32) # 構(gòu)造句子,用0填充空白處 src_seq = torch.cat([torch.unsqueeze(F.pad(torch.randint(1, 8, (L, )), (0, max(src_len)-L)), 0) for L in src_len]) tgt_seq = torch.cat([torch.unsqueeze(F.pad(torch.randint(1, 8, (L, )), (0, max(tgt_len)-L)), 0) for L in tgt_len]) # 構(gòu)造字典 src_embedding_table = nn.Embedding(9, 8) tgt_embedding_table = nn.Embedding(9, 8) # 得到向量化的句子 src_embedding = src_embedding_table(src_seq) tgt_embedding = tgt_embedding_table(tgt_seq) print(src_embedding) print(tgt_embedding)
二、位置編碼
位置編碼是transformer的一個(gè)重點(diǎn),通過(guò)加入transformer位置編碼,代替了傳統(tǒng)RNN的時(shí)序信息,增強(qiáng)了模型的并發(fā)度。位置編碼的公式如下:(其中pos代表行,i代表列)
2.1 計(jì)算括號(hào)內(nèi)的值
# 得到分子pos的值 pos_mat = torch.arange(4).reshape((-1, 1)) # 得到分母值 i_mat = torch.pow(10000, torch.arange(0, 8, 2).reshape((1, -1))/8) print(pos_mat) print(i_mat)
2.2 得到位置編碼
# 初始化位置編碼矩陣 pe_embedding_table = torch.zeros(4, 8) # 得到偶數(shù)行位置編碼 pe_embedding_table[:, 0::2] =torch.sin(pos_mat / i_mat) # 得到奇數(shù)行位置編碼 pe_embedding_table[:, 1::2] =torch.cos(pos_mat / i_mat) pe_embedding = nn.Embedding(4, 8) # 設(shè)置位置編碼不可更新參數(shù) pe_embedding.weight = nn.Parameter(pe_embedding_table, requires_grad=False) print(pe_embedding.weight)
三、多頭注意力
3.1 self mask
有些位置是空白用0填充的,訓(xùn)練時(shí)不希望被這些位置所影響,那么就需要用到self mask。self mask的原理是令這些位置的值為無(wú)窮小,經(jīng)過(guò)softmax后,這些值會(huì)變?yōu)?,不會(huì)再影響結(jié)果。
3.1.1 得到有效位置矩陣
# 得到有效位置矩陣 vaild_encoder_pos = torch.unsqueeze(torch.cat([torch.unsqueeze(F.pad(torch.ones(L), (0, max(src_len) - L)), 0)for L in src_len]), 2) valid_encoder_pos_matrix = torch.bmm(vaild_encoder_pos, vaild_encoder_pos.transpose(1, 2)) print(valid_encoder_pos_matrix)
3.1.2 得到無(wú)效位置矩陣
invalid_encoder_pos_matrix = 1-valid_encoder_pos_matrix mask_encoder_self_attention = invalid_encoder_pos_matrix.to(torch.bool) print(mask_encoder_self_attention)
True
代表需要對(duì)該位置mask
3.1.3 得到mask矩陣
用極小數(shù)填充需要被mask的位置
# 初始化mask矩陣 score = torch.randn(2, max(src_len), max(src_len)) # 用極小數(shù)填充 mask_score = score.masked_fill(mask_encoder_self_attention, -1e9) print(mask_score)
算其softmat
mask_score_softmax = F.softmax(mask_score) print(mask_score_softmax)
可以看到,已經(jīng)達(dá)到預(yù)期效果
到此這篇關(guān)于Pytorch從0實(shí)現(xiàn)Transformer的實(shí)踐的文章就介紹到這了,更多相關(guān)Pytorch Transformer內(nèi)容請(qǐng)搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
相關(guān)文章
Python實(shí)現(xiàn)根據(jù)指定端口探測(cè)服務(wù)器/模塊部署的方法
這篇文章主要介紹了Python根據(jù)指定端口探測(cè)服務(wù)器/模塊部署的方法,非常具有實(shí)用價(jià)值,需要的朋友可以參考下2014-08-08Python簡(jiǎn)單實(shí)現(xiàn)圖片轉(zhuǎn)字符畫的實(shí)例項(xiàng)目
這篇文章主要介紹了Python簡(jiǎn)單實(shí)現(xiàn)圖片轉(zhuǎn)字符畫的實(shí)例項(xiàng)目,文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來(lái)一起學(xué)習(xí)學(xué)習(xí)吧2021-04-04Python2 Selenium元素定位的實(shí)現(xiàn)(8種)
這篇文章主要介紹了Python2 Selenium元素定位的實(shí)現(xiàn),小編覺(jué)得挺不錯(cuò)的,現(xiàn)在分享給大家,也給大家做個(gè)參考。一起跟隨小編過(guò)來(lái)看看吧2019-02-02詳解 Python 與文件對(duì)象共事的實(shí)例
這篇文章主要介紹了詳解 Python 與文件對(duì)象共事的實(shí)例的相關(guān)資料,希望通過(guò)本文大家能掌握這部分內(nèi)容,需要的朋友可以參考下2017-09-09利用4行Python代碼監(jiān)測(cè)每一行程序的運(yùn)行時(shí)間和空間消耗
這篇文章主要介紹了如何使用4行Python代碼監(jiān)測(cè)每一行程序的運(yùn)行時(shí)間和空間消耗,本文給大家介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或工作具有一定的參考借鑒價(jià)值,需要的朋友可以參考下2020-04-04