以下の内容はhttps://www.anarchive-beta.com/entry/2021/04/10/192703より取得しました。


【R】3.4.1:多次元ガウス分布のベイズ推論の実装:平均が未知の場合【緑ベイズ入門のノート】

はじめに

 『ベイズ推論による機械学習入門』(MLSシリーズ)の独学時のノートです。各種のモデルやアルゴリズムについて「数式・プログラム・図」を用いて解説します。
 本の補助として読んでください。

 この記事では、平均が未知の多次元ガウス分布に対するベイズ推論をR言語でスクラッチ実装します。

【前節の内容】

www.anarchive-beta.com

【他の節の内容】

www.anarchive-beta.com

【この節の内容】

3.4.1 多次元ガウス分布のベイズ推論の実装:平均が未知の場合

 多次元ガウスモデル(multivariate Gaussian model)に対するベイズ推論(Bayesian inference)を実装して、人工的に生成したデータを用いて、パラメータの学習と未知変数の予測を行う。この記事では、生成分布の平均パラメータ(mean parameter)が未知の場合を扱う。平均と精度が未知の多次元ガウスモデルでは、尤度関数を多次元ガウス分布(multivariate Gaussian distribution・多変量正規分布・multivariate Normal distribution)、事前分布をガウス分布とする。この記事では、Rを利用して実装する。
 多次元ガウスモデルについては「3.4.0:多次元ガウスモデルの生成モデルの導出【緑ベイズ入門のノート】 - からっぽのしょこ」、ベイズ推論については「3.4.1:多次元ガウス分布のベイズ推論の導出:平均が未知の場合【緑ベイズ入門のノート】 - からっぽのしょこ」、Pythonを利用する場合は「【Python】3.4.1:多次元ガウス分布の学習と予測:平均が未知の場合【緑ベイズ入門のノート】 - からっぽのしょこ」を参照のこと。

 利用するパッケージを読み込む。

# 利用パッケージ
library(tidyverse)
library(mvnfast)
library(RColorBrewer)

 この記事では、基本的に パッケージ名::関数名() の記法を使うので、パッケージの読み込みは不要である。ただし、作図コードについては(ごちゃごちゃしないように)パッケージ名を省略するため、ggplot2 を読み込む必要がある。
 また、ネイティブパイプ演算子 |> を使う。(基本的には)magrittrパッケージのパイプ演算子 %>% に置き換えられるが、その場合は magrittr を読み込む必要がある。

ベイズ推論の実装

 まずは、平均が未知の多次元ガウス分布に対するベイズ推論における一連の処理を確認する。
 生成モデル(平均が未知の多次元ガウスモデル)を設定して、モデルに従うデータ(トイデータ)を生成する。続いて、生成した観測データを用いて、事後分布の計算(パラメータの推定)を行う。さらに、事後分布のパラメータ(または観測データ)を用いて、予測分布の計算(未観測データの予測)を行う。

生成分布の設定

 データの生成分布(真の分布・ガウス分布)  p(\mathbf{x}_n \mid \boldsymbol{\mu}_{\mathrm{truth}}, \boldsymbol{\Lambda}) = \mathcal{N}(\mathbf{x}_n \mid \boldsymbol{\mu}_{\mathrm{truth}}, \boldsymbol{\Lambda}^{-1}) のパラメータ(真のパラメータ)  \boldsymbol{\mu}_{\mathrm{truth}}, \boldsymbol{\Lambda} を設定する。
 この例では、2次元のグラフで可視化するため、次元数を  D = 2 とする。パラメータ推定の処理自体は次元数に関わらず行える。
 ガウス分布については「【R】2次元ガウス分布の作図 - からっぽのしょこ」を参照のこと。

 生成分布のパラメータ  \boldsymbol{\mu}_{\mathrm{truth}}, \boldsymbol{\Lambda} を設定する。

# 次元数を指定
D <- 2

# 真のパラメータを指定
mu_truth_d <- c(25, 50)

# 既知のパラメータを指定
sigma2_dd <- c(
  900, -100, 
  -100, 400
) |> 
  matrix(nrow = D, ncol = D)

# 精度行列に変換
lambda_dd <- solve(sigma2_dd)
mu_truth_d; sigma2_dd; lambda_dd
[1] 25 50
     [,1] [,2]
[1,]  900 -100
[2,] -100  400
             [,1]         [,2]
[1,] 0.0011428571 0.0002857143
[2,] 0.0002857143 0.0025714286

 ガウス分布の平均パラメータ(平均ベクトル・実数ベクトル)  \boldsymbol{\mu}、分散パラメータ(分散共分散行列・正定値行列)  \boldsymbol{\Sigma} を指定して、精度パラメータ(精度行列)  \boldsymbol{\Lambda} = \boldsymbol{\Sigma}^{-1} を計算する。

 \displaystyle
\boldsymbol{\mu}
    = \begin{bmatrix}
        \mu_1 \\
        \mu_2
      \end{bmatrix}
,\ 
\boldsymbol{\Sigma}
    = \begin{bmatrix}
        \sigma_1^2 & \sigma_{1,2} \\
        \sigma_{2,1} & \sigma_2^2
      \end{bmatrix}

  \boldsymbol{\mu}_{\mathrm{truth}} がガウスモデルにおける真のパラメータであり、この値を求めるのがここでの目的(学習)である。

 生成分布の確率変数  x_1, x_2 の作図範囲を設定する。

# x軸の範囲を設定
k <- 3
u <- 5
x_1_size <- sqrt(sigma2_dd[1, 1]) |> # 基準値を指定
  (\(.) {. * k})() |> # 定数倍
  #(\(.) {max(., abs(x_nd[, 1]-mu_truth_d[1]))})() |> # サンプルと比較
  (\(.) {ceiling(. /u)*u})() # u単位で切り上げ
x_2_size <- sqrt(sigma2_dd[2, 2]) |> # 基準値を指定
  (\(.) {. * k})() |> # 定数倍
  #(\(.) {max(., abs(x_nd[, 2]-mu_truth_d[2]))})() |> # サンプルと比較
  (\(.) {ceiling(. /u)*u})() # u単位で切り上げ
