在pytorch中計算準確率,召回率和F1值的操作
看代碼吧~
predict = output.argmax(dim = 1) confusion_matrix =torch.zeros(2,2) for t, p in zip(predict.view(-1), target.view(-1)): confusion_matrix[t.long(), p.long()] += 1 a_p =(confusion_matrix.diag() / confusion_matrix.sum(1))[0] b_p = (confusion_matrix.diag() / confusion_matrix.sum(1))[1] a_r =(confusion_matrix.diag() / confusion_matrix.sum(0))[0] b_r = (confusion_matrix.diag() / confusion_matrix.sum(0))[1]
補充:pytorch 查全率 recall 查準率 precision F1調和平均 準確率 accuracy
看代碼吧~
def eval(): net.eval() test_loss = 0 correct = 0 total = 0 classnum = 9 target_num = torch.zeros((1,classnum)) predict_num = torch.zeros((1,classnum)) acc_num = torch.zeros((1,classnum)) for batch_idx, (inputs, targets) in enumerate(testloader): if use_cuda: inputs, targets = inputs.cuda(), targets.cuda() inputs, targets = Variable(inputs, volatile=True), Variable(targets) outputs = net(inputs) loss = criterion(outputs, targets) # loss is variable , if add it(+=loss) directly, there will be a bigger ang bigger graph. test_loss += loss.data[0] _, predicted = torch.max(outputs.data, 1) total += targets.size(0) correct += predicted.eq(targets.data).cpu().sum() pre_mask = torch.zeros(outputs.size()).scatter_(1, predicted.cpu().view(-1, 1), 1.) predict_num += pre_mask.sum(0) tar_mask = torch.zeros(outputs.size()).scatter_(1, targets.data.cpu().view(-1, 1), 1.) target_num += tar_mask.sum(0) acc_mask = pre_mask*tar_mask acc_num += acc_mask.sum(0) recall = acc_num/target_num precision = acc_num/predict_num F1 = 2*recall*precision/(recall+precision) accuracy = acc_num.sum(1)/target_num.sum(1) #精度調整 recall = (recall.numpy()[0]*100).round(3) precision = (precision.numpy()[0]*100).round(3) F1 = (F1.numpy()[0]*100).round(3) accuracy = (accuracy.numpy()[0]*100).round(3) # 打印格式方便復制 print('recall'," ".join('%s' % id for id in recall)) print('precision'," ".join('%s' % id for id in precision)) print('F1'," ".join('%s' % id for id in F1)) print('accuracy',accuracy)
補充:Python scikit-learn,分類模型的評估,精確率和召回率,classification_report
分類模型的評估標準一般最常見使用的是準確率(estimator.score()),即預測結果正確的百分比。
混淆矩陣:
準確率是相對所有分類結果;精確率、召回率、F1-score是相對于某一個分類的預測評估標準。
精確率(Precision):預測結果為正例樣本中真實為正例的比例(查的準)()
召回率(Recall):真實為正例的樣本中預測結果為正例的比例(查的全)()
分類的其他評估標準:F1-score,反映了模型的穩(wěn)健型
demo.py(分類評估,精確率、召回率、F1-score,classification_report):
from sklearn.datasets import fetch_20newsgroups from sklearn.model_selection import train_test_split from sklearn.feature_extraction.text import TfidfVectorizer from sklearn.naive_bayes import MultinomialNB from sklearn.metrics import classification_report # 加載數(shù)據(jù)集 從scikit-learn官網(wǎng)下載新聞數(shù)據(jù)集(共20個類別) news = fetch_20newsgroups(subset='all') # all表示下載訓練集和測試集 # 進行數(shù)據(jù)分割 (劃分訓練集和測試集) x_train, x_test, y_train, y_test = train_test_split(news.data, news.target, test_size=0.25) # 對數(shù)據(jù)集進行特征抽取 (進行特征提取,將新聞文檔轉化成特征詞重要性的數(shù)字矩陣) tf = TfidfVectorizer() # tf-idf表示特征詞的重要性 # 以訓練集數(shù)據(jù)統(tǒng)計特征詞的重要性 (從訓練集數(shù)據(jù)中提取特征詞) x_train = tf.fit_transform(x_train) print(tf.get_feature_names()) # ["condensed", "condescend", ...] x_test = tf.transform(x_test) # 不需要重新fit()數(shù)據(jù),直接按照訓練集提取的特征詞進行重要性統(tǒng)計。 # 進行樸素貝葉斯算法的預測 mlt = MultinomialNB(alpha=1.0) # alpha表示拉普拉斯平滑系數(shù),默認1 print(x_train.toarray()) # toarray() 將稀疏矩陣以稠密矩陣的形式顯示。 ''' [[ 0. 0. 0. ..., 0.04234873 0. 0. ] [ 0. 0. 0. ..., 0. 0. 0. ] ..., [ 0. 0.03934786 0. ..., 0. 0. 0. ] ''' mlt.fit(x_train, y_train) # 填充訓練集數(shù)據(jù) # 預測類別 y_predict = mlt.predict(x_test) print("預測的文章類別為:", y_predict) # [4 18 8 ..., 15 15 4] # 準確率 print("準確率為:", mlt.score(x_test, y_test)) # 0.853565365025 print("每個類別的精確率和召回率:", classification_report(y_test, y_predict, target_names=news.target_names)) ''' precision recall f1-score support alt.atheism 0.86 0.66 0.75 207 comp.graphics 0.85 0.75 0.80 238 sport.baseball 0.96 0.94 0.95 253 ..., '''
召回率的意義(應用場景):產(chǎn)品的不合格率(不想漏掉任何一個不合格的產(chǎn)品,查全);癌癥預測(不想漏掉任何一個癌癥患者)
以上為個人經(jīng)驗,希望能給大家一個參考,也希望大家多多支持腳本之家。
相關文章
Python實現(xiàn)網(wǎng)頁截圖(PyQT5)過程解析
這篇文章主要介紹了Python實現(xiàn)網(wǎng)頁截圖(PyQT5)過程解析,文中通過示例代碼介紹的非常詳細,對大家的學習或者工作具有一定的參考學習價值,需要的朋友可以參考下2019-08-08詳解用python -m http.server搭一個簡易的本地局域網(wǎng)
這篇文章主要介紹了詳解用python -m http.server搭一個簡易的本地局域網(wǎng),文中通過示例代碼介紹的非常詳細,對大家的學習或者工作具有一定的參考學習價值,需要的朋友們下面隨著小編來一起學習學習吧2020-09-09接口自動化多層嵌套json數(shù)據(jù)處理代碼實例
這篇文章主要介紹了接口自動化多層嵌套json數(shù)據(jù)處理代碼實例,文中通過示例代碼介紹的非常詳細,對大家的學習或者工作具有一定的參考學習價值,需要的朋友可以參考下2020-11-11