Source code for mxnet.gluon.model_zoo.vision.vgg

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

# coding: utf-8
# pylint: disable= arguments-differ
"""VGG, implemented in Gluon."""
from __future__ import division
__all__ = ['VGG',
           'vgg11', 'vgg13', 'vgg16', 'vgg19',
           'vgg11_bn', 'vgg13_bn', 'vgg16_bn', 'vgg19_bn',
           'get_vgg']

from ....context import cpu
from ....initializer import Xavier
from ...block import HybridBlock
from ... import nn


[docs]class VGG(HybridBlock): r"""VGG model from the `"Very Deep Convolutional Networks for Large-Scale Image Recognition" `_ paper. Parameters ---------- layers : list of int Numbers of layers in each feature block. filters : list of int Numbers of filters in each feature block. List length should match the layers. classes : int, default 1000 Number of classification classes. batch_norm : bool, default False Use batch normalization. """ def __init__(self, layers, filters, classes=1000, batch_norm=False, **kwargs): super(VGG, self).__init__(**kwargs) assert len(layers) == len(filters) with self.name_scope(): self.features = self._make_features(layers, filters, batch_norm) self.classifier = nn.HybridSequential(prefix='') self.classifier.add(nn.Dense(4096, activation='relu', weight_initializer='normal', bias_initializer='zeros')) self.classifier.add(nn.Dropout(rate=0.5)) self.classifier.add(nn.Dense(4096, activation='relu', weight_initializer='normal', bias_initializer='zeros')) self.classifier.add(nn.Dropout(rate=0.5)) self.classifier.add(nn.Dense(classes, weight_initializer='normal', bias_initializer='zeros')) def _make_features(self, layers, filters, batch_norm): featurizer = nn.HybridSequential(prefix='') for i, num in enumerate(layers): for _ in range(num): featurizer.add(nn.Conv2D(filters[i], kernel_size=3, padding=1, weight_initializer=Xavier(rnd_type='gaussian', factor_type='out', magnitude=2), bias_initializer='zeros')) if batch_norm: featurizer.add(nn.BatchNorm()) featurizer.add(nn.Activation('relu')) featurizer.add(nn.MaxPool2D(strides=2)) return featurizer def hybrid_forward(self, F, x): x = self.features(x) x = self.classifier(x) return x
# Specification vgg_spec = {11: ([1, 1, 2, 2, 2], [64, 128, 256, 512, 512]), 13: ([2, 2, 2, 2, 2], [64, 128, 256, 512, 512]), 16: ([2, 2, 3, 3, 3], [64, 128, 256, 512, 512]), 19: ([2, 2, 4, 4, 4], [64, 128, 256, 512, 512])} # Constructors def get_vgg(num_layers, pretrained=False, ctx=cpu(), **kwargs): r"""VGG model from the `"Very Deep Convolutional Networks for Large-Scale Image Recognition" `_ paper. Parameters ---------- num_layers : int Number of layers for the variant of densenet. Options are 11, 13, 16, 19. pretrained : bool, default False Whether to load the pretrained weights for model. ctx : Context, default CPU The context in which to load the pretrained weights. """ layers, filters = vgg_spec[num_layers] net = VGG(layers, filters, **kwargs) if pretrained: from ..model_store import get_model_file batch_norm_suffix = '_bn' if kwargs.get('batch_norm') else '' net.load_params(get_model_file('vgg%d%s'%(num_layers, batch_norm_suffix)), ctx=ctx) return net def vgg11(**kwargs): r"""VGG-11 model from the `"Very Deep Convolutional Networks for Large-Scale Image Recognition" `_ paper. Parameters ---------- pretrained : bool, default False Whether to load the pretrained weights for model. ctx : Context, default CPU The context in which to load the pretrained weights. """ return get_vgg(11, **kwargs) def vgg13(**kwargs): r"""VGG-13 model from the `"Very Deep Convolutional Networks for Large-Scale Image Recognition" `_ paper. Parameters ---------- pretrained : bool, default False Whether to load the pretrained weights for model. ctx : Context, default CPU The context in which to load the pretrained weights. """ return get_vgg(13, **kwargs) def vgg16(**kwargs): r"""VGG-16 model from the `"Very Deep Convolutional Networks for Large-Scale Image Recognition" `_ paper. Parameters ---------- pretrained : bool, default False Whether to load the pretrained weights for model. ctx : Context, default CPU The context in which to load the pretrained weights. """ return get_vgg(16, **kwargs) def vgg19(**kwargs): r"""VGG-19 model from the `"Very Deep Convolutional Networks for Large-Scale Image Recognition" `_ paper. Parameters ---------- pretrained : bool, default False Whether to load the pretrained weights for model. ctx : Context, default CPU The context in which to load the pretrained weights. """ return get_vgg(19, **kwargs) def vgg11_bn(**kwargs): r"""VGG-11 model with batch normalization from the `"Very Deep Convolutional Networks for Large-Scale Image Recognition" `_ paper. Parameters ---------- pretrained : bool, default False Whether to load the pretrained weights for model. ctx : Context, default CPU The context in which to load the pretrained weights. """ kwargs['batch_norm'] = True return get_vgg(11, **kwargs) def vgg13_bn(**kwargs): r"""VGG-13 model with batch normalization from the `"Very Deep Convolutional Networks for Large-Scale Image Recognition" `_ paper. Parameters ---------- pretrained : bool, default False Whether to load the pretrained weights for model. ctx : Context, default CPU The context in which to load the pretrained weights. """ kwargs['batch_norm'] = True return get_vgg(13, **kwargs) def vgg16_bn(**kwargs): r"""VGG-16 model with batch normalization from the `"Very Deep Convolutional Networks for Large-Scale Image Recognition" `_ paper. Parameters ---------- pretrained : bool, default False Whether to load the pretrained weights for model. ctx : Context, default CPU The context in which to load the pretrained weights. """ kwargs['batch_norm'] = True return get_vgg(16, **kwargs) def vgg19_bn(**kwargs): r"""VGG-19 model with batch normalization from the `"Very Deep Convolutional Networks for Large-Scale Image Recognition" `_ paper. Parameters ---------- pretrained : bool, default False Whether to load the pretrained weights for model. ctx : Context, default CPU The context in which to load the pretrained weights. """ kwargs['batch_norm'] = True return get_vgg(19, **kwargs)