Let’s start with the implementation of scaled dot-product attention from The Annotated Transformer.

def attention(query, key, value, mask, dropout):  
    d_k = query.size(-1)  
    scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)  
    if mask is not None:  
        scores = scores.masked_fill(mask == 0, -1e9)  
    p_attn = scores.softmax(dim=-1)  
    if dropout is not None:  
        p_attn = dropout(p_attn)  
    return torch.matmul(p_attn, value), p_attn

Let’s add some type hints. (See the constants glossary.)

def attention(
	query: Tensor,                # (B x H x L x D_k)
	key: Tensor,                  # (B x H x L x D_k)
	value: Tensor,                # (B x H x L x D_k)
	mask: Tensor | None,          # (B x 1 x 1 x L)
	dropout: nn.Dropout | None
) -> tuple[Tensor, Tensor]:
	d_k: int = query.size(-1)
	scores: Tensor = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
	if mask is not None:
		scores = scores.masked_fill(mask == 0, -1e9)
	p_attn: Tensor = scores.softmax(dim=-1)
    if dropout is not None:  
        p_attn = dropout(p_attn)
    return torch.matmul(p_attn, value), p_attn

We’re passing in that optional dropout layer because this is implemented outside of a class. If it were an instance method, we could use the instance’s dropout layer, and just check if we’re in eval mode. But let’s set that aside. What’s in each of the parameters?

Dimension analysis

Recall that scaled dot-product attention is defined as

where , , and are all matrices whose columns are the input sequence projected into learned query, key, and value vector spaces respectively. The input to attention is a tensor, where is the batch size, is the length of the input, and is the dimension of the vector space.

Now, in Vaswani, et al. (2017), they have attention heads, and they fix all dimensions to . They divide this model dimensionality across the heads so that the final result has the right dimension. Hence the column vectors in these matrices are each elements along. So .

What about the dimension of mask? This argument is used for masked attention in decoder blocks. It’s going to be of binary type and we only need one bit per sequence position. So mask is going to be . This is “unsqueezed” in the implementation of multi-head attention in order to make broadcast to all matrices possible.

What about the fraction inside the parentheses? The denominator is just a scalar, so we ignore it. and are both , so is going to be This should make intuitive sense: the output of the softmax is an all-by-all comparison of the “relevance” of each position in the sequence to each other position in the sequence.

So the argument to softmax is . Softmax doesn’t change the dimension, so it stays that way. Finally, we multiply this whole quantity by , which has dimension . So the result will once again end up with dimension .

The first element of the returned tuple consists of this final attention value, and the second element is the softmax output. So its dimensions are

Syntactic analysis

Now let’s go line by line to ensure that I understand all of the PyTorch syntax.

d_k: int = query.size(-1)

The tensor.size() method returns either a tuple of dimension sizes or, if passed an integer index corresponding to a dimension, an integer. The index syntax is the same as that of a Python array, so query.size(-1) returns the last dimension of query, which in this case is 64.

scores: Tensor = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)

This line corresponds to . tensor.transpose(x, y) is an alias to torch.transpose(tensor, x, y). This method returns a tensor whose x and y dimensions have been swapped.

if mask is not None:
	scores = scores.masked_fill(mask == 0, -1e9)

The tensor.masked_fill(bitmask, value) method returns a tensor whose elements have been replaced with value if the bitmask is True at the corresponding location. The expression mask == 0 broadcasts the equality operation to every element in scores.

p_attn: Tensor = scores.softmax(dim=-1)

Again, dim follows the conventions of Python array indexing. So this says that the softmax operation should be applied to the vectors represented by the last dimension of the scores tensor. It will apply softmax to each such vector independently. Let’s unpack why this is what we want.

scores is a tensor, which we can think of as a batch of matrices. Each matrix represents the all-vs-all relevance scores between each position in an input sequence and each other position in that input sequence (including itself). Put another way, a given matrix in scores represents a bunch of column vectors . Each element in column vector is a representation of the relevance of the token at position to the position . (Not the token at , mind you; the position . A context vector is an expected value for a position.)

So when we softmax, we want to be softmaxing over these column vectors. Since the last two dimensions of the tensor represent the score matrix for a sequence, the last dimension represents that matrix’ columns. So we are softmaxing over the last dimension.

if dropout is not None:  
	p_attn = dropout(p_attn)

The only tunable parameter on torch.nn.Dropout is the dropout rate. So in principle, we could just pass in this parameter and instantiate a Dropout every time we do this operation. It’s a small lift, but we’re doing it a lot of times, so it probably would add a nontrivial prefactor. Still, I find this whole “pass the dropout layer” thing ugly and weird. In a production class, I would have made scaled dot-product attention an instance method of the MultiheadAttention module.

return torch.matmul(p_attn, value), p_attn

This should have been spread onto two lines, because it’s doing two things. The first is to multiply the weights by the values, resulting in the context vectors. The second is to return a tuple consisting of the context vectors and the weights.