摘要:FixMatch借鑑了UDA和ReMixMatch的這一思想,應用不同的增強方法,即在未標註的圖像上進行弱增強以生成僞標籤,同時在未標註圖像上進行強增強以進行預測。如圖所示,我們使用交叉熵損失在標註的圖像上訓練了監督模型。

        <div> 

  磐創AI分享   

來源 | AI公園 作者 | amitness 編譯 | ronghuaiyang 【導讀】 僅使用10張帶有標籤的圖像,它在CIFAR-10上的中位精度爲78%,最大精度爲84%,來看看是怎麼做到的。

深度學習在計算機視覺領域展示了非常有前途的結果。但是當將它應用於實際的醫學成像等領域的時候,標籤數據的缺乏是一個主要的挑戰。

在實際環境中,對數據做標註是一個耗時和昂貴的過程。你有很多的圖片,由於資源約束,只有一小部分人可以進行標註。在這樣的情況下,我們如何利用大量未標註的圖像以及部分已標註的圖像來提高我們的模型的性能?答案是semi-supervised學習。

FixMatch是Google Brain的Sohn等人最近開發的一種半監督方法,它改善了半監督學習(SSL)的技術水平。它是對之前的方法(例如UDA和ReMixMatch)的簡單組合。在本文中,我們將瞭解FixMatch的概念,並看到僅使用10張帶有標籤的圖像,它在CIFAR-10上的中位精度爲78%,最大精度爲84%。

FixMatch背後的直覺

假設我們正在對貓與狗進行分類,但是我們的標籤數據有限,並且有很多未標籤的貓狗圖像。

我們通常的“監督學習”方法將是僅在標註圖像上訓練分類器,而忽略未標註的圖像。

除了忽略未標註的圖像,我們還可以應用以下方法。我們知道模型也應該能夠處理圖像的擾動,從而提高泛化能力。

如果我們對未標註的圖像進行圖像增強,並讓監督模型預測這些圖像會怎麼樣?由於是同一張圖片,因此兩者的預測的標籤應該相同。

因此,即使不知道其正確的標籤,我們也可以將未標註的圖像用作訓練流水線的一部分。這是FixMatch及其之前的許多論文背後的核心思想。

FixMatch的Pipeline

憑直覺,讓我們看看如何在實踐中實際應用FixMatch。下圖總結了整個pipeline:

如圖所示,我們使用交叉熵損失在標註的圖像上訓練了監督模型。對於每個未標註的圖像,使用弱增強和強增強獲得兩個圖像。弱增強圖像被傳遞到我們的模型中,我們得到了關於類的預測。將最有信心的類別的概率與閾值進行比較。如果它高於閾值,那麼我們將該類作爲ground truth的標籤,即僞標籤。然後,將經過強增強的圖像傳遞到我們的模型中,獲取類別的預測。使用交叉熵損失將此概率分佈與ground truth僞標籤進行比較。兩種損失組合起來進行模型的更新。

Pipeline的組件

1. 訓練數據和增強

FixMatch借鑑了UDA和ReMixMatch的這一思想,應用不同的增強方法,即在未標註的圖像上進行弱增強以生成僞標籤,同時在未標註圖像上進行強增強以進行預測。

a. 弱增強

對於弱增強,本文使用標準的翻轉和平移策略。它包括兩個簡單的增強:

  • Random Horizontal Flip

應用此增強的概率爲50%。對於SVHN數據集,將跳過此步驟,因爲那些圖像包含與水平翻轉無關的數字。在PyTorch中,可以使用transforms執行以下操作:

from PIL import Image
import torchvision.transforms as transforms

im = Image.open('dog.png')
weak_im = transforms.RandomHorizontalFlip(p=0.5)(im)
  • 隨機水平和垂直移動

    12.5%,在PyTorch中,可以使用以下代碼來實現,其中32是圖像的大小:

import torchvision.transforms as transforms
from PIL import Image

im = Image.open('dog.png')
resized_im = transforms.Resize(32)(im)
translated = transforms.RandomCrop(size=32, 
                                   padding=int(32*0.125), 
                                   padding_mode='reflect')(resized_im)

b. 強增強

其中包括輸出嚴重失真的輸入圖像的增強版本。FixMatch應用RandAugment或CTAugment,然後應用CutOut增強。

1. Cutout

這種增強會隨機刪除圖像的正方形部分,並用灰色或黑色填充。PyTorch沒有內置的Cutout函數,但是我們可以重用其 RandomErasing 函數來達到CutOut的效果。

import torch
import torchvision.transforms as transforms

