Layers and Blocks

As network complexity increases, we move from designing single to entire layers of neurons.

Neural network designs like ResNet-152 have a fair degree of regularity. They consist of blocks of repeated (or at least similarly designed) layers; these blocks then form the basis of more complex network designs.

In this section, we’ll talk about how to write code that makes such blocks on demand, just like a Lego factory generates blocks which can be combined to produce terrific artifacts.

We start with a very simple block, namely the block for a multilayer perceptron. A common strategy would be to construct a two-layer network as follows:

[1]:
import mxnet as mx
from mxnet import np, npx
from mxnet.gluon import nn, Block, Parameter, Constant


x = np.random.uniform(size=(2, 20))

net = nn.Sequential()
net.add(nn.Dense(256, activation='relu'))
net.add(nn.Dense(10))
net.initialize()
net(x)
[04:46:04] /work/mxnet/src/storage/storage.cc:202: Using Pooled (Naive) StorageManager for CPU
[1]:
array([[-0.0133777 , -0.01564407, -0.06515459, -0.03360514, -0.02242954,
        -0.00461456, -0.03745391,  0.06270804,  0.01931533,  0.00364962],
       [-0.00507566, -0.00843352, -0.01791815, -0.07735927, -0.05121508,
        -0.01204855,  0.00963113,  0.07037181, -0.0080541 ,  0.00850688]])

This generates a network with a hidden layer of \(256\) units, followed by a ReLU activation and another \(10\) units governing the output. In particular, we used the nn.Sequential constructor to generate an empty network into which we then inserted both layers. What exactly happens inside nn.Sequential has remained rather mysterious so far. In the following we will see that this really just constructs a block that is a container for other blocks. These blocks can be combined into larger artifacts, often recursively. The diagram below shows how:

Blocks can be used recursively to form larger artifacts

In the following we will explain the various steps needed to go from defining layers to defining blocks (of one or more layers):

  1. Blocks take data as input.

  2. Blocks store state in the form of parameters that are inherent to the block. For instance, the block above contains two hidden layers, and we need a place to store parameters for it.

  3. Blocks produce meaningful output. This is typically encoded in what we will call the forward function. It allows us to invoke a block via net(X) to obtain the desired output. What happens behind the scenes is that it invokes forward to perform forward propagation (also called forward computation).

  4. Blocks initialize the parameters in a lazy fashion as part of the first forward call.

  5. Blocks calculate a gradient with regard to their input when invoking backward. Typically this is automatic.

A Sequential Block

The Block class is a generic component describing data flow. When the data flows through a sequence of blocks, each block applied to the output of the one before with the first block being applied on the input data itself, we have a special kind of block, namely the Sequential block.

Sequential has helper methods to manage the sequence, with add being the main one of interest allowing you to append blocks in sequence. Once the operations have been added, the forward computation of the model applies the blocks on the input data in the order they were added. Below, we implement a MySequential class that has the same functionality as the Sequential class. This may help you understand more clearly how the Sequential class works.

[2]:
class MySequential(Block):
    def __init__(self):
        super(MySequential, self).__init__()
        self._layers = []

    def add(self, block):
        # Here, block is an instance of a Block subclass, and we assume it has a unique name. We save it in the
        # member variable _layers of the Block class, and its type is List. When the MySequential instance
        # calls the initialize function, the system automatically initializes all members of _layers.
        self._layers.append(block)
        self.register_child(block)

    def forward(self, x):
        # OrderedDict guarantees that members will be traversed in the order they were added.
        for block in self._children.values():
            x = block()(x)
        return x

At its core is the add method. It adds any block to the ordered dictionary of children. These are then executed in sequence when forward propagation is invoked. Let’s see what the MLP looks like now.

[3]:
net = MySequential()
net.add(nn.Dense(256, activation='relu'))
net.add(nn.Dense(10))
net.initialize()
net(x)
[3]:
array([[ 0.03636155,  0.00339161, -0.00965453, -0.03355846, -0.02563929,
         0.06889853, -0.08230178, -0.02361742, -0.0281327 , -0.02201413],
       [ 0.021716  , -0.0254443 , -0.005876  , -0.01355301,  0.02687372,
         0.06052493, -0.07821415,  0.00131516, -0.03081927, -0.03834141]])

Indeed, it is no different than It can observed here that the use of the MySequential class is no different from the use of the Sequential class.

A Custom Block

It is easy to go beyond simple concatenation with Sequential. The Block class provides the functionality required to make such customizations. Block has a model constructor provided in the nn module, which we can inherit to define the model we want. The following inherits the Block class to construct the multilayer perceptron mentioned at the beginning of this section. The MLP class defined here overrides the __init__ and forward functions of the Block class. They are used to create model parameters and define forward computations, respectively. Forward computation is also forward propagation.

[4]:
class MLP(nn.Block):
    # Declare a layer with model parameters. Here, we declare two fully
    # connected layers.

    def __init__(self, **kwargs):
        # Call the constructor of the MLP parent class Block to perform the
        # necessary initialization. In this way, other function parameters can
        # also be specified when constructing an instance, such as the model
        # parameter, params, described in the following sections.
        super(MLP, self).__init__(**kwargs)
        self.hidden = nn.Dense(256, activation='relu')  # Hidden layer
        self.output = nn.Dense(10)  # Output layer

    # Define the forward computation of the model, that is, how to return the
    # required model output based on the input x.

    def forward(self, x):
        hidden_out = self.hidden(x)
        return self.output(hidden_out)

