Pytorch使用DataLoader實現(xiàn)批量加載數(shù)據(jù)
在進行模型訓(xùn)練時,需要把數(shù)據(jù)按照固定的形式分批次投喂給模型,在PyTorch中通過torch.utils.data庫的DataLoader
完成分批次返回數(shù)據(jù)。
構(gòu)造DataLoader首先需要一個Dataset
數(shù)據(jù)源,Dataset完成數(shù)據(jù)的讀取并可以返回單個數(shù)據(jù),然后DataLoader在此基礎(chǔ)上完成數(shù)據(jù)清洗、打亂等操作并按批次返回數(shù)據(jù)。
Dataset
PyTorch將數(shù)據(jù)源分為兩種類型:類似Map型(Map-style datasets)和可迭代型(Iterable-style datasets)。
Map風(fēng)格的數(shù)據(jù)源可以通過索引idx對數(shù)據(jù)進行查找:dataset[idx]
,它需要繼承Dataset
類,并且重寫__getitem__()
方法完成根據(jù)索引值獲取數(shù)據(jù)和__len__()
方法返回數(shù)據(jù)的總長度。
可迭代型可以迭代獲取其數(shù)據(jù),但沒有固定的長度,因此也不能通過下標獲得數(shù)據(jù),通常用于無法獲取全部數(shù)據(jù)或者流式返回的數(shù)據(jù)。它繼承自IterableDataset
類,并且需要實現(xiàn)__iter__()
方法來完成對數(shù)據(jù)集的迭代和返回。
如下所示為自定義的數(shù)據(jù)源MySet
,它完成數(shù)據(jù)的讀取,這里假定為[1, 9] 9個數(shù)據(jù),然后重寫了__getitem__() 和__len__() 方法
from torch.utils.data import Dataset, DataLoader, Sampler class MySet(Dataset): # 讀取數(shù)據(jù) def __init__(self): self.data = [1, 2, 3, 4, 5, 6, 7, 8, 9] # 根據(jù)索引返回數(shù)據(jù) def __getitem__(self, idx): return self.data[idx] # 返回數(shù)據(jù)集總長度 def __len__(self): return len(self.data)
DataLoader
其構(gòu)造函數(shù)如下:
torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, multiprocessing_context=None)
dataset
:Dataset類型,從其中加載數(shù)據(jù) batch_size:int,可選。每個batch加載多少樣本batch_size
: 一個批次的數(shù)據(jù)個數(shù)shuffle
:bool,可選。為True時表示每個epoch都對數(shù)據(jù)進行洗牌sampler
:Sampler,可選。獲取下一個數(shù)據(jù)的方法。batch_sampler
:獲取下一批次數(shù)據(jù)的方法num_workers
:int,可選。加載數(shù)據(jù)時使用多少子進程。默認值為0,表示在主進程中加載數(shù)據(jù)。collate_fn
:callable,可選,自定義處理數(shù)據(jù)并返回。pin_memory
:bool,可選,True代表將數(shù)據(jù)Tensor放入CUDA的pin儲存drop_last
:bool,可選。True表示如果最后剩下不完全的batch,丟棄。False表示不丟棄。
Sampler索引
既然DataLoader根據(jù)索引值從Dataset中獲取數(shù)據(jù),那么如何獲取一個批次數(shù)據(jù)的索引,索引值應(yīng)該如何排列才能實現(xiàn)隨機的效果?這就需要Sampler
了,它可以對索引進行shuffle操作來打亂順序,并且根據(jù)batch size一次返回指定個數(shù)的索引序列。
在初始化DataLoader時通過sampler
屬性指定獲取下一個數(shù)據(jù)的索引的方法,或者batch_sampler
屬性指定獲取下一個批次數(shù)據(jù)的索引。
當我們設(shè)置DataLoader的shuffle
屬性為True時,會根據(jù)batch_size
屬性傳入的批次大小自動構(gòu)造sample返回下一個批次的索引。
當我們不啟用shuffle屬性時,就可以通過batch_sampler
屬性自定義sample來返回下一批的索引,注意這時候不可用使用 batch_size
, shuffle
, sampler
, 和drop_last
屬性。
如下所示為自定義MySampler
,它繼承自Sampler
,由傳入dataset
的長度產(chǎn)生對應(yīng)的索引,例如上面有9個數(shù)據(jù),那么產(chǎn)生索引[0, 8]。
根據(jù)批次大小batch_size
計算出總批次數(shù),例如當batchsize是3,那么9/3=3,即總共有3個批次。
重寫__iter__()
方法按批次返回索引,即第一批返回[0, 1, 2],第二批返回[3, 4, 5]以此類推。
__len__()
方法返回總的批次數(shù),即3個批次。
class MySampler(Sampler): def __init__(self, dataset, batchsize): super(Sampler, self).__init__() self.dataset = dataset self.batch_size = batchsize # 每一批數(shù)據(jù)量 self.indices = range(len(dataset)) # 生成數(shù)據(jù)集的索引 self.count = int(len(dataset) / self.batch_size) # 一共有多少批 def __iter__(self): for i in range(self.count): yield self.indices[i * self.batch_size: (i + 1) * self.batch_size] def __len__(self): return self.count
collate處理數(shù)據(jù)
當我們拿到數(shù)據(jù)如果希望進行一些預(yù)處理而不是直接返回,這時候就需要collate_fn屬性來指定處理和返回數(shù)據(jù)的方法,如果不指定該屬性,默認會將普通的NumPy數(shù)組轉(zhuǎn)換為PyTorch的tensor并直接返回。
如下所示為自定義的my_collate()
函數(shù),默認傳入獲得的一個批次的數(shù)據(jù)data,例如之前返回一批數(shù)據(jù)[1, 2, 3],這里遍歷數(shù)據(jù)并平方之后放在res數(shù)組中返回[1, 4, 9]
def my_collate(data): res = [] for d in data: res.append(d ** 2) return res
有了上面的索引獲取類MySampler
和數(shù)據(jù)處理函數(shù)my_collate()
,就可以使用DataLoader自定義獲取批數(shù)據(jù)了。
首先DataLoader通過my_sampler
返回的索引[0, 1, 2]去dataset
拿到數(shù)據(jù)[1, 2, 3],然后傳遞給my_collate進行平方操作,然后返回一個批次的結(jié)果為[1, 4, 9],一共有三個批次的數(shù)據(jù)。
dataset = MySet() # 定義數(shù)據(jù)集 my_sampler = MySampler(dataset, 3) # 實例化MySampler data_loader = DataLoader(dataset, batch_sampler=my_sampler, collate_fn=my_collate) for data in data_loader: # 按批次獲取數(shù)據(jù) print(data) ''' [1, 4, 9] [16, 25, 36] [49, 64, 81] '''
總結(jié)
以上為個人經(jīng)驗,希望能給大家一個參考,也希望大家多多支持腳本之家。
相關(guān)文章
Pandas數(shù)值排序 sort_values()的使用
本文主要介紹了Pandas數(shù)值排序 sort_values()的使用,文中通過示例代碼介紹的非常詳細,對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧2022-07-07python numpy.power()數(shù)組元素求n次方案例
這篇文章主要介紹了python numpy.power()數(shù)組元素求n次方案例,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧2021-03-03使用Python代碼實現(xiàn)Linux中的ls遍歷目錄命令的實例代碼
這次我就要試著用 Python 來實現(xiàn)一下 Linux 中的 ls 命令, 小小地證明下 Python 的不簡單,需要的朋友可以參考下2019-09-09