Keras中的多分類損失函數(shù)用法categorical_crossentropy
from keras.utils.np_utils import to_categorical
注意:當(dāng)使用categorical_crossentropy損失函數(shù)時(shí),你的標(biāo)簽應(yīng)為多類模式,例如如果你有10個(gè)類別,每一個(gè)樣本的標(biāo)簽應(yīng)該是一個(gè)10維的向量,該向量在對(duì)應(yīng)有值的索引位置為1其余為0。
可以使用這個(gè)方法進(jìn)行轉(zhuǎn)換:
from keras.utils.np_utils import to_categorical
categorical_labels = to_categorical(int_labels, num_classes=None)
以mnist數(shù)據(jù)集為例:
from keras.datasets import mnist (X_train, y_train), (X_test, y_test) = mnist.load_data() y_train = to_categorical(y_train, 10) y_test = to_categorical(y_test, 10) ... model.compile(loss='categorical_crossentropy', optimizer='adam') model.fit(X_train, y_train, epochs=100, batch_size=1, verbose=2)
補(bǔ)充知識(shí):Keras中損失函數(shù)binary_crossentropy和categorical_crossentropy產(chǎn)生不同結(jié)果的分析
問題
在使用keras做對(duì)心電信號(hào)分類的項(xiàng)目中發(fā)現(xiàn)一個(gè)問題,這個(gè)問題起源于我的一個(gè)使用錯(cuò)誤:
binary_crossentropy 二進(jìn)制交叉熵用于二分類問題中,categorical_crossentropy分類交叉熵適用于多分類問題中,我的心電分類是一個(gè)多分類問題,但是我起初使用了二進(jìn)制交叉熵,代碼如下所示:
sgd = SGD(lr=0.003, decay=0, momentum=0.7, nesterov=False) model.compile(loss='categorical_crossentropy', optimizer='sgd',metrics=['accuracy']) model.fit(X_train, Y_train, validation_data=(X_test,Y_test),batch_size=16, epochs=20) score = model.evaluate(X_test, Y_test, batch_size=16)
注意:我的CNN網(wǎng)絡(luò)模型在最后輸入層正確使用了應(yīng)該用于多分類問題的softmax激活函數(shù)
后來我在另一個(gè)殘差網(wǎng)絡(luò)模型中對(duì)同類數(shù)據(jù)進(jìn)行相同的分類問題中,正確使用了分類交叉熵,令人奇怪的是殘差模型的效果遠(yuǎn)弱于普通卷積神經(jīng)網(wǎng)絡(luò),這一點(diǎn)是不符合常理的,經(jīng)過多次修改分析終于發(fā)現(xiàn)可能是損失函數(shù)的問題,因此我使用二進(jìn)制交叉熵在殘差網(wǎng)絡(luò)中,終于取得了優(yōu)于普通卷積神經(jīng)網(wǎng)絡(luò)的效果。
因此可以斷定問題就出在所使用的損失函數(shù)身上
原理
本人也只是個(gè)只會(huì)使用框架的調(diào)參俠,對(duì)于一些原理也是一知半解,經(jīng)過了學(xué)習(xí)才大致明白,將一些原理記錄如下:
要搞明白分類熵和二進(jìn)制交叉熵先要從二者適用的激活函數(shù)說起
激活函數(shù)
sigmoid, softmax主要用于神經(jīng)網(wǎng)絡(luò)輸出層的輸出。
softmax函數(shù)
softmax可以看作是Sigmoid的一般情況,用于多分類問題。
Softmax函數(shù)將K維的實(shí)數(shù)向量壓縮(映射)成另一個(gè)K維的實(shí)數(shù)向量,其中向量中的每個(gè)元素取值都介于 (0,1) 之間。常用于多分類問題。
sigmoid函數(shù)
Sigmoid 將一個(gè)實(shí)數(shù)映射到 (0,1) 的區(qū)間,可以用來做二分類。Sigmoid 在特征相差比較復(fù)雜或是相差不是特別大時(shí)效果比較好。Sigmoid不適合用在神經(jīng)網(wǎng)絡(luò)的中間層,因?yàn)閷?duì)于深層網(wǎng)絡(luò),sigmoid 函數(shù)反向傳播時(shí),很容易就會(huì)出現(xiàn)梯度消失的情況(在 sigmoid 接近飽和區(qū)時(shí),變換太緩慢,導(dǎo)數(shù)趨于 0,這種情況會(huì)造成信息丟失),從而無法完成深層網(wǎng)絡(luò)的訓(xùn)練。所以Sigmoid主要用于對(duì)神經(jīng)網(wǎng)絡(luò)輸出層的激活。
分析
所以說多分類問題是要softmax激活函數(shù)配合分類交叉熵函數(shù)使用,而二分類問題要使用sigmoid激活函數(shù)配合二進(jìn)制交叉熵函數(shù)適用,但是如果在多分類問題中使用了二進(jìn)制交叉熵函數(shù)最后的模型分類效果會(huì)虛高,即比模型本身真實(shí)的分類效果好。
所以就會(huì)出現(xiàn)我遇到的情況,這里引用了論壇一位大佬的樣例:
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy']) # WRONG way model.fit(x_train, y_train, batch_size=batch_size, epochs=2, # only 2 epochs, for demonstration purposes verbose=1, validation_data=(x_test, y_test)) # Keras reported accuracy: score = model.evaluate(x_test, y_test, verbose=0) score[1] # 0.9975801164627075 # Actual accuracy calculated manually: import numpy as np y_pred = model.predict(x_test) acc = sum([np.argmax(y_test[i])==np.argmax(y_pred[i]) for i in range(10000)])/10000 acc # 0.98780000000000001 score[1]==acc # False
樣例中模型在評(píng)估中得到的準(zhǔn)確度高于實(shí)際測(cè)算得到的準(zhǔn)確度,網(wǎng)上給出的原因是Keras沒有定義一個(gè)準(zhǔn)確的度量,但有幾個(gè)不同的,比如binary_accuracy和categorical_accuracy,當(dāng)你使用binary_crossentropy時(shí)keras默認(rèn)在評(píng)估過程中使用了binary_accuracy,但是針對(duì)你的分類要求,應(yīng)當(dāng)采用的是categorical_accuracy,所以就造成了這個(gè)問題(其中的具體原理我也沒去看源碼詳細(xì)了解)
解決
所以問題最后的解決方法就是:
對(duì)于多分類問題,要么采用
from keras.metrics import categorical_accuracy model.compile(loss='binary_crossentropy', optimizer='adam', metrics=[categorical_accuracy])
要么采用
model.compile(loss='categorical_crossentropy',
optimizer='adam',metrics=['accuracy'])
以上這篇Keras中的多分類損失函數(shù)用法categorical_crossentropy就是小編分享給大家的全部?jī)?nèi)容了,希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
相關(guān)文章
Python NumPy實(shí)現(xiàn)數(shù)組排序與過濾示例分析講解
NumPy是Python的一種開源的數(shù)值計(jì)算擴(kuò)展,它支持大量的維度數(shù)組與矩陣運(yùn)算,這篇文章主要介紹了使用NumPy實(shí)現(xiàn)數(shù)組排序與過濾的方法,需要的朋友們下面隨著小編來一起學(xué)習(xí)吧2023-05-05pandas數(shù)據(jù)合并與重塑之merge詳解
這篇文章主要介紹了pandas數(shù)據(jù)合并與重塑之merge,具有很好的參考價(jià)值,希望對(duì)大家有所幫助,如有錯(cuò)誤或未考慮完全的地方,望不吝賜教2024-02-02Python如何利用struct進(jìn)行二進(jìn)制文件或數(shù)據(jù)流
這篇文章主要介紹了Python如何利用struct進(jìn)行二進(jìn)制文件或數(shù)據(jù)流問題,具有很好的參考價(jià)值,希望對(duì)大家有所幫助,如有錯(cuò)誤或未考慮完全的地方,望不吝賜教2024-01-01python協(xié)程之動(dòng)態(tài)添加任務(wù)的方法
今天小編就為大家分享一篇python協(xié)程之動(dòng)態(tài)添加任務(wù)的方法,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過來看看吧2019-02-02python實(shí)現(xiàn)simhash算法實(shí)例
這篇文章主要介紹了python實(shí)現(xiàn)simhash算法實(shí)例,需要的朋友可以參考下2014-04-04Python?pandas找出、刪除重復(fù)的數(shù)據(jù)實(shí)例
在面試中很可能遇到給定一個(gè)含有重復(fù)元素的列表,刪除其中重復(fù)的元素,下面這篇文章主要給大家介紹了關(guān)于Python?pandas找出、刪除重復(fù)數(shù)據(jù)的相關(guān)資料,文中通過實(shí)例代碼介紹的非常詳細(xì),需要的朋友可以參考下2022-07-07Python tornado隊(duì)列示例-一個(gè)并發(fā)web爬蟲代碼分享
這篇文章主要介紹了Python tornado隊(duì)列示例-一個(gè)并發(fā)web爬蟲代碼分享,具有一定借鑒價(jià)值,需要的朋友可以參考下2018-01-01學(xué)會(huì)迭代器設(shè)計(jì)模式,幫你大幅提升python性能
這篇文章主要介紹了python 迭代器設(shè)計(jì)模式的相關(guān)資料,幫助大家更好的理解和使用python,感興趣的朋友可以了解下2021-01-01