Pytorch建模過(guò)程中的DataLoader與Dataset示例詳解
處理數(shù)據(jù)樣本的代碼會(huì)因?yàn)樘幚磉^(guò)程繁雜而變得混亂且難以維護(hù),在理想情況下,我們希望數(shù)據(jù)預(yù)處理過(guò)程代碼與我們的模型訓(xùn)練代碼分離,以獲得更好的可讀性和模塊化,為此,PyTorch提供了torch.utils.data.DataLoader
和 torch.utils.data.Dataset
兩個(gè)類用于數(shù)據(jù)處理。其中torch.utils.data.DataLoader
用于將數(shù)據(jù)集進(jìn)行打包封裝成一個(gè)可迭代對(duì)象,torch.utils.data.Dataset
存儲(chǔ)有一些常用的數(shù)據(jù)集示例以及相關(guān)標(biāo)簽。
同時(shí)PyTorch針對(duì)不同的專業(yè)領(lǐng)域,也提供有不同的模塊,例如 TorchText
(自然語(yǔ)言處理), TorchVision
(計(jì)算機(jī)視覺(jué)), TorchAudio
(音頻),這些模塊中也都包含一些真實(shí)數(shù)據(jù)集示例。例如TorchVision
模塊中提供了CIFAR, COCO, FashionMNIST 數(shù)據(jù)集。
1 定義數(shù)據(jù)集
pytorch中提供兩種風(fēng)格的數(shù)據(jù)集定義方式:
- 字典映射風(fēng)格。之所以稱為映射風(fēng)格,是因?yàn)樵诤罄m(xù)加載數(shù)據(jù)迭代時(shí),pytorch將自動(dòng)使用迭代索引作為key,通過(guò)字典索引的方式獲取value,本質(zhì)就是將數(shù)據(jù)集定義為一個(gè)字典,使用這種風(fēng)格時(shí),需要繼承
Dataset
類。
In [54]:
from torch.utils.data import Dataset from torch.utils.data import DataLoader
In [56]:
dataset = {0: '張三', 1:'李四', 2:'王五', 3:'趙六', 4:'陳七'} dataloader = DataLoader(dataset, batch_size=2) for i, value in enumerate(dataloader): print(i, value)
0 ['張三', '李四'] 1 ['王五', '趙六'] 2 ['陳七']
- 迭代器風(fēng)格。在自定義數(shù)據(jù)集類中,實(shí)現(xiàn)
__iter__
和__next__
方法,即定義為迭代器,在后續(xù)加載數(shù)據(jù)迭代時(shí),pytorch將依次獲取value,使用這種風(fēng)格時(shí),需要繼承IterableDataset
類。這種方法在數(shù)據(jù)量巨大,無(wú)法一下全部加載到內(nèi)存時(shí)非常實(shí)用。
In [57]:
from torch.utils.data import DataLoader from torch.utils.data import IterableDataset
In [58]:
dataset = [i for i in range(10)] dataloader = DataLoader(dataset=dataset, batch_size=3, shuffle=True) for i, item in enumerate(dataloader): # 迭代輸出 print(i, item)
0 tensor([3, 1, 2]) 1 tensor([9, 7, 5]) 2 tensor([0, 8, 4]) 3 tensor([6])
如下所示,我們有一個(gè)螞蟻蜜蜂圖像分類數(shù)據(jù)集,目錄結(jié)構(gòu)如下所示,下面我們結(jié)合這個(gè)數(shù)據(jù)集,分別介紹如何使用這兩個(gè)類定義真實(shí)數(shù)據(jù)集。
data └── hymenoptera_data ├── train │?? ├── ants │?? │?? ├── 0013035.jpg │ │ …… │?? └── bees │?? ├── 1092977343_cb42b38d62.jpg │ …… └── val ├── ants │?? ├── 10308379_1b6c72e180.jpg │?? …… └── bees ├── 1032546534_06907fe3b3.jpg ……
1.2 Dataset類
自定義一個(gè)Dataset類,繼承torch.utils.data.Dataset,且必須實(shí)現(xiàn)下面三個(gè)方法:
Dataset類里面的
__init__
函數(shù)初始化一些參數(shù),如讀取外部數(shù)據(jù)源文件。Dataset類里面的
__getitem__
函數(shù),映射取值是調(diào)用的方法,獲取單個(gè)的數(shù)據(jù),訓(xùn)練迭代時(shí)將會(huì)調(diào)用這個(gè)方法。Dataset類里面的
__len__
函數(shù)獲取數(shù)據(jù)的總量。
In [211]:
import os import pandas as pd from PIL import Image from torchvision.transforms import ToTensor, Lambda from torchvision import transforms import torchvision class AntBeeDataset(Dataset): # 把圖片所在的文件夾路徑分成兩個(gè)部分,一部分是根目錄,一部分是標(biāo)簽?zāi)夸洠@是因?yàn)闃?biāo)簽?zāi)夸浀拿Q我們需要用到 def __init__(self, root_dir, transform=None, target_transform=None): """ root_dir:存放數(shù)據(jù)的根目錄,即:data/hymenoptera_data transform: 對(duì)圖像數(shù)據(jù)進(jìn)行處理,例如,將圖片轉(zhuǎn)換為Tensor、圖片的維度可能不一致需要進(jìn)行resize target_transform:對(duì)標(biāo)簽數(shù)據(jù)進(jìn)行處理,例如,將文本標(biāo)簽轉(zhuǎn)換為數(shù)值 """ self.root_dir = root_dir self.transform = transform self.target_transform = target_transform # 獲取文件夾下所有圖片的名稱和對(duì)應(yīng)的標(biāo)簽 self.img_lst = [] for label in ['ants', 'bees']: path = os.path.join(root_dir, label) for img_name in os.listdir(path): self.img_lst.append((os.path.join(root_dir, label, img_name), label)) def __getitem__(self, idx): img_path, label = self.img_lst[idx] img = Image.open(img_path).convert('RGB') if self.transform: img = self.transform(img) if self.target_transform: label = self.target_transform(label) # 這個(gè)地方要注意,我們?cè)谟?jì)算loss的時(shí)候用交叉熵nn.CrossEntropyLoss() # 交叉熵的輸入有兩個(gè),一個(gè)是模型的輸出outputs,一個(gè)是標(biāo)簽targets,注意targets是一維tensor # 例如batchsize如果是2,ants的targets的應(yīng)該[0,0],而不是[[0][0]] # 因此label要返回0,而不是[0] return img, label def __len__(self): return len(self.img_lst)
In [310]:
train_transform = transforms.Compose([ transforms.RandomResizedCrop(224), # 將給定圖像隨機(jī)裁剪為不同的大小和寬高比,然后縮放所裁剪得到的圖像為制定的大小 transforms.RandomHorizontalFlip(), # 以給定的概率隨機(jī)水平旋轉(zhuǎn)給定的PIL的圖像,默認(rèn)為0.5 transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) # 驗(yàn)證集并不需要做與訓(xùn)練集相同的處理,所有,通常使用更加簡(jiǎn)單的transformer val_transform = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) # 根據(jù)標(biāo)簽?zāi)夸浀拿Q來(lái)確定圖片是哪一類,如果是"ants",標(biāo)簽設(shè)置為0,如果是"bees",標(biāo)簽設(shè)置為1 target_transform = transforms.Lambda(lambda y: 0 if y == "ants" else 1)
In [311]:
train_dataset = AntBeeDataset('data/hymenoptera_data/train', transform=train_transform, target_transform=target_transform) val_dataset = AntBeeDataset('data/hymenoptera_data/val', transform=val_transform, target_transform=target_transform)
1.2 Dataset數(shù)據(jù)集常用操作
1. 查看數(shù)據(jù)集大?。?/h4>
In [221]:
len(train_dataset), len(val_dataset)
Out[221]:
(245, 153)
2. 合并數(shù)據(jù)集
In [222]:
dataset = train_dataset + val_dataset
In [223]:
len(dataset)
Out[223]:
398
3. 劃分訓(xùn)練集、測(cè)試集
In [224]:
from torch.utils.data import random_split # random_split 不能直接使用百分比劃分,必須指定具體數(shù)字 train_size = int( len(dataset) * 0.8) test_size = len(dataset) - train_size
In [225]:
train_dataset, val_dataset = random_split(dataset, [train_size, test_size])
In [226]:
len(train_dataset), len(val_dataset)
Out[226]:
(318, 80)
1.3 IterableDataset類
使用迭代器風(fēng)格時(shí),必須繼承IterableDataset
類,且實(shí)現(xiàn)下面兩個(gè)方法:
__init__
,函數(shù)初始化一些參數(shù),如讀取外部數(shù)據(jù)源文件,在數(shù)據(jù)量過(guò)大時(shí),通常只是獲取操作句柄、數(shù)據(jù)庫(kù)連接。__iter__
,獲取迭代器。
雖然只需要實(shí)現(xiàn)這兩個(gè)方法,但是通常還需要在迭代過(guò)程中對(duì)數(shù)據(jù)進(jìn)行處理。IterableDataset類實(shí)現(xiàn)自定義數(shù)據(jù)集,本質(zhì)就是創(chuàng)建一個(gè)數(shù)據(jù)集類,且實(shí)現(xiàn)__iter__
返回一個(gè)迭代器。一下提供兩種方法通過(guò)IterableDataset類自定義數(shù)據(jù)集:
方法一:
In [289]:
class AntBeeIterableDataset(IterableDataset): # 把圖片所在的文件夾路徑分成兩個(gè)部分,一部分是根目錄,一部分是標(biāo)簽?zāi)夸洠@是因?yàn)闃?biāo)簽?zāi)夸浀拿Q我們需要用到 def __init__(self, root_dir, transform=None, target_transform=None): """ root_dir:存放數(shù)據(jù)的根目錄,即:data/hymenoptera_data transform: 對(duì)圖像數(shù)據(jù)進(jìn)行處理,例如,將圖片轉(zhuǎn)換為Tensor、圖片的維度可能不一致需要進(jìn)行resize target_transform:對(duì)標(biāo)簽數(shù)據(jù)進(jìn)行處理,例如,將文本標(biāo)簽轉(zhuǎn)換為數(shù)值 """ self.root_dir = root_dir self.transform = transform self.target_transform = target_transform # 獲取文件夾下所有圖片的名稱和對(duì)應(yīng)的標(biāo)簽 self.img_lst = [] for label in ['ants', 'bees']: path = os.path.join(root_dir, label) for img_name in os.listdir(path): self.img_lst.append((os.path.join(root_dir, label, img_name), label)) def __iter__(self): for img_path, label in self.img_lst: img = Image.open(img_path).convert('RGB') if self.transform: img = self.transform(img) if self.target_transform: label = self.target_transform(label) yield img, label
方法二:
In [285]:
class AntBeeIterableDataset(IterableDataset): # 把圖片所在的文件夾路徑分成兩個(gè)部分,一部分是根目錄,一部分是標(biāo)簽?zāi)夸?,這是因?yàn)闃?biāo)簽?zāi)夸浀拿Q我們需要用到 def __init__(self, root_dir, transform=None, target_transform=None): """ root_dir:存放數(shù)據(jù)的根目錄,即:data/hymenoptera_data transform: 對(duì)圖像數(shù)據(jù)進(jìn)行處理,例如,將圖片轉(zhuǎn)換為Tensor、圖片的維度可能不一致需要進(jìn)行resize target_transform:對(duì)標(biāo)簽數(shù)據(jù)進(jìn)行處理,例如,將文本標(biāo)簽轉(zhuǎn)換為數(shù)值 """ self.root_dir = root_dir self.transform = transform self.target_transform = target_transform # 獲取文件夾下所有圖片的名稱和對(duì)應(yīng)的標(biāo)簽 self.img_lst = [] for label in ['ants', 'bees']: path = os.path.join(root_dir, label) for img_name in os.listdir(path): self.img_lst.append((os.path.join(root_dir, label, img_name), label)) self.index = 0 def __iter__(self): return self def __next__(self): try: img_path, label = self.img_lst[self.index] self.index += 1 img = Image.open(img_path).convert('RGB') if self.transform: img = self.transform(img) if self.target_transform: label = self.target_transform(label) return img, label except IndexError: raise StopIteration()
In [290]:
train_dataset = AntBeeIterableDataset('data/hymenoptera_data/train', transform=train_transform, target_transform=target_transform) val_dataset = AntBeeIterableDataset('data/hymenoptera_data/val', transform=val_transform, target_transform=target_transform)
在處理大數(shù)據(jù)集時(shí),IterableDataset會(huì)比Dataset更有優(yōu)勢(shì),例如數(shù)據(jù)存儲(chǔ)在文件或者數(shù)據(jù)庫(kù)中,只需要在自定義的IterableDataset之類中獲取文件操作句柄或者數(shù)據(jù)庫(kù)連接和游標(biāo)驚喜迭代,每次只返回一條數(shù)據(jù)即可。我們把上文中螞蟻蜜蜂數(shù)據(jù)集的所有圖片、標(biāo)簽這里后寫入hymenoptera_data.txt中,內(nèi)容如下所示,假設(shè)有數(shù)億行,那么,就不能直接將數(shù)據(jù)加載到內(nèi)存了:
data/hymenoptera_data/train/ants/2288481644_83ff7e4572.jpg, ants data/hymenoptera_data/train/ants/2278278459_6b99605e50.jpg, ants data/hymenoptera_data/train/ants/543417860_b14237f569.jpg, ants ... ...
可以參考一下方式定義IterableDataset子類:
In [299]:
class AntBeeIterableDataset(IterableDataset): # 把圖片所在的文件夾路徑分成兩個(gè)部分,一部分是根目錄,一部分是標(biāo)簽?zāi)夸?,這是因?yàn)闃?biāo)簽?zāi)夸浀拿Q我們需要用到 def __init__(self, filepath, transform=None, target_transform=None): """ filepath:hymenoptera_data.txt完整路徑 transform: 對(duì)圖像數(shù)據(jù)進(jìn)行處理,例如,將圖片轉(zhuǎn)換為Tensor、圖片的維度可能不一致需要進(jìn)行resize target_transform:對(duì)標(biāo)簽數(shù)據(jù)進(jìn)行處理,例如,將文本標(biāo)簽轉(zhuǎn)換為數(shù)值 """ self.filepath = filepath self.transform = transform self.target_transform = target_transform def __iter__(self): with open(self.filepath, 'r') as f: for line in f: img_path, label = line.replace('\n', '').split(', ') img = Image.open(img_path).convert('RGB') if self.transform: img = self.transform(img) if self.target_transform: label = self.target_transform(label) yield img, label
In [307]:
train_dataset = AntBeeIterableDataset('hymenoptera_data.txt', transform=train_transform, target_transform=target_transform)
注意,IterableDataset方法在處理大數(shù)據(jù)集時(shí)確實(shí)比Dataset更有優(yōu)勢(shì),但是,IterableDataset在迭代過(guò)程中,樣本輸出順序是固定的,在使用DataLoader進(jìn)行加載時(shí),無(wú)法使用shuffle進(jìn)行打亂,同時(shí),因?yàn)樵贗terableDataset中并未強(qiáng)制限定必須實(shí)現(xiàn)__len__()
方法(很多時(shí)候確實(shí)也沒(méi)法獲取數(shù)據(jù)總量),不能通過(guò)len()
方法獲取數(shù)據(jù)總量。
2 DataLoad
DataLoader的功能是構(gòu)建可迭代的數(shù)據(jù)裝載器,在訓(xùn)練的時(shí)候,每一個(gè)for循環(huán),每一次Iteration,就是從DataLoader中獲取一個(gè)batch_size大小的數(shù)據(jù),節(jié)省內(nèi)存的同時(shí),它還可以實(shí)現(xiàn)多進(jìn)程、數(shù)據(jù)打亂等處理。我們通過(guò)一張圖來(lái)了解DataLoader數(shù)據(jù)讀取機(jī)制:
首先,在for循環(huán)中使用了DataLoader,進(jìn)入DataLoader后,首先根據(jù)是否使用多進(jìn)程DataLoaderIter,做出判斷之后單線程還是多線程,接著使用Sampler得索引Index,然后將索引給到DatasetFetcher,在這里面調(diào)用Dataset,根據(jù)索引,通過(guò)getitem得到實(shí)際的數(shù)據(jù)和標(biāo)簽,得到一個(gè)batch size大小的數(shù)據(jù)后,通過(guò)collate_fn函數(shù)整理成一個(gè)Batch Data的形式輸入到模型去訓(xùn)練。
在pytorch建模的數(shù)據(jù)處理、加載流程中,DataLoader應(yīng)該算是最核心的一步操作DataLoader有很多參數(shù),這里我們列出常用的幾個(gè):
- dataset:表示Dataset類,它決定了數(shù)據(jù)從哪讀取以及如何讀??;
- batch_size:表示批大??;
- num_works:表示是否多進(jìn)程讀取數(shù)據(jù);
- shuffle:表示每個(gè)epoch是否亂序;
- drop_last:表示當(dāng)樣本數(shù)不能被batch_size整除時(shí),是否舍棄最后一批數(shù)據(jù);
- num_workers:?jiǎn)?dòng)多少個(gè)進(jìn)程來(lái)加載數(shù)據(jù)。
我們重點(diǎn)說(shuō)說(shuō)多進(jìn)程模式下使用DataLoader,在多進(jìn)程模式下,每次 DataLoader 創(chuàng)建 iterator 時(shí)(遍歷DataLoader時(shí),例如,當(dāng)調(diào)用時(shí)enumerate(dataloader)),都會(huì)創(chuàng)建 num_workers 工作進(jìn)程。dataset, collate_fn, worker_init_fn 都會(huì)被傳到每個(gè)worker中,每個(gè)worker都用獨(dú)立的進(jìn)程。
對(duì)于映射風(fēng)格的數(shù)據(jù)集,即Dataset子類,主線程會(huì)用Sampler(采樣器)產(chǎn)生indice,并將它們送到進(jìn)程里。因此,shuffle是在主線程做的
對(duì)于迭代器風(fēng)格的數(shù)據(jù)集,即IterableDataset子類,因?yàn)槊總€(gè)進(jìn)程都有相同的data復(fù)制樣本,并在各個(gè)進(jìn)程里進(jìn)行不同的操作,以防止每個(gè)進(jìn)程輸出的數(shù)據(jù)是重復(fù)的,所以一般用 torch.utils.data.get_worker_info() 來(lái)進(jìn)行輔助處理。
這里,torch.utils.data.get_worker_info() 返回worker進(jìn)程的一些信息(id, dataset, num_workers, seed),如果在主線程跑的話返回None
注意,通常不建議在多進(jìn)程加載中返回CUDA張量,因?yàn)樵谑褂肅UDA和在多處理中共享CUDA張量時(shí)存在許多微妙之處(文檔中提出:只要接收過(guò)程保留張量的副本,就需要發(fā)送過(guò)程來(lái)保留原始張量)。建議采用 pin_memory=True ,以將數(shù)據(jù)快速傳輸?shù)街С諧UDA的GPU。簡(jiǎn)而言之,不建議在使用多線程的情況下返回CUDA的tensor。
In [313]:
dataload = DataLoader(train_dataset, batch_size=2)
In [315]:
img, label = next(iter(dataload))
In [316]:
img.shape, label
Out[316]:
(torch.Size([2, 3, 224, 224]), tensor([0, 0]))
到此這篇關(guān)于Pytorch建模過(guò)程中的DataLoader與Dataset的文章就介紹到這了,更多相關(guān)Pytorch建模內(nèi)容請(qǐng)搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
- Pytorch建模過(guò)程中的DataLoader與Dataset示例詳解
- Pytorch如何加載自己的數(shù)據(jù)集(使用DataLoader讀取Dataset)
- PyTorch?Dataset與DataLoader使用超詳細(xì)講解
- Pytorch數(shù)據(jù)讀取之Dataset和DataLoader知識(shí)總結(jié)
- Pytorch自定義Dataset和DataLoader去除不存在和空數(shù)據(jù)的操作
- pytorch Dataset,DataLoader產(chǎn)生自定義的訓(xùn)練數(shù)據(jù)案例
- PyTorch實(shí)現(xiàn)重寫/改寫Dataset并載入Dataloader
- 一文弄懂Pytorch的DataLoader, DataSet, Sampler之間的關(guān)系
- PyTorch 解決Dataset和Dataloader遇到的問(wèn)題
相關(guān)文章
Django 登陸驗(yàn)證碼和中間件的實(shí)現(xiàn)
這篇文章主要介紹了Django 登陸驗(yàn)證碼和中間件的實(shí)現(xiàn),小編覺(jué)得挺不錯(cuò)的,現(xiàn)在分享給大家,也給大家做個(gè)參考。一起跟隨小編過(guò)來(lái)看看吧2018-08-08Pipenv輕量級(jí)虛擬環(huán)境管理工具使用指南
這篇文章主要為大家介紹了Pipenv輕量級(jí)虛擬環(huán)境管理工具使用指南,有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進(jìn)步,早日升職加薪2023-02-02python中的線程threading.Thread()使用詳解
這篇文章主要介紹了python中的線程threading.Thread()使用詳解,文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來(lái)一起學(xué)習(xí)學(xué)習(xí)吧2019-12-12Python海象運(yùn)算符代碼分析及知識(shí)點(diǎn)總結(jié)
在本篇內(nèi)容里小編給大家總結(jié)了關(guān)于Python海象運(yùn)算符的使用的相關(guān)內(nèi)容及代碼,有興趣的朋友們跟著學(xué)習(xí)下。2022-11-11python lambda表達(dá)式(匿名函數(shù))寫法解析
這篇文章主要介紹了python lambda表達(dá)式(匿名函數(shù))寫法解析,文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友可以參考下2019-09-09python實(shí)現(xiàn)自動(dòng)化辦公郵件合并功能
這篇文章主要介紹了python實(shí)現(xiàn)自動(dòng)化辦公郵件合并功能,本文通過(guò)實(shí)例代碼給大家介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或工作具有一定的參考借鑒價(jià)值,需要的朋友可以參考下2021-07-07