歡迎關注“新浪科技”的微信訂閱號:techsina 

編/拉燕 如願 好睏

來源:新智元(ID:AI_era)

【新智元導讀】谷歌Meta之爭看來還沒完!TensorFlow幹不過還有JAX,二番戰能否戰勝PyTorch?

很喜歡有些網友的一句話:

‘這孩子實在不行,咱再要一個吧。’

谷歌還真這麼幹了。

養了七年的TensorFlow終於還是被Meta的PyTorch幹趴下了,在一定程度上。

谷歌眼見不對,趕緊又要了一個——‘JAX’,一款全新的機器學習框架。

最近超級火爆的DALL·E Mini都知道吧,它的模型就是基於JAX進行編程的,從而充分地利用了谷歌TPU帶來的優勢。

TensorFlow的黃昏和PyTorch的崛起

2015年,谷歌開發的機器學習框架——TensorFlow問世。

當時,TensorFlow只是Google Brain的一個小項目。

誰也沒有想到,剛一問世,TensorFlow就變得非常火爆。

優步愛彼迎這種大公司在用,NASA這種國家機構也在用。而且還都是用在他們各自最爲複雜的項目上。

而截止到2020年11月,TensorFlow的下載次數已經達到了1.6億次。

不過,谷歌好像並沒有十分在乎這麼多用戶的感受。

奇奇怪怪的界面和頻繁的更新都讓TensorFlow對用戶越來越不友好,並且越來越難以操作。

甚至,就連谷歌內部,也覺得這個框架在走下坡路。

其實谷歌如此頻繁的更新也實屬無奈,畢竟只有這樣才能追得上機器學習領域快速地迭代。

於是,越來越多的人加入了這個項目,導致整個團隊慢慢失去了重點。

而原本讓TensorFlow成爲首選工具的那些閃光點,也被埋沒在了茫茫多的要素裏,不再受人重視。

這種現象被Insider形容爲一種‘貓鼠遊戲’。公司就像是一隻貓,不斷迭代出現的新需求就像是一隻只老鼠。貓要時刻保持警惕,隨時撲向老鼠。

這種困局對最先打入某一市場的公司來說是避不開的。

舉個例子,就搜索引擎來說,谷歌並不是第一家。所以谷歌能夠從前輩(AltaVista、Yahoo等等)的失敗中總結經驗,應用在自身的發展上。

可惜到了TensorFlow這裏,谷歌是被困住的那一個。

正是因爲上面這些原因,原先給谷歌賣命的開發者,慢慢對老東家失去了信心。

昔日無處不在的TensorFlow漸漸隕落,敗給了Meta的後起之秀——PyTorch。

2017年,PyTorch的測試版開源。

2018年,Facebook的人工智能研究實驗室發佈了PyTorch的完整版本。

值得一提的是,PyTorch和TensorFlow都是基於Python開發的,而Meta則更注重維護開源社區,甚至不惜大量投入資源。

而且,Meta關注到了谷歌的問題所在,認爲不能重蹈覆轍。他們專注於一小部分功能,並把這些功能做到最好。

Meta並沒有步谷歌的後塵。這款首先在Facebook開發出來的框架,慢慢成爲了行業標杆。

一家機器學習初創公司的研究工程師表示,‘我們基本都用PyTorch。它的社羣和開源做得是最出色的。不僅有問必答,給的例子也很實用。’

面對這種局面,谷歌的開發者、硬件專家、雲提供商,以及任何和谷歌機器學習相關的人員在接受採訪時都說了一樣的話,他們認爲TensorFlow失掉了開發者的心。

經歷了一系列的明爭暗鬥,Meta最終佔了上風。

有專家表示,谷歌未來繼續引領機器學習的機會正慢慢流失。

PyTorch逐漸成爲了尋常開發者和研究人員的首選工具。

從Stack Overflow提供的互動數據上看,在開發者論壇上有關PyTorch的提問越來越多,而關於TensorFlow的最近幾年一直處於停滯狀態。

就連文章開始提到的優步等等公司也轉向PyTorch了。

甚至,PyTorch後來的每一次更新,都像是在打TensorFlow的臉。

谷歌機器學習的未來——JAX

就在TensorFlow和PyTorch打得熱火朝天的時候,谷歌內部的一個‘小型黑馬研究團隊’開始致力於開發一個全新的框架,可以更加便捷地利用TPU。

2018年,一篇題爲《Compiling machine learning programs via high-level tracing》的論文,讓JAX項目浮出水面,作者是Roy Frostig、Matthew James Johnson和Chris Leary。

從左至右依次是這三位大神

