亚洲乱码中文字幕综合,中国熟女仑乱hd,亚洲精品乱拍国产一区二区三区,一本大道卡一卡二卡三乱码全集资源,又粗又黄又硬又爽的免费视频

PyTorch實(shí)現(xiàn)重寫(xiě)/改寫(xiě)Dataset并載入Dataloader

 更新時(shí)間:2020年07月14日 16:11:53   作者:全員鱷魚(yú)  
這篇文章主要介紹了PyTorch實(shí)現(xiàn)重寫(xiě)/改寫(xiě)Dataset并載入Dataloader,文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來(lái)一起學(xué)習(xí)學(xué)習(xí)吧

前言

眾所周知,Dataset和Dataloder是pytorch中進(jìn)行數(shù)據(jù)載入的部件。必須將數(shù)據(jù)載入后,再進(jìn)行深度學(xué)習(xí)模型的訓(xùn)練。在pytorch的一些案例教學(xué)中,常使用torchvision.datasets自帶的MNIST、CIFAR-10數(shù)據(jù)集,一般流程為:

# 下載并存放數(shù)據(jù)集
train_dataset = torchvision.datasets.CIFAR10(root="數(shù)據(jù)集存放位置",download=True)
# load數(shù)據(jù)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset)

但是,在我們自己的模型訓(xùn)練中,需要使用非官方自制的數(shù)據(jù)集。這時(shí)應(yīng)該怎么辦呢?

我們可以通過(guò)改寫(xiě)torch.utils.data.Dataset中的__getitem____len__來(lái)載入我們自己的數(shù)據(jù)集。
__getitem__獲取數(shù)據(jù)集中的數(shù)據(jù),__len__獲取整個(gè)數(shù)據(jù)集的長(zhǎng)度(即個(gè)數(shù))。

改寫(xiě)

采用pytorch官網(wǎng)案例中提供的一個(gè)臉部landmark數(shù)據(jù)集。數(shù)據(jù)集中含有存放landmark的csv文件,但是我們?cè)谶@篇文章中不使用(其實(shí)也可以隨便下載一些圖片作數(shù)據(jù)集來(lái)實(shí)驗(yàn))。

import os
import torch
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils

plt.ion()  # interactive mode

torch.utils.data.Dataset是一個(gè)抽象類(lèi),我們自己的數(shù)據(jù)集需要繼承Dataset,然后改寫(xiě)上述兩個(gè)函數(shù):

class ImageLoader(Dataset):
  def __init__(self, file_path, transform=None):
    super(ImageLoader,self).__init__()
    self.file_path = file_path
    self.transform = transform # 對(duì)輸入圖像進(jìn)行預(yù)處理,這里并沒(méi)有做,預(yù)設(shè)為None
    self.image_names = os.listdir(self.file_path) # 文件名的列表
    
  def __getitem__(self,idx):
    image = self.image_names[idx]
    image = io.imread(os.path.join(self.file_path,image))
#    if self.transform:
#    	image= self.transform(image)
    return image
         
  def __len__(self):
    return len(self.image_names)

# 設(shè)置自己存放的數(shù)據(jù)集位置,并plot展示    
imageloader = ImageLoader(file_path="D:\\Projects\\datasets\\faces\\")
# imageloader.__len__()       # 輸出數(shù)據(jù)集長(zhǎng)度(個(gè)數(shù)),應(yīng)為71
# print(imageloader.__getitem__(0)) # 以數(shù)據(jù)形式展示
plt.imshow(imageloader.__getitem__(0)) # 以圖像形式展示
plt.show()

得到的圖片輸出:


得到的數(shù)據(jù)輸出,:

array([[[ 66, 59, 53],
    [ 66, 59, 53],
    [ 66, 59, 53],
    ...,
    [ 59, 54, 48],
    [ 59, 54, 48],
    [ 59, 54, 48]],
    ...,
    [153, 141, 129],
    [158, 146, 134],
    [158, 146, 134]]], dtype=uint8)

上面看到dytpe=uint8,實(shí)際進(jìn)行訓(xùn)練的時(shí)候,常常需要更改成float的數(shù)據(jù)類(lèi)型。可以使用:

# 直接改成pytorch中的tensor下的float格式 
# 也可以用numpy的改成普通的float格式
to_float= torch.from_numpy(imageloader.__getitem__(0)).float() 

改寫(xiě)完成后,直接使用train_loader =torch.utils.data.DataLoader(dataset=imageloader)載入到Dataloader中,就可以使用了。
下面的代碼可以試著運(yùn)行一下,產(chǎn)生的是一模一樣的圖片結(jié)果。

train_loader = torch.utils.data.DataLoader(dataset=imageloader)
train_loader.dataset[0]
plt.imshow(train_loader.dataset[0])
plt.show()

到此這篇關(guān)于PyTorch實(shí)現(xiàn)重寫(xiě)/改寫(xiě)Dataset并載入Dataloader的文章就介紹到這了,更多相關(guān)PyTorch重寫(xiě)/改寫(xiě)Dataset 內(nèi)容請(qǐng)搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!

相關(guān)文章

最新評(píng)論