x_1_min <- mu_truth_d[1] - x_1_size
x_1_max <- mu_truth_d[1] + x_1_size
x_2_min <- mu_truth_d[2] - x_2_size
x_2_max <- mu_truth_d[2] + x_2_size

# x軸の値を作成
x_1_vec <- seq(from = x_1_min, to = x_1_max, length.out = 251)
x_2_vec <- seq(from = x_2_min, to = x_2_max, length.out = 251)
x_1_min; x_1_max; head(x_1_vec); x_2_min; x_2_max; head(x_2_vec)
[1] -65
[1] 115
[1] -65.00 -64.28 -63.56 -62.84 -62.12 -61.40
[1] -10
[1] 110
[1] -10.00  -9.52  -9.04  -8.56  -8.08  -7.60

 この例では、指定したパラメータ(または生成したデータ)を使って、範囲を設定している。

 生成分布の確率密度を計算する。

# 生成分布の確率密度を計算:式(2.72)
model_df <- tidyr::expand_grid(
  x_1 = x_1_vec, # 1軸の確率変数
  x_2 = x_2_vec  # 2軸の確率変数
) |> # 格子点を作成
  dplyr::mutate(
    dens = mvnfast::dmvn(X = cbind(x_1, x_2), mu = mu_truth_d, sigma = sigma2_dd) # 確率密度
  )
model_df
# A tibble: 63,001 × 3
     x_1    x_2          dens
   <dbl>  <dbl>         <dbl>
 1   -65 -10    0.00000000549
 2   -65  -9.52 0.00000000598
 3   -65  -9.04 0.00000000652
 4   -65  -8.56 0.00000000709
 5   -65  -8.08 0.00000000772
 6   -65  -7.6  0.00000000839
 7   -65  -7.12 0.00000000912
 8   -65  -6.64 0.00000000990
 9   -65  -6.16 0.0000000107 
10   -65  -5.68 0.0000000117 
# ℹ 62,991 more rows

  x_1, x_2 の値の全ての組み合わせ(格子状の点)を expand_grid() で作成して、 \mathbf{x} = (x_1, x_2)^{\top} の点ごとに、ガウス分布に従う確率密度  \mathcal{N}(\mathbf{x} \mid \boldsymbol{\mu}_{\mathrm{truth}}, \boldsymbol{\Lambda}^{-1}) を計算する。
 多次元ガウス分布の確率密度関数 mvnfast::dmvn() の確率変数の引数 X \mathbf{x}、平均ベクトルの引数 mu \boldsymbol{\mu}_{\mathrm{truth}}、分散共分散行列の引数 sigma \boldsymbol{\Sigma} = \boldsymbol{\Lambda}^{-1} を指定する。

 生成分布のグラフを作成する。

# p(x)軸の範囲を設定
u <- 0.0003
dens_max <- model_df |> 
  dplyr::pull(dens) |> 
  max() |> 
  (\(.) {ceiling(. /u)*u})() # u単位で切り上げ

# 等高線を設定
level_num  <- 16 # 等高線の数を指定
dens_vals  <- seq(from = 0, to = dens_max, length.out = level_num)
color_name <- "YlOrRd" # カラーマップを指定
color_num  <- 9        # カラーマップの色数を設定
color_vals <- colorRampPalette(colors = RColorBrewer::brewer.pal(n = color_num, name = color_name))(n = level_num-1) |> # 色数を拡張
  rev()

# 生成分布のラベルを作成
model_param_lbl <- paste0(
  "list(", 
    "mu[truth] == bgroup('(', atop(", 
      paste(round(mu_truth_d, digits = 2), collapse = ", "), 
    "), ')'), ", 
    "Lambda == bgroup('(', atop(",
      "list(", paste(round(lambda_dd[1, ], digits = 5), collapse = ", "), "), ", 
      "list(", paste(round(lambda_dd[2, ], digits = 5), collapse = ", "), ")", 
    "), ')')", 
  ")"
) |> 
  parse(text = _)

# 生成分布を作図
ggplot() + 
  geom_contour_filled(
    data    = model_df, 
    mapping = aes(x = x_1, y = x_2, z = dens, fill = after_stat(level), linetype = "model"), 
    breaks = dens_vals, # 色数の拡張用
    alpha = 0.6
  ) + # 生成分布
  scale_linetype_manual(
    breaks = "model", 
    values = "blank", 
    labels = "true model", 
    name   = ""
  ) + # (凡例表示用)
  scale_fill_manual(values = color_vals) + # 色数の拡張用
  guides(
    linetype = guide_legend(order = 1), 
    fill     = guide_legend(order = 2)
  ) + 
  labs(
    title = "multivariate Gaussian distribution", 
    subtitle = model_param_lbl, 
    fill = "density", 
    x = expression(x[1]), 
    y = expression(x[2])
  )

生成分布(真の分布・多次元ガウス分布)のグラフ

 真の分布(ガウス分布)を等高線(グラデーション)で示す。

 真のパラメータ  \boldsymbol{\mu}_{\mathrm{truth}} を求めることは、真の分布  \mathcal{N}(\mathbf{x}_n \mid \boldsymbol{\mu}_{\mathrm{truth}}, \boldsymbol{\Lambda}^{-1}) を求めることを意味する。

データの生成

 設定した生成分布(ガウス分布)  p(\mathbf{x}_n \mid \boldsymbol{\mu}_{\mathrm{truth}}, \boldsymbol{\Lambda}) = \mathcal{N}(\mathbf{x}_n \mid \boldsymbol{\mu}_{\mathrm{truth}}, \boldsymbol{\Lambda}^{-1}) に従うデータ(観測データ)  \mathbf{X} を作成する。
 ガウスモデルのデータ生成については「【R】多次元ガウス分布の乱数生成 - からっぽのしょこ」を参照のこと。

 生成分布からデータ  \mathbf{X} を生成する。

# データ数を指定
N <- 300

# 観測データを生成
x_nd <- mvnfast::rmvn(n = N, mu = mu_truth_d, sigma = sigma2_dd)
head(x_nd)
           [,1]     [,2]
