打通多個視覺任務的全能Backbone:HRNet
<meta http-equiv="Content-Type" content="text/html; charset=UTF-8"/><div>
<div>
<h2>
<img src="https://img1.tuicool.com/vee6RnZ.jpg!web" class="alignCenter" referrerpolicy="no-referrer"/>
</h2>
<h3>
<div>
<span> 磐創AI分享</span>
<strong>
<span> </span>
</strong>
<strong/>
</div>
</h3>
來源 | GiantPandaCV
編輯 | pprp
<h4>
<span>
<strong>【導讀】</strong>
</span>
<span>HRNet是微軟亞洲研究院的王井東老師領導的團隊完成的,打通圖像分類、圖像分割、目標檢測、人臉對齊、姿態識別、風格遷移、Image Inpainting、超分、optical flow、Depth estimation、邊緣檢測等網絡結構。</span>
</h4>
王老師在ValseWebinar《物體和關鍵點檢測》中親自講解了HRNet,講解地非常透徹。以下文章主要參考了王老師在演講中的解讀,配合論文+代碼部分,來爲各位讀者介紹這個全能的Backbone-HRNet。
引入
<figure>
<img src="https://img0.tuicool.com/feURFb2.png!web" class="alignCenter" referrerpolicy="no-referrer"/>
<figcaption>網絡結構設計思路</figcaption>
</figure>
在人體姿態識別這類的任務中,需要生成一個高分辨率的heatmap來進行關鍵點檢測。這就與一般的網絡結構比如VGGNet的要求不同,因爲VGGNet最終得到的feature map分辨率很低,損失了空間結構。
<figure>
<img src="https://img1.tuicool.com/rayemqu.png!web" class="alignCenter" referrerpolicy="no-referrer"/>
<figcaption>傳統的解決思路</figcaption>
</figure>
獲取高分辨率的方式大部分都是如上圖所示,採用的是先降分辨率,然後再升分辨率的方法。U-Net、SegNet、DeconvNet、Hourglass本質上都是這種結構。
<figure>
<img src="https://img2.tuicool.com/i6Z3Qv3.png!web" class="alignCenter" referrerpolicy="no-referrer"/>
<figcaption>雖然看上去不同,但是本質是一致的</figcaption>
</figure>
<h4>核心</h4>
普通網絡都是這種結構,不同分辨率之間是進行了串聯
<figure>
<img src="https://img1.tuicool.com/EfiABju.png!web" class="alignCenter" referrerpolicy="no-referrer"/>
<figcaption>不斷降分辨率</figcaption>
</figure>
王井東老師則是將不同分辨率的feature map進行並聯:
<figure>
<img src="https://img0.tuicool.com/JRv6Jr2.png!web" class="alignCenter" referrerpolicy="no-referrer"/>
<figcaption>並聯不同分辨率feature map</figcaption>
</figure>
在並聯的基礎上,添加不同分辨率feature map之間的交互(fusion)。
<figure>
<img src="https://img2.tuicool.com/j2IZZzq.png!web" class="alignCenter" referrerpolicy="no-referrer"/>
</figure>
具體fusion的方法如下圖所示:
<figure>
<img src="https://img0.tuicool.com/Qv6VZjE.png!web" class="alignCenter" referrerpolicy="no-referrer"/>
</figure>
<ul>
<li>
<p class="no-text-indent">同分辨率的層直接複製。</p>
</li>
<li>
<p class="no-text-indent">需要升分辨率的使用bilinear upsample + 1x1卷積將channel數統一。</p>
</li>
<li>
<p class="no-text-indent">需要降分辨率的使用strided 3x3 卷積。</p>
</li>
<li>
<p class="no-text-indent">三個feature map融合的方式是相加。</p>
</li>
</ul>
至於爲何要用strided 3x3卷積,這是因爲卷積在降維的時候會出現信息損失,使用strided 3x3卷積是爲了通過學習的方式,降低信息的損耗。所以這裏沒有用maxpool或者組合池化。
<figure>
<img src="https://img2.tuicool.com/buINryY.png!web" class="alignCenter" referrerpolicy="no-referrer"/>
<figcaption>HR示意圖</figcaption>
</figure>
另外在讀HRNet的時候會有一個問題,有四個分支的到底如何使用這幾個分支呢?論文中也給出了幾種方式作爲最終的特徵選擇。
<figure>
<img src="https://img1.tuicool.com/EJ3U7vb.png!web" class="alignCenter" referrerpolicy="no-referrer"/>
<figcaption>三種特徵融合方法</figcaption>
</figure>
(a)圖展示的是HRNetV1的特徵選擇,只使用分辨率最高的特徵圖。
(b)圖展示的是HRNetV2的特徵選擇,將所有分辨率的特徵圖(小的特徵圖進行upsample)進行concate,主要用於語義分割和麪部關鍵點檢測。
(c)圖展示的是HRNetV2p的特徵選擇,在HRNetV2的基礎上,使用了一個特徵金字塔,主要用於目標檢測網絡。
再補充一個(d)圖
<figure>
<img src="https://img2.tuicool.com/a6N3eyq.png!web" class="alignCenter" referrerpolicy="no-referrer"/>
<figcaption>HRNetV2分類網絡後的特徵選擇</figcaption>
</figure>
(d)圖展示的也是HRNetV2,採用上圖的融合方式,主要用於訓練分類網絡。
總結一下HRNet 創新點
:
-
將高低分辨率之間的鏈接由串聯改爲並聯。
-
在整個網絡結構中都保持了高分辨率的表徵(最上邊那個通路)。
-
在高低分辨率中引入了交互來提高模型性能。
<h4>效果</h4>
消融實驗
<ol>
<li>
<p class="no-text-indent">對交互方法進行消融實驗,證明了當前跨分辨率的融合的有效性。</p>
</li>
</ol>
<figure>
<img src="https://img2.tuicool.com/Y3QvA3R.png!web" class="alignCenter" referrerpolicy="no-referrer"/>
<figcaption>交互方法的消融實現</figcaption>
</figure>
<ol start="2">
<li>
<p class="no-text-indent">證明高分辨率feature map的表徵能力</p>
</li>
</ol>
<figure>
<img src="https://img2.tuicool.com/7baMviM.png!web" class="alignCenter" referrerpolicy="no-referrer"/>
</figure>
1x代表不進行降維,2x代表分辨率變爲原來一半,4x代表分辨率變爲原來四分之一。W32、W48中的32、48代表卷積的寬度或者通道數。
姿態識別任務上的表現
<figure>
<img src="https://img2.tuicool.com/iamIz2U.png!web" class="alignCenter" referrerpolicy="no-referrer"/>
</figure>
以上的姿態識別採用的是top-down的方法。
<figure>
<img src="https://img2.tuicool.com/Qb2EfqB.png!web" class="alignCenter" referrerpolicy="no-referrer"/>
<figcaption>COCO驗證集的結果</figcaption>
</figure>
在參數和計算量不增加的情況下,要比其他同類網絡效果好很多。
<figure>
<img src="https://img0.tuicool.com/JR3UbiZ.png!web" class="alignCenter" referrerpolicy="no-referrer"/>
<figcaption>COCO測試集上的結果</figcaption>
</figure>
在19年2月28日時的PoseTrack Leaderboard,HRNet佔領兩個項目的第一名。
<figure>
<img src="https://img0.tuicool.com/Nf26vyZ.png!web" class="alignCenter" referrerpolicy="no-referrer"/>
<figcaption>PoseTrack Leaderboard</figcaption>
</figure>
<div>
<h4>語義分割任務中的表現</h4>
<p style="text-align:center">
<img src="https://img0.tuicool.com/RZFJvuJ.png!web" class="alignCenter" referrerpolicy="no-referrer"/>
</p>
</div>
<figure>
<div>CityScape驗證集上的結果對比</div>
</figure>
<figure>
<img src="https://img1.tuicool.com/vUBv2qm.png!web" class="alignCenter" referrerpolicy="no-referrer"/>
<figcaption>Cityscapes測試集上的對比</figcaption>
</figure>
<h4>目標檢測任務中的表現</h4>
<figure>
<img src="https://img1.tuicool.com/JBNf2un.png!web" class="alignCenter" referrerpolicy="no-referrer"/>
<figcaption>單模型單尺度模型對比</figcaption>
</figure>
<figure>
<img src="https://img0.tuicool.com/e2a26fF.png!web" class="alignCenter" referrerpolicy="no-referrer"/>
<figcaption>Mask R-CNN上結果</figcaption>
</figure>
分類任務上的表現
<figure>
<img src="https://img0.tuicool.com/nINfEry.png!web" class="alignCenter" referrerpolicy="no-referrer"/>
</figure>
ps: 王井東老師在這部分提到,分割的網絡也需要使用分類的預訓練模型,否則結果會差幾個點。
<figure>
<img src="https://img2.tuicool.com/FN3Y3mV.png!web" class="alignCenter" referrerpolicy="no-referrer"/>
<figcaption>圖像分類任務中和ResNet進行對比</figcaption>
</figure>
以上是HRNet和ResNet結果對比,同一個顏色的都是參數量大體一致的模型進行的對比,在參數量差不多甚至更少的情況下,HRNet能夠比ResNet達到更好的效果。
HRNet( https://github.com/HRNet )工作量非常大,構建了六個庫涉及語義分割、人體姿態檢測、目標檢測、圖片分類、面部關鍵點檢測、Mask R-CNN等庫。全部內容如下圖所示:
<figure>
<img src="https://img0.tuicool.com/ru6nErZ.png!web" class="alignCenter" referrerpolicy="no-referrer"/>
</figure>
筆者對HRNet代碼構建非常感興趣,所以以HRNet-Image-Classification庫爲例,來解析一下這部分代碼。
先從簡單的入手,BasicBlock
<figure>
<img src="https://img0.tuicool.com/AB3amym.png!web" class="alignCenter" referrerpolicy="no-referrer"/>
<figcaption>BasicBlock結構</figcaption>
</figure>
<pre class="prettyprint"><span style="line-height: 1.6 !important;"><span style="line-height:1.6 !important;">def</span> <span style="line-height:1.6 !important;">conv3x3</span><span style="line-height: 1.6 !important;">(in_planes, out_planes, stride=<span style="line-height:1.6 !important;">1</span>)</span>:</span><br style="line-height: 1.6 !important;"/> <span style="line-height:1.6 !important;">"""3x3 convolution with padding"""</span><br style="line-height: 1.6 !important;"/> <span style="line-height:1.6 !important;">return</span> nn.Conv2d(in_planes, out_planes, kernel_size=<span style="line-height:1.6 !important;">3</span>, stride=stride,<br style="line-height: 1.6 !important;"/> padding=<span style="line-height:1.6 !important;">1</span>, bias=<span style="line-height:1.6 !important;">False</span>)<br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/><span style="line-height: 1.6 !important;"><span style="line-height:1.6 !important;">class</span> <span style="line-height:1.6 !important;">BasicBlock</span><span style="line-height: 1.6 !important;">(nn.Module)</span>:</span><br style="line-height: 1.6 !important;"/> expansion = <span style="line-height:1.6 !important;">1</span><br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/> <span style="line-height: 1.6 !important;"><span style="line-height:1.6 !important;">def</span> <span style="line-height:1.6 !important;">__init__</span><span style="line-height: 1.6 !important;">(self, inplanes, planes, stride=<span style="line-height:1.6 !important;">1</span>, downsample=None)</span>:</span><br style="line-height: 1.6 !important;"/> super(BasicBlock, self).__init__()<br style="line-height: 1.6 !important;"/> self.conv1 = conv3x3(inplanes, planes, stride)<br style="line-height: 1.6 !important;"/> self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)<br style="line-height: 1.6 !important;"/> self.relu = nn.ReLU(inplace=<span style="line-height:1.6 !important;">True</span>)<br style="line-height: 1.6 !important;"/> self.conv2 = conv3x3(planes, planes)<br style="line-height: 1.6 !important;"/> self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)<br style="line-height: 1.6 !important;"/> self.downsample = downsample<br style="line-height: 1.6 !important;"/> self.stride = stride<br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/> <span style="line-height: 1.6 !important;"><span style="line-height:1.6 !important;">def</span> <span style="line-height:1.6 !important;">forward</span><span style="line-height: 1.6 !important;">(self, x)</span>:</span><br style="line-height: 1.6 !important;"/> residual = x<br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/> out = self.conv1(x)<br style="line-height: 1.6 !important;"/> out = self.bn1(out)<br style="line-height: 1.6 !important;"/> out = self.relu(out)<br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/> out = self.conv2(out)<br style="line-height: 1.6 !important;"/> out = self.bn2(out)<br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/> <span style="line-height:1.6 !important;">if</span> self.downsample <span style="line-height:1.6 !important;">is</span> <span style="line-height:1.6 !important;">not</span> <span style="line-height:1.6 !important;">None</span>:<br style="line-height: 1.6 !important;"/> residual = self.downsample(x)<br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/> out += residual<br style="line-height: 1.6 !important;"/> out = self.relu(out)<br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/> <span style="line-height:1.6 !important;">return</span> out<br style="line-height: 1.6 !important;"/></pre>
Bottleneck:
<figure>
<img src="https://img1.tuicool.com/367B3um.png!web" class="alignCenter" referrerpolicy="no-referrer"/>
<figcaption>Bottleneck結構圖</figcaption>
</figure>
<pre class="prettyprint"><span style="line-height: 1.6 !important;"><span style="line-height:1.6 !important;">class</span> <span style="line-height:1.6 !important;">Bottleneck</span><span style="line-height: 1.6 !important;">(nn.Module)</span>:</span><br style="line-height: 1.6 !important;"/> expansion = <span style="line-height:1.6 !important;">4</span><br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/> <span style="line-height: 1.6 !important;"><span style="line-height:1.6 !important;">def</span> <span style="line-height:1.6 !important;">__init__</span><span style="line-height: 1.6 !important;">(self, inplanes, planes, stride=<span style="line-height:1.6 !important;">1</span>, downsample=None)</span>:</span><br style="line-height: 1.6 !important;"/> super(Bottleneck, self).__init__()<br style="line-height: 1.6 !important;"/> self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=<span style="line-height:1.6 !important;">1</span>, bias=<span style="line-height:1.6 !important;">False</span>)<br style="line-height: 1.6 !important;"/> self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)<br style="line-height: 1.6 !important;"/> self.conv2 = nn.Conv2d(planes, planes, kernel_size=<span style="line-height:1.6 !important;">3</span>, stride=stride,<br style="line-height: 1.6 !important;"/> padding=<span style="line-height:1.6 !important;">1</span>, bias=<span style="line-height:1.6 !important;">False</span>)<br style="line-height: 1.6 !important;"/> self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)<br style="line-height: 1.6 !important;"/> self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=<span style="line-height:1.6 !important;">1</span>,<br style="line-height: 1.6 !important;"/> bias=<span style="line-height:1.6 !important;">False</span>)<br style="line-height: 1.6 !important;"/> self.bn3 = nn.BatchNorm2d(planes * self.expansion,<br style="line-height: 1.6 !important;"/> momentum=BN_MOMENTUM)<br style="line-height: 1.6 !important;"/> self.relu = nn.ReLU(inplace=<span style="line-height:1.6 !important;">True</span>)<br style="line-height: 1.6 !important;"/> self.downsample = downsample<br style="line-height: 1.6 !important;"/> self.stride = stride<br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/> <span style="line-height: 1.6 !important;"><span style="line-height:1.6 !important;">def</span> <span style="line-height:1.6 !important;">forward</span><span style="line-height: 1.6 !important;">(self, x)</span>:</span><br style="line-height: 1.6 !important;"/> residual = x<br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/> out = self.conv1(x)<br style="line-height: 1.6 !important;"/> out = self.bn1(out)<br style="line-height: 1.6 !important;"/> out = self.relu(out)<br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/> out = self.conv2(out)<br style="line-height: 1.6 !important;"/> out = self.bn2(out)<br style="line-height: 1.6 !important;"/> out = self.relu(out)<br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/> out = self.conv3(out)<br style="line-height: 1.6 !important;"/> out = self.bn3(out)<br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/> <span style="line-height:1.6 !important;">if</span> self.downsample <span style="line-height:1.6 !important;">is</span> <span style="line-height:1.6 !important;">not</span> <span style="line-height:1.6 !important;">None</span>:<br style="line-height: 1.6 !important;"/> residual = self.downsample(x)<br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/> out += residual<br style="line-height: 1.6 !important;"/> out = self.relu(out)<br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/> <span style="line-height:1.6 !important;">return</span> out<br style="line-height: 1.6 !important;"/></pre>
HighResolutionModule,這是核心模塊, 主要分爲兩個組件:branches和fuse layer。
<figure>
<img src="https://img1.tuicool.com/ZfeU3ej.jpg!web" class="alignCenter" referrerpolicy="no-referrer"/>
</figure>
<pre class="prettyprint"><span style="line-height: 1.6 !important;"><span style="line-height:1.6 !important;">class</span> <span style="line-height:1.6 !important;">HighResolutionModule</span><span style="line-height: 1.6 !important;">(nn.Module)</span>:</span><br style="line-height: 1.6 !important;"/> <span style="line-height: 1.6 !important;"><span style="line-height:1.6 !important;">def</span> <span style="line-height:1.6 !important;">__init__</span><span style="line-height: 1.6 !important;">(self, num_branches, blocks, num_blocks, num_inchannels,<br style="line-height: 1.6 !important;"/> num_channels, fuse_method, multi_scale_output=True)</span>:</span><br style="line-height: 1.6 !important;"/> <span style="line-height:1.6 !important;">'''<br style="line-height: 1.6 !important;"/> 調用:<br style="line-height: 1.6 !important;"/> # 調用高低分辨率交互模塊, stage2 爲例<br style="line-height: 1.6 !important;"/> HighResolutionModule(num_branches, # 2<br style="line-height: 1.6 !important;"/> block, # 'BASIC'<br style="line-height: 1.6 !important;"/> num_blocks, # [4, 4]<br style="line-height: 1.6 !important;"/> num_inchannels, # 上個stage的out channel<br style="line-height: 1.6 !important;"/> num_channels, # [32, 64]<br style="line-height: 1.6 !important;"/> fuse_method, # SUM<br style="line-height: 1.6 !important;"/> reset_multi_scale_output)<br style="line-height: 1.6 !important;"/> '''</span><br style="line-height: 1.6 !important;"/> super(HighResolutionModule, self).__init__()<br style="line-height: 1.6 !important;"/> self._check_branches(<br style="line-height: 1.6 !important;"/> <span style="line-height:1.6 !important;font-style:italic;"># 檢查分支數目是否合理</span><br style="line-height: 1.6 !important;"/> num_branches, blocks, num_blocks, num_inchannels, num_channels)<br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/> self.num_inchannels = num_inchannels<br style="line-height: 1.6 !important;"/> <span style="line-height:1.6 !important;font-style:italic;"># 融合選用相加的方式</span><br style="line-height: 1.6 !important;"/> self.fuse_method = fuse_method<br style="line-height: 1.6 !important;"/> self.num_branches = num_branches<br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/> self.multi_scale_output = multi_scale_output<br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/> <span style="line-height:1.6 !important;font-style:italic;"># 兩個核心部分,一個是branches構建,一個是融合layers構建</span><br style="line-height: 1.6 !important;"/> self.branches = self._make_branches(<br style="line-height: 1.6 !important;"/> num_branches, blocks, num_blocks, num_channels)<br style="line-height: 1.6 !important;"/> self.fuse_layers = self._make_fuse_layers()<br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/> self.relu = nn.ReLU(<span style="line-height:1.6 !important;">False</span>)<br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/> <span style="line-height: 1.6 !important;"><span style="line-height:1.6 !important;">def</span> <span style="line-height:1.6 !important;">_check_branches</span><span style="line-height: 1.6 !important;">(self, num_branches, blocks, num_blocks,<br style="line-height: 1.6 !important;"/> num_inchannels, num_channels)</span>:</span><br style="line-height: 1.6 !important;"/> <span style="line-height:1.6 !important;font-style:italic;"># 分別檢查參數是否符合要求,看models.py中的參數,blocks參數冗餘了</span><br style="line-height: 1.6 !important;"/> <span style="line-height:1.6 !important;">if</span> num_branches != len(num_blocks):<br style="line-height: 1.6 !important;"/> error_msg = <span style="line-height:1.6 !important;">'NUM_BRANCHES({}) <> NUM_BLOCKS({})'</span>.format(<br style="line-height: 1.6 !important;"/> num_branches, len(num_blocks))<br style="line-height: 1.6 !important;"/> logger.error(error_msg)<br style="line-height: 1.6 !important;"/> <span style="line-height:1.6 !important;">raise</span> ValueError(error_msg)<br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/> <span style="line-height:1.6 !important;">if</span> num_branches != len(num_channels):<br style="line-height: 1.6 !important;"/> error_msg = <span style="line-height:1.6 !important;">'NUM_BRANCHES({}) <> NUM_CHANNELS({})'</span>.format(<br style="line-height: 1.6 !important;"/> num_branches, len(num_channels))<br style="line-height: 1.6 !important;"/> logger.error(error_msg)<br style="line-height: 1.6 !important;"/> <span style="line-height:1.6 !important;">raise</span> ValueError(error_msg)<br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/> <span style="line-height:1.6 !important;">if</span> num_branches != len(num_inchannels):<br style="line-height: 1.6 !important;"/> error_msg = <span style="line-height:1.6 !important;">'NUM_BRANCHES({}) <> NUM_INCHANNELS({})'</span>.format(<br style="line-height: 1.6 !important;"/> num_branches, len(num_inchannels))<br style="line-height: 1.6 !important;"/> logger.error(error_msg)<br style="line-height: 1.6 !important;"/> <span style="line-height:1.6 !important;">raise</span> ValueError(error_msg)<br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/> <span style="line-height: 1.6 !important;"><span style="line-height:1.6 !important;">def</span> <span style="line-height:1.6 !important;">_make_one_branch</span><span style="line-height: 1.6 !important;">(self, branch_index, block, num_blocks, num_channels,<br style="line-height: 1.6 !important;"/> stride=<span style="line-height:1.6 !important;">1</span>)</span>:</span><br style="line-height: 1.6 !important;"/> <span style="line-height:1.6 !important;font-style:italic;"># 構建一個分支,一個分支重複num_blocks個block</span><br style="line-height: 1.6 !important;"/> downsample = <span style="line-height:1.6 !important;">None</span><br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/> <span style="line-height:1.6 !important;font-style:italic;"># 這裏判斷,如果通道變大(分辨率變小),則使用下采樣</span><br style="line-height: 1.6 !important;"/> <span style="line-height:1.6 !important;">if</span> stride != <span style="line-height:1.6 !important;">1</span> <span style="line-height:1.6 !important;">or</span> \<br style="line-height: 1.6 !important;"/> self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion:<br style="line-height: 1.6 !important;"/> downsample = nn.Sequential(<br style="line-height: 1.6 !important;"/> nn.Conv2d(self.num_inchannels[branch_index],<br style="line-height: 1.6 !important;"/> num_channels[branch_index] * block.expansion,<br style="line-height: 1.6 !important;"/> kernel_size=<span style="line-height:1.6 !important;">1</span>, stride=stride, bias=<span style="line-height:1.6 !important;">False</span>),<br style="line-height: 1.6 !important;"/> nn.BatchNorm2d(num_channels[branch_index] * block.expansion,<br style="line-height: 1.6 !important;"/> momentum=BN_MOMENTUM),<br style="line-height: 1.6 !important;"/> )<br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/> layers = []<br style="line-height: 1.6 !important;"/> layers.append(block(self.num_inchannels[branch_index],<br style="line-height: 1.6 !important;"/> num_channels[branch_index], stride, downsample))<br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/> self.num_inchannels[branch_index] = \<br style="line-height: 1.6 !important;"/> num_channels[branch_index] * block.expansion<br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/> <span style="line-height:1.6 !important;">for</span> i <span style="line-height:1.6 !important;">in</span> range(<span style="line-height:1.6 !important;">1</span>, num_blocks[branch_index]):<br style="line-height: 1.6 !important;"/> layers.append(block(self.num_inchannels[branch_index],<br style="line-height: 1.6 !important;"/> num_channels[branch_index]))<br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/> <span style="line-height:1.6 !important;">return</span> nn.Sequential(*layers)<br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/> <span style="line-height: 1.6 !important;"><span style="line-height:1.6 !important;">def</span> <span style="line-height:1.6 !important;">_make_branches</span><span style="line-height: 1.6 !important;">(self, num_branches, block, num_blocks, num_channels)</span>:</span><br style="line-height: 1.6 !important;"/> branches = []<br style="line-height: 1.6 !important;"/> <br style="line-height: 1.6 !important;"/> <span style="line-height:1.6 !important;font-style:italic;"># 通過循環構建多分支,每個分支屬於不同的分辨率</span><br style="line-height: 1.6 !important;"/> <span style="line-height:1.6 !important;">for</span> i <span style="line-height:1.6 !important;">in</span> range(num_branches):<br style="line-height: 1.6 !important;"/> branches.append(<br style="line-height: 1.6 !important;"/> self._make_one_branch(i, block, num_blocks, num_channels))<br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/> <span style="line-height:1.6 !important;">return</span> nn.ModuleList(branches)<br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/> <span style="line-height: 1.6 !important;"><span style="line-height:1.6 !important;">def</span> <span style="line-height:1.6 !important;">_make_fuse_layers</span><span style="line-height: 1.6 !important;">(self)</span>:</span><br style="line-height: 1.6 !important;"/> <span style="line-height:1.6 !important;">if</span> self.num_branches == <span style="line-height:1.6 !important;">1</span>:<br style="line-height: 1.6 !important;"/> <span style="line-height:1.6 !important;">return</span> <span style="line-height:1.6 !important;">None</span><br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/> num_branches = self.num_branches <span style="line-height:1.6 !important;font-style:italic;"># 2</span><br style="line-height: 1.6 !important;"/> num_inchannels = self.num_inchannels<br style="line-height: 1.6 !important;"/> fuse_layers = []<br style="line-height: 1.6 !important;"/> <span style="line-height:1.6 !important;">for</span> i <span style="line-height:1.6 !important;">in</span> range(num_branches <span style="line-height:1.6 !important;">if</span> self.multi_scale_output <span style="line-height:1.6 !important;">else</span> <span style="line-height:1.6 !important;">1</span>):<br style="line-height: 1.6 !important;"/> <span style="line-height:1.6 !important;font-style:italic;"># i代表枚舉所有分支</span><br style="line-height: 1.6 !important;"/> fuse_layer = []<br style="line-height: 1.6 !important;"/> <span style="line-height:1.6 !important;">for</span> j <span style="line-height:1.6 !important;">in</span> range(num_branches):<br style="line-height: 1.6 !important;"/> <span style="line-height:1.6 !important;font-style:italic;"># j代表處理的當前分支</span><br style="line-height: 1.6 !important;"/> <span style="line-height:1.6 !important;">if</span> j > i: <span style="line-height:1.6 !important;font-style:italic;"># 進行上採樣,使用最近鄰插值</span><br style="line-height: 1.6 !important;"/> fuse_layer.append(nn.Sequential(<br style="line-height: 1.6 !important;"/> nn.Conv2d(num_inchannels[j],<br style="line-height: 1.6 !important;"/> num_inchannels[i],<br style="line-height: 1.6 !important;"/> <span style="line-height:1.6 !important;">1</span>,<br style="line-height: 1.6 !important;"/> <span style="line-height:1.6 !important;">1</span>,<br style="line-height: 1.6 !important;"/> <span style="line-height:1.6 !important;">0</span>,<br style="line-height: 1.6 !important;"/> bias=<span style="line-height:1.6 !important;">False</span>),<br style="line-height: 1.6 !important;"/> nn.BatchNorm2d(num_inchannels[i],<br style="line-height: 1.6 !important;"/> momentum=BN_MOMENTUM),<br style="line-height: 1.6 !important;"/> nn.Upsample(scale_factor=<span style="line-height:1.6 !important;">2</span>**(j-i), mode=<span style="line-height:1.6 !important;">'nearest'</span>)))<br style="line-height: 1.6 !important;"/> <span style="line-height:1.6 !important;">elif</span> j == i:<br style="line-height: 1.6 !important;"/> <span style="line-height:1.6 !important;font-style:italic;"># 本層不做處理</span><br style="line-height: 1.6 !important;"/> fuse_layer.append(<span style="line-height:1.6 !important;">None</span>)<br style="line-height: 1.6 !important;"/> <span style="line-height:1.6 !important;">else</span>:<br style="line-height: 1.6 !important;"/> conv3x3s = []<br style="line-height: 1.6 !important;"/> <span style="line-height:1.6 !important;font-style:italic;"># 進行strided 3x3 conv下采樣,如果跨兩層,就使用兩次strided 3x3 conv</span><br style="line-height: 1.6 !important;"/> <span style="line-height:1.6 !important;">for</span> k <span style="line-height:1.6 !important;">in</span> range(i-j):<br style="line-height: 1.6 !important;"/> <span style="line-height:1.6 !important;">if</span> k == i - j - <span style="line-height:1.6 !important;">1</span>:<br style="line-height: 1.6 !important;"/> num_outchannels_conv3x3 = num_inchannels[i]<br style="line-height: 1.6 !important;"/> conv3x3s.append(nn.Sequential(<br style="line-height: 1.6 !important;"/> nn.Conv2d(num_inchannels[j],<br style="line-height: 1.6 !important;"/> num_outchannels_conv3x3,<br style="line-height: 1.6 !important;"/> <span style="line-height:1.6 !important;">3</span>, <span style="line-height:1.6 !important;">2</span>, <span style="line-height:1.6 !important;">1</span>, bias=<span style="line-height:1.6 !important;">False</span>),<br style="line-height: 1.6 !important;"/> nn.BatchNorm2d(num_outchannels_conv3x3,<br style="line-height: 1.6 !important;"/> momentum=BN_MOMENTUM)))<br style="line-height: 1.6 !important;"/> <span style="line-height:1.6 !important;">else</span>:<br style="line-height: 1.6 !important;"/> num_outchannels_conv3x3 = num_inchannels[j]<br style="line-height: 1.6 !important;"/> conv3x3s.append(nn.Sequential(<br style="line-height: 1.6 !important;"/> nn.Conv2d(num_inchannels[j],<br style="line-height: 1.6 !important;"/> num_outchannels_conv3x3,<br style="line-height: 1.6 !important;"/> <span style="line-height:1.6 !important;">3</span>, <span style="line-height:1.6 !important;">2</span>, <span style="line-height:1.6 !important;">1</span>, bias=<span style="line-height:1.6 !important;">False</span>),<br style="line-height: 1.6 !important;"/> nn.BatchNorm2d(num_outchannels_conv3x3,<br style="line-height: 1.6 !important;"/> nn.ReLU(<span style="line-height:1.6 !important;">False</span>)))<br style="line-height: 1.6 !important;"/> fuse_layer.append(nn.Sequential(*conv3x3s))<br style="line-height: 1.6 !important;"/> fuse_layers.append(nn.ModuleList(fuse_layer))<br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/> <span style="line-height:1.6 !important;">return</span> nn.ModuleList(fuse_layers)<br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/> <span style="line-height: 1.6 !important;"><span style="line-height:1.6 !important;">def</span> <span style="line-height:1.6 !important;">get_num_inchannels</span><span style="line-height: 1.6 !important;">(self)</span>:</span><br style="line-height: 1.6 !important;"/> <span style="line-height:1.6 !important;">return</span> self.num_inchannels<br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/> <span style="line-height: 1.6 !important;"><span style="line-height:1.6 !important;">def</span> <span style="line-height:1.6 !important;">forward</span><span style="line-height: 1.6 !important;">(self, x)</span>:</span><br style="line-height: 1.6 !important;"/> <span style="line-height:1.6 !important;">if</span> self.num_branches == <span style="line-height:1.6 !important;">1</span>:<br style="line-height: 1.6 !important;"/> <span style="line-height:1.6 !important;">return</span> [self.branches[<span style="line-height:1.6 !important;">0</span>](x[<span style="line-height:1.6 !important;">0</span>])]<br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/> <span style="line-height:1.6 !important;">for</span> i <span style="line-height:1.6 !important;">in</span> range(self.num_branches):<br style="line-height: 1.6 !important;"/> x[i]=self.branches[i](x[i])<br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/> x_fuse=[]<br style="line-height: 1.6 !important;"/> <span style="line-height:1.6 !important;">for</span> i <span style="line-height:1.6 !important;">in</span> range(len(self.fuse_layers)):<br style="line-height: 1.6 !important;"/> y=x[<span style="line-height:1.6 !important;">0</span>] <span style="line-height:1.6 !important;">if</span> i == <span style="line-height:1.6 !important;">0</span> <span style="line-height:1.6 !important;">else</span> self.fuse_layers[i][<span style="line-height:1.6 !important;">0</span>](x[<span style="line-height:1.6 !important;">0</span>])<br style="line-height: 1.6 !important;"/> <span style="line-height:1.6 !important;">for</span> j <span style="line-height:1.6 !important;">in</span> range(<span style="line-height:1.6 !important;">1</span>, self.num_branches):<br style="line-height: 1.6 !important;"/> <span style="line-height:1.6 !important;">if</span> i == j:<br style="line-height: 1.6 !important;"/> y=y + x[j]<br style="line-height: 1.6 !important;"/> <span style="line-height:1.6 !important;">else</span>:<br style="line-height: 1.6 !important;"/> y=y + self.fuse_layers[i][j](x[j])<br style="line-height: 1.6 !important;"/> x_fuse.append(self.relu(y))<br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/> <span style="line-height:1.6 !important;font-style:italic;"># 將fuse以後的多個分支結果保存到list中</span><br style="line-height: 1.6 !important;"/> <span style="line-height:1.6 !important;">return</span> x_fuse<br style="line-height: 1.6 !important;"/></pre>
models.py中保存的參數, 可以通過這些配置來改變模型的容量、分支個數、特徵融合方法:
<pre class="prettyprint"><span style="line-height:1.6 !important;font-style:italic;"># high_resoluton_net related params for classification</span><br style="line-height: 1.6 !important;"/>POSE_HIGH_RESOLUTION_NET = CN()<br style="line-height: 1.6 !important;"/>POSE_HIGH_RESOLUTION_NET.PRETRAINED_LAYERS = [<span style="line-height:1.6 !important;">'*'</span>]<br style="line-height: 1.6 !important;"/>POSE_HIGH_RESOLUTION_NET.STEM_INPLANES = <span style="line-height:1.6 !important;">64</span><br style="line-height: 1.6 !important;"/>POSE_HIGH_RESOLUTION_NET.FINAL_CONV_KERNEL = <span style="line-height:1.6 !important;">1</span><br style="line-height: 1.6 !important;"/>POSE_HIGH_RESOLUTION_NET.WITH_HEAD = <span style="line-height:1.6 !important;">True</span><br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/>POSE_HIGH_RESOLUTION_NET.STAGE2 = CN()<br style="line-height: 1.6 !important;"/>POSE_HIGH_RESOLUTION_NET.STAGE2.NUM_MODULES = <span style="line-height:1.6 !important;">1</span><br style="line-height: 1.6 !important;"/>POSE_HIGH_RESOLUTION_NET.STAGE2.NUM_BRANCHES = <span style="line-height:1.6 !important;">2</span><br style="line-height: 1.6 !important;"/>POSE_HIGH_RESOLUTION_NET.STAGE2.NUM_BLOCKS = [<span style="line-height:1.6 !important;">4</span>, <span style="line-height:1.6 !important;">4</span>]<br style="line-height: 1.6 !important;"/>POSE_HIGH_RESOLUTION_NET.STAGE2.NUM_CHANNELS = [<span style="line-height:1.6 !important;">32</span>, <span style="line-height:1.6 !important;">64</span>]<br style="line-height: 1.6 !important;"/>POSE_HIGH_RESOLUTION_NET.STAGE2.BLOCK = <span style="line-height:1.6 !important;">'BASIC'</span><br style="line-height: 1.6 !important;"/>POSE_HIGH_RESOLUTION_NET.STAGE2.FUSE_METHOD = <span style="line-height:1.6 !important;">'SUM'</span><br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/>POSE_HIGH_RESOLUTION_NET.STAGE3 = CN()<br style="line-height: 1.6 !important;"/>POSE_HIGH_RESOLUTION_NET.STAGE3.NUM_MODULES = <span style="line-height:1.6 !important;">1</span><br style="line-height: 1.6 !important;"/>POSE_HIGH_RESOLUTION_NET.STAGE3.NUM_BRANCHES = <span style="line-height:1.6 !important;">3</span><br style="line-height: 1.6 !important;"/>POSE_HIGH_RESOLUTION_NET.STAGE3.NUM_BLOCKS = [<span style="line-height:1.6 !important;">4</span>, <span style="line-height:1.6 !important;">4</span>, <span style="line-height:1.6 !important;">4</span>]<br style="line-height: 1.6 !important;"/>POSE_HIGH_RESOLUTION_NET.STAGE3.NUM_CHANNELS = [<span style="line-height:1.6 !important;">32</span>, <span style="line-height:1.6 !important;">64</span>, <span style="line-height:1.6 !important;">128</span>]<br style="line-height: 1.6 !important;"/>POSE_HIGH_RESOLUTION_NET.STAGE3.BLOCK = <span style="line-height:1.6 !important;">'BASIC'</span><br style="line-height: 1.6 !important;"/>POSE_HIGH_RESOLUTION_NET.STAGE3.FUSE_METHOD = <span style="line-height:1.6 !important;">'SUM'</span><br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/>POSE_HIGH_RESOLUTION_NET.STAGE4 = CN()<br style="line-height: 1.6 !important;"/>POSE_HIGH_RESOLUTION_NET.STAGE4.NUM_MODULES = <span style="line-height:1.6 !important;">1</span><br style="line-height: 1.6 !important;"/>POSE_HIGH_RESOLUTION_NET.STAGE4.NUM_BRANCHES = <span style="line-height:1.6 !important;">4</span><br style="line-height: 1.6 !important;"/>POSE_HIGH_RESOLUTION_NET.STAGE4.NUM_BLOCKS = [<span style="line-height:1.6 !important;">4</span>, <span style="line-height:1.6 !important;">4</span>, <span style="line-height:1.6 !important;">4</span>, <span style="line-height:1.6 !important;">4</span>]<br style="line-height: 1.6 !important;"/>POSE_HIGH_RESOLUTION_NET.STAGE4.NUM_CHANNELS = [<span style="line-height:1.6 !important;">32</span>, <span style="line-height:1.6 !important;">64</span>, <span style="line-height:1.6 !important;">128</span>, <span style="line-height:1.6 !important;">256</span>]<br style="line-height: 1.6 !important;"/>POSE_HIGH_RESOLUTION_NET.STAGE4.BLOCK = <span style="line-height:1.6 !important;">'BASIC'</span><br style="line-height: 1.6 !important;"/>POSE_HIGH_RESOLUTION_NET.STAGE4.FUSE_METHOD = <span style="line-height:1.6 !important;">'SUM'</span><br style="line-height: 1.6 !important;"/></pre>
然後來看整個HRNet模型的構建, 由於整體代碼量太大,這裏僅僅來看forward函數。
<pre class="prettyprint"><span style="line-height: 1.6 !important;"><span style="line-height:1.6 !important;">def</span> <span style="line-height:1.6 !important;">forward</span><span style="line-height: 1.6 !important;">(self, x)</span>:</span><br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/> <span style="line-height:1.6 !important;font-style:italic;"># 使用兩個strided 3x3conv進行快速降維</span><br style="line-height: 1.6 !important;"/> x=self.relu(self.bn1(self.conv1(x)))<br style="line-height: 1.6 !important;"/> x=self.relu(self.bn2(self.conv2(x)))<br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/> <span style="line-height:1.6 !important;font-style:italic;"># 構建了一串BasicBlock構成的模塊</span><br style="line-height: 1.6 !important;"/> x=self.layer1(x)<br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/> <span style="line-height:1.6 !important;font-style:italic;"># 然後是多個stage,每個stage核心是調用HighResolutionModule模塊</span><br style="line-height: 1.6 !important;"/> x_list=[]<br style="line-height: 1.6 !important;"/> <span style="line-height:1.6 !important;">for</span> i <span style="line-height:1.6 !important;">in</span> range(self.stage2_cfg[<span style="line-height:1.6 !important;">'NUM_BRANCHES'</span>]):<br style="line-height: 1.6 !important;"/> <span style="line-height:1.6 !important;">if</span> self.transition1[i] <span style="line-height:1.6 !important;">is</span> <span style="line-height:1.6 !important;">not</span> <span style="line-height:1.6 !important;">None</span>:<br style="line-height: 1.6 !important;"/> x_list.append(self.transition1[i](x))<br style="line-height: 1.6 !important;"/> <span style="line-height:1.6 !important;">else</span>:<br style="line-height: 1.6 !important;"/> x_list.append(x)<br style="line-height: 1.6 !important;"/> y_list=self.stage2(x_list)<br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/> x_list=[]<br style="line-height: 1.6 !important;"/> <span style="line-height:1.6 !important;">for</span> i <span style="line-height:1.6 !important;">in</span> range(self.stage3_cfg[<span style="line-height:1.6 !important;">'NUM_BRANCHES'</span>]):<br style="line-height: 1.6 !important;"/> <span style="line-height:1.6 !important;">if</span> self.transition2[i] <span style="line-height:1.6 !important;">is</span> <span style="line-height:1.6 !important;">not</span> <span style="line-height:1.6 !important;">None</span>:<br style="line-height: 1.6 !important;"/> x_list.append(self.transition2[i](y_list[<span style="line-height:1.6 !important;">-1</span>]))<br style="line-height: 1.6 !important;"/> <span style="line-height:1.6 !important;">else</span>:<br style="line-height: 1.6 !important;"/> x_list.append(y_list[i])<br style="line-height: 1.6 !important;"/> y_list=self.stage3(x_list)<br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/> x_list=[]<br style="line-height: 1.6 !important;"/> <span style="line-height:1.6 !important;">for</span> i <span style="line-height:1.6 !important;">in</span> range(self.stage4_cfg[<span style="line-height:1.6 !important;">'NUM_BRANCHES'</span>]):<br style="line-height: 1.6 !important;"/> <span style="line-height:1.6 !important;">if</span> self.transition3[i] <span style="line-height:1.6 !important;">is</span> <span style="line-height:1.6 !important;">not</span> <span style="line-height:1.6 !important;">None</span>:<br style="line-height: 1.6 !important;"/> x_list.append(self.transition3[i](y_list[<span style="line-height:1.6 !important;">-1</span>]))<br style="line-height: 1.6 !important;"/> <span style="line-height:1.6 !important;">else</span>:<br style="line-height: 1.6 !important;"/> x_list.append(y_list[i])<br style="line-height: 1.6 !important;"/> y_list=self.stage4(x_list)<br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/> <span style="line-height:1.6 !important;font-style:italic;"># 添加分類頭,上文中有顯示,在分類問題中添加這種頭</span><br style="line-height: 1.6 !important;"/> <span style="line-height:1.6 !important;font-style:italic;"># 在其他問題中換用不同的頭</span><br style="line-height: 1.6 !important;"/> y=self.incre_modules[<span style="line-height:1.6 !important;">0</span>](y_list[<span style="line-height:1.6 !important;">0</span>])<br style="line-height: 1.6 !important;"/> <span style="line-height:1.6 !important;">for</span> i <span style="line-height:1.6 !important;">in</span> range(len(self.downsamp_modules)):<br style="line-height: 1.6 !important;"/> y=self.incre_modules[i+<span style="line-height:1.6 !important;">1</span>](y_list[i+<span style="line-height:1.6 !important;">1</span>]) + \<br style="line-height: 1.6 !important;"/> self.downsamp_modules[i](y)<br style="line-height: 1.6 !important;"/> y=self.final_layer(y)<br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/> <span style="line-height:1.6 !important;">if</span> torch._C._get_tracing_state():<br style="line-height: 1.6 !important;"/> <span style="line-height:1.6 !important;font-style:italic;"># 在不寫C代碼的情況下執行forward,直接用python版本</span><br style="line-height: 1.6 !important;"/> y=y.flatten(start_dim=<span style="line-height:1.6 !important;">2</span>).mean(dim=<span style="line-height:1.6 !important;">2</span>)<br style="line-height: 1.6 !important;"/> <span style="line-height:1.6 !important;">else</span>:<br style="line-height: 1.6 !important;"/> y=F.avg_pool2d(y, kernel_size=y.size()<br style="line-height: 1.6 !important;"/> [<span style="line-height:1.6 !important;">2</span>:]).view(y.size(<span style="line-height:1.6 !important;">0</span>), <span style="line-height:1.6 !important;">-1</span>)<br style="line-height: 1.6 !important;"/> y=self.classifier(y)<br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/> <span style="line-height:1.6 !important;">return</span> y<br style="line-height: 1.6 !important;"/></pre>
<h4>總結</h4>
HRNet核心方法是:在模型的整個過程中,保存高分辨率表徵的同時使用讓不同分辨率的feature map進行特徵交互。
HRNet在非常多的CV領域有廣泛的應用,比如ICCV2019的東北虎關鍵點識別比賽中,HRNet就起到了一定的作用。並且在分類部分的實驗證明了在同等參數量的情況下,可以取代ResNet進行分類。
之前看鄭安坤大佬的一篇文章 CNN結構設計技巧-兼顧速度精度與工程實現
中提到了一點:
senet是hrnet的一個特例,hrnet不僅有通道注意力,同時也有空間注意力 -- akkaze-鄭安坤
<figure>
<img src="https://img0.tuicool.com/jE7vuyB.png!web" class="alignCenter" referrerpolicy="no-referrer"/>
<figcaption>SELayer核心實現</figcaption>
</figure>
SELayer首先通過一個全局平均池化得到一個一維向量,然後通過兩個全連接層,將信息進行壓縮和擴展,通過sigmoid以後得到每個通道的權值,然後用這個權值與原來的feature map相乘,進行信息上的優化。
<figure>
<img src="https://img2.tuicool.com/j6ziErV.png!web" class="alignCenter" referrerpolicy="no-referrer"/>
<figcaption>HRNet一個結構</figcaption>
</figure>
可以看到上圖用紅色箭頭串起來的是不是和SELayer很相似。爲什麼說SENet是HRNet的一個特例,但從這個結構來講,可以這麼看:
<ul>
<li>
<p class="no-text-indent">SENet沒有像HRNet這樣分辨率變爲原來的一半,分辨率直接變爲1x1,比較極端。變爲1x1向量以後,SENet中使用了兩個全連接網絡來學習通道的特徵分佈;但是在HRNet中,使用了幾個卷積(Residual block)來學習特徵。</p>
</li>
<li>
<p class="no-text-indent">SENet在主幹部分(高分辨率分支)沒有安排卷積進行特徵的學習;HRNet在主幹部分(高分辨率分支)安排了幾個卷積(Residual block)來學習特徵。</p>
</li>
<li>
<p class="no-text-indent">特徵融合部分SENet和HRNet區分比較大,SENet使用的對應通道相乘的方法,HRNet則使用的是相加。之所以說SENet是通道注意力機制是因爲通過全局平均池化後沒有了空間特徵,只剩通道的特徵;HRNet則可以看作同時保留了空間特徵和通道特徵,所以說HRNet不僅有通道注意力,同時也有空間注意力。</p>
</li>
</ul>
HRNet團隊有10人之多,構建了分類、分割、檢測、關鍵點檢測等庫,工作量非常大,而且做了很多紮實的實驗證明了這種思路的有效性。所以是否可以認爲HRNet屬於SENet之後又一個更優的backbone呢?還需要自己實踐中使用這種想法和思路來驗證。
<h4>參考</h4>
https://arxiv.org/pdf/1908.07919
https://www.bilibili.com/video/BV1WJ41197dh?t=508
https://github.com/HRNet
</div>
- End -
<p style="text-align:center">
<img src="https://img0.tuicool.com/6Fj63a6.jpg!web" class="alignCenter" referrerpolicy="no-referrer"/>
</p>
<mpcps frameborder="0"/>
<span>京東滿100-50優惠活動等你來參與!</span>
<span>✄------------------------------------------------</span>
<span>看到這裏,說明你喜歡這篇文章,請點擊「</span>
<span>在看</span>
<span>」或順手「</span>
<span>轉發</span>
<span>」「</span>
<span>點贊</span>
<span>」。</span>
<span>歡迎微信搜索「</span>
<span>panchuangxx</span>
<span>」,添加小編</span>
<span>磐小小仙</span>
<span>微信,每日朋友圈更新一篇高質量推文(無廣告),爲您提供更多精彩內容。</span>
<span>▼ </span>
<span>▼</span>
<span> </span>
<span> 掃描二維碼添加小編</span>
<span> ▼</span>
<span>
<span> </span>
▼
<p style="text-align:center">
<img src="https://img0.tuicool.com/ym2Yny6.jpg!web" class="alignCenter" referrerpolicy="no-referrer"/>
</p>