【PyTorch】nn.Transformerのsrc_maskとsrc_key_padding_maskの違い

Programming

自身の研究において、Transformerモデルを構築しているときに困ったことがあったのでそれの備忘録としてsrc_masksrc_key_padding_maskの違いと利用方法について記載します。

かなりニッチな内容になるのですが、早速いきましょう!

なお、最も参考にすべきは公式ドキュメントとソースコードですので、適宜そちらで確認してください。

nn.TransformerEncoderLayerの内容

まずはPyTorch公式ドキュメントのコードを引用します。https://pytorch.org/docs/stable/generated/torch.nn.TransformerEncoderLayer.html

forward(srcsrc_mask=Nonesrc_key_padding_mask=Noneis_causal=False)[SOURCE]

Pass the input through the encoder layer.Parameters

  • src (Tensor) – the sequence to the encoder layer (required).
  • src_mask (Optional[Tensor]) – the mask for the src sequence (optional).
  • src_key_padding_mask (Optional[Tensor]) – the mask for the src keys per batch (optional).
  • is_causal (bool) – If specified, applies a causal mask as src mask. Default: False. Warning: is_causal provides a hint that src_mask is the causal mask. Providing incorrect hints can result in incorrect execution, including forward and backward compatibility.

Return type

Tensor

ここで、src_masksrc_key_padding_maskという言葉が出てきて、似ているのでかなりややこしいです。

src_mask

目的

src_maskを導入する目的は現在注目している単語以前の単語にだけattention計算を行うことです。

裏を返せば、注目している単語以降のattentionは無視されます。

実装

PyTorchには言語モデルのチュートリアルがあり、そちらが非常に参考になるので引用しておきます。
https://pytorch.org/tutorials/beginner/transformer_tutorial.html

Python
import math
import os
from tempfile import TemporaryDirectory
from typing import Tuple

import torch
from torch import nn, Tensor
from torch.nn import TransformerEncoder, TransformerEncoderLayer
from torch.utils.data import dataset

class TransformerModel(nn.Module):

    def __init__(self, ntoken: int, d_model: int, nhead: int, d_hid: int, nlayers: int, dropout: float = 0.5):
        super().__init__()
        self.model_type = 'Transformer'
        self.pos_encoder = PositionalEncoding(d_model, dropout)
        encoder_layers = TransformerEncoderLayer(d_model, nhead, d_hid, dropout)
        self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
        self.embedding = nn.Embedding(ntoken, d_model)
        self.d_model = d_model
        self.linear = nn.Linear(d_model, ntoken)

        self.init_weights()

    def init_weights(self) -> None:
        initrange = 0.1
        self.embedding.weight.data.uniform_(-initrange, initrange)
        self.linear.bias.data.zero_()
        self.linear.weight.data.uniform_(-initrange, initrange)

    def forward(self, src: Tensor, src_mask: Tensor = None) -> Tensor:
        src = self.embedding(src) * math.sqrt(self.d_model)
        src = self.pos_encoder(src)
        if src_mask is None:
            src_mask = nn.Transformer.generate_square_subsequent_mask(len(src)).to(device) 
        output = self.transformer_encoder(src, src_mask)
        output = self.linear(output)
        return output

forwardの部分にsrc_maskが導入されています。
またsrc_masknn.Transformergenerate_square_subsequent_mask関数を利用して作成されます。

generate_square_subsequent_mask関数は配列長を1辺の長さとした正方行列を作成します。
作成する行列は次の項で記載します。

計算

src_maskの正体は以下のような構造をしています。

Python
tensor([[0., -inf, -inf, -inf, -inf, -inf],
        [0.,   0., -inf, -inf, -inf, -inf],
        [0.,   0.,   0., -inf, -inf, -inf],
        [0.,   0.,   0.,   0., -inf, -inf],
        [0.,   0.,   0.,   0.,   0., -inf],
        [0.,   0.,   0.,   0.,   0.,   0.]])

1行目は1トークン目に注目しているときのmaskを表しています。
そのため、1トークン目しか注目していません。

同じく4行目を考えると、1から4のトークンにしか注目されないようになっています。

このことから、現在注目している単語以前の単語にだけattention計算を行うようになっています。
繰り返しますが、逆に渡さない場合は全単語に対するattentionを計算します。

適応タスク

チュートリアルのように言語モデルや翻訳などにおいて利用されそうです。
また、パットは思いつかないですが、目的に沿った計算を行いたい場合に利用できます。

では続いてsrc_key_padding_maskに行きます!

src_key_padding_mask

目的

配列に対してpaddingを行った際に、そのpaddingに対してattention計算を行わないことです。

実装

src_key_padding_maskは残念ながら、先ほどのgenerate_square_subsequent_mask関数のように便利関数は用意されていません。
そのため自分で作成する必要がありますが、そこまで難しくありません。

例えば入力シーケンスであるsrcが(4, 6, 512)の次元を持っていたとします。
つまりバッチサイズが4、配列の長さが6、それの特徴が512次元であるということです。
具体的には以下の例を考えます。

I am student . [MASK] [MASK]       :トークン4こ+[MASK]2この構成
How old are you ? [MASK]         :トークン5こ+[MASK]1この構成
He loves playing video games .       :トークン6こ+[MASK]0この構成
Sorry ? [MASK] [MASK] [MASK] [MASK]  :トークン2こ+[MASK]4この構成

バッチ数が4 すなわち4つの文章で1つのテンソル

配列の長さが6 すなわち文章が6トークンで構成されている

特徴が512 すなわちそれぞれのトークンに512次元の特徴がついている(src_key_padding_maskなどに挙動に直接関係ありませんが記載しています)という形です。

この場合に作るべきsrc_key_padding_maskは以下のどちらかです。

Python
tensor([[0, 0, 0, 0, 1, 1],
        [0, 0, 0, 0, 0, 1],
        [0, 0, 0, 0, 0, 0],
        [0, 0, 1, 1, 1, 1]])

tensor([[False, False, False, False, True, True],
        [False, False, False, False, False, True],
        [False, False, False, False, False, False],
        [False, False, True, True, True, True]])

文章中の[MASK]に該当する部分が1またはTrueになるテンソルを用意します。

もし逆にしてしまった(あるいはそうならざるを得ない)場合は
src_key_padding_mask = ~src_key_padding_maskとして逆にすれば問題ないです。

そして、定義できれば

Python
output = self.transformer_encoder(src, src_key_padding_mask = src_key_padding_mask)

これで行けます!
ただし、forwardの部分の引数を忘れないようにしてください!

計算

先ほどと同様に-infにすることでattentionを無視することになります。
つまり、先ほどの例ですと

Python
tensor([[0.,   0.,   0.,   0., -inf, -inf],
        [0.,   0.,   0.,   0.,   0., -inf],
        [0.,   0.,   0.,   0.,   0.,   0.],
        [0.,   0., -inf, -inf, -inf, -inf]])

このようになります。

適応タスク

入力データの配列長が異なる場合は汎用しそうです。
逆に利用しないと、[MASK]の部分も計算されてしまいますので注意です!

まとめ

  • src_mask
    現在注目している単語以前の単語にだけattention計算を行う
  • src_key_padding_mask
    ⇒ paddingに対してattention計算を行わない

また、この二つのmaskの手法は両方同時に利用することもできますので、目的に合わせて使ってみてください!

別途実装のコードなどを紹介できればと思います!

コメント

タイトルとURLをコピーしました