Consider LeNet, a ConvNet for MNIST. At a very high level, it will look something like this:

training_data = datasets.MNIST(...)
test_data = datasets.MNIST(...)

train_dataloader = DataLoader(...)
test_dataloader = DataLoader(...)

class NeuralNetwork(nn.Module):
    def __init__(self):
		...
		
    def forward(self, x):
		...
		
def train(dataloader, model, loss_fn, optimizer):
    ​for batch, (X, y) in enumerate(dataloader):
	    ...

def test(dataloader, model, loss_fn):
    with torch.no_grad():
        for X, y in dataloader:
	        ...
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

for t in range(epochs):
    train(train_dataloader, model, loss_fn, optimizer)
    test(test_dataloader, model, loss_fn)

In other words, we have the following pieces:

  1. 01 Data import and preprocessing (PyTorch)
  2. 02 Defining a custom PyTorch module
  3. 03 PyTorch training function
  4. 04 PyTorch test function
  5. 05 PyTorch train-test loop

Each are explored in more detail in the linked pages.

Complete script for examples

import torch
import torch.nn as nn
import torch.optim as optim

from torchvision import datasets, transforms
from torch.utils.data import DataLoader

batch_size = 64
device="mps"

transform = transforms.Compose([
	transforms.Pad(2),
	transforms.ToTensor(),
])

training_data = datasets.MNIST(
    root="mnist",
    train=True,
    download=True,
	transform=transform,
)

test_data = datasets.MNIST(
    root="mnist",
    train=False,
    download=True,
    transform=transform,
)

train_dataloader = DataLoader(training_data, batch_size=batch_size)
test_dataloader = DataLoader(test_data, batch_size=batch_size)

class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()

        # Input is 1x32x32
        self.stack = nn.Sequential(
            # 1x32x32 -> 6x28x28
            nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5),
            nn.ReLU(),

            # 6x28x28 -> 6x14x14
            nn.MaxPool2d(kernel_size=2, stride=2),

            # 6x14x14 -> 16x10x10
            nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5),
            nn.ReLU(),

            # 16x10x10 -> 16x5x5
            nn.MaxPool2d(2),

            # 16x5x5 -> 160x1
            nn.Flatten(),

            # 160x1 -> 120x1
            nn.Linear(in_features=16*5*5, out_features=120),
            nn.ReLU(),

            # 120x1 -> 84x1
            nn.Linear(in_features=120, out_features=84),
            nn.ReLU(),

            # 84x1 -> 10x1
            nn.Linear(in_features=84, out_features=10)
        )

    def forward(self, x):
        return self.stack(x)

model = LeNet().to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        # Compute prediction error
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        if batch % 100 == 0:
            loss, current = loss.item(), (batch + 1) * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

epochs = 5
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train(train_dataloader, model, criterion, optimizer)
    test(test_dataloader, model, criterion)
print("Done!")