PyTorch highly recommends that we develop a custom nn.Module for all nontrivial models. In our example, we implement LeNet using the very convenient nn.Sequential container.
The main things to know are:
- Be sure to call the superclass constructor
- Define all your layers in the constructor so that their parameters can be registered
- Define a forward operation, which is often pretty lightweight
However, more detail is discussed here.
LeNet has the following architecture:
-learning-(data-science)/30-Model-classes/10-Neural-networks/03-Architectures/Convolutional-neural-network-(CNN)/LeNet-architecture.png)
We implement it as follows:
import torch.nn as nn
class LeNet(nn.Module):
def __init__(self):
super(LeNet, self).__init__()
# Input is 1x32x32
self.stack = nn.Sequential(
# 1x32x32 -> 6x28x28
nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5),
nn.ReLU(),
# 6x28x28 -> 6x14x14
nn.MaxPool2d(kernel_size=2, stride=2),
# 6x14x14 -> 16x10x10
nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5),
nn.ReLU(),
# 16x10x10 -> 16x5x5
nn.MaxPool2d(2),
# 16x5x5 -> 160x1
nn.Flatten(),
# 160x1 -> 120x1
nn.Linear(in_features=16*5*5, out_features=120),
nn.ReLU(),
# 120x1 -> 84x1
nn.Linear(in_features=120, out_features=84),
nn.ReLU(),
# 84x1 -> 10x1
nn.Linear(in_features=84, out_features=10)
)
def forward(self, x):
return self.stack(x)
A more instructive example
While the above is how these things should really be written, the tutorial provides a (somewhat more contrived) version of the same thing, which sheds a bit more light on what’s happening:
class LeNet(nn.Module):
def __init__(self):
super(LeNet, self).__init__()
# 1 input image channel (black & white), 6 output channels, 5x5 square convolution
# kernel
self.conv1 = nn.Conv2d(1, 6, 5)
self.conv2 = nn.Conv2d(6, 16, 5)
# an affine operation: y = Wx + b
self.fc1 = nn.Linear(16 * 5 * 5, 120) # 5*5 from image dimension
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
# Max pooling over a (2, 2) window
x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
# If the size is a square you can only specify a single number
x = F.max_pool2d(F.relu(self.conv2(x)), 2)
x = x.view(-1, self.num_flat_features(x))
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
def num_flat_features(self, x):
size = x.size()[1:] # all dimensions except the batch dimension
num_features = 1
for s in size:
num_features *= s
return num_features
Things to notice:
- The parameters are registered in
__init__. - Operations in the
forwardmethod are composed sequentially. - We have an arbitrary function,
num_flat_features, and that’s fine. - As revealed by extra argument to the
x.viewoperation, PyTorch assumes that the input toforwardis a batch of unrelated tensors.