Pytorch DataLoader 變長(zhǎng)數(shù)據(jù)處理方式
關(guān)于Pytorch中怎么自定義Dataset數(shù)據(jù)集類、怎樣使用DataLoader迭代加載數(shù)據(jù),這篇官方文檔已經(jīng)說(shuō)得很清楚了,這里就不在贅述。
現(xiàn)在的問(wèn)題:有的時(shí)候,特別對(duì)于NLP任務(wù)來(lái)說(shuō),輸入的數(shù)據(jù)可能不是定長(zhǎng)的,比如多個(gè)句子的長(zhǎng)度一般不會(huì)一致,這時(shí)候使用DataLoader加載數(shù)據(jù)時(shí),不定長(zhǎng)的句子會(huì)被胡亂切分,這肯定是不行的。
解決方法是重寫DataLoader的collate_fn,具體方法如下:
# 假如每一個(gè)樣本為: sample = { # 一個(gè)句子中各個(gè)詞的id 'token_list' : [5, 2, 4, 1, 9, 8], # 結(jié)果y 'label' : 5, } # 重寫collate_fn函數(shù),其輸入為一個(gè)batch的sample數(shù)據(jù) def collate_fn(batch): # 因?yàn)閠oken_list是一個(gè)變長(zhǎng)的數(shù)據(jù),所以需要用一個(gè)list來(lái)裝這個(gè)batch的token_list token_lists = [item['token_list'] for item in batch] # 每個(gè)label是一個(gè)int,我們把這個(gè)batch中的label也全取出來(lái),重新組裝 labels = [item['label'] for item in batch] # 把labels轉(zhuǎn)換成Tensor labels = torch.Tensor(labels) return { 'token_list': token_lists, 'label': labels, } # 在使用DataLoader加載數(shù)據(jù)時(shí),注意collate_fn參數(shù)傳入的是重寫的函數(shù) DataLoader(trainset, batch_size=4, shuffle=True, num_workers=4, collate_fn=collate_fn)
使用以上方法,可以保證DataLoader能Load出一個(gè)batch的數(shù)據(jù),load出來(lái)的東西就是重寫的collate_fn函數(shù)最后return出來(lái)的字典。
以上這篇Pytorch DataLoader 變長(zhǎng)數(shù)據(jù)處理方式就是小編分享給大家的全部?jī)?nèi)容了,希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
相關(guān)文章
快速上手基于Anaconda搭建Django環(huán)境的教程
Django具有完整的封裝,開發(fā)者可以高效率的開發(fā)項(xiàng)目,Django將大部分的功能進(jìn)行了封裝,開發(fā)者只需要調(diào)用即可,接下來(lái)通過(guò)本文給大家介紹基于Anaconda搭建Django環(huán)境的教程,需要的朋友可以參考下2021-10-10python簡(jiǎn)單圖片操作:打開\顯示\保存圖像方法介紹
這篇文章主要介紹了python簡(jiǎn)單圖片操作:打開\顯示\保存圖像方法介紹,還涉及將圖片保存為灰度圖的簡(jiǎn)單方法示例,具有一定參考價(jià)值,需要的朋友可以了解下。2017-11-11Python shelve模塊實(shí)現(xiàn)解析
這篇文章主要介紹了Python shelve模塊實(shí)現(xiàn)解析,文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友可以參考下2019-08-08