亚洲乱码中文字幕综合,中国熟女仑乱hd,亚洲精品乱拍国产一区二区三区,一本大道卡一卡二卡三乱码全集资源,又粗又黄又硬又爽的免费视频

pytorch實(shí)現(xiàn)seq2seq時(shí)對loss進(jìn)行mask的方式

 更新時(shí)間:2020年02月18日 09:55:43   作者:uhauha2929  
今天小編就為大家分享一篇pytorch實(shí)現(xiàn)seq2seq時(shí)對loss進(jìn)行mask的方式,具有很好的參考價(jià)值,希望對大家有所幫助。一起跟隨小編過來看看吧

如何對loss進(jìn)行mask

pytorch官方教程中有一個(gè)Chatbot教程,就是利用seq2seq和注意力機(jī)制實(shí)現(xiàn)的,感覺和機(jī)器翻譯沒什么不同啊,如果對話中一句話有下一句,那么就把這一對句子加入模型進(jìn)行訓(xùn)練。其中在訓(xùn)練階段,損失函數(shù)通常需要進(jìn)行mask操作,因?yàn)橐粋€(gè)batch中句子的長度通常是不一樣的,一個(gè)batch中不足長度的位置需要進(jìn)行填充(pad)補(bǔ)0,最后生成句子計(jì)算loss時(shí)需要忽略那些原本是pad的位置的值,即只保留mask中值為1位置的值,忽略值為0位置的值,具體演示如下:

import torch
import torch.nn as nn
import torch.nn.functional as F
import itertools

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

PAD_token = 0

首先是pad函數(shù)和建立mask矩陣,矩陣的維度應(yīng)該和目標(biāo)一致。

def zeroPadding(l, fillvalue=PAD_token):
 # 輸入:[[1, 1, 1], [2, 2], [3]]
 # 返回:[(1, 2, 3), (1, 2, 0), (1, 0, 0)] 返回已經(jīng)是轉(zhuǎn)置后的 [L, B]
 return list(itertools.zip_longest(*l, fillvalue=fillvalue))


def binaryMatrix(l):
 # 將targets里非pad部分標(biāo)記為1,pad部分標(biāo)記為0
 m = []
 for i, seq in enumerate(l):
 m.append([])
 for token in seq:
  if token == PAD_token:
  m[i].append(0)
  else:
  m[i].append(1)
 return m

假設(shè)現(xiàn)在輸入一個(gè)batch中有三個(gè)句子,我們按照長度從大到小排好序,LSTM或是GRU的輸入和輸出我們需要利用pack_padded_sequence和pad_packed_sequence進(jìn)行打包和解包,感覺也是在進(jìn)行mask操作。

inputs = [[1, 2, 3], [4, 5], [6]] # 輸入句,一個(gè)batch,需要按照長度從大到小排好序
inputs_lengths = [3, 2, 1]
targets = [[1, 2], [1, 2, 3], [1]] # 目標(biāo)句,這里的長度是不確定的,mask是針對targets的
inputs_batch = torch.LongTensor(zeroPadding(inputs))
inputs_lengths = torch.LongTensor(inputs_lengths)
targets_batch = torch.LongTensor(zeroPadding(targets))
targets_mask = torch.ByteTensor(binaryMatrix(zeroPadding(targets))) # 注意這里是ByteTensor
print(inputs_batch)
print(targets_batch)
print(targets_mask)

打印后結(jié)果如下,可見維度統(tǒng)一變成了[L, B],并且mask和target長得一樣。另外,seq2seq模型處理時(shí)for循環(huán)每次讀取一行,預(yù)測下一行的值(即[B, L]時(shí)的一列預(yù)測下一列)。

tensor([[ 1, 4, 6],
 [ 2, 5, 0],
 [ 3, 0, 0]])
tensor([[ 1, 1, 1],
 [ 2, 2, 0],
 [ 0, 3, 0]])
tensor([[ 1, 1, 1],
 [ 1, 1, 0],
 [ 0, 1, 0]], dtype=torch.uint8)

現(xiàn)在假設(shè)我們將inputs輸入模型后,模型讀入sos后預(yù)測的第一行為outputs1, 維度為[B, vocab_size],即每個(gè)詞在詞匯表中的概率,模型輸出之前需要softmax。

outputs1 = torch.FloatTensor([[0.2, 0.1, 0.7], [0.3, 0.6, 0.1], [0.4, 0.5, 0.1]])
print(outputs1)
tensor([[ 0.2000, 0.1000, 0.7000],
 [ 0.3000, 0.6000, 0.1000],
 [ 0.4000, 0.5000, 0.1000]])

先看看兩個(gè)函數(shù)

torch.gather(input, dim, index, out=None)->Tensor

沿著某個(gè)軸,按照指定維度采集數(shù)據(jù),對于3維數(shù)據(jù),相當(dāng)于進(jìn)行如下操作:

out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2

比如在這里,在第1維,選第二個(gè)元素。

# 收集每行的第2個(gè)元素
temp = torch.gather(outputs1, 1, torch.LongTensor([[1], [1], [1]]))
print(temp)
tensor([[ 0.1000],
 [ 0.6000],
 [ 0.5000]])

torch.masked_select(input, mask, out=None)->Tensor

根據(jù)mask(ByteTensor)選取對應(yīng)位置的值,返回一維張量。

例如在這里我們選取temp大于等于0.5的值。

