Python利用 SVM 算法實(shí)現(xiàn)識(shí)別手寫數(shù)字
前言
支持向量機(jī) (Support Vector Machine, SVM) 是一種監(jiān)督學(xué)習(xí)技術(shù),它通過(guò)根據(jù)指定的類對(duì)訓(xùn)練數(shù)據(jù)進(jìn)行最佳分離,從而在高維空間中構(gòu)建一個(gè)或一組超平面。在博文《OpenCV-Python實(shí)戰(zhàn)(13)——OpenCV與機(jī)器學(xué)習(xí)的碰撞》中,我們已經(jīng)學(xué)習(xí)了如何在 OpenCV 中實(shí)現(xiàn)和訓(xùn)練 SVM 算法,同時(shí)通過(guò)簡(jiǎn)單的示例了解了如何使用 SVM 算法。在本文中,我們將學(xué)習(xí)如何使用 SVM 分類器執(zhí)行手寫數(shù)字識(shí)別,同時(shí)也將探索不同的參數(shù)對(duì)于模型性能的影響,以獲取具有最佳性能的 SVM 分類器。
使用 SVM 進(jìn)行手寫數(shù)字識(shí)別
我們已經(jīng)在《利用 KNN 算法識(shí)別手寫數(shù)字》中介紹了 MNIST 手寫數(shù)字?jǐn)?shù)據(jù)集,以及如何利用 KNN 算法識(shí)別手寫數(shù)字。并通過(guò)對(duì)數(shù)字圖像進(jìn)行預(yù)處理( desew() 函數(shù))并使用高級(jí)描述符( HOG 描述符)作為用于描述每個(gè)數(shù)字的特征向量來(lái)獲得最佳分類準(zhǔn)確率。因此,對(duì)于相同的內(nèi)容不再贅述,接下來(lái)將直接使用在《利用 KNN 算法識(shí)別手寫數(shù)字》中介紹預(yù)處理和 HOG 特征,利用 SVM 算法對(duì)數(shù)字圖像進(jìn)行分類。
首先加載數(shù)據(jù),并將其劃分為訓(xùn)練集和測(cè)試集:
# 加載數(shù)據(jù) (train_dataset, train_labels), (test_dataset, test_labels) = keras.datasets.mnist.load_data() SIZE_IMAGE = train_dataset.shape[1] train_labels = np.array(train_labels, dtype=np.int32) # 預(yù)處理函數(shù) def deskew(img): m = cv2.moments(img) if abs(m['mu02']) < 1e-2: return img.copy() skew = m['mu11'] / m['mu02'] M = np.float32([[1, skew, -0.5 * SIZE_IMAGE * skew], [0, 1, 0]]) img = cv2.warpAffine(img, M, (SIZE_IMAGE, SIZE_IMAGE), flags=cv2.WARP_INVERSE_MAP | cv2.INTER_LINEAR) return img # HOG 高級(jí)描述符 def get_hog(): hog = cv2.HOGDescriptor((SIZE_IMAGE, SIZE_IMAGE), (8, 8), (4, 4), (8, 8), 9, 1, -1, 0, 0.2, 1, 64, True) print("hog descriptor size: {}".format(hog.getDescriptorSize())) return hog # 數(shù)據(jù)打散 shuffle = np.random.permutation(len(train_dataset)) train_dataset, train_labels = train_dataset[shuffle], train_labels[shuffle] hog = get_hog() hog_descriptors = [] for img in train_dataset: hog_descriptors.append(hog.compute(deskew(img))) hog_descriptors = np.squeeze(hog_descriptors) results = defaultdict(list) # 數(shù)據(jù)劃分 split_values = np.arange(0.1, 1, 0.1)
接下來(lái),初始化 SVM,并進(jìn)行訓(xùn)練:
# 模型初始化函數(shù) def svm_init(C=12.5, gamma=0.50625): model = cv2.ml.SVM_create() model.setGamma(gamma) model.setC(C) model.setKernel(cv2.ml.SVM_RBF) model.setType(cv2.ml.SVM_C_SVC) model.setTermCriteria((cv2.TERM_CRITERIA_MAX_ITER, 100, 1e-6)) return model # 模型訓(xùn)練函數(shù) def svm_train(model, samples, responses): model.train(samples, cv2.ml.ROW_SAMPLE, responses) return model # 模型預(yù)測(cè)函數(shù) def svm_predict(model, samples): return model.predict(samples)[1].ravel() # 模型評(píng)估函數(shù) def svm_evaluate(model, samples, labels): predictions = svm_predict(model, samples) acc = (labels == predictions).mean() print('Percentage Accuracy: %.2f %%' % (acc * 100)) return acc *100 # 使用不同訓(xùn)練集、測(cè)試集劃分方法進(jìn)行訓(xùn)練和測(cè)試 for split_value in split_values: partition = int(split_value * len(hog_descriptors)) hog_descriptors_train, hog_descriptors_test = np.split(hog_descriptors, [partition]) labels_train, labels_test = np.split(train_labels, [partition]) print('Training SVM model ...') model = svm_init(C=12.5, gamma=0.50625) svm_train(model, hog_descriptors_train, labels_train) print('Evaluating model ... ') acc = svm_evaluate(model, hog_descriptors_test, labels_test) results['svm'].append(acc)
從上圖所示,使用默認(rèn)參數(shù)的 SVM 模型在使用 70% 的數(shù)字圖像訓(xùn)練算法時(shí)準(zhǔn)確率可以達(dá)到 98.60%,接下來(lái)我們通過(guò)修改 SVM 模型的參數(shù) C 和 γ 來(lái)測(cè)試模型是否還有提升空間。
參數(shù) C 和 γ 對(duì)識(shí)別手寫數(shù)字精確度的影響
SVM 模型在使用 RBF 核時(shí),有兩個(gè)重要參數(shù)——C 和 γ,上例中我們使用 C=12.5 和 γ=0.50625 作為參數(shù)值,C 和 γ 的設(shè)定依賴于特定的數(shù)據(jù)集。因此,必須使用某種方法進(jìn)行參數(shù)搜索,本例中使用網(wǎng)格搜索合適的參數(shù) C 和 γ。
for C in [1, 10, 100, 1000]: for gamma in [0.1, 0.15, 0.25, 0.3, 0.35, 0.45, 0.5, 0.65]: model = svm_init(C, gamma) svm_train(model, hog_descriptors_train, labels_train) acc = svm_evaluate(model, hog_descriptors_test, labels_test) print(" {}".format("%.2f" % acc)) results[C].append(acc)
最后,可視化結(jié)果:
fig = plt.figure(figsize=(10, 6)) plt.suptitle("SVM handwritten digits recognition", fontsize=14, fontweight='bold') ax = plt.subplot(1, 1, 1) ax.set_xlim(0, 0.65) dim = [0.1, 0.15, 0.25, 0.3, 0.35, 0.45, 0.5, 0.65] for key in results: ax.plot(dim, results[key], linestyle='--', marker='o', label=str(key)) plt.legend(loc='upper left', title="C") plt.title('Accuracy of the SVM model varying both C and gamma') plt.xlabel("gamma") plt.ylabel("accuracy") plt.show()
程序的運(yùn)行結(jié)果如下所示:
如圖所示,通過(guò)使用不同參數(shù),準(zhǔn)確率可以達(dá)到 99.25% 左右。通過(guò)比較 KNN 分類器和 SVM 分類器在手寫數(shù)字識(shí)別任務(wù)中的表現(xiàn),我們可以得出在手寫數(shù)字識(shí)別任務(wù)中 SVM 優(yōu)于 KNN 分類器的結(jié)論。
完整代碼
程序的完整代碼如下所示:
import cv2
import numpy as np
import matplotlib.pyplot as plt
from collections import defaultdict
import keras
(train_dataset, train_labels), (test_dataset, test_labels) = keras.datasets.mnist.load_data()
SIZE_IMAGE = train_dataset.shape[1]
train_labels = np.array(train_labels, dtype=np.int32)
def deskew(img):
m = cv2.moments(img)
if abs(m['mu02']) < 1e-2:
return img.copy()
skew = m['mu11'] / m['mu02']
M = np.float32([[1, skew, -0.5 * SIZE_IMAGE * skew], [0, 1, 0]])
img = cv2.warpAffine(img, M, (SIZE_IMAGE, SIZE_IMAGE), flags=cv2.WARP_INVERSE_MAP | cv2.INTER_LINEAR)
return img
def get_hog():
hog = cv2.HOGDescriptor((SIZE_IMAGE, SIZE_IMAGE), (8, 8), (4, 4), (8, 8), 9, 1, -1, 0, 0.2, 1, 64, True)
print("hog descriptor size: {}".format(hog.getDescriptorSize()))
return hog
def svm_init(C=12.5, gamma=0.50625):
model = cv2.ml.SVM_create()
model.setGamma(gamma)
model.setC(C)
model.setKernel(cv2.ml.SVM_RBF)
model.setType(cv2.ml.SVM_C_SVC)
model.setTermCriteria((cv2.TERM_CRITERIA_MAX_ITER, 100, 1e-6))
return model
def svm_train(model, samples, responses):
model.train(samples, cv2.ml.ROW_SAMPLE, responses)
return model
def svm_predict(model, samples):
return model.predict(samples)[1].ravel()
def svm_evaluate(model, samples, labels):
predictions = svm_predict(model, samples)
acc = (labels == predictions).mean()
return acc * 100
# 數(shù)據(jù)打散
shuffle = np.random.permutation(len(train_dataset))
train_dataset, train_labels = train_dataset[shuffle], train_labels[shuffle]
# 使用 HOG 描述符
hog = get_hog()
hog_descriptors = []
for img in train_dataset:
hog_descriptors.append(hog.compute(deskew(img)))
hog_descriptors = np.squeeze(hog_descriptors)
# 訓(xùn)練數(shù)據(jù)與測(cè)試數(shù)據(jù)劃分
partition = int(0.9 * len(hog_descriptors))
hog_descriptors_train, hog_descriptors_test = np.split(hog_descriptors, [partition])
labels_train, labels_test = np.split(train_labels, [partition])
print('Training SVM model ...')
results = defaultdict(list)
for C in [1, 10, 100, 1000]:
for gamma in [0.1, 0.15, 0.25, 0.3, 0.35, 0.45, 0.5, 0.65]:
model = svm_init(C, gamma)
svm_train(model, hog_descriptors_train, labels_train)
acc = svm_evaluate(model, hog_descriptors_test, labels_test)
print(" {}".format("%.2f" % acc))
results[C].append(acc)
fig = plt.figure(figsize=(10, 6))
plt.suptitle("SVM handwritten digits recognition", fontsize=14, fontweight='bold')
ax = plt.subplot(1, 1, 1)
ax.set_xlim(0, 0.65)
dim = [0.1, 0.15, 0.25, 0.3, 0.35, 0.45, 0.5, 0.65]
for key in results:
ax.plot(dim, results[key], linestyle='--', marker='o', label=str(key))
plt.legend(loc='upper left', title="C")
plt.title('Accuracy of the SVM model varying both C and gamma')
plt.xlabel("gamma")
plt.ylabel("accuracy")
plt.show()
以上就是Python利用 SVM 算法實(shí)現(xiàn)識(shí)別手寫數(shù)字的詳細(xì)內(nèi)容,更多關(guān)于Python SVM算法識(shí)別手寫數(shù)字的資料請(qǐng)關(guān)注腳本之家其它相關(guān)文章!
相關(guān)文章
帶你徹底搞懂python操作mysql數(shù)據(jù)庫(kù)(cursor游標(biāo)講解)
這篇文章主要介紹了帶你徹底搞懂python操作mysql數(shù)據(jù)庫(kù)(cursor游標(biāo)講解),文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來(lái)一起學(xué)習(xí)學(xué)習(xí)吧2020-01-01Python多進(jìn)程multiprocessing用法實(shí)例分析
這篇文章主要介紹了Python多進(jìn)程multiprocessing用法,結(jié)合實(shí)例形式分析了Python多線程的概念以及進(jìn)程的創(chuàng)建、守護(hù)進(jìn)程、終止、退出進(jìn)程、進(jìn)程間消息傳遞等相關(guān)操作技巧,需要的朋友可以參考下2017-08-08如何利用opencv訓(xùn)練自己的模型實(shí)現(xiàn)特定物體的識(shí)別
在Python中通過(guò)OpenCV自己訓(xùn)練分類器進(jìn)行特定物體實(shí)時(shí)識(shí)別,下面這篇文章主要給大家介紹了關(guān)于如何利用opencv訓(xùn)練自己的模型實(shí)現(xiàn)特定物體的識(shí)別,文中通過(guò)實(shí)例代碼介紹的非常詳細(xì),需要的朋友可以參考下2022-10-10談?wù)凱ython進(jìn)行驗(yàn)證碼識(shí)別的一些想法
關(guān)于python驗(yàn)證碼識(shí)別,主要方法有幾類:一類是通過(guò)對(duì)圖片進(jìn)行處理,然后利用字庫(kù)特征匹配的方法,一類是圖片處理后建立字符對(duì)應(yīng)字典,還有一類是直接利用ocr模塊進(jìn)行識(shí)別。不管是用什么方法,都需要首先對(duì)圖片進(jìn)行處理,于是試著對(duì)下面的驗(yàn)證碼進(jìn)行分析2016-01-01解決pycharm工程啟動(dòng)卡住沒(méi)反應(yīng)的問(wèn)題
今天小編就為大家分享一篇解決pycharm工程啟動(dòng)卡住沒(méi)反應(yīng)的問(wèn)題,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧2019-01-01Python基于whois模塊簡(jiǎn)單識(shí)別網(wǎng)站域名及所有者的方法
這篇文章主要介紹了Python基于whois模塊簡(jiǎn)單識(shí)別網(wǎng)站域名及所有者的方法,簡(jiǎn)單分析了Python whois模塊的安裝及使用相關(guān)操作技巧,需要的朋友可以參考下2018-04-04