<meta http-equiv="Content-Type" content="text/html; charset=UTF-8"/><div>
<div>
    <h2>
        <img src="https://img1.tuicool.com/vee6RnZ.jpg!web" class="alignCenter" referrerpolicy="no-referrer"/>
    </h2>
    <h3>
        <div>


                    <span>  磐创AI分享</span>
                                    <strong>
                    <span>  </span>
                </strong>
                <strong/>

        </div>
    </h3>
    来源 | GiantPandaCV
    编辑 | pprp
    <h4>
        <span>
            <strong>【导读】</strong>
        </span>
        <span>HRNet是微软亚洲研究院的王井东老师领导的团队完成的,打通图像分类、图像分割、目标检测、人脸对齐、姿态识别、风格迁移、Image Inpainting、超分、optical flow、Depth estimation、边缘检测等网络结构。</span>
    </h4>
    王老师在ValseWebinar《物体和关键点检测》中亲自讲解了HRNet,讲解地非常透彻。以下文章主要参考了王老师在演讲中的解读,配合论文+代码部分,来为各位读者介绍这个全能的Backbone-HRNet。
    引入
    <figure>
        <img src="https://img0.tuicool.com/feURFb2.png!web" class="alignCenter" referrerpolicy="no-referrer"/>
        <figcaption>网络结构设计思路</figcaption>
    </figure>
    在人体姿态识别这类的任务中,需要生成一个高分辨率的heatmap来进行关键点检测。这就与一般的网络结构比如VGGNet的要求不同,因为VGGNet最终得到的feature map分辨率很低,损失了空间结构。
    <figure>
        <img src="https://img1.tuicool.com/rayemqu.png!web" class="alignCenter" referrerpolicy="no-referrer"/>
        <figcaption>传统的解决思路</figcaption>
    </figure>
    获取高分辨率的方式大部分都是如上图所示,采用的是先降分辨率,然后再升分辨率的方法。U-Net、SegNet、DeconvNet、Hourglass本质上都是这种结构。
    <figure>
        <img src="https://img2.tuicool.com/i6Z3Qv3.png!web" class="alignCenter" referrerpolicy="no-referrer"/>
        <figcaption>虽然看上去不同,但是本质是一致的</figcaption>
    </figure>
    <h4>核心</h4>
    普通网络都是这种结构,不同分辨率之间是进行了串联
    <figure>
        <img src="https://img1.tuicool.com/EfiABju.png!web" class="alignCenter" referrerpolicy="no-referrer"/>
        <figcaption>不断降分辨率</figcaption>
    </figure>
    王井东老师则是将不同分辨率的feature map进行并联:
    <figure>
        <img src="https://img0.tuicool.com/JRv6Jr2.png!web" class="alignCenter" referrerpolicy="no-referrer"/>
        <figcaption>并联不同分辨率feature map</figcaption>
    </figure>
    在并联的基础上,添加不同分辨率feature map之间的交互(fusion)。
    <figure>
        <img src="https://img2.tuicool.com/j2IZZzq.png!web" class="alignCenter" referrerpolicy="no-referrer"/>
    </figure>
    具体fusion的方法如下图所示:
    <figure>
        <img src="https://img0.tuicool.com/Qv6VZjE.png!web" class="alignCenter" referrerpolicy="no-referrer"/>
    </figure>
    <ul>
        <li>
            <p class="no-text-indent">同分辨率的层直接复制。</p>
        </li>
        <li>
            <p class="no-text-indent">需要升分辨率的使用bilinear upsample + 1x1卷积将channel数统一。</p>
        </li>
        <li>
            <p class="no-text-indent">需要降分辨率的使用strided 3x3 卷积。</p>
        </li>
        <li>
            <p class="no-text-indent">三个feature map融合的方式是相加。</p>
        </li>
    </ul>
    至于为何要用strided 3x3卷积,这是因为卷积在降维的时候会出现信息损失,使用strided 3x3卷积是为了通过学习的方式,降低信息的损耗。所以这里没有用maxpool或者组合池化。
    <figure>
        <img src="https://img2.tuicool.com/buINryY.png!web" class="alignCenter" referrerpolicy="no-referrer"/>
        <figcaption>HR示意图</figcaption>
    </figure>
    另外在读HRNet的时候会有一个问题,有四个分支的到底如何使用这几个分支呢?论文中也给出了几种方式作为最终的特征选择。
    <figure>
        <img src="https://img1.tuicool.com/EJ3U7vb.png!web" class="alignCenter" referrerpolicy="no-referrer"/>
        <figcaption>三种特征融合方法</figcaption>
    </figure>
    (a)图展示的是HRNetV1的特征选择,只使用分辨率最高的特征图。
    (b)图展示的是HRNetV2的特征选择,将所有分辨率的特征图(小的特征图进行upsample)进行concate,主要用于语义分割和面部关键点检测。
    (c)图展示的是HRNetV2p的特征选择,在HRNetV2的基础上,使用了一个特征金字塔,主要用于目标检测网络。
    再补充一个(d)图
    <figure>
        <img src="https://img2.tuicool.com/a6N3eyq.png!web" class="alignCenter" referrerpolicy="no-referrer"/>
        <figcaption>HRNetV2分类网络后的特征选择</figcaption>
    </figure>
    (d)图展示的也是HRNetV2,采用上图的融合方式,主要用于训练分类网络。