[1,] 117.783503 21.05768
[2,]   4.004877 32.81244
[3,]  68.123504 35.78109
[4,]  25.771419 33.71867
[5,] -18.427680 58.99636
[6,] -31.524175 64.56169

 データ数  N を指定して、多次元ガウス分布に従う乱数  \mathbf{X} = \{\mathbf{x}_1, \cdots, \mathbf{x}_N\} \mathbf{x}_n = (x_{n,1}, x_{n,2})^{\top} を生成する。
 多次元ガウス分布の乱数生成関数 mvnfast::rmvnm() のサンプルサイズの引数 n N、平均ベクトルの引数 mu \boldsymbol{\mu}_{\mathrm{truth}}、分散共分散行列の引数 sigma \boldsymbol{\Sigma} = \boldsymbol{\Lambda}^{-1} を指定する。

 観測データ  \mathbf{X} を集計する。

# 階級数を指定
bin_num <- 40

# 階級幅を計算
bin_1_size <- (x_1_max - x_1_min) / bin_num
bin_2_size <- (x_2_max - x_2_min) / bin_num

# 境界値の範囲を設定
bin_1_min <- x_1_min - 0.5*bin_1_size
bin_2_min <- x_2_min - 0.5*bin_2_size

# 観測データを集計
obs_df <- tibble::tibble(
  x_1 = x_nd[, 1], # 1軸のサンプル値
  x_2 = x_nd[, 2]  # 2軸のサンプル値
) |> 
  dplyr::mutate(
    bin_1_i  = (x_1 - bin_1_min) %/% bin_1_size, # 1軸の階級番号
    bin_2_i  = (x_2 - bin_2_min) %/% bin_2_size, # 2軸の階級番号
    center_1 = bin_1_min + (bin_1_i + 0.5) * bin_1_size, # 1軸の階級値
    center_2 = bin_2_min + (bin_2_i + 0.5) * bin_2_size  # 2軸の階級値
  ) |> 
  dplyr::count(
    center_1, center_2, name = "freq" # 度数
  ) |> 
  dplyr::mutate(
    dens = freq / (bin_1_size * bin_2_size * N) # 密度
  ) |> 
  tidyr::complete(
    center_1 = seq(from = x_1_min, to = x_1_max, by = bin_1_size), 
    center_2 = seq(from = x_2_min, to = x_2_max, by = bin_2_size), 
    fill = list(freq = 0, dens = 0)
  ) # 未観測値を補完
obs_df
# A tibble: 1,685 × 4
   center_1 center_2  freq  dens
      <dbl>    <dbl> <int> <dbl>
 1      -65      -10     0     0
 2      -65       -7     0     0
 3      -65       -4     0     0
 4      -65       -1     0     0
 5      -65        2     0     0
 6      -65        5     0     0
 7      -65        8     0     0
 8      -65       11     0     0
 9      -65       14     0     0
10      -65       17     0     0
# ℹ 1,675 more rows

 階級数を指定して、階級幅  w_1, w_2 を設定する。
  x_1, x_2 の範囲で階級値を作成して、観測データ x_nd に含まれる要素数をカウントして、度数  N_{x_1,x_2} と密度  \frac{N_{x_1,x_2}}{w_1 w_2 N} を求める。

 観測データ  \mathbf{X} のグラフを作成する。

# 観測データのラベルを作成
obs_param_lbl <- paste0(
  "list(", 
    "N == ", N, ", ", 
    "mu[truth] == bgroup('(', atop(", 
      paste(round(mu_truth_d, digits = 2), collapse = ", "), 
    "), ')'), ", 
    "Lambda == bgroup('(', atop(",
      "list(", paste(round(lambda_dd[1, ], digits = 5), collapse = ", "), "), ", 
      "list(", paste(round(lambda_dd[2, ], digits = 5), collapse = ", "), ")", 
    "), ')')", 
  ")"
) |> 
  parse(text = _)

# 観測データを作図
ggplot() + 
  geom_tile(
    data    = obs_df, 
    mapping = aes(x = center_1, y = center_2, fill = dens), 
    alpha = 0.5
  ) + # 観測データ
  geom_contour(
    data    = model_df, 
    mapping = aes(x = x_1, y = x_2, z = dens, color = after_stat(level), linetype = "model"), 
    breaks = dens_vals, # 軸目盛の対応用
    linewidth = 1
  ) + # 生成分布
  scale_colour_distiller(palette = color_name, direction = -1) + # 軸目盛の対応用
  scale_fill_gradientn(colors = color_vals) + # 軸目盛の対応用
  scale_linetype_manual(
    breaks = "model", 
    values = "dashed", 
    labels = "true model", 
    name = ""
  ) + # (凡例の表示用)
  guides(
    linetype = guide_legend(override.aes = list(color = "red", linewidth = 0.5), order = 1), 
    color    = guide_colorbar(order = 2)
  ) + 
  coord_cartesian(
    xlim = c(x_1_min, x_1_max), 
    ylim = c(x_2_min, x_2_max)
  ) + 
  labs(
    title = "multivariate Gaussian distribution", 
    subtitle = obs_param_lbl, 
    fill  = "density", 
    color = "density", 
    x = expression(x[1]), 
    y = expression(x[2])
  )

観測データ(多次元ガウス乱数)のグラフ

  \mathbf{x} の値ごとの確率密度(生成分布)を等高線(赤色・破線)、観測データ  \mathbf{X} の度数  N_{x_1,x_2} (密度  \frac{N_{x_1,x_2}}{w_1 w_2 N} )をヒートマップ(グラデーション)で示す。

 データ数  N が十分に大きいと、観測データ  \mathbf{X} のヒートマップの形状が生成分布  \mathcal{N}(\mathbf{x}_n \mid \boldsymbol{\mu}_{\mathrm{truth}}, \boldsymbol{\Lambda}^{-1}) に近付く。

事前分布の設定

 パラメータ  \boldsymbol{\mu} の事前分布(ガウス分布)  p(\boldsymbol{\mu} \mid \mathbf{m}, \boldsymbol{\Lambda}_{\mu}) = \mathcal{N}(\boldsymbol{\mu} \mid \mathbf{m}, \boldsymbol{\Lambda}_{\mu}^{-1}) のパラメータ(超パラメータ)  \mathbf{m}, \boldsymbol{\Lambda} を設定する。

 事前分布のパラメータ  \mathbf{m}, \boldsymbol{\Lambda}_{\mu} を設定する。

