In conducting my research, I encountered some challenges while building a Transformer model.
To serve as a memo, I will document the differences and usage methods of src_mask
and src_key_padding_mask
.
This topic may be somewhat niche, but let’s beggin!
* It’s worth noting that the official documentation and source code are the most reliable references, so please consult them as needed.
nn.TransformerEncoderLayer
First, I’ll quote from the PyTorch official documentation.
https://pytorch.org/docs/stable/generated/torch.nn.TransformerEncoderLayer.html
forward(src, src_mask=None, src_key_padding_mask=None, is_causal=False)[SOURCE]
Pass the input through the encoder layer.Parameters
- src (Tensor) – the sequence to the encoder layer (required).
- src_mask (Optional[Tensor]) – the mask for the src sequence (optional).
- src_key_padding_mask (Optional[Tensor]) – the mask for the src keys per batch (optional).
- is_causal (bool) – If specified, applies a causal mask as
src mask
. Default:False
. Warning:is_causal
provides a hint thatsrc_mask
is the causal mask. Providing incorrect hints can result in incorrect execution, including forward and backward compatibility.Return type
This is where the terms src_mask and src_key_padding_mask come in, which is quite complicated because they are so similar.
src_mask
Objective
The purpose of introducing src_mask
is to only compute attention for words preceding the current focus word.
Inversely, attention for words following the focus word is ignored.
Implementation
There is a tutorial for language models in PyTorch that is very informative, and I recommend consulting it. PyTorch Transformer Tutorial
import math
import os
from tempfile import TemporaryDirectory
from typing import Tuple
import torch
from torch import nn, Tensor
from torch.nn import TransformerEncoder, TransformerEncoderLayer
from torch.utils.data import dataset
class TransformerModel(nn.Module):
def __init__(self, ntoken: int, d_model: int, nhead: int, d_hid: int, nlayers: int, dropout: float = 0.5):
super().__init__()
self.model_type = 'Transformer'
self.pos_encoder = PositionalEncoding(d_model, dropout)
encoder_layers = TransformerEncoderLayer(d_model, nhead, d_hid, dropout)
self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
self.embedding = nn.Embedding(ntoken, d_model)
self.d_model = d_model
self.linear = nn.Linear(d_model, ntoken)
self.init_weights()
def init_weights(self) -> None:
initrange = 0.1
self.embedding.weight.data.uniform_(-initrange, initrange)
self.linear.bias.data.zero_()
self.linear.weight.data.uniform_(-initrange, initrange)
def forward(self, src: Tensor, src_mask: Tensor = None) -> Tensor:
src = self.embedding(src) * math.sqrt(self.d_model)
src = self.pos_encoder(src)
if src_mask is None:
src_mask = nn.Transformer.generate_square_subsequent_mask(len(src)).to(device)
output = self.transformer_encoder(src, src_mask)
output = self.linear(output)
return output
The forward
part incorporates the src_mask
.
Additionally, src_mask
is created using the generate_square_subsequent_mask
function from nn.Transformer.
The generate_square_subsequent_mask
function creates a square matrix with the array length as its side. The structure of the matrix is described in the next section.
Computation
The structure of src_mask
is as follows:
tensor([[0., -inf, -inf, -inf, -inf, -inf],
[0., 0., -inf, -inf, -inf, -inf],
[0., 0., 0., -inf, -inf, -inf],
[0., 0., 0., 0., -inf, -inf],
[0., 0., 0., 0., 0., -inf],
[0., 0., 0., 0., 0., 0.]])
The first row represents the mask when focusing on the first token, meaning only the first token is considered.
Similarly, considering the fourth row, it’s apparent that only tokens 1 to 4 are focused on.
This indicates that the computation of attention is only performed for words preceding the current focus word.
Conversely, if not applied, attention is computed for all words.
Applicable Tasks
Tasks such as language modeling and translation, similar to those in the tutorial, seem suitable for this.
Additionally, it can be used for specific computations aligned with the objective, though no other applications come to mind at the moment.
Now, let’s move on to src_key_padding_mask
!
src_key_padding_mask
Objective
The purpose is to ignore padding in the array during attention computation.
Implementation
Unfortunately, unlike the generate_square_subsequent_mask
function, there is no convenient function for creating src_key_padding_mask
, so you need to create it yourself.
However, it’s not overly complicated.
For example, consider an input sequence src
with dimensions (4, 6, 512). This means a batch size of 4, an array length of 6, and each token has 512-dimensional features.
Here’s a specific example:
I am student . [MASK] [MASK] :4 tokens+2 [MASK]
How old are you ? [MASK] :5 tokens+1 [MASK]
He loves playing video games . :6 tokens+0 [MASK]
Sorry ? [MASK] [MASK] [MASK] [MASK] :2 tokens+4 [MASK]
4 batches, meaning 4 sentences in one tensor
An array length of 6, meaning each sentence consists of 6 tokens
512 features, indicating each token is associated with 512-dimensional features (though this doesn’t directly relate to the behavior of src_key_padding_mask
).
In this case, the src_key_padding_mask
to be created would be a tensor where parts corresponding to [MASK] are marked as 1 or True.
tensor([[0, 0, 0, 0, 1, 1],
[0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 0],
[0, 0, 1, 1, 1, 1]])
tensor([[False, False, False, False, True, True],
[False, False, False, False, False, True],
[False, False, False, False, False, False],
[False, False, True, True, True, True]])
If the opposite is done (or if it’s unavoidable), you can reverse it by doing src_key_padding_mask = ~src_key_padding_mask
.
Then, you can define:
output = self.transformer_encoder(src, src_key_padding_mask = src_key_padding_mask)
Make sure not to forget the argument in the forward
part!
Computation
Similar to before, setting it to -inf
ignores the attention.
For example:
tensor([[0., 0., 0., 0., -inf, -inf],
[0., 0., 0., 0., 0., -inf],
[0., 0., 0., 0., 0., 0.],
[0., 0., -inf, -inf, -inf, -inf]])
Applicable Tasks
This seems generally applicable when input data array lengths vary.
Otherwise, attention would also be computed for [MASK] parts, so be cautious!
Summary
src_mask
: Computes attention only for words preceding the current focus word.src_key_padding_mask
: Ignores padding for attention computation.
Both masking techniques can be used simultaneously, so try them out according to your objectives!
I hope to introduce implementation code and more details in the future!
コメント