<meta http-equiv="Content-Type" content="text/html; charset=UTF-8"/><div>
<div>
    <h2>
        <img src="https://img1.tuicool.com/vee6RnZ.jpg!web" class="alignCenter" referrerpolicy="no-referrer"/>
    </h2>
    <h3>
        <div>


                    <span>  磐創AI分享</span>
                                    <strong>
                    <span>  </span>
                </strong>
                <strong/>

        </div>
    </h3>
    來源 | GiantPandaCV
    編輯 | pprp
    <h4>
        <span>
            <strong>【導讀】</strong>
        </span>
        <span>HRNet是微軟亞洲研究院的王井東老師領導的團隊完成的,打通圖像分類、圖像分割、目標檢測、人臉對齊、姿態識別、風格遷移、Image Inpainting、超分、optical flow、Depth estimation、邊緣檢測等網絡結構。</span>
    </h4>
    王老師在ValseWebinar《物體和關鍵點檢測》中親自講解了HRNet,講解地非常透徹。以下文章主要參考了王老師在演講中的解讀,配合論文+代碼部分,來爲各位讀者介紹這個全能的Backbone-HRNet。
    引入
    <figure>
        <img src="https://img0.tuicool.com/feURFb2.png!web" class="alignCenter" referrerpolicy="no-referrer"/>
        <figcaption>網絡結構設計思路</figcaption>
    </figure>
    在人體姿態識別這類的任務中,需要生成一個高分辨率的heatmap來進行關鍵點檢測。這就與一般的網絡結構比如VGGNet的要求不同,因爲VGGNet最終得到的feature map分辨率很低,損失了空間結構。
    <figure>
        <img src="https://img1.tuicool.com/rayemqu.png!web" class="alignCenter" referrerpolicy="no-referrer"/>
        <figcaption>傳統的解決思路</figcaption>
    </figure>
    獲取高分辨率的方式大部分都是如上圖所示,採用的是先降分辨率,然後再升分辨率的方法。U-Net、SegNet、DeconvNet、Hourglass本質上都是這種結構。
    <figure>
        <img src="https://img2.tuicool.com/i6Z3Qv3.png!web" class="alignCenter" referrerpolicy="no-referrer"/>
        <figcaption>雖然看上去不同,但是本質是一致的</figcaption>
    </figure>
    <h4>核心</h4>
    普通網絡都是這種結構,不同分辨率之間是進行了串聯
    <figure>
        <img src="https://img1.tuicool.com/EfiABju.png!web" class="alignCenter" referrerpolicy="no-referrer"/>
        <figcaption>不斷降分辨率</figcaption>
    </figure>
    王井東老師則是將不同分辨率的feature map進行並聯:
    <figure>
        <img src="https://img0.tuicool.com/JRv6Jr2.png!web" class="alignCenter" referrerpolicy="no-referrer"/>
        <figcaption>並聯不同分辨率feature map</figcaption>
    </figure>
    在並聯的基礎上,添加不同分辨率feature map之間的交互(fusion)。
    <figure>
        <img src="https://img2.tuicool.com/j2IZZzq.png!web" class="alignCenter" referrerpolicy="no-referrer"/>
    </figure>
    具體fusion的方法如下圖所示:
    <figure>
        <img src="https://img0.tuicool.com/Qv6VZjE.png!web" class="alignCenter" referrerpolicy="no-referrer"/>
    </figure>
    <ul>
        <li>
            <p class="no-text-indent">同分辨率的層直接複製。</p>
        </li>
        <li>
            <p class="no-text-indent">需要升分辨率的使用bilinear upsample + 1x1卷積將channel數統一。</p>
        </li>
        <li>
            <p class="no-text-indent">需要降分辨率的使用strided 3x3 卷積。</p>
        </li>
        <li>
            <p class="no-text-indent">三個feature map融合的方式是相加。</p>
        </li>
    </ul>
    至於爲何要用strided 3x3卷積,這是因爲卷積在降維的時候會出現信息損失,使用strided 3x3卷積是爲了通過學習的方式,降低信息的損耗。所以這裏沒有用maxpool或者組合池化。
    <figure>
        <img src="https://img2.tuicool.com/buINryY.png!web" class="alignCenter" referrerpolicy="no-referrer"/>
        <figcaption>HR示意圖</figcaption>
    </figure>
    另外在讀HRNet的時候會有一個問題,有四個分支的到底如何使用這幾個分支呢?論文中也給出了幾種方式作爲最終的特徵選擇。
    <figure>
        <img src="https://img1.tuicool.com/EJ3U7vb.png!web" class="alignCenter" referrerpolicy="no-referrer"/>
        <figcaption>三種特徵融合方法</figcaption>
    </figure>
    (a)圖展示的是HRNetV1的特徵選擇,只使用分辨率最高的特徵圖。
    (b)圖展示的是HRNetV2的特徵選擇,將所有分辨率的特徵圖(小的特徵圖進行upsample)進行concate,主要用於語義分割和麪部關鍵點檢測。
    (c)圖展示的是HRNetV2p的特徵選擇,在HRNetV2的基礎上,使用了一個特徵金字塔,主要用於目標檢測網絡。
    再補充一個(d)圖
    <figure>
        <img src="https://img2.tuicool.com/a6N3eyq.png!web" class="alignCenter" referrerpolicy="no-referrer"/>
        <figcaption>HRNetV2分類網絡後的特徵選擇</figcaption>
    </figure>
    (d)圖展示的也是HRNetV2,採用上圖的融合方式,主要用於訓練分類網絡。
