Pytorch中DataLoader的使用方法詳解
在Pytorch中,torch.utils.data中的Dataset與DataLoader是處理數(shù)據(jù)集的兩個(gè)函數(shù),用來處理加載數(shù)據(jù)集。通常情況下,使用的關(guān)鍵在于構(gòu)建dataset類。
一:dataset類構(gòu)建。
在構(gòu)建數(shù)據(jù)集類時(shí),除了__init__(self),還要有__len__(self)與__getitem__(self,item)兩個(gè)方法,這三個(gè)是必不可少的,至于其它用于數(shù)據(jù)處理的函數(shù),可以任意定義。
class dataset:
def __init__(self,...):
...
def __len__(self,...):
return n
def __getitem__(self,item):
return data[item]正常情況下,該數(shù)據(jù)集是要繼承Pytorch中Dataset類的,但實(shí)際操作中,即使不繼承,數(shù)據(jù)集類構(gòu)建后仍可以用Dataloader()加載的。
在dataset類中,__len__(self)返回?cái)?shù)據(jù)集中數(shù)據(jù)個(gè)數(shù),__getitem__(self,item)表示每次返回第item條數(shù)據(jù)。
二:DataLoader使用
在構(gòu)建dataset類后,即可使用DataLoader加載。DataLoader中常用參數(shù)如下:
1.dataset:需要載入的數(shù)據(jù)集,如前面構(gòu)造的dataset類。
2.batch_size:批大小,在神經(jīng)網(wǎng)絡(luò)訓(xùn)練時(shí)我們很少逐條數(shù)據(jù)訓(xùn)練,而是幾條數(shù)據(jù)作為一個(gè)batch進(jìn)行訓(xùn)練。
3.shuffle:是否在打亂數(shù)據(jù)集樣本順序。True為打亂,F(xiàn)alse反之。
4.drop_last:是否舍去最后一個(gè)batch的數(shù)據(jù)(很多情況下數(shù)據(jù)總數(shù)N與batch size不整除,導(dǎo)致最后一個(gè)batch不為batch size)。True為舍去,F(xiàn)alse反之。
三:舉例
兔兔以指標(biāo)為1,數(shù)據(jù)個(gè)數(shù)為100的數(shù)據(jù)為例。
import torch
from torch.utils.data import DataLoader
class dataset:
def __init__(self):
self.x=torch.randint(0,20,size=(100,1),dtype=torch.float32)
self.y=(torch.sin(self.x)+1)/2
def __len__(self):
return 100
def __getitem__(self, item):
return self.x[item],self.y[item]
data=DataLoader(dataset(),batch_size=10,shuffle=True)
for batch in data:
print(batch)當(dāng)然,利用這個(gè)數(shù)據(jù)集可以進(jìn)行簡單的神經(jīng)網(wǎng)絡(luò)訓(xùn)練。
from torch import nn
data=DataLoader(dataset(),batch_size=10,shuffle=True)
bp=nn.Sequential(nn.Linear(1,5),
nn.Sigmoid(),
nn.Linear(5,1),
nn.Sigmoid())
optim=torch.optim.Adam(params=bp.parameters())
Loss=nn.MSELoss()
for epoch in range(10):
print('the {} epoch'.format(epoch))
for batch in data:
yp=bp(batch[0])
loss=Loss(yp,batch[1])
optim.zero_grad()
loss.backward()
optim.step()ps:下面再給大家補(bǔ)充介紹下Pytorch中DataLoader的使用。
前言
最近開始接觸pytorch,從跑別人寫好的代碼開始,今天需要把輸入數(shù)據(jù)根據(jù)每個(gè)batch的最長輸入數(shù)據(jù),填充到一樣的長度(之前是將所有的數(shù)據(jù)直接填充到一樣的長度再輸入)。
剛開始是想偷懶,沒有去認(rèn)真了解輸入的機(jī)制,結(jié)果一直報(bào)錯(cuò)…還是要認(rèn)真學(xué)習(xí)呀!
加載數(shù)據(jù)
pytorch中加載數(shù)據(jù)的順序是:
①創(chuàng)建一個(gè)dataset對象
②創(chuàng)建一個(gè)dataloader對象
③循環(huán)dataloader對象,將data,label拿到模型中去訓(xùn)練
dataset
你需要自己定義一個(gè)class,里面至少包含3個(gè)函數(shù):
①__init__:傳入數(shù)據(jù),或者像下面一樣直接在函數(shù)里加載數(shù)據(jù)
②__len__:返回這個(gè)數(shù)據(jù)集一共有多少個(gè)item
③__getitem__:返回一條訓(xùn)練數(shù)據(jù),并將其轉(zhuǎn)換成tensor
import torch
from torch.utils.data import Dataset
class Mydata(Dataset):
def __init__(self):
a = np.load("D:/Python/nlp/NRE/a.npy",allow_pickle=True)
b = np.load("D:/Python/nlp/NRE/b.npy",allow_pickle=True)
d = np.load("D:/Python/nlp/NRE/d.npy",allow_pickle=True)
c = np.load("D:/Python/nlp/NRE/c.npy")
self.x = list(zip(a,b,d,c))
def __getitem__(self, idx):
assert idx < len(self.x)
return self.x[idx]
def __len__(self):
return len(self.x)
dataloader
參數(shù):
dataset:傳入的數(shù)據(jù)
shuffle = True:是否打亂數(shù)據(jù)
collate_fn:使用這個(gè)參數(shù)可以自己操作每個(gè)batch的數(shù)據(jù)
dataset = Mydata() dataloader = DataLoader(dataset, batch_size = 2, shuffle=True,collate_fn = mycollate)
下面是將每個(gè)batch的數(shù)據(jù)填充到該batch的最大長度
def mycollate(data):
a = []
b = []
c = []
d = []
max_len = len(data[0][0])
for i in data:
if len(i[0])>max_len:
max_len = len(i[0])
if len(i[1])>max_len:
max_len = len(i[1])
if len(i[2])>max_len:
max_len = len(i[2])
print(max_len)
# 填充
for i in data:
if len(i[0])<max_len:
i[0].extend([27] * (max_len-len(i[0])))
if len(i[1])<max_len:
i[1].extend([27] * (max_len-len(i[1])))
if len(i[2])<max_len:
i[2].extend([27] * (max_len-len(i[2])))
a.append(i[0])
b.append(i[1])
d.append(i[2])
c.extend(i[3])
# 這里要自己轉(zhuǎn)成tensor
a = torch.Tensor(a)
b = torch.Tensor(b)
c = torch.Tensor(c)
d = torch.Tensor(d)
data1 = [a,b,d,c]
print("data1",data1)
return data1
結(jié)果:

