pytorch加載的cifar10數(shù)據(jù)集過程詳解
pytorch怎么加載cifar10數(shù)據(jù)集
torchvision.datasets.CIFAR10
pytorch里面的torchvision.datasets中提供了大多數(shù)計算機視覺領域相關任務的數(shù)據(jù)集,可以根據(jù)實際需要加載相關數(shù)據(jù)集——需要cifar10就用torchvision.datasets.CIFAR10(),需要SVHN就調用torchvision.datasets.SVHN()。
針對cifar10數(shù)據(jù)集而言,調用torchvision.datasets.CIFAR10(),其中root是下載數(shù)據(jù)集后保存的位置;train是一個bool變量,為true就是訓練數(shù)據(jù)集,false就是測試數(shù)據(jù)集;download也是一個bool變量,表示是否下載;transform是對數(shù)據(jù)集中的"image"進行一些操作,比如歸一化、隨機裁剪、各種數(shù)據(jù)增強操作等;target_transform是針對數(shù)據(jù)集中的"label"進行一些操作。

示例代碼如下:
# 加載訓練數(shù)據(jù)集
train_data = datasets.CIFAR10(root='../_datasets', train=True, download=True,
transform= transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) # 歸一化
]) )
# 加載測試數(shù)據(jù)集
test_data = datasets.CIFAR10(root='../_datasets', train=False,download=True,
transform= transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) # 歸一化
]) )transforms.Normalize()進行歸一化到底在哪里起作用?【CIFAR10源碼分析】
上面的代碼中,我們用transforms.Compose([……])組合了一系列的對image的操作,其中trandforms.ToTensor()和transforms.Normalize()都涉及到歸一化操作:
- 原始的cifar10數(shù)據(jù)集是numpy array的形式,其中數(shù)據(jù)范圍是[0,255],pytorch加載時,并沒有改變數(shù)據(jù)范圍,依舊是[0,255],加載后的數(shù)據(jù)維度是(H, W, C),源碼部分:

__getitem__()函數(shù)中進行transforms操作,進行了歸一化:實際上傳入的transform在__getitem__()函數(shù)中被調用,其中transforms.Totensor()會將data(也就是image)的維度變成(C,H, W)的形式,并且歸一化到[0.0,1.0];

transforms.Normalize()會根據(jù)z = (x-mean) / std 對數(shù)據(jù)進行歸一化,上述代碼中mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]是可以將3個通道單獨進行歸一化,3個通道可以設置不同的mean和std,最終數(shù)據(jù)范圍變成[-0.5,+0.5] 。

所以如果通過pytorch的cifar10加載數(shù)據(jù)集后,針對traindataset.data,依舊是沒有進行歸一化的;但是比如traindataset[index].data,其中[index]這樣的按下標取元素的操作會直接調用的__getitem__()函數(shù),此時的data就是經過了歸一化的。
除traindataset[index]會隱式自動調用__getitem__()函數(shù)外,還有什么時候會調用這個函數(shù)呢?畢竟……只有調用了這個函數(shù)才會調用transforms中的歸一化處理。——答案是與dataloader搭配使用!
torchvision.datasets加載的數(shù)據(jù)集搭配Dataloader使用
torchvision.datasets實際上是torch.utils.data.Dataset的子類,那么就能傳入Dataloader中,迭代的按batch-size獲取批量數(shù)據(jù),用于訓練或者測試。其中dataloader加載dataset中的數(shù)據(jù)時,就是用到了其__getitem__()函數(shù),所以用dataloader加載數(shù)據(jù)集,得到的是經過歸一化后的數(shù)據(jù)。

