亚洲乱码中文字幕综合,中国熟女仑乱hd,亚洲精品乱拍国产一区二区三区,一本大道卡一卡二卡三乱码全集资源,又粗又黄又硬又爽的免费视频

pytorch加載的cifar10數(shù)據(jù)集過程詳解

 更新時間:2023年11月08日 10:01:01   作者:PleaseBrave  
這篇文章主要介紹了pytorch加載的cifar10數(shù)據(jù)集,到底有沒有經過歸一化,本文對這一問題給大家介紹的非常詳細,對大家的學習或工作具有一定的參考借鑒價值,需要的朋友參考下吧

pytorch怎么加載cifar10數(shù)據(jù)集

torchvision.datasets.CIFAR10

pytorch里面的torchvision.datasets中提供了大多數(shù)計算機視覺領域相關任務的數(shù)據(jù)集,可以根據(jù)實際需要加載相關數(shù)據(jù)集——需要cifar10就用torchvision.datasets.CIFAR10(),需要SVHN就調用torchvision.datasets.SVHN()。

針對cifar10數(shù)據(jù)集而言,調用torchvision.datasets.CIFAR10(),其中root是下載數(shù)據(jù)集后保存的位置;train是一個bool變量,為true就是訓練數(shù)據(jù)集,false就是測試數(shù)據(jù)集;download也是一個bool變量,表示是否下載;transform是對數(shù)據(jù)集中的"image"進行一些操作,比如歸一化、隨機裁剪、各種數(shù)據(jù)增強操作等;target_transform是針對數(shù)據(jù)集中的"label"進行一些操作。

示例代碼如下:

# 加載訓練數(shù)據(jù)集
train_data = datasets.CIFAR10(root='../_datasets', train=True, download=True,
                                  transform= transforms.Compose([  
                                                 transforms.ToTensor(),  
                                                 transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # 歸一化  
                                                 ])  )
# 加載測試數(shù)據(jù)集
test_data = datasets.CIFAR10(root='../_datasets', train=False,download=True, 
                             transform= transforms.Compose([  
                                               transforms.ToTensor(),  
                                               transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # 歸一化  
                                               ])  )

transforms.Normalize()進行歸一化到底在哪里起作用?【CIFAR10源碼分析】

上面的代碼中,我們用transforms.Compose([……])組合了一系列的對image的操作,其中trandforms.ToTensor()transforms.Normalize()都涉及到歸一化操作:

  • 原始的cifar10數(shù)據(jù)集是numpy array的形式,其中數(shù)據(jù)范圍是[0,255],pytorch加載時,并沒有改變數(shù)據(jù)范圍,依舊是[0,255],加載后的數(shù)據(jù)維度是(H, W, C),源碼部分:

  • __getitem__()函數(shù)中進行transforms操作,進行了歸一化:實際上傳入的transform在__getitem__()函數(shù)中被調用,其中transforms.Totensor()會將data(也就是image)的維度變成(C,H, W)的形式,并且歸一化到[0.0,1.0];

  • transforms.Normalize()會根據(jù)z = (x-mean) / std 對數(shù)據(jù)進行歸一化,上述代碼中mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]是可以將3個通道單獨進行歸一化,3個通道可以設置不同的mean和std,最終數(shù)據(jù)范圍變成[-0.5,+0.5] 。

所以如果通過pytorch的cifar10加載數(shù)據(jù)集后,針對traindataset.data,依舊是沒有進行歸一化的;但是比如traindataset[index].data,其中[index]這樣的按下標取元素的操作會直接調用的__getitem__()函數(shù),此時的data就是經過了歸一化的。
除traindataset[index]會隱式自動調用__getitem__()函數(shù)外,還有什么時候會調用這個函數(shù)呢?畢竟……只有調用了這個函數(shù)才會調用transforms中的歸一化處理。——答案是與dataloader搭配使用!

torchvision.datasets加載的數(shù)據(jù)集搭配Dataloader使用

