pytorch中交叉熵損失函數(shù)的使用小細節(jié)
目前pytorch中的交叉熵損失函數(shù)主要分為以下三類,我們將其使用的要點以及場景做一下總結。
類型一:F.cross_entropy()與torch.nn.CrossEntropyLoss()
- 輸入:非onehot label + logit。函數(shù)會自動將logit通過softmax映射為概率。
- 使用場景:都是應用于互斥的分類任務,如典型的二分類以及互斥的多分類。
- 網(wǎng)絡:分類個數(shù)即為網(wǎng)絡的輸出節(jié)點數(shù)
類型二:F.binary_cross_entropy_with_logits()與torch.nn.BCEWithLogitsLoss()
- 輸入:logit。函數(shù)會自動將logit通過sidmoid映射為概率。
- 使用場景:① 二分類 ② 非互斥多分類
- 網(wǎng)絡:使用這類損失函數(shù)需要將網(wǎng)絡輸出的每一個節(jié)點當作一個二分類的節(jié)點
①當為標準的二分類時,網(wǎng)絡的輸出節(jié)點為1
②當為非互斥的多分類時,分類個數(shù)即為網(wǎng)絡的輸出節(jié)點數(shù)
類型三:F.binary_cross_entropy()與torch.nn.BCELoss()
- 輸入:prob(概率)。這個概率可以由softmax計算而來,也可以由sigmoid計算而來。兩種不同的概率映射方式對應不同的分類任務。
- 使用場景:① 二分類 ② 非互斥多分類
- 網(wǎng)絡:①標準的二分類任務:網(wǎng)絡的輸出節(jié)點可以為1,此時概率必須由sigmoid進行映射;
網(wǎng)絡的輸出節(jié)點可以為2,此時概率必須由softmax進行映射。
②當為非互斥的多分類時,分類個數(shù)即為網(wǎng)絡的輸出節(jié)點數(shù),此時概率必須由sigmoid進行映射
1.二分類
類型一:F.cross_entropy()與torch.nn.CrossEntropyLoss()
- 網(wǎng)絡的輸出節(jié)點為2,表示real和fake(類別1和類別2)
類型二:F.binary_cross_entropy_with_logits()與torch.nn.BCEWithLogitsLoss()
- 由于這兩個函數(shù)自帶sigmoid函數(shù),要想完成二分類,網(wǎng)絡的輸出節(jié)點個數(shù)必須設置為1
類型三:F.binary_cross_entropy()與torch.nn.BCELoss(),以下兩種情況都可以使用:
- 當網(wǎng)絡輸出的節(jié)點為2時,一個節(jié)點為real另一個節(jié)點為fake,那么必然要采用softmax將logits映射為概率(兩個節(jié)點的概率和為1),此時該函數(shù)輸入為onehot label + softmax prob,計算出的交叉熵損失與類型一結算結果相同。
- 當網(wǎng)絡的輸出節(jié)點為1時,也就是后面我們要講的GAN的交叉熵損失的實現(xiàn),那么則需要使用sigmoid函數(shù)來進行映射。
這里我們以網(wǎng)絡輸出節(jié)點為2為例,由于類型二要求網(wǎng)絡的輸出節(jié)點為1,因此暫時不納入討論,主要討論類型和類型三。
測試代碼如下:
(網(wǎng)絡輸出節(jié)點為1的二分類就是目前GAN的實現(xiàn)方式,該方式下類型一的函數(shù)不可用,只能采用類型二和類型三,后面將會詳細討論)
softmax = torch.nn.Softmax() logits = np.array([[0.7, -0.1], ? ? ? ? ? ? ? ? ? ? [-1.587, ?-0.5907]]) classes = 2 label = torch.tensor([1, 1]) logits = torch.from_numpy(logits).float() ? #F.cross_entropy loss1 = F.cross_entropy(logits, label) ? print(loss1) ? #nn.CrossEntropyLoss() criterion = nn.CrossEntropyLoss() loss2 = criterion(logits, label) print(loss2) ? #可以看到,loss1是等于loss2的 ? prob = softmax(logits) ?#計算概率 one_hot_label = one_hot(label, classes) ? #F.binary_cross_entropy loss3 = F.binary_cross_entropy(prob, one_hot_label) #輸入概率和one-hot print(loss3) ? #torch.nn.BCELoss() adversarial_loss = torch.nn.BCELoss() loss4 = adversarial_loss(prob, one_hot_label) print(loss4) ? #同理,loss3是等于loss4的 ? #手動實現(xiàn)二分類的交叉熵損失 shixian = -torch.mean(torch.sum(one_hot_label * torch.log(prob), axis = 1)) ?#手動實現(xiàn) print(shixian)
2.多分類
此時網(wǎng)絡輸出時多節(jié)點,每一個節(jié)點代表一個類別。
類型一:F.cross_entropy()與torch.nn.CrossEntropyLoss()
- 可以用于多分類的互斥任務,輸入非onehot label + logit。但是不能用于多分類多標簽任務。因為這兩個函數(shù)中自帶的softmax將網(wǎng)絡的每一個節(jié)點都當作時互斥的獨立節(jié)點,每個節(jié)點的概率和為1,因為概率最大的那個節(jié)點的類別會被當為最終的預測類別
類型二:F.binary_cross_entropy_with_logits()與torch.nn.BCEWithLogitsLoss()
- 不能用于多分類的互斥任務,只能用于多分類的非互斥任務
類型三:F.binary_cross_entropy()與torch.nn.BCELoss()
- 與類型二一樣,不能用于多分類的互斥任務,只能用于多分類的非互斥任務。
這里我們首先討論下類型一和類型三,為什么類型三不能用于多分類的互斥任務,只能用于多分類多標簽的分類任務?我們來看一段代碼,這里有三個類別,兩個樣本。
softmax = torch.nn.Softmax() logits = np.array([[0.7, -0.1, 0.2], ? ? ? ? ? ? ? ? ? ? [-1.587, ?-0.5907, 0.3]]) classes = 3 label = torch.tensor([1, 2]) logits = torch.from_numpy(logits).float() ? ### F.cross_entropy loss1 = F.cross_entropy(logits, label) ? print(loss1) ? ### nn.CrossEntropyLoss() criterion = nn.CrossEntropyLoss() loss2 = criterion(logits, label) print(loss2) ##loss1 = loss2
上面是采用類型一的兩個函數(shù)計算而來,loss1 = loss2 = 0.9833
然后我們用類型三的函數(shù)來實現(xiàn),同樣將logit通過softmax映射為概率,運行后的結果可以看loss3 =loss4 = 0.5649,不等于類型一的函數(shù)的結果的。
prob_softmax = softmax(logits) ?#計算概率 one_hot_label = one_hot(label, classes) ? ## F.binary_cross_entropy loss3 = F.binary_cross_entropy(prob_softmax, one_hot_label) #輸入概率和one-hot print(loss3) ? ## torch.nn.BCELoss() adversarial_loss = torch.nn.BCELoss() loss4 = adversarial_loss(prob_softmax, one_hot_label) print(loss4)
最后我們再手動實現(xiàn)類型三的損失究竟是怎么得到的:
#手動實現(xiàn) shixian = -torch.mean(one_hot_label * torch.log(prob_softmax) + (1-one_hot_label) * torch.log(1-prob_softmax)) print(shixian)
可以看出來,F(xiàn).binary_cross_entropy()與torch.nn.BCELoss()是將網(wǎng)絡的每個節(jié)點看作是一個二分類的節(jié)點來計算交叉熵損失的。
進一步來討論下類型二和類型三的一致性,代碼如下。由于類型二中函數(shù)自動將logit通過sigloid函數(shù)映射為概率,為了檢驗一致性性,我門也需要通過sigmoid計算類型三所需要的概率。
最后可以看到下面的輸出均為0.6378
sigmoid = nn.Sigmoid() prob_sig = sigmoid(logits) ?#計算概率 ? ##類型二 ##F.binary_cross_entropy_with_logits loss5 = F.binary_cross_entropy_with_logits(logits, one_hot_label) print(loss5) ? ##torch.nn.BCEWithLogitsLoss() BCEWithLogitsLoss = torch.nn.BCEWithLogitsLoss() loss6 = BCEWithLogitsLoss(logits, one_hot_label) print(loss6) ? ##類型三 ##F.binary_cross_entropy loss7 = F.binary_cross_entropy(prob_sig, one_hot_label) #輸入概率和one-hot print(loss7) ? ## torch.nn.BCELoss() adversarial_loss = torch.nn.BCELoss() loss8 = adversarial_loss(prob_sig, one_hot_label) print(loss8) ? #手動實現(xiàn) shixian = -torch.mean(one_hot_label * torch.log(prob_sig) + (1-one_hot_label) * torch.log(1-prob_sig)) print(shixian)
3. GAN中的實現(xiàn):二分類
GAN中的判別器出的損失就是典型的最小化二分類的交叉熵損失。但是在實現(xiàn)上,與二分類網(wǎng)絡不同。
- 一般的二分類網(wǎng)絡,輸出有兩個節(jié)點,分別表示real和fake的logit(或者概率)。
- GAN的判別器,輸出只有一個節(jié)點,表示的是樣本屬于real的logit(或者概率)。
正因為判別器的輸出是一維,類型一的兩個函數(shù)F.cross_entropy()與torch.nn.CrossEntropyLoss()是沒有辦法使用的,因為這兩個函數(shù)要求輸入是二維的,即分別在real和fake的logit。因此只能采用類型二或者類型三的函數(shù)。
很多GAN網(wǎng)絡采用的二分類交叉熵損失函數(shù)如下:
#類型二: adversarial_loss_2 = torch.nn.BCEWithLogitsLoss(logit,y) #類型三: adversarial_loss_3 = torch.nn.BCELoss(p,y)
前面我們講到,類型二和類型三的函數(shù)都是將每一個節(jié)點視為一個二分類的節(jié)點,因此對于每一個給節(jié)點,其具體的表達式可以寫為:
#類型二: torch.nn.BCEWithLogitsLoss(logit,y) = - (ylog(sigmoid(logit)) + (1-y)log(1-sigmoid(logit))) # 其中l(wèi)ogit表示判斷為real的logit # y=1表示real # y=0表示fake ? #類型三: torch.nn.BCELoss(p, y) = - (ylog(p) + (1-y)log(1-p)) # 其中p表示判斷為real的概率 # y=1表示real # y=0表示fake
3.1 判別器損失計算
判別器輸出維度為1,輸出logit,有兩個樣本,都為fake圖像
logits = np.array([1.2, -0.5]) logits = torch.from_numpy(logits).float() sigmoid = nn.Sigmoid() prob_sig = sigmoid(logits) ?#計算概率 ? label = torch.tensor([1, 1]).float() ? #類型二: adversarial_loss_2 = torch.nn.BCEWithLogitsLoss() loss_2 = adversarial_loss_2(logits, 1-label) ?#因為是fake,需要將y設置為0 print(loss_2) ? #類型三: adversarial_loss_3 = torch.nn.BCELoss() loss_3 = adversarial_loss_3(prob_sig, 1-label) #因為是fake,需要將y設置為0 print(loss_3) #輸出均為0.9687
通過上述代碼可以分析如下:
(1)當樣本為fake時,網(wǎng)絡輸出其為real的logit:
- 對于類型二:torch.nn.BCEWithLogitsLoss(logit,0),即直接輸入logit。由于樣本的實際類別為fake,根據(jù)交叉熵損失公式,要將為y設置為0,相當于告訴函數(shù)我輸入的樣本是fake。
- 對于類型三:torch.nn.BCELoss(prob, 0),此時prob等于公式中的p,由于樣本的實際類別為fake,與類型二一致,要將為y設置為0。
(2)樣本為real,網(wǎng)絡輸出其為real的logit:
- 對于類型二:torch.nn.BCEWithLogitsLoss(logit,1),即直接輸入logit。由于樣本的實際類別也為real,根據(jù)交叉熵損失公式,要將為y設置為1,這樣就計算了 ylog(sigmoid(logit))
- 對于類型三:torch.nn.BCELoss(prob, 1),此時prob等于公式中的p,樣本的實際類別也為real,與類型二一致,要將為y設置為1,這樣就計算了 ylog(p)
GAN網(wǎng)絡在更新判別器時,代碼一般如下:
criterion = torch.nn.BCELoss() real_out = D(real_img) ?# 將真實圖片放入判別器中 d_loss_real = criterion(real_out, 1) ?# 真實樣本的損失 ? fake_img = G(z) ?# 隨機噪聲放入生成網(wǎng)絡中,生成一張假的圖片 fake_out = D(fake_img) ?# 判別器判斷假的圖片, d_loss_fake = criterion(fake_out, 0) ?# 生成樣本的損失 ? d_loss = d_loss_real + d_loss_fake ?# ?兩個相加 就是標準的交叉熵損失 ? optimizer_D.zero_grad() d_loss.backward() optimizer_D.step()
3.2 生成器的損失計算
前面判別器處的損失是最小化交叉熵損失:
min - (ylog(p) + (1-y)log(1-p))
那么生成器與之相反就是最大化交叉熵損失:
max - (ylog(p) + (1-y)log(1-p))
因為真實樣本于與生成器無關,因此可以轉變?yōu)閙in log(1-p)
max - ((1-y)log(1-p)) = min (1-y)log(1-p) = min log(1-p)
上述形式為飽和形式,轉變?yōu)榉秋柡腿缦隆?/p>
min -log(p)
可以看到上式子在形式上就是將fake圖像當作real圖像進行優(yōu)化。
可以這么理解:生成器的作用的就是盡可能生成逼近與real的fake,由于判別器判斷的結果p就是表示圖像為real的概率,那么生成器就希望p越高越好。而在訓練判別器時,判別器對real的優(yōu)化就是讓其p越高越好,即盡可能的區(qū)分real和fake。
因此在更新生成器時,fake處的損失與更新判別器在real處的損失在邏輯上是一致的。
criterion = torch.nn.BCELoss() fake_img = G(z) ?# 隨機噪聲放入生成網(wǎng)絡中,生成一張假的圖片 fake_out = D(fake_img) ?# 判別器判斷假的圖片, G_loss = criterion(fake_out, 1) ?# 假樣本的損失 ? ? optimizer_G.zero_grad() G_loss .backward() optimizer_G.step()
3.3 小結
在GAN網(wǎng)絡中,由于輸出網(wǎng)絡只有一個節(jié)點,表示圖像屬于real的logit或者prob,因此一般使用類型二和類型三的損失函數(shù)。
兩類函數(shù)的實現(xiàn)如下:
torch.nn.BCEWithLogitsLoss(logit,y) = - (ylog(sigmoid(logit)) + (1-y)log(1-sigmoid(logit))) torch.nn.BCELoss(p, y) = - (ylog(prob) + (1-y)log(1-prob))
因為上述實現(xiàn):
- 在更新判別器時:real圖像后面label為1,fake圖像后面label為0。分別計算real和fake的損失相加。
- 在更新判別器時:與real圖像無關,fake圖像后面label為1,更新。
總結
以上為個人經(jīng)驗,希望能給大家一個參考,也希望大家多多支持腳本之家。
相關文章
Pycharm中安裝Pygal并使用Pygal模擬擲骰子(推薦)
這篇文章主要介紹了Pycharm中安裝Pygal并使用Pygal模擬擲骰子,本文給大家介紹的非常詳細,對大家的學習或工作具有一定的參考借鑒價值,需要的朋友可以參考下2020-04-04Python實現(xiàn)數(shù)字圖像處理染色體計數(shù)示例
這篇文章主要為大家介紹了Python實現(xiàn)數(shù)字圖像處理染色體計數(shù)示例,有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進步,早日升職加薪2022-06-06Python面向對象程序設計之繼承、多態(tài)原理與用法詳解
這篇文章主要介紹了Python面向對象程序設計之繼承、多態(tài),結合實例形式分析了Python面向對象程序設計中繼承、多態(tài)的相關概念、原理、用法及操作注意事項,需要的朋友可以參考下2020-03-03