Let’s start with the implementation of scaled dot-product attention from The Annotated Transformer.
def attention(query, key, value, mask, dropout):
d_k = query.size(-1)
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
p_attn = scores.softmax(dim=-1)
if dropout is not None:
p_attn = dropout(p_attn)
return torch.matmul(p_attn, value), p_attnLet’s add some type hints. (See the constants glossary.)
def attention(
query: Tensor, # (B x H x L x D_k)
key: Tensor, # (B x H x L x D_k)
value: Tensor, # (B x H x L x D_k)
mask: Tensor | None, # (B x 1 x 1 x L)
dropout: nn.Dropout | None
) -> tuple[Tensor, Tensor]:
d_k: int = query.size(-1)
scores: Tensor = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
p_attn: Tensor = scores.softmax(dim=-1)
if dropout is not None:
p_attn = dropout(p_attn)
return torch.matmul(p_attn, value), p_attnWe’re passing in that optional dropout layer because this is implemented outside of a class. If it were an instance method, we could use the instance’s dropout layer, and just check if we’re in eval mode. But let’s set that aside. What’s in each of the parameters?
Dimension analysis
Recall that scaled dot-product attention is defined as
where attention is a
Now, in Vaswani, et al. (2017), they have
What about the dimension of mask? This argument is used for masked attention in decoder blocks. It’s going to be of binary type and we only need one bit per sequence position. So mask is going to be
What about the fraction inside the parentheses? The denominator is just a scalar, so we ignore it.
So the argument to softmax is
The first element of the returned tuple consists of this final attention value, and the second element is the softmax output. So its dimensions are
Syntactic analysis
Now let’s go line by line to ensure that I understand all of the PyTorch syntax.
d_k: int = query.size(-1)The tensor.size() method returns either a tuple of dimension sizes or, if passed an integer index corresponding to a dimension, an integer. The index syntax is the same as that of a Python array, so query.size(-1) returns the last dimension of query, which in this case is 64.
scores: Tensor = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)This line corresponds to tensor.transpose(x, y) is an alias to torch.transpose(tensor, x, y). This method returns a tensor whose x and y dimensions have been swapped.
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)The tensor.masked_fill(bitmask, value) method returns a tensor whose elements have been replaced with value if the bitmask is True at the corresponding location. The expression mask == 0 broadcasts the equality operation to every element in scores.
p_attn: Tensor = scores.softmax(dim=-1)Again, dim follows the conventions of Python array indexing. So this says that the softmax operation should be applied to the vectors represented by the last dimension of the scores tensor. It will apply softmax to each such vector independently. Let’s unpack why this is what we want.
scores is a scores represents a bunch of column vectors
So when we softmax, we want to be softmaxing over these column vectors. Since the last two dimensions of the tensor represent the score matrix for a sequence, the last dimension represents that matrix’ columns. So we are softmaxing over the last dimension.
if dropout is not None:
p_attn = dropout(p_attn)
The only tunable parameter on torch.nn.Dropout is the dropout rate. So in principle, we could just pass in this parameter and instantiate a Dropout every time we do this operation. It’s a small lift, but we’re doing it a lot of times, so it probably would add a nontrivial prefactor. Still, I find this whole “pass the dropout layer” thing ugly and weird. In a production class, I would have made scaled dot-product attention an instance method of the MultiheadAttention module.
return torch.matmul(p_attn, value), p_attnThis should have been spread onto two lines, because it’s doing two things. The first is to multiply the weights by the values, resulting in the context vectors. The second is to return a tuple consisting of the context vectors and the weights.