利用Python實現(xiàn)kNN算法的代碼
鄰近算法(k-NearestNeighbor) 是機器學(xué)習(xí)中的一種分類(classification)算法,也是機器學(xué)習(xí)中最簡單的算法之一了。雖然很簡單,但在解決特定問題時卻能發(fā)揮很好的效果。因此,學(xué)習(xí)kNN算法是機器學(xué)習(xí)入門的一個很好的途徑。
kNN算法的思想非常的樸素,它選取k個離測試點最近的樣本點,輸出在這k個樣本點中數(shù)量最多的標簽(label)。我們假設(shè)每一個樣本有m個特征值(property),則一個樣本的可以用一個m維向量表示: X =( x1,x2,... , xm ), 同樣地,測試點的特征值也可表示成:Y =( y1,y2,... , ym )。那我們怎么定義這兩者之間的“距離”呢?
在二維空間中,有:d2 = ( x1 - y1 )2 + ( x2 - y2 )2 , 在三維空間中,兩點的距離被定義為:d2 = ( x1 - y1 )2 + ( x2 - y2 )2 + ( x3 - y3 )2 。我們可以據(jù)此推廣到m維空間中,定義m維空間的距離:d2 = ( x1 - y1 )2 + ( x2 - y2 )2 + ...... + ( xm - ym )2 。要實現(xiàn)kNN算法,我們只需要計算出每一個樣本點與測試點的距離,選取距離最近的k個樣本,獲取他們的標簽(label) ,然后找出k個樣本中數(shù)量最多的標簽,返回該標簽。
在開始實現(xiàn)算法之前,我們要考慮一個問題,不同特征的特征值范圍可能有很大的差別,例如,我們要分辨一個人的性別,一個女生的身高是1.70m,體重是60kg,一個男生的身高是1.80m,體重是70kg,而一個未知性別的人的身高是1.81m, 體重是64kg,這個人與女生數(shù)據(jù)點的“距離”的平方 d2 = ( 1.70 - 1.81 )2 + ( 60 - 64 )2 = 0.0121 + 16.0 = 16.0121,而與男生數(shù)據(jù)點的“距離”的平方d2 = ( 1.80 - 1.81 )2 + ( 70 - 64 )2 = 0.0001 + 36.0 = 36.0001 。可見,在這種情況下,身高差的平方相對于體重差的平方基本可以忽略不計,但是身高對于辨別性別來說是十分重要的。為了解決這個問題,就需要將數(shù)據(jù)標準化(normalize),把每一個特征值除以該特征的范圍,保證標準化后每一個特征值都在0~1之間。我們寫一個normData函數(shù)來執(zhí)行標準化數(shù)據(jù)集的工作:
def normData(dataSet): maxVals = dataSet.max(axis=0) minVals = dataSet.min(axis=0) ranges = maxVals - minVals retData = (dataSet - minVals) / ranges return retData, ranges, minVals
然后開始實現(xiàn)kNN算法:
def kNN(dataSet, labels, testData, k): distSquareMat = (dataSet - testData) ** 2 # 計算差值的平方 distSquareSums = distSquareMat.sum(axis=1) # 求每一行的差值平方和 distances = distSquareSums ** 0.5 # 開根號,得出每個樣本到測試點的距離 sortedIndices = distances.argsort() # 排序,得到排序后的下標 indices = sortedIndices[:k] # 取最小的k個 labelCount = {} # 存儲每個label的出現(xiàn)次數(shù) for i in indices: label = labels[i] labelCount[label] = labelCount.get(label, 0) + 1 # 次數(shù)加一 sortedCount = sorted(labelCount.items(), key=opt.itemgetter(1), reverse=True) # 對label出現(xiàn)的次數(shù)從大到小進行排序 return sortedCount[0][0] # 返回出現(xiàn)次數(shù)最大的label
注意,在testData作為參數(shù)傳入kNN函數(shù)之前,需要經(jīng)過標準化。
我們用幾個小數(shù)據(jù)驗證一下kNN函數(shù)是否能正常工作:
if __name__ == "__main__": dataSet = np.array([[2, 3], [6, 8]]) normDataSet, ranges, minVals = normData(dataSet) labels = ['a', 'b'] testData = np.array([3.9, 5.5]) normTestData = (testData - minVals) / ranges result = kNN(normDataSet, labels, normTestData, 1) print(result)
結(jié)果輸出 a ,與預(yù)期結(jié)果一致。
完整代碼:
import numpy as np from math import sqrt import operator as opt def normData(dataSet): maxVals = dataSet.max(axis=0) minVals = dataSet.min(axis=0) ranges = maxVals - minVals retData = (dataSet - minVals) / ranges return retData, ranges, minVals def kNN(dataSet, labels, testData, k): distSquareMat = (dataSet - testData) ** 2 # 計算差值的平方 distSquareSums = distSquareMat.sum(axis=1) # 求每一行的差值平方和 distances = distSquareSums ** 0.5 # 開根號,得出每個樣本到測試點的距離 sortedIndices = distances.argsort() # 排序,得到排序后的下標 indices = sortedIndices[:k] # 取最小的k個 labelCount = {} # 存儲每個label的出現(xiàn)次數(shù) for i in indices: label = labels[i] labelCount[label] = labelCount.get(label, 0) + 1 # 次數(shù)加一 sortedCount = sorted(labelCount.items(), key=opt.itemgetter(1), reverse=True) # 對label出現(xiàn)的次數(shù)從大到小進行排序 return sortedCount[0][0] # 返回出現(xiàn)次數(shù)最大的label if __name__ == "__main__": dataSet = np.array([[2, 3], [6, 8]]) normDataSet, ranges, minVals = normData(dataSet) labels = ['a', 'b'] testData = np.array([3.9, 5.5]) normTestData = (testData - minVals) / ranges result = kNN(normDataSet, labels, normTestData, 1) print(result)
以上就是本文的全部內(nèi)容,希望對大家的學(xué)習(xí)有所幫助,也希望大家多多支持腳本之家。
相關(guān)文章
numpy如何按條件給元素賦值np.where、np.clip
這篇文章主要介紹了numpy如何按條件給元素賦值np.where、np.clip問題,具有很好的參考價值,希望對大家有所幫助。如有錯誤或未考慮完全的地方,望不吝賜教2023-06-06python語言中print中加號、減號、乘號的應(yīng)用方式
這篇文章主要介紹了python語言中print中加號、減號、乘號的應(yīng)用方式,具有很好的參考價值,希望對大家有所幫助,如有錯誤或未考慮完全的地方,望不吝賜教2024-02-02簡單掌握Python的Collections模塊中counter結(jié)構(gòu)的用法
counter數(shù)據(jù)結(jié)構(gòu)被用來提供技術(shù)功能,形式類似于Python中內(nèi)置的字典結(jié)構(gòu),這里通過幾個小例子來簡單掌握Python的Collections模塊中counter結(jié)構(gòu)的用法:2016-07-07python異步編程之a(chǎn)syncio高階API的使用詳解
asyncio中函數(shù)可以分為高階函數(shù)和低階函數(shù),通常開發(fā)中使用更多的是高階函數(shù),本文主要為大家介紹了asyncio中常用的高階函數(shù),需要的可以參考下2024-01-01