Python中的Dataset和Dataloader詳解
Dataset,Dataloader是什么?
- Dataset:負(fù)責(zé)可被Pytorch使用的數(shù)據(jù)集的創(chuàng)建
- Dataloader:向模型中傳遞數(shù)據(jù)
為什么要了解Dataloader
? 因?yàn)槟愕纳窠?jīng)網(wǎng)絡(luò)表現(xiàn)不佳的主要原因之一可能是由于數(shù)據(jù)不佳或理解不足。
因此,以更直觀的方式理解、預(yù)處理數(shù)據(jù)并將其加載到網(wǎng)絡(luò)中非常重要。
? 通常,我們?cè)谀J(rèn)或知名數(shù)據(jù)集(如 MNIST 或 CIFAR)上訓(xùn)練神經(jīng)網(wǎng)絡(luò),可以輕松地實(shí)現(xiàn)預(yù)測(cè)和分類類型問題的超過 90% 的準(zhǔn)確度。
但是那是因?yàn)檫@些數(shù)據(jù)集組織整齊且易于預(yù)處理。
但是處理自己的數(shù)據(jù)集時(shí),我們常常無法達(dá)到這樣高的準(zhǔn)確率
Dataloader 的使用
載入相關(guān)類
from torch.utils.data import Dataloader
設(shè)置相關(guān)參數(shù)
from torch.utils.data import DataLoader DataLoader( dataset, batch_size=1, shuffle=False, num_workers=0, collate_fn=None, pin_memory=False, ) """ dataset:是數(shù)據(jù)集 batch_size:是指一次迭代中使用的訓(xùn)練樣本數(shù)。通常我們將數(shù)據(jù)分成訓(xùn)練集和測(cè)試集,并且我們可能有不同的批量大小。 shuffle:是傳遞給 DataLoader 類的另一個(gè)參數(shù)。該參數(shù)采用布爾值(真/假)。如果 shuffle 設(shè)置為 True,則所有樣本都被打亂并分批加載。否則,它們會(huì)被一個(gè)接一個(gè)地發(fā)送,而不會(huì)進(jìn)行任何洗牌。 num_workers:允許多處理來增加同時(shí)運(yùn)行的進(jìn)程數(shù) collate_fn:合并數(shù)據(jù)集 pin_memory:鎖頁內(nèi)存:將張量固定在內(nèi)存中 """
以minist為例子
# Import MNIST from torchvision.datasets import MNIST # Download and Save MNIST data_train = MNIST('~/mnist_data', train=True, download=True) # Print Data print(data_train) print(data_train[12]) #Dataset MNIST Number of datapoints: 60000 Root location: /Users/viharkurama/mnist_data Split: Train (<PIL.Image.Image image mode=L size=28x28 at 0x11164A100>, 3)
現(xiàn)在讓嘗試提取元組,其中第一個(gè)值對(duì)應(yīng)于圖像,第二個(gè)值對(duì)應(yīng)于其各自的標(biāo)簽。
下面是代碼片段:
import matplotlib.pyplot as plt random_image = data_train[0][0] random_image_label = data_train[0][1] # Print the Image using Matplotlib plt.imshow(random_image) print("The label of the image is:", random_image_label)
讓我們使用 DataLoader 類來加載數(shù)據(jù)集,如下所示。
import torch from torchvision import transforms data_train = torch.utils.data.DataLoader( MNIST( '~/mnist_data', train=True, download=True, transform = transforms.Compose([ transforms.ToTensor() ])), batch_size=64, shuffle=True ) for batch_idx, samples in enumerate(data_train): print(batch_idx, samples)
這就是我們使用 DataLoader 加載簡(jiǎn)單數(shù)據(jù)集的方式。 但是,我們不能總是對(duì)每個(gè)數(shù)據(jù)集都依賴已經(jīng)有的數(shù)據(jù)集,要是自己的數(shù)據(jù)集怎么辦。
定義自己的數(shù)據(jù)集
我們將創(chuàng)建一個(gè)由數(shù)字和文本組成的簡(jiǎn)單自定義數(shù)據(jù)集
先介紹兩個(gè)方法
#__getitem__() 方法通過索引返回?cái)?shù)據(jù)集中選定的樣本。 #__len__() 方法返回?cái)?shù)據(jù)集的總大小。例如,如果您的數(shù)據(jù)集包含 1,00,000 個(gè)樣本,則 len 方法應(yīng)返回 1,00,000。 class Dataset(object): def __getitem__(self, index): raise NotImplementedError def __len__(self): raise NotImplementedError
? 創(chuàng)建自定義數(shù)據(jù)集并不復(fù)雜,但作為加載數(shù)據(jù)的典型過程的附加步驟,有必要構(gòu)建一個(gè)接口以獲得良好的抽象(至少可以說是一個(gè)很好的語法糖)。
現(xiàn)在我們將創(chuàng)建一個(gè)包含數(shù)字及其平方值的新數(shù)據(jù)集。 讓我們將數(shù)據(jù)集稱為 SquareDataset。 其目的是返回 [a,b] 范圍內(nèi)的值的平方。
下面是相關(guān)代碼:
import torch import torchvision from torch.utils.data import Dataset, DataLoader from torchvision import datasets, transforms class SquareDataset(Dataset): def __init__(self, a=0, b=1): super(Dataset, self).__init__() assert a <= b self.a = a self.b = b def __len__(self): return self.b - self.a + 1 def __getitem__(self, index): assert self.a <= index <= self.b return index, index**2 data_train = SquareDataset(a=1,b=64) data_train_loader = DataLoader(data_train, batch_size=64, shuffle=True) print(len(data_train))
? 在上面的代碼塊中,我們創(chuàng)建了一個(gè)名為 SquareDataset 的 Python 類,它繼承了 PyTorch 的 Dataset 類。
接下來,我們調(diào)用了一個(gè) init() 構(gòu)造函數(shù),其中 a 和 b 分別被初始化為 0 和 1。 超類用于從繼承的 Dataset 類中訪問 len 和 get_item 方法。
接下來我們使用 assert 語句來檢查 a 是否小于或等于 b,因?yàn)槲覀兿胍獎(jiǎng)?chuàng)建一個(gè)數(shù)據(jù)集,其中值將位于 a 和 b 之間。
? 然后,我們使用 SquareDataset 類創(chuàng)建了一個(gè)數(shù)據(jù)集,其中數(shù)據(jù)值的范圍為 1 到 64。我們將其加載到名為 data_train 的變量中。
最后,Dataloader 類在 data_train_loader 中存儲(chǔ)的數(shù)據(jù)上創(chuàng)建了一個(gè)迭代器,batch_size 初始化為 64,shuffle 設(shè)置為 True。
如何使用transform
? 當(dāng)你學(xué)會(huì)怎么定義自己的數(shù)據(jù)集的時(shí)候,你可能會(huì)想要更近 一步的操作,對(duì)于你自己的數(shù)據(jù)集進(jìn)行剪切或者變換
? 以CIFAR10為例子
- 將所有圖像調(diào)整為 32×32
- 對(duì)圖像應(yīng)用中心裁剪變換
- 將裁剪后的圖像轉(zhuǎn)換為張量
- 標(biāo)準(zhǔn)化圖像
導(dǎo)入必要的模塊
import torch import torchvision import torchvision.transforms as transforms import matplotlib.pyplot as plt import numpy as np
接下來,我們將定義一個(gè)名為 transforms 的變量,我們?cè)谄渲邪错樞蚓帉懰蓄A(yù)處理步驟。我們使用 Compose 類將所有轉(zhuǎn)換操作鏈接在一起。
transform = transforms.Compose([ # resize transforms.Resize(32), # center-crop transforms.CenterCrop(32), # to-tensor transforms.ToTensor(), # normalize transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) ]) """ resize:此調(diào)整大小轉(zhuǎn)換將所有圖像轉(zhuǎn)換為定義的大小。在這種情況下,我們要將所有圖像的大小調(diào)整為 32×32。因此,我們將 32 作為參數(shù)傳遞。 center-crop:接下來,我們使用 CenterCrop 變換裁剪圖像。 我們發(fā)送的參數(shù)也是分辨率/大小,但由于我們已經(jīng)將圖像大小調(diào)整為 32x32,因此圖像將與此裁剪中心對(duì)齊。 這意味著圖像將從中心裁剪 32 個(gè)單位(垂直和水平)。 to-tensor:我們使用 ToTensor() 方法將圖像轉(zhuǎn)換為張量數(shù)據(jù)類型。 normalize:這將張量中的所有值歸一化,使它們位于 0.5 和 1 之間。 """
在下一步中,在執(zhí)行我們剛剛定義的轉(zhuǎn)換之后,我們將使用 trainloader 將 CIFAR 數(shù)據(jù)集加載到訓(xùn)練集中。
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform) trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=False)
到此這篇關(guān)于Python中的Dataset和Dataloader詳解的文章就介紹到這了,更多相關(guān)Dataset和Dataloader詳解內(nèi)容請(qǐng)搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
相關(guān)文章
Python實(shí)現(xiàn)多條件篩選目標(biāo)數(shù)據(jù)功能【測(cè)試可用】
這篇文章主要介紹了Python實(shí)現(xiàn)多條件篩選目標(biāo)數(shù)據(jù)功能,結(jié)合實(shí)例形式總結(jié)分析了Python3使用內(nèi)建函數(shù)filter、pandas包以及for循環(huán)三種方法對(duì)比分析了列表進(jìn)行條件篩選操作相關(guān)實(shí)現(xiàn)技巧與運(yùn)行效率,需要的朋友可以參考下2018-06-06深入解析Python中BeautifulSoup4的基礎(chǔ)知識(shí)與實(shí)戰(zhàn)應(yīng)用
BeautifulSoup4正是一款功能強(qiáng)大的解析器,能夠輕松解析HTML和XML文檔,本文將介紹BeautifulSoup4的基礎(chǔ)知識(shí),并通過實(shí)際代碼示例進(jìn)行演示,感興趣的可以了解下2024-02-02Python使用PySimpleGUI打造輕量級(jí)計(jì)算器
PySimpleGUI是一個(gè)跨平臺(tái)的Python GUI庫(kù),它支持Windows、Mac和Linux等多種操作系統(tǒng),本文將利用PySimpleGUI打造一個(gè)輕量級(jí)計(jì)算器,希望對(duì)大家有所幫助2024-03-03Tensorflow 實(shí)現(xiàn)線性回歸模型的示例代碼
這篇文章主要介紹了Tensorflow 實(shí)現(xiàn)線性回歸模型,文中通過示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧2022-05-05Python中關(guān)于面向?qū)ο笾欣^承的詳細(xì)講解
面向?qū)ο缶幊?(OOP) 語言的一個(gè)主要功能就是“繼承”。繼承是指這樣一種能力:它可以使用現(xiàn)有類的所有功能,并在無需重新編寫原來的類的情況下對(duì)這些功能進(jìn)行擴(kuò)展2021-10-10淺談多卡服務(wù)器下隱藏部分 GPU 和 TensorFlow 的顯存使用設(shè)置
這篇文章主要介紹了淺談多卡服務(wù)器下隱藏部分 GPU 和 TensorFlow 的顯存使用設(shè)置,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過來看看吧2020-06-06