# 事前分布のパラメータを指定
m_d          <- rep(0, times = D)
sigma2_mu_dd <- diag(D) * 100^2

# 精度行列に変換
lambda_mu_dd <- solve(sigma2_mu_dd)
m_d; sigma2_mu_dd; lambda_mu_dd
[1] 0 0
      [,1]  [,2]
[1,] 10000     0
[2,]     0 10000
      [,1]  [,2]
[1,] 1e-04 0e+00
[2,] 0e+00 1e-04

 ガウス分布の平均パラメータ(平均パラメータ・実数ベクトル)  \mathbf{m}、分散パラメータ(分散共分散行列・正定値行列)  \boldsymbol{\Sigma}_{\mu} を指定して、精度パラメータ(精度行列)  \boldsymbol{\Lambda}_{\mu} = \boldsymbol{\Sigma}_{\mu}^{-1} を計算する。

 \displaystyle
\mathbf{m}
    = \begin{bmatrix}
        m_1 \\
        m_2
      \end{bmatrix}
,\ 
\boldsymbol{\Sigma}_{\mu}
    = \begin{bmatrix}
        \sigma_{\mu,1}^2 & \sigma_{\mu,1,2} \\
        \sigma_{\mu,2,1} & \sigma_{\mu,2}^2
      \end{bmatrix}

 事前分布の確率変数  \mu_1, \mu_2 の作図範囲を設定する。

# μ軸の範囲を設定
mu_1_min <- x_1_min
mu_1_max <- x_1_max
mu_2_min <- x_2_min
mu_2_max <- x_2_max

# μ軸の値を作成
mu_1_vec <- seq(from = mu_1_min, to = mu_1_max, length.out = 251)
mu_2_vec <- seq(from = mu_2_min, to = mu_2_max, length.out = 251)
mu_1_min; mu_1_max; head(mu_1_vec); mu_2_min; mu_2_max; head(mu_2_vec)
[1] -65
[1] 115
[1] -65.00 -64.28 -63.56 -62.84 -62.12 -61.40
[1] -10
[1] 110
[1] -10.00  -9.52  -9.04  -8.56  -8.08  -7.60

 この例では、生成分布の確率変数  x_1, x_2 の範囲を設定している。

 事前分布の確率密度を計算する。

# 生成分布の確率密度を計算:式(2.72)
prior_df <- tidyr::expand_grid(
  mu_1 = mu_1_vec, # 1軸の確率変数
  mu_2 = mu_2_vec  # 2軸の確率変数
) |> # 格子点を作成
  dplyr::mutate(
    dens = mvnfast::dmvn(X = cbind(mu_1, mu_2), mu = m_d, sigma = sigma2_mu_dd) # 確率密度
  )
prior_df
# A tibble: 63,001 × 3
    mu_1   mu_2      dens
   <dbl>  <dbl>     <dbl>
 1   -65 -10    0.0000128
 2   -65  -9.52 0.0000128
 3   -65  -9.04 0.0000128
 4   -65  -8.56 0.0000128
 5   -65  -8.08 0.0000128
 6   -65  -7.6  0.0000128
 7   -65  -7.12 0.0000129
 8   -65  -6.64 0.0000129
 9   -65  -6.16 0.0000129
10   -65  -5.68 0.0000129
# ℹ 62,991 more rows

  \mu_1, \mu_2 の値の全ての組み合わせ(格子状の点)を expand_grid() で作成して、 \boldsymbol{\mu} = (\mu_1, \mu_2)^{\top} の点ごとに、ガウス分布に従う確率密度  \mathcal{N}(\boldsymbol{\mu} \mid \mathbf{m}, \boldsymbol{\Lambda}^{-1}) を計算する。
 多次元ガウス分布の確率密度関数 mvnfast::dmvn() の確率変数の引数 X \boldsymbol{\mu}、平均ベクトルの引数 mu \mathbf{m}、分散共分散行列の引数 sigma \boldsymbol{\Sigma}_{\mu} = \boldsymbol{\Lambda}_{\mu}^{-1} を指定する。

 事前分布のグラフを作成する。

# 事前分布のラベルを作成
prior_param_lbl <- paste0(
  "list(", 
    "mu[truth] == bgroup('(', atop(", 
      paste(round(mu_truth_d, digits = 2), collapse = ", "), 
    "), ')'), ", 
    "m == bgroup('(', atop(", 
      paste(round(m_d, digits = 2), collapse = ", "), 
    "), ')'), ", 
    "Lambda[mu] == bgroup('(', atop(", 
      "list(", paste(round(lambda_mu_dd[1, ], digits = 5), collapse = ", "), "), ", 
      "list(", paste(round(lambda_mu_dd[2, ], digits = 5), collapse = ", "), ")", 
    "), ')')", 
  ")"
) |> 
  parse(text = _)

# 事前分布を作図
ggplot() + 
  geom_contour_filled(
    data    = prior_df, 
    mapping = aes(x = mu_1, y = mu_2, z = dens, fill = after_stat(level), linetype = "prior"), 
    alpha = 0.6
  ) + # 事前分布
  geom_vline(
    mapping = aes(xintercept = mu_truth_d[1], linetype = "model"), 
    color = "red", linewidth = 1
  ) + # 真のパラメータ
  geom_hline(
    mapping = aes(yintercept = mu_truth_d[2], linetype = "model"), 
    color = "red", linewidth = 1
  ) + # 真のパラメータ
  scale_x_continuous(
    sec.axis = sec_axis(
      transform = ~ ., 
      breaks    = mu_truth_d[1], 
      labels    = expression(mu[1]^{truth})
    ) # パラメータラベル
  ) + 
  scale_y_continuous(
    sec.axis = sec_axis(
      transform = ~ ., 
      breaks    = mu_truth_d[2], 
      labels    = expression(mu[2]^{truth})
    ) # パラメータラベル
  ) + 
  scale_linetype_manual(
    breaks = c("model", "prior"), 
    values = c("dashed", "blank"), 
    labels = c("true parameter", "prior distribution"), 
    name = ""
  ) + # (凡例の表示用)
  guides(
    linetype = guide_legend(override.aes = list(linewidth = 0.5), order = 1), 
    fill     = guide_legend(order = 2)
  ) + 
  labs(
    title = "multivariate Gaussian distribution", 
    subtitle = prior_param_lbl, 
    fill = "density", 
    x = expression(mu[1]), 
    y = expression(mu[2])
  )

