Pytorch加載圖像數(shù)據(jù)集的方法
1. 簡介
Pytorch深度學(xué)習(xí)框架,加載圖像數(shù)據(jù)集(這里以分類為例),通常都需要經(jīng)過以下兩個步驟:
1、定義數(shù)據(jù)集:torch以及torchvision中提供了多種方法來簡化數(shù)據(jù)集定義的過程。
2、創(chuàng)建Dataloader數(shù)據(jù)加載器:通過torch.utils.data.Dataloader實例化數(shù)據(jù)加載迭代器,傳 入自定義的數(shù)據(jù)集,并配置相關(guān)參數(shù)。
其中,第一個步驟定義數(shù)據(jù)集又包含多種實現(xiàn)方式:
1、torchvision.datasets.ImageFolder:用于加載標(biāo)準(zhǔn)的開源數(shù)據(jù)集。
2、torchvision.datasets.ImageFolder:從文件夾結(jié)構(gòu)加載圖像數(shù)據(jù),自動生成標(biāo)簽。
3、torchvision.datasets.DatasetFolder:更通用的工具,適用于自定義圖像數(shù)據(jù)集,其中,圖像和標(biāo)簽不一定按文件夾結(jié)構(gòu)組織。
4、torch.utils.data.Dataset:一個抽象基類,用戶通過重寫__init__、__len__、和 __getitem__ 方法以提供數(shù)據(jù)和標(biāo)簽。
第二個步驟,實例化數(shù)據(jù)加載迭代器 torch.utils.data.Dataloader 類,涉及到的主要參數(shù):
- dataset :數(shù)據(jù)集(可迭代對象)
- batch_size :批處理數(shù)量
- shuffle :每完成一個epoch,是否需要重新打亂數(shù)據(jù)
- num_worker:采用多進(jìn)程讀取機(jī)制
- collate_fn:可自定義函數(shù),用于將一批數(shù)據(jù)合并成一個批次,默認(rèn)為
None - drop_last :當(dāng)樣本數(shù)不能被batch_size整除時,是否舍棄最后一個batch的數(shù)據(jù)
在了解完數(shù)據(jù)集加載的兩步驟后,其實主要變化的是第一步如何定義數(shù)據(jù)集。所以,接下來都是圍繞不同的數(shù)據(jù)集定義方式,實現(xiàn)最終的數(shù)據(jù)加載。
2. torchvision.datasets.MNIST
目前,torchvision.datasets 庫中已經(jīng)收錄了多種類型的數(shù)據(jù)集,一般都是各個圖像處理領(lǐng)域內(nèi)的開源標(biāo)準(zhǔn)數(shù)據(jù)集,如下列舉了一些較為常見的數(shù)據(jù)集。
- 圖像分類:MNIST,CIFAR10, CIFAR100,ImageNet
- 目標(biāo)檢測:COCO,VOC
- 圖像分割:COCO,VOC
這種開源數(shù)據(jù)集的加載,還是非常簡單的,因為大佬們都已經(jīng)封裝好方法了,直接調(diào)用API就實現(xiàn)了。這里以mnist手寫數(shù)字識別數(shù)據(jù)集為例,代碼如下。
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
# 數(shù)據(jù)轉(zhuǎn)換
transform = transforms.Compose([transforms.ToTensor()])
# 加載 MNIST 數(shù)據(jù)集,這里設(shè)置了下載數(shù)據(jù)集
train_dataset = datasets.MNIST(root='mnist_datasets', train=True, download=True,transform=transform)
test_dataset = datasets.MNIST(root='mnist_datasets', train=False, download=True,transform=transform)
#打印dataset
print(train_dataset[0])
# 創(chuàng)建數(shù)據(jù)加載迭代器,傳入數(shù)據(jù)集
train_loader = DataLoader(dataset=train_dataset, batch_size=256, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=256, shuffle=False)
# 使用加載器迭代輸出數(shù)據(jù)
for images, labels in train_loader:
print("images:",images.shape)
print("labels",labels.shape)代碼執(zhí)行后的結(jié)果,首先在定義的root目錄下,下載了mnist數(shù)據(jù)集 。

終端打印了train_dataset數(shù)據(jù)集中的第1個元素,前面也講過,定義的數(shù)據(jù)集必須是可迭代的結(jié)構(gòu),也就是使用索引,可檢索出其中的內(nèi)容,其中內(nèi)容的格式如下:
(tensor,label_index),tensor是圖片,label_index是該圖片對應(yīng)的數(shù)字標(biāo)簽(模型中用到的標(biāo)簽,與現(xiàn)實中定義的標(biāo)簽不同,后續(xù)會講)。
另外,終端也迭代輸出了每一批次數(shù)據(jù)的形狀,每一批次喂入的數(shù)據(jù)量 batch_size = 256 ,每一張圖像形狀(1,28,28),單通道的灰色圖像,大小為28*28。

