PyTorch Lightning

白交 發自 凹非寺

量子位 報道 | 公衆號 QbitAI

一直以來,PyTorch就以 簡單又好用 的特點,廣受AI研究者的喜愛。

但是,一旦任務複雜化,就可能會發生一系列錯誤,花費的時間更長。

於是,就誕生了這樣一個“友好”的PyTorch Lightning。

直接在GitHub上斬獲6.6k星。

首先,它把研究代碼與工程代碼相分離,還將PyTorch代碼結構化,更加直觀的展現數據操作過程。

這樣,更加易於理解,不易出錯,本來很冗長的代碼一下子就變得輕便了,對AI研究者十分的友好。

話不多說,我們就來看看這個輕量版的“PyTorch”。

關於Lightning

Lightning將DL/ML代碼分爲三種類型:研究代碼、工程代碼、非必要代碼。

針對不同的代碼,Lightning有不同的處理方式。

這裏的研究代碼指的是特定系統及其訓練方式,比如GAN、VAE,這類的代碼將由LightningModule直接抽象出來。

我們以MNIST生成爲例。

l1 = nn.Linear(...)
l2 = nn.Linear(...)
decoder = Decoder()

x1 = l1(x)
x2 = l2(x2)
out = decoder(features, x)

loss = perceptual_loss(x1, x2, x) + CE(out, x)

而工程代碼是與培訓此係統相關的所有代碼,比如提前停止、通過GPU分配、16位精度等。

我們知道,這些代碼在大多數項目中都相同,所以在這裏,直接由Trainer抽象出來。

model.cuda(0)
x = x.cuda(0)

distributed = DistributedParallel(model)

with gpu_zero:
download_data()

dist.barrier()

剩下的就是非必要代碼,有助於研究項目,但是與研究項目無關,可能是檢查梯度、記錄到張量板。此代碼由Callbacks抽象出來。

# log samples
z = Q.rsample()
generated = decoder(z)
self.experiment.log('images', generated)

此外,它還有一些的附加功能,比如你可以在CPU,GPU,多個GPU或TPU上訓練模型,而無需更改PyTorch代碼的一行;你可以進行16位精度訓練,可以使用Tensorboard的五種方式進行記錄。

這樣說,可能不太明顯,我們就來直觀的比較一下PyTorch與PyTorch Lightning之間的差別吧。

PyTorch與PyTorch Lightning比較

直接上圖。

我們就以構建一個簡單的MNIST分類器爲例,從模型、數據、損失函數、優化這四個關鍵部分入手。

模型

首先是構建模型,本次設計一個3層全連接神經網絡,以28×28的圖像作爲輸入,將其轉換爲數字0-9的10類的概率分佈。

兩者的代碼完全相同。意味着,若是要將PyTorch模型轉換爲PyTorch Lightning,我們只需將nn.Module替換爲pl.LightningModule

也許這時候,你還看不出這個Lightning的神奇之處。不着急,我們接着看。

數據

接下來是數據的準備部分,代碼也是完全相同的,只不過Lightning做了這樣的處理。

它將PyTorch代碼組織成了4個函數,prepare_data、train_dataloader、val_dataloader、test_dataloader

prepare_data

這個功能可以確保在你使用多個GPU的時候,不會下載多個數據集或者對數據進行多重操作。這樣所有代碼都確保關鍵部分只從一個GPU調用。

這樣就解決了PyTorch老是重複處理數據的問題,這樣速度也就提上來了。

train_dataloader, val_dataloader, test_dataloader

每一個都負責返回相應的數據分割,這樣就能很清楚的知道數據是如何被操作的,在以往的教程裏,都幾乎看不到它們的是如何操作數據的。

此外,Lightning還允許使用多個dataloaders來測試或驗證。

優化

接着就是優化。

不同的是,Lightning被組織到配置優化器的功能中。如果你想要使用多個優化器,則可同時返回兩者。

損失函數

對於n項分類,我們要計算交叉熵損失。兩者的代碼是完全一樣的。

此外,還有更爲直觀的——驗證和訓練循環。

在PyTorch中,我們知道,需要你自己去構建for循環,可能簡單的項目還好,但是一遇到更加複雜高級的項目就很容易翻車了。

而Lightning裏這些抽象化的代碼,其背後就是由Lightning裏強大的trainer團隊負責了。

PyTorch Lightning安裝教程

看到這裏,是不是也想安裝下來試一試。

PyTorch Lightning安裝十分簡單。

代碼如下:

conda activate my_env
pip install pytorch-lightning

或在沒有conda環境的情況下,可以在任何地方使用pip。

代碼如下:

pip install pytorch-lightning

創建者也有大來頭

William Falcon,PyTorch Lightning 的創建者,現在在紐約大學的人工智能專業攻讀博士學位,在《福布斯》擔任AI特約作者。

2018年,從哥倫比亞大學計算機科學與統計學專業畢業,本科期間,他還曾輔修數學。

現在已獲得Google Deepmind資助攻讀博士學位的獎學金,去年還收到Facebook AI Research實習邀請。

此外,他還曾是一個海軍軍官,接受過美國海軍海豹突擊隊的訓練。

前不久, 華爾街日報 就曾還曾提到這個團隊,他們正在研究呼吸系統疾病與呼吸模式之間的聯繫。可能會應用到的場景,是通過電話在診斷新冠症狀。目前,該團隊還處在數據收集階段。

果然,優秀的人,幹什麼都是優秀的。嘆氣……

怎麼樣,是不是想試一試?趕緊戳下方鏈接下載來看看吧!

上手傳送門

https://github.com/PyTorchLightning/pytorch-lightning

https://pytorch-lightning.readthedocs.io/en/latest/index.html

創建者個人網站:

https://www.williamfalcon.com/

版權所有,未經授權不得以任何形式轉載及使用,違者必究。

相關文章