事前分布(多次元ガウス分布)のグラフ

 真のパラメータを直線(赤色・破線)、事前分布(ガウス分布)を等高線(グラデーション)で示す。

 真のパラメータ(真の分布のパラメータ)  \boldsymbol{\mu}_{\mathrm{truth}} と、パラメータ  \boldsymbol{\mu} の事前分布  \mathcal{N}(\boldsymbol{\mu} \mid \mathbf{m}, \boldsymbol{\Lambda}_{\mu}) の位置関係を図で確認する。

事後分布の計算

 観測データ  \mathbf{X} から、パラメータ  \boldsymbol{\mu} の事後分布(ガウス分布)  p(\boldsymbol{\mu} \mid \mathbf{X}, \mathbf{m}, \boldsymbol{\Lambda}_{\mu}) = \mathcal{N}(\boldsymbol{\mu} \mid \hat{\mathbf{m}}, \hat{\boldsymbol{\Lambda}}_{\mu}^{-1}) のパラメータ(超パラメータ)  \hat{\mathbf{m}}, \hat{\boldsymbol{\Lambda}}_{\mu} を求める(真のパラメータ  \boldsymbol{\mu}_{\mathrm{truth}} を分布推定する)。

 事後分布のパラメータ  \hat{\mathbf{m}}, \hat{\boldsymbol{\Lambda}}_{\mu} を計算する。

# 事後分布のパラメータを計算:式(3.102, 3.103)
lambda_mu_hat_dd <- N * lambda_dd + lambda_mu_dd
m_hat_d          <- (solve(lambda_mu_hat_dd) %*% (lambda_dd %*% colSums(x_nd) + lambda_mu_dd %*% m_d)) |> 
  as.vector()
m_hat_d; lambda_mu_hat_dd
[1] 24.29844 49.18142
           [,1]       [,2]
[1,] 0.34295714 0.08571429
[2,] 0.08571429 0.77152857

 事後分布のパラメータは、次の式で計算できる。

 \displaystyle
\begin{align}
\hat{\boldsymbol{\Lambda}}_{\mu}
   &= N \boldsymbol{\Lambda}
      + \boldsymbol{\Lambda}_{\mu}
\tag{3.102}\\
\hat{\mathbf{m}}
   &= \hat{\boldsymbol{\Lambda}}_{\mu}^{-1} \left(
          \boldsymbol{\Lambda}
          \sum_{n=1}^N \mathbf{x}_n
          + \boldsymbol{\Lambda}_{\mu} \mathbf{m}
      \right)
\tag{3.103}
\end{align}

 事後分布の確率密度を計算する。

# 事後分布の確率密度を計算:式(2.72)
posterior_df <- tidyr::expand_grid(
  mu_1 = mu_1_vec, # 1軸の確率変数
  mu_2 = mu_2_vec  # 2軸の確率変数
) |> # 格子点を作成
  dplyr::mutate(
    dens = mvnfast::dmvn(X = cbind(mu_1, mu_2), mu = m_hat_d, sigma = solve(lambda_mu_hat_dd)) # 確率密度
  )
posterior_df
# A tibble: 63,001 × 3
    mu_1   mu_2  dens
   <dbl>  <dbl> <dbl>
 1   -65 -10        0
 2   -65  -9.52     0
 3   -65  -9.04     0
 4   -65  -8.56     0
 5   -65  -8.08     0
 6   -65  -7.6      0
 7   -65  -7.12     0
 8   -65  -6.64     0
 9   -65  -6.16     0
10   -65  -5.68     0
# ℹ 62,991 more rows

 事前分布のときと同様にして、ガウス分布に従う確率密度  \mathcal{N}(\boldsymbol{\mu} \mid \hat{\mathbf{m}}, \hat{\boldsymbol{\Lambda}}_{\mu}) を計算する。

 事後分布のグラフを作成する。

# 事後分布のラベルを作成
posterior_param_lbl <- paste0(
  "list(", 
    "N == ", N, ", ", 
    "mu[truth] == bgroup('(', atop(", 
      paste(round(mu_truth_d, digits = 2), collapse = ", "), 
    "), ')'), ", 
    "hat(m) == bgroup('(', atop(", 
      paste(round(m_hat_d, digits = 2), collapse = ", "), 
    "), ')'), ", 
    "hat(Lambda)[mu] == bgroup('(', atop(", 
      "list(", paste(round(lambda_mu_hat_dd[1, ], digits = 5), collapse = ", "), "), ", 
      "list(", paste(round(lambda_mu_hat_dd[2, ], digits = 5), collapse = ", "), ")", 
    "), ')')", 
  ")"
) |> 
  parse(text = _)

# 事後分布を作図
ggplot() + 
  geom_contour_filled(
    data    = posterior_df, 
    mapping = aes(x = mu_1, y = mu_2, z = dens, fill = after_stat(level), linetype = "posterior"), 
    alpha = 0.6
  ) + # 事後分布
  geom_vline(
    mapping = aes(xintercept = mu_truth_d[1], linetype = "model"), 
    color = "red", linewidth = 1
  ) + # 真のパラメータ
  geom_hline(
    mapping = aes(yintercept = mu_truth_d[2], linetype = "model"), 
    color = "red", linewidth = 1
  ) + # 真のパラメータ
  scale_x_continuous(
    sec.axis = sec_axis(
      transform = ~ ., 
      breaks    = mu_truth_d[1], 
      labels    = expression(mu[1]^{truth})
    ) # パラメータラベル
  ) + 
  scale_y_continuous(
    sec.axis = sec_axis(
      transform = ~ ., 
      breaks    = mu_truth_d[2], 
      labels    = expression(mu[2]^{truth})
    ) # パラメータラベル
  ) + 
  scale_linetype_manual(
    breaks = c("model", "posterior"), 
    values = c("dashed", "blank"), 
    labels = c("true parameter", "posterior distribution"), 
    name = ""
  ) + # (凡例の表示用)
  guides(
    linetype = guide_legend(override.aes = list(linewidth = 0.5), order = 1), 
    fill     = guide_legend(order = 2)
  ) + 
  labs(
    title = "multivariate Gaussian distribution", 
    subtitle = posterior_param_lbl, 
    fill = "density", 
    x = expression(mu[1]), 
    y = expression(mu[2])
  )

