背景

在閒魚的很多業務場景中有大量需要利用算法進行分類的需求,例如圖片分類、組件識別、商品分層、糾紛類別預測等。這些場景往往需要模型識別出的結果具備可解釋性,也就是識別不能只得到其類別,最好能在識別過程中同時解釋類別的層級和來源。如何進行有解釋的圖片分類成爲了項目研發中的一個需求,基於此我對NBDT算法進行了調研。

NBDT是UC伯克利和波斯頓大學最新(2020年4月)發的一篇paper中的模型。NBDT全稱“ Neural-Backed Decision Trees ”,翻譯爲“神經支持決策樹”,特別強調此處“B”不代表“Boosting”,以免熟悉GBDT的同學可能會誤以爲NBDT又是一種新型的梯度提升樹模型。NBDT只是一顆決策樹,而不是多棵樹。

介紹

NBDT的特點在於它在決策樹中(準確說是決策樹 )融入了神經網絡NN,這裏NN通常是CNN即卷積神經網絡。個人理解,NBDT的結構可以大致認爲是“前面的CNN + 後面的DT”。DT=決策樹。NBDT目前的使用場景是在圖像分類領域。它的優勢不在於準確率有多高,事實上在作者的實驗中,它的準確率是略低於“前面的CNN”的。它的真正優勢是能夠很好的平衡 模型準確率模型解釋性 。具體來講,它可以在略微犧牲CNN的準確率的前提下,取得比任何樹模型都高的多的(分類)準確率,同時因爲它融入了決策樹,還可以顯式的、逐級的給出模型推斷的依據,也就是說,NBDT不但可以把一張狗的圖片識別爲“狗”,還可以告訴你它是如何一步一步識別的:比如,先把該圖片以99.49%的概率識別爲“動物”,再以99.63%的概率識別成“脊椎動物(Chordate)”,然後以99.4%的概率識別成脊椎動物下的“食肉動物(Carnivore)”,最後以99.88%的概率判斷成食肉動物下的“狗”。這種推斷方式無疑增強了模型的解釋力。

圖1 - 狗狗分類 (引用自官方Demo)

原理

NBDT採用了“預訓練+finetune”的框架。整個流程大致分爲以下三步:

預訓練一個CNN模型,並拿CNN最後一層的權重作爲“每種類別”的隱向量

比如先拿cifar10(一個圖片分類數據集,有“貓”、“狗”之類的10種類別)訓練一個resnet18的CNN。這類CNN的最後一層通常是全連接層(Fully Connected layer, FC),設倒數第二層輸出的向量維度爲d,則該全連接層W的維度爲W,那麼W的每一個列向量正好對應了每一個類別,可以將其視作每一種類別的隱向量。這種做法有點類似於Word2Vec。

利用類別的隱向量做層次聚類(hierarchical clustering)並利用wordnet形成層次樹結構。

論文中將該樹結構稱之爲“誘導層級”(Induced Hierarchy)。具體地,首先對類別隱向量做層次聚類,源碼中是直接調用sklearn模塊的AgglomerativeClustering類實現。聚類的分層結構有了之後,帶來了兩個問題:(1)兩個子節點可以被聚類算法聚到一起,子節點都表示一類實體,但它們的父節點並沒有一個實體的描述。(2)假設兩個子節點被聚到了一起,子節點都有隱向量,它們的父節點的隱向量該怎麼表示?

針對問題(1),作者使用了WordNet,一種包含名詞之間上下位關係的詞網絡,python裏面可以直接在nltk模塊中導入wordnet模塊調用。由於葉節點是存在實體描述的,比方說cifar10的10個類別,那麼通過WordNet,可以找到兩個葉節點“最鄰近的共同祖先”,e.g. “貓”和“狗”在WordNet中可能最近的歸屬是都位於“哺乳動物”下,那麼“哺乳動物”就被作爲“貓”和“狗”的父節點。因此,可以按照層次聚類的結果,自底向上依次爲父節點“命名”,直到只有一個根節點,這就形成了所謂的“誘導層級”,即下圖中的“ Step 1 ”。這個誘導層級也就是上面狗狗圖片中的決策樹。