而後,PyTorch原始作者之一的Adam Paszke,也在2020年初全職加入了JAX團隊。

JAX提供了一個更直接的方法用於處理機器學習中最複雜的問題之一:多核處理器調度問題。

根據所應用的情況,JAX會自動地將若干個芯片組合而成一個小團體,而不是讓一個去單打獨鬥。

如此帶來的好處就是,讓儘可能多的TPU片刻間就能得到響應,從而燃燒我們的‘煉丹小宇宙’。

最終,相比於臃腫的TensorFlow,JAX解決了谷歌內部的一個心頭大患:如何快速訪問TPU。

下面簡單介紹一下構成JAX的Autograd和XLA。

Autograd主要應用於基於梯度的優化,可以自動區分Python和Numpy代碼。

它既可以用來處理Python的一個子集,包括循環、遞歸和閉包,也可以對導數的導數進行求導。

此外,Autograd支持梯度的反向傳播,這也就這意味着它可以有效地獲取標量值函數相對於數組值參數的梯度,以及前向模式微分,並且兩者可以任意組合。

XLA(Accelerated Linear Algebra)可以加速TensorFlow模型而無需更改源代碼。

當一個程序運行時,所有的操作都由執行器單獨執行。每個操作都有一個預編譯的GPU內核實現,執行器會分派到該內核實現。

舉個栗子:

def model_fn(x, y, z): return tf.reduce_sum(x + y * z)

在沒有XLA的情況下運行,該部分會啓動三個內核:一個用於乘法,一個用於加法,一個用於減法。

而XLA可以通過將加法、乘法和減法‘融合’到單個GPU內核中,從而實現優化。

這種融合操作不會將由內存產生的中間值寫入y*z內存x+y*z;相反,它將這些中間計算的結果直接‘流式傳輸’給用戶,同時將它們完全保存在GPU中。

在實踐中,XLA可以實現約7倍的性能改進和約5倍的batch大小改進。

此外,XLA和Autograd可以任意組合,甚至可以利用pmap方法一次使用多個GPU或TPU內核進行編程。

而將JAX與Autograd和Numpy相結合的話,就可以獲得一個面向CPU、GPU和TPU的易於編程且高性能的機器學習系統了。

顯然,谷歌這一次吸取了教訓,除了在自家全面鋪開以外,在推進開源生態的建設方面,也是格外地積極。

2020年DeepMind正式投入JAX的懷抱,而這也宣告了谷歌親自下場,自此之後各種開源的庫層出不窮。

縱觀整場‘明爭暗鬥’,賈揚清表示,在批評TensorFlow的進程中,AI系統認爲Pythonic的科研就是全部需求。

但一方面純Python無法實現高效的軟硬協同設計,另一方面上層分佈式系統依然需要高效的抽象。

而JAX正是在尋找更好的平衡,谷歌這種願意顛覆自己的pragmatism非常值得學習。

causact R軟件包和相關貝葉斯分析教科書的作者表示,自己很高興看到谷歌從TF過渡到JAX,一個更乾淨的解決方案。

谷歌的挑戰

作爲一個新秀,Jax雖然可以借鑑PyTorch和TensorFlow這兩位老前輩的優點,但有的時候後發可能也會帶來劣勢。

首先,JAX還太‘年輕’,作爲實驗性的框架,遠沒有達到一個成熟的谷歌產品的標準。

除了各種隱藏的bug以外,JAX在一些問題上仍然要依賴於其他框架。

拿加載和預處理數據來說,就需要用TensorFlow或PyTorch來處理大部分的設置。

顯然,這和理想的‘一站式’框架還相去甚遠。

其次,JAX主要針對TPU進行了高度的優化,但是到了GPU和CPU上,就要差得多了。

一方面,谷歌在2018年至2021年組織和戰略的混亂,導致在對GPU進行支持上的研發的資金不足,以及對相關問題的處理優先級靠後。

與此同時,大概是過於專注於讓自家的TPU能在AI加速上分得更多的蛋糕,和英偉達的合作自然十分匱乏,更不用說完善對GPU的支持這種細節問題了。

另一方面,谷歌自己的內部研究,不用想肯定都集中在TPU上,這就導致谷歌失去了對GPU使用的良好反饋迴路。

此外,更長的調試時間、並未與Windows兼容、未跟蹤副作用的風險等等,都增加了Jax的使用門檻以及友好程度。

現在,PyTorch已經快6歲了,但完全沒有TensorFlow當年顯現出的頹勢。

如此看來,想要後來者居上的話,Jax還有很長一段路要走。

參考資料:

https://www.businessinsider.com/facebook-pytorch-beat-google-tensorflow-jax-meta-ai-2022-6

相關文章