【PyTorch】Differences between src_mask and src_key_padding_mask

English version

In conducting my research, I encountered some challenges while building a Transformer model.

To serve as a memo, I will document the differences and usage methods of src_mask and src_key_padding_mask.

This topic may be somewhat niche, but let’s beggin!

* It’s worth noting that the official documentation and source code are the most reliable references, so please consult them as needed.

nn.TransformerEncoderLayer

First, I’ll quote from the PyTorch official documentation.
https://pytorch.org/docs/stable/generated/torch.nn.TransformerEncoderLayer.html

forward(srcsrc_mask=Nonesrc_key_padding_mask=Noneis_causal=False)[SOURCE]

Pass the input through the encoder layer.Parameters

  • src (Tensor) – the sequence to the encoder layer (required).
  • src_mask (Optional[Tensor]) – the mask for the src sequence (optional).
  • src_key_padding_mask (Optional[Tensor]) – the mask for the src keys per batch (optional).
  • is_causal (bool) – If specified, applies a causal mask as src mask. Default: False. Warning: is_causal provides a hint that src_mask is the causal mask. Providing incorrect hints can result in incorrect execution, including forward and backward compatibility.

Return type

Tensor

This is where the terms src_mask and src_key_padding_mask come in, which is quite complicated because they are so similar.

src_mask

Objective

The purpose of introducing src_mask is to only compute attention for words preceding the current focus word.

Inversely, attention for words following the focus word is ignored.

Implementation

There is a tutorial for language models in PyTorch that is very informative, and I recommend consulting it. PyTorch Transformer Tutorial

Python
import math
import os
from tempfile import TemporaryDirectory
from typing import Tuple

import torch
from torch import nn, Tensor
from torch.nn import TransformerEncoder, TransformerEncoderLayer
from torch.utils.data import dataset

class TransformerModel(nn.Module):

    def __init__(self, ntoken: int, d_model: int, nhead: int, d_hid: int, nlayers: int, dropout: float = 0.5):
        super().__init__()
        self.model_type = 'Transformer'
        self.pos_encoder = PositionalEncoding(d_model, dropout)
        encoder_layers = TransformerEncoderLayer(d_model, nhead, d_hid, dropout)
        self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
        self.embedding = nn.Embedding(ntoken, d_model)
        self.d_model = d_model
        self.linear = nn.Linear(d_model, ntoken)

        self.init_weights()

    def init_weights(self) -> None:
        initrange = 0.1
        self.embedding.weight.data.uniform_(-initrange, initrange)
        self.linear.bias.data.zero_()
        self.linear.weight.data.uniform_(-initrange, initrange)

    def forward(self, src: Tensor, src_mask: Tensor = None) -> Tensor:
        src = self.embedding(src) * math.sqrt(self.d_model)
        src = self.pos_encoder(src)
        if src_mask is None:
            src_mask = nn.Transformer.generate_square_subsequent_mask(len(src)).to(device) 
        output = self.transformer_encoder(src, src_mask)
        output = self.linear(output)
        return output

The forward part incorporates the src_mask.

Additionally, src_mask is created using the generate_square_subsequent_mask function from nn.Transformer.

The generate_square_subsequent_mask function creates a square matrix with the array length as its side. The structure of the matrix is described in the next section.

Computation

The structure of src_mask is as follows:

Python
tensor([[0., -inf, -inf, -inf, -inf, -inf],
        [0.,   0., -inf, -inf, -inf, -inf],
        [0.,   0.,   0., -inf, -inf, -inf],
        [0.,   0.,   0.,   0., -inf, -inf],
        [0.,   0.,   0.,   0.,   0., -inf],
        [0.,   0.,   0.,   0.,   0.,   0.]])

The first row represents the mask when focusing on the first token, meaning only the first token is considered.

Similarly, considering the fourth row, it’s apparent that only tokens 1 to 4 are focused on.

This indicates that the computation of attention is only performed for words preceding the current focus word.
Conversely, if not applied, attention is computed for all words.

Applicable Tasks

Tasks such as language modeling and translation, similar to those in the tutorial, seem suitable for this.

Additionally, it can be used for specific computations aligned with the objective, though no other applications come to mind at the moment.

Now, let’s move on to src_key_padding_mask !

src_key_padding_mask

Objective

The purpose is to ignore padding in the array during attention computation.

Implementation

Unfortunately, unlike the generate_square_subsequent_mask function, there is no convenient function for creating src_key_padding_mask, so you need to create it yourself.
However, it’s not overly complicated.

For example, consider an input sequence src with dimensions (4, 6, 512). This means a batch size of 4, an array length of 6, and each token has 512-dimensional features.

Here’s a specific example:

I am student . [MASK] [MASK]       :4 tokens+2 [MASK]
How old are you ? [MASK]         :5 tokens+1 [MASK]
He loves playing video games .       :6 tokens+0 [MASK]
Sorry ? [MASK] [MASK] [MASK] [MASK]  :2 tokens+4 [MASK]

4 batches, meaning 4 sentences in one tensor

An array length of 6, meaning each sentence consists of 6 tokens

512 features, indicating each token is associated with 512-dimensional features (though this doesn’t directly relate to the behavior of src_key_padding_mask).

In this case, the src_key_padding_mask to be created would be a tensor where parts corresponding to [MASK] are marked as 1 or True.

Python
tensor([[0, 0, 0, 0, 1, 1],
        [0, 0, 0, 0, 0, 1],
        [0, 0, 0, 0, 0, 0],
        [0, 0, 1, 1, 1, 1]])

tensor([[False, False, False, False, True, True],
        [False, False, False, False, False, True],
        [False, False, False, False, False, False],
        [False, False, True, True, True, True]])

If the opposite is done (or if it’s unavoidable), you can reverse it by doing src_key_padding_mask = ~src_key_padding_mask.

Then, you can define:

Python
output = self.transformer_encoder(src, src_key_padding_mask = src_key_padding_mask)

Make sure not to forget the argument in the forward part!

Computation

Similar to before, setting it to -inf ignores the attention.
For example:

Python
tensor([[0.,   0.,   0.,   0., -inf, -inf],
        [0.,   0.,   0.,   0.,   0., -inf],
        [0.,   0.,   0.,   0.,   0.,   0.],
        [0.,   0., -inf, -inf, -inf, -inf]])
Applicable Tasks

This seems generally applicable when input data array lengths vary.
Otherwise, attention would also be computed for [MASK] parts, so be cautious!

Summary

  • src_mask: Computes attention only for words preceding the current focus word.
  • src_key_padding_mask: Ignores padding for attention computation.

Both masking techniques can be used simultaneously, so try them out according to your objectives!

I hope to introduce implementation code and more details in the future!

コメント

タイトルとURLをコピーしました