# Image of 520*520
im = torch.rand(3, 520, 520)

# Fill cutout with gray color
gray_code = 127

# ratio=(1, 1) to set aspect ratio of square
# p=1 means probability is 1, so always apply cutout
# scale=(0.01, 0.01) means we want to get cutout of 1% of image area
# Hence: Cuts out gray square of 52*52
cutout_im = transforms.RandomErasing(p=1, 
                                     ratio=(1, 1), 
                                     scale=(0.01, 0.01), 
                                     value=gray_code)(im)

2. AutoAugment的變體

以前的SSL使用的是 AutoAugment ,這個工具訓練了一個強化學習算法來尋找讓代理任務(例如CIFAR-10)得到最佳準確性的增強方法。這是有問題的,因爲我們需要一些標註的數據集來學習增強,並且還受到使用RL的資源限制。

因此,FixMatch使用了AutoAugment的兩個變體之一:

a. RandAugment

Random Augmentation(RandAugment) 思想是非常簡單的。

  • 首先,你有一個14種可能的增強的列表,以及一系列可能的幅度。

  • 你從這個列表裏隨機選出N個增強,這裏我們從列表裏選出兩種。

  • 然後我們選擇一個隨機的幅度M,從1到10。我們可以選擇一個幅度5,這意味着以百分比表示的幅度爲50%,因爲最大可能的M爲10,所以百分比= 5/10 = 50%。

  • 現在,將所選的增強應用於序列中的圖像。每種增強都有50%的可能性被應用。

  • N和M的值可以通過在驗證集上使用網格搜索的超參數優化來找到。在本文中,在每個訓練步驟使用預定義範圍內的隨機幅度,而不是固定幅度。

b. CTAugment

CTAugment是ReMixMatch論文中引入的一種增強技術,它使用控制理論中的思想來消除對AutoAugment中增強學習的需求。運作方式如下:

  • 我們有一組18種可能的變換,類似於RandAugment

  • 變換的幅度值被劃分爲bin,每個bin被分配一個權重。最初,所有bin的權重均爲1。

  • 現在從該集合中以相等的概率隨機選擇兩個變換,它們的序列形成了一條管道。這類似於RandAugment。

  • 對於每個變換,根據歸一化的bin權重隨機選擇一個幅值bin

  • 現在,帶有標記的樣本通過這兩個轉換得到了增強,並傳遞給模型以進行預測

  • 根據模型預測值與實際標籤的接近程度,更新這些變換的bin權重。

  • 因此,它學會選擇具有較高的機會來預測正確的標籤的模型,從而在網絡容差範圍內進行增強。

因此,我們看到,與RandAugment不同,CTAugment可以在訓練過程中動態學習每個變換的幅度。因此,我們無需在某些受監督的代理任務上對其進行優化,並且它沒有敏感的超參數可優化。因此,這非常適合缺少標籤數據的半監督環境。

2. 模型結構

本文使用稱爲Wide Residual Networks的ResNet的更廣和更淺的變體作爲基礎體系結構。使用的確切變體是Wide-Resnet-28-2,深度爲28,擴展因子爲2。因此,此模型的寬度是ResNet的兩倍。它總共有150萬個參數。該模型與輸出層堆疊在一起,輸出層的節點等於所需的類數(例如,用於貓/狗分類的2個類)。

3. 模型訓練和損失函數

  • 步驟1: 準備batches

我們準備了批大小爲B的標記圖像和批大小爲μB的未標記圖像。μ是一個超參數,它決定批中未標註圖像的相對大小。例如,μ=2意味着我們使用的未標註圖像數量是標註圖像的兩倍。

該論文嘗試增加μ的值,發現隨着我們增加未標註圖像的數量,錯誤率會降低。本文將μ=7用於驗證數據集。

  • 步驟2: 監督學習

對於在標註圖像上訓練的監督部分,我們將常規的交叉熵損失H()用於分類任務。batch的總損失由定義,並通過取batch中每個圖像的交叉熵損失的平均值來計算。

  • 步驟3: 僞標籤

對於未標註的圖像,首先我們對未標註的圖像應用弱增強,並通過argmax獲得最高概率的預測類別。這就是僞標籤,將與強增強圖像上的模型輸出進行比較。

  • 步驟4: 一致性正則化

現在,同一張未標註的圖片進行了強增強,並將其輸出與我們的僞標籤進行比較以計算交叉熵損失H()。未標註的總batch損失由表示,並由下式給出:

