假設我們有一個臥室圖像數據集和一個在這個數據集上訓練的圖像分類器CNN,它告訴我們給定的輸入圖像是否是臥室。假設圖像大小爲16 * 16。每個像素可以有256個可能的值。所以存在無限大量的可能輸入(即25616*16或~10616可能的組合)。這使得我們的分類器模型成爲一個高維概率分佈函數,它給出了來自這個大輸入空間的給定輸入作爲臥室的概率。

那麼,如果我們可以從臥室圖像的數據分佈中學習這種高維知識來進行分類,我們肯定能夠利用相同的知識甚至生成全新的臥室圖像。

雖然有多種生成建模方法,但我們將在本文中探討生成對抗網絡。GAN的原始論文(arxiv.org/pdf/1406.2661.pdf)發表於2014年,在這篇論文(arxiv.org/pdf/1511.06434.pdf)引入了深度卷積生成對抗網絡(deep tional generate adversarial networks, DCGAN),並作爲一種流行的參考。這篇文章是基於這兩篇論文的研究,對GAN做了很好的介紹。

GAN是同時訓練生成模型G和判別模型D的網絡。生成模型將通過捕獲與訓練數據集相關的數據分佈來生成新的臥室圖像。訓練判別模型將給定的輸入圖像正確分類爲真實(即來自訓練數據集圖像)或虛假(即生成模型生成的合成圖像)。簡單的說,判別模型是典型的CNN圖像分類器模型,或者更具體地說是二元圖像分類器。

生成模型與判別模型略有不同。它的目標不是分類而是生成。當判別模型給出一個代表不同的類的激活向量時給出一個輸入圖像,生成模型就會反向執行。

瞭解GAN背後的設計,訓練,損失函數和算法

生成與判別模型

它可以被認爲是反向CNN,在某種意義上它將隨機數的向量作爲輸入並生成圖像作爲輸出,而正常的CNN則相反地將圖像作爲輸入並生成數字向量或激活(對應於不同的類)作爲輸出。

但是這些不同的模型如何協同工作?下圖給出了網絡的圖示。首先,我們將隨機噪聲向量作爲生成模型的輸入,生成模型生成圖像輸出。我們將這些生成的圖像稱爲僞圖像或合成圖像。然後判別模型將訓練數據集中的假圖像和真圖像都作爲輸入,並生成一個輸出來分類圖像是假圖像還是真圖像。

瞭解GAN背後的設計,訓練,損失函數和算法

生成敵對網絡的說明

使用這兩個模型對該網絡的參數進行訓練和優化。判別模型的目標是最大限度地正確分類圖像的真僞。相反,生成模型的目標是最小化判別器正確地將假圖像分類爲假的。

反向傳播和普通卷積神經網絡(CNN)一樣,是用來訓練網絡參數的,但是由於涉及到兩個目標不同的模型,使得反向傳播的應用有所不同。更具體地說,涉及的損失函數和在每個模型上執行的迭代次數是GAN不同的兩個關鍵領域。

判別模型的損失函數將是一個與二元分類器相關的正則交叉熵損失函數。根據輸入圖像,損失函數中的一項或另一項將爲0,結果將是模型預測圖像被正確分類的概率的負對數。換句話說,在我們的上下文中,對於真實圖像,“y”將爲“1”,對於假圖像,“1-y”將爲“1”。“p”是圖像是真實圖像的預測概率,“1-p”是圖像是假圖像的預測概率。

瞭解GAN背後的設計,訓練,損失函數和算法

二元分類器的交叉熵損失

上面的概率p可以表示爲D(x),即判別器D估計的圖像“x”是真實圖像的概率。重寫,如下圖所示:

瞭解GAN背後的設計,訓練,損失函數和算法

根據我們如何分配上下文,方程的第一部分將被激活,第二部分對於真實圖像將爲零。反之亦然。第二部分中圖像“x”的表示可以用“G(z)”代替。也就是說,在給定輸入z的情況下,將假圖像表示爲模型G的輸出。“z”只不過是建模“G”產生“G(z)”的隨機噪聲輸入向量。這些符號在初看時令人困惑,但論文中的算法通過“ascending”其隨機梯度來更新判別器,這與上文所述的最小化損失函數相同。下面是論文中函數的快照:

瞭解GAN背後的設計,訓練,損失函數和算法

回到生成函數G,G的損失函數將反過來,即最大化D的損失函數。但是等式的第一部分對生成器沒有任何意義,所以我們真正說的是第二部分應該最大化。所以G的損失函數與D的損失函數相同,只是符號顛倒了,第一項被忽略了。

瞭解GAN背後的設計,訓練,損失函數和算法

生成器的損失函數

以下是論文中生成器損失函數的快照:

瞭解GAN背後的設計,訓練,損失函數和算法

正如DCGAN內容所示,這是通過重塑和轉置卷積的組合來實現的。以下是生成器的表示:

瞭解GAN背後的設計,訓練,損失函數和算法

DCGAN生成器

轉置卷積與卷積的逆不同,它不能恢復給定卷積輸出的輸入,只是改變了卷積的形狀。下面的例子說明了上述生成器模型背後的數學原理,特別是卷積層。

瞭解GAN背後的設計,訓練,損失函數和算法

在CNN中使用的常規卷積圖示,以及通過轉置卷積實現的上採樣的兩個示例。第一示例的結果用作第二和第三示例中具有相同內核的輸入,以證明轉換與反捲積不相同,並不是爲了恢復原始輸入。

論文算法的內部for循環。這意味着,對於k> 1,我們在G的每次迭代中對判別器D執行多次訓練迭代。這是爲了確保D'被充分訓練並且比G更早地學習。我們需要一個好的D來欺騙G。

瞭解GAN背後的設計,訓練,損失函數和算法

另一個相關的重點是生成器可能記憶輸入示例的問題,DCGAN通過使用3072-128-3072、降噪、dropout、正則化、RELU和自動編碼器來解決,基本上是減少和重構機制,以最小化記憶。

DCGAN中還重點介紹了生成器在操作時如何忘記它正在生成的臥室圖像中的某些對象。他們通過從第二層卷積層特徵集中刪除對應窗口的特徵映射來實現這一點,並展示了網絡是如何用其他對象替換窗口空間的。

相關文章