The implementation of The Annotated Transformer is structured primarily to help map the text of Vaswani, et al. (2017) to a comprehensible worked implementation. This forced them to make some compromises, such as implementing scaled dot-product attention separately from multi-head attention. Even with that concession, there are some confusing decisions as well, such as the one analyzed here.
Their interface for multi-head attention is as follows. (Type hints and comments are my own in this and all subsequent snippets.)
class MultiHeadedAttention(nn.Module):
def __init__(
self,
h: int, # The number of attention heads. In the paper, h=8.
d_model: int, # The dimension of all sequences in the model. In the paper, d_model=512.
dropout: float = 0.1 # The dropout rate. Used only during training.
) -> None:
...
def forward(
self,
query: Tensor, # The raw query sequence. For self-attention, query=key=value.
key: Tensor, # The raw memory sequence. Always equal to value.
value: Tensor, # Redundant; always equal to "key."
mask: Tensor = None # If supplied, a bitmask for masked self-attention in the decoder.
) -> Tensor: # Returns the concatenated output of the heads, projected into a learned vector space.
...There is a stand-alone function for scaled dot-product attention, which ends up behaving a lot like an instance method of MultiHeadedAttention. Its signature is as follows. Note that query, key, and value are NOT the same as the ones passed into multi-head attention because they are first projected into learned vector spaces.
def attention(
query: Tensor, # AFTER projection
key: Tensor, # AFTER projection
value: Tensor, # AFTER projection; no longer identical to key!
mask: Tensor | None, # Unchanged from MHA
dropout: nn.Dropout | None # Instance variable from MHA
) -> tuple[Tensor, Tensor]: # Returns (context vectors, attention weights).
...The only time we directly instantiate MultiHeadedAttention is in a factory method called make_model. This in turn makes many copies of the original instance. We have:
def make_model(...) -> EncoderDecoder:
c = copy.deepcopy
attn = MultiHeadedAttention(h, d_model)
...
model = EncoderDecoder(
Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout), N),
Decoder(DecoderLayer(d_model, c(attn), c(attn), c(ff), dropout), N),
...
)
...
return modelWe see that attention is supplied as constructor parameters to EncoderLayer and DecoderLayer. Their constructor signatures are:
class EncoderLayer(nn.Module):
def __init__(self, size, self_attn, feed_forward, dropout):
...
def forward(self, x, mask):
...
class DecoderLayer(nn.Module):
def __init__(self, size, self_attn, src_attn, feed_forward, dropout):
...
def forward(self, x, memory, src_mask, tgt_mask):
...We have self-attention for each encoder block, and both cross-attention and masked self-attention in each decoder block. These all need to be separate instances because each needs to compute a number of weight matrices independently of the others. Since the number of blocks MultiHeadedAttention.
In the encoder layer, the same tensor x is used for key, query, and value:
class EncoderLayer(nn.Module):
...
def forward(self, x, mask):
x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask))
return self.sublayer[1](x, self.feed_forward)In the decoder layer, we pass in a memory tensor separately. For cross-attention, this gets used for both key and query. Self-attention still takes the same three instances of the query as in the encoder.
class DecoderLayer(nn.Module):
...
def forward(self, x, memory, src_mask, tgt_mask):
"Follow Figure 1 (right) for connections."
m = memory
x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask))
x = self.sublayer[1](x, lambda x: self.src_attn(x, m, m, src_mask))
return self.sublayer[2](x, self.feed_forward)It seems plain that the three arguments query, key, and value should have been two arguments input and memory. This is particularly since, in attention(...), the three same-named arguments refer instead to projections of these original values. Multi-head attention is not a function of three inputs; it is a function of two inputs. This seems like a pure mistake to me.