PyTorch中常見損失函數(shù)的使用詳解
損失函數(shù)
損失函數(shù),又叫目標函數(shù)。在編譯神經(jīng)網(wǎng)絡(luò)模型必須的兩個參數(shù)之一。另一個必不可少的就是優(yōu)化器,我將在后面詳解到。
重點
損失函數(shù)是指計算機標簽值和預測值直接差異的函數(shù)。
這里我們會結(jié)束幾種常見的損失函數(shù)的計算方法,pytorch中也是以及定義了很多類型的預定義函數(shù),具體的公式不需要去深究(學了也不一定remember),這里暫時能做就是了解。
我們先來定義兩個二維的數(shù)組,然后用不同的損失函數(shù)計算其損失值。
import torch from torch.autograd import Variable import torch.nn as nn sample=Variable(torch.ones(2,2)) a=torch.Tensor(2,2) a[0,0]=0 a[0,1]=1 a[1,0]=2 a[1,1]=3 target=Variable(a) print(sample,target)
這里:
sample的值為tensor([[1., 1.],[1., 1.]])
target的值為tensor([[0., 1.],[2., 3.]])
nn.L1Loss

L1Loss計算方法很簡單,取預測值和真實值的絕對誤差的平均數(shù)。
loss=FunLoss(sample,target)['L1Loss'] print(loss)
在控制臺中打印出來是
tensor(1.)
它的計算過程是這樣的:(∣0−1∣+∣1−1∣+∣2−1∣+∣3−1∣)/4=1,先計算的是絕對值求和,然后再平均。
nn.SmoothL1Loss
SmoothL1Loss的誤差在(-1,1)上是平方損失,其他情況是L1損失。

loss=FunLoss(sample,target)['SmoothL1Loss'] print(loss)
在控制臺中打印出來是
tensor(0.6250)
nn.MSELoss
平方損失函數(shù)。其計算公式是預測值和真實值之間的平方和的平均數(shù)。

loss=FunLoss(sample,target)['MSELoss'] print(loss)
在控制臺中打印出來是
tensor(1.5000)
nn.CrossEntropyLoss
交叉熵損失公式

此公式常在圖像分類神經(jīng)網(wǎng)絡(luò)模型中會常常用到。
loss=FunLoss(sample,target)['CrossEntropyLoss'] print(loss)
在控制臺中打印出來是
tensor(2.0794)
nn.NLLLoss
負對數(shù)似然損失函數(shù)

需要注意的是,這里的xlabel和上面的交叉熵損失里的是不一樣的,這里是經(jīng)過log運算后的數(shù)值。這個損失函數(shù)一般用在圖像識別的模型上。
loss=FunLoss(sample,target)['NLLLoss'] print(loss)
這里,控制臺報錯,需要0D或1D目標張量,不支持多目標??赡苄枰渌囊恍l件,這里我們?nèi)绻龅搅嗽僬f。
損失函數(shù)模塊化設(shè)計
class FunLoss():
def __init__(self, sample, target):
self.sample = sample
self.target = target
self.loss = {
'L1Loss': nn.L1Loss(),
'SmoothL1Loss': nn.SmoothL1Loss(),
'MSELoss': nn.MSELoss(),
'CrossEntropyLoss': nn.CrossEntropyLoss(),
'NLLLoss': nn.NLLLoss()
}
def __getitem__(self, loss_type):
if loss_type in self.loss:
loss_func = self.loss[loss_type]
return loss_func(self.sample, self.target)
else:
raise KeyError(f"Invalid loss type '{loss_type}'")
if __name__=="__main__":
loss=FunLoss(sample,target)['NLLLoss']
print(loss)總結(jié)
這篇博客適合那些希望了解在PyTorch中常見損失函數(shù)的讀者。通過FunLoss我們自己也能簡單的去調(diào)用。
到此這篇關(guān)于PyTorch中常見損失函數(shù)的使用詳解的文章就介紹到這了,更多相關(guān)PyTorch損失函數(shù)內(nèi)容請搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
相關(guān)文章
pytorch 轉(zhuǎn)換矩陣的維數(shù)位置方法
今天小編就為大家分享一篇pytorch 轉(zhuǎn)換矩陣的維數(shù)位置方法,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧2018-12-12
python spilt()分隔字符串的實現(xiàn)示例
split() 方法可以實現(xiàn)將一個字符串按照指定的分隔符切分成多個子串,本文介紹了spilt的具體使用,感興趣的可以了解一下2021-05-05
Python issubclass和isinstance函數(shù)的具體使用
本文主要介紹了Python issubclass和isinstance函數(shù)的具體使用,文中通過示例代碼介紹的非常詳細,對大家的學習或者工作具有一定的參考學習價值,需要的朋友們下面隨著小編來一起學習學習吧2023-02-02
Python使用Selenium模塊實現(xiàn)模擬瀏覽器抓取淘寶商品美食信息功能示例
這篇文章主要介紹了Python使用Selenium模塊實現(xiàn)模擬瀏覽器抓取淘寶商品美食信息功能,涉及Python基于re模塊的正則匹配及selenium模塊的頁面抓取等相關(guān)操作技巧,需要的朋友可以參考下2018-07-07
通過conda把已有虛擬環(huán)境的python版本進行降級操作指南
當使用conda創(chuàng)建虛擬環(huán)境時,有時候可能會遇到python版本不對的問題,下面這篇文章主要給大家介紹了關(guān)于如何通過conda把已有虛擬環(huán)境的python版本進行降級操作的相關(guān)資料,需要的朋友可以參考下2024-05-05

