Pytorch中DataLoader的使用方法詳解
在Pytorch中,torch.utils.data中的Dataset與DataLoader是處理數(shù)據(jù)集的兩個函數(shù),用來處理加載數(shù)據(jù)集。通常情況下,使用的關(guān)鍵在于構(gòu)建dataset類。
一:dataset類構(gòu)建。
在構(gòu)建數(shù)據(jù)集類時,除了__init__(self),還要有__len__(self)與__getitem__(self,item)兩個方法,這三個是必不可少的,至于其它用于數(shù)據(jù)處理的函數(shù),可以任意定義。
class dataset: def __init__(self,...): ... def __len__(self,...): return n def __getitem__(self,item): return data[item]
正常情況下,該數(shù)據(jù)集是要繼承Pytorch中Dataset類的,但實際操作中,即使不繼承,數(shù)據(jù)集類構(gòu)建后仍可以用Dataloader()加載的。
在dataset類中,__len__(self)返回數(shù)據(jù)集中數(shù)據(jù)個數(shù),__getitem__(self,item)表示每次返回第item條數(shù)據(jù)。
二:DataLoader使用
在構(gòu)建dataset類后,即可使用DataLoader加載。DataLoader中常用參數(shù)如下:
1.dataset:需要載入的數(shù)據(jù)集,如前面構(gòu)造的dataset類。
2.batch_size:批大小,在神經(jīng)網(wǎng)絡(luò)訓練時我們很少逐條數(shù)據(jù)訓練,而是幾條數(shù)據(jù)作為一個batch進行訓練。
3.shuffle:是否在打亂數(shù)據(jù)集樣本順序。True為打亂,F(xiàn)alse反之。
4.drop_last:是否舍去最后一個batch的數(shù)據(jù)(很多情況下數(shù)據(jù)總數(shù)N與batch size不整除,導致最后一個batch不為batch size)。True為舍去,F(xiàn)alse反之。
三:舉例
兔兔以指標為1,數(shù)據(jù)個數(shù)為100的數(shù)據(jù)為例。
import torch from torch.utils.data import DataLoader class dataset: def __init__(self): self.x=torch.randint(0,20,size=(100,1),dtype=torch.float32) self.y=(torch.sin(self.x)+1)/2 def __len__(self): return 100 def __getitem__(self, item): return self.x[item],self.y[item] data=DataLoader(dataset(),batch_size=10,shuffle=True) for batch in data: print(batch)
當然,利用這個數(shù)據(jù)集可以進行簡單的神經(jīng)網(wǎng)絡(luò)訓練。
from torch import nn data=DataLoader(dataset(),batch_size=10,shuffle=True) bp=nn.Sequential(nn.Linear(1,5), nn.Sigmoid(), nn.Linear(5,1), nn.Sigmoid()) optim=torch.optim.Adam(params=bp.parameters()) Loss=nn.MSELoss() for epoch in range(10): print('the {} epoch'.format(epoch)) for batch in data: yp=bp(batch[0]) loss=Loss(yp,batch[1]) optim.zero_grad() loss.backward() optim.step()
ps:下面再給大家補充介紹下Pytorch中DataLoader的使用。
前言
最近開始接觸pytorch,從跑別人寫好的代碼開始,今天需要把輸入數(shù)據(jù)根據(jù)每個batch的最長輸入數(shù)據(jù),填充到一樣的長度(之前是將所有的數(shù)據(jù)直接填充到一樣的長度再輸入)。
剛開始是想偷懶,沒有去認真了解輸入的機制,結(jié)果一直報錯…還是要認真學習呀!
加載數(shù)據(jù)
pytorch中加載數(shù)據(jù)的順序是:
①創(chuàng)建一個dataset對象
②創(chuàng)建一個dataloader對象
③循環(huán)dataloader對象,將data,label拿到模型中去訓練
dataset
你需要自己定義一個class,里面至少包含3個函數(shù):
①__init__:傳入數(shù)據(jù),或者像下面一樣直接在函數(shù)里加載數(shù)據(jù)
②__len__:返回這個數(shù)據(jù)集一共有多少個item
③__getitem__:返回一條訓練數(shù)據(jù),并將其轉(zhuǎn)換成tensor
import torch from torch.utils.data import Dataset class Mydata(Dataset): def __init__(self): a = np.load("D:/Python/nlp/NRE/a.npy",allow_pickle=True) b = np.load("D:/Python/nlp/NRE/b.npy",allow_pickle=True) d = np.load("D:/Python/nlp/NRE/d.npy",allow_pickle=True) c = np.load("D:/Python/nlp/NRE/c.npy") self.x = list(zip(a,b,d,c)) def __getitem__(self, idx): assert idx < len(self.x) return self.x[idx] def __len__(self): return len(self.x)
dataloader
參數(shù):
dataset:傳入的數(shù)據(jù)
shuffle = True:是否打亂數(shù)據(jù)
collate_fn:使用這個參數(shù)可以自己操作每個batch的數(shù)據(jù)
dataset = Mydata() dataloader = DataLoader(dataset, batch_size = 2, shuffle=True,collate_fn = mycollate)
下面是將每個batch的數(shù)據(jù)填充到該batch的最大長度
def mycollate(data): a = [] b = [] c = [] d = [] max_len = len(data[0][0]) for i in data: if len(i[0])>max_len: max_len = len(i[0]) if len(i[1])>max_len: max_len = len(i[1]) if len(i[2])>max_len: max_len = len(i[2]) print(max_len) # 填充 for i in data: if len(i[0])<max_len: i[0].extend([27] * (max_len-len(i[0]))) if len(i[1])<max_len: i[1].extend([27] * (max_len-len(i[1]))) if len(i[2])<max_len: i[2].extend([27] * (max_len-len(i[2]))) a.append(i[0]) b.append(i[1]) d.append(i[2]) c.extend(i[3]) # 這里要自己轉(zhuǎn)成tensor a = torch.Tensor(a) b = torch.Tensor(b) c = torch.Tensor(c) d = torch.Tensor(d) data1 = [a,b,d,c] print("data1",data1) return data1
結(jié)果:
最后循環(huán)該dataloader ,拿到數(shù)據(jù)放入模型進行訓練:
for ii, data in enumerate(test_data_loader): if opt.use_gpu: data = list(map(lambda x: torch.LongTensor(x.long()).cuda(), data)) else: data = list(map(lambda x: torch.LongTensor(x.long()), data)) out = model(data[:-1]) #數(shù)據(jù)data[:-1] loss = F.cross_entropy(out, data[-1])# 最后一列是標簽
寫在最后:建議像我一樣剛開始不太熟練的小伙伴,在處理數(shù)據(jù)輸入的時候可以打印出來仔細查看。
到此這篇關(guān)于Pytorch中DataLoader的使用方法的文章就介紹到這了,更多相關(guān)Pytorch DataLoader內(nèi)容請搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
相關(guān)文章
python使用PyCharm進行遠程開發(fā)和調(diào)試
這篇文章主要介紹了python使用PyCharm進行遠程開發(fā)和調(diào)試,小編覺得挺不錯的,現(xiàn)在分享給大家,也給大家做個參考。一起跟隨小編過來看看吧2017-11-11python3?cookbook解壓可迭代對象賦值給多個變量的問題及解決方案
這篇文章主要介紹了python3?cookbook-解壓可迭代對象賦值給多個變量,本文給大家介紹的非常詳細,對大家的學習或工作具有一定的參考借鑒價值,需要的朋友可以參考下2024-01-01