Pytorch Dataset,TensorDataset,Dataloader,Sampler關系解讀
Dataloader
Dataloader是數據加載器,組合數據集和采樣器,并在數據集上提供單線程或多線程的迭代器。
所以Dataloader的參數必然需要指定數據集Dataset和采樣器Sampler。
class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, num_workers=0, collate_fn=<function default_collate>, pin_memory=False, drop_last=False)
- dataset (Dataset) – 數據集。
- batch_size (int, optional) – 每個batch加載樣本數。
- shuffle (bool, optional) – True則打亂數據.
- sampler (Sampler, optional) – 采樣器,如指定則忽略shuffle參數。
- num_workers (int, optional) – 用多少個子進程加載數據。0表示數據將在主進程中加載
- collate_fn (callable, optional) – 獲取batch數據的回調函數,也就是說可以在這個函數中修改batch的形式
- pin_memory (bool, optional) –
- drop_last (bool, optional) – 如果數據集大小不能被batch size整除,則設置為True后可刪除最后一個不完整的batch。如果設為False并且數據集的大小不能被batch size整除,則最后一個batch將更小。
Dataset和TensorDataset
所有其他數據集都應該進行子類化。所有子類應該override __len__
和 __getitem__
,前者提供了數據集的大小,后者支持整數索引,范圍從0到len(self)。
TensorDataset是Dataset的子類,已經復寫了 __len__
和 __getitem__
方法,只要傳入張量即可,它通過第一個維度進行索引。
所以TensorDataset說白了就是將輸入的tensors捆綁在一起,然后 __len__
是任何一個tensor的維度, __getitem__
表示每個tensor取相同的索引,然后將這個結果組成一個元組,源碼如下,要好好理解它通過第一個維度進行索引的意思(針對tensors里面的每一個tensor而言)。
class TensorDataset(Dataset): def __init__(self,*tensors): assert all(tensors[0].size(0)==tensor.size(0) for tensor in tensors) self.tensors = tensors def __getitem__(self,index): return tuple(tensor[index] for tensor in self.tensors) def __len__(self): return self.tensors[0].size(0)
Sampler和RandomSampler
Sampler與Dataset類似,是采樣器的基礎類。
每個采樣器子類必須提供一個 __iter__
方法,提供一種迭代數據集元素的索引的方法,以及返回迭代器長度的 __len__
方法。
所以Sampler必然是關于索引的迭代器,也就是它的輸出是索引。
而RandomSampler與TensorDataset類似,RandomSamper已經實現了 __iter__
和 __len__
方法,只需要傳入數據集即可。
猜想理解RandomSampler的實現方式,考慮到這個類實現需要傳入Dataset,所以 __len__
就是Dataset的 __len__
,然后 __iter__
就可以隨便搞一個隨機函數對range(length)隨機即可。
綜合示例
結合TensorDataset和RandomSampler使用Dataloader
這里即可理解Dataloader這個數據加載器其實就是組合數據集和采樣器的組合。
所以那就是先根據Sampler隨機拿到一個索引,再用這個索引到Dataset中取tensors里每個tensor對應索引的數據來組成一個元組。
總結
以上為個人經驗,希望能給大家一個參考,也希望大家多多支持腳本之家。