总结一下HRNet 创新点
  • 将高低分辨率之间的链接由串联改为并联。

  • 在整个网络结构中都保持了高分辨率的表征(最上边那个通路)。

  • 在高低分辨率中引入了交互来提高模型性能。

    <h4>效果</h4>
    消融实验
    <ol>
        <li>
            <p class="no-text-indent">对交互方法进行消融实验,证明了当前跨分辨率的融合的有效性。</p>
        </li>
    </ol>
    <figure>
        <img src="https://img2.tuicool.com/Y3QvA3R.png!web" class="alignCenter" referrerpolicy="no-referrer"/>
        <figcaption>交互方法的消融实现</figcaption>
    </figure>
    <ol start="2">
        <li>
            <p class="no-text-indent">证明高分辨率feature map的表征能力</p>
        </li>
    </ol>
    <figure>
        <img src="https://img2.tuicool.com/7baMviM.png!web" class="alignCenter" referrerpolicy="no-referrer"/>
    </figure>
    1x代表不进行降维,2x代表分辨率变为原来一半,4x代表分辨率变为原来四分之一。W32、W48中的32、48代表卷积的宽度或者通道数。
    姿态识别任务上的表现
    <figure>
        <img src="https://img2.tuicool.com/iamIz2U.png!web" class="alignCenter" referrerpolicy="no-referrer"/>
    </figure>
    以上的姿态识别采用的是top-down的方法。
    <figure>
        <img src="https://img2.tuicool.com/Qb2EfqB.png!web" class="alignCenter" referrerpolicy="no-referrer"/>
        <figcaption>COCO验证集的结果</figcaption>
    </figure>
    在参数和计算量不增加的情况下,要比其他同类网络效果好很多。
    <figure>
        <img src="https://img0.tuicool.com/JR3UbiZ.png!web" class="alignCenter" referrerpolicy="no-referrer"/>
        <figcaption>COCO测试集上的结果</figcaption>
    </figure>
    在19年2月28日时的PoseTrack Leaderboard,HRNet占领两个项目的第一名。
    <figure>
        <img src="https://img0.tuicool.com/Nf26vyZ.png!web" class="alignCenter" referrerpolicy="no-referrer"/>
        <figcaption>PoseTrack Leaderboard</figcaption>
    </figure>
    <div>
        <h4>语义分割任务中的表现</h4>
        <p style="text-align:center">
            <img src="https://img0.tuicool.com/RZFJvuJ.png!web" class="alignCenter" referrerpolicy="no-referrer"/>
        </p>
    </div>
    <figure>
        <div>CityScape验证集上的结果对比</div>
    </figure>
    <figure>
        <img src="https://img1.tuicool.com/vUBv2qm.png!web" class="alignCenter" referrerpolicy="no-referrer"/>
        <figcaption>Cityscapes测试集上的对比</figcaption>
    </figure>
    <h4>目标检测任务中的表现</h4>
    <figure>
        <img src="https://img1.tuicool.com/JBNf2un.png!web" class="alignCenter" referrerpolicy="no-referrer"/>
        <figcaption>单模型单尺度模型对比</figcaption>
    </figure>
    <figure>
        <img src="https://img0.tuicool.com/e2a26fF.png!web" class="alignCenter" referrerpolicy="no-referrer"/>
        <figcaption>Mask R-CNN上结果</figcaption>
    </figure>
    分类任务上的表现
    <figure>
        <img src="https://img0.tuicool.com/nINfEry.png!web" class="alignCenter" referrerpolicy="no-referrer"/>
    </figure>
    ps: 王井东老师在这部分提到,分割的网络也需要使用分类的预训练模型,否则结果会差几个点。
    <figure>
        <img src="https://img2.tuicool.com/FN3Y3mV.png!web" class="alignCenter" referrerpolicy="no-referrer"/>
        <figcaption>图像分类任务中和ResNet进行对比</figcaption>
    </figure>
    以上是HRNet和ResNet结果对比,同一个颜色的都是参数量大体一致的模型进行的对比,在参数量差不多甚至更少的情况下,HRNet能够比ResNet达到更好的效果。
    HRNet( https://github.com/HRNet )工作量非常大,构建了六个库涉及语义分割、人体姿态检测、目标检测、图片分类、面部关键点检测、Mask R-CNN等库。全部内容如下图所示:
    <figure>
        <img src="https://img0.tuicool.com/ru6nErZ.png!web" class="alignCenter" referrerpolicy="no-referrer"/>
    </figure>
    笔者对HRNet代码构建非常感兴趣,所以以HRNet-Image-Classification库为例,来解析一下这部分代码。
    先从简单的入手,BasicBlock
    <figure>
        <img src="https://img0.tuicool.com/AB3amym.png!web" class="alignCenter" referrerpolicy="no-referrer"/>
        <figcaption>BasicBlock结构</figcaption>
    </figure>
    <pre class="prettyprint"><span style="line-height: 1.6 !important;"><span style="line-height:1.6 !important;">def</span> <span style="line-height:1.6 !important;">conv3x3</span><span style="line-height: 1.6 !important;">(in_planes, out_planes, stride=<span style="line-height:1.6 !important;">1</span>)</span>:</span><br style="line-height: 1.6 !important;"/>    <span style="line-height:1.6 !important;">"""3x3 convolution with padding"""</span><br style="line-height: 1.6 !important;"/>    <span style="line-height:1.6 !important;">return</span> nn.Conv2d(in_planes, out_planes, kernel_size=<span style="line-height:1.6 !important;">3</span>, stride=stride,<br style="line-height: 1.6 !important;"/>                     padding=<span style="line-height:1.6 !important;">1</span>, bias=<span style="line-height:1.6 !important;">False</span>)<br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/><span style="line-height: 1.6 !important;"><span style="line-height:1.6 !important;">class</span> <span style="line-height:1.6 !important;">BasicBlock</span><span style="line-height: 1.6 !important;">(nn.Module)</span>:</span><br style="line-height: 1.6 !important;"/>    expansion = <span style="line-height:1.6 !important;">1</span><br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/>    <span style="line-height: 1.6 !important;"><span style="line-height:1.6 !important;">def</span> <span style="line-height:1.6 !important;">__init__</span><span style="line-height: 1.6 !important;">(self, inplanes, planes, stride=<span style="line-height:1.6 !important;">1</span>, downsample=None)</span>:</span><br style="line-height: 1.6 !important;"/>        super(BasicBlock, self).__init__()<br style="line-height: 1.6 !important;"/>        self.conv1 = conv3x3(inplanes, planes, stride)<br style="line-height: 1.6 !important;"/>        self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)<br style="line-height: 1.6 !important;"/>        self.relu = nn.ReLU(inplace=<span style="line-height:1.6 !important;">True</span>)<br style="line-height: 1.6 !important;"/>        self.conv2 = conv3x3(planes, planes)<br style="line-height: 1.6 !important;"/>        self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)<br style="line-height: 1.6 !important;"/>        self.downsample = downsample<br style="line-height: 1.6 !important;"/>        self.stride = stride<br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/>    <span style="line-height: 1.6 !important;"><span style="line-height:1.6 !important;">def</span> <span style="line-height:1.6 !important;">forward</span><span style="line-height: 1.6 !important;">(self, x)</span>:</span><br style="line-height: 1.6 !important;"/>        residual = x<br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/>        out = self.conv1(x)<br style="line-height: 1.6 !important;"/>        out = self.bn1(out)<br style="line-height: 1.6 !important;"/>        out = self.relu(out)<br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/>        out = self.conv2(out)<br style="line-height: 1.6 !important;"/>        out = self.bn2(out)<br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/>        <span style="line-height:1.6 !important;">if</span> self.downsample <span style="line-height:1.6 !important;">is</span> <span style="line-height:1.6 !important;">not</span> <span style="line-height:1.6 !important;">None</span>:<br style="line-height: 1.6 !important;"/>            residual = self.downsample(x)<br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/>        out += residual<br style="line-height: 1.6 !important;"/>        out = self.relu(out)<br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/>        <span style="line-height:1.6 !important;">return</span> out<br style="line-height: 1.6 !important;"/></pre>
    Bottleneck:
    <figure>
        <img src="https://img1.tuicool.com/367B3um.png!web" class="alignCenter" referrerpolicy="no-referrer"/>
        <figcaption>Bottleneck结构图</figcaption>
    </figure>
    <pre class="prettyprint"><span style="line-height: 1.6 !important;"><span style="line-height:1.6 !important;">class</span> <span style="line-height:1.6 !important;">Bottleneck</span><span style="line-height: 1.6 !important;">(nn.Module)</span>:</span><br style="line-height: 1.6 !important;"/>    expansion = <span style="line-height:1.6 !important;">4</span><br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/>    <span style="line-height: 1.6 !important;"><span style="line-height:1.6 !important;">def</span> <span style="line-height:1.6 !important;">__init__</span><span style="line-height: 1.6 !important;">(self, inplanes, planes, stride=<span style="line-height:1.6 !important;">1</span>, downsample=None)</span>:</span><br style="line-height: 1.6 !important;"/>        super(Bottleneck, self).__init__()<br style="line-height: 1.6 !important;"/>        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=<span style="line-height:1.6 !important;">1</span>, bias=<span style="line-height:1.6 !important;">False</span>)<br style="line-height: 1.6 !important;"/>        self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)<br style="line-height: 1.6 !important;"/>        self.conv2 = nn.Conv2d(planes, planes, kernel_size=<span style="line-height:1.6 !important;">3</span>, stride=stride,<br style="line-height: 1.6 !important;"/>                               padding=<span style="line-height:1.6 !important;">1</span>, bias=<span style="line-height:1.6 !important;">False</span>)<br style="line-height: 1.6 !important;"/>        self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)<br style="line-height: 1.6 !important;"/>        self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=<span style="line-height:1.6 !important;">1</span>,<br style="line-height: 1.6 !important;"/>                               bias=<span style="line-height:1.6 !important;">False</span>)<br style="line-height: 1.6 !important;"/>        self.bn3 = nn.BatchNorm2d(planes * self.expansion,<br style="line-height: 1.6 !important;"/>                                  momentum=BN_MOMENTUM)<br style="line-height: 1.6 !important;"/>        self.relu = nn.ReLU(inplace=<span style="line-height:1.6 !important;">True</span>)<br style="line-height: 1.6 !important;"/>        self.downsample = downsample<br style="line-height: 1.6 !important;"/>        self.stride = stride<br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/>    <span style="line-height: 1.6 !important;"><span style="line-height:1.6 !important;">def</span> <span style="line-height:1.6 !important;">forward</span><span style="line-height: 1.6 !important;">(self, x)</span>:</span><br style="line-height: 1.6 !important;"/>        residual = x<br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/>        out = self.conv1(x)<br style="line-height: 1.6 !important;"/>        out = self.bn1(out)<br style="line-height: 1.6 !important;"/>        out = self.relu(out)<br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/>        out = self.conv2(out)<br style="line-height: 1.6 !important;"/>        out = self.bn2(out)<br style="line-height: 1.6 !important;"/>        out = self.relu(out)<br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/>        out = self.conv3(out)<br style="line-height: 1.6 !important;"/>        out = self.bn3(out)<br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/>        <span style="line-height:1.6 !important;">if</span> self.downsample <span style="line-height:1.6 !important;">is</span> <span style="line-height:1.6 !important;">not</span> <span style="line-height:1.6 !important;">None</span>:<br style="line-height: 1.6 !important;"/>            residual = self.downsample(x)<br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/>        out += residual<br style="line-height: 1.6 !important;"/>        out = self.relu(out)<br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/>        <span style="line-height:1.6 !important;">return</span> out<br style="line-height: 1.6 !important;"/></pre>
    HighResolutionModule,这是核心模块, 主要分为两个组件:branches和fuse layer。
    <figure>
        <img src="https://img1.tuicool.com/ZfeU3ej.jpg!web" class="alignCenter" referrerpolicy="no-referrer"/>
    </figure>
    <pre class="prettyprint"><span style="line-height: 1.6 !important;"><span style="line-height:1.6 !important;">class</span> <span style="line-height:1.6 !important;">HighResolutionModule</span><span style="line-height: 1.6 !important;">(nn.Module)</span>:</span><br style="line-height: 1.6 !important;"/>    <span style="line-height: 1.6 !important;"><span style="line-height:1.6 !important;">def</span> <span style="line-height:1.6 !important;">__init__</span><span style="line-height: 1.6 !important;">(self, num_branches, blocks, num_blocks, num_inchannels,<br style="line-height: 1.6 !important;"/>                 num_channels, fuse_method, multi_scale_output=True)</span>:</span><br style="line-height: 1.6 !important;"/>        <span style="line-height:1.6 !important;">'''<br style="line-height: 1.6 !important;"/>        调用:<br style="line-height: 1.6 !important;"/>        # 调用高低分辨率交互模块, stage2 为例<br style="line-height: 1.6 !important;"/>        HighResolutionModule(num_branches, # 2<br style="line-height: 1.6 !important;"/>                             block, # 'BASIC'<br style="line-height: 1.6 !important;"/>                             num_blocks, # [4, 4]<br style="line-height: 1.6 !important;"/>                             num_inchannels, # 上个stage的out channel<br style="line-height: 1.6 !important;"/>                             num_channels, # [32, 64]<br style="line-height: 1.6 !important;"/>                             fuse_method, # SUM<br style="line-height: 1.6 !important;"/>                             reset_multi_scale_output)<br style="line-height: 1.6 !important;"/>        '''</span><br style="line-height: 1.6 !important;"/>        super(HighResolutionModule, self).__init__()<br style="line-height: 1.6 !important;"/>        self._check_branches(<br style="line-height: 1.6 !important;"/>            <span style="line-height:1.6 !important;font-style:italic;"># 检查分支数目是否合理</span><br style="line-height: 1.6 !important;"/>            num_branches, blocks, num_blocks, num_inchannels, num_channels)<br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/>        self.num_inchannels = num_inchannels<br style="line-height: 1.6 !important;"/>        <span style="line-height:1.6 !important;font-style:italic;"># 融合选用相加的方式</span><br style="line-height: 1.6 !important;"/>        self.fuse_method = fuse_method<br style="line-height: 1.6 !important;"/>        self.num_branches = num_branches<br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/>        self.multi_scale_output = multi_scale_output<br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/>        <span style="line-height:1.6 !important;font-style:italic;"># 两个核心部分,一个是branches构建,一个是融合layers构建</span><br style="line-height: 1.6 !important;"/>        self.branches = self._make_branches(<br style="line-height: 1.6 !important;"/>            num_branches, blocks, num_blocks, num_channels)<br style="line-height: 1.6 !important;"/>        self.fuse_layers = self._make_fuse_layers()<br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/>        self.relu = nn.ReLU(<span style="line-height:1.6 !important;">False</span>)<br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/>    <span style="line-height: 1.6 !important;"><span style="line-height:1.6 !important;">def</span> <span style="line-height:1.6 !important;">_check_branches</span><span style="line-height: 1.6 !important;">(self, num_branches, blocks, num_blocks,<br style="line-height: 1.6 !important;"/>                        num_inchannels, num_channels)</span>:</span><br style="line-height: 1.6 !important;"/>        <span style="line-height:1.6 !important;font-style:italic;"># 分别检查参数是否符合要求,看models.py中的参数,blocks参数冗余了</span><br style="line-height: 1.6 !important;"/>        <span style="line-height:1.6 !important;">if</span> num_branches != len(num_blocks):<br style="line-height: 1.6 !important;"/>            error_msg = <span style="line-height:1.6 !important;">'NUM_BRANCHES({}) &lt;&gt; NUM_BLOCKS({})'</span>.format(<br style="line-height: 1.6 !important;"/>                num_branches, len(num_blocks))<br style="line-height: 1.6 !important;"/>            logger.error(error_msg)<br style="line-height: 1.6 !important;"/>            <span style="line-height:1.6 !important;">raise</span> ValueError(error_msg)<br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/>        <span style="line-height:1.6 !important;">if</span> num_branches != len(num_channels):<br style="line-height: 1.6 !important;"/>            error_msg = <span style="line-height:1.6 !important;">'NUM_BRANCHES({}) &lt;&gt; NUM_CHANNELS({})'</span>.format(<br style="line-height: 1.6 !important;"/>                num_branches, len(num_channels))<br style="line-height: 1.6 !important;"/>            logger.error(error_msg)<br style="line-height: 1.6 !important;"/>            <span style="line-height:1.6 !important;">raise</span> ValueError(error_msg)<br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/>        <span style="line-height:1.6 !important;">if</span> num_branches != len(num_inchannels):<br style="line-height: 1.6 !important;"/>            error_msg = <span style="line-height:1.6 !important;">'NUM_BRANCHES({}) &lt;&gt; NUM_INCHANNELS({})'</span>.format(<br style="line-height: 1.6 !important;"/>                num_branches, len(num_inchannels))<br style="line-height: 1.6 !important;"/>            logger.error(error_msg)<br style="line-height: 1.6 !important;"/>            <span style="line-height:1.6 !important;">raise</span> ValueError(error_msg)<br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/>    <span style="line-height: 1.6 !important;"><span style="line-height:1.6 !important;">def</span> <span style="line-height:1.6 !important;">_make_one_branch</span><span style="line-height: 1.6 !important;">(self, branch_index, block, num_blocks, num_channels,<br style="line-height: 1.6 !important;"/>                         stride=<span style="line-height:1.6 !important;">1</span>)</span>:</span><br style="line-height: 1.6 !important;"/>        <span style="line-height:1.6 !important;font-style:italic;"># 构建一个分支,一个分支重复num_blocks个block</span><br style="line-height: 1.6 !important;"/>        downsample = <span style="line-height:1.6 !important;">None</span><br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/>        <span style="line-height:1.6 !important;font-style:italic;"># 这里判断,如果通道变大(分辨率变小),则使用下采样</span><br style="line-height: 1.6 !important;"/>        <span style="line-height:1.6 !important;">if</span> stride != <span style="line-height:1.6 !important;">1</span> <span style="line-height:1.6 !important;">or</span> \<br style="line-height: 1.6 !important;"/>           self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion:<br style="line-height: 1.6 !important;"/>            downsample = nn.Sequential(<br style="line-height: 1.6 !important;"/>                nn.Conv2d(self.num_inchannels[branch_index],<br style="line-height: 1.6 !important;"/>                          num_channels[branch_index] * block.expansion,<br style="line-height: 1.6 !important;"/>                          kernel_size=<span style="line-height:1.6 !important;">1</span>, stride=stride, bias=<span style="line-height:1.6 !important;">False</span>),<br style="line-height: 1.6 !important;"/>                nn.BatchNorm2d(num_channels[branch_index] * block.expansion,<br style="line-height: 1.6 !important;"/>                               momentum=BN_MOMENTUM),<br style="line-height: 1.6 !important;"/>            )<br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/>        layers = []<br style="line-height: 1.6 !important;"/>        layers.append(block(self.num_inchannels[branch_index],<br style="line-height: 1.6 !important;"/>                            num_channels[branch_index], stride, downsample))<br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/>        self.num_inchannels[branch_index] = \<br style="line-height: 1.6 !important;"/>            num_channels[branch_index] * block.expansion<br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/>        <span style="line-height:1.6 !important;">for</span> i <span style="line-height:1.6 !important;">in</span> range(<span style="line-height:1.6 !important;">1</span>, num_blocks[branch_index]):<br style="line-height: 1.6 !important;"/>            layers.append(block(self.num_inchannels[branch_index],<br style="line-height: 1.6 !important;"/>                                num_channels[branch_index]))<br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/>        <span style="line-height:1.6 !important;">return</span> nn.Sequential(*layers)<br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/>    <span style="line-height: 1.6 !important;"><span style="line-height:1.6 !important;">def</span> <span style="line-height:1.6 !important;">_make_branches</span><span style="line-height: 1.6 !important;">(self, num_branches, block, num_blocks, num_channels)</span>:</span><br style="line-height: 1.6 !important;"/>        branches = []<br style="line-height: 1.6 !important;"/>        <br style="line-height: 1.6 !important;"/>        <span style="line-height:1.6 !important;font-style:italic;"># 通过循环构建多分支,每个分支属于不同的分辨率</span><br style="line-height: 1.6 !important;"/>        <span style="line-height:1.6 !important;">for</span> i <span style="line-height:1.6 !important;">in</span> range(num_branches):<br style="line-height: 1.6 !important;"/>            branches.append(<br style="line-height: 1.6 !important;"/>                self._make_one_branch(i, block, num_blocks, num_channels))<br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/>        <span style="line-height:1.6 !important;">return</span> nn.ModuleList(branches)<br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/>    <span style="line-height: 1.6 !important;"><span style="line-height:1.6 !important;">def</span> <span style="line-height:1.6 !important;">_make_fuse_layers</span><span style="line-height: 1.6 !important;">(self)</span>:</span><br style="line-height: 1.6 !important;"/>        <span style="line-height:1.6 !important;">if</span> self.num_branches == <span style="line-height:1.6 !important;">1</span>:<br style="line-height: 1.6 !important;"/>            <span style="line-height:1.6 !important;">return</span> <span style="line-height:1.6 !important;">None</span><br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/>        num_branches = self.num_branches <span style="line-height:1.6 !important;font-style:italic;"># 2</span><br style="line-height: 1.6 !important;"/>        num_inchannels = self.num_inchannels<br style="line-height: 1.6 !important;"/>        fuse_layers = []<br style="line-height: 1.6 !important;"/>        <span style="line-height:1.6 !important;">for</span> i <span style="line-height:1.6 !important;">in</span> range(num_branches <span style="line-height:1.6 !important;">if</span> self.multi_scale_output <span style="line-height:1.6 !important;">else</span> <span style="line-height:1.6 !important;">1</span>):<br style="line-height: 1.6 !important;"/>            <span style="line-height:1.6 !important;font-style:italic;"># i代表枚举所有分支</span><br style="line-height: 1.6 !important;"/>            fuse_layer = []<br style="line-height: 1.6 !important;"/>            <span style="line-height:1.6 !important;">for</span> j <span style="line-height:1.6 !important;">in</span> range(num_branches):<br style="line-height: 1.6 !important;"/>                <span style="line-height:1.6 !important;font-style:italic;"># j代表处理的当前分支</span><br style="line-height: 1.6 !important;"/>                <span style="line-height:1.6 !important;">if</span> j &gt; i: <span style="line-height:1.6 !important;font-style:italic;"># 进行上采样,使用最近邻插值</span><br style="line-height: 1.6 !important;"/>                    fuse_layer.append(nn.Sequential(<br style="line-height: 1.6 !important;"/>                        nn.Conv2d(num_inchannels[j],<br style="line-height: 1.6 !important;"/>                                  num_inchannels[i],<br style="line-height: 1.6 !important;"/>                                  <span style="line-height:1.6 !important;">1</span>,<br style="line-height: 1.6 !important;"/>                                  <span style="line-height:1.6 !important;">1</span>,<br style="line-height: 1.6 !important;"/>                                  <span style="line-height:1.6 !important;">0</span>,<br style="line-height: 1.6 !important;"/>                                  bias=<span style="line-height:1.6 !important;">False</span>),<br style="line-height: 1.6 !important;"/>                        nn.BatchNorm2d(num_inchannels[i],<br style="line-height: 1.6 !important;"/>                                       momentum=BN_MOMENTUM),<br style="line-height: 1.6 !important;"/>                        nn.Upsample(scale_factor=<span style="line-height:1.6 !important;">2</span>**(j-i), mode=<span style="line-height:1.6 !important;">'nearest'</span>)))<br style="line-height: 1.6 !important;"/>                <span style="line-height:1.6 !important;">elif</span> j == i:<br style="line-height: 1.6 !important;"/>                    <span style="line-height:1.6 !important;font-style:italic;"># 本层不做处理</span><br style="line-height: 1.6 !important;"/>                    fuse_layer.append(<span style="line-height:1.6 !important;">None</span>)<br style="line-height: 1.6 !important;"/>                <span style="line-height:1.6 !important;">else</span>:<br style="line-height: 1.6 !important;"/>                    conv3x3s = []<br style="line-height: 1.6 !important;"/>                    <span style="line-height:1.6 !important;font-style:italic;"># 进行strided 3x3 conv下采样,如果跨两层,就使用两次strided 3x3 conv</span><br style="line-height: 1.6 !important;"/>                    <span style="line-height:1.6 !important;">for</span> k <span style="line-height:1.6 !important;">in</span> range(i-j):<br style="line-height: 1.6 !important;"/>                        <span style="line-height:1.6 !important;">if</span> k == i - j - <span style="line-height:1.6 !important;">1</span>:<br style="line-height: 1.6 !important;"/>                            num_outchannels_conv3x3 = num_inchannels[i]<br style="line-height: 1.6 !important;"/>                            conv3x3s.append(nn.Sequential(<br style="line-height: 1.6 !important;"/>                                nn.Conv2d(num_inchannels[j],<br style="line-height: 1.6 !important;"/>                                          num_outchannels_conv3x3,<br style="line-height: 1.6 !important;"/>                                          <span style="line-height:1.6 !important;">3</span>, <span style="line-height:1.6 !important;">2</span>, <span style="line-height:1.6 !important;">1</span>, bias=<span style="line-height:1.6 !important;">False</span>),<br style="line-height: 1.6 !important;"/>                                nn.BatchNorm2d(num_outchannels_conv3x3,<br style="line-height: 1.6 !important;"/>                                               momentum=BN_MOMENTUM)))<br style="line-height: 1.6 !important;"/>                        <span style="line-height:1.6 !important;">else</span>:<br style="line-height: 1.6 !important;"/>                            num_outchannels_conv3x3 = num_inchannels[j]<br style="line-height: 1.6 !important;"/>                            conv3x3s.append(nn.Sequential(<br style="line-height: 1.6 !important;"/>                                nn.Conv2d(num_inchannels[j],<br style="line-height: 1.6 !important;"/>                                          num_outchannels_conv3x3,<br style="line-height: 1.6 !important;"/>                                          <span style="line-height:1.6 !important;">3</span>, <span style="line-height:1.6 !important;">2</span>, <span style="line-height:1.6 !important;">1</span>, bias=<span style="line-height:1.6 !important;">False</span>),<br style="line-height: 1.6 !important;"/>                                nn.BatchNorm2d(num_outchannels_conv3x3,<br style="line-height: 1.6 !important;"/>                                nn.ReLU(<span style="line-height:1.6 !important;">False</span>)))<br style="line-height: 1.6 !important;"/>                    fuse_layer.append(nn.Sequential(*conv3x3s))<br style="line-height: 1.6 !important;"/>            fuse_layers.append(nn.ModuleList(fuse_layer))<br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/>        <span style="line-height:1.6 !important;">return</span> nn.ModuleList(fuse_layers)<br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/>    <span style="line-height: 1.6 !important;"><span style="line-height:1.6 !important;">def</span> <span style="line-height:1.6 !important;">get_num_inchannels</span><span style="line-height: 1.6 !important;">(self)</span>:</span><br style="line-height: 1.6 !important;"/>        <span style="line-height:1.6 !important;">return</span> self.num_inchannels<br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/>    <span style="line-height: 1.6 !important;"><span style="line-height:1.6 !important;">def</span> <span style="line-height:1.6 !important;">forward</span><span style="line-height: 1.6 !important;">(self, x)</span>:</span><br style="line-height: 1.6 !important;"/>        <span style="line-height:1.6 !important;">if</span> self.num_branches == <span style="line-height:1.6 !important;">1</span>:<br style="line-height: 1.6 !important;"/>            <span style="line-height:1.6 !important;">return</span> [self.branches[<span style="line-height:1.6 !important;">0</span>](x[<span style="line-height:1.6 !important;">0</span>])]<br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/>        <span style="line-height:1.6 !important;">for</span> i <span style="line-height:1.6 !important;">in</span> range(self.num_branches):<br style="line-height: 1.6 !important;"/>            x[i]=self.branches[i](x[i])<br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/>        x_fuse=[]<br style="line-height: 1.6 !important;"/>        <span style="line-height:1.6 !important;">for</span> i <span style="line-height:1.6 !important;">in</span> range(len(self.fuse_layers)):<br style="line-height: 1.6 !important;"/>            y=x[<span style="line-height:1.6 !important;">0</span>] <span style="line-height:1.6 !important;">if</span> i == <span style="line-height:1.6 !important;">0</span> <span style="line-height:1.6 !important;">else</span> self.fuse_layers[i][<span style="line-height:1.6 !important;">0</span>](x[<span style="line-height:1.6 !important;">0</span>])<br style="line-height: 1.6 !important;"/>            <span style="line-height:1.6 !important;">for</span> j <span style="line-height:1.6 !important;">in</span> range(<span style="line-height:1.6 !important;">1</span>, self.num_branches):<br style="line-height: 1.6 !important;"/>                <span style="line-height:1.6 !important;">if</span> i == j:<br style="line-height: 1.6 !important;"/>                    y=y + x[j]<br style="line-height: 1.6 !important;"/>                <span style="line-height:1.6 !important;">else</span>:<br style="line-height: 1.6 !important;"/>                    y=y + self.fuse_layers[i][j](x[j])<br style="line-height: 1.6 !important;"/>            x_fuse.append(self.relu(y))<br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/>        <span style="line-height:1.6 !important;font-style:italic;"># 将fuse以后的多个分支结果保存到list中</span><br style="line-height: 1.6 !important;"/>        <span style="line-height:1.6 !important;">return</span> x_fuse<br style="line-height: 1.6 !important;"/></pre>
    models.py中保存的参数, 可以通过这些配置来改变模型的容量、分支个数、特征融合方法:
    <pre class="prettyprint"><span style="line-height:1.6 !important;font-style:italic;"># high_resoluton_net related params for classification</span><br style="line-height: 1.6 !important;"/>POSE_HIGH_RESOLUTION_NET = CN()<br style="line-height: 1.6 !important;"/>POSE_HIGH_RESOLUTION_NET.PRETRAINED_LAYERS = [<span style="line-height:1.6 !important;">'*'</span>]<br style="line-height: 1.6 !important;"/>POSE_HIGH_RESOLUTION_NET.STEM_INPLANES = <span style="line-height:1.6 !important;">64</span><br style="line-height: 1.6 !important;"/>POSE_HIGH_RESOLUTION_NET.FINAL_CONV_KERNEL = <span style="line-height:1.6 !important;">1</span><br style="line-height: 1.6 !important;"/>POSE_HIGH_RESOLUTION_NET.WITH_HEAD = <span style="line-height:1.6 !important;">True</span><br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/>POSE_HIGH_RESOLUTION_NET.STAGE2 = CN()<br style="line-height: 1.6 !important;"/>POSE_HIGH_RESOLUTION_NET.STAGE2.NUM_MODULES = <span style="line-height:1.6 !important;">1</span><br style="line-height: 1.6 !important;"/>POSE_HIGH_RESOLUTION_NET.STAGE2.NUM_BRANCHES = <span style="line-height:1.6 !important;">2</span><br style="line-height: 1.6 !important;"/>POSE_HIGH_RESOLUTION_NET.STAGE2.NUM_BLOCKS = [<span style="line-height:1.6 !important;">4</span>, <span style="line-height:1.6 !important;">4</span>]<br style="line-height: 1.6 !important;"/>POSE_HIGH_RESOLUTION_NET.STAGE2.NUM_CHANNELS = [<span style="line-height:1.6 !important;">32</span>, <span style="line-height:1.6 !important;">64</span>]<br style="line-height: 1.6 !important;"/>POSE_HIGH_RESOLUTION_NET.STAGE2.BLOCK = <span style="line-height:1.6 !important;">'BASIC'</span><br style="line-height: 1.6 !important;"/>POSE_HIGH_RESOLUTION_NET.STAGE2.FUSE_METHOD = <span style="line-height:1.6 !important;">'SUM'</span><br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/>POSE_HIGH_RESOLUTION_NET.STAGE3 = CN()<br style="line-height: 1.6 !important;"/>POSE_HIGH_RESOLUTION_NET.STAGE3.NUM_MODULES = <span style="line-height:1.6 !important;">1</span><br style="line-height: 1.6 !important;"/>POSE_HIGH_RESOLUTION_NET.STAGE3.NUM_BRANCHES = <span style="line-height:1.6 !important;">3</span><br style="line-height: 1.6 !important;"/>POSE_HIGH_RESOLUTION_NET.STAGE3.NUM_BLOCKS = [<span style="line-height:1.6 !important;">4</span>, <span style="line-height:1.6 !important;">4</span>, <span style="line-height:1.6 !important;">4</span>]<br style="line-height: 1.6 !important;"/>POSE_HIGH_RESOLUTION_NET.STAGE3.NUM_CHANNELS = [<span style="line-height:1.6 !important;">32</span>, <span style="line-height:1.6 !important;">64</span>, <span style="line-height:1.6 !important;">128</span>]<br style="line-height: 1.6 !important;"/>POSE_HIGH_RESOLUTION_NET.STAGE3.BLOCK = <span style="line-height:1.6 !important;">'BASIC'</span><br style="line-height: 1.6 !important;"/>POSE_HIGH_RESOLUTION_NET.STAGE3.FUSE_METHOD = <span style="line-height:1.6 !important;">'SUM'</span><br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/>POSE_HIGH_RESOLUTION_NET.STAGE4 = CN()<br style="line-height: 1.6 !important;"/>POSE_HIGH_RESOLUTION_NET.STAGE4.NUM_MODULES = <span style="line-height:1.6 !important;">1</span><br style="line-height: 1.6 !important;"/>POSE_HIGH_RESOLUTION_NET.STAGE4.NUM_BRANCHES = <span style="line-height:1.6 !important;">4</span><br style="line-height: 1.6 !important;"/>POSE_HIGH_RESOLUTION_NET.STAGE4.NUM_BLOCKS = [<span style="line-height:1.6 !important;">4</span>, <span style="line-height:1.6 !important;">4</span>, <span style="line-height:1.6 !important;">4</span>, <span style="line-height:1.6 !important;">4</span>]<br style="line-height: 1.6 !important;"/>POSE_HIGH_RESOLUTION_NET.STAGE4.NUM_CHANNELS = [<span style="line-height:1.6 !important;">32</span>, <span style="line-height:1.6 !important;">64</span>, <span style="line-height:1.6 !important;">128</span>, <span style="line-height:1.6 !important;">256</span>]<br style="line-height: 1.6 !important;"/>POSE_HIGH_RESOLUTION_NET.STAGE4.BLOCK = <span style="line-height:1.6 !important;">'BASIC'</span><br style="line-height: 1.6 !important;"/>POSE_HIGH_RESOLUTION_NET.STAGE4.FUSE_METHOD = <span style="line-height:1.6 !important;">'SUM'</span><br style="line-height: 1.6 !important;"/></pre>
    然后来看整个HRNet模型的构建, 由于整体代码量太大,这里仅仅来看forward函数。
    <pre class="prettyprint"><span style="line-height: 1.6 !important;"><span style="line-height:1.6 !important;">def</span> <span style="line-height:1.6 !important;">forward</span><span style="line-height: 1.6 !important;">(self, x)</span>:</span><br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/>    <span style="line-height:1.6 !important;font-style:italic;"># 使用两个strided 3x3conv进行快速降维</span><br style="line-height: 1.6 !important;"/>    x=self.relu(self.bn1(self.conv1(x)))<br style="line-height: 1.6 !important;"/>    x=self.relu(self.bn2(self.conv2(x)))<br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/>    <span style="line-height:1.6 !important;font-style:italic;"># 构建了一串BasicBlock构成的模块</span><br style="line-height: 1.6 !important;"/>    x=self.layer1(x)<br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/>    <span style="line-height:1.6 !important;font-style:italic;"># 然后是多个stage,每个stage核心是调用HighResolutionModule模块</span><br style="line-height: 1.6 !important;"/>    x_list=[]<br style="line-height: 1.6 !important;"/>    <span style="line-height:1.6 !important;">for</span> i <span style="line-height:1.6 !important;">in</span> range(self.stage2_cfg[<span style="line-height:1.6 !important;">'NUM_BRANCHES'</span>]):<br style="line-height: 1.6 !important;"/>        <span style="line-height:1.6 !important;">if</span> self.transition1[i] <span style="line-height:1.6 !important;">is</span> <span style="line-height:1.6 !important;">not</span> <span style="line-height:1.6 !important;">None</span>:<br style="line-height: 1.6 !important;"/>            x_list.append(self.transition1[i](x))<br style="line-height: 1.6 !important;"/>        <span style="line-height:1.6 !important;">else</span>:<br style="line-height: 1.6 !important;"/>            x_list.append(x)<br style="line-height: 1.6 !important;"/>    y_list=self.stage2(x_list)<br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/>    x_list=[]<br style="line-height: 1.6 !important;"/>    <span style="line-height:1.6 !important;">for</span> i <span style="line-height:1.6 !important;">in</span> range(self.stage3_cfg[<span style="line-height:1.6 !important;">'NUM_BRANCHES'</span>]):<br style="line-height: 1.6 !important;"/>        <span style="line-height:1.6 !important;">if</span> self.transition2[i] <span style="line-height:1.6 !important;">is</span> <span style="line-height:1.6 !important;">not</span> <span style="line-height:1.6 !important;">None</span>:<br style="line-height: 1.6 !important;"/>            x_list.append(self.transition2[i](y_list[<span style="line-height:1.6 !important;">-1</span>]))<br style="line-height: 1.6 !important;"/>        <span style="line-height:1.6 !important;">else</span>:<br style="line-height: 1.6 !important;"/>            x_list.append(y_list[i])<br style="line-height: 1.6 !important;"/>    y_list=self.stage3(x_list)<br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/>    x_list=[]<br style="line-height: 1.6 !important;"/>    <span style="line-height:1.6 !important;">for</span> i <span style="line-height:1.6 !important;">in</span> range(self.stage4_cfg[<span style="line-height:1.6 !important;">'NUM_BRANCHES'</span>]):<br style="line-height: 1.6 !important;"/>        <span style="line-height:1.6 !important;">if</span> self.transition3[i] <span style="line-height:1.6 !important;">is</span> <span style="line-height:1.6 !important;">not</span> <span style="line-height:1.6 !important;">None</span>:<br style="line-height: 1.6 !important;"/>            x_list.append(self.transition3[i](y_list[<span style="line-height:1.6 !important;">-1</span>]))<br style="line-height: 1.6 !important;"/>        <span style="line-height:1.6 !important;">else</span>:<br style="line-height: 1.6 !important;"/>            x_list.append(y_list[i])<br style="line-height: 1.6 !important;"/>    y_list=self.stage4(x_list)<br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/>    <span style="line-height:1.6 !important;font-style:italic;"># 添加分类头,上文中有显示,在分类问题中添加这种头</span><br style="line-height: 1.6 !important;"/>    <span style="line-height:1.6 !important;font-style:italic;"># 在其他问题中换用不同的头</span><br style="line-height: 1.6 !important;"/>    y=self.incre_modules[<span style="line-height:1.6 !important;">0</span>](y_list[<span style="line-height:1.6 !important;">0</span>])<br style="line-height: 1.6 !important;"/>    <span style="line-height:1.6 !important;">for</span> i <span style="line-height:1.6 !important;">in</span> range(len(self.downsamp_modules)):<br style="line-height: 1.6 !important;"/>        y=self.incre_modules[i+<span style="line-height:1.6 !important;">1</span>](y_list[i+<span style="line-height:1.6 !important;">1</span>]) + \<br style="line-height: 1.6 !important;"/>            self.downsamp_modules[i](y)<br style="line-height: 1.6 !important;"/>    y=self.final_layer(y)<br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/>    <span style="line-height:1.6 !important;">if</span> torch._C._get_tracing_state():<br style="line-height: 1.6 !important;"/>        <span style="line-height:1.6 !important;font-style:italic;"># 在不写C代码的情况下执行forward,直接用python版本</span><br style="line-height: 1.6 !important;"/>        y=y.flatten(start_dim=<span style="line-height:1.6 !important;">2</span>).mean(dim=<span style="line-height:1.6 !important;">2</span>)<br style="line-height: 1.6 !important;"/>    <span style="line-height:1.6 !important;">else</span>:<br style="line-height: 1.6 !important;"/>        y=F.avg_pool2d(y, kernel_size=y.size()<br style="line-height: 1.6 !important;"/>                            [<span style="line-height:1.6 !important;">2</span>:]).view(y.size(<span style="line-height:1.6 !important;">0</span>), <span style="line-height:1.6 !important;">-1</span>)<br style="line-height: 1.6 !important;"/>    y=self.classifier(y)<br style="line-height: 1.6 !important;"/><br style="line-height: 1.6 !important;"/>    <span style="line-height:1.6 !important;">return</span> y<br style="line-height: 1.6 !important;"/></pre>
    <h4>总结</h4>
    HRNet核心方法是:在模型的整个过程中,保存高分辨率表征的同时使用让不同分辨率的feature map进行特征交互。
    HRNet在非常多的CV领域有广泛的应用,比如ICCV2019的东北虎关键点识别比赛中,HRNet就起到了一定的作用。并且在分类部分的实验证明了在同等参数量的情况下,可以取代ResNet进行分类。
之前看郑安坤大佬的一篇文章 CNN结构设计技巧-兼顾速度精度与工程实现 中提到了一点:
senet是hrnet的一个特例,hrnet不仅有通道注意力,同时也有空间注意力
-- akkaze-郑安坤
    <figure>
        <img src="https://img0.tuicool.com/jE7vuyB.png!web" class="alignCenter" referrerpolicy="no-referrer"/>
        <figcaption>SELayer核心实现</figcaption>
    </figure>
    SELayer首先通过一个全局平均池化得到一个一维向量,然后通过两个全连接层,将信息进行压缩和扩展,通过sigmoid以后得到每个通道的权值,然后用这个权值与原来的feature map相乘,进行信息上的优化。
    <figure>
        <img src="https://img2.tuicool.com/j6ziErV.png!web" class="alignCenter" referrerpolicy="no-referrer"/>
        <figcaption>HRNet一个结构</figcaption>
    </figure>
    可以看到上图用红色箭头串起来的是不是和SELayer很相似。为什么说SENet是HRNet的一个特例,但从这个结构来讲,可以这么看:
    <ul>
        <li>
            <p class="no-text-indent">SENet没有像HRNet这样分辨率变为原来的一半,分辨率直接变为1x1,比较极端。变为1x1向量以后,SENet中使用了两个全连接网络来学习通道的特征分布;但是在HRNet中,使用了几个卷积(Residual block)来学习特征。</p>
        </li>
        <li>
            <p class="no-text-indent">SENet在主干部分(高分辨率分支)没有安排卷积进行特征的学习;HRNet在主干部分(高分辨率分支)安排了几个卷积(Residual block)来学习特征。</p>
        </li>
        <li>
            <p class="no-text-indent">特征融合部分SENet和HRNet区分比较大,SENet使用的对应通道相乘的方法,HRNet则使用的是相加。之所以说SENet是通道注意力机制是因为通过全局平均池化后没有了空间特征,只剩通道的特征;HRNet则可以看作同时保留了空间特征和通道特征,所以说HRNet不仅有通道注意力,同时也有空间注意力。</p>
        </li>
    </ul>
    HRNet团队有10人之多,构建了分类、分割、检测、关键点检测等库,工作量非常大,而且做了很多扎实的实验证明了这种思路的有效性。所以是否可以认为HRNet属于SENet之后又一个更优的backbone呢?还需要自己实践中使用这种想法和思路来验证。
    <h4>参考</h4>
    https://arxiv.org/pdf/1908.07919
    https://www.bilibili.com/video/BV1WJ41197dh?t=508
    https://github.com/HRNet
</div>
- End -
<p style="text-align:center">
    <img src="https://img0.tuicool.com/6Fj63a6.jpg!web" class="alignCenter" referrerpolicy="no-referrer"/>
</p>

    <mpcps frameborder="0"/>


    <span>京东满100-50优惠活动等你来参与!</span>


    <span>✄------------------------------------------------</span>


    <span>看到这里,说明你喜欢这篇文章,请点击「</span>
    <span>在看</span>
    <span>」或顺手「</span>
    <span>转发</span>
    <span>」「</span>
    <span>点赞</span>
    <span>」。</span>


    <span>欢迎微信搜索「</span>
    <span>panchuangxx</span>
    <span>」,添加小编</span>
    <span>磐小小仙</span>
    <span>微信,每日朋友圈更新一篇高质量推文(无广告),为您提供更多精彩内容。</span>


    <span>▼  </span>
    <span>▼</span>
    <span>  </span>
    <span> 扫描二维码添加小编</span>
    <span>  ▼</span>
    <span>
        <span>  </span>
▼  
<p style="text-align:center">
    <img src="https://img0.tuicool.com/ym2Yny6.jpg!web" class="alignCenter" referrerpolicy="no-referrer"/>
</p>
相关文章