自身の研究において、Transformerモデルを構築しているときに困ったことがあったのでそれの備忘録としてsrc_maskとsrc_key_padding_maskの違いと利用方法について記載します。
かなりニッチな内容になるのですが、早速いきましょう!
なお、最も参考にすべきは公式ドキュメントとソースコードですので、適宜そちらで確認してください。
nn.TransformerEncoderLayerの内容
まずはPyTorch公式ドキュメントのコードを引用します。https://pytorch.org/docs/stable/generated/torch.nn.TransformerEncoderLayer.html
forward(src, src_mask=None, src_key_padding_mask=None, is_causal=False)[SOURCE]
Pass the input through the encoder layer.Parameters
- src (Tensor) – the sequence to the encoder layer (required).
- src_mask (Optional[Tensor]) – the mask for the src sequence (optional).
- src_key_padding_mask (Optional[Tensor]) – the mask for the src keys per batch (optional).
- is_causal (bool) – If specified, applies a causal mask as
src mask
. Default:False
. Warning:is_causal
provides a hint thatsrc_mask
is the causal mask. Providing incorrect hints can result in incorrect execution, including forward and backward compatibility.Return type
ここで、src_maskとsrc_key_padding_maskという言葉が出てきて、似ているのでかなりややこしいです。
src_mask
目的
src_maskを導入する目的は現在注目している単語以前の単語にだけattention計算を行うことです。
裏を返せば、注目している単語以降のattentionは無視されます。
実装
PyTorchには言語モデルのチュートリアルがあり、そちらが非常に参考になるので引用しておきます。
https://pytorch.org/tutorials/beginner/transformer_tutorial.html
import math
import os
from tempfile import TemporaryDirectory
from typing import Tuple
import torch
from torch import nn, Tensor
from torch.nn import TransformerEncoder, TransformerEncoderLayer
from torch.utils.data import dataset
class TransformerModel(nn.Module):
def __init__(self, ntoken: int, d_model: int, nhead: int, d_hid: int, nlayers: int, dropout: float = 0.5):
super().__init__()
self.model_type = 'Transformer'
self.pos_encoder = PositionalEncoding(d_model, dropout)
encoder_layers = TransformerEncoderLayer(d_model, nhead, d_hid, dropout)
self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
self.embedding = nn.Embedding(ntoken, d_model)
self.d_model = d_model
self.linear = nn.Linear(d_model, ntoken)
self.init_weights()
def init_weights(self) -> None:
initrange = 0.1
self.embedding.weight.data.uniform_(-initrange, initrange)
self.linear.bias.data.zero_()
self.linear.weight.data.uniform_(-initrange, initrange)
def forward(self, src: Tensor, src_mask: Tensor = None) -> Tensor:
src = self.embedding(src) * math.sqrt(self.d_model)
src = self.pos_encoder(src)
if src_mask is None:
src_mask = nn.Transformer.generate_square_subsequent_mask(len(src)).to(device)
output = self.transformer_encoder(src, src_mask)
output = self.linear(output)
return output
forwardの部分にsrc_maskが導入されています。
またsrc_maskはnn.Transformerのgenerate_square_subsequent_mask関数を利用して作成されます。
generate_square_subsequent_mask関数は配列長を1辺の長さとした正方行列を作成します。
作成する行列は次の項で記載します。
計算
src_maskの正体は以下のような構造をしています。
tensor([[0., -inf, -inf, -inf, -inf, -inf],
[0., 0., -inf, -inf, -inf, -inf],
[0., 0., 0., -inf, -inf, -inf],
[0., 0., 0., 0., -inf, -inf],
[0., 0., 0., 0., 0., -inf],
[0., 0., 0., 0., 0., 0.]])
1行目は1トークン目に注目しているときのmaskを表しています。
そのため、1トークン目しか注目していません。
同じく4行目を考えると、1から4のトークンにしか注目されないようになっています。
このことから、現在注目している単語以前の単語にだけattention計算を行うようになっています。
繰り返しますが、逆に渡さない場合は全単語に対するattentionを計算します。
適応タスク
チュートリアルのように言語モデルや翻訳などにおいて利用されそうです。
また、パットは思いつかないですが、目的に沿った計算を行いたい場合に利用できます。
では続いてsrc_key_padding_maskに行きます!
src_key_padding_mask
目的
配列に対してpaddingを行った際に、そのpaddingに対してattention計算を行わないことです。
実装
src_key_padding_maskは残念ながら、先ほどのgenerate_square_subsequent_mask関数のように便利関数は用意されていません。
そのため自分で作成する必要がありますが、そこまで難しくありません。
例えば入力シーケンスであるsrcが(4, 6, 512)の次元を持っていたとします。
つまりバッチサイズが4、配列の長さが6、それの特徴が512次元であるということです。
具体的には以下の例を考えます。
I am student . [MASK] [MASK] :トークン4こ+[MASK]2この構成
How old are you ? [MASK] :トークン5こ+[MASK]1この構成
He loves playing video games . :トークン6こ+[MASK]0この構成
Sorry ? [MASK] [MASK] [MASK] [MASK] :トークン2こ+[MASK]4この構成
バッチ数が4 すなわち4つの文章で1つのテンソル
配列の長さが6 すなわち文章が6トークンで構成されている
特徴が512 すなわちそれぞれのトークンに512次元の特徴がついている(src_key_padding_maskなどに挙動に直接関係ありませんが記載しています)という形です。
この場合に作るべきsrc_key_padding_maskは以下のどちらかです。
tensor([[0, 0, 0, 0, 1, 1],
[0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 0],
[0, 0, 1, 1, 1, 1]])
tensor([[False, False, False, False, True, True],
[False, False, False, False, False, True],
[False, False, False, False, False, False],
[False, False, True, True, True, True]])
文章中の[MASK]に該当する部分が1またはTrueになるテンソルを用意します。
もし逆にしてしまった(あるいはそうならざるを得ない)場合は
src_key_padding_mask = ~src_key_padding_maskとして逆にすれば問題ないです。
そして、定義できれば
output = self.transformer_encoder(src, src_key_padding_mask = src_key_padding_mask)
これで行けます!
ただし、forwardの部分の引数を忘れないようにしてください!
計算
先ほどと同様に-infにすることでattentionを無視することになります。
つまり、先ほどの例ですと
tensor([[0., 0., 0., 0., -inf, -inf],
[0., 0., 0., 0., 0., -inf],
[0., 0., 0., 0., 0., 0.],
[0., 0., -inf, -inf, -inf, -inf]])
このようになります。
適応タスク
入力データの配列長が異なる場合は汎用しそうです。
逆に利用しないと、[MASK]の部分も計算されてしまいますので注意です!
まとめ
- src_mask
⇒ 現在注目している単語以前の単語にだけattention計算を行う - src_key_padding_mask
⇒ paddingに対してattention計算を行わない
また、この二つのmaskの手法は両方同時に利用することもできますので、目的に合わせて使ってみてください!
別途実装のコードなどを紹介できればと思います!
コメント