Let’s look at it a bit more closely. The forward method invokes a network simply by evaluating the hidden layer self.hidden(x) and subsequently by evaluating the output layer self.output( ... ). This is what we expect in the forward pass of this block.

In order for the block to know what it needs to evaluate, we first need to define the layers. This is what the __init__ method does. It first initializes all of the Block-related parameters and then constructs the requisite layers. This attaches the coresponding layers and the required parameters to the class. Note that there is no need to define a backpropagation method in the class. The system automatically generates the backward method needed for back propagation by automatically finding the gradient (see the tutorial on autograd). The same applies to the initialize method, which is generated automatically. Let’s try this out:

[5]:
net = MLP()
net.initialize()
net(x)
[5]:
array([[-0.05392137, -0.06907995,  0.01572804, -0.0730975 , -0.08156237,
        -0.01220089, -0.00389173, -0.01574439,  0.05399816, -0.00557903],
       [-0.05604443, -0.00363856,  0.0165119 , -0.07483605, -0.05082021,
        -0.02666114, -0.0373126 , -0.07433479,  0.02280258,  0.04191535]])

As explained above, the Block class can be quite versatile in terms of what it does. For instance, its subclass can be a layer (such as the Dense class provided by Gluon), it can be a model (such as the MLP class we just derived), or it can be a part of a model (this is what typically happens when designing very deep networks). Throughout this chapter we will see how to use this with great flexibility.

Coding with Blocks

Blocks

The Sequential class can make model construction easier and does not require you to define the forward method; however, directly inheriting from its parent class, Block, can greatly expand the flexibility of model construction. For example, implementing the forward method means you can introduce control flow in the network.

Constant parameters

Now we’d like to introduce the notation of a constant parameter. These are parameters that are not used when invoking backpropagation. This sounds very abstract but here’s what’s really going on. Assume that we have some function

\[f(\mathbf{x},\mathbf{w}) = 3 \cdot \mathbf{w}^\top \mathbf{x}.\]

In this case \(3\) is a constant parameter. We could change \(3\) to something else, say \(c\) via

\[f(\mathbf{x},\mathbf{w}) = c \cdot \mathbf{w}^\top \mathbf{x}.\]

Nothing has really changed, except that we can adjust the value of \(c\). It is still a constant as far as \(\mathbf{w}\) and \(\mathbf{x}\) are concerned. However, Gluon doesn’t know about this unless we create it with get_constant (this makes the code go faster, too, since we’re not sending the Gluon engine on a wild goose chase after a parameter that doesn’t change).

[6]:
class FancyMLP(nn.Block):
    def __init__(self, **kwargs):
        super(FancyMLP, self).__init__(**kwargs)

        # Random weight parameters created with the get_constant are not
        # iterated during training (i.e. constant parameters).
        self.rand_weight = Constant(np.random.uniform(size=(20, 20)))
        self.dense = nn.Dense(20, activation='relu')

    def forward(self, x):
        x = self.dense(x)
        # Use the constant parameters created, as well as the ReLU and dot
        # functions of NDArray.

        x = npx.relu(np.dot(x, self.rand_weight.data()) + 1)
        # Re-use the fully connected layer. This is equivalent to sharing
        # parameters with two fully connected layers.
        x = self.dense(x)
        # Here in the control flow, we need to call `item` to return the
        # scalar for comparison.

        while npx.norm(x).item() > 1:
            x /= 2
        if npx.norm(x).item() < 0.8:
            x *= 10
        return x.sum()

In this FancyMLP model, we used constant weight rand_weight (note that it is not a model parameter), performed a matrix multiplication operation (nd.dot), and reused the same Dense layer. Note that this is very different from using two dense layers with different sets of parameters. Instead, we used the same network twice. Quite often in deep networks one also says that the parameters are tied when one wants to express that multiple parts of a network share the same parameters. Let’s see what happens if we construct it and feed data through it.

[7]:
net = FancyMLP()
net.initialize()
net(x)
[7]:
array(22.585115)

There’s no reason why we couldn’t mix and match these ways of building a network. Obviously the example below resembles a Rube Goldberg Machine. That said, it combines examples for building a block from individual blocks, which in turn, may be blocks themselves. Furthermore, we can even combine multiple strategies inside the same forward function. To demonstrate this, here’s the network.

[8]:
class NestMLP(nn.Block):
    def __init__(self, **kwargs):
        super(NestMLP, self).__init__(**kwargs)
        self.net = nn.Sequential()
        self.net.add(nn.Dense(64, activation='relu'),
                     nn.Dense(32, activation='relu'))
        self.dense = nn.Dense(16, activation='relu')

    def forward(self, x):
        return self.dense(self.net(x))

chimera = nn.Sequential()
chimera.add(NestMLP(), nn.Dense(20), FancyMLP())

chimera.initialize()
chimera(x)
[8]:
array(21.727867)

Hybridization

The reader may be starting to think about the efficiency of this Python code. After all, we have lots of dictionary lookups, code execution, and lots of other Pythonic things going on in what is supposed to be a high performance deep learning library. The problems of Python’s Global Interpreter Lock are well known.

In the device of deep learning, we often have highly performant GPUs that depend on CPUs running Python to tell them what to do. This mismatch can manifest in the form of GPU starvation when the CPUs can not provide instruction fast enough. We can improve this situation by deferring to a more performant language instead of Python when possible.

Gluon does this by allowing for Hybridization. In it, the Python interpreter executes the block the first time it’s invoked. The Gluon runtime records what is happening and the next time around it short circuits any calls to Python. This can accelerate things considerably in some cases but care needs to be taken with control flow.