Pytorch加載圖像數據集的方法
1. 簡介
Pytorch深度學習框架,加載圖像數據集(這里以分類為例),通常都需要經過以下兩個步驟:
1、定義數據集:torch以及torchvision中提供了多種方法來簡化數據集定義的過程。
2、創(chuàng)建Dataloader數據加載器:通過torch.utils.data.Dataloader實例化數據加載迭代器,傳 入自定義的數據集,并配置相關參數。
其中,第一個步驟定義數據集又包含多種實現方式:
1、torchvision.datasets.ImageFolder:用于加載標準的開源數據集。
2、torchvision.datasets.ImageFolder:從文件夾結構加載圖像數據,自動生成標簽。
3、torchvision.datasets.DatasetFolder:更通用的工具,適用于自定義圖像數據集,其中,圖像和標簽不一定按文件夾結構組織。
4、torch.utils.data.Dataset:一個抽象基類,用戶通過重寫__init__、__len__、
和 __getitem__
方法以提供數據和標簽。
第二個步驟,實例化數據加載迭代器 torch.utils.data.Dataloader 類,涉及到的主要參數:
- dataset :數據集(可迭代對象)
- batch_size :批處理數量
- shuffle :每完成一個epoch,是否需要重新打亂數據
- num_worker:采用多進程讀取機制
- collate_fn:可自定義函數,用于將一批數據合并成一個批次,默認為
None
- drop_last :當樣本數不能被batch_size整除時,是否舍棄最后一個batch的數據
在了解完數據集加載的兩步驟后,其實主要變化的是第一步如何定義數據集。所以,接下來都是圍繞不同的數據集定義方式,實現最終的數據加載。
2. torchvision.datasets.MNIST
目前,torchvision.datasets 庫中已經收錄了多種類型的數據集,一般都是各個圖像處理領域內的開源標準數據集,如下列舉了一些較為常見的數據集。
- 圖像分類:MNIST,CIFAR10, CIFAR100,ImageNet
- 目標檢測:COCO,VOC
- 圖像分割:COCO,VOC
這種開源數據集的加載,還是非常簡單的,因為大佬們都已經封裝好方法了,直接調用API就實現了。這里以mnist手寫數字識別數據集為例,代碼如下。
from torchvision import datasets, transforms from torch.utils.data import DataLoader # 數據轉換 transform = transforms.Compose([transforms.ToTensor()]) # 加載 MNIST 數據集,這里設置了下載數據集 train_dataset = datasets.MNIST(root='mnist_datasets', train=True, download=True,transform=transform) test_dataset = datasets.MNIST(root='mnist_datasets', train=False, download=True,transform=transform) #打印dataset print(train_dataset[0]) # 創(chuàng)建數據加載迭代器,傳入數據集 train_loader = DataLoader(dataset=train_dataset, batch_size=256, shuffle=True) test_loader = DataLoader(dataset=test_dataset, batch_size=256, shuffle=False) # 使用加載器迭代輸出數據 for images, labels in train_loader: print("images:",images.shape) print("labels",labels.shape)
代碼執(zhí)行后的結果,首先在定義的root目錄下,下載了mnist數據集 。
終端打印了train_dataset數據集中的第1個元素,前面也講過,定義的數據集必須是可迭代的結構,也就是使用索引,可檢索出其中的內容,其中內容的格式如下:
(tensor,label_index),tensor是圖片,label_index是該圖片對應的數字標簽(模型中用到的標簽,與現實中定義的標簽不同,后續(xù)會講)。
另外,終端也迭代輸出了每一批次數據的形狀,每一批次喂入的數據量 batch_size = 256 ,每一張圖像形狀(1,28,28),單通道的灰色圖像,大小為28*28。
解釋下,前面提到的模型標簽與現實中真是標簽。debug模式下,調試上面代碼,可以看到定義的數據集train_dataset中的屬性,其中:
classes:真實的標簽
class_to_index:影射了真實標簽與模型標簽的關系,可以看到模型標簽以阿拉伯數字命名,從0開始依次遞增+1。
總結:訓練時喂入的分類標簽,是以阿拉伯數字,從0開始依次遞增+1,這樣的命名規(guī)則。所以,在模型訓練和推理階段,模型輸出的標簽依然是阿拉伯,這時候定義的class_to_index就有作用了,將模型推理出的阿拉伯數字標簽轉化為真正的類名。
3. torchvision.datasets.ImageFolder
torchvision.datasets.ImageFolder
主要用于從文件夾中加載圖像數據集,指定根目錄下的每一個子文件夾表示一個類別。該方法通常用于圖像分類任務,并且可以很方便地使用Dataloader來加載批量數據。
文件夾的目錄結構如下,root表示根目錄,class_0和class_1是以類名命名的文件夾,里面分別包含屬于該類的圖像。
root/ class_0/ images1.jpg images2.jpg .... class_1/ images1.jpg images2:jpg .... ....
我測試的根目錄 root 是mnist數據集中的train目錄,共有10類。其中第10類,類名為 ”九“,是我特意修改的,同樣也是為了驗證真實標簽與模型標簽之間的關系。
這是第一類 0 文件夾下的數據,均為手寫數字0 的圖片。
接下里可直接使用代碼加載該數據集。
from torchvision import datasets, transforms from torch.utils.data import DataLoader # 定義數據預處理操作 transform = transforms.Compose([ transforms.ToTensor(), # 將圖像轉換為張量 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # 標準化 ]) # 創(chuàng)建ImageFolder數據集,根目錄用了絕對路徑 dataset = datasets.ImageFolder(root='F:\Amode\datasets\mnist\train', transform=transform) # 打印數據集中第一項 print(dataset[0]) # 創(chuàng)建DataLoader數據加載迭代器 data_loader = DataLoader(dataset, batch_size=32, shuffle=True) #按照常例,迭代遍歷數據 for images,labels in data_loader: print("images:",images.shape) print("lables",labels)
執(zhí)行代碼,終端打印信息,首先還是數據集中的第一項,內容格式仍然是:
(tensor,label_index)
同樣,更簡便的方式,大家用debug模式調試代碼。
個人覺得,對于分類數據集,這種加載方式是非常容易和輕松的。前提是需要將數據集整理成固定的結構 。
4.torchvision.datasets.DatasetFolder
torchvision.datasets.DatasetFolder
是一個比 ImageFolder
更靈活的類,而ImageFolder繼承的父類就是它,它允許你自定義加載數據的方式,自定義數據集結構。
因為比較靈活百變,更難理解和掌握。接下來先了解下該方法的源碼,初始化參數及重要屬性。
這部分內容是初始化參數。
- root 是數據集的根目錄。
- loader 可自定義讀取數據樣本的方法,該方法傳入參數是樣本的路徑。
- extension 擴展名,指的是圖片的后綴類型,以元組形式入參。
- is_valid_file (可調用對象,可選項參數),獲取文件路徑并核實文件是否有效,它和extension必須有一個。
- allow_empty True 允許空文件被認為是一個類,False反之。
既然ImageFolder的父類就是它,可以先用它實現ImageFolder中要求的數據集目錄結構(結構在第3部門有說明)。以下代碼和ImagesFolde的r實現效果一致。
from torchvision import datasets, transforms from torch.utils.data import DataLoader from PIL import Image #自定義的圖像讀取方式 def custom_load(path): return Image.open(path).convert("RGB") # 定義數據預處理操作 transform = transforms.Compose([ transforms.ToTensor(), # 將圖像轉換為張量 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # 標準化 ]) # 創(chuàng)建ImageFolder數據集,根目錄用了絕對路徑 dataset = datasets.DatasetFolder( root=r'F:\Amode\datasets\mnist\train', loader= custom_load, transform=transform, extensions=("jpg","png") ) # 打印數據集中第一項 print(dataset[0]) # 創(chuàng)建DataLoader數據加載迭代器 data_loader = DataLoader(dataset, batch_size=32, shuffle=True) #按照常例,迭代遍歷數據 for images,labels in data_loader: print("images:",images.shape) print("lables",labels)
假設,換種數據集的目錄結構呢,這里舉例一種比較常見的結構,如下圖所示。
所有圖片都在同一目錄下,且圖片文件名稱以 label_name的格式命名,即標簽在文件名中體現。
接下來是實現的代碼,新定義了一個類,繼承DatasetsFolder類,重新定義了父類中的find_class,make_dataset函數。想具體了解這兩個函數的可點進父類源碼中去看。
find_class:輸入根目錄root,輸出classes(列表),所有的真實標簽(str),輸出class_to_idx(字典),鍵為真實標簽,值為類別索引值。
make_dataset:輸入仍是初始化那些參數;輸出樣本列表,格式為[(file_path,class_indx),.......]
import os from torchvision import datasets, transforms from torch.utils.data import DataLoader from PIL import Image #自定義的圖像加載方式 def custom_load(path): return Image.open(path).convert("RGB") # 定義數據預處理操作 transform = transforms.Compose([ transforms.ToTensor(), # 將圖像轉換為張量 transforms.Resize((224, 224)), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # 標準化 ]) class custom_DatasetFolder(datasets.DatasetFolder): #重寫find_classes函數 def find_classes(self, directory): """ 傳參:根目錄; 輸出:classes = [] ,classes_to_idx = {class:index} """ lables = set() lables_to_indexs = {} #獲取目錄下文件列表 file_list = os.listdir(self.root) #遍歷文件列表 for f in file_list: #從文件名中分離出標簽 lable = f.split('_')[0] #添加到集合中,集合不允許重復元素 lables.add(str(lable)) #生成真實標簽label與類別索引class的映射字典 for i,l in enumerate(list(lables)): lables_to_indexs[l] = int(i) return list(lables),lables_to_indexs def make_dataset(self,directory,class_to_idx,extensions,is_valid_file,allow_empty,): """ 傳參; 輸出:sample[(path,class),......] """ #獲取目錄下的文件列表 file = os.listdir(directory) samp = [] #遍歷文件 for f in file: #分離出標簽和文件后綴 lab = f.split('_')[0] sufix = f.split('.')[-1] #文件后綴滿足擴展要求 if sufix in extensions: #根據標簽找到類別class cls = class_to_idx[lab] #文件完整路徑 file_path = os.path.join(directory,f) #每個樣本以(path,class)格式添加到列表中 samp.append((str(file_path),cls)) return samp # 創(chuàng)建ImageFolder數據集,根目錄用了絕對路徑 dataset = custom_DatasetFolder( root=r'F:\Amode\datasets\image_data', loader= custom_load, transform=transform, extensions=("jpg","png") ) # 打印數據集中第一項 print(dataset[0]) # 創(chuàng)建DataLoader數據加載迭代器 data_loader = DataLoader(dataset, batch_size=32, shuffle=True) #按照常例,迭代遍歷數據 for images,labels in data_loader: print("images:",images.shape) print("lables",labels)
任意結構的數據集,都可以使用基類DatasetFolder實現,主要還是通過覆蓋上面兩個函數,實現獲取標簽類別屬性,以及樣本的路徑和類別,還有自定義的加載圖片函數。
5. torch.utils.data.Datasets
繼上面內容,這是唯一一個使用torch,定義數據集的方式。
翻譯一下上面的內容:
該類是一個抽象類,所有表示從鍵到數據樣本映射的數據集都應繼承此類。所有子類應重寫 __getitem__
方法,以支持根據給定的鍵獲取數據樣本。子類還可以選擇性地重寫 __len__
方法,這通常會返回數據集的大小,torch.utils.data.Sampler
實現和 torch.utils.data.DataLoader
的默認選項都期望這個方法的存在。子類還可以選擇性地實現 __getitems__
方法,以加速批量樣本的加載。該方法接受一個樣本索引的列表,并返回這些樣本的列表。
那什么叫抽象類呢?
抽象類是一種不能直接實例化的類,主要用于定義方法的基本結構和要求,其作為父類呢,通常讓子類去繼承它,并且在子類中必須實現這個抽象類中定義的方法,也就是具體的實現交給子類。
本節(jié)中用到的基類torch.utils.data.Datasets,需要實現以下三種方法。
__init__
: 初始化數據集對象,通常在這里加載和處理數據。__len__
: 返回數據集的大?。颖緮盗浚?/li>__getitem__
: 根據給定的索引返回數據集中的樣本和標簽。
這部分的演示代碼,使用的是上一小節(jié)中的數據集 ,數據集和實現代碼如下。
rom torch.utils.data import Dataset from PIL import Image import os class CustomDataset(Dataset): def __init__(self, image_folder, transform=None): """ Args: image_folder : 圖像所在文件夾的路徑 transform : 應用于樣本的轉換操作 """ self.image_folder = image_folder self.transform = transform self.class_to_idx = {} self.image_files = [f for f in os.listdir(image_folder) if f.endswith('.jpg')] self.__class_to_idx() def __len__(self): """返回數據集中的樣本數量""" return len(self.image_files) def __class_to_idx(self): labels = set() for file in os.listdir(self.image_folder): if file.endswith('.jpg'): label = file.split('_')[0] labels.add(str(label)) for i,l in enumerate(labels): self.class_to_idx[l] = int(i) def __getitem__(self, idx): """ Args: idx (int): 索引 Returns: dict: 包含圖像和標簽的字典 """ img_name = os.path.join(self.image_folder, self.image_files[idx]) image = Image.open(img_name).convert('RGB') if self.transform: image = self.transform(image) # 標簽從文件名中提取 lab_name = self.image_files[idx].split('_')[0] label = self.class_to_idx[lab_name] return image, label from torch.utils.data import DataLoader from torchvision import transforms # 定義轉換操作 transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor() ]) # 實例化自定義數據集 dataset = CustomDataset(image_folder='F:\Amode\datasets\image_data', transform=transform) # 創(chuàng)建 DataLoader data_loader = DataLoader(dataset, batch_size=4, shuffle=True) print(dataset[0]) # 使用 DataLoader 遍歷數據 for images, labels in data_loader: # 在這里進行訓練或測試操作 print(images.size(), labels)
以上就是Pytorch加載圖像數據集的方法的詳細內容,更多關于Pytorch加載圖像數據集的資料請關注腳本之家其它相關文章!
相關文章
集成開發(fā)環(huán)境Pycharm的安裝及模板設置圖文教程
PyCharm是一種Python的集成開發(fā)環(huán)境,帶有一整套可以幫助用戶在使用Python語言開發(fā)時提高效率的工具,這篇文章主要介紹了集成開發(fā)環(huán)境Pycharm的安裝及模板設置,需要的朋友可以參考下2022-07-07