圖2 - 訓練和推斷 (引用自原Paper)

針對問題(2),作者使用了子節點隱向量的 均值 ,來代表父節點的隱向量。如下圖中的“ Step C ”描述。

圖3 - 構造層次結構 (引用自原Paper)

在總損失中加入誘導層級的分類損失,finetune模型

在誘導層級(樹結構,下稱DT)有了之後,完整的模型不再是CNN,而是CNN+DT。爲了迫使模型對新樣本的預測能夠遵循樹結構從根節點一路推斷至葉節點,就需要在總損失中加入樹結構的分類損失,並對模型做finetune。

這裏首先要理解完整模型預測所採用的方式,我認爲作者在這裏的思路是非常之精髓的。一個新的樣本(一張圖片)進來,首先要經過前面的CNN,在最後一層的全連接層W之前,CNN給該圖片輸出的是一個d維向量x。將x與W做矩陣乘法(實質上是與各列向量做內積),即得到該樣本在各個類別的logits分佈,如果再softmax則得到了概率分佈。由於W的各列向量代表着DT葉節點的隱向量,那麼完全可以用該DT來替換W,不再直接把x與W做矩陣乘法,而是從DT的根節點開始遍歷,讓x依次與DT各節點的子節點隱向量計算內積。這裏遍歷DT各節點有兩種模式:“ Hard ”和“ Soft ”。以DT是二叉樹爲例,若是Hard模式,那麼每次x會與左右兩邊的子節點分別算內積,哪邊大就把x歸爲哪一邊,一直計算到葉節點爲止,最後x落到的葉節點,即爲x所屬的最終類別。若是Soft模式,則x會自頂向下遍歷全部中間節點並計算內積,然後葉節點的最終概率是到達葉節點的路徑上各中間節點的概率之乘積,最後通過比較各葉節點上的最終概率值的大小,即可確定x所屬類別。

圖4 - 節點概率計算 (引用自原Paper)

在理解了完整模型預測的細節之後,就可以來解釋“誘導層級(樹結構)的分類損失”。相對應的,損失函數同樣有“Hard”和“Soft”兩種模式,如下圖所示。若是Hard模式的損失,那麼Loss只會累加樣本所屬葉節點在DT中真實路徑上的每個節點的分類損失(以一定權重),非真實路徑(下圖A虛線節點w3/w4)則不會計入,此處每個節點的分類損失使用交叉熵計算。若是Soft模式的損失,則是直接計算葉節點上的最終概率分佈與真實onehot分佈的交叉熵作爲Loss。 簡言之,Hard模式損失函數計算的是“路徑交叉熵”,Soft模式則計算的是“葉節點交叉熵” 。在pytorch中的交叉熵計算方式爲:

最終模型的總損失還會考慮原始CNN的分類損失Lossoriginal,因此最後交由finetune階段進行優化的總損失爲:

根據我對源碼的閱讀,Loss進行BP反向傳播時優化的依然是CNN的網絡權重,直觀上理解:就是迫使前面CNN的輸出能夠符合後面DT的預期,儘可能使得樣本按照DT的推斷路徑輸出的預測類別符合其真實類別。

圖5 - Hard和Soft模式下的損失 (引用自原Paper)

源碼解析

NBDT的python代碼開源在github ,整體上使用pytorch和networkx實現,我統計了下總共大概有4000+行,核心腳本是 model.py/loss.py/graph.py/hierarchy.py  四個。代碼基本沒有註釋和參數釋義,讀起來頗爲費力,花了好幾天纔看完。以下對最核心的幾段代碼做解析。

生成誘導層級

核心函數爲buildinducedgraph,其作用是輸入葉節點的WordNet ID和CNN模型,通過從CNN模型獲取到FC的權重,然後做層次聚類,利用WordNet對聚類結果“命名”,形成樹節點有實體含義的DT。此函數對應本文原理細節的②部分。詳細解釋如下:

