はじめに
機械学習や物理などの勉強していて、ことあるごとに出てくるのがKLダイバージェンス。
重要そうなのはわかるのですが、式の意味も複雑で直感的な意味もわかりにくくて非常に厄介です。
情報理論から考えると、ほんのちょっとだけわかった気になれたのでその喜びをシェアするために
この記事では、KLダイバージェンスについて順番に解説していきます。
間違っている場合は指摘していただけると幸いです。
KLダイバージェンスとは何か
KLダイバージェンスは、簡単に言うと「2つの確率分布がどれくらい違うか」を表すものです。
今はわからなくても問題ないです。最後に戻ってきます。
例えば、ある商品の売れ行きを予測したいとします。過去のデータから作った予測と、実際に売れたデータがあるとします。この2つのデータの分布がどれくらい違うかをKLダイバージェンスで測ることができます。
なぜ機械学習や情報理論で重要なのか
KLダイバージェンスは、機械学習の様々な場面で使われています。
- モデルの学習:モデルが予測する確率分布を、実際のデータの分布に近づけるため
- 特徴量の選択:どの特徴量が重要かを判断するため
- 異常検知:通常とは異なるデータを見つけるため
情報理論では、KLダイバージェンスは2つの情報源がどれくらい違うかを表す指標として使われます。
繰り返しますが、今はわからなくても問題ないです。最後に戻ってきます。
本記事の流れ
この記事では、KLダイバージェンスを理解するために、以下のステップで解説していきます。
- 情報理論の基礎:シャノン情報量
- シャノンエントロピー:平均情報量
- クロスエントロピー:異なる分布間の測定
- KLダイバージェンス:分布間の非対称な距離
- 機械学習におけるKLダイバージェンスの応用
- KLダイバージェンスの実装と計算
情報理論の基礎:シャノン情報量
まず、情報量とは何か、どうやって定量化するのかから説明します。
冒頭で記載したことはまだまだ出てこないので頭を空っぽにして参照いただけると幸いです。
情報量とは何か
「情報」と聞くと、ニュースや天気予報などを思い浮かべるかもしれません。
情報理論における情報とは、「ある出来事が起こる確率が低いほど、その出来事が起きたときに得られる情報量が多い」という考え方です。
(これからの説明は、ざっくりとした説明+例というセットを繰り返しながら説明していきます)
確率と情報量の関係:驚きの度合い
例えば、天気予報とかをみている時に
例1:「大阪は今日は晴れです」という情報
例2:「沖縄は今日は雪です」という情報
どちらの情報が豊富だと感じるでしょうか?
おそらく、例2の「沖縄は今日は雪です」という情報の方が、ずっと驚きが大きいですよね。
なぜなら、雪が降る確率が非常に低いからです。
情報量は、この「驚きの度合い」を数値化したものと考えるとわかりやすいでしょう。
シャノン情報量の定義:\(I(x) = -\log_2 P(x)\)
この驚きの度合いを数値化したものがシャノン情報量で、以下の式で定義されます。
\[I(x) = -\log_2 P(x)\]
\(I(x)\): 事象 \(x\) が起きた時の情報量
\(P(x)\): 事象 \(x\) が起こる確率
この式を見ると、確率 \(P(x)\) が小さくなるほど、情報量 \(I(x)\)が大きくなることがわかります。
対数の底が2なのは、情報量の単位を「ビット (bit)」で表すためです(後述します。余談参照)。
例:コイン投げ、サイコロの目
もう一度、例を考えてみましょう。
コイン投げ:表が出る確率は1/2
サイコロの目:1が出る確率は1/6
どちらの事象が起きた時の情報量が多いでしょうか?
サイコロの目が1であることの方が、起こる確率が低いので、情報量が多いと言えます。
実際に先ほどのシャノン情報量の定義を利用して実際に計算してみましょう。
コイン投げ
表が出る確率: \(P(表) = \frac{1}{2}\)
情報量: \(I(表) = -\log_2 \frac{1}{2} = 1\) bit
サイコロの目
1が出る確率: \(P(1) = \frac{1}{6}\)
情報量: \(I(1) = -\log_2 \frac{1}{6} \approx 2.58\) bit
このように、サイコロの目が1であることの方が確率が低いため、コインの表が出るよりも情報量が多いことがわかります。
bitに関する余談①
シャノン情報量の定義で対数の底を2にすると、情報量の単位が「bit」になる、というお話でした。
情報を記録したり伝えたりする「符号化」という観点からみてみます。
bitというのは、情報の最小単位で、0か1のどちらか一方を表現できます。
私たちが扱う様々な情報を、コンピュータが理解できる0と1の並び(bit列)に変換することを「符号化」と言います。
情報量とビットの関係を考える上で、最もシンプルなのは二つの選択肢を区別する場合です。例えば、「結果は成功か失敗か」という二択があって、それぞれの確率が1/2で等しいとします。
- 出来事1:「結果は成功である」 確率 P(成功) = 1/2
- 出来事2:「結果は失敗である」 確率 P(失敗) = 1/2
「結果は成功である」という出来事の情報量を計算すると、 \(−\log2(1/2)=1\)bit、となります。
この「1ビット」という情報量、符号化の観点から見るとどうなるでしょう? 成功と失敗という二つの異なる状態を区別するためには、たった1ビットの符号で十分です。
例えば、「成功を0」、「失敗を1」と割り当てれば、これだけで二つの状態を完全に区別して記録したり伝えたりできます。
つまり、情報量1ビットの出来事は、ロスなく表現するために1ビットで符号化できる、と考えることができます。
では、もう少し選択肢を増やして、4つのクラスA、B、C、Dがあって、それぞれが等しい確率(1/4)で出現する場合を考えてみます。
- クラスAが出現する確率: P(A) = 1/4
- クラスBが出現する確率: P(B) = 1/4
- クラスCが出現する確率: P(C) = 1/4
- クラスDが出現する確率: P(D) = 1/4
クラスAが出現した時の情報量は、 \(−\log2(1/4)=2 \)bit、となります。クラスB、C、Dもそれぞれ 2 ビットです。
この「2ビット」という情報量、符号化の観点からはどうでしょう? 4つの異なるクラスを完全に区別して表現するためには、少なくとも2ビットの長さの符号語が必要です。
- クラスAを “00”
- クラスBを “01”
- クラスCを “10”
- クラスDを “11”
このように、2ビットの組み合わせを使えば、4つの異なる状態をロスなく表現できます。
つまり、情報量2ビットの出来事は、ロスなく表現するために2ビットで符号化できる、と考えることができます。
これらの余談から見えてくるのは、シャノン情報量の定義\( I(x)=−log2P(x) \)が、「確率 \(P(x) \)の出来事を、他の出来事とロスなく区別して表現するために、少なくとも必要となるビット数」に対応している、ということです。
シャノンエントロピー:平均情報量
ここでは、ある確率分布全体で、平均的にどれくらいの情報量が得られるのかを考えます。
エントロピーの定義:\(H(X) = -\sum_x P(x) \log_2 P(x)\)
シャノンエントロピーは、以下の式で定義されます。
\[H(X) = -\sum_x P(x) \log_2 P(x)\]
\(H(X)\): 確率変数 \(X\) のエントロピー
\(P(x)\): 事象 \(x\) が起こる確率
この式は、各事象の情報量 \(I(x) = -\log_2 P(x)\) に、その事象が起こる確率 \(P(x)\) をかけて、すべての事象について足し合わせたものです。つまり、エントロピーは情報量の期待値(平均値)を表しています。
エントロピーの性質(最大値、最小値)
エントロピーは、確率分布の「不確実さ」や「ランダムさ」を表す指標として解釈できます。
エントロピーが高い:分布が均一に近く、どの事象が起こるか予測しにくい(不確実性が高い)
エントロピーが低い:分布に偏りがあり、特定の事象が起こりやすい(不確実性が低い)
そのようなエントロピーには、以下のような性質があります。
この性質について以下の例を使って理解します。
例:均一な分布 vs 偏った分布のエントロピー
例1:均一な分布(サイコロ)
サイコロの各目の出る確率はすべて \(\frac{1}{6}\) です。
エントロピー: \[H(X) = -\sum_{i=1}^{6} \frac{1}{6} \log_2 \frac{1}{6} = \log_2 6 \approx 2.58 \]
例2:偏った分布
あるイカサマサイコロがあり、1の目が出る確率が \(\frac{5}{6}\)、それ以外の目が \(\frac{1}{30}\) だとします。
エントロピー: \[H(X) = -\frac{5}{6} \log_2 \frac{5}{6} – 5 \cdot \frac{1}{30} \log_2 \frac{1}{30} \approx 0.76 \]
この例から、均一な分布(サイコロ)の方が、偏った分布よりもエントロピーが高いことがわかります。
例3:偏り切った分布
全ての面が1のサイコロがあります。
エントロピー: \[H(X) = -1 \cdot \log_2 1 = 0 \]
このサイコロでは出る結果が完全に確定しており、「驚き」が全くありません。結果を見ても新しい情報が得られないため、エントロピーは0になります。
エントロピーについて理解は深まったでしょうか。
ある確率分布に従う各事象から得られる情報量の期待値がエントロピーであるということでした。
bitに関する余談②
4つのものを区別する際に、2bit必要だと思いますが、その出現確率に偏りがある場合だとbit数を減らせるというのに違和感があるかもしれません。
例を使いながら記載しておきます。
4つのものを区別するための最小限のビット数
4つの異なるもの(例えば、クラス1, クラス2, クラス3, クラス4)を確実に区別するためには、少なくとも 2 ビットが必要でした。
- 00 → クラス1
- 01 → クラス2
- 10 → クラス3
- 11 → クラス4
このように、すべてのクラスに同じ長さ(2ビット)の符号語を割り当てる方法を固定長符号化と呼びます。4つのクラスが等しい確率で出現する場合、この固定長符号化が最も効率的であり、シンボル1つあたりに平均して必要なビット数は 2 ビットになります。
これは、4つのクラスが等確率(各1/4)で出現する場合のエントロピーの2ビットと一致します。情報源が最も不確かな状態なので、平均して多くのビットが必要です。
確率に偏りがある場合:平均ビット数が小さくなる理由
しかし、4つのクラスの出現確率に偏りがある場合です。例えば、先ほどのアルファベットの例のように、特定のクラスが他のクラスよりも圧倒的に高い確率で出現するとします。
- クラス1: 確率 0.7 (よく出る)
- クラス2: 確率 0.1 (めったに出ない)
- クラス3: 確率 0.1 (めったに出ない)
- クラス4: 確率 0.1 (めったに出ない)
この場合のエントロピー(最適な平均ビット数)は、計算すると約1.35 ビットとなります。これは2ビットよりも小さい値です。
情報理論における効率的な符号化(変動長符号化)の考え方は、「よく出るシンボルには短い符号語を割り当て、めったに出ないシンボルには長い符号語を割り当てる」というものです。
例えば、上記の確率分布の場合、以下のような符号化を考えます(これは最適な符号ではありませんが、アイデアとして考えてください)。
- クラス1 (確率 0.7): “0” (1ビット)
- クラス2 (確率 0.1): “10” (2ビット)
- クラス3 (確率 0.1): “110” (3ビット)
- クラス4 (確率 0.1): “111” (3ビット)
この符号化では、それぞれのクラスに異なる長さの符号語を割り当てています。そして、どのアルファベットも一意に復号できるように符号語を決めています(例えば、符号列 “010” があったら、まず “0” で区切ってクラス1、残りの “10” でクラス2、のように復号できます)。
この符号化での平均符号長を計算してみましょう。それぞれのクラスが出現する確率で、そのクラスの符号語の長さを重み付けして合計します。(途中計算は省きます、定義と照らし合わせながら確認してみてください。)
平均符号長 =0.7+0.2+0.3+0.3=1.5ビット
この平均符号長 1.5 ビットは、4つのクラスを固定長で表現した場合の 2 ビットよりも小さくなっています。つまり、確率の高いクラスに短い符号を割り当てる戦略によって、アルファベット1つあたりに平均して必要となるビット数を削減できます。
もし、頻度の高い文字に長い符号を割り当ててしまうと、長い文章を送る際に、トータルで送らなきゃいけない符号の量が多くなってしまいロスが大きくなってしまいます。
この損失が大きい状況というのは次のクロスエントロピーを知ると納得感があるかと思いますので、次に進みましょう。
クロスエントロピー:異なる分布間の測定
2つの異なる確率分布を使って情報を符号化する場合について考えます。
クロスエントロピーの定義: \(H(P,Q) = -\sum_x P(x) \log_2 Q(x)\)
クロスエントロピーは、以下の式で定義されます。
\[H(P,Q) = -\sum_x P(x) \log_2 Q(x)\]
\(H(P,Q)\):確率分布 \(P\) を、確率分布 \(Q\) を使って符号化した場合のクロスエントロピー
\(P(x)\): 事象 \(x\) が、確率分布 \(P\) で起こる確率
\(Q(x)\): 事象 \(x\) が、確率分布 \(Q\) で起こる確率
この式は、確率分布 \(P\) に基づいて事象 \(x\) が起きた時に、確率分布 \(Q\) を使って符号化した場合の符号長 \(-\log_2 Q(x)\) を計算し、それをすべての事象について足し合わせたものです。
わかりにくいと思うので、以下で例を出します。
例:真の分布 vs 間違った予測分布
ここでは、アルファベットが「A」「B」「C」「D」の4種類しかない、非常にシンプルな情報源を考えましょう。そして、「真の情報源の確率分布\(P\)」と、私たちが「誤って仮定してしまった確率分布\(Q\)」を以下のように設定します。
真の確率分布\(P\) (元の情報源の実際の確率)
\(P(A) = 0.7\) (「A」がすごくよく出る)
\(P(B) = 0.1\) (「B」はたまに出る)
\(P(C) = 0.1\) (「C」はたまに出る)
\(P(D) = 0.1\) (「D」はたまに出る)
この情報源では、「A」が出やすい、というハッキリした偏りがありますね。
誤って仮定してしまった確率分布\(Q\)(私たちがなぜか信じている確率)
\(Q(A) = 0.1\) (「A」はたまにしか出ないと思っている)
\(Q(B) = 0.1\) (「B」はたまにしか出ないと思っている)
\(Q(C) = 0.1\) (「C」はたまにしか出ないと思っている)
\(Q(D) = 0.7\) (「D」がすごくよく出ると信じ込んでいる)
こちらは、「D」が一番よく出る、と信じ込んでいるパターンです。真の分布とは全然違いますが、この場合を考えます。
クロスエントロピーの定義式を今一度確認します。
\[H(P,Q) = -\sum_x P(x) \log_2 Q(x)\]
さて、この状況で、クロスエントロピー \(H(P,Q)\) を計算してみましょう。定義式に、設定したPとQの確率値を代入します。\(x\) は「A」「B」「C」「D」です。
\(H(P, Q) = – [ P(A) \log_2 Q(A) + P(B) \log_2 Q(B) + P(C) \log_2 Q(C) + P(D) \log_2 Q(D) ]\)
確率値を代入します。
\(H(P, Q) = – [ 0.7 \log_2 (0.1) + 0.1 \log_2 (0.1) + 0.1 \log_2 (0.1) + 0.1 \log_2 (0.7) ]\)
ここで、\( \log_2 (0.1) \) の値を計算すると、約 -3.32 となります。また \( \log_2 (0.7) \) は約 -0.51 です。
\(H(P, Q) \approx – [ 0.7 \times (-3.32) + 0.1 \times (-3.32) + 0.1 \times (-3.32) + 0.1 \times (-0.51) ]\)
\(H(P, Q) \approx – [ -2.324 + (-0.332) + (-0.332) + (-0.051) ]\)
\(H(P, Q) \approx – [ -3.039 ]\)
\(H(P, Q) \approx 3.04\)
計算の結果、クロスエントロピーは約 3.04 ビットとなりました。
この「3.04ビット」という数値、一体何を意味しているんでしょうか?
これは、真の確率分布 P (「A」が0.7、「B, C, D」が各0.1)に従って出てくるアルファベットを、私たちが誤って信じている確率分布 Q (「D」が0.7、「A, B, C」が各0.1)に基づいて符号化(それぞれのシンボルを \(-\log_2 Q(x)\) ビットで表現)した場合に、アルファベット1つあたりに平均して必要となるビット数なんです。
つまり、分布を見誤ることで、アルファベット1つあたり平均 3.04 ビットも必要になってしまった、という状況を表しています。
真の確率分布を正しく予測できたのであれば、アルファベット1つ当たり平均1.35ビットで済みます。(計算は”bitに関する余談②”を参照してください)
クロスエントロピーは、このように「真実(P)を、別の(間違った)物差し(Q)を使って測った情報の平均的な重さ(コスト)」のようなものだと考えると分かりやすいかもしれません。
ちなみに先出ししてしまうのですが、この3.05と1.35の差分(見誤ったことによってロスしてしまう平均bit数)である約1.7bitがまさにKLダイバージェンスです。
ここで終わってもいいのですがもう少し理解を深めましょう。
P分布をQ分布で符号化する場合の平均符号長
先ほどの例ですでに完全に理解されたかと思いますが、クロスエントロピーは、確率分布 \(P\) に従う情報を、確率分布\(Q\) を使って符号化した場合の平均符号長を表します。
\(P\) と \(Q\) が同じ分布であれば、クロスエントロピーはエントロピーと等しくなります。
言い換えるのであれば、完全に理想的な符号化を行えるので、ロスは0ということです。
\(P\) と\(Q\) が異なる分布であれば、クロスエントロピーはエントロピーよりも大きくなります。
言い換えるのであれば、理想的な符号化はできていないので、ロスは発生してしまいます。
機械学習での使用例:分類問題での損失関数
クロスエントロピーは、機械学習の分類問題で、モデルの予測と正解ラベルの間の誤差を測るために使われます。
正解ラベル(P):実際のデータの分布
モデルの予測(Q):モデルが予測した確率分布
モデルの学習では、クロスエントロピーを最小化するようにパラメータを調整します。これは、モデルの予測が実際のデータに近づくように学習することを意味します。
KLダイバージェンス:分布間の非対称な距離
KLダイバージェンスの定義:\(D_{KL}(P||Q) = \sum_x P(x) \log \frac{P(x)}{Q(x)}\)
KLダイバージェンスは、以下の式で定義されます。一旦わかりにくい方を出します。
\[D_{KL}(P||Q) = \sum_x P(x) \log \frac{P(x)}{Q(x)}\]
\(D_{KL}(P||Q)\): 確率分布 \(P\)から確率分布 \(Q\) へのKLダイバージェンス
\(P(x)\): 事象 \(x\) が、確率分布 \(P\) で起こる確率
* \(Q(x)\): 事象 \(x\) が、確率分布 \(Q\) で起こる確率
この式は、確率分布 \(P\) と \(Q\) の確率の比の対数を、確率分布\(P\) で重み付けして足し合わせたものです。
クロスエントロピーとの関係:\(D_{KL}(P||Q) = H(P,Q) – H(P)\)
KLダイバージェンスは、上記の式を変形すると、先ほどもあったようにクロスエントロピーとエントロピーを使って、以下のように表現できます。
\[D_{KL}(P||Q) = H(P,Q) – H(P)\]
この式から、KLダイバージェンスは、確率分布 \(P\) の情報を、確率分布 \(Q\) を使って符号化した場合の平均符号長(クロスエントロピー)と、確率分布\(P\) の情報を最適に符号化した場合の平均符号長(エントロピー)の差であることがわかります。
つまり、KLダイバージェンスは、確率分布 \(Q\) を使うことによって、どれだけ情報が損失するかを表していると言えます。
性質:非対称性、非負性
KLダイバージェンスには、以下のような重要な性質があります。
非対称性:\(D_{KL}(P||Q) \neq D_{KL}(Q||P)\)
KLダイバージェンスは、確率分布\(P\) から \(Q\) への距離と、確率分布\(Q\) から\(P\) への距離が一般的に異なります。これは、KLダイバージェンスが「距離」という名前がついていますが、数学的な意味での距離の公理を満たさないことを意味します。
これは真の確率分布\(P\) に従って事象が出現することによる期待値計算であるため非対称性が生まれます。
非負性:\(D_{KL}(P||Q) \geq 0\)
KLダイバージェンスは、常に0以上の値を取ります。KLダイバージェンスが0になるのは、\(P = Q\) の場合のみです。
情報損失の観点からの解釈
KLダイバージェンスは、確率分布 \(Q\) で確率分布\(P\) を近似した場合の情報損失を表します。
KLダイバージェンスが大きいほど、情報損失が大きく、確率分布 \(Q\) は確率分布\(P\) をうまく近似できていないと言えます。
KLダイバージェンスは、「真の分布 (P) の情報構造を、別の分布 (Q) を使って表現しようとした際の、効率の悪さ」をビット数という物理量で測ったものです。この「効率の悪さの度合い」が、そのまま「二つの分布 (P) と (Q) がどれだけ異なっているか」という隔たりの度合いに対応しているんです。
つまり、KLダイバージェンスは、情報理論的な「損失」という物差しで、統計的な「分布間の違い」を定量化しているのです。「損失が大きい状態」が「分布が大きく異なる状態」を意味し、「損失がゼロの状態」が「分布が完全に一致している状態」を意味する、という対応関係が、両方の解釈を結びつけています。
KLダイバージェンスは、単なる数値的な違いだけでなく、その背後にある情報表現の効率性や不確かさといった情報理論的な意味を含んだ「分布間の違い」を測る、非常に強力な尺度なんです。
さて、これまでの議論では、KLダイバージェンスを「ビット数」、特に「最適な符号化からの余分な平均ビット数(情報の損失)」という観点から見てきました。このアプローチで、KLダイバージェンスが持つ情報理論的な意味がグッと分かりやすくなったかと思います。
もちろん、情報理論は通信やデータ圧縮といった物理的な「ビット」の効率を考える場面で発展してきましたから、「ビット数」という解釈は非常に重要で正確なものです。
しかし、KLダイバージェンスは、必ずしも物理的なビットの送受信といった狭い範囲に限定されるものではありません。最後にKLダイバージェンスを「ビット数」という具体的な単位から少し離れて、より一般化された視点で捉え直してみましょう。
情報理論的には、私たちが何かを観測したり、何かを予測したりする際に扱うデータや知識は、しばしば確率分布という形で表現できます。
KLダイバージェンスは、まさにこの「確率分布」という形で表現される「情報」そのものの「違い」や「隔たり」を測るための、汎用的な尺度と考えることができます。
- 真の分布 (P) という「情報」。
- 別の分布 (Q) という「情報」。
KLダイバージェンス \(D_{KL}(P || Q)\) は、これら二つの「情報」が、情報理論的に見てどれだけ異なっているか、どれだけ互いをうまく表現(近似)できないか、どれだけ区別できるか、を定量化していると考えます。
「ビット数」というのは、その「違い」や「隔たり」を測る際に、情報理論が提供する便利な物差しの一つとして止めるのが良さそうです。対数の底を2に選べば単位はビットになりますが、対数の底を自然対数 \e\ に選べば「ナット (nat)」という別の単位になるそうです。
単位が変わっても、KLダイバージェンスが二つの分布間の「情報的な違い」を測るという本質は変わりません。
KLダイバージェンスは、単にデータ圧縮の効率の話だけではなく、
といった、より広い意味での「情報」としての確率分布間の関係性を評価するために使われます。
機械学習で、モデルの予測分布と正解分布の間の損失関数として使われたり、異なる統計モデルの比較に使われたりするのは、KLダイバージェンスが、確率分布という形で表現される「情報」の本質的な「違い」を捉えることができる、非常に強力で汎用的なツールとしてよく出現するということのようです。
まとめ
この記事では、情報理論の基礎からKLダイバージェンスへと順を追って解説しました。
- 情報量は、出来事の「驚き」や「情報の重さ」。
- エントロピーは、情報源全体の「不確かさ」や「最適な平均符号長」。
- クロスエントロピーは、真の分布を別の分布で符号化した場合の「平均コスト」。
そして、KLダイバージェンスは、このクロスエントロピーとエントロピーの差であり、「情報の損失」そのものを表します。これは、真の分布 P の情報を、誤った分布 Q を基準として表現した際に発生する、平均的な「余分なビット数」と考えることができました。
KLダイバージェンスが大きいほど情報の損失が大きく、これは二つの確率分布 P と Q の「違い」や「隔たり」が大きいことを意味します。つまり、KLダイバージェンスは、情報理論的な観点から見た、二つの確率分布がどれだけ異なるかを測る尺度なのです。
非対称性を持つKLダイバージェンスは、単なるビット数の話に留まらず、確率分布という形で表現される「情報」の本質的な違いを捉える、強力なツールです。
この記事を通して、KLダイバージェンスの持つ意味が少しでもクリアになり、理解が進んだなら幸いです。
コメント