亚洲乱码中文字幕综合,中国熟女仑乱hd,亚洲精品乱拍国产一区二区三区,一本大道卡一卡二卡三乱码全集资源,又粗又黄又硬又爽的免费视频

PyTorch中常見損失函數(shù)的使用詳解

 更新時間:2023年06月18日 11:34:38   作者:夏天是冰紅茶  
損失函數(shù),又叫目標函數(shù),是指計算機標簽值和預測值直接差異的函數(shù),本文為大家整理了PyTorch中常見損失函數(shù)的簡單解釋和使用,希望對大家有所幫助

損失函數(shù)

損失函數(shù),又叫目標函數(shù)。在編譯神經(jīng)網(wǎng)絡(luò)模型必須的兩個參數(shù)之一。另一個必不可少的就是優(yōu)化器,我將在后面詳解到。

重點

損失函數(shù)是指計算機標簽值和預測值直接差異的函數(shù)。

這里我們會結(jié)束幾種常見的損失函數(shù)的計算方法,pytorch中也是以及定義了很多類型的預定義函數(shù),具體的公式不需要去深究(學了也不一定remember),這里暫時能做就是了解。

我們先來定義兩個二維的數(shù)組,然后用不同的損失函數(shù)計算其損失值。

import torch
from torch.autograd import Variable
import torch.nn as nn
sample=Variable(torch.ones(2,2))
a=torch.Tensor(2,2)
a[0,0]=0
a[0,1]=1
a[1,0]=2
a[1,1]=3
target=Variable(a)
print(sample,target)

這里:

sample的值為tensor([[1., 1.],[1., 1.]])

target的值為tensor([[0., 1.],[2., 3.]])

nn.L1Loss

L1Loss計算方法很簡單,取預測值和真實值的絕對誤差的平均數(shù)。

loss=FunLoss(sample,target)['L1Loss']
print(loss)

在控制臺中打印出來是

tensor(1.)

它的計算過程是這樣的:(∣0−1∣+∣1−1∣+∣2−1∣+∣3−1∣)/4=1,先計算的是絕對值求和,然后再平均。

nn.SmoothL1Loss

SmoothL1Loss的誤差在(-1,1)上是平方損失,其他情況是L1損失。

loss=FunLoss(sample,target)['SmoothL1Loss']
print(loss)

在控制臺中打印出來是

tensor(0.6250)

nn.MSELoss

平方損失函數(shù)。其計算公式是預測值和真實值之間的平方和的平均數(shù)。

loss=FunLoss(sample,target)['MSELoss']
print(loss)

在控制臺中打印出來是

tensor(1.5000)

nn.CrossEntropyLoss

交叉熵損失公式

此公式常在圖像分類神經(jīng)網(wǎng)絡(luò)模型中會常常用到。

loss=FunLoss(sample,target)['CrossEntropyLoss']
print(loss)

在控制臺中打印出來是

tensor(2.0794)

nn.NLLLoss

負對數(shù)似然損失函數(shù)

需要注意的是,這里的xlabel和上面的交叉熵損失里的是不一樣的,這里是經(jīng)過log運算后的數(shù)值。這個損失函數(shù)一般用在圖像識別的模型上。

loss=FunLoss(sample,target)['NLLLoss']
print(loss)

這里,控制臺報錯,需要0D或1D目標張量,不支持多目標??赡苄枰渌囊恍l件,這里我們?nèi)绻龅搅嗽僬f。

損失函數(shù)模塊化設(shè)計

class FunLoss():
    def __init__(self, sample, target):
        self.sample = sample
        self.target = target
        self.loss = {
            'L1Loss': nn.L1Loss(),
            'SmoothL1Loss': nn.SmoothL1Loss(),
            'MSELoss': nn.MSELoss(),
            'CrossEntropyLoss': nn.CrossEntropyLoss(),
            'NLLLoss': nn.NLLLoss()
        }
    def __getitem__(self, loss_type):
        if loss_type in self.loss:
            loss_func = self.loss[loss_type]
            return loss_func(self.sample, self.target)
        else:
            raise KeyError(f"Invalid loss type '{loss_type}'")
if __name__=="__main__":
    loss=FunLoss(sample,target)['NLLLoss']
    print(loss)

總結(jié)

這篇博客適合那些希望了解在PyTorch中常見損失函數(shù)的讀者。通過FunLoss我們自己也能簡單的去調(diào)用。

到此這篇關(guān)于PyTorch中常見損失函數(shù)的使用詳解的文章就介紹到這了,更多相關(guān)PyTorch損失函數(shù)內(nèi)容請搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!

相關(guān)文章

最新評論