Source code for mxnet.gluon.model_zoo.vision.resnet

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

# coding: utf-8
# pylint: disable= arguments-differ
"""ResNets, implemented in Gluon."""

__all__ = ['ResNetV1', 'ResNetV2',
           'BasicBlockV1', 'BasicBlockV2',
           'BottleneckV1', 'BottleneckV2',
           'resnet18_v1', 'resnet34_v1', 'resnet50_v1', 'resnet101_v1', 'resnet152_v1',
           'resnet18_v2', 'resnet34_v2', 'resnet50_v2', 'resnet101_v2', 'resnet152_v2',
           'get_resnet']

import os

from ....device import cpu
from ...block import HybridBlock
from ... import nn
from .... import base
from .... util import use_np, wrap_ctx_to_device_func
from .... import npx

# Helpers
def _conv3x3(channels, stride, in_channels):
    return nn.Conv2D(channels, kernel_size=3, strides=stride, padding=1,
                     use_bias=False, in_channels=in_channels)


# Blocks
[docs]@use_np class BasicBlockV1(HybridBlock): r"""BasicBlock V1 from `"Deep Residual Learning for Image Recognition" <http://arxiv.org/abs/1512.03385>`_ paper. This is used for ResNet V1 for 18, 34 layers. Parameters ---------- channels : int Number of output channels. stride : int Stride size. downsample : bool, default False Whether to downsample the input. in_channels : int, default 0 Number of input channels. Default is 0, to infer from the graph. """ def __init__(self, channels, stride, downsample=False, in_channels=0, **kwargs): super(BasicBlockV1, self).__init__(**kwargs) self.body = nn.HybridSequential() self.body.add(_conv3x3(channels, stride, in_channels)) self.body.add(nn.BatchNorm()) self.body.add(nn.Activation('relu')) self.body.add(_conv3x3(channels, 1, channels)) self.body.add(nn.BatchNorm()) if downsample: self.downsample = nn.HybridSequential() self.downsample.add(nn.Conv2D(channels, kernel_size=1, strides=stride, use_bias=False, in_channels=in_channels)) self.downsample.add(nn.BatchNorm()) else: self.downsample = None
[docs] def forward(self, x): residual = x x = self.body(x) if self.downsample: residual = self.downsample(residual) x = npx.activation(residual+x, act_type='relu') return x
[docs]@use_np class BottleneckV1(HybridBlock): r"""Bottleneck V1 from `"Deep Residual Learning for Image Recognition" <http://arxiv.org/abs/1512.03385>`_ paper. This is used for ResNet V1 for 50, 101, 152 layers. Parameters ---------- channels : int Number of output channels. stride : int Stride size. downsample : bool, default False Whether to downsample the input. in_channels : int, default 0 Number of input channels. Default is 0, to infer from the graph. """ def __init__(self, channels, stride, downsample=False, in_channels=0, **kwargs): super(BottleneckV1, self).__init__(**kwargs) self.body = nn.HybridSequential() self.body.add(nn.Conv2D(channels//4, kernel_size=1, strides=stride)) self.body.add(nn.BatchNorm()) self.body.add(nn.Activation('relu')) self.body.add(_conv3x3(channels//4, 1, channels//4)) self.body.add(nn.BatchNorm()) self.body.add(nn.Activation('relu')) self.body.add(nn.Conv2D(channels, kernel_size=1, strides=1)) self.body.add(nn.BatchNorm()) if downsample: self.downsample = nn.HybridSequential() self.downsample.add(nn.Conv2D(channels, kernel_size=1, strides=stride, use_bias=False, in_channels=in_channels)) self.downsample.add(nn.BatchNorm()) else: self.downsample = None
[docs] def forward(self, x): residual = x x = self.body(x) if self.downsample: residual = self.downsample(residual) x = npx.activation(x + residual, act_type='relu') return x
[docs]@use_np class BasicBlockV2(HybridBlock): r"""BasicBlock V2 from `"Identity Mappings in Deep Residual Networks" <https://arxiv.org/abs/1603.05027>`_ paper. This is used for ResNet V2 for 18, 34 layers. Parameters ---------- channels : int Number of output channels. stride : int Stride size. downsample : bool, default False Whether to downsample the input. in_channels : int, default 0 Number of input channels. Default is 0, to infer from the graph. """ def __init__(self, channels, stride, downsample=False, in_channels=0, **kwargs): super(BasicBlockV2, self).__init__(**kwargs) self.bn1 = nn.BatchNorm() self.conv1 = _conv3x3(channels, stride, in_channels) self.bn2 = nn.BatchNorm() self.conv2 = _conv3x3(channels, 1, channels) if downsample: self.downsample = nn.Conv2D(channels, 1, stride, use_bias=False, in_channels=in_channels) else: self.downsample = None
[docs] def forward(self, x): residual = x x = self.bn1(x) x = npx.activation(x, act_type='relu') if self.downsample: residual = self.downsample(x) x = self.conv1(x) x = self.bn2(x) x = npx.activation(x, act_type='relu') x = self.conv2(x) return x + residual
[docs]@use_np class BottleneckV2(HybridBlock): r"""Bottleneck V2 from `"Identity Mappings in Deep Residual Networks" <https://arxiv.org/abs/1603.05027>`_ paper. This is used for ResNet V2 for 50, 101, 152 layers. Parameters ---------- channels : int Number of output channels. stride : int Stride size. downsample : bool, default False Whether to downsample the input. in_channels : int, default 0 Number of input channels. Default is 0, to infer from the graph. """ def __init__(self, channels, stride, downsample=False, in_channels=0, **kwargs): super(BottleneckV2, self).__init__(**kwargs) self.bn1 = nn.BatchNorm() self.conv1 = nn.Conv2D(channels//4, kernel_size=1, strides=1, use_bias=False) self.bn2 = nn.BatchNorm() self.conv2 = _conv3x3(channels//4, stride, channels//4) self.bn3 = nn.BatchNorm() self.conv3 = nn.Conv2D(channels, kernel_size=1, strides=1, use_bias=False) if downsample: self.downsample = nn.Conv2D(channels, 1, stride, use_bias=False, in_channels=in_channels) else: self.downsample = None
[docs] def forward(self, x): residual = x x = self.bn1(x) x = npx.activation(x, act_type='relu') if self.downsample: residual = self.downsample(x) x = self.conv1(x) x = self.bn2(x) x = npx.activation(x, act_type='relu') x = self.conv2(x) x = self.bn3(x) x = npx.activation(x, act_type='relu') x = self.conv3(x) return x + residual
# Nets
[docs]@use_np class ResNetV1(HybridBlock): r"""ResNet V1 model from `"Deep Residual Learning for Image Recognition" <http://arxiv.org/abs/1512.03385>`_ paper. Parameters ---------- block : gluon.HybridBlock Class for the residual block. Options are BasicBlockV1, BottleneckV1. layers : list of int Numbers of layers in each block channels : list of int Numbers of channels in each block. Length should be one larger than layers list. classes : int, default 1000 Number of classification classes. thumbnail : bool, default False Enable thumbnail. """ def __init__(self, block, layers, channels, classes=1000, thumbnail=False, **kwargs): super(ResNetV1, self).__init__(**kwargs) assert len(layers) == len(channels) - 1 self.features = nn.HybridSequential() if thumbnail: self.features.add(_conv3x3(channels[0], 1, 0)) else: self.features.add(nn.Conv2D(channels[0], 7, 2, 3, use_bias=False)) self.features.add(nn.BatchNorm()) self.features.add(nn.Activation('relu')) self.features.add(nn.MaxPool2D(3, 2, 1)) for i, num_layer in enumerate(layers): stride = 1 if i == 0 else 2 self.features.add(self._make_layer(block, num_layer, channels[i+1], stride, in_channels=channels[i])) self.features.add(nn.GlobalAvgPool2D()) self.output = nn.Dense(classes, in_units=channels[-1]) def _make_layer(self, block, layers, channels, stride, in_channels=0): layer = nn.HybridSequential() layer.add(block(channels, stride, channels != in_channels, in_channels=in_channels)) for _ in range(layers-1): layer.add(block(channels, 1, False, in_channels=channels)) return layer
[docs] def forward(self, x): x = self.features(x) x = self.output(x) return x
[docs]@use_np class ResNetV2(HybridBlock): r"""ResNet V2 model from `"Identity Mappings in Deep Residual Networks" <https://arxiv.org/abs/1603.05027>`_ paper. Parameters ---------- block : gluon.HybridBlock Class for the residual block. Options are BasicBlockV1, BottleneckV1. layers : list of int Numbers of layers in each block channels : list of int Numbers of channels in each block. Length should be one larger than layers list. classes : int, default 1000 Number of classification classes. thumbnail : bool, default False Enable thumbnail. """ def __init__(self, block, layers, channels, classes=1000, thumbnail=False, **kwargs): super(ResNetV2, self).__init__(**kwargs) assert len(layers) == len(channels) - 1 self.features = nn.HybridSequential() self.features.add(nn.BatchNorm(scale=False, center=False)) if thumbnail: self.features.add(_conv3x3(channels[0], 1, 0)) else: self.features.add(nn.Conv2D(channels[0], 7, 2, 3, use_bias=False)) self.features.add(nn.BatchNorm()) self.features.add(nn.Activation('relu')) self.features.add(nn.MaxPool2D(3, 2, 1)) in_channels = channels[0] for i, num_layer in enumerate(layers): stride = 1 if i == 0 else 2 self.features.add(self._make_layer(block, num_layer, channels[i+1], stride, in_channels=in_channels)) in_channels = channels[i+1] self.features.add(nn.BatchNorm()) self.features.add(nn.Activation('relu')) self.features.add(nn.GlobalAvgPool2D()) self.features.add(nn.Flatten()) self.output = nn.Dense(classes, in_units=in_channels) def _make_layer(self, block, layers, channels, stride, in_channels=0): layer = nn.HybridSequential() layer.add(block(channels, stride, channels != in_channels, in_channels=in_channels)) for _ in range(layers-1): layer.add(block(channels, 1, False, in_channels=channels)) return layer
[docs] def forward(self, x): x = self.features(x) x = self.output(x) return x
# Specification resnet_spec = {18: ('basic_block', [2, 2, 2, 2], [64, 64, 128, 256, 512]), 34: ('basic_block', [3, 4, 6, 3], [64, 64, 128, 256, 512]), 50: ('bottle_neck', [3, 4, 6, 3], [64, 256, 512, 1024, 2048]), 101: ('bottle_neck', [3, 4, 23, 3], [64, 256, 512, 1024, 2048]), 152: ('bottle_neck', [3, 8, 36, 3], [64, 256, 512, 1024, 2048])} resnet_net_versions = [ResNetV1, ResNetV2] resnet_block_versions = [{'basic_block': BasicBlockV1, 'bottle_neck': BottleneckV1}, {'basic_block': BasicBlockV2, 'bottle_neck': BottleneckV2}] # Constructor
[docs]@wrap_ctx_to_device_func def get_resnet(version, num_layers, pretrained=False, device=cpu(), root=os.path.join(base.data_dir(), 'models'), **kwargs): r"""ResNet V1 model from `"Deep Residual Learning for Image Recognition" <http://arxiv.org/abs/1512.03385>`_ paper. ResNet V2 model from `"Identity Mappings in Deep Residual Networks" <https://arxiv.org/abs/1603.05027>`_ paper. Parameters ---------- version : int Version of ResNet. Options are 1, 2. num_layers : int Numbers of layers. Options are 18, 34, 50, 101, 152. pretrained : bool, default False Whether to load the pretrained weights for model. device : Device, default CPU The device in which to load the pretrained weights. root : str, default $MXNET_HOME/models Location for keeping the model parameters. """ assert num_layers in resnet_spec, \ f"Invalid number of layers: {num_layers}. Options are {str(resnet_spec.keys())}" block_type, layers, channels = resnet_spec[num_layers] assert version >= 1 and version <= 2, \ f"Invalid resnet version: {version}. Options are 1 and 2." resnet_class = resnet_net_versions[version-1] block_class = resnet_block_versions[version-1][block_type] net = resnet_class(block_class, layers, channels, **kwargs) if pretrained: from ..model_store import get_model_file net.load_parameters(get_model_file(f'resnet{num_layers}_v{version}', root=root), device=device) return net
[docs]@wrap_ctx_to_device_func def resnet18_v1(**kwargs): r"""ResNet-18 V1 model from `"Deep Residual Learning for Image Recognition" <http://arxiv.org/abs/1512.03385>`_ paper. Parameters ---------- pretrained : bool, default False Whether to load the pretrained weights for model. device : Device, default CPU The device in which to load the pretrained weights. root : str, default '$MXNET_HOME/models' Location for keeping the model parameters. """ return get_resnet(1, 18, **kwargs)
[docs]@wrap_ctx_to_device_func def resnet34_v1(**kwargs): r"""ResNet-34 V1 model from `"Deep Residual Learning for Image Recognition" <http://arxiv.org/abs/1512.03385>`_ paper. Parameters ---------- pretrained : bool, default False Whether to load the pretrained weights for model. device : Device, default CPU The device in which to load the pretrained weights. root : str, default '$MXNET_HOME/models' Location for keeping the model parameters. """ return get_resnet(1, 34, **kwargs)
[docs]@wrap_ctx_to_device_func def resnet50_v1(**kwargs): r"""ResNet-50 V1 model from `"Deep Residual Learning for Image Recognition" <http://arxiv.org/abs/1512.03385>`_ paper. Parameters ---------- pretrained : bool, default False Whether to load the pretrained weights for model. device : Device, default CPU The device in which to load the pretrained weights. root : str, default '$MXNET_HOME/models' Location for keeping the model parameters. """ return get_resnet(1, 50, **kwargs)
[docs]@wrap_ctx_to_device_func def resnet101_v1(**kwargs): r"""ResNet-101 V1 model from `"Deep Residual Learning for Image Recognition" <http://arxiv.org/abs/1512.03385>`_ paper. Parameters ---------- pretrained : bool, default False Whether to load the pretrained weights for model. device : Device, default CPU The device in which to load the pretrained weights. root : str, default '$MXNET_HOME/models' Location for keeping the model parameters. """ return get_resnet(1, 101, **kwargs)
[docs]@wrap_ctx_to_device_func def resnet152_v1(**kwargs): r"""ResNet-152 V1 model from `"Deep Residual Learning for Image Recognition" <http://arxiv.org/abs/1512.03385>`_ paper. Parameters ---------- pretrained : bool, default False Whether to load the pretrained weights for model. device : Device, default CPU The device in which to load the pretrained weights. root : str, default '$MXNET_HOME/models' Location for keeping the model parameters. """ return get_resnet(1, 152, **kwargs)
[docs]@wrap_ctx_to_device_func def resnet18_v2(**kwargs): r"""ResNet-18 V2 model from `"Identity Mappings in Deep Residual Networks" <https://arxiv.org/abs/1603.05027>`_ paper. Parameters ---------- pretrained : bool, default False Whether to load the pretrained weights for model. device : Device, default CPU The device in which to load the pretrained weights. root : str, default '$MXNET_HOME/models' Location for keeping the model parameters. """ return get_resnet(2, 18, **kwargs)
[docs]@wrap_ctx_to_device_func def resnet34_v2(**kwargs): r"""ResNet-34 V2 model from `"Identity Mappings in Deep Residual Networks" <https://arxiv.org/abs/1603.05027>`_ paper. Parameters ---------- pretrained : bool, default False Whether to load the pretrained weights for model. device : Device, default CPU The device in which to load the pretrained weights. root : str, default '$MXNET_HOME/models' Location for keeping the model parameters. """ return get_resnet(2, 34, **kwargs)
[docs]@wrap_ctx_to_device_func def resnet50_v2(**kwargs): r"""ResNet-50 V2 model from `"Identity Mappings in Deep Residual Networks" <https://arxiv.org/abs/1603.05027>`_ paper. Parameters ---------- pretrained : bool, default False Whether to load the pretrained weights for model. device : Device, default CPU The device in which to load the pretrained weights. root : str, default '$MXNET_HOME/models' Location for keeping the model parameters. """ return get_resnet(2, 50, **kwargs)
[docs]@wrap_ctx_to_device_func def resnet101_v2(**kwargs): r"""ResNet-101 V2 model from `"Identity Mappings in Deep Residual Networks" <https://arxiv.org/abs/1603.05027>`_ paper. Parameters ---------- pretrained : bool, default False Whether to load the pretrained weights for model. device : Device, default CPU The device in which to load the pretrained weights. root : str, default '$MXNET_HOME/models' Location for keeping the model parameters. """ return get_resnet(2, 101, **kwargs)
[docs]@wrap_ctx_to_device_func def resnet152_v2(**kwargs): r"""ResNet-152 V2 model from `"Identity Mappings in Deep Residual Networks" <https://arxiv.org/abs/1603.05027>`_ paper. Parameters ---------- pretrained : bool, default False Whether to load the pretrained weights for model. device : Device, default CPU The device in which to load the pretrained weights. root : str, default '$MXNET_HOME/models' Location for keeping the model parameters. """ return get_resnet(2, 152, **kwargs)