解釋下,前面提到的模型標(biāo)簽與現(xiàn)實中真是標(biāo)簽。debug模式下,調(diào)試上面代碼,可以看到定義的數(shù)據(jù)集train_dataset中的屬性,其中:
classes:真實的標(biāo)簽
class_to_index:影射了真實標(biāo)簽與模型標(biāo)簽的關(guān)系,可以看到模型標(biāo)簽以阿拉伯?dāng)?shù)字命名,從0開始依次遞增+1。

總結(jié):訓(xùn)練時喂入的分類標(biāo)簽,是以阿拉伯?dāng)?shù)字,從0開始依次遞增+1,這樣的命名規(guī)則。所以,在模型訓(xùn)練和推理階段,模型輸出的標(biāo)簽依然是阿拉伯,這時候定義的class_to_index就有作用了,將模型推理出的阿拉伯?dāng)?shù)字標(biāo)簽轉(zhuǎn)化為真正的類名。
3. torchvision.datasets.ImageFolder
torchvision.datasets.ImageFolder 主要用于從文件夾中加載圖像數(shù)據(jù)集,指定根目錄下的每一個子文件夾表示一個類別。該方法通常用于圖像分類任務(wù),并且可以很方便地使用Dataloader來加載批量數(shù)據(jù)。
文件夾的目錄結(jié)構(gòu)如下,root表示根目錄,class_0和class_1是以類名命名的文件夾,里面分別包含屬于該類的圖像。
root/
class_0/
images1.jpg
images2.jpg
....
class_1/
images1.jpg
images2:jpg
....
....我測試的根目錄 root 是mnist數(shù)據(jù)集中的train目錄,共有10類。其中第10類,類名為 ”九“,是我特意修改的,同樣也是為了驗證真實標(biāo)簽與模型標(biāo)簽之間的關(guān)系。

這是第一類 0 文件夾下的數(shù)據(jù),均為手寫數(shù)字0 的圖片。

接下里可直接使用代碼加載該數(shù)據(jù)集。
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
# 定義數(shù)據(jù)預(yù)處理操作
transform = transforms.Compose([
transforms.ToTensor(), # 將圖像轉(zhuǎn)換為張量
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # 標(biāo)準(zhǔn)化
])
# 創(chuàng)建ImageFolder數(shù)據(jù)集,根目錄用了絕對路徑
dataset = datasets.ImageFolder(root='F:\Amode\datasets\mnist\train', transform=transform)
# 打印數(shù)據(jù)集中第一項
print(dataset[0])
# 創(chuàng)建DataLoader數(shù)據(jù)加載迭代器
data_loader = DataLoader(dataset, batch_size=32, shuffle=True)
#按照常例,迭代遍歷數(shù)據(jù)
for images,labels in data_loader:
print("images:",images.shape)
print("lables",labels)執(zhí)行代碼,終端打印信息,首先還是數(shù)據(jù)集中的第一項,內(nèi)容格式仍然是:
(tensor,label_index)
同樣,更簡便的方式,大家用debug模式調(diào)試代碼。

個人覺得,對于分類數(shù)據(jù)集,這種加載方式是非常容易和輕松的。前提是需要將數(shù)據(jù)集整理成固定的結(jié)構(gòu) 。
4.torchvision.datasets.DatasetFolder
torchvision.datasets.DatasetFolder 是一個比 ImageFolder 更靈活的類,而ImageFolder繼承的父類就是它,它允許你自定義加載數(shù)據(jù)的方式,自定義數(shù)據(jù)集結(jié)構(gòu)。
因為比較靈活百變,更難理解和掌握。接下來先了解下該方法的源碼,初始化參數(shù)及重要屬性。
這部分內(nèi)容是初始化參數(shù)。

