pytorch加載的cifar10數(shù)據(jù)集過程詳解
pytorch怎么加載cifar10數(shù)據(jù)集
torchvision.datasets.CIFAR10
pytorch里面的torchvision.datasets中提供了大多數(shù)計(jì)算機(jī)視覺領(lǐng)域相關(guān)任務(wù)的數(shù)據(jù)集,可以根據(jù)實(shí)際需要加載相關(guān)數(shù)據(jù)集——需要cifar10就用torchvision.datasets.CIFAR10(),需要SVHN就調(diào)用torchvision.datasets.SVHN()。
針對cifar10數(shù)據(jù)集而言,調(diào)用torchvision.datasets.CIFAR10(),其中root是下載數(shù)據(jù)集后保存的位置;train是一個bool變量,為true就是訓(xùn)練數(shù)據(jù)集,false就是測試數(shù)據(jù)集;download也是一個bool變量,表示是否下載;transform是對數(shù)據(jù)集中的"image"進(jìn)行一些操作,比如歸一化、隨機(jī)裁剪、各種數(shù)據(jù)增強(qiáng)操作等;target_transform是針對數(shù)據(jù)集中的"label"進(jìn)行一些操作。
示例代碼如下:
# 加載訓(xùn)練數(shù)據(jù)集 train_data = datasets.CIFAR10(root='../_datasets', train=True, download=True, transform= transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) # 歸一化 ]) ) # 加載測試數(shù)據(jù)集 test_data = datasets.CIFAR10(root='../_datasets', train=False,download=True, transform= transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) # 歸一化 ]) )
transforms.Normalize()進(jìn)行歸一化到底在哪里起作用?【CIFAR10源碼分析】
上面的代碼中,我們用transforms.Compose([……])組合了一系列的對image的操作,其中trandforms.ToTensor()
和transforms.Normalize()
都涉及到歸一化操作:
- 原始的cifar10數(shù)據(jù)集是numpy array的形式,其中數(shù)據(jù)范圍是[0,255],pytorch加載時,并沒有改變數(shù)據(jù)范圍,依舊是[0,255],加載后的數(shù)據(jù)維度是(H, W, C),源碼部分:
__getitem__()
函數(shù)中進(jìn)行transforms操作,進(jìn)行了歸一化:實(shí)際上傳入的transform在__getitem__()
函數(shù)中被調(diào)用,其中transforms.Totensor()
會將data(也就是image)的維度變成(C,H, W)的形式,并且歸一化到[0.0,1.0];
transforms.Normalize()
會根據(jù)z = (x-mean) / std 對數(shù)據(jù)進(jìn)行歸一化,上述代碼中mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]
是可以將3個通道單獨(dú)進(jìn)行歸一化,3個通道可以設(shè)置不同的mean和std,最終數(shù)據(jù)范圍變成[-0.5,+0.5] 。
所以如果通過pytorch的cifar10加載數(shù)據(jù)集后,針對traindataset.data,依舊是沒有進(jìn)行歸一化的;但是比如traindataset[index].data,其中[index]這樣的按下標(biāo)取元素的操作會直接調(diào)用的__getitem__()函數(shù),此時的data就是經(jīng)過了歸一化的。
除traindataset[index]會隱式自動調(diào)用__getitem__()函數(shù)外,還有什么時候會調(diào)用這個函數(shù)呢?畢竟……只有調(diào)用了這個函數(shù)才會調(diào)用transforms中的歸一化處理。——答案是與dataloader搭配使用!
torchvision.datasets加載的數(shù)據(jù)集搭配Dataloader使用
torchvision.datasets實(shí)際上是torch.utils.data.Dataset的子類,那么就能傳入Dataloader中,迭代的按batch-size獲取批量數(shù)據(jù),用于訓(xùn)練或者測試。其中dataloader加載dataset中的數(shù)據(jù)時,就是用到了其__getitem__()函數(shù),所以用dataloader加載數(shù)據(jù)集,得到的是經(jīng)過歸一化后的數(shù)據(jù)。
model.train()和model.eval()
我發(fā)現(xiàn)上面的問題,是我用dataloader加載了訓(xùn)練數(shù)據(jù)集用于訓(xùn)練resnet18模型,訓(xùn)練過程中,我訓(xùn)練好并保存后,順便測試了一下在測試數(shù)據(jù)集上的準(zhǔn)確度。但是在測試的過程中,我沒有用dataloader加載測試數(shù)據(jù)集,而是直接用的dataset.data來進(jìn)行的測試。并且!由于是并沒有將model設(shè)置成model.eval()【其實(shí)我設(shè)置了,但是我對自己很無語,我寫的model.eval,忘記加括號了,無語嗚嗚】……也就是即便我的測試數(shù)據(jù)集沒有經(jīng)過歸一化,由于模型還是在model.train()模式下,因此模型的BN層會自己調(diào)整,使得模型性能不受影響,因此在測試數(shù)據(jù)集上的accuracy達(dá)到了0.86,我就沒有多想。
后來我用模型的時候,設(shè)置了model.eval()后,依舊是直接用的dataset.data(也就是沒有歸一化),不管是在測試數(shù)據(jù)集上還是在訓(xùn)練數(shù)據(jù)集上,accuracy都只有0.10+,我表示非常的迷茫疑惑??!然后才發(fā)現(xiàn)是歸一化的問題。
- 在
model.train()
模式下進(jìn)行預(yù)測時,PyTorch會默認(rèn)啟用一些訓(xùn)練相關(guān)的操作,例如Batch Normalization和Dropout,并且模型的參數(shù)是可變的,能夠根據(jù)輸入進(jìn)行調(diào)整。這些操作在訓(xùn)練模式下可以幫助模型更好地適應(yīng)訓(xùn)練數(shù)據(jù),并產(chǎn)生較高的準(zhǔn)確度。 - 在
model.eval()
模式下進(jìn)行預(yù)測時,PyTorch會將模型切換到評估模式,這會導(dǎo)致一些訓(xùn)練相關(guān)的操作行為發(fā)生變化。具體而言,Batch Normalization層會使用訓(xùn)練集上的統(tǒng)計(jì)信息進(jìn)行歸一化,而不是使用當(dāng)前批次的統(tǒng)計(jì)信息。因此,如果輸入數(shù)據(jù)沒有進(jìn)行歸一化,模型在評估模式下的準(zhǔn)確度可能會顯著下降。
以下是我沒有用dataloader加載數(shù)據(jù)集,進(jìn)行預(yù)測的代碼:
def correctness(model,data,target, device): batchsize = 1000 batch_num = int(len(data) / batchsize) # 對原始的數(shù)據(jù)進(jìn)行操作 從H.W.C變成C.H.W data = torch.tensor(data).permute(0,3,1,2).type(torch.FloatTensor).to(device) # 手動歸一化 data = data/255 data = (data - 0.5) / 0.5 # 求一個batch的correctness def _batch_correctness(i): images, labels = data[i*batchsize : (i+1)*batchsize], target[i*batchsize : (i+1)*batchsize] predict = model(images).detach().cpu() correctness = np.array(torch.argmax(predict, dim = 1).numpy() == np.array(labels) , dtype= np.float32) return correctness result = np.array([_batch_correctness(i) for i in range(batch_num)]) return result.flatten().sum()/data.shape[0]
我后面用上面的代碼測試了四種情況:
- model.eval() + 沒有歸一化:train_accuracy = 0.10,test_accuracy = 0.10;
- model.eval() + 手動歸一化:train_accuracy = 0.95,test_accuracy = 0.84;
- model.train() + 沒有歸一化:train_accuracy = 0.95,test_accuracy = 0.83;
- model.train() + 手動歸一化:train_accuracy = 0.94,test_accuracy = 0.84;
由此可見,在model.eval()模式下,數(shù)據(jù)歸一化對最終的測試結(jié)果有很大影響。
到此這篇關(guān)于pytorch加載的cifar10數(shù)據(jù)集,到底有沒有經(jīng)過歸一化的文章就介紹到這了,更多相關(guān)pytorch加載cifar10數(shù)據(jù)集內(nèi)容請搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
- Pytorch搭建簡單的卷積神經(jīng)網(wǎng)絡(luò)(CNN)實(shí)現(xiàn)MNIST數(shù)據(jù)集分類任務(wù)
- Pytorch卷積神經(jīng)網(wǎng)絡(luò)遷移學(xué)習(xí)的目標(biāo)及好處
- Pytorch深度學(xué)習(xí)經(jīng)典卷積神經(jīng)網(wǎng)絡(luò)resnet模塊訓(xùn)練
- Pytorch卷積神經(jīng)網(wǎng)絡(luò)resent網(wǎng)絡(luò)實(shí)踐
- pytorch中的模型訓(xùn)練(以CIFAR10數(shù)據(jù)集為例)
- Pytorch使用卷積神經(jīng)網(wǎng)絡(luò)對CIFAR10圖片進(jìn)行分類方式
相關(guān)文章
利用Hyperic調(diào)用Python實(shí)現(xiàn)進(jìn)程守護(hù)
這篇文章主要為大家詳細(xì)介紹了利用Hyperic調(diào)用Python實(shí)現(xiàn)進(jìn)程守護(hù),具有一定的參考價值,感興趣的小伙伴們可以參考一下2018-01-01聊聊基于pytorch實(shí)現(xiàn)Resnet對本地?cái)?shù)據(jù)集的訓(xùn)練問題
本文項(xiàng)目是使用Resnet模型來識別螞蟻和蜜蜂,其一共有三百九十六張的數(shù)據(jù),訓(xùn)練集只有兩百多張(數(shù)據(jù)集很?。?,運(yùn)行十輪后,分別對訓(xùn)練集和測試集在每一輪的準(zhǔn)確率,對pytorch實(shí)現(xiàn)Resnet本地?cái)?shù)據(jù)集的訓(xùn)練感興趣的朋友一起看看吧2022-03-03Python+Selenium實(shí)現(xiàn)在Geoserver批量發(fā)布Mongo矢量數(shù)據(jù)
這篇文章主要為大家詳細(xì)介紹了如何利用Python+Selenium實(shí)現(xiàn)在 Geoserver批量發(fā)布來自Mongo中的矢量數(shù)據(jù),文中的示例代碼講解詳細(xì),感興趣的小伙伴可以了解一下2022-07-07Python表格處理模塊xlrd在Anaconda中的安裝方法
本文介紹在Anaconda環(huán)境下,安裝Python讀取.xls格式表格文件的庫xlrd的方法,xlrd是一個用于讀取Excel文件的Python庫,本文介紹了xlrd庫的一些主要特點(diǎn)和功能,感興趣的朋友一起看看吧2024-04-04linux環(huán)境下安裝pyramid和新建項(xiàng)目的步驟
這篇文章簡單介紹了linux環(huán)境下安裝pyramid和新建項(xiàng)目的步驟,大家參考使用2013-11-11詳解Pytorch+PyG實(shí)現(xiàn)GCN過程示例
這篇文章主要為大家介紹了Pytorch+PyG實(shí)現(xiàn)GCN過程示例詳解,有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進(jìn)步,早日升職加薪2023-04-04利用python實(shí)現(xiàn)漢字轉(zhuǎn)拼音的2種方法
這篇文章主要給大家介紹了關(guān)于如何利用python實(shí)現(xiàn)漢字轉(zhuǎn)拼音的相關(guān)資料,文中通過示例代碼介紹的非常詳細(xì),對大家學(xué)習(xí)或者使用python具有一定的參考學(xué)習(xí)價值,需要的朋友們下面來一起學(xué)習(xí)學(xué)習(xí)吧2019-08-08