Though the transformer’s decoder is similar to its encoder, the differences are very important. In The Annotated Transformer, the decoder is invoked as part of the EncoderDecoder module:
class EncoderDecoder(nn.Module):
...
def forward(self, src, tgt, src_mask, tgt_mask):
return self.decode(self.encode(src, src_mask), src_mask, tgt, tgt_mask)
def encode(self, src, src_mask):
return self.encoder(self.src_embed(src), src_mask)
def decode(self, memory, src_mask, tgt, tgt_mask):
return self.decoder(self.tgt_embed(tgt), memory, src_mask, tgt_mask)Notice that the decode method adds two new arguments, memory and tgt_mask, compared with encode. The first, memory, contains the output from the encoder; it is used for cross-attention. The second, tgt_mask, contains a bitmask to ensure that training information from future states does not leak during self-attention. (See “Masked Attention”.)
The decoder itself is implemented as follows:
class Decoder(nn.Module):
def __init__(self, layer, N):
super(Decoder, self).__init__()
self.layers = clones(layer, N)
self.norm = LayerNorm(layer.size)
def forward(self, x, memory, src_mask, tgt_mask):
for layer in self.layers:
x = layer(x, memory, src_mask, tgt_mask)
return self.norm(x)
At this level of abstraction, the only difference between the decoder and the encoder is the inclusion of the memory and tgt_mask arguments discussed above.
The differences become apparent inside of the decoder block (“layer”).
class DecoderLayer(nn.Module):
def __init__(self, size, self_attn, src_attn, feed_forward, dropout):
super(DecoderLayer, self).__init__()
self.size = size
self.self_attn = self_attn
self.src_attn = src_attn
self.feed_forward = feed_forward
self.sublayer = clones(SublayerConnection(size, dropout), 3)
def forward(self, x, memory, src_mask, tgt_mask):
m = memory
x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask))
x = self.sublayer[1](x, lambda x: self.src_attn(x, m, m, src_mask))
return self.sublayer[2](x, self.feed_forward)There are three differences compared to the encoder:
- We have an extra instance of multi-head attention (for cross-attention);
- We have three sublayer connections instead of two
- Because we do self attention → cross attention → positionwise feed-forward
- The forward method tracks memory and target mask, as discussed above.