摘要:def query(node, data, answers, K): # 判斷node是否已經訪問過 if node.visited: # 標記訪問 node.visited = True # 計算data到node中點的距離 dis = cal_dis(data, node.point) # 如果小於答案中最大值則更新答案 if dis < max(answers): answers.update(node.point) # 計算data到分割線的距離 dis = cal_dis(data, node.split) # 如果小於最長距離,說明另一側還可能有答案 if dis < max(answers): # 獲取當前節點的兄弟節點 brother = self.get_brother(node) if brother is not None: # 往下搜索到葉子節點,從葉子節點開始尋找 leaf = self.iter_down(brother, data) if leaf is not None: return self.query(leaf, data, answers, K) # 如果已經到了根節點了,退出 if node is root: return answers # 遞歸父親節點 return self.query(node.father, data, answers, K) else: if node is root: return answers return self.query(node.father, data, answers, K)。首先我們先通過 遞歸查找到KD-Tree上的葉子節點 ,也就是找到樣本所在的子空間。

今天是機器學習的 第15篇文章 ,之前的文章當中講了Kmeans的相關優化,還講了大名鼎鼎的EM算法。有些小夥伴表示喜歡看這些硬核的,於是今天上點硬菜,我們來看一個機器學習領域經常用到的數據結構—— KD-Tree

從線段樹到KD樹

在講KD樹之前,我們先來了解一下 線段樹 的概念。線段樹在機器學習領域當中不太常見,作爲高性能維護的數據結構,經常出現在各種算法比賽當中。線段樹的本質是一棵維護一段區間的平衡二叉樹。

比如下圖就是一個經典的線段樹:

從下圖當中我們不難看出來,這棵線段樹維護的是一個 區間內的最大值 。比如樹根是8,維護的是整個區間的最大值,每一箇中間節點的值都是以它爲樹根的子樹中所有元素的最大值。

通過線段樹,我們可以在 的時間內計算出某一個 連續區間的最大值 。比如我們來看下圖:

當我們要求被框起來的區間中的最大值,我們只需要 找到能夠覆蓋這個區間的中間節點 就行。我們可以發現被紅框框起來的兩個節點的子樹剛好覆蓋這個區間,於是整個區間的最大值,就是這兩個元素的最大值。這樣,我們就把一個需要 查找的問題降低成了 ,不但如此,我們也可以 做到 複雜度內的更新 ,也就是說我們不但可以快速查詢,還可以更新線段當中的元素。

當然線段樹的應用非常廣泛,也有 許多種變體 ,這裏我們不過多深入,感興趣的同學可以期待一下週三的算法與數據結構專題,在之後的文章當中會爲大家分享線段樹的相關內容。在這裏,我們只需要有一個大概的印象,線段樹究竟完成的是什麼樣的事情即可。

線段樹維護的是一個線段,也就是區間內的元素,也就是說維護的是一個一維的序列。如果我們將數據的維度擴充一下,擴充到多維呢?

是的,你沒有猜錯,從某種程度上來說,我們可以把KD-Tree看成是線段樹 拓展到多維空間 當中的情況。

KD-Tree定義

我們來看一下KD-Tree的具體定義,這裏的K指的是 K維空間 ,D自然就是dimension,也就是維度,也就是說KD-Tree就是K維度樹的意思。

在我們構建線段樹的時候,其實是一個遞歸的建樹過程,我們每次把當前的線段一分爲二,然後用分成兩半的數據分別構建左右子樹。我們可以簡單寫一下僞代碼,來更直觀地感受一下:

class Node:
    def __init__(self, value, lchild, rchild):
        self.value = value
        self.lchild = lchild
        self.rchild = rchild   
        
