Pytorch使用技巧之Dataloader中的collate_fn參數(shù)詳析
以MNIST為例
from torchvision import datasets mnist = datasets.MNIST(root='./data/', train=True, download=True) print(mnist[0])
結(jié)果
(<PIL.Image.Image image mode=L size=28x28 at 0x196E3F1D898>, 5)
MINIST數(shù)據(jù)集的dataset是由一張圖片和一個(gè)label組成的元組
dataloader = torch.utils.data.DataLoader(dataset=mnist, batch_size=2, shuffle=True,collate_fn=lambda x:x) for each in dataloader: print(each) break
結(jié)果
[(<PIL.Image.Image image mode=L size=28x28 at 0x2CB3B105630>, 0), (<PIL.Image.Image image mode=L size=28x28 at 0x2CB3B105668>, 2)]
collate_fn為lamda x:x時(shí)表示對(duì)傳入進(jìn)來的數(shù)據(jù)不做處理
下面自定義collate_fn看看什么效果
def collate(data): img = [] label = [] for each in data: img.append(each[0]) label.append(each[1]) return img,label dataloader = torch.utils.data.DataLoader(dataset=mnist, batch_size=2, shuffle=True,collate_fn=lambda x:collate(x)) for each in dataloader: print(each) break
結(jié)果
([<PIL.Image.Image image mode=L size=28x28 at 0x241433A36D8>, <PIL.Image.Image image mode=L size=28x28 at 0x241433A3710>], [9, 3])
說明:若不設(shè)置collate_fn參數(shù)則會(huì)使用默認(rèn)處理函數(shù)
但必須保證傳進(jìn)來的數(shù)據(jù)都是tensor格式否則會(huì)報(bào)錯(cuò)
附:DataLoader完整的參數(shù)表如下:
class torch.utils.data.DataLoader( dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=<function default_collate>, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None)
DataLoader在數(shù)據(jù)集上提供單進(jìn)程或多進(jìn)程的迭代器
幾個(gè)關(guān)鍵的參數(shù)意思:
- shuffle:設(shè)置為True的時(shí)候,每個(gè)世代都會(huì)打亂數(shù)據(jù)集
- collate_fn:如何取樣本的,我們可以定義自己的函數(shù)來準(zhǔn)確地實(shí)現(xiàn)想要的功能
- drop_last:告訴如何處理數(shù)據(jù)集長(zhǎng)度除于batch_size余下的數(shù)據(jù)。True就拋棄,否則保留
總結(jié)
到此這篇關(guān)于Pytorch使用技巧之Dataloader中的collate_fn參數(shù)的文章就介紹到這了,更多相關(guān)Dataloader中的collate_fn參數(shù)內(nèi)容請(qǐng)搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
相關(guān)文章
Python算法應(yīng)用實(shí)戰(zhàn)之隊(duì)列詳解
隊(duì)列是一種先進(jìn)先出(First-In-First-Out,F(xiàn)IFO)的數(shù)據(jù)結(jié)構(gòu)。隊(duì)列被用在很多地方,比如提交操作系統(tǒng)執(zhí)行的一系列進(jìn)程、打印任務(wù)池等,一些仿真系統(tǒng)用隊(duì)列來模擬銀行或雜貨店里排隊(duì)的顧客。下面就介紹了Python中隊(duì)列的應(yīng)用實(shí)戰(zhàn),需要的可以參考。2017-02-02pytorch通過自己的數(shù)據(jù)集訓(xùn)練Unet網(wǎng)絡(luò)架構(gòu)
Unet是一個(gè)最近比較火的網(wǎng)絡(luò)結(jié)構(gòu)。它的理論已經(jīng)有很多大佬在討論了。本文主要從實(shí)際操作的層面,講解如何使用pytorch實(shí)現(xiàn)unet圖像分割2022-12-12Python基礎(chǔ)知識(shí)方法重寫+文件處理+異常處理
這篇文章主要介紹了Python基礎(chǔ)知識(shí)方法重寫+文件處理+異常處理,這是基礎(chǔ)知識(shí)分享的第四篇,看到這里了相信大家前幾篇都學(xué)得還不錯(cuò)吧,下面我們繼續(xù)鞏固Python基礎(chǔ)知識(shí),需要的朋友也可以參考一下2022-05-05在Python中通過threshold創(chuàng)建mask方式
今天小編就為大家分享一篇在Python中通過threshold創(chuàng)建mask方式,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過來看看吧2020-02-02pandas基礎(chǔ)?Series與Dataframe與numpy對(duì)二進(jìn)制文件輸入輸出
這篇文章主要介紹了pandas基礎(chǔ)Series與Dataframe與numpy對(duì)二進(jìn)制文件輸入輸出,series是一種一維的數(shù)組型對(duì)象,它包含了一個(gè)值序列和一個(gè)數(shù)據(jù)標(biāo)簽2022-07-07