最后循環(huán)該dataloader ,拿到數(shù)據(jù)放入模型進(jìn)行訓(xùn)練:
for ii, data in enumerate(test_data_loader):
if opt.use_gpu:
data = list(map(lambda x: torch.LongTensor(x.long()).cuda(), data))
else:
data = list(map(lambda x: torch.LongTensor(x.long()), data))
out = model(data[:-1]) #數(shù)據(jù)data[:-1]
loss = F.cross_entropy(out, data[-1])# 最后一列是標(biāo)簽
寫在最后:建議像我一樣剛開始不太熟練的小伙伴,在處理數(shù)據(jù)輸入的時(shí)候可以打印出來仔細(xì)查看。
到此這篇關(guān)于Pytorch中DataLoader的使用方法的文章就介紹到這了,更多相關(guān)Pytorch DataLoader內(nèi)容請搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
相關(guān)文章
Python實(shí)現(xiàn)簡易端口掃描器代碼實(shí)例
本篇文章主要介紹了Python實(shí)現(xiàn)簡易端口掃描器的相關(guān)代碼,具有一定的參考價(jià)值,感興趣的小伙伴們可以參考一下。2017-03-03
K-近鄰算法的python實(shí)現(xiàn)代碼分享
這篇文章主要介紹了K-近鄰算法的python實(shí)現(xiàn)代碼分享,具有一定借鑒價(jià)值,需要的朋友可以參考下。2017-12-12
python使用PyCharm進(jìn)行遠(yuǎn)程開發(fā)和調(diào)試
這篇文章主要介紹了python使用PyCharm進(jìn)行遠(yuǎn)程開發(fā)和調(diào)試,小編覺得挺不錯(cuò)的,現(xiàn)在分享給大家,也給大家做個(gè)參考。一起跟隨小編過來看看吧2017-11-11
python3?cookbook解壓可迭代對象賦值給多個(gè)變量的問題及解決方案
這篇文章主要介紹了python3?cookbook-解壓可迭代對象賦值給多個(gè)變量,本文給大家介紹的非常詳細(xì),對大家的學(xué)習(xí)或工作具有一定的參考借鑒價(jià)值,需要的朋友可以參考下2024-01-01