這裏τ表示閾值,在該閾值之上我們採用僞標籤。該損失類似於僞標籤損失。不同之處在於,我們使用弱增強來生成標籤,而強增強則用於損失。

  • 步驟5: 課程學習

    最後,我們將這兩個損失相結合,以獲得總損失,我們可以對其進行優化以改進模型。是一個固定的標量超參數,它決定了未標註圖像損失相對於標註損失的貢獻量。

有趣的結果來自。以前的工作表明,在訓練過程中增加權重是很好的。但是,在FixMatch中,這是自動內置的。由於最初,該模型對標註的數據沒有把握,因此其對未標註的數據的輸出預測將低於閾值。這樣,將僅在標註的數據上訓練模型。但是隨着訓練的進行,模型對標註的數據變得更加自信,因此,對未標註數據的預測也將開始超過閾值。這樣,損失將很快也開始包含對未標註圖像的預測。這爲我們提供了一種免費的課程學習形式。

從直覺上講,這類似於我們在兒童時期的教學方式。在早期,我們先學習一些簡單的概念,例如字母表及其代表的含義,然後再繼續學習複雜的主題,例如構詞,句子和文章。

論文中的直覺

1. 我們可以每個類別只學習一張圖片嗎?

作者對CIFAR-10數據集進行了一個非常有趣的實驗。他們僅使用10個標註圖像(即每個類別1個標註樣本)在CIFAR-10上訓練了模型。

  • 他們通過從數據集中從每個類中隨機選擇一個樣本來創建4個數據集,並對每個數據集進行4次訓練。他們達到了48.58%至85.32%的測試準確度,中位準確率爲64.28%。準確性的這種差異是由於標註樣本的質量所致。當模型提供低質量的樣本時,很難有效地學習每個類別。

爲了測試這一點,他們創建了8個訓練數據集,其樣本範圍從最有代表性的到最不代表性的。他們遵循了本文的順序並分爲8個bin。第一個bin包含最具代表性的圖像,而最後一個bin包含離羣值。然後,他們從每個bin中隨機抽取每個類別的一個樣本,以創建8個標註過的訓練集並訓練FixMatch模型。結果是:

  • 最具代表性的bin :中位數精度爲78%,最大精度爲84%

  • 中等代表性的bin :準確度達65%

  • 離羣值 :僅僅有10%的精度無法完全收斂

評估和結果

作者對常用的SSL的數據集(例如CIFAR-10,CIFAR-100,SVHN,STL-10和ImageNet)進行了評估。

  • CIFAR-10和SVHN:

FixMatch在CIFAR-10和SVHN基準測試中獲得了state of the art的結果。

  • CIFAR-100:

在CIFAR-100上,ReMixMatch優於FixMatch。爲了理解原因,作者從ReMixMatch中借用了各種組件到FixMatch上,並測量了它們對性能的影響。他們發現,*Distribution Alignment(DA)*組件促使模型以相同的概率預測所有類,這就是原因。因此,當他們將FixMatch與DA結合使用時,他們實現了40.14%的錯誤率,而ReMixMatch的錯誤率爲44.28%。

  • STL-10:

STL-10數據集由100,000個未標註圖像和5000個標註圖像組成。我們需要預測10類(飛機,鳥,汽車,貓,鹿,狗,馬,猴子,船,卡車)。它是半監督學習的更具代表性的評估方法,因爲其未標註的集合具有分佈以外的圖像。在所有方法中,對1000張帶標籤的圖像進行5折評估時,FixMatch的CTAugment可以實現最低的錯誤率。

  • ImageNet:

    作者還評估了ImageNet上的模型,以驗證其是否適用於大型和複雜的數據集。他們將訓練數據的10%作爲標記的圖像,其餘的90%作爲未標記的圖像。同樣,所使用的體系結構是ResNet-50而不是WideResNet,並且RandAugment被用作增強。他們的top-1錯誤率達到28.54%±0.52,比UDA高2.68%。top5的錯誤率是10.87%±0.28%。

代碼實現

官方實現,TensorFlow:https://github.com/google-research/fixmatch。

PyTorch的的非官方實現:

1、https://github.com/kekmodel/FixMatch-pytorch

2、https://github.com/CoinCheung/fixmatch

3、https://github.com/valencebond/FixMatch_pytorch

- End -

✄------------------------------------------------ 看到這裏,說明你喜歡這篇文章,請點擊「 在看 」或順手「 轉發 」「 點贊 」。 歡迎微信搜索「 panchuangxx 」,添加小編磐小小仙微信,每日朋友圈更新一篇高質量推文(無廣告),爲您提供更多精彩內容。 ▼       掃描二維碼添加小編   ▼    ▼  

相關文章