torchvision.datasets實際上是torch.utils.data.Dataset的子類,那么就能傳入Dataloader中,迭代的按batch-size獲取批量數(shù)據(jù),用于訓練或者測試。其中dataloader加載dataset中的數(shù)據(jù)時,就是用到了其__getitem__()函數(shù),所以用dataloader加載數(shù)據(jù)集,得到的是經過歸一化后的數(shù)據(jù)。

在這里插入圖片描述

model.train()和model.eval()

我發(fā)現(xiàn)上面的問題,是我用dataloader加載了訓練數(shù)據(jù)集用于訓練resnet18模型,訓練過程中,我訓練好并保存后,順便測試了一下在測試數(shù)據(jù)集上的準確度。但是在測試的過程中,我沒有用dataloader加載測試數(shù)據(jù)集,而是直接用的dataset.data來進行的測試。并且!由于是并沒有將model設置成model.eval()【其實我設置了,但是我對自己很無語,我寫的model.eval,忘記加括號了,無語嗚嗚】……也就是即便我的測試數(shù)據(jù)集沒有經過歸一化,由于模型還是在model.train()模式下,因此模型的BN層會自己調整,使得模型性能不受影響,因此在測試數(shù)據(jù)集上的accuracy達到了0.86,我就沒有多想。
后來我用模型的時候,設置了model.eval()后,依舊是直接用的dataset.data(也就是沒有歸一化),不管是在測試數(shù)據(jù)集上還是在訓練數(shù)據(jù)集上,accuracy都只有0.10+,我表示非常的迷茫疑惑啊!然后才發(fā)現(xiàn)是歸一化的問題。

  • model.train()模式下進行預測時,PyTorch會默認啟用一些訓練相關的操作,例如Batch Normalization和Dropout,并且模型的參數(shù)是可變的,能夠根據(jù)輸入進行調整。這些操作在訓練模式下可以幫助模型更好地適應訓練數(shù)據(jù),并產生較高的準確度。
  • model.eval()模式下進行預測時,PyTorch會將模型切換到評估模式,這會導致一些訓練相關的操作行為發(fā)生變化。具體而言,Batch Normalization層會使用訓練集上的統(tǒng)計信息進行歸一化,而不是使用當前批次的統(tǒng)計信息。因此,如果輸入數(shù)據(jù)沒有進行歸一化,模型在評估模式下的準確度可能會顯著下降。

以下是我沒有用dataloader加載數(shù)據(jù)集,進行預測的代碼:

def correctness(model,data,target, device):
    batchsize = 1000
    batch_num = int(len(data) / batchsize)   
    # 對原始的數(shù)據(jù)進行操作 從H.W.C變成C.H.W 
    data = torch.tensor(data).permute(0,3,1,2).type(torch.FloatTensor).to(device)
    # 手動歸一化
    data = data/255
    data = (data - 0.5) / 0.5 
    # 求一個batch的correctness
    def _batch_correctness(i):
        images, labels = data[i*batchsize : (i+1)*batchsize], target[i*batchsize : (i+1)*batchsize]
        predict = model(images).detach().cpu()    
        correctness = np.array(torch.argmax(predict, dim = 1).numpy() == np.array(labels) , dtype= np.float32)
        return correctness
    result = np.array([_batch_correctness(i) for i in range(batch_num)])
    return result.flatten().sum()/data.shape[0]

我后面用上面的代碼測試了四種情況:

  • model.eval() + 沒有歸一化:train_accuracy = 0.10,test_accuracy = 0.10;
  • model.eval() + 手動歸一化:train_accuracy = 0.95,test_accuracy = 0.84;
  • model.train() + 沒有歸一化:train_accuracy = 0.95,test_accuracy = 0.83;
  • model.train() + 手動歸一化:train_accuracy = 0.94,test_accuracy = 0.84;

由此可見,在model.eval()模式下,數(shù)據(jù)歸一化對最終的測試結果有很大影響。

到此這篇關于pytorch加載的cifar10數(shù)據(jù)集,到底有沒有經過歸一化的文章就介紹到這了,更多相關pytorch加載cifar10數(shù)據(jù)集內容請搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關文章希望大家以后多多支持腳本之家!

