欧美bbbwbbbw肥妇,免费乱码人妻系列日韩,一级黄片

Python利用 SVM 算法實現(xiàn)識別手寫數(shù)字

 更新時間:2021年12月20日 10:31:04   作者:盼小輝丶  
支持向量機 (Support Vector Machine, SVM) 是一種監(jiān)督學習技術(shù),它通過根據(jù)指定的類對訓練數(shù)據(jù)進行最佳分離,從而在高維空間中構(gòu)建一個或一組超平面。本文將介紹通過SVM算法實現(xiàn)手寫數(shù)字的識別,需要的可以了解一下

前言

支持向量機 (Support Vector Machine, SVM) 是一種監(jiān)督學習技術(shù),它通過根據(jù)指定的類對訓練數(shù)據(jù)進行最佳分離,從而在高維空間中構(gòu)建一個或一組超平面。在博文《OpenCV-Python實戰(zhàn)(13)——OpenCV與機器學習的碰撞》中,我們已經(jīng)學習了如何在 OpenCV 中實現(xiàn)和訓練 SVM 算法,同時通過簡單的示例了解了如何使用 SVM 算法。在本文中,我們將學習如何使用 SVM 分類器執(zhí)行手寫數(shù)字識別,同時也將探索不同的參數(shù)對于模型性能的影響,以獲取具有最佳性能的 SVM 分類器。

使用 SVM 進行手寫數(shù)字識別

我們已經(jīng)在《利用 KNN 算法識別手寫數(shù)字》中介紹了 MNIST 手寫數(shù)字數(shù)據(jù)集,以及如何利用 KNN 算法識別手寫數(shù)字。并通過對數(shù)字圖像進行預處理( desew() 函數(shù))并使用高級描述符( HOG 描述符)作為用于描述每個數(shù)字的特征向量來獲得最佳分類準確率。因此,對于相同的內(nèi)容不再贅述,接下來將直接使用在《利用 KNN 算法識別手寫數(shù)字》中介紹預處理和 HOG 特征,利用 SVM 算法對數(shù)字圖像進行分類。

首先加載數(shù)據(jù),并將其劃分為訓練集和測試集:

# 加載數(shù)據(jù)
(train_dataset, train_labels), (test_dataset, test_labels) = keras.datasets.mnist.load_data()
SIZE_IMAGE = train_dataset.shape[1]
train_labels = np.array(train_labels, dtype=np.int32)
# 預處理函數(shù)
def deskew(img):
    m = cv2.moments(img)
    if abs(m['mu02']) < 1e-2:
        return img.copy()
    skew = m['mu11'] / m['mu02']
    M = np.float32([[1, skew, -0.5 * SIZE_IMAGE * skew], [0, 1, 0]])
    img = cv2.warpAffine(img, M, (SIZE_IMAGE, SIZE_IMAGE), flags=cv2.WARP_INVERSE_MAP | cv2.INTER_LINEAR)

    return img
# HOG 高級描述符
def get_hog():
    hog = cv2.HOGDescriptor((SIZE_IMAGE, SIZE_IMAGE), (8, 8), (4, 4), (8, 8), 9, 1, -1, 0, 0.2, 1, 64, True)

    print("hog descriptor size: {}".format(hog.getDescriptorSize()))

    return hog
# 數(shù)據(jù)打散
shuffle = np.random.permutation(len(train_dataset))
train_dataset, train_labels = train_dataset[shuffle], train_labels[shuffle]

hog = get_hog()

hog_descriptors = []
for img in train_dataset:
    hog_descriptors.append(hog.compute(deskew(img)))
hog_descriptors = np.squeeze(hog_descriptors)

results = defaultdict(list)
# 數(shù)據(jù)劃分
split_values = np.arange(0.1, 1, 0.1)

接下來,初始化 SVM,并進行訓練:

# 模型初始化函數(shù)
def svm_init(C=12.5, gamma=0.50625):
    model = cv2.ml.SVM_create()
    model.setGamma(gamma)
    model.setC(C)
    model.setKernel(cv2.ml.SVM_RBF)
    model.setType(cv2.ml.SVM_C_SVC)
    model.setTermCriteria((cv2.TERM_CRITERIA_MAX_ITER, 100, 1e-6))

    return model
# 模型訓練函數(shù)
def svm_train(model, samples, responses):
    model.train(samples, cv2.ml.ROW_SAMPLE, responses)
    return model
# 模型預測函數(shù)
def svm_predict(model, samples):
    return model.predict(samples)[1].ravel()
# 模型評估函數(shù)
def svm_evaluate(model, samples, labels):
    predictions = svm_predict(model, samples)
    acc = (labels == predictions).mean()
    print('Percentage Accuracy: %.2f %%' % (acc * 100))
    return acc *100
# 使用不同訓練集、測試集劃分方法進行訓練和測試
for split_value in split_values:
    partition = int(split_value * len(hog_descriptors))
    hog_descriptors_train, hog_descriptors_test = np.split(hog_descriptors, [partition])
    labels_train, labels_test = np.split(train_labels, [partition])

    print('Training SVM model ...')
    model = svm_init(C=12.5, gamma=0.50625)
    svm_train(model, hog_descriptors_train, labels_train)

    print('Evaluating model ... ')
    acc = svm_evaluate(model, hog_descriptors_test, labels_test)
    results['svm'].append(acc)