事後分布(多次元ガウス分布)のグラフ

真の値の付近の様子

 真のパラメータを直線(赤色・破線)、事後分布(ガウス分布)を等高線(グラデーション)で示す。
 下の図は、パラメータの真の値から標準偏差1つ分の範囲を拡大したものである。(そのため、作図用の点の数によっては、等高線が粗くなる。)

 データ数  N が十分に大きいと、パラメータ  \boldsymbol{\mu} の事後分布  \mathcal{N}(\boldsymbol{\mu} \mid \hat{\mathbf{m}}, \hat{\boldsymbol{\Lambda}}_{\mu}^{-1}) のピークが真のパラメータ  \boldsymbol{\mu}_{\mathrm{truth}} に近付く。

予測分布の計算

 観測データ  \mathbf{X} から、未観測のデータ  \mathbf{x}_{*} の予測分布(ガウス分布)  p(\mathbf{x}_{*} \mid \mathbf{X}, \mathbf{m}, \boldsymbol{\Lambda}_{\mu}) = \mathcal{N}(\mathbf{x}_{*} \mid \hat{\boldsymbol{\mu}}_{*}, \hat{\boldsymbol{\Lambda}}_{*}^{-1}) を求める。

 予測分布のパラメータ  \hat{\boldsymbol{\mu}}_{*}, \hat{\boldsymbol{\Lambda}}_{*} を計算する。

# 事後分布のパラメータにより予測分布のパラメータを計算:式(3.109', 3.110')
mu_star_hat_d      <- m_hat_d
lambda_star_hat_dd <- solve(solve(lambda_dd) + solve(lambda_mu_hat_dd))
mu_star_hat_d; lambda_star_hat_dd
[1] 24.29844 49.18142
             [,1]         [,2]
[1,] 0.0011390614 0.0002847651
[2,] 0.0002847651 0.0025628867
# 観測データにより予測分布のパラメータを計算:式(3.109', 3.110')
mu_star_hat_d      <- (solve(lambda_mu_hat_dd) %*% (lambda_dd %*% colSums(x_nd) + lambda_mu_dd %*% m_d)) |> 
  as.vector()
lambda_star_hat_dd <- solve(solve(lambda_dd) + solve(N * lambda_dd + lambda_mu_dd))
mu_star_hat_d; lambda_star_hat_dd
[1] 24.29844 49.18142
             [,1]         [,2]
[1,] 0.0011390614 0.0002847651
[2,] 0.0002847651 0.0025628867

 予測分布のパラメータは、事後分布のパラメータ  \hat{\mathbf{m}}, \hat{\boldsymbol{\Lambda}}_{\mu} または観測データ  \mathbf{X} を用いて、次の式で計算できる。

 \displaystyle
\begin{align}
\hat{\boldsymbol{\Lambda}}_{*}
   &= \Bigl(
          \boldsymbol{\Lambda}^{-1}
          + \hat{\boldsymbol{\Lambda}}_{\mu}^{-1}
      \Bigr)^{-1}
\\
   &= \Bigl\{
          \boldsymbol{\Lambda}^{-1}
          + (N \boldsymbol{\Lambda} + \boldsymbol{\Lambda}_{\mu})^{-1}
      \Bigr\}^{-1}