model.train()和model.eval()
我發(fā)現(xiàn)上面的問題,是我用dataloader加載了訓練數(shù)據(jù)集用于訓練resnet18模型,訓練過程中,我訓練好并保存后,順便測試了一下在測試數(shù)據(jù)集上的準確度。但是在測試的過程中,我沒有用dataloader加載測試數(shù)據(jù)集,而是直接用的dataset.data來進行的測試。并且!由于是并沒有將model設置成model.eval()【其實我設置了,但是我對自己很無語,我寫的model.eval,忘記加括號了,無語嗚嗚】……也就是即便我的測試數(shù)據(jù)集沒有經過歸一化,由于模型還是在model.train()模式下,因此模型的BN層會自己調整,使得模型性能不受影響,因此在測試數(shù)據(jù)集上的accuracy達到了0.86,我就沒有多想。
后來我用模型的時候,設置了model.eval()后,依舊是直接用的dataset.data(也就是沒有歸一化),不管是在測試數(shù)據(jù)集上還是在訓練數(shù)據(jù)集上,accuracy都只有0.10+,我表示非常的迷茫疑惑啊!然后才發(fā)現(xiàn)是歸一化的問題。
- 在
model.train()模式下進行預測時,PyTorch會默認啟用一些訓練相關的操作,例如Batch Normalization和Dropout,并且模型的參數(shù)是可變的,能夠根據(jù)輸入進行調整。這些操作在訓練模式下可以幫助模型更好地適應訓練數(shù)據(jù),并產生較高的準確度。 - 在
model.eval()模式下進行預測時,PyTorch會將模型切換到評估模式,這會導致一些訓練相關的操作行為發(fā)生變化。具體而言,Batch Normalization層會使用訓練集上的統(tǒng)計信息進行歸一化,而不是使用當前批次的統(tǒng)計信息。因此,如果輸入數(shù)據(jù)沒有進行歸一化,模型在評估模式下的準確度可能會顯著下降。
以下是我沒有用dataloader加載數(shù)據(jù)集,進行預測的代碼:
def correctness(model,data,target, device):
batchsize = 1000
batch_num = int(len(data) / batchsize)
# 對原始的數(shù)據(jù)進行操作 從H.W.C變成C.H.W
data = torch.tensor(data).permute(0,3,1,2).type(torch.FloatTensor).to(device)
# 手動歸一化
data = data/255
data = (data - 0.5) / 0.5
# 求一個batch的correctness
def _batch_correctness(i):
images, labels = data[i*batchsize : (i+1)*batchsize], target[i*batchsize : (i+1)*batchsize]
predict = model(images).detach().cpu()
correctness = np.array(torch.argmax(predict, dim = 1).numpy() == np.array(labels) , dtype= np.float32)
return correctness
result = np.array([_batch_correctness(i) for i in range(batch_num)])
return result.flatten().sum()/data.shape[0]我后面用上面的代碼測試了四種情況:
- model.eval() + 沒有歸一化:train_accuracy = 0.10,test_accuracy = 0.10;
- model.eval() + 手動歸一化:train_accuracy = 0.95,test_accuracy = 0.84;
- model.train() + 沒有歸一化:train_accuracy = 0.95,test_accuracy = 0.83;
- model.train() + 手動歸一化:train_accuracy = 0.94,test_accuracy = 0.84;
由此可見,在model.eval()模式下,數(shù)據(jù)歸一化對最終的測試結果有很大影響。
到此這篇關于pytorch加載的cifar10數(shù)據(jù)集,到底有沒有經過歸一化的文章就介紹到這了,更多相關pytorch加載cifar10數(shù)據(jù)集內容請搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關文章希望大家以后多多支持腳本之家!
相關文章
聊聊基于pytorch實現(xiàn)Resnet對本地數(shù)據(jù)集的訓練問題
本文項目是使用Resnet模型來識別螞蟻和蜜蜂,其一共有三百九十六張的數(shù)據(jù),訓練集只有兩百多張(數(shù)據(jù)集很?。?,運行十輪后,分別對訓練集和測試集在每一輪的準確率,對pytorch實現(xiàn)Resnet本地數(shù)據(jù)集的訓練感興趣的朋友一起看看吧2022-03-03
Python+Selenium實現(xiàn)在Geoserver批量發(fā)布Mongo矢量數(shù)據(jù)
這篇文章主要為大家詳細介紹了如何利用Python+Selenium實現(xiàn)在 Geoserver批量發(fā)布來自Mongo中的矢量數(shù)據(jù),文中的示例代碼講解詳細,感興趣的小伙伴可以了解一下2022-07-07
Python表格處理模塊xlrd在Anaconda中的安裝方法
本文介紹在Anaconda環(huán)境下,安裝Python讀取.xls格式表格文件的庫xlrd的方法,xlrd是一個用于讀取Excel文件的Python庫,本文介紹了xlrd庫的一些主要特點和功能,感興趣的朋友一起看看吧2024-04-04
linux環(huán)境下安裝pyramid和新建項目的步驟
這篇文章簡單介紹了linux環(huán)境下安裝pyramid和新建項目的步驟,大家參考使用2013-11-11