從上圖所示,使用默認參數(shù)的 SVM 模型在使用 70% 的數(shù)字圖像訓練算法時準確率可以達到 98.60%,接下來我們通過修改 SVM 模型的參數(shù) C 和 γ 來測試模型是否還有提升空間。

參數(shù) C 和 γ 對識別手寫數(shù)字精確度的影響

SVM 模型在使用 RBF 核時,有兩個重要參數(shù)——C 和 γ,上例中我們使用 C=12.5 和 γ=0.50625 作為參數(shù)值,C 和 γ 的設(shè)定依賴于特定的數(shù)據(jù)集。因此,必須使用某種方法進行參數(shù)搜索,本例中使用網(wǎng)格搜索合適的參數(shù) C 和 γ。

for C in [1, 10, 100, 1000]:
    for gamma in [0.1, 0.15, 0.25, 0.3, 0.35, 0.45, 0.5, 0.65]:
        model = svm_init(C, gamma)
        svm_train(model, hog_descriptors_train, labels_train)
        acc = svm_evaluate(model, hog_descriptors_test, labels_test)
        print(" {}".format("%.2f" % acc))
        results[C].append(acc)

最后,可視化結(jié)果:

fig = plt.figure(figsize=(10, 6))
plt.suptitle("SVM handwritten digits recognition", fontsize=14, fontweight='bold')
ax = plt.subplot(1, 1, 1)
ax.set_xlim(0, 0.65)
dim = [0.1, 0.15, 0.25, 0.3, 0.35, 0.45, 0.5, 0.65]

for key in results:
    ax.plot(dim, results[key], linestyle='--', marker='o', label=str(key))

plt.legend(loc='upper left', title="C")
plt.title('Accuracy of the SVM model varying both C and gamma')
plt.xlabel("gamma")
plt.ylabel("accuracy")
plt.show()

程序的運行結(jié)果如下所示:

如圖所示,通過使用不同參數(shù),準確率可以達到 99.25% 左右。通過比較 KNN 分類器和 SVM 分類器在手寫數(shù)字識別任務(wù)中的表現(xiàn),我們可以得出在手寫數(shù)字識別任務(wù)中 SVM 優(yōu)于 KNN 分類器的結(jié)論。

完整代碼

程序的完整代碼如下所示:

import cv2
import numpy as np
import matplotlib.pyplot as plt
from collections import defaultdict
import keras

(train_dataset, train_labels), (test_dataset, test_labels) = keras.datasets.mnist.load_data()
SIZE_IMAGE = train_dataset.shape[1]
train_labels = np.array(train_labels, dtype=np.int32)

def deskew(img):
    m = cv2.moments(img)
    if abs(m['mu02']) < 1e-2:
        return img.copy()
    skew = m['mu11'] / m['mu02']
    M = np.float32([[1, skew, -0.5 * SIZE_IMAGE * skew], [0, 1, 0]])
    img = cv2.warpAffine(img, M, (SIZE_IMAGE, SIZE_IMAGE), flags=cv2.WARP_INVERSE_MAP | cv2.INTER_LINEAR)

    return img

def get_hog():
    hog = cv2.HOGDescriptor((SIZE_IMAGE, SIZE_IMAGE), (8, 8), (4, 4), (8, 8), 9, 1, -1, 0, 0.2, 1, 64, True)

    print("hog descriptor size: {}".format(hog.getDescriptorSize()))

    return hog

def svm_init(C=12.5, gamma=0.50625):
    model = cv2.ml.SVM_create()
    model.setGamma(gamma)
    model.setC(C)
    model.setKernel(cv2.ml.SVM_RBF)
    model.setType(cv2.ml.SVM_C_SVC)
    model.setTermCriteria((cv2.TERM_CRITERIA_MAX_ITER, 100, 1e-6))

    return model

def svm_train(model, samples, responses):
    model.train(samples, cv2.ml.ROW_SAMPLE, responses)
    return model

def svm_predict(model, samples):
    return model.predict(samples)[1].ravel()

def svm_evaluate(model, samples, labels):
    predictions = svm_predict(model, samples)
    acc = (labels == predictions).mean()
    return acc * 100
# 數(shù)據(jù)打散
shuffle = np.random.permutation(len(train_dataset))
train_dataset, train_labels = train_dataset[shuffle], train_labels[shuffle]
# 使用 HOG 描述符
hog = get_hog()
hog_descriptors = []
for img in train_dataset:
    hog_descriptors.append(hog.compute(deskew(img)))
hog_descriptors = np.squeeze(hog_descriptors)

# 訓練數(shù)據(jù)與測試數(shù)據(jù)劃分
partition = int(0.9 * len(hog_descriptors))
hog_descriptors_train, hog_descriptors_test = np.split(hog_descriptors, [partition])
labels_train, labels_test = np.split(train_labels, [partition])