mask = temp.ge(0.5) # 大于等于0.5
print(mask)
print(torch.masked_select(temp, temp.ge(0.5)))
tensor([[ 0],
 [ 1],
 [ 1]], dtype=torch.uint8)
tensor([ 0.6000, 0.5000])

然后我們就可以計(jì)算loss了,這里是負(fù)對數(shù)損失函數(shù),之前模型的輸出要進(jìn)行softmax。

# 計(jì)算一個(gè)batch內(nèi)的平均負(fù)對數(shù)似然損失,即只考慮mask為1的元素
def maskNLLLoss(inp, target, mask):
 nTotal = mask.sum()
 # 收集目標(biāo)詞的概率,并取負(fù)對數(shù)
 crossEntropy = -torch.log(torch.gather(inp, 1, target.view(-1, 1)))
 # 只保留mask中值為1的部分,并求均值
 loss = crossEntropy.masked_select(mask).mean()
 loss = loss.to(DEVICE)
 return loss, nTotal.item()

這里我們計(jì)算第一行的平均損失。

# 計(jì)算預(yù)測的第一行和targets的第一行的loss
maskNLLLoss(outputs1, targets_batch[0], targets_mask[0])

(tensor(1.1689, device='cuda:0'), 3)

最后進(jìn)行最后把所有行的loss累加起來變?yōu)閠otal_loss.backward()進(jìn)行反向傳播就可以了。

以上這篇pytorch實(shí)現(xiàn)seq2seq時(shí)對loss進(jìn)行mask的方式就是小編分享給大家的全部內(nèi)容了,希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。

相關(guān)文章

  • python3學(xué)習(xí)之Splash的安裝與實(shí)例教程

    python3學(xué)習(xí)之Splash的安裝與實(shí)例教程

    splash 是一個(gè)python語言編寫的用于配合scrapy解析js的庫,下面這篇文章主要給大家介紹了關(guān)于python3學(xué)習(xí)之Splash的安裝與使用的一些相關(guān)資料,文中通過示例代碼介紹的非常詳細(xì),需要的朋友可以參考下
    2018-07-07
  • Python 堆疊柱狀圖繪制方法

    Python 堆疊柱狀圖繪制方法

    這篇文章主要介紹了Python 堆疊柱狀圖繪制方法,文中通過示例代碼介紹的非常詳細(xì),對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧
    2019-07-07
  • PyQt Qt Designer工具的布局管理詳解

    PyQt Qt Designer工具的布局管理詳解

    這篇文章主要介紹了PyQt Qt Designer工具的布局管理詳解,文中通過示例代碼介紹的非常詳細(xì),對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友可以參考下
    2019-08-08
  • Python自動(dòng)化操作Excel方法詳解(xlrd,xlwt)

    Python自動(dòng)化操作Excel方法詳解(xlrd,xlwt)

    Excel是Windows環(huán)境下流行的、強(qiáng)大的電子表格應(yīng)用。本文將詳解用Python利用xlrd和xlwt實(shí)現(xiàn)自動(dòng)化操作Excel的方法詳細(xì),需要的可以參考一下
    2022-06-06
  • Python機(jī)器學(xué)習(xí)logistic回歸代碼解析

    Python機(jī)器學(xué)習(xí)logistic回歸代碼解析

    這篇文章主要介紹了Python機(jī)器學(xué)習(xí)logistic回歸代碼解析,具有一定借鑒價(jià)值,需要的朋友可以參考下
    2018-01-01
  • python如何將兩張圖片生成為全景圖片

    python如何將兩張圖片生成為全景圖片

    這篇文章主要為大家詳細(xì)介紹了python如何將兩張圖片生成為全景圖片,文中示例代碼介紹的非常詳細(xì),具有一定的參考價(jià)值,感興趣的小伙伴們可以參考一下
    2020-03-03
  • matplotlib畫圖之修改坐標(biāo)軸刻度問題

    matplotlib畫圖之修改坐標(biāo)軸刻度問題

    這篇文章主要介紹了matplotlib畫圖之修改坐標(biāo)軸刻度問題,具有很好的參考價(jià)值,希望對大家有所幫助。如有錯(cuò)誤或未考慮完全的地方,望不吝賜教
    2022-11-11
  • 簡單了解Python中的幾種函數(shù)

    簡單了解Python中的幾種函數(shù)

    這篇文章主要介紹了簡單了解Python中的幾種函數(shù),具有一定參考價(jià)值。需要的朋友可以了解下。
    2017-11-11
  • Pycharm新建項(xiàng)目時(shí)報(bào)錯(cuò)解決辦法

    Pycharm新建項(xiàng)目時(shí)報(bào)錯(cuò)解決辦法

    pycharm可以很方便的管理Python的解釋器(如果安裝了多個(gè)的話),以及第三方模塊,包,下面這篇文章主要給大家介紹了關(guān)于Pycharm新建項(xiàng)目時(shí)報(bào)錯(cuò)解決的相關(guān)資料,需要的朋友可以參考下
    2023-06-06
  • pytorch實(shí)現(xiàn)保證每次運(yùn)行使用的隨機(jī)數(shù)都相同

    pytorch實(shí)現(xiàn)保證每次運(yùn)行使用的隨機(jī)數(shù)都相同

    今天小編就為大家分享一篇pytorch實(shí)現(xiàn)保證每次運(yùn)行使用的隨機(jī)數(shù)都相同,具有很好的參考價(jià)值,希望對大家有所幫助。一起跟隨小編過來看看吧
    2020-02-02

最新評論