In our LeNet example, the training function (“training loop”) is as follows:
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
device = "mps"
def train(dataloader, model, loss_fn, optimizer):
size = len(dataloader.dataset)
model.train()
for batch, (X, y) in enumerate(dataloader):
X, y = X.to(device), y.to(device)
# Compute prediction error
pred = model(X)
loss = loss_fn(pred, y)
# Backpropagation
loss.backward()
optimizer.step()
optimizer.zero_grad()
if batch % 100 == 0:
loss, current = loss.item(), (batch + 1) * len(X)
print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")
# In test-train loop
train(train_dataloader, model, criterion, optimizer)
There’s a lot to unpack here. Let’s go line by line through the key ones.
criterion = nn.CrossEntropyLoss()
This call returns a function (really a sort of command executed via the forward method) that applies a cross-entropy loss function. By calling this function on a torch.Tensor, the resulting loss computation is added to the computational graph.
When we later call loss.backward(), it follows the computational graph from this loss calculation (which should be, and usually is, the starting node of the graph) all the way back to the parameters (which are leaf nodes). At each node in the graph, loss.backward() computes (and stores) the gradient of the loss with respect to the computation represented by the node.
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
As noted, the parameters are leaf nodes of the computational graph. When we later call optimizer.step(), it will go straight to these nodes to retrieve their accumulated gradients, to which it will apply its update rule (or “optimizer”).
In this case, the update rule is the unfortunately named “SGD” (stochastic gradient descent) optimizer, which of course has nothing stochastic about it because we spoon-fed it the observations in the minibatch. But what they mean is the OG update rule used by Rumelhart, Hinton, and Williams (1986), including a momentum term.
def train(dataloader, model, loss_fn, optimizer):
While not actually part of the PyTorch framework itself, this pattern is basically expected. I think of it as a blessed convention akin to the verbose but ever-familiar public static Void main(String[] args) in Java.
model.train()
“train” mode has nothing to do with autograd, despite appearances. Rather, it sets a stateful flag on the model instance that it is in “training” mode. This propagates to the submodules (layers), and signals to training-only layers like dropout and batch normalization that they should be active.
for batch, (X, y) in enumerate(dataloader):
The DataLoader class has an __iter__ method that yields minibatch feature-label tensor pairs.
X, y = X.to(device), y.to(device)
Out of the box, PyTorch can run on CPUs or on GPUs (via Nvidia CUDA or Apple MPS), and probably some other stuff. When you call Tensor.to(device), it physically copies the state over to this other device.
You need to do this for everything involved in your computation! If you look at the full example, you’ll see we did it for model as well.
# Compute prediction error
pred = model(X)
loss = loss_fn(pred, y)
# Backpropagation
loss.backward()
optimizer.step()
optimizer.zero_grad()
As described above, there’s a lot of autograd magic going on here. When we call model(X), it uses the model’s forward method, the resulting tensor pred gets added to the computational graph. When we calculate the loss using it and the ground truth labels y, this also gets added to the computational graph.
So the computation that produced loss is the last node in the computational graph. Consequently, when we call loss.backward(), computing that node’s gradient of the error causes the gradient to be propagated all the way through the network.
The resulting gradient is accumulated in the leaf nodes of the graph (i.e., the parameters). Recall that optimizer got these parameters as an argument. So when we call optimizer.step() (to apply the update rule), it knows just where to look for those gradients and which parameters to change as a result.
The gradients accumulate until zero_grad gets called. It’s a bit of a weird process with some side effects, so consult the linked page for the whole story.