pytorch中model.train()和model.eval()用法及說明
model.train()和model.eval()用法
1.1 model.train()
model.train()的作用是啟用 Batch Normalization 和 Dropout。
如果模型中有BN層(Batch Normalization)和Dropout,需要在訓練時添加model.train()。
model.train()是保證BN層能夠用到每一批數(shù)據(jù)的均值和方差。
對于Dropout,model.train()是隨機取一部分網(wǎng)絡連接來訓練更新參數(shù)。
1.2 model.eval()
model.eval()的作用是不啟用 Batch Normalization 和 Dropout。
如果模型中有BN層(Batch Normalization)和Dropout,在測試時添加model.eval()。
model.eval()是保證BN層能夠用全部訓練數(shù)據(jù)的均值和方差,即測試過程中要保證BN層的均值和方差不變。
對于Dropout,model.eval()是利用到了所有網(wǎng)絡連接,即不進行隨機舍棄神經(jīng)元。
訓練完train樣本后,生成的模型model要用來測試樣本。
在model(test)之前,需要加上model.eval(),否則的話,有輸入數(shù)據(jù),即使不訓練,它也會改變權(quán)值。這是model中含有BN層和Dropout所帶來的的性質(zhì)。
在做one classification的時候,訓練集和測試集的樣本分布是不一樣的,尤其需要注意這一點。
1.3 分析原因
使用PyTorch進行訓練和測試時一定注意要把實例化的model指定train/eval。
model.eval()時,框架會自動把BN和Dropout固定住,不會取平均,而是用訓練好的值,
不然的話,一旦test的batch_size過小,很容易就會被BN層導致生成圖片顏色失真極大?。。。。。?/p>
總結(jié)
以上為個人經(jīng)驗,希望能給大家一個參考,也希望大家多多支持腳本之家。
相關文章
Python全面解析json數(shù)據(jù)并保存為csv文件
這篇文章主要介紹了Python全面解析json數(shù)據(jù)并保存為csv文件,具有很好的參考價值,希望對大家有所幫助。如有錯誤或未考慮完全的地方,望不吝賜教2022-07-07Python pandas 的索引方式 data.loc[],data[][]示例詳解
這篇文章主要介紹了Python pandas 的索引方式 data.loc[], data[][]的相關資料,其中data.loc[index,column]使用.loc[ ]第一個參數(shù)是行索引,第二個參數(shù)是列索引,本文結(jié)合實例代碼講解的非常詳細,需要的朋友可以參考下2023-02-02Python實現(xiàn)刪除list列表重復元素的方法總結(jié)
在Python編程中,我們經(jīng)常需要處理列表中的重復元素,這篇文章為大家介紹了五種高效的方法來刪除列表中的重復元素,希望對大家有所幫助2023-07-07Python列表排序方法reverse、sort、sorted詳解
這篇文章主要介紹了Python列表排序方法reverse、sort、sorted詳解,需要的朋友可以參考下2021-04-04