- root 是數(shù)據(jù)集的根目錄。
- loader 可自定義讀取數(shù)據(jù)樣本的方法,該方法傳入?yún)?shù)是樣本的路徑。
- extension 擴(kuò)展名,指的是圖片的后綴類型,以元組形式入?yún)ⅰ?/li>
- is_valid_file (可調(diào)用對象,可選項參數(shù)),獲取文件路徑并核實文件是否有效,它和extension必須有一個。
- allow_empty True 允許空文件被認(rèn)為是一個類,False反之。
既然ImageFolder的父類就是它,可以先用它實現(xiàn)ImageFolder中要求的數(shù)據(jù)集目錄結(jié)構(gòu)(結(jié)構(gòu)在第3部門有說明)。以下代碼和ImagesFolde的r實現(xiàn)效果一致。
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from PIL import Image
#自定義的圖像讀取方式
def custom_load(path):
return Image.open(path).convert("RGB")
# 定義數(shù)據(jù)預(yù)處理操作
transform = transforms.Compose([
transforms.ToTensor(), # 將圖像轉(zhuǎn)換為張量
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # 標(biāo)準(zhǔn)化
])
# 創(chuàng)建ImageFolder數(shù)據(jù)集,根目錄用了絕對路徑
dataset = datasets.DatasetFolder(
root=r'F:\Amode\datasets\mnist\train',
loader= custom_load,
transform=transform,
extensions=("jpg","png")
)
# 打印數(shù)據(jù)集中第一項
print(dataset[0])
# 創(chuàng)建DataLoader數(shù)據(jù)加載迭代器
data_loader = DataLoader(dataset, batch_size=32, shuffle=True)
#按照常例,迭代遍歷數(shù)據(jù)
for images,labels in data_loader:
print("images:",images.shape)
print("lables",labels)假設(shè),換種數(shù)據(jù)集的目錄結(jié)構(gòu)呢,這里舉例一種比較常見的結(jié)構(gòu),如下圖所示。
所有圖片都在同一目錄下,且圖片文件名稱以 label_name的格式命名,即標(biāo)簽在文件名中體現(xiàn)。

接下來是實現(xiàn)的代碼,新定義了一個類,繼承DatasetsFolder類,重新定義了父類中的find_class,make_dataset函數(shù)。想具體了解這兩個函數(shù)的可點進(jìn)父類源碼中去看。
find_class:輸入根目錄root,輸出classes(列表),所有的真實標(biāo)簽(str),輸出class_to_idx(字典),鍵為真實標(biāo)簽,值為類別索引值。
make_dataset:輸入仍是初始化那些參數(shù);輸出樣本列表,格式為[(file_path,class_indx),.......]
import os
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from PIL import Image
#自定義的圖像加載方式
def custom_load(path):
return Image.open(path).convert("RGB")
# 定義數(shù)據(jù)預(yù)處理操作
transform = transforms.Compose([
transforms.ToTensor(), # 將圖像轉(zhuǎn)換為張量
transforms.Resize((224, 224)),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # 標(biāo)準(zhǔn)化
])
class custom_DatasetFolder(datasets.DatasetFolder):
#重寫find_classes函數(shù)
def find_classes(self, directory):
"""
傳參:根目錄;
輸出:classes = [] ,classes_to_idx = {class:index}
"""
lables = set()
lables_to_indexs = {}
#獲取目錄下文件列表
file_list = os.listdir(self.root)
#遍歷文件列表
for f in file_list:
#從文件名中分離出標(biāo)簽
lable = f.split('_')[0]
#添加到集合中,集合不允許重復(fù)元素
lables.add(str(lable))
#生成真實標(biāo)簽label與類別索引class的映射字典
for i,l in enumerate(list(lables)):
lables_to_indexs[l] = int(i)
return list(lables),lables_to_indexs
def make_dataset(self,directory,class_to_idx,extensions,is_valid_file,allow_empty,):
"""
傳參;
輸出:sample[(path,class),......]
"""
#獲取目錄下的文件列表
file = os.listdir(directory)
samp = []
#遍歷文件
for f in file:
#分離出標(biāo)簽和文件后綴
lab = f.split('_')[0]
sufix = f.split('.')[-1]
#文件后綴滿足擴(kuò)展要求
if sufix in extensions:
#根據(jù)標(biāo)簽找到類別class
cls = class_to_idx[lab]
#文件完整路徑
file_path = os.path.join(directory,f)
#每個樣本以(path,class)格式添加到列表中
samp.append((str(file_path),cls))
return samp
# 創(chuàng)建ImageFolder數(shù)據(jù)集,根目錄用了絕對路徑
dataset = custom_DatasetFolder(
root=r'F:\Amode\datasets\image_data',
loader= custom_load,
transform=transform,
extensions=("jpg","png")
)
# 打印數(shù)據(jù)集中第一項
print(dataset[0])
# 創(chuàng)建DataLoader數(shù)據(jù)加載迭代器
data_loader = DataLoader(dataset, batch_size=32, shuffle=True)
#按照常例,迭代遍歷數(shù)據(jù)
for images,labels in data_loader:
print("images:",images.shape)
print("lables",labels)任意結(jié)構(gòu)的數(shù)據(jù)集,都可以使用基類DatasetFolder實現(xiàn),主要還是通過覆蓋上面兩個函數(shù),實現(xiàn)獲取標(biāo)簽類別屬性,以及樣本的路徑和類別,還有自定義的加載圖片函數(shù)。
5. torch.utils.data.Datasets
繼上面內(nèi)容,這是唯一一個使用torch,定義數(shù)據(jù)集的方式。

