pytorch中dataloader 的sampler 參數詳解
1. dataloader() 初始化函數
def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, multiprocessing_context=None):
其中幾個常用的參數:
- dataset 數據集,map-style and iterable-style 可以用index取值的對象、
- batch_size 大小
- shuffle 取batch是否隨機取, 默認為False
- sampler 定義取batch的方法,是一個迭代器, 每次生成一個key 用于讀取dataset中的值
- batch_sampler 也是一個迭代器, 每次生次一個batch_size的key
- num_workers 參與工作的線程數collate_fn 對取出的batch進行處理
- drop_last 對最后不足batchsize的數據的處理方法
下面看兩段取自DataLoader中的__init__代碼, 幫助我們理解幾個常用參數之間的關系
2. shuffle 與sample 之間的關系
當我們sampler有輸入時,shuffle的值就沒有意義,
if sampler is None: # give default samplers if self._dataset_kind == _DatasetKind.Iterable: # See NOTE [ Custom Samplers and IterableDataset ] sampler = _InfiniteConstantSampler() else: # map-style if shuffle: sampler = RandomSampler(dataset) else: sampler = SequentialSampler(dataset)
當dataset類型是map style時, shuffle其實就是改變sampler的取值
- shuffle為默認值 False時,sampler是SequentialSampler,就是按順序取樣,
- shuffle為True時,sampler是RandomSampler, 就是按隨機取樣
3. sample 的定義方法
3.1 sampler 參數的使用
sampler 是用來定義取batch方法的一個函數或者類,返回的是一個迭代器。
我們可以看下自帶的RandomSampler類中最重要的iter函數
def __iter__(self): n = len(self.data_source) # dataset的長度, 按順序索引 if self.replacement:# 對應的replace參數 return iter(torch.randint(high=n, size=(self.num_samples,), dtype=torch.int64).tolist()) return iter(torch.randperm(n).tolist())
可以看出,其實就是生成索引,然后隨機的取值, 然后再迭代。
其實還有一些細節(jié)需要注意理解:
比如__len__函數,包括DataLoader的len和sample的len, 兩者區(qū)別, 這部分代碼比較簡單,可以自行閱讀,其實參考著RandomSampler寫也不會出現問題。
比如,迭代器和生成器的使用, 以及區(qū)別
if batch_size is not None and batch_sampler is None: # auto_collation without custom batch_sampler batch_sampler = BatchSampler(sampler, batch_size, drop_last) self.sampler = sampler self.batch_sampler = batch_sampler
BatchSampler的生成過程:
# 略去類的初始化 def __iter__(self): batch = [] for idx in self.sampler: batch.append(idx) if len(batch) == self.batch_size: yield batch batch = [] if len(batch) > 0 and not self.drop_last: yield batch
就是按batch_size從sampler中讀取索引, 并形成生成器返回。
以上可以看出, batch_sampler和sampler, batch_size, drop_last之間的關系
- 如果batch_sampler沒有定義的話且batch_size有定義, 會根據sampler, batch_size, drop_last生成一個batch_sampler
- 自帶的注釋中對batch_sampler有一句話: Mutually exclusive with :attr:batch_size :attr:shuffle, :attr:sampler, and :attr:drop_last.
- 意思就是b
- atch_sampler 與這些參數沖突 ,即 如果你定義了batch_sampler, 其他參數都不需要有
4. batch 生成過程
每個batch都是由迭代器產生的:
# DataLoader中iter的部分 def __iter__(self): if self.num_workers == 0: return _SingleProcessDataLoaderIter(self) else: return _MultiProcessingDataLoaderIter(self) # 再看調用的另一個類 class _SingleProcessDataLoaderIter(_BaseDataLoaderIter): def __init__(self, loader): super(_SingleProcessDataLoaderIter, self).__init__(loader) assert self._timeout == 0 assert self._num_workers == 0 self._dataset_fetcher = _DatasetKind.create_fetcher( self._dataset_kind, self._dataset, self._auto_collation, self._collate_fn, self._drop_last) def __next__(self): index = self._next_index() data = self._dataset_fetcher.fetch(index) if self._pin_memory: data = _utils.pin_memory.pin_memory(data) return data
到此這篇關于pytorch中dataloader 的sampler 參數詳解的文章就介紹到這了,更多相關pytorch sampler 內容請搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關文章希望大家以后多多支持腳本之家!
相關文章
對python3 Serial 串口助手的接收讀取數據方法詳解
今天小編就為大家分享一篇對python3 Serial 串口助手的接收讀取數據方法詳解,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧2019-06-06解讀pandas交叉表與透視表pd.crosstab()和pd.pivot_table()函數
這篇文章主要介紹了pandas交叉表與透視表pd.crosstab()和pd.pivot_table()函數的用法,具有很好的參考價值,希望對大家有所幫助,如有錯誤或未考慮完全的地方,望不吝賜教2023-09-09python爬蟲模擬瀏覽器訪問-User-Agent過程解析
這篇文章主要介紹了python爬蟲模擬瀏覽器訪問-User-Agent過程解析,文中通過示例代碼介紹的非常詳細,對大家的學習或者工作具有一定的參考學習價值,需要的朋友可以參考下2019-12-12對python中Matplotlib的坐標軸的坐標區(qū)間的設定實例講解
今天小編就為大家分享一篇對python中Matplotlib的坐標軸的坐標區(qū)間的設定實例講解,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧2018-05-05