def build(arr):
    n = len(arr):
    left, right = arr[: n//2], arr[n//2:]
    lchild, rchild = self.build(left), self.build(right)
    return Node(max(lchild.value, rchild.value), lchild, rchild)

我們來看一個二維的例子,在一個二維的平面當中分佈着若干個點。

我們首先選擇一個維度 將這些數據一分爲二 ,比如我們選擇x軸。我們對所有數據按照x軸的值排序,選出其中的中點進行一分爲二。

在這根線左右兩側的點被分成了兩棵子樹,對於這兩個部分的數據來說,我們 更換一個維度 ,也就是選擇y軸進行劃分。一樣,我們先排序,然後找到中間的點,再次一分爲二。我們可以得到:

我們 重複上述過程 ,一直將點分到不能分爲止,爲了能更好地看清楚,我們對所有數據標上座標(並不精確)。

如果我們把空間看成是廣義的區間,那麼它和線段樹的原理是一樣的。最後得到的也是一棵 完美二叉樹 ,因爲我們每次都選擇了數據集的中點進行劃分,可以保證從樹根到葉子節點的長度不會超過

我們代入上面的座標之後,我們最終得到的KD-Tree大概是下面這個樣子:

KD-Tree 建樹

在建樹的過程當中,我們的樹深每往下延伸一層,我們就會換一個維度作爲衡量標準。原因也很簡單,因爲我們希望這棵樹對於這K維空間都有很好的表達能力,方便我們根據不同的維度快速查詢。

在一些實現當中,我們會計算每一個維度的方差,然後選擇方差較大的維度進行切分。這樣做自然是因爲方差較大的維度說明數據相對分散,切分之後可以把數據區分得更加明顯。但我個人覺得這樣做意義不是很大,畢竟計算方差也是一筆開銷。所以這裏我們選擇了最樸素的方法——輪流選擇。

也就是說我們從樹根開始,選擇第0維作爲排序和切分數據的依據,然後到了樹深爲1的這一層,我們選擇第一維,樹深爲2的這一層,我們選擇第二維,以此類推。當樹深超過了K的時候,我們就對樹深取模。

明確了這一點之後,我們就可以來寫KD-Tree的建樹代碼了,和上面二叉樹的代碼非常相似,只不過多了維度的處理而已。

    # 外部暴露接口
    def build_model(self, dataset):
        self.root = self._build_model(dataset)
        # 先忽略,容後再講
        self.set_father(self.root, None)

    # 內部實現的接口
    def _build_model(self, dataset, depth=0):
        if len(dataset) == 0:
            return None

        # 通過樹深對K取模來獲得當前對哪一維切分
        axis = depth % self.K
        m = len(dataset) // 2
        # 根據axis這一維排序
        dataset = sorted(dataset, key=lambda x: x[axis])
        # 將數據一分爲二
        left, mid, right = dataset[:m], dataset[m], dataset[m+1:]

        # 遞歸建樹
        return KDTree.Node(
            mid[axis],
            mid,
            axis,
            depth,
            len(dataset),
            self._build_model(left, depth+1),
            self._build_model(right, depth+1)
        )

這樣我們就建好了樹,但是在後序的查詢當中我們需要訪問節點的父節點,所以我們需要爲每一個節點都賦值指向父親節點的指針。這個值我們可以寫在建樹的代碼裏,但是會稍稍複雜一些,所以我把它單獨拆分了出來,作爲一個獨立的函數來給每一個節點賦值。對於根節點來說,由於它沒有父親節點,所以賦值爲None。

我們來看下set_father當中的內容,其實很簡單,就是一個樹的遞歸遍歷:

def set_father(self, node, father):
    if node is None:
        return
    # 賦值
    node.father = father
    # 遞歸左右
    self.set_father(node.lchild, node)
    self.set_father(node.rchild, node)

快速批量查詢

KD-Tree建樹建好了肯定是要來用的,它最大的用處是可以 在單次查詢中獲得距離樣本最近的若干個樣本 。在分散均勻的數據集當中,我們可以在 的時間內完成查詢,但是對於特殊情況可能會長一些,但是也比我們通過樸素的方法查詢要快得多。

我們很容易發現,KD-Tree一個廣泛的使用場景是用來 優化KNN算法 。我們在之前介紹KNN算法的文章當中曾經提到過,KNN算法在預測的時候需要遍歷整個數據集,然後計算數據集中每一個樣本與當前樣本的距離,選出最近的K個來,這需要大量的開銷。而使用KD-Tree,我們可以在一次查詢當中 直接查找到K個最近的樣本 ,因此大大提升KNN算法的效率。

那麼,這個查詢操作又是怎麼實現的呢?

這個查詢基於遞歸實現,因此對於遞歸不熟悉的小夥伴,可能初看會比較困難,可以先閱讀一下之前關於遞歸的文章。

首先我們先通過 遞歸查找到KD-Tree上的葉子節點 ,也就是找到樣本所在的子空間。這個查找應該非常容易,本質上來說我們就是將當前樣本不停地與分割線進行比較,看看是在分割線的左側還是右側。和二叉搜索樹的元素查找是一樣的:

    def iter_down(self, node, data):
        # 如果是葉子節點,則返回
        if node.lchild is None and node.rchild is None:
            return node
        # 如果左節點爲空,則遞歸右節點
        if node.lchild is None:
            return self.iter_down(node.rchild, data)
        # 同理,遞歸左節點
        if node.rchild is None:
            return self.iter_down(node.rchild, data)
        # 都不爲空則和分割線判斷是左還是右
        axis = node.axis
        next_node = node.lchild if data[axis] <= node.boundray else node.rchild
        return self.iter_down(next_node, data)

我們找到了葉子節點,其實 代表樣本空間當中的一小塊空間

我們來實際走一下整個流程,假設我們要查找3個點。首先,我們會創建一個候選集,用來存儲答案。當我們找到葉子節點之後,這個區域當中只有一個點,我們把它加入候選集。

在上圖當中紫色的x代表我們查找的樣本,我們查找到的葉子節點之後,在兩種情況下我們會把當前點加入候選集。第一種情況是 候選集還有空餘 ,也就是還沒有滿K個,這裏的K是我們查詢的數量,也就是3。第二種情況是 當前點到樣本的距離小於候選集中最大的一個 ,那麼我們需要更新候選集。

這個點被我們訪問過之後,我們會 打上標記 ,表示這個點已經訪問過了。這個時候我們需要判斷,整棵樹當中的搜索是否已經結束,如果當前節點已經是根節點了,說明我們的遍歷結束了,那麼返回候選集,否則說明還沒有,我們需要繼續搜索。上圖當中我們用 綠色表示樣本被放入了候選集當中,黃色表示已經訪問過。

由於我們的搜索還沒有結束,所以需要繼續搜索。繼續搜索需要判斷 樣本和當前分割線的距離 來判斷和分割線的另一側有沒有可能存在答案。由於葉子節點沒有另一側,所以作罷,我們往上移動一個,跳轉到它的父親節點。

我們計算距離並且查看候選集,此時候選集未滿,我們加入候選集,標記爲已經訪問過。它雖然存在分割線,但是也 沒有另一側的節點 ,所以也跳過。

我們再往上,遍歷到它的父親節點,我們執行同樣的判斷,發現此時候選集還有空餘,於是將它繼續加入答案:

但是當我們判斷到分割線距離的時候,我們發現這一次 樣本到分割線的舉例要比之前候選集當中的最大距離要小 ,所以分割線的另一側很有可能存在答案:

這裏的d1是樣本到分割線的距離,d2是樣本到候選集當中最遠點的距離。由於到分割線更近,所以分割線的另一側很有可能也存在答案,這個時候我們需要 搜索分割線另一側的子樹 ,一直搜索到葉子節點。

我們找到了葉子節點,計算距離,發現此時 候選集已經滿了 ,並且它的距離大於候選集當中任何一個答案,所以不能構成新的答案。於是我們只是標記它已經訪問過,並不會加入候選集。同樣,我們繼續往上遍歷,到它的父節點:

比較之後發現,data到它的 距離小於候選集當中最大的那個 ,於是我們更新候選集,去掉距離大於它的答案。然後我們重複上述的過程,直到根節點爲止。

由於後面沒有更近的點,所以候選集一直沒有更新,最後上圖當中的三個打了綠標的點就是答案。

我們把上面的流程整理一下,就得到了遞歸函數當中的邏輯,我們用Python寫出來其實已經和代碼差不多了:

def query(node, data, answers, K):
    # 判斷node是否已經訪問過
    if node.visited:
        # 標記訪問
        node.visited = True
        # 計算data到node中點的距離
        dis = cal_dis(data, node.point)
        # 如果小於答案中最大值則更新答案
        if dis < max(answers):
            answers.update(node.point)
        # 計算data到分割線的距離
        dis = cal_dis(data, node.split)
        # 如果小於最長距離,說明另一側還可能有答案
        if dis < max(answers):
            # 獲取當前節點的兄弟節點
            brother = self.get_brother(node)
            if brother is not None:
                # 往下搜索到葉子節點,從葉子節點開始尋找
                leaf = self.iter_down(brother, data)
                if leaf is not None:
                    return self.query(leaf, data, answers, K)
        # 如果已經到了根節點了,退出
        if node is root:
            return answers
        # 遞歸父親節點
        return self.query(node.father, data, answers, K)
    else:
        if node is root:
            return answers
        return self.query(node.father, data, answers, K)

最終寫成的代碼和上面這段並沒有太多的差別,在得到距離之後和答案當中的最大距離進行比較的地方,我們使用了 優先隊列 。其他地方几乎都是一樣的,我也貼上來給大家感受一下:

def _query_nearest_k(self, node, path, data, topK, K):
    # 我們用set記錄訪問過的路徑,而不是直接在節點上標記
    if node not in path:
        path.add(node)
        # 計算歐氏距離
        dis = KDTree.distance(node.value, data)
        if (len(topK) < K or dis < topK[-1]['distance']):
            topK.append({'node': node, 'distance': dis})
            # 使用優先隊列獲取topK
            topK = heapq.nsmallest(K, topK, key=lambda x: x['distance'])
        axis = node.axis
        # 分割線都是直線,直接計算座標差
        dis = abs(data[axis] - node.boundray)
        if len(topK) < K or dis <= topK[-1]['distance']:
            brother = self.get_brother(node, path)
            if brother is not None:
                next_node = self.iter_down(brother, data)
                if next_node is not None:
                    return self._query_nearest_k(next_node, path, data, topK, K)
        if node == self.root:
            return topK
        return self._query_nearest_k(node.father, path, data, topK, K)
    else:
        if node == self.root:
            return topK
        return self._query_nearest_k(node.father, path, data, topK, K)

這段邏輯大家應該都能看明白,但是有一個疑問是,我們 爲什麼不在node裏面加一個visited的字段,而是通過傳入一個set來維護訪問過的節點呢 ?這個邏輯只看代碼是很難想清楚的,必須要親手實驗纔會理解。如果在node當中加入一個字段當然也是可以的,如果這樣做的話,在我們執行查找之後必須得手動再執行一次遞歸,將樹上所有節點的node全部置爲false,否則下一次查詢的時候,會有一些節點已經被標記成了True,顯然會影響結果。查詢之後將這些值手動還原會帶來開銷,所以才轉換思路使用set來進行訪問判斷。

這裏的iter_down函數和我們上面貼的查找葉子節點的函數是一樣的,就是查找當前子樹的葉子節點。如果我沒記錯的話,這也是我們文章當中 第一次出現在遞歸當中調用另一個遞歸 的情況。對於初學者而言,這在理解上可能會相對困難一些。我個人建議可以親自動手試一試在紙上畫一個kd-tree進行手動模擬試一試,自然就能知道其中的運行邏輯了。這也是一個思考和學習非常好用的方法。

優化

當我們理解了整個kd-tree的建樹和查找的邏輯之後,我們來 考慮一下優化

這段代碼看下來初步可以找到兩個可以優化的地方,第一個地方是我們建樹的時候。我們每次遞歸的時候由於要將數據一分爲二,我們是 使用了排序 的方法來實現的,而每次排序都是 的複雜度,這其實是不低的。其實仔細想想,我們沒有必要排序,我們 只需要選出根據某個軸排序前n/2個數 。也就是說這是一個選擇問題,並不是排序問題,所以可以想到我們可以利用之前講過的快速選擇的方法來優化。使用快速選擇,我們可以在 的時間內完成數據的拆分。

另一個地方是我們在查詢K個鄰近點的時候,我們使用了 優先隊列 維護的候選集當中的答案,方便我們對答案進行更新。同樣,優先隊列獲取topK也是 的複雜度。這裏也是可以優化的,比較好的思路是 使用堆來代替 。可以做到 的插入和彈出,相比於heapq的nsmallest方法要效率更高。

總結

到這裏,我們關於KD-tree的原理部分已經差不多講完了,我們有了建樹和查詢功能之後就可以用在KNN算法上進行優化了。但是我們現在的KD-tree只支持建樹以及查詢,如果我們 想要插入或者刪除集合當中的數據應該怎麼辦 ?難道每次修改都重新建樹嗎?這顯然不行,但是插入和刪除節點都會引起樹結構的變化很有可能導致樹不再平衡,這個時候我們應該怎麼辦呢?

我們先賣個關子,相關的內容將會放到下一篇文章當中,感興趣的同學不要錯過哦。

最後,我把KD-tree完整的代碼放在了ubuntu.paste上,想要查看完整源碼的同學請在公衆號內回覆 kd-tree 進行查看。

今天的文章就是這些,如果覺得有所收穫,請順手點個 關注或者轉發 吧,你們的舉手之勞對我來說很重要。

相關文章