前向計算節點概率

前面提到新樣本進來後會先經過CNN,在FC之前會輸出d維向量x,然後x與DT的各個節點的隱向量做內積,而各節點的隱向量又等於其子節點隱向量的均值。getnodelogits方法在這裏做了一個優化:考慮到向量均值的內積等於向量內積的均值(如下圖公式),因此不必顯示的去求隱向量再做內積,而是對某個節點,直接把其子節點的logits求均值作爲它本身的logits。具體代碼如下:

總損失函數

前面提到,總損失=原始CNN損失+樹結構損失。具體地,以Hard模式爲例,如下代碼解釋瞭如何計算決策路徑上的樹結構損失,併合併到總損失當中。

論文實驗

在多個數據集上,作者拿原始CNN(WiderResnet28×10)和多個“可解釋”的神經網絡模型做了對比,從下表可以看到,NBDT精度僅僅比原始CNN略低,但已經遠遠超過其它模型,說明NBDT已達到SOTA。 而在NBDT中,Soft模式的分數要高於Hard模式 ,這個好理解,因爲Soft考慮的是全局最優,Hard考慮的則是連續多次局部最優。

圖6 - 實驗結果 (引用自原Paper)

使用

安裝和使用詳見官方github,此處僅對常用方式做總結

命令行預測

直接調用  nbdt 命令,後面跟圖片路徑(url或本地路徑)。第一次執行會下載WordNet和官方預訓練模型。由於該預訓練模型是針對cifar10數據集的,因此儘量輸入一張屬於這十類之一的圖片。從輸出中可以看到,預測行爲是“逐級進行”的。

在phthon中預測

完整使用方式

後續計劃

調研NBDT的目的是尋找一種讓分類問題變得可解釋的方法,這種可解釋性可以應用在任何分類過程中需要給出決策路徑的場景。儘管作者在論文中介紹的應用場景是圖片分類,但只要把前面的CNN替換成其他網絡,那麼實際上任何分類問題都可以利用NBDT做出解釋。比如在閒魚優質商品分層項目中,我們可以基於業務知識構造商品間的誘導層級(例如第一層分爲專業賣家/個人賣家、第二層分爲動銷率高/中/低...最後一層分爲商品不同的優質等級等等),然後基於層級結構訓練NBDT做分類。再比如一個典型的圖片分類場景,賣家在閒魚上上傳一張圖片,希望算法能自動判斷出他想賣什麼類別的商品,有可能他上傳了一張“椅子”和一張“桌子”的圖片,但其實他想賣的是“傢俱”。那麼基於層級結構的NBDT就能自動把他發佈的商品識別爲“傢俱”,或者提供推薦的備選項讓用戶自己選擇他想要賣的是哪一層大類,省去了手動填寫的麻煩。這些都是NBDT可以在後續中嘗試的實踐。

參考

  • 論文:https://arxiv.org/abs/2004.00221

  • 源碼:https://github.com/alvinwan/neural-backed-decision-trees

閒魚技術團隊不僅是阿里巴巴集團旗下閒置交易社區的創造者,更是移動與高併發大數據應用新技術的引導者與創新者。我們與 Google Flutter/Dart 小組密切合作,爲社區貢獻了多個高 star 的項目和大量 PR 。我們正在積極探索深度學習和視覺技術在互動、交易、社區場景的創新應用。閒魚技術與集團中間件團隊共同打造的 FaaS 平臺每天支持數以千萬級用戶的高併發訪問場景。  

就是現在! 客戶端/服務端java/架構/前端/質量 工程師 面向社會+校園招聘,base杭州阿里巴巴西溪園區,一起做有創想空間的社區產品、做深度頂級的開源項目,一起拓展技術邊界成就極致!

*投餵簡歷給小閒魚→ [email protected]

開源項目、峯會直擊、關鍵洞察、深度解讀

請認準 閒魚技術

相關文章