PyTorch highly recommends that we develop a custom nn.Module for all nontrivial models. In our example, we implement LeNet using the very convenient nn.Sequential container.

The main things to know are:

  • Be sure to call the superclass constructor
  • Define all your layers in the constructor so that their parameters can be registered
  • Define a forward operation, which is often pretty lightweight

However, more detail is discussed here.

LeNet has the following architecture:

We implement it as follows:

import torch.nn as nn

class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()

        # Input is 1x32x32
        self.stack = nn.Sequential(
            # 1x32x32 -> 6x28x28
            nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5),
            nn.ReLU(),

            # 6x28x28 -> 6x14x14
            nn.MaxPool2d(kernel_size=2, stride=2),

            # 6x14x14 -> 16x10x10
            nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5),
            nn.ReLU(),

            # 16x10x10 -> 16x5x5
            nn.MaxPool2d(2),

            # 16x5x5 -> 160x1
            nn.Flatten(),

            # 160x1 -> 120x1
            nn.Linear(in_features=16*5*5, out_features=120),
            nn.ReLU(),

            # 120x1 -> 84x1
            nn.Linear(in_features=120, out_features=84),
            nn.ReLU(),

            # 84x1 -> 10x1
            nn.Linear(in_features=84, out_features=10)
        )

    def forward(self, x):
        return self.stack(x)

A more instructive example

While the above is how these things should really be written, the tutorial provides a (somewhat more contrived) version of the same thing, which sheds a bit more light on what’s happening:

class LeNet(nn.Module):

    def __init__(self):
        super(LeNet, self).__init__()
        # 1 input image channel (black & white), 6 output channels, 5x5 square convolution
        # kernel
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.conv2 = nn.Conv2d(6, 16, 5)
        # an affine operation: y = Wx + b
        self.fc1 = nn.Linear(16 * 5 * 5, 120)  # 5*5 from image dimension
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        # Max pooling over a (2, 2) window
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        # If the size is a square you can only specify a single number
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(-1, self.num_flat_features(x))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

    def num_flat_features(self, x):
        size = x.size()[1:]  # all dimensions except the batch dimension
        num_features = 1
        for s in size:
            num_features *= s
        return num_features

Things to notice:

  • The parameters are registered in __init__.
  • Operations in the forward method are composed sequentially.
  • We have an arbitrary function, num_flat_features, and that’s fine.
  • As revealed by extra argument to the x.view operation, PyTorch assumes that the input to forward is a batch of unrelated tensors.