翻譯一下上面的內(nèi)容:
該類是一個抽象類,所有表示從鍵到數(shù)據(jù)樣本映射的數(shù)據(jù)集都應(yīng)繼承此類。所有子類應(yīng)重寫 __getitem__ 方法,以支持根據(jù)給定的鍵獲取數(shù)據(jù)樣本。子類還可以選擇性地重寫 __len__ 方法,這通常會返回數(shù)據(jù)集的大小,torch.utils.data.Sampler 實現(xiàn)和 torch.utils.data.DataLoader 的默認(rèn)選項都期望這個方法的存在。子類還可以選擇性地實現(xiàn) __getitems__ 方法,以加速批量樣本的加載。該方法接受一個樣本索引的列表,并返回這些樣本的列表。
那什么叫抽象類呢?
抽象類是一種不能直接實例化的類,主要用于定義方法的基本結(jié)構(gòu)和要求,其作為父類呢,通常讓子類去繼承它,并且在子類中必須實現(xiàn)這個抽象類中定義的方法,也就是具體的實現(xiàn)交給子類。
本節(jié)中用到的基類torch.utils.data.Datasets,需要實現(xiàn)以下三種方法。
__init__: 初始化數(shù)據(jù)集對象,通常在這里加載和處理數(shù)據(jù)。__len__: 返回數(shù)據(jù)集的大?。颖緮?shù)量)。__getitem__: 根據(jù)給定的索引返回數(shù)據(jù)集中的樣本和標(biāo)簽。
這部分的演示代碼,使用的是上一小節(jié)中的數(shù)據(jù)集 ,數(shù)據(jù)集和實現(xiàn)代碼如下。

rom torch.utils.data import Dataset
from PIL import Image
import os
class CustomDataset(Dataset):
def __init__(self, image_folder, transform=None):
"""
Args:
image_folder : 圖像所在文件夾的路徑
transform : 應(yīng)用于樣本的轉(zhuǎn)換操作
"""
self.image_folder = image_folder
self.transform = transform
self.class_to_idx = {}
self.image_files = [f for f in os.listdir(image_folder) if f.endswith('.jpg')]
self.__class_to_idx()
def __len__(self):
"""返回數(shù)據(jù)集中的樣本數(shù)量"""
return len(self.image_files)
def __class_to_idx(self):
labels = set()
for file in os.listdir(self.image_folder):
if file.endswith('.jpg'):
label = file.split('_')[0]
labels.add(str(label))
for i,l in enumerate(labels):
self.class_to_idx[l] = int(i)
def __getitem__(self, idx):
"""
Args:
idx (int): 索引
Returns:
dict: 包含圖像和標(biāo)簽的字典
"""
img_name = os.path.join(self.image_folder, self.image_files[idx])
image = Image.open(img_name).convert('RGB')
if self.transform:
image = self.transform(image)
# 標(biāo)簽從文件名中提取
lab_name = self.image_files[idx].split('_')[0]
label = self.class_to_idx[lab_name]
return image, label
from torch.utils.data import DataLoader
from torchvision import transforms
# 定義轉(zhuǎn)換操作
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor()
])
# 實例化自定義數(shù)據(jù)集
dataset = CustomDataset(image_folder='F:\Amode\datasets\image_data', transform=transform)
# 創(chuàng)建 DataLoader
data_loader = DataLoader(dataset, batch_size=4, shuffle=True)
print(dataset[0])
# 使用 DataLoader 遍歷數(shù)據(jù)
for images, labels in data_loader:
# 在這里進(jìn)行訓(xùn)練或測試操作
print(images.size(), labels)以上就是Pytorch加載圖像數(shù)據(jù)集的方法的詳細(xì)內(nèi)容,更多關(guān)于Pytorch加載圖像數(shù)據(jù)集的資料請關(guān)注腳本之家其它相關(guān)文章!
相關(guān)文章
集成開發(fā)環(huán)境Pycharm的安裝及模板設(shè)置圖文教程
PyCharm是一種Python的集成開發(fā)環(huán)境,帶有一整套可以幫助用戶在使用Python語言開發(fā)時提高效率的工具,這篇文章主要介紹了集成開發(fā)環(huán)境Pycharm的安裝及模板設(shè)置,需要的朋友可以參考下2022-07-07
python語法之語言元素和分支循環(huán)結(jié)構(gòu)詳解
這篇文章主要介紹了Python的語言元素和分支循環(huán)結(jié)構(gòu),本文通過實例代碼給大家介紹的非常詳細(xì),具有一定的參考借鑒價值,需要的朋友可以參考下2021-10-10
Python編程pygame模塊實現(xiàn)移動的小車示例代碼
這篇文章主要介紹了Python編程pygame模塊實現(xiàn)移動的小車示例代碼,具有一定借鑒價值,需要的朋友可以參考下2018-01-01