總結一下HRNet 創新點
  • 將高低分辨率之間的鏈接由串聯改爲並聯。

  • 在整個網絡結構中都保持了高分辨率的表徵(最上邊那個通路)。

  • 在高低分辨率中引入了交互來提高模型性能。

    <h4>效果</h4>
    消融實驗
    <ol>
        <li>
            <p class="no-text-indent">對交互方法進行消融實驗,證明了當前跨分辨率的融合的有效性。</p>
        </li>
    </ol>
    <figure>
        <img src="https://img2.tuicool.com/Y3QvA3R.png!web" class="alignCenter" referrerpolicy="no-referrer"/>
        <figcaption>交互方法的消融實現</figcaption>
    </figure>
    <ol start="2">
        <li>
            <p class="no-text-indent">證明高分辨率feature map的表徵能力</p>
        </li>
    </ol>
    <figure>
        <img src="https://img2.tuicool.com/7baMviM.png!web" class="alignCenter" referrerpolicy="no-referrer"/>
    </figure>
    1x代表不進行降維,2x代表分辨率變爲原來一半,4x代表分辨率變爲原來四分之一。W32、W48中的32、48代表卷積的寬度或者通道數。
    姿態識別任務上的表現
    <figure>
        <img src="https://img2.tuicool.com/iamIz2U.png!web" class="alignCenter" referrerpolicy="no-referrer"/>
    </figure>
    以上的姿態識別採用的是top-down的方法。
    <figure>
        <img src="https://img2.tuicool.com/Qb2EfqB.png!web" class="alignCenter" referrerpolicy="no-referrer"/>
        <figcaption>COCO驗證集的結果</figcaption>
    </figure>
    在參數和計算量不增加的情況下,要比其他同類網絡效果好很多。
    <figure>
        <img src="https://img0.tuicool.com/JR3UbiZ.png!web" class="alignCenter" referrerpolicy="no-referrer"/>
        <figcaption>COCO測試集上的結果</figcaption>
    </figure>
    在19年2月28日時的PoseTrack Leaderboard,HRNet佔領兩個項目的第一名。
    <figure>
        <img src="https://img0.tuicool.com/Nf26vyZ.png!web" class="alignCenter" referrerpolicy="no-referrer"/>
        <figcaption>PoseTrack Leaderboard</figcaption>
    </figure>
    <div>
        <h4>語義分割任務中的表現</h4>
        <p style="text-align:center">
            <img src="https://img0.tuicool.com/RZFJvuJ.png!web" class="alignCenter" referrerpolicy="no-referrer"/>
        </p>
    </div>
    <figure>
        <div>CityScape驗證集上的結果對比</div>
    </figure>
    <figure>
        <img src="https://img1.tuicool.com/vUBv2qm.png!web" class="alignCenter" referrerpolicy="no-referrer"/>
        <figcaption>Cityscapes測試集上的對比</figcaption>
    </figure>
    <h4>目標檢測任務中的表現</h4>
    <figure>
        <img src="https://img1.tuicool.com/JBNf2un.png!web" class="alignCenter" referrerpolicy="no-referrer"/>
        <figcaption>單模型單尺度模型對比</figcaption>
    </figure>
    <figure>
        <img src="https://img0.tuicool.com/e2a26fF.png!web" class="alignCenter" referrerpolicy="no-referrer"/>
        <figcaption>Mask R-CNN上結果</figcaption>
    </figure>
    分類任務上的表現
    <figure>
        <img src="https://img0.tuicool.com/nINfEry.png!web" class="alignCenter" referrerpolicy="no-referrer"/>
    </figure>
    ps: 王井東老師在這部分提到,分割的網絡也需要使用分類的預訓練模型,否則結果會差幾個點。
    <figure>
        <img src="https://img2.tuicool.com/FN3Y3mV.png!web" class="alignCenter" referrerpolicy="no-referrer"/>
        <figcaption>圖像分類任務中和ResNet進行對比</figcaption>
    </figure>
    以上是HRNet和ResNet結果對比,同一個顏色的都是參數量大體一致的模型進行的對比,在參數量差不多甚至更少的情況下,HRNet能夠比ResNet達到更好的效果。
    HRNet( https://github.com/HRNet )工作量非常大,構建了六個庫涉及語義分割、人體姿態檢測、目標檢測、圖片分類、面部關鍵點檢測、Mask R-CNN等庫。全部內容如下圖所示:
    <figure>
        <img src="https://img0.tuicool.com/ru6nErZ.png!web" class="alignCenter" referrerpolicy="no-referrer"/>
    </figure>
    筆者對HRNet代碼構建非常感興趣,所以以HRNet-Image-Classification庫爲例,來解析一下這部分代碼。
    先從簡單的入手,BasicBlock
    <figure>
        <img src="https://img0.tuicool.com/AB3amym.png!web" class="alignCenter" referrerpolicy="no-referrer"/>
        <figcaption>BasicBlock結構</figcaption>
    </figure>
    <pre class="prettyprint"><span style="line-height: 1.6 !important;"><span style="line-height:1.6 !important;">def</span> <span style="line-height:1.6 !important;">conv3x3</span><span style="line-height: 1.6 !important;">(in_planes, out_planes, stride=<span style="line-height:1.6 !important;">1</span>)</span>:</span><br style="line-height: 1.6 !important;"/>    <span style="line-height:1.6 !important;">"""3x3 convolution with padding"""</span><br style="line-height: 1.6 !important;"/>    <span style="line-height:1.6 !important;">return</span> nn.Conv2d(in_planes, out_planes, kernel_size=<span style="line-height:1.6 !important;">3</span>, stride=stride,<br style="line-height: 1.6 !important;"/>                     padding=<span style="line-height:1.6 !important;">1</span>, bias=<span style="line-height:1.6 !important;">False</span>)<br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/><span style="line-height: 1.6 !important;"><span style="line-height:1.6 !important;">class</span> <span style="line-height:1.6 !important;">BasicBlock</span><span style="line-height: 1.6 !important;">(nn.Module)</span>:</span><br style="line-height: 1.6 !important;"/>    expansion = <span style="line-height:1.6 !important;">1</span><br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/>    <span style="line-height: 1.6 !important;"><span style="line-height:1.6 !important;">def</span> <span style="line-height:1.6 !important;">__init__</span><span style="line-height: 1.6 !important;">(self, inplanes, planes, stride=<span style="line-height:1.6 !important;">1</span>, downsample=None)</span>:</span><br style="line-height: 1.6 !important;"/>        super(BasicBlock, self).__init__()<br style="line-height: 1.6 !important;"/>        self.conv1 = conv3x3(inplanes, planes, stride)<br style="line-height: 1.6 !important;"/>        self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)<br style="line-height: 1.6 !important;"/>        self.relu = nn.ReLU(inplace=<span style="line-height:1.6 !important;">True</span>)<br style="line-height: 1.6 !important;"/>        self.conv2 = conv3x3(planes, planes)<br style="line-height: 1.6 !important;"/>        self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)<br style="line-height: 1.6 !important;"/>        self.downsample = downsample<br style="line-height: 1.6 !important;"/>        self.stride = stride<br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/>    <span style="line-height: 1.6 !important;"><span style="line-height:1.6 !important;">def</span> <span style="line-height:1.6 !important;">forward</span><span style="line-height: 1.6 !important;">(self, x)</span>:</span><br style="line-height: 1.6 !important;"/>        residual = x<br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/>        out = self.conv1(x)<br style="line-height: 1.6 !important;"/>        out = self.bn1(out)<br style="line-height: 1.6 !important;"/>        out = self.relu(out)<br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/>        out = self.conv2(out)<br style="line-height: 1.6 !important;"/>        out = self.bn2(out)<br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/>        <span style="line-height:1.6 !important;">if</span> self.downsample <span style="line-height:1.6 !important;">is</span> <span style="line-height:1.6 !important;">not</span> <span style="line-height:1.6 !important;">None</span>:<br style="line-height: 1.6 !important;"/>            residual = self.downsample(x)<br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/>        out += residual<br style="line-height: 1.6 !important;"/>        out = self.relu(out)<br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/>        <span style="line-height:1.6 !important;">return</span> out<br style="line-height: 1.6 !important;"/></pre>
    Bottleneck:
    <figure>
        <img src="https://img1.tuicool.com/367B3um.png!web" class="alignCenter" referrerpolicy="no-referrer"/>
        <figcaption>Bottleneck結構圖</figcaption>
    </figure>
    <pre class="prettyprint"><span style="line-height: 1.6 !important;"><span style="line-height:1.6 !important;">class</span> <span style="line-height:1.6 !important;">Bottleneck</span><span style="line-height: 1.6 !important;">(nn.Module)</span>:</span><br style="line-height: 1.6 !important;"/>    expansion = <span style="line-height:1.6 !important;">4</span><br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/>    <span style="line-height: 1.6 !important;"><span style="line-height:1.6 !important;">def</span> <span style="line-height:1.6 !important;">__init__</span><span style="line-height: 1.6 !important;">(self, inplanes, planes, stride=<span style="line-height:1.6 !important;">1</span>, downsample=None)</span>:</span><br style="line-height: 1.6 !important;"/>        super(Bottleneck, self).__init__()<br style="line-height: 1.6 !important;"/>        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=<span style="line-height:1.6 !important;">1</span>, bias=<span style="line-height:1.6 !important;">False</span>)<br style="line-height: 1.6 !important;"/>        self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)<br style="line-height: 1.6 !important;"/>        self.conv2 = nn.Conv2d(planes, planes, kernel_size=<span style="line-height:1.6 !important;">3</span>, stride=stride,<br style="line-height: 1.6 !important;"/>                               padding=<span style="line-height:1.6 !important;">1</span>, bias=<span style="line-height:1.6 !important;">False</span>)<br style="line-height: 1.6 !important;"/>        self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)<br style="line-height: 1.6 !important;"/>        self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=<span style="line-height:1.6 !important;">1</span>,<br style="line-height: 1.6 !important;"/>                               bias=<span style="line-height:1.6 !important;">False</span>)<br style="line-height: 1.6 !important;"/>        self.bn3 = nn.BatchNorm2d(planes * self.expansion,<br style="line-height: 1.6 !important;"/>                                  momentum=BN_MOMENTUM)<br style="line-height: 1.6 !important;"/>        self.relu = nn.ReLU(inplace=<span style="line-height:1.6 !important;">True</span>)<br style="line-height: 1.6 !important;"/>        self.downsample = downsample<br style="line-height: 1.6 !important;"/>        self.stride = stride<br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/>    <span style="line-height: 1.6 !important;"><span style="line-height:1.6 !important;">def</span> <span style="line-height:1.6 !important;">forward</span><span style="line-height: 1.6 !important;">(self, x)</span>:</span><br style="line-height: 1.6 !important;"/>        residual = x<br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/>        out = self.conv1(x)<br style="line-height: 1.6 !important;"/>        out = self.bn1(out)<br style="line-height: 1.6 !important;"/>        out = self.relu(out)<br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/>        out = self.conv2(out)<br style="line-height: 1.6 !important;"/>        out = self.bn2(out)<br style="line-height: 1.6 !important;"/>        out = self.relu(out)<br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/>        out = self.conv3(out)<br style="line-height: 1.6 !important;"/>        out = self.bn3(out)<br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/>        <span style="line-height:1.6 !important;">if</span> self.downsample <span style="line-height:1.6 !important;">is</span> <span style="line-height:1.6 !important;">not</span> <span style="line-height:1.6 !important;">None</span>:<br style="line-height: 1.6 !important;"/>            residual = self.downsample(x)<br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/>        out += residual<br style="line-height: 1.6 !important;"/>        out = self.relu(out)<br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/>        <span style="line-height:1.6 !important;">return</span> out<br style="line-height: 1.6 !important;"/></pre>
    HighResolutionModule,這是核心模塊, 主要分爲兩個組件:branches和fuse layer。
    <figure>
        <img src="https://img1.tuicool.com/ZfeU3ej.jpg!web" class="alignCenter" referrerpolicy="no-referrer"/>
    </figure>
    <pre class="prettyprint"><span style="line-height: 1.6 !important;"><span style="line-height:1.6 !important;">class</span> <span style="line-height:1.6 !important;">HighResolutionModule</span><span style="line-height: 1.6 !important;">(nn.Module)</span>:</span><br style="line-height: 1.6 !important;"/>    <span style="line-height: 1.6 !important;"><span style="line-height:1.6 !important;">def</span> <span style="line-height:1.6 !important;">__init__</span><span style="line-height: 1.6 !important;">(self, num_branches, blocks, num_blocks, num_inchannels,<br style="line-height: 1.6 !important;"/>                 num_channels, fuse_method, multi_scale_output=True)</span>:</span><br style="line-height: 1.6 !important;"/>        <span style="line-height:1.6 !important;">'''<br style="line-height: 1.6 !important;"/>        調用:<br style="line-height: 1.6 !important;"/>        # 調用高低分辨率交互模塊, stage2 爲例<br style="line-height: 1.6 !important;"/>        HighResolutionModule(num_branches, # 2<br style="line-height: 1.6 !important;"/>                             block, # 'BASIC'<br style="line-height: 1.6 !important;"/>                             num_blocks, # [4, 4]<br style="line-height: 1.6 !important;"/>                             num_inchannels, # 上個stage的out channel<br style="line-height: 1.6 !important;"/>                             num_channels, # [32, 64]<br style="line-height: 1.6 !important;"/>                             fuse_method, # SUM<br style="line-height: 1.6 !important;"/>                             reset_multi_scale_output)<br style="line-height: 1.6 !important;"/>        '''</span><br style="line-height: 1.6 !important;"/>        super(HighResolutionModule, self).__init__()<br style="line-height: 1.6 !important;"/>        self._check_branches(<br style="line-height: 1.6 !important;"/>            <span style="line-height:1.6 !important;font-style:italic;"># 檢查分支數目是否合理</span><br style="line-height: 1.6 !important;"/>            num_branches, blocks, num_blocks, num_inchannels, num_channels)<br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/>        self.num_inchannels = num_inchannels<br style="line-height: 1.6 !important;"/>        <span style="line-height:1.6 !important;font-style:italic;"># 融合選用相加的方式</span><br style="line-height: 1.6 !important;"/>        self.fuse_method = fuse_method<br style="line-height: 1.6 !important;"/>        self.num_branches = num_branches<br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/>        self.multi_scale_output = multi_scale_output<br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/>        <span style="line-height:1.6 !important;font-style:italic;"># 兩個核心部分,一個是branches構建,一個是融合layers構建</span><br style="line-height: 1.6 !important;"/>        self.branches = self._make_branches(<br style="line-height: 1.6 !important;"/>            num_branches, blocks, num_blocks, num_channels)<br style="line-height: 1.6 !important;"/>        self.fuse_layers = self._make_fuse_layers()<br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/>        self.relu = nn.ReLU(<span style="line-height:1.6 !important;">False</span>)<br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/>    <span style="line-height: 1.6 !important;"><span style="line-height:1.6 !important;">def</span> <span style="line-height:1.6 !important;">_check_branches</span><span style="line-height: 1.6 !important;">(self, num_branches, blocks, num_blocks,<br style="line-height: 1.6 !important;"/>                        num_inchannels, num_channels)</span>:</span><br style="line-height: 1.6 !important;"/>        <span style="line-height:1.6 !important;font-style:italic;"># 分別檢查參數是否符合要求,看models.py中的參數,blocks參數冗餘了</span><br style="line-height: 1.6 !important;"/>        <span style="line-height:1.6 !important;">if</span> num_branches != len(num_blocks):<br style="line-height: 1.6 !important;"/>            error_msg = <span style="line-height:1.6 !important;">'NUM_BRANCHES({}) &lt;&gt; NUM_BLOCKS({})'</span>.format(<br style="line-height: 1.6 !important;"/>                num_branches, len(num_blocks))<br style="line-height: 1.6 !important;"/>            logger.error(error_msg)<br style="line-height: 1.6 !important;"/>            <span style="line-height:1.6 !important;">raise</span> ValueError(error_msg)<br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/>        <span style="line-height:1.6 !important;">if</span> num_branches != len(num_channels):<br style="line-height: 1.6 !important;"/>            error_msg = <span style="line-height:1.6 !important;">'NUM_BRANCHES({}) &lt;&gt; NUM_CHANNELS({})'</span>.format(<br style="line-height: 1.6 !important;"/>                num_branches, len(num_channels))<br style="line-height: 1.6 !important;"/>            logger.error(error_msg)<br style="line-height: 1.6 !important;"/>            <span style="line-height:1.6 !important;">raise</span> ValueError(error_msg)<br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/>        <span style="line-height:1.6 !important;">if</span> num_branches != len(num_inchannels):<br style="line-height: 1.6 !important;"/>            error_msg = <span style="line-height:1.6 !important;">'NUM_BRANCHES({}) &lt;&gt; NUM_INCHANNELS({})'</span>.format(<br style="line-height: 1.6 !important;"/>                num_branches, len(num_inchannels))<br style="line-height: 1.6 !important;"/>            logger.error(error_msg)<br style="line-height: 1.6 !important;"/>            <span style="line-height:1.6 !important;">raise</span> ValueError(error_msg)<br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/>    <span style="line-height: 1.6 !important;"><span style="line-height:1.6 !important;">def</span> <span style="line-height:1.6 !important;">_make_one_branch</span><span style="line-height: 1.6 !important;">(self, branch_index, block, num_blocks, num_channels,<br style="line-height: 1.6 !important;"/>                         stride=<span style="line-height:1.6 !important;">1</span>)</span>:</span><br style="line-height: 1.6 !important;"/>        <span style="line-height:1.6 !important;font-style:italic;"># 構建一個分支,一個分支重複num_blocks個block</span><br style="line-height: 1.6 !important;"/>        downsample = <span style="line-height:1.6 !important;">None</span><br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/>        <span style="line-height:1.6 !important;font-style:italic;"># 這裏判斷,如果通道變大(分辨率變小),則使用下采樣</span><br style="line-height: 1.6 !important;"/>        <span style="line-height:1.6 !important;">if</span> stride != <span style="line-height:1.6 !important;">1</span> <span style="line-height:1.6 !important;">or</span> \<br style="line-height: 1.6 !important;"/>           self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion:<br style="line-height: 1.6 !important;"/>            downsample = nn.Sequential(<br style="line-height: 1.6 !important;"/>                nn.Conv2d(self.num_inchannels[branch_index],<br style="line-height: 1.6 !important;"/>                          num_channels[branch_index] * block.expansion,<br style="line-height: 1.6 !important;"/>                          kernel_size=<span style="line-height:1.6 !important;">1</span>, stride=stride, bias=<span style="line-height:1.6 !important;">False</span>),<br style="line-height: 1.6 !important;"/>                nn.BatchNorm2d(num_channels[branch_index] * block.expansion,<br style="line-height: 1.6 !important;"/>                               momentum=BN_MOMENTUM),<br style="line-height: 1.6 !important;"/>            )<br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/>        layers = []<br style="line-height: 1.6 !important;"/>        layers.append(block(self.num_inchannels[branch_index],<br style="line-height: 1.6 !important;"/>                            num_channels[branch_index], stride, downsample))<br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/>        self.num_inchannels[branch_index] = \<br style="line-height: 1.6 !important;"/>            num_channels[branch_index] * block.expansion<br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/>        <span style="line-height:1.6 !important;">for</span> i <span style="line-height:1.6 !important;">in</span> range(<span style="line-height:1.6 !important;">1</span>, num_blocks[branch_index]):<br style="line-height: 1.6 !important;"/>            layers.append(block(self.num_inchannels[branch_index],<br style="line-height: 1.6 !important;"/>                                num_channels[branch_index]))<br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/>        <span style="line-height:1.6 !important;">return</span> nn.Sequential(*layers)<br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/>    <span style="line-height: 1.6 !important;"><span style="line-height:1.6 !important;">def</span> <span style="line-height:1.6 !important;">_make_branches</span><span style="line-height: 1.6 !important;">(self, num_branches, block, num_blocks, num_channels)</span>:</span><br style="line-height: 1.6 !important;"/>        branches = []<br style="line-height: 1.6 !important;"/>        <br style="line-height: 1.6 !important;"/>        <span style="line-height:1.6 !important;font-style:italic;"># 通過循環構建多分支,每個分支屬於不同的分辨率</span><br style="line-height: 1.6 !important;"/>        <span style="line-height:1.6 !important;">for</span> i <span style="line-height:1.6 !important;">in</span> range(num_branches):<br style="line-height: 1.6 !important;"/>            branches.append(<br style="line-height: 1.6 !important;"/>                self._make_one_branch(i, block, num_blocks, num_channels))<br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/>        <span style="line-height:1.6 !important;">return</span> nn.ModuleList(branches)<br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/>    <span style="line-height: 1.6 !important;"><span style="line-height:1.6 !important;">def</span> <span style="line-height:1.6 !important;">_make_fuse_layers</span><span style="line-height: 1.6 !important;">(self)</span>:</span><br style="line-height: 1.6 !important;"/>        <span style="line-height:1.6 !important;">if</span> self.num_branches == <span style="line-height:1.6 !important;">1</span>:<br style="line-height: 1.6 !important;"/>            <span style="line-height:1.6 !important;">return</span> <span style="line-height:1.6 !important;">None</span><br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/>        num_branches = self.num_branches <span style="line-height:1.6 !important;font-style:italic;"># 2</span><br style="line-height: 1.6 !important;"/>        num_inchannels = self.num_inchannels<br style="line-height: 1.6 !important;"/>        fuse_layers = []<br style="line-height: 1.6 !important;"/>        <span style="line-height:1.6 !important;">for</span> i <span style="line-height:1.6 !important;">in</span> range(num_branches <span style="line-height:1.6 !important;">if</span> self.multi_scale_output <span style="line-height:1.6 !important;">else</span> <span style="line-height:1.6 !important;">1</span>):<br style="line-height: 1.6 !important;"/>            <span style="line-height:1.6 !important;font-style:italic;"># i代表枚舉所有分支</span><br style="line-height: 1.6 !important;"/>            fuse_layer = []<br style="line-height: 1.6 !important;"/>            <span style="line-height:1.6 !important;">for</span> j <span style="line-height:1.6 !important;">in</span> range(num_branches):<br style="line-height: 1.6 !important;"/>                <span style="line-height:1.6 !important;font-style:italic;"># j代表處理的當前分支</span><br style="line-height: 1.6 !important;"/>                <span style="line-height:1.6 !important;">if</span> j &gt; i: <span style="line-height:1.6 !important;font-style:italic;"># 進行上採樣,使用最近鄰插值</span><br style="line-height: 1.6 !important;"/>                    fuse_layer.append(nn.Sequential(<br style="line-height: 1.6 !important;"/>                        nn.Conv2d(num_inchannels[j],<br style="line-height: 1.6 !important;"/>                                  num_inchannels[i],<br style="line-height: 1.6 !important;"/>                                  <span style="line-height:1.6 !important;">1</span>,<br style="line-height: 1.6 !important;"/>                                  <span style="line-height:1.6 !important;">1</span>,<br style="line-height: 1.6 !important;"/>                                  <span style="line-height:1.6 !important;">0</span>,<br style="line-height: 1.6 !important;"/>                                  bias=<span style="line-height:1.6 !important;">False</span>),<br style="line-height: 1.6 !important;"/>                        nn.BatchNorm2d(num_inchannels[i],<br style="line-height: 1.6 !important;"/>                                       momentum=BN_MOMENTUM),<br style="line-height: 1.6 !important;"/>                        nn.Upsample(scale_factor=<span style="line-height:1.6 !important;">2</span>**(j-i), mode=<span style="line-height:1.6 !important;">'nearest'</span>)))<br style="line-height: 1.6 !important;"/>                <span style="line-height:1.6 !important;">elif</span> j == i:<br style="line-height: 1.6 !important;"/>                    <span style="line-height:1.6 !important;font-style:italic;"># 本層不做處理</span><br style="line-height: 1.6 !important;"/>                    fuse_layer.append(<span style="line-height:1.6 !important;">None</span>)<br style="line-height: 1.6 !important;"/>                <span style="line-height:1.6 !important;">else</span>:<br style="line-height: 1.6 !important;"/>                    conv3x3s = []<br style="line-height: 1.6 !important;"/>                    <span style="line-height:1.6 !important;font-style:italic;"># 進行strided 3x3 conv下采樣,如果跨兩層,就使用兩次strided 3x3 conv</span><br style="line-height: 1.6 !important;"/>                    <span style="line-height:1.6 !important;">for</span> k <span style="line-height:1.6 !important;">in</span> range(i-j):<br style="line-height: 1.6 !important;"/>                        <span style="line-height:1.6 !important;">if</span> k == i - j - <span style="line-height:1.6 !important;">1</span>:<br style="line-height: 1.6 !important;"/>                            num_outchannels_conv3x3 = num_inchannels[i]<br style="line-height: 1.6 !important;"/>                            conv3x3s.append(nn.Sequential(<br style="line-height: 1.6 !important;"/>                                nn.Conv2d(num_inchannels[j],<br style="line-height: 1.6 !important;"/>                                          num_outchannels_conv3x3,<br style="line-height: 1.6 !important;"/>                                          <span style="line-height:1.6 !important;">3</span>, <span style="line-height:1.6 !important;">2</span>, <span style="line-height:1.6 !important;">1</span>, bias=<span style="line-height:1.6 !important;">False</span>),<br style="line-height: 1.6 !important;"/>                                nn.BatchNorm2d(num_outchannels_conv3x3,<br style="line-height: 1.6 !important;"/>                                               momentum=BN_MOMENTUM)))<br style="line-height: 1.6 !important;"/>                        <span style="line-height:1.6 !important;">else</span>:<br style="line-height: 1.6 !important;"/>                            num_outchannels_conv3x3 = num_inchannels[j]<br style="line-height: 1.6 !important;"/>                            conv3x3s.append(nn.Sequential(<br style="line-height: 1.6 !important;"/>                                nn.Conv2d(num_inchannels[j],<br style="line-height: 1.6 !important;"/>                                          num_outchannels_conv3x3,<br style="line-height: 1.6 !important;"/>                                          <span style="line-height:1.6 !important;">3</span>, <span style="line-height:1.6 !important;">2</span>, <span style="line-height:1.6 !important;">1</span>, bias=<span style="line-height:1.6 !important;">False</span>),<br style="line-height: 1.6 !important;"/>                                nn.BatchNorm2d(num_outchannels_conv3x3,<br style="line-height: 1.6 !important;"/>                                nn.ReLU(<span style="line-height:1.6 !important;">False</span>)))<br style="line-height: 1.6 !important;"/>                    fuse_layer.append(nn.Sequential(*conv3x3s))<br style="line-height: 1.6 !important;"/>            fuse_layers.append(nn.ModuleList(fuse_layer))<br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/>        <span style="line-height:1.6 !important;">return</span> nn.ModuleList(fuse_layers)<br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/>    <span style="line-height: 1.6 !important;"><span style="line-height:1.6 !important;">def</span> <span style="line-height:1.6 !important;">get_num_inchannels</span><span style="line-height: 1.6 !important;">(self)</span>:</span><br style="line-height: 1.6 !important;"/>        <span style="line-height:1.6 !important;">return</span> self.num_inchannels<br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/>    <span style="line-height: 1.6 !important;"><span style="line-height:1.6 !important;">def</span> <span style="line-height:1.6 !important;">forward</span><span style="line-height: 1.6 !important;">(self, x)</span>:</span><br style="line-height: 1.6 !important;"/>        <span style="line-height:1.6 !important;">if</span> self.num_branches == <span style="line-height:1.6 !important;">1</span>:<br style="line-height: 1.6 !important;"/>            <span style="line-height:1.6 !important;">return</span> [self.branches[<span style="line-height:1.6 !important;">0</span>](x[<span style="line-height:1.6 !important;">0</span>])]<br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/>        <span style="line-height:1.6 !important;">for</span> i <span style="line-height:1.6 !important;">in</span> range(self.num_branches):<br style="line-height: 1.6 !important;"/>            x[i]=self.branches[i](x[i])<br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/>        x_fuse=[]<br style="line-height: 1.6 !important;"/>        <span style="line-height:1.6 !important;">for</span> i <span style="line-height:1.6 !important;">in</span> range(len(self.fuse_layers)):<br style="line-height: 1.6 !important;"/>            y=x[<span style="line-height:1.6 !important;">0</span>] <span style="line-height:1.6 !important;">if</span> i == <span style="line-height:1.6 !important;">0</span> <span style="line-height:1.6 !important;">else</span> self.fuse_layers[i][<span style="line-height:1.6 !important;">0</span>](x[<span style="line-height:1.6 !important;">0</span>])<br style="line-height: 1.6 !important;"/>            <span style="line-height:1.6 !important;">for</span> j <span style="line-height:1.6 !important;">in</span> range(<span style="line-height:1.6 !important;">1</span>, self.num_branches):<br style="line-height: 1.6 !important;"/>                <span style="line-height:1.6 !important;">if</span> i == j:<br style="line-height: 1.6 !important;"/>                    y=y + x[j]<br style="line-height: 1.6 !important;"/>                <span style="line-height:1.6 !important;">else</span>:<br style="line-height: 1.6 !important;"/>                    y=y + self.fuse_layers[i][j](x[j])<br style="line-height: 1.6 !important;"/>            x_fuse.append(self.relu(y))<br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/>        <span style="line-height:1.6 !important;font-style:italic;"># 將fuse以後的多個分支結果保存到list中</span><br style="line-height: 1.6 !important;"/>        <span style="line-height:1.6 !important;">return</span> x_fuse<br style="line-height: 1.6 !important;"/></pre>
    models.py中保存的參數, 可以通過這些配置來改變模型的容量、分支個數、特徵融合方法:
    <pre class="prettyprint"><span style="line-height:1.6 !important;font-style:italic;"># high_resoluton_net related params for classification</span><br style="line-height: 1.6 !important;"/>POSE_HIGH_RESOLUTION_NET = CN()<br style="line-height: 1.6 !important;"/>POSE_HIGH_RESOLUTION_NET.PRETRAINED_LAYERS = [<span style="line-height:1.6 !important;">'*'</span>]<br style="line-height: 1.6 !important;"/>POSE_HIGH_RESOLUTION_NET.STEM_INPLANES = <span style="line-height:1.6 !important;">64</span><br style="line-height: 1.6 !important;"/>POSE_HIGH_RESOLUTION_NET.FINAL_CONV_KERNEL = <span style="line-height:1.6 !important;">1</span><br style="line-height: 1.6 !important;"/>POSE_HIGH_RESOLUTION_NET.WITH_HEAD = <span style="line-height:1.6 !important;">True</span><br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/>POSE_HIGH_RESOLUTION_NET.STAGE2 = CN()<br style="line-height: 1.6 !important;"/>POSE_HIGH_RESOLUTION_NET.STAGE2.NUM_MODULES = <span style="line-height:1.6 !important;">1</span><br style="line-height: 1.6 !important;"/>POSE_HIGH_RESOLUTION_NET.STAGE2.NUM_BRANCHES = <span style="line-height:1.6 !important;">2</span><br style="line-height: 1.6 !important;"/>POSE_HIGH_RESOLUTION_NET.STAGE2.NUM_BLOCKS = [<span style="line-height:1.6 !important;">4</span>, <span style="line-height:1.6 !important;">4</span>]<br style="line-height: 1.6 !important;"/>POSE_HIGH_RESOLUTION_NET.STAGE2.NUM_CHANNELS = [<span style="line-height:1.6 !important;">32</span>, <span style="line-height:1.6 !important;">64</span>]<br style="line-height: 1.6 !important;"/>POSE_HIGH_RESOLUTION_NET.STAGE2.BLOCK = <span style="line-height:1.6 !important;">'BASIC'</span><br style="line-height: 1.6 !important;"/>POSE_HIGH_RESOLUTION_NET.STAGE2.FUSE_METHOD = <span style="line-height:1.6 !important;">'SUM'</span><br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/>POSE_HIGH_RESOLUTION_NET.STAGE3 = CN()<br style="line-height: 1.6 !important;"/>POSE_HIGH_RESOLUTION_NET.STAGE3.NUM_MODULES = <span style="line-height:1.6 !important;">1</span><br style="line-height: 1.6 !important;"/>POSE_HIGH_RESOLUTION_NET.STAGE3.NUM_BRANCHES = <span style="line-height:1.6 !important;">3</span><br style="line-height: 1.6 !important;"/>POSE_HIGH_RESOLUTION_NET.STAGE3.NUM_BLOCKS = [<span style="line-height:1.6 !important;">4</span>, <span style="line-height:1.6 !important;">4</span>, <span style="line-height:1.6 !important;">4</span>]<br style="line-height: 1.6 !important;"/>POSE_HIGH_RESOLUTION_NET.STAGE3.NUM_CHANNELS = [<span style="line-height:1.6 !important;">32</span>, <span style="line-height:1.6 !important;">64</span>, <span style="line-height:1.6 !important;">128</span>]<br style="line-height: 1.6 !important;"/>POSE_HIGH_RESOLUTION_NET.STAGE3.BLOCK = <span style="line-height:1.6 !important;">'BASIC'</span><br style="line-height: 1.6 !important;"/>POSE_HIGH_RESOLUTION_NET.STAGE3.FUSE_METHOD = <span style="line-height:1.6 !important;">'SUM'</span><br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/>POSE_HIGH_RESOLUTION_NET.STAGE4 = CN()<br style="line-height: 1.6 !important;"/>POSE_HIGH_RESOLUTION_NET.STAGE4.NUM_MODULES = <span style="line-height:1.6 !important;">1</span><br style="line-height: 1.6 !important;"/>POSE_HIGH_RESOLUTION_NET.STAGE4.NUM_BRANCHES = <span style="line-height:1.6 !important;">4</span><br style="line-height: 1.6 !important;"/>POSE_HIGH_RESOLUTION_NET.STAGE4.NUM_BLOCKS = [<span style="line-height:1.6 !important;">4</span>, <span style="line-height:1.6 !important;">4</span>, <span style="line-height:1.6 !important;">4</span>, <span style="line-height:1.6 !important;">4</span>]<br style="line-height: 1.6 !important;"/>POSE_HIGH_RESOLUTION_NET.STAGE4.NUM_CHANNELS = [<span style="line-height:1.6 !important;">32</span>, <span style="line-height:1.6 !important;">64</span>, <span style="line-height:1.6 !important;">128</span>, <span style="line-height:1.6 !important;">256</span>]<br style="line-height: 1.6 !important;"/>POSE_HIGH_RESOLUTION_NET.STAGE4.BLOCK = <span style="line-height:1.6 !important;">'BASIC'</span><br style="line-height: 1.6 !important;"/>POSE_HIGH_RESOLUTION_NET.STAGE4.FUSE_METHOD = <span style="line-height:1.6 !important;">'SUM'</span><br style="line-height: 1.6 !important;"/></pre>
    然後來看整個HRNet模型的構建, 由於整體代碼量太大,這裏僅僅來看forward函數。
    <pre class="prettyprint"><span style="line-height: 1.6 !important;"><span style="line-height:1.6 !important;">def</span> <span style="line-height:1.6 !important;">forward</span><span style="line-height: 1.6 !important;">(self, x)</span>:</span><br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/>    <span style="line-height:1.6 !important;font-style:italic;"># 使用兩個strided 3x3conv進行快速降維</span><br style="line-height: 1.6 !important;"/>    x=self.relu(self.bn1(self.conv1(x)))<br style="line-height: 1.6 !important;"/>    x=self.relu(self.bn2(self.conv2(x)))<br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/>    <span style="line-height:1.6 !important;font-style:italic;"># 構建了一串BasicBlock構成的模塊</span><br style="line-height: 1.6 !important;"/>    x=self.layer1(x)<br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/>    <span style="line-height:1.6 !important;font-style:italic;"># 然後是多個stage,每個stage核心是調用HighResolutionModule模塊</span><br style="line-height: 1.6 !important;"/>    x_list=[]<br style="line-height: 1.6 !important;"/>    <span style="line-height:1.6 !important;">for</span> i <span style="line-height:1.6 !important;">in</span> range(self.stage2_cfg[<span style="line-height:1.6 !important;">'NUM_BRANCHES'</span>]):<br style="line-height: 1.6 !important;"/>        <span style="line-height:1.6 !important;">if</span> self.transition1[i] <span style="line-height:1.6 !important;">is</span> <span style="line-height:1.6 !important;">not</span> <span style="line-height:1.6 !important;">None</span>:<br style="line-height: 1.6 !important;"/>            x_list.append(self.transition1[i](x))<br style="line-height: 1.6 !important;"/>        <span style="line-height:1.6 !important;">else</span>:<br style="line-height: 1.6 !important;"/>            x_list.append(x)<br style="line-height: 1.6 !important;"/>    y_list=self.stage2(x_list)<br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/>    x_list=[]<br style="line-height: 1.6 !important;"/>    <span style="line-height:1.6 !important;">for</span> i <span style="line-height:1.6 !important;">in</span> range(self.stage3_cfg[<span style="line-height:1.6 !important;">'NUM_BRANCHES'</span>]):<br style="line-height: 1.6 !important;"/>        <span style="line-height:1.6 !important;">if</span> self.transition2[i] <span style="line-height:1.6 !important;">is</span> <span style="line-height:1.6 !important;">not</span> <span style="line-height:1.6 !important;">None</span>:<br style="line-height: 1.6 !important;"/>            x_list.append(self.transition2[i](y_list[<span style="line-height:1.6 !important;">-1</span>]))<br style="line-height: 1.6 !important;"/>        <span style="line-height:1.6 !important;">else</span>:<br style="line-height: 1.6 !important;"/>            x_list.append(y_list[i])<br style="line-height: 1.6 !important;"/>    y_list=self.stage3(x_list)<br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/>    x_list=[]<br style="line-height: 1.6 !important;"/>    <span style="line-height:1.6 !important;">for</span> i <span style="line-height:1.6 !important;">in</span> range(self.stage4_cfg[<span style="line-height:1.6 !important;">'NUM_BRANCHES'</span>]):<br style="line-height: 1.6 !important;"/>        <span style="line-height:1.6 !important;">if</span> self.transition3[i] <span style="line-height:1.6 !important;">is</span> <span style="line-height:1.6 !important;">not</span> <span style="line-height:1.6 !important;">None</span>:<br style="line-height: 1.6 !important;"/>            x_list.append(self.transition3[i](y_list[<span style="line-height:1.6 !important;">-1</span>]))<br style="line-height: 1.6 !important;"/>        <span style="line-height:1.6 !important;">else</span>:<br style="line-height: 1.6 !important;"/>            x_list.append(y_list[i])<br style="line-height: 1.6 !important;"/>    y_list=self.stage4(x_list)<br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/>    <span style="line-height:1.6 !important;font-style:italic;"># 添加分類頭,上文中有顯示,在分類問題中添加這種頭</span><br style="line-height: 1.6 !important;"/>    <span style="line-height:1.6 !important;font-style:italic;"># 在其他問題中換用不同的頭</span><br style="line-height: 1.6 !important;"/>    y=self.incre_modules[<span style="line-height:1.6 !important;">0</span>](y_list[<span style="line-height:1.6 !important;">0</span>])<br style="line-height: 1.6 !important;"/>    <span style="line-height:1.6 !important;">for</span> i <span style="line-height:1.6 !important;">in</span> range(len(self.downsamp_modules)):<br style="line-height: 1.6 !important;"/>        y=self.incre_modules[i+<span style="line-height:1.6 !important;">1</span>](y_list[i+<span style="line-height:1.6 !important;">1</span>]) + \<br style="line-height: 1.6 !important;"/>            self.downsamp_modules[i](y)<br style="line-height: 1.6 !important;"/>    y=self.final_layer(y)<br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/>    <span style="line-height:1.6 !important;">if</span> torch._C._get_tracing_state():<br style="line-height: 1.6 !important;"/>        <span style="line-height:1.6 !important;font-style:italic;"># 在不寫C代碼的情況下執行forward,直接用python版本</span><br style="line-height: 1.6 !important;"/>        y=y.flatten(start_dim=<span style="line-height:1.6 !important;">2</span>).mean(dim=<span style="line-height:1.6 !important;">2</span>)<br style="line-height: 1.6 !important;"/>    <span style="line-height:1.6 !important;">else</span>:<br style="line-height: 1.6 !important;"/>        y=F.avg_pool2d(y, kernel_size=y.size()<br style="line-height: 1.6 !important;"/>                            [<span style="line-height:1.6 !important;">2</span>:]).view(y.size(<span style="line-height:1.6 !important;">0</span>), <span style="line-height:1.6 !important;">-1</span>)<br style="line-height: 1.6 !important;"/>    y=self.classifier(y)<br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/>    <span style="line-height:1.6 !important;">return</span> y<br style="line-height: 1.6 !important;"/></pre>
    <h4>總結</h4>
    HRNet核心方法是:在模型的整個過程中,保存高分辨率表徵的同時使用讓不同分辨率的feature map進行特徵交互。
    HRNet在非常多的CV領域有廣泛的應用,比如ICCV2019的東北虎關鍵點識別比賽中,HRNet就起到了一定的作用。並且在分類部分的實驗證明了在同等參數量的情況下,可以取代ResNet進行分類。
之前看鄭安坤大佬的一篇文章 CNN結構設計技巧-兼顧速度精度與工程實現 中提到了一點:
senet是hrnet的一個特例,hrnet不僅有通道注意力,同時也有空間注意力
-- akkaze-鄭安坤
    <figure>
        <img src="https://img0.tuicool.com/jE7vuyB.png!web" class="alignCenter" referrerpolicy="no-referrer"/>
        <figcaption>SELayer核心實現</figcaption>
    </figure>
    SELayer首先通過一個全局平均池化得到一個一維向量,然後通過兩個全連接層,將信息進行壓縮和擴展,通過sigmoid以後得到每個通道的權值,然後用這個權值與原來的feature map相乘,進行信息上的優化。
    <figure>
        <img src="https://img2.tuicool.com/j6ziErV.png!web" class="alignCenter" referrerpolicy="no-referrer"/>
        <figcaption>HRNet一個結構</figcaption>
    </figure>
    可以看到上圖用紅色箭頭串起來的是不是和SELayer很相似。爲什麼說SENet是HRNet的一個特例,但從這個結構來講,可以這麼看:
    <ul>
        <li>
            <p class="no-text-indent">SENet沒有像HRNet這樣分辨率變爲原來的一半,分辨率直接變爲1x1,比較極端。變爲1x1向量以後,SENet中使用了兩個全連接網絡來學習通道的特徵分佈;但是在HRNet中,使用了幾個卷積(Residual block)來學習特徵。</p>
        </li>
        <li>
            <p class="no-text-indent">SENet在主幹部分(高分辨率分支)沒有安排卷積進行特徵的學習;HRNet在主幹部分(高分辨率分支)安排了幾個卷積(Residual block)來學習特徵。</p>
        </li>
        <li>
            <p class="no-text-indent">特徵融合部分SENet和HRNet區分比較大,SENet使用的對應通道相乘的方法,HRNet則使用的是相加。之所以說SENet是通道注意力機制是因爲通過全局平均池化後沒有了空間特徵,只剩通道的特徵;HRNet則可以看作同時保留了空間特徵和通道特徵,所以說HRNet不僅有通道注意力,同時也有空間注意力。</p>
        </li>
    </ul>
    HRNet團隊有10人之多,構建了分類、分割、檢測、關鍵點檢測等庫,工作量非常大,而且做了很多紮實的實驗證明了這種思路的有效性。所以是否可以認爲HRNet屬於SENet之後又一個更優的backbone呢?還需要自己實踐中使用這種想法和思路來驗證。
    <h4>參考</h4>
    https://arxiv.org/pdf/1908.07919
    https://www.bilibili.com/video/BV1WJ41197dh?t=508
    https://github.com/HRNet
</div>
- End -
<p style="text-align:center">
    <img src="https://img0.tuicool.com/6Fj63a6.jpg!web" class="alignCenter" referrerpolicy="no-referrer"/>
</p>

    <mpcps frameborder="0"/>


    <span>京東滿100-50優惠活動等你來參與!</span>


    <span>✄------------------------------------------------</span>


    <span>看到這裏,說明你喜歡這篇文章,請點擊「</span>
    <span>在看</span>
    <span>」或順手「</span>
    <span>轉發</span>
    <span>」「</span>
    <span>點贊</span>
    <span>」。</span>


    <span>歡迎微信搜索「</span>
    <span>panchuangxx</span>
    <span>」,添加小編</span>
    <span>磐小小仙</span>
    <span>微信,每日朋友圈更新一篇高質量推文(無廣告),爲您提供更多精彩內容。</span>


    <span>▼  </span>
    <span>▼</span>
    <span>  </span>
    <span> 掃描二維碼添加小編</span>
    <span>  ▼</span>
    <span>
        <span>  </span>
▼  
<p style="text-align:center">
    <img src="https://img0.tuicool.com/ym2Yny6.jpg!web" class="alignCenter" referrerpolicy="no-referrer"/>
</p>
相關文章