Python語(yǔ)言描述KNN算法與Kd樹
最近鄰法和k-近鄰法
下面圖片中只有三種豆,有三個(gè)豆是未知的種類,如何判定他們的種類?
提供一種思路,即:未知的豆離哪種豆最近就認(rèn)為未知豆和該豆是同一種類。由此,我們引出最近鄰算法的定義:為了判定未知樣本的類別,以全部訓(xùn)練樣本作為代表點(diǎn),計(jì)算未知樣本與所有訓(xùn)練樣本的距離,并以最近鄰者的類別作為決策未知樣本類別的唯一依據(jù)。但是,最近鄰算法明顯是存在缺陷的,比如下面的例子:有一個(gè)未知形狀(圖中綠色的圓點(diǎn)),如何判斷它是什么形狀?
顯然,最近鄰算法的缺陷——對(duì)噪聲數(shù)據(jù)過(guò)于敏感,為了解決這個(gè)問(wèn)題,我們可以可以把未知樣本周邊的多個(gè)最近樣本計(jì)算在內(nèi),擴(kuò)大參與決策的樣本量,以避免個(gè)別數(shù)據(jù)直接決定決策結(jié)果。由此,我們引進(jìn)K-最近鄰算法。K-最近鄰算法是最近鄰算法的一個(gè)延伸?;舅悸肥牵哼x擇未知樣本一定范圍內(nèi)確定個(gè)數(shù)的K個(gè)樣本,該K個(gè)樣本大多數(shù)屬于某一類型,則未知樣本判定為該類型。如何選擇一個(gè)最佳的K值取決于數(shù)據(jù)。一般情況下,在分類時(shí)較大的K值能夠減小噪聲的影響,但會(huì)使類別之間的界限變得模糊。待測(cè)樣本(綠色圓圈)既可能分到紅色三角形類,也可能分到藍(lán)色正方形類。如果k取3,從圖可見,待測(cè)樣本的3個(gè)鄰居在實(shí)線的內(nèi)圓里,按多數(shù)投票結(jié)果,它屬于紅色三角形類。但是如果k取5,那么待測(cè)樣本的最鄰近的5個(gè)樣本在虛線的圓里,按表決法,它又屬于藍(lán)色正方形類。在實(shí)際應(yīng)用中,K先取一個(gè)比較小的數(shù)值,再采用交叉驗(yàn)證法來(lái)逐步調(diào)整K值,最終選擇適合該樣本的最優(yōu)的K值。
KNN算法實(shí)現(xiàn)
算法基本步驟:
1)計(jì)算待分類點(diǎn)與已知類別的點(diǎn)之間的距離
2)按照距離遞增次序排序
3)選取與待分類點(diǎn)距離最小的k個(gè)點(diǎn)
4)確定前k個(gè)點(diǎn)所在類別的出現(xiàn)次數(shù)
5)返回前k個(gè)點(diǎn)出現(xiàn)次數(shù)最高的類別作為待分類點(diǎn)的預(yù)測(cè)分類
下面是一個(gè)按照算法基本步驟用python實(shí)現(xiàn)的簡(jiǎn)單例子,根據(jù)已分類的4個(gè)樣本點(diǎn)來(lái)預(yù)測(cè)未知點(diǎn)(圖中的灰點(diǎn))的分類:
from numpy import * # create a dataset which contains 4 samples with 2 classes def createDataSet(): # create a matrix: each row as a sample group = array([[1.0, 0.9], [1.0, 1.0], [0.1, 0.2], [0.0, 0.1]]) labels = ['A', 'A', 'B', 'B'] # four samples and two classes return group, labels # classify using kNN (k Nearest Neighbors ) # Input: newInput: 1 x N # dataSet: M x N (M samples N, features) # labels: 1 x M # k: number of neighbors to use for comparison # Output: the most popular class label def kNNClassify(newInput, dataSet, labels, k): numSamples = dataSet.shape[0] # shape[0] stands for the num of row ## step 1: calculate Euclidean distance # tile(A, reps): Construct an array by repeating A reps times # the following copy numSamples rows for dataSet diff = tile(newInput, (numSamples, 1)) - dataSet # Subtract element-wise squaredDiff = diff ** 2 # squared for the subtract squaredDist = sum(squaredDiff, axis = 1) # sum is performed by row distance = squaredDist ** 0.5 ## step 2: sort the distance # argsort() returns the indices that would sort an array in a ascending order sortedDistIndices = argsort(distance) classCount = {} # define a dictionary (can be append element) for i in xrange(k): ## step 3: choose the min k distance voteLabel = labels[sortedDistIndices[i]] ## step 4: count the times labels occur # when the key voteLabel is not in dictionary classCount, get() # will return 0 classCount[voteLabel] = classCount.get(voteLabel, 0) + 1 ## step 5: the max voted class will return maxCount = 0 for key, value in classCount.items(): if value > maxCount: maxCount = value maxIndex = key return maxIndex if __name__== "__main__": dataSet, labels = createDataSet() testX = array([1.2, 1.0]) k = 3 outputLabel = kNNClassify(testX, dataSet, labels, 3) print "Your input is:", testX, "and classified to class: ", outputLabel testX = array([0.1, 0.3]) outputLabel = kNNClassify(testX, dataSet, labels, 3) print "Your input is:", testX, "and classified to class: ", outputLabel
結(jié)果如下:
Your input is: [ 1.2 1. ] and classified to class: A
Your input is: [ 0.1 0.3] and classified to class: B
OpenCV中也提供了機(jī)器學(xué)習(xí)的相關(guān)算法,其中KNN算法的最基本例子如下
import numpy as np import matplotlib.pyplot as plt import cv2 # Feature set containing (x,y) values of 25 known/training data trainData = np.random.randint(0,100,(25,2)).astype(np.float32) # Labels each one either Red or Blue with numbers 0 and 1 responses = np.random.randint(0,2,(25,1)).astype(np.float32) # Take Red families and plot them red = trainData[responses.ravel()==0] plt.scatter(red[:,0],red[:,1],80,'r','^') # Take Blue families and plot them blue = trainData[responses.ravel()==1] plt.scatter(blue[:,0],blue[:,1],80,'b','s') # Testing data newcomer = np.random.randint(0,100,(1,2)).astype(np.float32) plt.scatter(newcomer[:,0],newcomer[:,1],80,'g','o') knn = cv2.KNearest() knn.train(trainData,responses) # Trains the model # Finds the neighbors and predicts responses for input vectors. ret, results, neighbours ,dist = knn.find_nearest(newcomer, 3) print "result: ", results,"\n" print "neighbours: ", neighbours,"\n" print "distance: ", dist plt.show()
>>> result: [[ 0.]] neighbours: [[ 0. 0. 0.]] distance: [[ 65. 145. 178.]]
可以看到KNN算法將未知點(diǎn)分到第0組(紅色三角形組),從上圖中也可看出3個(gè)距離未知點(diǎn)最近的樣本都屬于第0組,因此算法返回分類標(biāo)簽也為0。
KNN算法的缺陷
觀察下面的例子,我們看到對(duì)于樣本X,通過(guò)KNN算法,我們顯然可以得到X應(yīng)屬于紅點(diǎn),但對(duì)于樣本Y,通過(guò)KNN算法我們似乎得到了Y應(yīng)屬于藍(lán)點(diǎn)的結(jié)論,而這個(gè)結(jié)論直觀來(lái)看并沒有說(shuō)服力。
由上面的例子可見:該算法在分類時(shí)有個(gè)重要的不足是,當(dāng)樣本不平衡時(shí),即:一個(gè)類的樣本容量很大,而其他類樣本數(shù)量很小時(shí),很有可能導(dǎo)致當(dāng)輸入一個(gè)未知樣本時(shí),該樣本的K個(gè)鄰居中大數(shù)量類的樣本占多數(shù)。 但是這類樣本并不接近目標(biāo)樣本,而數(shù)量小的這類樣本很靠近目標(biāo)樣本。這個(gè)時(shí)候,我們有理由認(rèn)為該位置樣本屬于數(shù)量小的樣本所屬的一類,但是,KNN卻不關(guān)心這個(gè)問(wèn)題,它只關(guān)心哪類樣本的數(shù)量最多,而不去把距離遠(yuǎn)近考慮在內(nèi),因此,我們可以采用權(quán)值的方法來(lái)改進(jìn)。和該樣本距離小的鄰居權(quán)值大,和該樣本距離大的鄰居權(quán)值則相對(duì)較小,由此,將距離遠(yuǎn)近的因素也考慮在內(nèi),避免因一個(gè)樣本過(guò)大導(dǎo)致誤判的情況。
從算法實(shí)現(xiàn)的過(guò)程可以發(fā)現(xiàn),該算法存兩個(gè)嚴(yán)重的問(wèn)題,第一個(gè)是需要存儲(chǔ)全部的訓(xùn)練樣本,第二個(gè)是計(jì)算量較大,因?yàn)閷?duì)每一個(gè)待分類的樣本都要計(jì)算它到全體已知樣本的距離,才能求得它的K個(gè)最近鄰點(diǎn)。KNN算法的改進(jìn)方法之一是分組快速搜索近鄰法。其基本思想是:將樣本集按近鄰關(guān)系分解成組,給出每組質(zhì)心的位置,以質(zhì)心作為代表點(diǎn),和未知樣本計(jì)算距離,選出距離最近的一個(gè)或若干個(gè)組,再在組的范圍內(nèi)應(yīng)用一般的KNN算法。由于并不是將未知樣本與所有樣本計(jì)算距離,故該改進(jìn)算法可以減少計(jì)算量,但并不能減少存儲(chǔ)量。
KD樹
實(shí)現(xiàn)k近鄰法時(shí),主要考慮的問(wèn)題是如何對(duì)訓(xùn)練數(shù)據(jù)進(jìn)行快速k近鄰搜索。這在特征空間的維數(shù)大及訓(xùn)練數(shù)據(jù)容量大時(shí)尤其必要。k近鄰法最簡(jiǎn)單的實(shí)現(xiàn)是線性掃描(窮舉搜索),即要計(jì)算輸入實(shí)例與每一個(gè)訓(xùn)練實(shí)例的距離。計(jì)算并存儲(chǔ)好以后,再查找K近鄰。當(dāng)訓(xùn)練集很大時(shí),計(jì)算非常耗時(shí)。為了提高kNN搜索的效率,可以考慮使用特殊的結(jié)構(gòu)存儲(chǔ)訓(xùn)練數(shù)據(jù),以減小計(jì)算距離的次數(shù)。
kd樹(K-dimension tree)是一種對(duì)k維空間中的實(shí)例點(diǎn)進(jìn)行存儲(chǔ)以便對(duì)其進(jìn)行快速檢索的樹形數(shù)據(jù)結(jié)構(gòu)。kd樹是是一種二叉樹,表示對(duì)k維空間的一個(gè)劃分,構(gòu)造kd樹相當(dāng)于不斷地用垂直于坐標(biāo)軸的超平面將K維空間切分,構(gòu)成一系列的K維超矩形區(qū)域。kd樹的每個(gè)結(jié)點(diǎn)對(duì)應(yīng)于一個(gè)k維超矩形區(qū)域。利用kd樹可以省去對(duì)大部分?jǐn)?shù)據(jù)點(diǎn)的搜索,從而減少搜索的計(jì)算量。
對(duì)一個(gè)三維空間,kd樹按照一定的劃分規(guī)則把這個(gè)三維空間劃分了多個(gè)空間,如下圖所示
類比“二分查找”:給出一組數(shù)據(jù):[9 1 4 7 2 5 0 3 8],要查找8。如果挨個(gè)查找(線性掃描),那么將會(huì)把數(shù)據(jù)集都遍歷一遍。而如果排一下序那數(shù)據(jù)集就變成了:[0 1 2 3 4 5 6 7 8 9],按前一種方式我們進(jìn)行了很多沒有必要的查找,現(xiàn)在如果我們以5為分界點(diǎn),那么數(shù)據(jù)集就被劃分為了左右兩個(gè)“簇” [0 1 2 3 4]和[6 7 8 9]。因此,根本久沒有必要進(jìn)入第一個(gè)簇,可以直接進(jìn)入第二個(gè)簇進(jìn)行查找。把二分查找中的數(shù)據(jù)點(diǎn)換成k維數(shù)據(jù)點(diǎn),這樣的劃分就變成了用超平面對(duì)k維空間的劃分。空間劃分就是對(duì)數(shù)據(jù)點(diǎn)進(jìn)行分類,“挨得近”的數(shù)據(jù)點(diǎn)就在一個(gè)空間里面。
構(gòu)造kd樹的方法如下:構(gòu)造根結(jié)點(diǎn),使根結(jié)點(diǎn)對(duì)應(yīng)于K維空間中包含所有實(shí)例點(diǎn)的超矩形區(qū)域;通過(guò)下面的遞歸的方法,不斷地對(duì)k維空間進(jìn)行切分,生成子結(jié)點(diǎn)。在超矩形區(qū)域上選擇一個(gè)坐標(biāo)軸和在此坐標(biāo)軸上的一個(gè)切分點(diǎn),確定一個(gè)超平面,這個(gè)超平面通過(guò)選定的切分點(diǎn)并垂直于選定的坐標(biāo)軸,將當(dāng)前超矩形區(qū)域切分為左右兩個(gè)子區(qū)域(子結(jié)點(diǎn));這時(shí),實(shí)例被分到兩個(gè)子區(qū)域,這個(gè)過(guò)程直到子區(qū)域內(nèi)沒有實(shí)例時(shí)終止(終止時(shí)的結(jié)點(diǎn)為葉結(jié)點(diǎn))。在此過(guò)程中,將實(shí)例保存在相應(yīng)的結(jié)點(diǎn)上。通常,循環(huán)的擇坐標(biāo)軸對(duì)空間切分,選擇訓(xùn)練實(shí)例點(diǎn)在坐標(biāo)軸上的中位數(shù)為切分點(diǎn),這樣得到的kd樹是平衡的(平衡二叉樹:它是一棵空樹,或其左子樹和右子樹的深度之差的絕對(duì)值不超過(guò)1,且它的左子樹和右子樹都是平衡二叉樹)?!?/p>
KD樹中每個(gè)節(jié)點(diǎn)是一個(gè)向量,和二叉樹按照數(shù)的大小劃分不同的是,KD樹每層需要選定向量中的某一維,然后根據(jù)這一維按左小右大的方式劃分?jǐn)?shù)據(jù)。在構(gòu)建KD樹時(shí),關(guān)鍵需要解決2個(gè)問(wèn)題:(1)選擇向量的哪一維進(jìn)行劃分;(2)如何劃分?jǐn)?shù)據(jù)。第一個(gè)問(wèn)題簡(jiǎn)單的解決方法可以是選擇隨機(jī)選擇某一維或按順序選擇,但是更好的方法應(yīng)該是在數(shù)據(jù)比較分散的那一維進(jìn)行劃分(分散的程度可以根據(jù)方差來(lái)衡量)。好的劃分方法可以使構(gòu)建的樹比較平衡,可以每次選擇中位數(shù)來(lái)進(jìn)行劃分,這樣問(wèn)題2也得到了解決。
構(gòu)造平衡kd樹算法:
輸入:kk維空間數(shù)據(jù)集T={x1,x2,...,xN},其中
輸出:kd樹
(1)開始:構(gòu)造根結(jié)點(diǎn),根結(jié)點(diǎn)對(duì)應(yīng)于包含T的k維空間的超矩形區(qū)域。選擇x(1)x(1)為坐標(biāo)軸,以T中所有實(shí)例的x(1)x(1)坐標(biāo)的中位數(shù)為切分點(diǎn),將根結(jié)點(diǎn)對(duì)應(yīng)的超矩形區(qū)域切分為兩個(gè)子區(qū)域。切分由通過(guò)切分點(diǎn)并與坐標(biāo)軸x(1)x(1)垂直的超平面實(shí)現(xiàn)。由根結(jié)點(diǎn)生成深度為1的左、右子結(jié)點(diǎn):左子結(jié)點(diǎn)對(duì)應(yīng)坐標(biāo)x(1)x(1)小于切分點(diǎn)的子區(qū)域,右子結(jié)點(diǎn)對(duì)應(yīng)于坐標(biāo)x(1)x(1)大于切分點(diǎn)的子區(qū)域。將落在切分超平面上的實(shí)例點(diǎn)保存在根結(jié)點(diǎn)。
(2)重復(fù)。對(duì)深度為j的結(jié)點(diǎn),選擇x(l)x(l)為切分的坐標(biāo)軸,l=j%k+1l=j%k+1,以該結(jié)點(diǎn)的區(qū)域中所有實(shí)例的x(l)x(l)坐標(biāo)的中位數(shù)為切分點(diǎn),將該結(jié)點(diǎn)對(duì)應(yīng)的超矩形區(qū)域切分為兩個(gè)子區(qū)域。切分由通過(guò)切分點(diǎn)并與坐標(biāo)軸x(l)x(l)垂直的超平面實(shí)現(xiàn)。由該結(jié)點(diǎn)生成深度為j+1的左、右子結(jié)點(diǎn):左子結(jié)點(diǎn)對(duì)應(yīng)坐標(biāo)x(l)x(l)小于切分點(diǎn)的子區(qū)域,右子結(jié)點(diǎn)對(duì)應(yīng)坐標(biāo)x(l)x(l)大于切分點(diǎn)的子區(qū)域。將落在切分超平面上的實(shí)例點(diǎn)保存在該結(jié)點(diǎn)。
下面用一個(gè)簡(jiǎn)單的2維平面上的例子來(lái)進(jìn)行說(shuō)明。
例. 給定一個(gè)二維空間數(shù)據(jù)集:T={(2,3),(5,4),(9,6),(4,7),(8,1),(7,2)}T={(2,3),(5,4),(9,6),(4,7),(8,1),(7,2)},構(gòu)造一個(gè)平衡kd樹。
解:根結(jié)點(diǎn)對(duì)應(yīng)包含數(shù)據(jù)集T的矩形,選擇x(1)x(1)軸,6個(gè)數(shù)據(jù)點(diǎn)的x(1)x(1)坐標(biāo)中位數(shù)是6,這里選最接近的(7,2)點(diǎn),以平面x(1)=7x(1)=7將空間分為左、右兩個(gè)子矩形(子結(jié)點(diǎn));接著左矩形以x(2)=4x(2)=4分為兩個(gè)子矩形(左矩形中{(2,3),(5,4),(4,7)}點(diǎn)的x(2)x(2)坐標(biāo)中位數(shù)正好為4),右矩形以x(2)=6x(2)=6分為兩個(gè)子矩形,如此遞歸,最后得到如下圖所示的特征空間劃分和kd樹。
下面的代碼用遞歸的方式構(gòu)建了kd樹,通過(guò)前序遍歷可以進(jìn)行驗(yàn)證。這里只是簡(jiǎn)單地采用坐標(biāo)輪換方式選取分割軸,為了更高效的分割空間,也可以計(jì)算所有數(shù)據(jù)點(diǎn)在每個(gè)維度上的數(shù)值的方差,然后選擇方差最大的維度作為當(dāng)前節(jié)點(diǎn)的劃分維度。方差越大,說(shuō)明這個(gè)維度上的數(shù)據(jù)越不集中(稀疏、分散),也就說(shuō)明了它們就越不可能屬于同一個(gè)空間,因此需要在這個(gè)維度上進(jìn)行劃分。
# -*- coding: utf-8 -*- #from operator import itemgetter import sys reload(sys) sys.setdefaultencoding('utf8') # kd-tree每個(gè)結(jié)點(diǎn)中主要包含的數(shù)據(jù)結(jié)構(gòu)如下 class KdNode(object): def __init__(self, dom_elt, split, left, right): self.dom_elt = dom_elt # k維向量節(jié)點(diǎn)(k維空間中的一個(gè)樣本點(diǎn)) self.split = split # 整數(shù)(進(jìn)行分割維度的序號(hào)) self.left = left # 該結(jié)點(diǎn)分割超平面左子空間構(gòu)成的kd-tree self.right = right # 該結(jié)點(diǎn)分割超平面右子空間構(gòu)成的kd-tree class KdTree(object): def __init__(self, data): k = len(data[0]) # 數(shù)據(jù)維度 def CreateNode(split, data_set): # 按第split維劃分?jǐn)?shù)據(jù)集exset創(chuàng)建KdNode if not data_set: # 數(shù)據(jù)集為空 return None # key參數(shù)的值為一個(gè)函數(shù),此函數(shù)只有一個(gè)參數(shù)且返回一個(gè)值用來(lái)進(jìn)行比較 # operator模塊提供的itemgetter函數(shù)用于獲取對(duì)象的哪些維的數(shù)據(jù),參數(shù)為需要獲取的數(shù)據(jù)在對(duì)象中的序號(hào) #data_set.sort(key=itemgetter(split)) # 按要進(jìn)行分割的那一維數(shù)據(jù)排序 data_set.sort(key=lambda x: x[split]) split_pos = len(data_set) // 2 # //為Python中的整數(shù)除法 median = data_set[split_pos] # 中位數(shù)分割點(diǎn) split_next = (split + 1) % k # cycle coordinates # 遞歸的創(chuàng)建kd樹 return KdNode(median, split, CreateNode(split_next, data_set[:split_pos]), # 創(chuàng)建左子樹 CreateNode(split_next, data_set[split_pos + 1:])) # 創(chuàng)建右子樹 self.root = CreateNode(0, data) # 從第0維分量開始構(gòu)建kd樹,返回根節(jié)點(diǎn) # KDTree的前序遍歷 def preorder(root): print root.dom_elt if root.left: # 節(jié)點(diǎn)不為空 preorder(root.left) if root.right: preorder(root.right) if __name__ == "__main__": data = [[2,3],[5,4],[9,6],[4,7],[8,1],[7,2]] kd = KdTree(data) preorder(kd.root)
進(jìn)行前序遍歷(前序遍歷首先訪問(wèn)根結(jié)點(diǎn)然后遍歷左子樹,最后遍歷右子樹)的結(jié)果如下,可見已經(jīng)正確構(gòu)建了kd樹:
搜索kd樹
利用kd樹可以省去對(duì)大部分?jǐn)?shù)據(jù)點(diǎn)的搜索,從而減少搜索的計(jì)算量。下面以搜索最近鄰點(diǎn)為例加以敘述:給定一個(gè)目標(biāo)點(diǎn),搜索其最近鄰,首先找到包含目標(biāo)點(diǎn)的葉節(jié)點(diǎn);然后從該葉結(jié)點(diǎn)出發(fā),依次回退到父結(jié)點(diǎn);不斷查找與目標(biāo)點(diǎn)最近鄰的結(jié)點(diǎn),當(dāng)確定不可能存在更近的結(jié)點(diǎn)時(shí)終止。這樣搜索就被限制在空間的局部區(qū)域上,效率大為提高。
用kd樹的最近鄰搜索:
輸入: 已構(gòu)造的kd樹;目標(biāo)點(diǎn)xx;
輸出:xx的最近鄰。
(1) 在kd樹中找出包含目標(biāo)點(diǎn)xx的葉結(jié)點(diǎn):從根結(jié)點(diǎn)出發(fā),遞歸的向下訪問(wèn)kd樹。若目標(biāo)點(diǎn)當(dāng)前維的坐標(biāo)值小于切分點(diǎn)的坐標(biāo)值,則移動(dòng)到左子結(jié)點(diǎn),否則移動(dòng)到右子結(jié)點(diǎn)。直到子結(jié)點(diǎn)為葉結(jié)點(diǎn)為止;
(2) 以此葉結(jié)點(diǎn)為“當(dāng)前最近點(diǎn)”;
(3) 遞歸的向上回退,在每個(gè)結(jié)點(diǎn)進(jìn)行以下操作:
?。╝) 如果該結(jié)點(diǎn)保存的實(shí)例點(diǎn)比當(dāng)前最近點(diǎn)距目標(biāo)點(diǎn)更近,則以該實(shí)例點(diǎn)為“當(dāng)前最近點(diǎn)”;
?。╞) 當(dāng)前最近點(diǎn)一定存在于該結(jié)點(diǎn)一個(gè)子結(jié)點(diǎn)對(duì)應(yīng)的區(qū)域。檢查該子結(jié)點(diǎn)的父結(jié)點(diǎn)的另一個(gè)子結(jié)點(diǎn)對(duì)應(yīng)的區(qū)域是否有更近的點(diǎn)。具體的,檢查另一個(gè)子結(jié)點(diǎn)對(duì)應(yīng)的區(qū)域是否與以目標(biāo)點(diǎn)為球心、以目標(biāo)點(diǎn)與“當(dāng)前最近點(diǎn)”間的距離為半徑的超球體相交。如果相交,可能在另一個(gè)子結(jié)點(diǎn)對(duì)應(yīng)的區(qū)域內(nèi)存在距離目標(biāo)更近的點(diǎn),移動(dòng)到另一個(gè)子結(jié)點(diǎn)。接著,遞歸的進(jìn)行最近鄰搜索。如果不相交,向上回退。
(4) 當(dāng)回退到根結(jié)點(diǎn)時(shí),搜索結(jié)束。最后的“當(dāng)前最近點(diǎn)”即為xx的最近鄰點(diǎn)。
以先前構(gòu)建好的kd樹為例,查找目標(biāo)點(diǎn)(3,4.5)的最近鄰點(diǎn)。同樣先進(jìn)行二叉查找,先從(7,2)查找到(5,4)節(jié)點(diǎn),在進(jìn)行查找時(shí)是由y = 4為分割超平面的,由于查找點(diǎn)為y值為4.5,因此進(jìn)入右子空間查找到(4,7),形成搜索路徑:(7,2)→(5,4)→(4,7),取(4,7)為當(dāng)前最近鄰點(diǎn)。以目標(biāo)查找點(diǎn)為圓心,目標(biāo)查找點(diǎn)到當(dāng)前最近點(diǎn)的距離2.69為半徑確定一個(gè)紅色的圓。然后回溯到(5,4),計(jì)算其與查找點(diǎn)之間的距離為2.06,則該結(jié)點(diǎn)比當(dāng)前最近點(diǎn)距目標(biāo)點(diǎn)更近,以(5,4)為當(dāng)前最近點(diǎn)。用同樣的方法再次確定一個(gè)綠色的圓,可見該圓和y = 4超平面相交,所以需要進(jìn)入(5,4)結(jié)點(diǎn)的另一個(gè)子空間進(jìn)行查找。(2,3)結(jié)點(diǎn)與目標(biāo)點(diǎn)距離為1.8,比當(dāng)前最近點(diǎn)要更近,所以最近鄰點(diǎn)更新為(2,3),最近距離更新為1.8,同樣可以確定一個(gè)藍(lán)色的圓。接著根據(jù)規(guī)則回退到根結(jié)點(diǎn)(7,2),藍(lán)色圓與x=7的超平面不相交,因此不用進(jìn)入(7,2)的右子空間進(jìn)行查找。至此,搜索路徑回溯完,返回最近鄰點(diǎn)(2,3),最近距離1.8。
如果實(shí)例點(diǎn)是隨機(jī)分布的,kd樹搜索的平均計(jì)算復(fù)雜度是O(logN)O(logN),這里N是訓(xùn)練實(shí)例數(shù)。kd樹更適用于訓(xùn)練實(shí)例數(shù)遠(yuǎn)大于空間維數(shù)時(shí)的k近鄰搜索。當(dāng)空間維數(shù)接近訓(xùn)練實(shí)例數(shù)時(shí),它的效率會(huì)迅速下降,幾乎接近線性掃描。
下面的代碼對(duì)構(gòu)建好的kd樹進(jìn)行搜索,尋找與目標(biāo)點(diǎn)最近的樣本點(diǎn):
from math import sqrt from collections import namedtuple # 定義一個(gè)namedtuple,分別存放最近坐標(biāo)點(diǎn)、最近距離和訪問(wèn)過(guò)的節(jié)點(diǎn)數(shù) result = namedtuple("Result_tuple", "nearest_point nearest_dist nodes_visited") def find_nearest(tree, point): k = len(point) # 數(shù)據(jù)維度 def travel(kd_node, target, max_dist): if kd_node is None: return result([0] * k, float("inf"), 0) # python中用float("inf")和float("-inf")表示正負(fù)無(wú)窮 nodes_visited = 1 s = kd_node.split # 進(jìn)行分割的維度 pivot = kd_node.dom_elt # 進(jìn)行分割的“軸” if target[s] <= pivot[s]: # 如果目標(biāo)點(diǎn)第s維小于分割軸的對(duì)應(yīng)值(目標(biāo)離左子樹更近) nearer_node = kd_node.left # 下一個(gè)訪問(wèn)節(jié)點(diǎn)為左子樹根節(jié)點(diǎn) further_node = kd_node.right # 同時(shí)記錄下右子樹 else: # 目標(biāo)離右子樹更近 nearer_node = kd_node.right # 下一個(gè)訪問(wèn)節(jié)點(diǎn)為右子樹根節(jié)點(diǎn) further_node = kd_node.left temp1 = travel(nearer_node, target, max_dist) # 進(jìn)行遍歷找到包含目標(biāo)點(diǎn)的區(qū)域 nearest = temp1.nearest_point # 以此葉結(jié)點(diǎn)作為“當(dāng)前最近點(diǎn)” dist = temp1.nearest_dist # 更新最近距離 nodes_visited += temp1.nodes_visited if dist < max_dist: max_dist = dist # 最近點(diǎn)將在以目標(biāo)點(diǎn)為球心,max_dist為半徑的超球體內(nèi) temp_dist = abs(pivot[s] - target[s]) # 第s維上目標(biāo)點(diǎn)與分割超平面的距離 if max_dist < temp_dist: # 判斷超球體是否與超平面相交 return result(nearest, dist, nodes_visited) # 不相交則可以直接返回,不用繼續(xù)判斷 #---------------------------------------------------------------------- # 計(jì)算目標(biāo)點(diǎn)與分割點(diǎn)的歐氏距離 temp_dist = sqrt(sum((p1 - p2) ** 2 for p1, p2 in zip(pivot, target))) if temp_dist < dist: # 如果“更近” nearest = pivot # 更新最近點(diǎn) dist = temp_dist # 更新最近距離 max_dist = dist # 更新超球體半徑 # 檢查另一個(gè)子結(jié)點(diǎn)對(duì)應(yīng)的區(qū)域是否有更近的點(diǎn) temp2 = travel(further_node, target, max_dist) nodes_visited += temp2.nodes_visited if temp2.nearest_dist < dist: # 如果另一個(gè)子結(jié)點(diǎn)內(nèi)存在更近距離 nearest = temp2.nearest_point # 更新最近點(diǎn) dist = temp2.nearest_dist # 更新最近距離 return result(nearest, dist, nodes_visited) return travel(tree.root, point, float("inf")) # 從根節(jié)點(diǎn)開始遞歸
下面結(jié)合前面寫的代碼來(lái)進(jìn)行一下測(cè)試:
from time import clock from random import random # 產(chǎn)生一個(gè)k維隨機(jī)向量,每維分量值在0~1之間 def random_point(k): return [random() for _ in range(k)] # 產(chǎn)生n個(gè)k維隨機(jī)向量 def random_points(k, n): return [random_point(k) for _ in range(n)] if __name__ == "__main__": data = [[2,3],[5,4],[9,6],[4,7],[8,1],[7,2]] # samples kd = KdTree(data) ret = find_nearest(kd, [3,4.5]) print ret N = 400000 t0 = clock() kd2 = KdTree(random_points(3, N)) # 構(gòu)建包含四十萬(wàn)個(gè)3維空間樣本點(diǎn)的kd樹 ret2 = find_nearest(kd2, [0.1,0.5,0.8]) # 四十萬(wàn)個(gè)樣本點(diǎn)中尋找離目標(biāo)最近的點(diǎn) t1 = clock() print "time: ",t1-t0, "s" print ret2
結(jié)果如下圖所示。先是測(cè)試了之前例子中距離(3,4.5)最近的點(diǎn),可以看出正確返回了最近點(diǎn)(2,3)以及最近距離。然后隨機(jī)生成了四十萬(wàn)個(gè)三維空間樣本點(diǎn),并構(gòu)建kd樹,然后搜索離(0.1,0.5,0.8)最近的樣本點(diǎn),并測(cè)試用時(shí)。為了進(jìn)行對(duì)比我先是使用numpy算出全部四十萬(wàn)個(gè)距離后尋找最近點(diǎn),結(jié)果耗時(shí)0.5s左右?。。≡趺茨苓@么快(⊙▽⊙),然后不用numpy自己在python中計(jì)算全部距離,結(jié)果耗時(shí)2s左右,還是比自己寫的KD樹要快得多...
可能是這種使用遞歸方式創(chuàng)建和搜索的kd樹本身效率就不是很高(知乎:為什么說(shuō)遞歸效率低?)。而且深層遞歸一定要盡量避免,一是不安全,容易導(dǎo)致棧溢出;二是調(diào)用代價(jià)高(遞歸函數(shù)調(diào)用的代價(jià))??梢钥紤]轉(zhuǎn)換為循環(huán)結(jié)構(gòu)。循環(huán)結(jié)構(gòu)的kd樹實(shí)現(xiàn)參考:KDTree example in scipy
總結(jié)
以上就是本文關(guān)于Python語(yǔ)言描述KNN算法與Kd樹的全部?jī)?nèi)容,希望對(duì)大家有所幫助。感興趣的朋友可以繼續(xù)參閱本站其他相關(guān)專題。如有不足之處,歡迎留言指出。感謝朋友們對(duì)本站的支持!
相關(guān)文章
如何解決tensorflow恢復(fù)模型的特定值時(shí)出錯(cuò)
今天小編就為大家分享一篇如何解決tensorflow恢復(fù)模型的特定值時(shí)出錯(cuò),具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧2020-02-02Python閉包之返回函數(shù)的函數(shù)用法示例
這篇文章主要介紹了 Python閉包之返回函數(shù)的函數(shù)用法示例,小編覺得挺不錯(cuò)的,現(xiàn)在分享給大家,也給大家做個(gè)參考。一起跟隨小編過(guò)來(lái)看看吧2018-01-01Python模仿POST提交HTTP數(shù)據(jù)及使用Cookie值的方法
這篇文章主要介紹了Python模仿POST提交HTTP數(shù)據(jù)及使用Cookie值的方法,通過(guò)兩種不同的實(shí)現(xiàn)方法較為詳細(xì)的講述了HTTP數(shù)據(jù)通信及cookie的具體用法,需要的朋友可以參考下2014-11-11python tkinter控件treeview的數(shù)據(jù)列表顯示的實(shí)現(xiàn)示例
本文主要介紹了python tkinter控件treeview的數(shù)據(jù)列表顯示的實(shí)現(xiàn)示例,文中通過(guò)示例代碼介紹的非常詳細(xì),具有一定的參考價(jià)值,感興趣的小伙伴們可以參考一下2022-01-01python中的信號(hào)通信 blinker的使用小結(jié)
信號(hào)是一種通知或者說(shuō)通信的方式,信號(hào)分為發(fā)送方和接收方,信號(hào)的特點(diǎn)就是發(fā)送端通知訂閱者發(fā)生了什么,今天通過(guò)本文給大家介紹python中的信號(hào)通信 blinker的相關(guān)知識(shí),感興趣的朋友一起看看吧2021-10-10