Python Opencv使用ann神經(jīng)網(wǎng)絡識別手寫數(shù)字功能
opencv中也提供了一種類似于Keras的神經(jīng)網(wǎng)絡,即為ann,這種神經(jīng)網(wǎng)絡的使用方法與Keras的很接近。
關(guān)于mnist數(shù)據(jù)的解析,讀者可以自己從網(wǎng)上下載相應壓縮文件,用python自己編寫解析代碼,由于這里主要研究knn算法,為了圖簡單,直接使用Keras的mnist手寫數(shù)字解析模塊。
本次代碼運行環(huán)境為:
python 3.6.8
opencv-python 4.4.0.46
opencv-contrib-python 4.4.0.46
下面的代碼為使用ann進行模型的訓練:
from keras.datasets import mnist from keras import utils import cv2 import numpy as np #opencv中ANN定義神經(jīng)網(wǎng)絡層 def create_ANN(): ann=cv2.ml.ANN_MLP_create() #設置神經(jīng)網(wǎng)絡層的結(jié)構(gòu) 輸入層為784 隱藏層為80 輸出層為10 ann.setLayerSizes(np.array([784,64,10])) #設置網(wǎng)絡參數(shù)為誤差反向傳播法 ann.setTrainMethod(cv2.ml.ANN_MLP_BACKPROP) #設置激活函數(shù)為sigmoid ann.setActivationFunction(cv2.ml.ANN_MLP_SIGMOID_SYM) #設置訓練迭代條件 #結(jié)束條件為訓練30次或者誤差小于0.00001 ann.setTermCriteria((cv2.TermCriteria_EPS|cv2.TermCriteria_COUNT,100,0.0001)) return ann #計算測試數(shù)據(jù)上的識別率 def evaluate_acc(ann,test_images,test_labels): #采用的sigmoid激活函數(shù),需要對結(jié)果進行置信度處理 #對于大于0.99的可以確定為1 對于小于0.01的可以確信為0 test_ret=ann.predict(test_images) #預測結(jié)果是一個元組 test_pre=test_ret[1] #可以直接最大值的下標 (10000,) test_pre=test_pre.argmax(axis=1) true_sum=(test_pre==test_labels) return true_sum.mean() if __name__=='__main__': #直接使用Keras載入的訓練數(shù)據(jù)(60000, 28, 28) (60000,) (train_images,train_labels),(test_images,test_labels)=mnist.load_data() #變換數(shù)據(jù)的形狀并歸一化 train_images=train_images.reshape(train_images.shape[0],-1)#(60000, 784) train_images=train_images.astype('float32')/255 test_images=test_images.reshape(test_images.shape[0],-1) test_images=test_images.astype('float32')/255 #將標簽變?yōu)閛ne-hot形狀 (60000, 10) float32 train_labels=utils.to_categorical(train_labels) #測試數(shù)據(jù)標簽不用變?yōu)閛ne-hot (10000,) test_labels=test_labels.astype(np.int) #定義神經(jīng)網(wǎng)絡模型結(jié)構(gòu) ann=create_ANN() #開始訓練 ann.train(train_images,cv2.ml.ROW_SAMPLE,train_labels) #在測試數(shù)據(jù)上測試準確率 print(evaluate_acc(ann,test_images,test_labels)) #保存模型 ann.save('mnist_ann.xml') #加載模型 myann=cv2.ml.ANN_MLP_load('mnist_ann.xml')
訓練100次得到的準確率為0.9376,可以接著增加訓練次數(shù)或者提高神經(jīng)網(wǎng)絡的層次結(jié)構(gòu)深度來提高準確率。
使用ann神經(jīng)網(wǎng)絡的模型結(jié)構(gòu)非常小,因為只是保存了權(quán)重參數(shù)。
可以看到整個模型文件的大小才1M,而svm的大小為十多兆,knn的為幾百兆,因此使用ann神經(jīng)網(wǎng)絡更加適合部署在客戶端上。
接下來使用ann進行圖片的測試識別:
import cv2 import numpy as np if __name__=='__main__': #讀取圖片 img=cv2.imread('shuzi.jpg',0) img_sw=img.copy() #將數(shù)據(jù)類型由uint8轉(zhuǎn)為float32 img=img.astype(np.float32) #圖片形狀由(28,28)轉(zhuǎn)為(784,) img=img.reshape(-1,) #增加一個維度變?yōu)?1,784) img=img.reshape(1,-1) #圖片數(shù)據(jù)歸一化 img=img/255 #載入ann模型 ann=cv2.ml.ANN_MLP_load('minist_ann.xml') #進行預測 img_pre=ann.predict(img) #因為激活函數(shù)sigmoid,因此要進行置信度處理 ret=img_pre[1] ret[ret>0.9]=1 ret[ret<0.1]=0 print(ret) cv2.imshow('test',img_sw) cv2.waitKey(0)
運行程序,結(jié)果如下,可見該模型正確識別了數(shù)字0.
到此這篇關(guān)于Python Opencv使用ann神經(jīng)網(wǎng)絡識別手寫數(shù)字的文章就介紹到這了,更多相關(guān)python opencv識別手寫數(shù)字內(nèi)容請搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
相關(guān)文章
Python中元組的基礎(chǔ)介紹及常用操作總結(jié)
元組是一種不可變序列。元組變量的賦值要在定義時就進行,這就像C語言中的const變量或是C++的引用,定義時賦值之后就不允許有修改。元組存在的意義是:元組在映射中可以作為鍵使用,因為要保證鍵的不變性。元組作為很多內(nèi)置函數(shù)和方法的返回值存在2021-09-09Python?sklearn預測評估指標混淆矩陣計算示例詳解
這篇文章主要為大家介紹了Python?sklearn預測評估指標混淆矩陣計算示例詳解,有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進步,早日升職加薪2023-02-02PyTorch中的神經(jīng)網(wǎng)絡 Mnist 分類任務
這篇文章主要介紹了PyTorch中的神經(jīng)網(wǎng)絡 Mnist 分類任務,在本次的分類任務當中,我們使用的數(shù)據(jù)集是 Mnist 數(shù)據(jù)集,這個數(shù)據(jù)集大家都比較熟悉,需要的朋友可以參考下2023-03-03