Pytorch中的model.train()?和?model.eval()?原理與用法解析
Pytorch中的model.train() 和 model.eval() 原理與用法
一、兩種模式
pytorch可以給我們提供兩種方式來切換訓(xùn)練和評估(推斷)的模式,分別是:model.train()
和 model.eval()
。
一般用法是:在訓(xùn)練開始之前寫上 model.trian() ,在測試時(shí)寫上 model.eval() 。
二、功能
1. model.train()
在使用 pytorch 構(gòu)建神經(jīng)網(wǎng)絡(luò)的時(shí)候,訓(xùn)練過程中會(huì)在程序上方添加一句model.train(),作用是 啟用 batch normalization 和 dropout 。
如果模型中有BN層(Batch Normalization)和 Dropout ,需要在 訓(xùn)練時(shí) 添加 model.train()。
model.train() 是保證 BN 層能夠用到 每一批數(shù)據(jù) 的均值和方差。對于 Dropout,model.train() 是 隨機(jī)取一部分 網(wǎng)絡(luò)連接來訓(xùn)練更新參數(shù)。
2. model.eval()
model.eval()的作用是 不啟用 Batch Normalization 和 Dropout。
如果模型中有 BN 層(Batch Normalization)和 Dropout,在 測試時(shí) 添加 model.eval()。
model.eval() 是保證 BN 層能夠用 全部訓(xùn)練數(shù)據(jù) 的均值和方差,即測試過程中要保證 BN 層的均值和方差不變。對于 Dropout,model.eval() 是利用到了 所有 網(wǎng)絡(luò)連接,即不進(jìn)行隨機(jī)舍棄神經(jīng)元。
為什么測試時(shí)要用 model.eval() ?
訓(xùn)練完 train 樣本后,生成的模型 model 要用來測試樣本了。在 model(test) 之前,需要加上model.eval(),否則的話,有輸入數(shù)據(jù),即使不訓(xùn)練,它也會(huì)改變權(quán)值。這是 model 中含有 BN 層和 Dropout 所帶來的的性質(zhì)。
eval() 時(shí),pytorch 會(huì)自動(dòng)把 BN 和 DropOut 固定住,不會(huì)取平均,而是用訓(xùn)練好的值。
不然的話,一旦 test 的 batch_size 過小,很容易就會(huì)被 BN 層導(dǎo)致生成圖片顏色失真極大。
eval() 在非訓(xùn)練的時(shí)候是需要加的,沒有這句代碼,一些網(wǎng)絡(luò)層的值會(huì)發(fā)生變動(dòng),不會(huì)固定,你神經(jīng)網(wǎng)絡(luò)每一次生成的結(jié)果也是不固定的,生成質(zhì)量可能好也可能不好。
也就是說,測試過程中使用model.eval(),這時(shí)神經(jīng)網(wǎng)絡(luò)會(huì) 沿用 batch normalization 的值,而并 不使用 dropout。
3. 總結(jié)與對比
如果模型中有 BN 層(Batch Normalization)和 Dropout,需要在訓(xùn)練時(shí)添加 model.train(),在測試時(shí)添加 model.eval()。
其中 model.train() 是保證 BN 層用每一批數(shù)據(jù)的均值和方差,而 model.eval() 是保證 BN 用全部訓(xùn)練數(shù)據(jù)的均值和方差;
而對于 Dropout,model.train() 是隨機(jī)取一部分網(wǎng)絡(luò)連接來訓(xùn)練更新參數(shù),而 model.eval() 是利用到了所有網(wǎng)絡(luò)連接。
三、Dropout 簡介
dropout 常常用于抑制過擬合。
設(shè)置Dropout時(shí),torch.nn.Dropout(0.5),這里的 0.5 是指該層(layer)的神經(jīng)元在每次迭代訓(xùn)練時(shí)會(huì)隨機(jī)有 50% 的可能性被丟棄(失活),不參與訓(xùn)練。也就是將上一層數(shù)據(jù)減少一半傳播。
參考鏈接
- PyTorch中train()方法的作用是什么
- 【pytorch】model.train()和model.evel()的用法
- pytorch中net.eval() 和net.train()的使用
- Pytorch學(xué)習(xí)筆記11----model.train()與model.eval()的用法、Dropout原理、relu,sigmiod,tanh激活函數(shù)、nn.Linear淺析、輸出整個(gè)tensor的方法
- 好文:Pytorch:model.train()和model.eval()用法和區(qū)別,以及model.eval()和torch.no_grad()的區(qū)別
補(bǔ)充:pytroch:model.train()、model.eval()的使用
前言:最近在把兩個(gè)模型的代碼整合到一起,發(fā)現(xiàn)有一個(gè)模型的代碼整合后性能大不如前,但基本上是源碼遷移,找了一天原因才發(fā)現(xiàn)是因?yàn)閙odel.eval()和model.train()放錯(cuò)了位置?。?!故在此介紹一下pytroch框架下model.train()、model.eval()的作用和不同點(diǎn)。
一、model.train、model.eval
1.model.train和model.eval放在代碼什么位置
簡單的說:
model.train
放在網(wǎng)絡(luò)訓(xùn)練前,model.eval
放在網(wǎng)絡(luò)測試前。
常見的位置擺放錯(cuò)誤(也是我犯的錯(cuò)誤)有把model.train()
放在for epoch in range(epoch):
前面,同時(shí)在test或者val(測試或者評估函數(shù))中只放置model.eval
,這就導(dǎo)致了只有第一個(gè)epoch模型訓(xùn)練是使用了model.train()
,之后的epoch模型訓(xùn)練時(shí)都采用model.eval()
.可能會(huì)影響訓(xùn)練好模型的性能。
修改方式:可以在test函數(shù)里return前面添加model.train()
或者把model.train()
放到for epoch in range(epoch):
語句下面。
model.train() for epoch in range(epoch): for train_batch in train_loader: ... zhibiao = test(epoch, test_loader, model) def test(epoch, test_loader, model): model.eval() for test_batch in test_loader: ... return zhibiao
2.model.train和model.eval有什么作用
model.train()和model.eval()的區(qū)別主要在于Batch Normalization和Dropout兩層。
如果模型中有BN層(Batch Normalization)和Dropout,在測試時(shí)添加model.eval()。model.eval()是保證BN層能夠用全部訓(xùn)練數(shù)據(jù)的均值和方差,即測試過程中要保證BN層的均值和方差不變。對于Dropout,model.eval()是利用到了所有網(wǎng)絡(luò)連接,即不進(jìn)行隨機(jī)舍棄神經(jīng)元。
下面是model.train 和model.eval的源碼,可以看到是利用self.training = mode
來判斷是使用train還是eval。這個(gè)參數(shù)將傳遞到一些常用層,比如dropout、BN層等。
def train(self: T, mode: bool = True) -> T: r"""Sets the module in training mode. This has any effect only on certain modules. See documentations of particular modules for details of their behaviors in training/evaluation mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`, etc. Args: mode (bool): whether to set training mode (``True``) or evaluation mode (``False``). Default: ``True``. Returns: Module: self """ self.training = mode for module in self.children(): module.train(mode) return self def eval(self: T) -> T: r"""Sets the module in evaluation mode. This has any effect only on certain modules. See documentations of particular modules for details of their behaviors in training/evaluation mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`, etc. This is equivalent with :meth:`self.train(False) <torch.nn.Module.train>`. Returns: Module: self """ return self.train(False)
拿dropout層的源碼舉例,可以看到傳遞了self.training這個(gè)參數(shù)。
class Dropout(_DropoutNd): r"""During training, randomly zeroes some of the elements of the input tensor with probability :attr:`p` using samples from a Bernoulli distribution. Each channel will be zeroed out independently on every forward call. This has proven to be an effective technique for regularization and preventing the co-adaptation of neurons as described in the paper `Improving neural networks by preventing co-adaptation of feature detectors`_ . Furthermore, the outputs are scaled by a factor of :math:`\frac{1}{1-p}` during training. This means that during evaluation the module simply computes an identity function. Args: p: probability of an element to be zeroed. Default: 0.5 inplace: If set to ``True``, will do this operation in-place. Default: ``False`` Shape: - Input: :math:`(*)`. Input can be of any shape - Output: :math:`(*)`. Output is of the same shape as input Examples:: >>> m = nn.Dropout(p=0.2) >>> input = torch.randn(20, 16) >>> output = m(input) .. _Improving neural networks by preventing co-adaptation of feature detectors: https://arxiv.org/abs/1207.0580 """ def forward(self, input: Tensor) -> Tensor: return F.dropout(input, self.p, self.training, self.inplace)
3.為什么主要區(qū)別在于BN層和dropout層
在BN層中,主要涉及到四個(gè)需要更新的參數(shù),分別是running_mean,running_var,weight,bias。這里的weight,bias是Pytorch官方實(shí)現(xiàn)中的叫法,有點(diǎn)誤導(dǎo)人,其實(shí)weight就是gamma,bias就是beta。當(dāng)然它這樣的叫法也符合實(shí)際的應(yīng)用場景。其實(shí)gamma,beta就是對規(guī)范化后的值進(jìn)行一個(gè)加權(quán)求和操作running_mean,running_var是當(dāng)前所求得的所有batch_size下的均值和方差,每經(jīng)過一個(gè)mini_batch我們都會(huì)更新running_mean,running_var.為什么要更新它?因?yàn)闇y試的時(shí)候,往往是一個(gè)一個(gè)的圖像feed至網(wǎng)絡(luò)的,如果你在這里對其進(jìn)行計(jì)算均值方差顯然是不合理的,所以model.eval()這個(gè)語句就是控制BN層中的running_mean,running_std不更新。采用訓(xùn)練結(jié)束后的running_mean,running_std來規(guī)范化該張圖像。
dropout層在訓(xùn)練過程中會(huì)隨機(jī)舍棄一些神經(jīng)元用來提高性能,但測試過程中如果還是測試的模型還是和訓(xùn)練時(shí)一樣隨機(jī)舍棄了一些神經(jīng)元(不是原模型)這就和測試的本意相違背。因?yàn)闇y試的模型應(yīng)該是我們最終得到的模型,而這個(gè)模型應(yīng)該是一個(gè)完整的模型。
4.BN層和dropout層的作用
既然都講到這了,不了解一些BN層和dropout層的作用就說不過去了。
BN層的原理和作用建議讀一下這篇博客:神經(jīng)網(wǎng)絡(luò)中BN層的原理與作用
dropout是指在深度學(xué)習(xí)網(wǎng)絡(luò)的訓(xùn)練過程中,對于神經(jīng)網(wǎng)絡(luò)單元,按照一定的概率將其暫時(shí)從網(wǎng)絡(luò)中丟棄。注意是暫時(shí),對于隨機(jī)梯度下降來說,由于是隨機(jī)丟棄,故而每一個(gè)mini-batch都在訓(xùn)練不同的網(wǎng)絡(luò)。
大規(guī)模的神經(jīng)網(wǎng)絡(luò)有兩個(gè)缺點(diǎn):費(fèi)時(shí)、容易過擬合
Dropout的出現(xiàn)很好的可以解決這個(gè)問題,每次做完dropout,相當(dāng)于從原始的網(wǎng)絡(luò)中找到一個(gè)更瘦的網(wǎng)絡(luò)。因而,對于一個(gè)有N個(gè)節(jié)點(diǎn)的神經(jīng)網(wǎng)絡(luò),有了dropout后,就可以看做是2^n個(gè)模型的集合了,但此時(shí)要訓(xùn)練的參數(shù)數(shù)目卻是不變的,這就解決了費(fèi)時(shí)的問題。
將dropout比作是有性繁殖,將基因隨機(jī)進(jìn)行拆分,可以將優(yōu)秀的基因傳下來,并且降低基因之間的聯(lián)合適應(yīng)性,使得復(fù)雜的大段大段基因聯(lián)合適應(yīng)性變成比較小的一個(gè)一個(gè)小段基因的聯(lián)合適應(yīng)性。
dropout也能達(dá)到同樣的效果,它強(qiáng)迫一個(gè)神經(jīng)單元,和隨機(jī)挑選出來的其他神經(jīng)單元共同工作,達(dá)到好的效果。消除減弱了神經(jīng)元節(jié)點(diǎn)間的聯(lián)合適應(yīng)性,增強(qiáng)了泛化能力。
參考鏈接
pytorch中model.train()和model.eval()的區(qū)別
BN層(Pytorch)
神經(jīng)網(wǎng)絡(luò)中BN層的原理與作用————這篇博客寫的賊棒
深度學(xué)習(xí)中Dropout的作用和原理
到此這篇關(guān)于Pytorch中的model.train() 和 model.eval() 原理與用法的文章就介紹到這了,更多相關(guān)Pytorch model.train() 和 model.eval()內(nèi)容請搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
相關(guān)文章
解決python字典對值(值為列表)賦值出現(xiàn)重復(fù)的問題
今天小編就為大家分享一篇解決python字典對值(值為列表)賦值出現(xiàn)重復(fù)的問題,具有很好的參考價(jià)值,希望對大家有所幫助。一起跟隨小編過來看看吧2019-01-01Python使用logging實(shí)現(xiàn)多進(jìn)程安全的日志模塊
這篇文章主要為大家詳細(xì)介紹了Python如何使用標(biāo)準(zhǔn)庫logging實(shí)現(xiàn)多進(jìn)程安全的日志模塊,文中的示例代碼講解詳細(xì),感興趣的小伙伴可以了解下2024-01-01python dataframe向下向上填充,fillna和ffill的方法
今天小編就為大家分享一篇python dataframe向下向上填充,fillna和ffill的方法,具有很好的參考價(jià)值,希望對大家有所幫助。一起跟隨小編過來看看吧2018-11-11Python unittest 簡單實(shí)現(xiàn)參數(shù)化的方法
今天小編就為大家分享一篇Python unittest 簡單實(shí)現(xiàn)參數(shù)化的方法,具有很好的參考價(jià)值,希望對大家有所幫助。一起跟隨小編過來看看吧2018-11-11Python 調(diào)用VC++的動(dòng)態(tài)鏈接庫(DLL)
Python下調(diào)用VC++的動(dòng)態(tài)鏈接庫的腳本2008-09-09使用pycharm在本地開發(fā)并實(shí)時(shí)同步到服務(wù)器
這篇文章主要介紹了使用pycharm在本地開發(fā)并實(shí)時(shí)同步到服務(wù)器,文中通過示例代碼介紹的非常詳細(xì),對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧2019-08-08