Constants used

The following uses conventions from the constants glossary.

Analysis

Let’s start with the implementation of multi-head attention in The Annotated Transformer. I have gone ahead and added type hints, but otherwise this is as reported.

class MultiHeadedAttention(nn.Module):
    def __init__(self, h: int, d_model: int, dropout: float = 0.1):
        super(MultiHeadedAttention, self).__init__()
        assert d_model % h == 0
        # We assume d_v always equals d_k
        self.d_k: int = d_model // h
        self.h: int = h
        self.linears: List[nn.Linear] = clones(nn.Linear(d_model, d_model), 4)
        self.attn: Optional[Tensor] = None
        self.dropout: nn.Dropout = nn.Dropout(p=dropout)
 
    def forward(self, query: Tensor, key: Tensor, value: Tensor, mask: Optional[Tensor] = None) -> Tensor:
        if mask is not None:
            # Same mask applied to all h heads.
            mask = mask.unsqueeze(1)
        nbatches: int = query.size(0)
 
        # 1) Do all the linear projections in batch from d_model => h x d_k
        query, key, value = [
            l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
            for l, x in zip(self.linears, (query, key, value))
        ]
 
        # 2) Apply attention on all the projected vectors in batch.
        x: Tensor
        x, self.attn = attention(
            query, key, value, mask=mask, dropout=self.dropout
        )
 
        # 3) "Concat" using a view and apply a final linear.
        x = (
            x.transpose(1, 2)
            .contiguous()
            .view(nbatches, -1, self.h * self.d_k)
        )
        del query
        del key
        del value
        return self.linears[-1](x)

The arguments query, key, and value here are highly misleading. While scaled dot-product attention is a function of these three tensors, multi-head attention is not. Rather, multi-head attention is a function of two variables: a raw input sequence and a raw memory sequence. For self-attention, these are the same variable; for cross-attention, they are distinct. Indeed, in the actual use of these methods, the key and value are always the same. So let’s rewrite forward in a way that is easier to explain.

def forward(self, input_: Tensor, memory: Tensor, mask: Optional[Tensor] = None) -> Tensor:
	if mask is not None:
		# Same mask applied to all h heads.
		mask = mask.unsqueeze(1)
	nbatches: int = input_.size(0)
 
	# 1) Do all the linear projections in batch from d_model => h x d_k
	query, key, value = [
		l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
		for l, x in zip(self.linears, (input_, memory, memory))
	]
 
	# 2) Apply attention on all the projected vectors in batch.
	x: Tensor
	x, self.attn = attention(
		query, key, value, mask=mask, dropout=self.dropout
	)
 
	# 3) "Concat" using a view and apply a final linear.
	x = (
		x.transpose(1, 2)
		.contiguous()
		.view(nbatches, -1, self.h * self.d_k)
	)
	del query
	del key
	del value
	return self.linears[-1](x)

Now we can go through each block.

if mask is not None:
	mask = mask.unsqueeze(1)

tensor.unsqueeze(dim) is an alias of torch.unsqueeze(tensor, dim). This inserts a size-1 dimension. The semantics of dim are slightly confusing:

  • For index , it will insert before the index. So, mask.unsqueeze(0) would add a dimension at the start.
  • For index , it will insert after the index. So, mask.unsqueeze(-1) would insert a dimension at the end.

In this case, where mask starts out as , the resulting tensor will be , which allows PyTorch to broadcast the mask to the score tensor inside the dot-product attention function.

query, key, value = [
	l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
	for l, x in zip(self.linears, (input_, memory, memory))
]

This is really weird and ugly. self.linearshas length 4, and (input_, memory, memory) obviously has length 3. zip will stop iterating when either supplied sequence ends, but it requires thought to make sense of it.

Additionally, there’s nothing naturally sequential about what’s being done here, so wrapping it in a list comprehension seems obfuscatory. The only reason one would even think of such a sequential construction is that four linear layers with clear, non-sequential meanings have been packed anonymously into a single array.

So let’s rewrite the relevant bits like a competent engineer who actually wants future developers to fucking understand her code. At the expense of Python conventions, we can even assign variable names from Vaswani, et al. (2017).

class MultiHeadedAttention(nn.Module):
    def __init__(...):
	    ...
	    self.Wq = nn.linear(d_model, d_model)
	    self.Wk = nn.linear(d_model, d_model)
	    self.Wv = nn.linear(d_model, d_model)
	    self.Wo = nn.linear(d_model, d_model)
	    ...
 
def forward(...) -> Tensor:
	...
	def transform(raw_sequence: Tensor, project: Tensor) -> Tensor:
		u: Tensor = project(raw_sequence)
		u = u.view(nbatches, -1, self.h, self.d_k)
		u = u.transpose(1, 2)
		return u
 
	query: Tensor = transform(input_, self.Wq)
	key: Tensor = transform(memory, self.Wk)
	value: Tensor = transform(memory, self.Wv)
	...
	return self.Wo(x)

Now we can actually unpack what transform is doing to u:

  1. u is defined as raw_sequence projected into a learned vector space.
  2. u is reshaped into a tensor. (See tensor.view.)
  3. Swap the and dimensions so that u is .

Why (2) and (3)? The linear layer learns the projections for all of the heads at once. We then need to break them up, which we do using u.view. But having done this, the dimensions are not in the right order: each head plans to compute scaled dot-product attention on matrices. So we swap them. (Note that PyTorch uses row-major order)

Here we see the above in action:

x, self.attn = attention(
	query, key, value, mask=mask, dropout=self.dropout
)

Something very important happens here. If you look at the implementation of scaled dot-product attention, you see that it doesn’t know anything about batches or heads. It just computes

which we think of as acting on matrices. By using tensors, we perform the same work on all such matrices for all layers for all batches. This is a bit mind-blowing, even for someone who’s used multidimensional arrays for years. A rank tensor can replace for loops, where the actual work is performed on the last dimensions. Here, it replaces an outer loop over batch and an inner loop over head.

Here’s the last meaningful part:

x = (
	x.transpose(1, 2)
	.contiguous()
	.view(nbatches, -1, self.h * self.d_k)
)
...
return self.Wo(x)

The context vectors returned by multi-head attention are supposed to include all of the features from all of the heads. In other words, they represent a concatenation of the context length- context vectors returned by each head, such that the output is once again . The context tensor we get from the attention function is . So we can achieve our concatenation just by rejiggering this tensor.

So first we swap out indices back to , which, enumerated according to row-major order, is structurally equivalent to the concatenation we want. Then we move it all into a single contiguous block of memory. Finally, we reshape it into a tensor. (Recall that ).

At this point, we’re done. But there are these mysterious forced deletions:

del query
del key
del value

Obviously, these three local variables are about to go out of scope and get garbage collected. Presumably, there is some optimization need for deleting them here, but it goes unexplained in the manuscript.