\tag{3.109'}\\
\hat{\boldsymbol{\mu}}_{*}
   &= \hat{\mathbf{m}}
\\
   &= \hat{\boldsymbol{\Lambda}}_{\mu}^{-1} \left(
          \boldsymbol{\Lambda}
          \sum_{n=1}^N \mathbf{x}_n
          + \boldsymbol{\Lambda}_{\mu}
            \mathbf{m}
      \right)
\tag{3.110'}
\end{align}

 予測分布の確率密度を計算する。

# 予測分布の確率密度を計算:式(2.72)
predict_df <- tidyr::expand_grid(
  x_1 = x_1_vec, # 1軸の確率変数
  x_2 = x_2_vec  # 2軸の確率変数
) |> # 格子点を作成
  dplyr::mutate(
    dens = mvnfast::dmvn(X = cbind(x_1, x_2), mu = mu_star_hat_d, sigma = solve(lambda_star_hat_dd)) # 確率密度
  )
predict_df
# A tibble: 63,001 × 3
     x_1    x_2          dens
   <dbl>  <dbl>         <dbl>
 1   -65 -10    0.00000000713
 2   -65  -9.52 0.00000000776
 3   -65  -9.04 0.00000000844
 4   -65  -8.56 0.00000000918
 5   -65  -8.08 0.00000000997
 6   -65  -7.6  0.0000000108 
 7   -65  -7.12 0.0000000118 
 8   -65  -6.64 0.0000000127 
 9   -65  -6.16 0.0000000138 
10   -65  -5.68 0.0000000150 
# ℹ 62,991 more rows

 生成分布のときと同様にして、ガウス分布に従う確率密度  \mathcal{N}(\mathbf{x}_{*} \mid \boldsymbol{\mu}_{*}, \boldsymbol{\Lambda}_{*}^{-1}) を計算する。

 予測分布のグラフを作成する。

# 予測分布のラベルを作成
predict_param_lbl <- paste0(
  "list(", 
    "N == ", N, ", ", 
    "mu[truth] == bgroup('(', atop(", 
      paste(round(mu_truth_d, digits = 2), collapse = ", "), 
    "), ')'), ", 
    "Lambda == bgroup('(', atop(",
      "list(", paste(round(lambda_dd[1, ], digits = 5), collapse = ", "), "), ", 
      "list(", paste(round(lambda_dd[2, ], digits = 5), collapse = ", "), ")", 
    "), ')'), ", 
    "mu['*'] == bgroup('(', atop(", 
      paste(round(mu_star_hat_d, digits = 2), collapse = ", "), 
    "), ')'), ", 
    "Lambda['*'] == bgroup('(', atop(",
      "list(", paste(round(lambda_star_hat_dd[1, ], digits = 5), collapse = ", "), "), ", 
      "list(", paste(round(lambda_star_hat_dd[2, ], digits = 5), collapse = ", "), ")", 
    "), ')')", 
  ")"
) |> 
  parse(text = _)

# 予測分布を作図
ggplot() + 
  geom_contour_filled(
    data    = predict_df, 
    mapping = aes(x = x_1, y = x_2, z = dens, fill = after_stat(level), linetype = "predict"), 
    breaks = dens_vals, # (等高線の位置の共通化用)
    alpha = 0.6
  ) + # 予測分布
  geom_contour(
    data    = model_df, 
    mapping = aes(x = x_1, y = x_2, z = dens, color = after_stat(level), linetype = "model"), 
    breaks = dens_vals, # (等高線の位置の共通化用)
    linewidth = 1
  ) + # 真の分布
  scale_colour_distiller(
    palette = "YlOrRd", direction = -1
  ) + 
  scale_linetype_manual(
    breaks = c("model", "predict"), 
    values = c("dashed", "blank"), 
    labels = c("true model", "predict distribution"), 
    name = ""
  ) + # (凡例の表示用)
  guides(
    linetype = guide_legend(
      override.aes = list(color = c("red", NA), linewidth = 0.5), 
      order = 1
    ), 
    fill  = guide_legend(order = 2), 
    color = "none"
  ) + 
  labs(
    title = "multivariate Gaussian distribution", 
    subtitle = predict_param_lbl, 
    fill = "density", 
    x = expression(x[1]), 
    y = expression(x[2])
  )

事後予測分布(多次元ガウス分布)のグラフ

 真の分布(ガウス分布)を等高線(赤色・破線)、予測分布(ガウス分布)を等高線(グラデーション)で示す。

 データ数  N が十分に大きいと、未観測データ  \mathbf{x}_{*} の予測分布  \mathcal{N}(\mathbf{x}_{*} \mid \hat{\boldsymbol{\mu}}_{*}, \hat{\boldsymbol{\Lambda}}_{*}^{-1}) の形状が真の分布  \mathcal{N}(\mathbf{x}_{n,d} \mid \hat{\boldsymbol{\mu}}_{*}, \hat{\boldsymbol{\Lambda}}_{*}^{-1}) に近付く。

 以上で、平均が未知の多次元ガウスモデルのベイズ推論を実装した。

学習の推移

 次は、平均が未知の多次元ガウス分布に対するベイズ推論を図で確認する。
 データ数を増やして分布の変化をアニメーションで確認する。
 作図コードについては「Suyama-Bayes/code/multivariate_gaussian_model/bayesian_inference/plot_parameter_updates_unknown_mean.R at master · anemptyarchive/Suyama-Bayes · GitHub」を参照のこと。

データ数と分布の関係

 データ数  N を変化させたときの事後分布  p(\boldsymbol{\mu} \mid \mathbf{X}, \boldsymbol{\Lambda}, \mathbf{m}, \boldsymbol{\Lambda}_{\mu}) の変化をアニメーションにする。

  n 個のデータから求めた(  n 回更新した)事後分布(ガウス分布)  \mathcal{N}(\boldsymbol{\mu} \mid \mathbf{m}^{(n)}, \boldsymbol{\Lambda}_{\mu}^{(n)}) を紫色の曲線(実線)、 n 番目の観測データ  \mathbf{x}_n に対応する位置  (\mu_1, \mu_2) = (x_{n,1}, x_{n,2}) を桃色の点で示す。

 「ベイズ推論の実装」では、 N 個(複数)のデータ  \mathbf{X} を用いて、事後分布のパラメータ  \hat{\mathbf{m}}, \hat{\boldsymbol{\Lambda}}_{\mu} を一括更新した。
  n 番目(1つ)のデータ  \mathbf{x}_n を用いて逐次更新する場合、 n 回目の事後分布のパラメータ  \mathbf{m}^{(n)}, \boldsymbol{\Lambda}_{\mu}^{(n)} は、次の式で計算できる。

 \displaystyle
\begin{aligned}
\mathbf{m}^{(n)}
   &= (\boldsymbol{\Lambda}_{\mu}^{(n)})^{-1} \left(
          \boldsymbol{\Lambda}_{\mu}^{(n-1)} \mathbf{m}^{(n-1)}
          + \boldsymbol{\Lambda} \mathbf{x}_n
      \right)
\\
\boldsymbol{\Lambda}_{\mu}^{(n)}
   &= \boldsymbol{\Lambda}_{\mu}^{(n-1)}
      + \boldsymbol{\Lambda}
\end{aligned}

 超パラメータの初期値(1回目の更新における事前分布のパラメータ)を  \mathbf{m}^{(0)}, \boldsymbol{\Lambda}_{\mu}^{(0)} として、 \mathbf{m}^{(n-1)}, \boldsymbol{\Lambda}_{\mu}^{(n-1)} は、 n-1 回目の更新における事後分布のパラメータであり、また  n 回目の更新における事前分布のパラメータを表す。

 データ数  N が大きくなるのに応じて、パラメータ  \boldsymbol{\mu} の事後分布  \mathcal{N}(\boldsymbol{\mu} \mid \hat{\mathbf{m}}, \hat{\boldsymbol{\Lambda}}_{\mu}) のピークが真のパラメータ  \mu_{\mathrm{truth}} に近付いていくのを確認できる。

 データ数  N を変化させたときの予測分布  p(\mathbf{x}_{*} \mid \mathbf{X}, \boldsymbol{\Lambda}, \mathbf{m}, \boldsymbol{\Lambda}_{\mu}) の変化をアニメーションにする。

  n 個のデータから求めた(  n 回更新した)予測分布(ガウス分布)  \mathcal{N}(\mathbf{x}_{*} \mid \hat{\boldsymbol{\mu}}_{*}, \hat{\boldsymbol{\Lambda}}_{*}^{-1}) を紫色の曲線(実線)、 n 番目の観測データ  \mathbf{x}_n を桃色の点で示す。

 「ベイズ推論の実装」では、 N 個(複数)のデータ  \mathbf{X} を用いて、予測分布のパラメータ  \hat{\boldsymbol{\mu}}_{*}, \hat{\boldsymbol{\Lambda}}_{*} を一括更新した。
  n 番目(1つ)のデータ  \mathbf{x}_n を用いて逐次更新する場合、 n 回目の予測分布のパラメータ  \boldsymbol{\mu}_{*}^{(n)}, \boldsymbol{\Lambda}_{*}^{(n)} は、次の式で計算できる。

 \displaystyle
\begin{aligned}
\boldsymbol{\mu}_{*}^{(n)}
   &= \mathbf{m}^{(n)}
\\
\boldsymbol{\Lambda}_{*}^{(n)}
   &= \Bigl(
          (\boldsymbol{\Lambda}_{\mu}^{(n)})^{-1}
          + \boldsymbol{\Lambda}^{-1}
      \Bigr)^{-1}
\end{aligned}

 超パラメータの初期値  \mathbf{m}^{(0)}, \boldsymbol{\Lambda}^{(0)} を用いて求めたパラメータを  \boldsymbol{\mu}_{*}^{(0)}, \boldsymbol{\Lambda}_{*}^{(0)} として、 \boldsymbol{\mu}_{*}^{(n-1)}, \boldsymbol{\Lambda}_{*}^{(n-1)} n-1 回目の更新値を表す。

 データ数  N が大きくなるのに応じて、未観測データ  \mathbf{x}_{*} の予測分布  \mathcal{N}(\mathbf{x}_{*} \mid \hat{\boldsymbol{\mu}}_{*}, \hat{\boldsymbol{\Lambda}}_{*}^{-1}) の形状が真の分布  \mathcal{N}(\mathbf{x}_n \mid \boldsymbol{\mu}_{\mathrm{truth}}, \boldsymbol{\Lambda}^{-1}) に近付いていくのを確認できる。

観測データと分布の関係

 データ数  N の変化による観測データと事後分布、予測分布の関係をアニメーションにする。

 真の分布の期待値  \mathbb{E}[\mathbf{x}_n] と真のパラメータ  \boldsymbol{\mu}_{\mathrm{truth}} の各図(分布)に対応する位置を赤色の直線(破線)、観測データの標本平均  \bar{\mathbf{x}} を桃色の直線(破線)、事後分布の期待値  \mathbb{E}[\boldsymbol{\mu}] と予測分布の期待値  \mathbb{E}[\mathbf{x}_{*}] を紫色の直線(破線)で示す。

 生成分布(ガウス分布)、事後分布(ガウス分布)、予測分布(ガウス分布)の期待値は、それぞれ次の式で計算できる。

 \displaystyle
\begin{aligned}
\mathbb{E}_{\mathcal{N}(\mathbf{x}_n \mid \boldsymbol{\mu}_{\mathrm{truth}}, \boldsymbol{\Lambda}^{-1})}[\mathbf{x}_n]
   &= \boldsymbol{\mu}_{\mathrm{truth}}
\\
\mathbb{E}_{\mathcal{N}(\boldsymbol{\mu} \mid \hat{\mathbf{m}}, \hat{\boldsymbol{\Lambda}}_{\mu}^{-1})}[\boldsymbol{\mu}]
   &= \hat{\mathbf{m}}
\\
\mathbb{E}_{\mathcal{N}(\mathbf{x}_{*} \mid \hat{\boldsymbol{\mu}}_{*}, \hat{\boldsymbol{\Lambda}}_{*}^{-1})}[\mathbf{x}_{*}]
   &= \hat{\boldsymbol{\mu}}_{*}
\\
   &= \hat{\mathbf{m}}
\end{aligned}

 真の分布  \mathcal{N}(\mathbf{x}_n \mid \boldsymbol{\mu}_{\mathrm{truth}}, \boldsymbol{\Lambda}^{-1}) や真のパラメータ  \boldsymbol{\mu}_{\mathrm{truth}} と、観測データ  \mathbf{X}、事後分布  \mathcal{N}(\boldsymbol{\mu} \mid \hat{\mathbf{m}}, \hat{\boldsymbol{\Lambda}}_{\mu}^{-1})、予測分布  \mathcal{N}(\mathbf{x}_{*} \mid \hat{\boldsymbol{\mu}}_{*}, \hat{\boldsymbol{\Lambda}}_{*}^{-1})、またそれぞれの統計量の対応関係を確認できる。

 以上で、平均が未知の多次元ガウスモデルのベイズ推論における学習推移を確認した。

 この記事では、平均が未知の場合の多次元ガウス分布に対するベイズ推論を扱った。次の記事では、精度が未知の場合を扱う。

参考文献

おわりに

 その時々では理解したつもりで記事にしているのですが、勉強を続けていると勘違いしていたことに気付くことも多々あります。適宜修正していきたいですけど結構大変。独学あるある?

  • 2021/04/10:加筆修正しました。その際に数式読解編と記事を分割しました。
  • 2022/09/20:加筆修正しました。

 多次元編でもfor()からtidyverseの関数に置き換えられそうで良かったです。推移の可視化の方で自分でもよく分からない処理をしていますがまぁ。

  • 2026.02.02:加筆修正しました。

 パラメータの計算処理はほとんど変わらず、分布の計算処理はシンプルな実装になり、グラフの作成処理は装飾用のコードが過多になりました。

 2026年2月2日は、ばってん少女隊の元メンバーの瀬田さくらさんの24歳のお誕生日です。

 卒業されてから1年が過ぎました。いかがお過ごしでしょうか。楽しい日々であればいいのですが。

 (新規記事を用意したかったけどダメでした。なんなら加筆修正も完全には間に合いませんでした。無念。)

【次節の内容】

  • 数式読解編

 多次元ガウスモデルに対するベイズ推論を数式で確認します。

www.anarchive-beta.com


  • スクラッチ実装編

 多次元ガウスモデルに対するベイズ推論をプログラムで確認します。

www.anarchive-beta.com




以上の内容はhttps://www.anarchive-beta.com/entry/2021/04/10/192703より取得しました。
このページはhttp://font.textar.tv/のウェブフォントを使用してます

不具合報告/要望等はこちらへお願いします。
モバイルやる夫Viewer Ver0.14