print('Training SVM model ...')
results = defaultdict(list)

for C in [1, 10, 100, 1000]:
    for gamma in [0.1, 0.15, 0.25, 0.3, 0.35, 0.45, 0.5, 0.65]:
        model = svm_init(C, gamma)
        svm_train(model, hog_descriptors_train, labels_train)
        acc = svm_evaluate(model, hog_descriptors_test, labels_test)
        print(" {}".format("%.2f" % acc))
        results[C].append(acc)

fig = plt.figure(figsize=(10, 6))
plt.suptitle("SVM handwritten digits recognition", fontsize=14, fontweight='bold')
ax = plt.subplot(1, 1, 1)
ax.set_xlim(0, 0.65)
dim = [0.1, 0.15, 0.25, 0.3, 0.35, 0.45, 0.5, 0.65]
for key in results:
    ax.plot(dim, results[key], linestyle='--', marker='o', label=str(key))
plt.legend(loc='upper left', title="C")
plt.title('Accuracy of the SVM model varying both C and gamma')
plt.xlabel("gamma")
plt.ylabel("accuracy")
plt.show() 

以上就是Python利用 SVM 算法實現(xiàn)識別手寫數(shù)字的詳細內(nèi)容,更多關(guān)于Python SVM算法識別手寫數(shù)字的資料請關(guān)注腳本之家其它相關(guān)文章!

相關(guān)文章

  • Python類的用法實例淺析

    Python類的用法實例淺析

    這篇文章主要介紹了Python類的用法,以實例形式簡單分析了Python中類的定義、構(gòu)造函數(shù)及使用技巧,需要的朋友可以參考下
    2015-05-05
  • 帶你徹底搞懂python操作mysql數(shù)據(jù)庫(cursor游標講解)

    帶你徹底搞懂python操作mysql數(shù)據(jù)庫(cursor游標講解)

    這篇文章主要介紹了帶你徹底搞懂python操作mysql數(shù)據(jù)庫(cursor游標講解),文中通過示例代碼介紹的非常詳細,對大家的學習或者工作具有一定的參考學習價值,需要的朋友們下面隨著小編來一起學習學習吧
    2020-01-01
  • Python多進程multiprocessing用法實例分析

    Python多進程multiprocessing用法實例分析

    這篇文章主要介紹了Python多進程multiprocessing用法,結(jié)合實例形式分析了Python多線程的概念以及進程的創(chuàng)建、守護進程、終止、退出進程、進程間消息傳遞等相關(guān)操作技巧,需要的朋友可以參考下
    2017-08-08
  • python文件拆分與重組實例

    python文件拆分與重組實例

    今天小編就為大家分享一篇python文件拆分與重組實例,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧
    2018-12-12
  • Python使用docx模塊處理word文檔流程詳解

    Python使用docx模塊處理word文檔流程詳解

    這篇文章主要介紹了Python使用docx模塊處理word文檔流程,docx模塊是用于創(chuàng)建和更新Microsoft Word文件的Python庫,用于辦公可以顯著提升工作效率,感興趣的同學可以參考下文
    2023-05-05
  • 如何利用opencv訓練自己的模型實現(xiàn)特定物體的識別

    如何利用opencv訓練自己的模型實現(xiàn)特定物體的識別

    在Python中通過OpenCV自己訓練分類器進行特定物體實時識別,下面這篇文章主要給大家介紹了關(guān)于如何利用opencv訓練自己的模型實現(xiàn)特定物體的識別,文中通過實例代碼介紹的非常詳細,需要的朋友可以參考下
    2022-10-10
  • 談?wù)凱ython進行驗證碼識別的一些想法

    談?wù)凱ython進行驗證碼識別的一些想法

    關(guān)于python驗證碼識別,主要方法有幾類:一類是通過對圖片進行處理,然后利用字庫特征匹配的方法,一類是圖片處理后建立字符對應(yīng)字典,還有一類是直接利用ocr模塊進行識別。不管是用什么方法,都需要首先對圖片進行處理,于是試著對下面的驗證碼進行分析
    2016-01-01
  • 解決pycharm工程啟動卡住沒反應(yīng)的問題

    解決pycharm工程啟動卡住沒反應(yīng)的問題

    今天小編就為大家分享一篇解決pycharm工程啟動卡住沒反應(yīng)的問題,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧
    2019-01-01
  • Python 打印中文字符的三種方法

    Python 打印中文字符的三種方法

    本文給大家分享三種方法實現(xiàn)python打印中文字符的方法,代碼簡單易懂,非常不錯,具有一定的參考借鑒價值,需要的朋友參考下吧
    2018-08-08
  • Python基于whois模塊簡單識別網(wǎng)站域名及所有者的方法

    Python基于whois模塊簡單識別網(wǎng)站域名及所有者的方法

    這篇文章主要介紹了Python基于whois模塊簡單識別網(wǎng)站域名及所有者的方法,簡單分析了Python whois模塊的安裝及使用相關(guān)操作技巧,需要的朋友可以參考下
    2018-04-04

最新評論