相關文章

  • 利用Hyperic調用Python實現(xiàn)進程守護

    利用Hyperic調用Python實現(xiàn)進程守護

    這篇文章主要為大家詳細介紹了利用Hyperic調用Python實現(xiàn)進程守護,具有一定的參考價值,感興趣的小伙伴們可以參考一下
    2018-01-01
  • 聊聊基于pytorch實現(xiàn)Resnet對本地數(shù)據(jù)集的訓練問題

    聊聊基于pytorch實現(xiàn)Resnet對本地數(shù)據(jù)集的訓練問題

    本文項目是使用Resnet模型來識別螞蟻和蜜蜂,其一共有三百九十六張的數(shù)據(jù),訓練集只有兩百多張(數(shù)據(jù)集很?。?,運行十輪后,分別對訓練集和測試集在每一輪的準確率,對pytorch實現(xiàn)Resnet本地數(shù)據(jù)集的訓練感興趣的朋友一起看看吧
    2022-03-03
  • Python+Selenium實現(xiàn)在Geoserver批量發(fā)布Mongo矢量數(shù)據(jù)

    Python+Selenium實現(xiàn)在Geoserver批量發(fā)布Mongo矢量數(shù)據(jù)

    這篇文章主要為大家詳細介紹了如何利用Python+Selenium實現(xiàn)在 Geoserver批量發(fā)布來自Mongo中的矢量數(shù)據(jù),文中的示例代碼講解詳細,感興趣的小伙伴可以了解一下
    2022-07-07
  • python3自動更新緩存類的具體使用

    python3自動更新緩存類的具體使用

    本文介紹了使用一個自動更新緩存的Python類AutoUpdatingCache,文中通過示例代碼介紹的非常詳細,對大家的學習或者工作具有一定的參考學習價值,需要的朋友們下面隨著小編來一起學習學習吧
    2025-01-01
  • Python表格處理模塊xlrd在Anaconda中的安裝方法

    Python表格處理模塊xlrd在Anaconda中的安裝方法

    本文介紹在Anaconda環(huán)境下,安裝Python讀取.xls格式表格文件的庫xlrd的方法,xlrd是一個用于讀取Excel文件的Python庫,本文介紹了xlrd庫的一些主要特點和功能,感興趣的朋友一起看看吧
    2024-04-04
  • linux環(huán)境下安裝pyramid和新建項目的步驟

    linux環(huán)境下安裝pyramid和新建項目的步驟

    這篇文章簡單介紹了linux環(huán)境下安裝pyramid和新建項目的步驟,大家參考使用
    2013-11-11
  • 詳解Pytorch+PyG實現(xiàn)GCN過程示例

    詳解Pytorch+PyG實現(xiàn)GCN過程示例

    這篇文章主要為大家介紹了Pytorch+PyG實現(xiàn)GCN過程示例詳解,有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進步,早日升職加薪
    2023-04-04
  • 利用python實現(xiàn)漢字轉拼音的2種方法

    利用python實現(xiàn)漢字轉拼音的2種方法

    這篇文章主要給大家介紹了關于如何利用python實現(xiàn)漢字轉拼音的相關資料,文中通過示例代碼介紹的非常詳細,對大家學習或者使用python具有一定的參考學習價值,需要的朋友們下面來一起學習學習吧
    2019-08-08
  • 基于PyQt5自制簡單的文件內容檢索小工具

    基于PyQt5自制簡單的文件內容檢索小工具

    這篇文章主要為大家詳細介紹了如何基于PyQt5自制一個簡單的文件內容檢索小工具,文中的示例代碼講解詳細,感興趣的小伙伴可以跟隨小編一起學習一下
    2023-05-05
  • 詳細分析Python垃圾回收機制

    詳細分析Python垃圾回收機制

    這篇文章主要介紹了Python垃圾回收機制的相關資料,文中講解非常詳細,示例代碼幫助大家更好的理解和學習,感興趣的朋友可以了解下
    2020-07-07

最新評論