Contrastive Divergence 法
Contrastive Divergence法 (Hinton, 2002)について少し勉強したので,そのまとめです.
Contrastive Divergence 法とは
(確率的な)最適化方法です.正確には正規化定数が分からない(求めるのが困難)確率分布のためのパラメータの最尤法です.特にBoltzmann分布(マルコフ確率場)における最尤法を指します.
何がうれしいか
正規化定数が分からない確率分布に対しても最尤推定量に近い推定量が得られることが利点です.
Boltzmann分布の例を挙げます.Boltzmann分布とは
\[
p(\boldsymbol{x};\boldsymbol{\theta},\boldsymbol{W})=\frac{1}{Z(\boldsymbol{\theta},\boldsymbol{W})}\exp(-\boldsymbol{\theta}^{\top}\boldsymbol{x}-\boldsymbol{x}^{\top}\boldsymbol{W}\boldsymbol{x})
\]
という確率分布です.ここで$\boldsymbol{x}=(x_1,\ldots,x_d)\in\{0,1\}^d$, $\boldsymbol{\theta}\in\mathbf{R}^d$, $\boldsymbol{W}\in\mathbf{R}^{d\times d}$であり,
\[
Z(\boldsymbol{\theta},\boldsymbol{W})=\sum_{\boldsymbol{x}\in\{0,1\}^d}\exp(-\boldsymbol{\theta}^{\top}\boldsymbol{x}-\boldsymbol{x}^{\top}\boldsymbol{W}\boldsymbol{x})
\]
は正規化定数です.また
\[
E(\boldsymbol{x};\boldsymbol{\theta},\boldsymbol{W})=\boldsymbol{\theta}^{\top}\boldsymbol{x}+\boldsymbol{x}^{\top}\boldsymbol{W}\boldsymbol{x}
\]
はエネルギー関数と呼ばれます.注意すべきは$d$が比較的多きときに$Z(\boldsymbol{\theta},\boldsymbol{W})$の計算は非常に大変であるということです.例えば$d=10$のときには$
Z(\boldsymbol{\theta},\boldsymbol{W})$を計算するのに$\exp(-\boldsymbol{\theta}^{\top}\boldsymbol{x}-\boldsymbol{x}^{\top}\boldsymbol{W}\boldsymbol{x})$を1024回足さなければなりません.$d=20$なら約100万回足さなければなりません.
Contrastive Divergence 法はどういう方法なのか
この確率分布からの標本$\boldsymbol{x}_1,\ldots,\boldsymbol{x}_n$が得られたときに,最尤推定量$\hat{\boldsymbol{\theta}}_{\rm mle}, \hat{\boldsymbol{W}}_{\rm mle}$を数値的に求めるには対数尤度 $\log p(\boldsymbol{x};\boldsymbol{\theta},\boldsymbol{W})$の勾配が必要です.例えば$\boldsymbol{\theta}=(\theta_1,\ldots,\theta_d)^{\top}$の勾配を計算すると
\[
\frac{\partial}{\partial \theta_i}\log p(\boldsymbol{x};\boldsymbol{\theta},\boldsymbol{W})=-x_i-\frac{\partial}{\partial \theta_i}\log Z(\boldsymbol{\theta},\boldsymbol{W})
\]
となり$Z(\boldsymbol{\theta},\boldsymbol{W})$の計算が必要となります.最尤推定量の計算には勾配法が基本的ですが,Boltzamnn分布の対数尤度の勾配を計算するのは困難であることが上の式からわかります.
そこでContrastive Divergence 法(CD法)が考え出されました(Hinton, 2002).Contrastive Divergence法の考え方はシンプルです.
表記の簡便のため,パラメータ$\boldsymbol{\xi}=(\boldsymbol{\theta},\boldsymbol{W})$とまとめておきます.$\xi_i$方向の勾配は
\[
\frac{\partial}{\partial \xi_i}\log p(\boldsymbol{x};\boldsymbol{\theta},\boldsymbol{W})=-\frac{\partial}{\partial \xi_i}E(\boldsymbol{x};\boldsymbol{\theta},\boldsymbol{W})-\frac{\partial}{\partial \xi_i}\log Z(\boldsymbol{\theta},\boldsymbol{W})
\]
移項して
\[
\frac{\partial}{\partial \xi_i}\log Z(\boldsymbol{\theta},\boldsymbol{W})=-\frac{\partial}{\partial \xi_i}\log p(\boldsymbol{x};\boldsymbol{\theta},\boldsymbol{W})-\frac{\partial}{\partial \xi_i}E(\boldsymbol{x};\boldsymbol{\theta},\boldsymbol{W})
\]
左辺は$\boldsymbol{x}$に依らないです.両辺を$\boldsymbol{x}$ ($\sim p(\boldsymbol{x};\boldsymbol{\theta},\boldsymbol{W})$)で期待値をとって
\[
\frac{\partial}{\partial \xi_i}\log Z(\boldsymbol{\theta},\boldsymbol{W})=-\mathrm{E}_{\tilde{\boldsymbol{x}}\sim p(\tilde{\boldsymbol{x}};\boldsymbol{\theta},\boldsymbol{W})}\bigg[\frac{\partial}{\partial \xi_i}E(\tilde{\boldsymbol{x}};\boldsymbol{\theta},\boldsymbol{W})\bigg]
\]
この関係式を代入して,勾配は
\[
\frac{\partial}{\partial \xi_i}\log p(\boldsymbol{x};\boldsymbol{\theta},\boldsymbol{W})=-\frac{\partial}{\partial \xi_i}E(\boldsymbol{x};\boldsymbol{\theta},\boldsymbol{W})+\mathrm{E}_{\tilde{\boldsymbol{x}}\sim p(\tilde{\boldsymbol{x}};\boldsymbol{\theta},\boldsymbol{W})}\bigg[\frac{\partial}{\partial \xi_i}E(\tilde{\boldsymbol{x}};\boldsymbol{\theta},\boldsymbol{W})\bigg]
\]
となります.このままでは$\mathrm{E}_{\tilde{\boldsymbol{x}}\sim p(\tilde{\boldsymbol{x}};\boldsymbol{\theta},\boldsymbol{W})}$の計算の際に$2^d$回の和が必要なので,計算が困難なことには変わりありません.
そこでCD法では期待値の計算をモンテカルロ近似します.つまり$p(\tilde{\boldsymbol{x}};\boldsymbol{\theta},\boldsymbol{W})$に従うサンプルを$k$個$(\tilde{\boldsymbol{x}}_1,\ldots,\tilde{\boldsymbol{x}}_k)$サンプリングし,
\[
\mathrm{E}_{\tilde{\boldsymbol{x}}\sim p(\tilde{\boldsymbol{x}};\boldsymbol{\theta},\boldsymbol{W})}\bigg[\frac{\partial}{\partial \xi_i}E(\tilde{\boldsymbol{x}};\boldsymbol{\theta},\boldsymbol{W})\bigg]\approx
\frac{1}{k}\sum_{m=1}^k \frac{\partial}{\partial \xi_i}E(\tilde{\boldsymbol{x}}_m;\boldsymbol{\theta},\boldsymbol{W})
\]
とします.これにより勾配の計算に必要な計算コストををおさえます.
サンプリング方法はマルコフ連鎖モンテカルロ法(MCMC),特にギブスサンプリングです.*1
CD法はMCMCが収束するように十分遷移させてサンプルするのではなく,少ない回数だけ遷移させます.
初期値は経験分布です.毎回初期値を経験分布にリセットするのはもったいない気がします.Persistent Cotrastive Divergenceという方法もあるみたいで,これはパラメータを更新するたびにMCMCの初期値を前のステップでMCMCで得た分布を使うものです.
まとめるとCD法は以下のような手順で最尤推定量を求める方法です.
- 初期値 $\boldsymbol{\xi}_0=(\boldsymbol{\theta}_0,\boldsymbol{W}_0)$ を決める
- 勾配 $\frac{\partial}{\partial \boldsymbol{\xi}}\log p(\boldsymbol{x};\boldsymbol{\theta}_k,\boldsymbol{W}_k)$ の近似値をモンテカルロ法で計算
(初期値は経験分布,MCMCの遷移回数は少ない)
- $\boldsymbol{\xi}_{k+1} \leftarrow \boldsymbol{\xi}_{k} + \gamma_k \times \frac{\partial}{\partial \boldsymbol{\xi}}\log p(\boldsymbol{x};\boldsymbol{\theta}_k,\boldsymbol{W}_k)$ ($\gamma_k$: ステップ幅)
- 停止条件を満たしてなければステップ2へ
少し考えると,この方法では勾配の計算の際,少ないMCMCの遷移回数で近似することにしています.
単純に考えると,MCMCは十分多く遷移させると所望の確率分布に収束する方法なので少数の遷移回数でよいのかは疑問です.
しかし,実は近似のための遷移回数が少なくてもCD法はうまく動くということが経験的に知られています.
自分としては近似のための遷移回数が少なくてもCD法はうまく動くというのが不思議だったので,少し勉強&検証してみました.
数値実験
まずは数値実験をして検証してみました.
設定としては$d=3$の小さい次元の場合で検証しました.サンプルサイズ$10,100,200$ごとに100回
$\boldsymbol{\theta}=(\theta_1,\theta_2,\theta_3)$を平均3,分散1の正規分布から,
$\boldsymbol{W}$の上三角かつ非対角成分$(W_{12},W_{13},W_{23})$を平均-3,分散1の正規分布から,
発生させ,それを真の値とするBoltzmann分布からサンプルサイズ分だけ標本を得ます($\boldsymbol{W}$の他のパラメータの値は全て0というモデルを考えています).
各回ごとに,最尤推定量,およびCD1,CD3を計算します.
ここで,CD1,CD3は共にCD法で,それぞれモンテカルロ近似の際に用いた遷移の数が1または3であることを意味します.
CD1,CD3で得られた推定量と最尤推定量との距離を平均二乗誤差$\|\hat{\boldsymbol{\xi}}_{\rm mle}-\hat{\boldsymbol{\xi}}_{\rm cd}\|_2$で計算しました
(各サンプルサイズごとに100回CD1とCD3の平均二乗誤差が計算されます).
結果は以下の通りです.コードは最後に回します.
この結果からCD3の方がCD1よりも最尤推定量に近い推定量を得られることが分かります.
CD3の方が最尤推定量からのばらつきが小さいです.
またサンプルサイズが大きくなるにつれてCD1もCD3も最尤推定量との距離が近くなっています.
個人的にはCD3でも結構最尤推定量に近いのが意外でした.
なぜ少ないサンプルでも良いのか?
やはり経験的に知られているように,今回の数値実験でも
少ない遷移回数による勾配の近似計算でも最尤推定量に近い推定量が得られました.
調べている内に次の文献にたどり着きました:
S. Akaho and K. Takabatake. (2008).
Information geometry of contrastive divergence.
International Conference on Information theory and statistical learning.
情報幾何では確率分布を多様体の一点としてとらえます.
沢山の遷移させなければMCMCで得られる確率分布(1点)はモデルの確率分布(1点)には収束しないですが,
経験分布からMCMCを数ステップ回して得られる確率分布への"方向"が
経験分布からモデル(多様体)への射影の方向に近いためうまくいくということになります.
CD法ではないですが
K. Takabatake. (2004).
Information geometry of Gibbs sampler.
Proceedings of the WSEAS International Conferences.
も面白いです.Gibbs samplerが貪欲的に経験分布からの距離を最小化しているというのは知りませんでした.
コード
#contrastive divergence #maximum likelihood estimation #Boltzmann distribution ####################### #integer to binary dec2bin <- function(num, digit=0){ if(num <= 0 && digit <= 0){ return(NULL) }else{ return(append(Recall(num%/%2,digit-1), num%%2)) } } ######################## #Boltzmann distribution energy_boltz <- function(state, theta, W){ #state: binary vector (length=dim(theta)) #Energy function: Energy(state)=state%*%theta+state%*%W.mat%*%state #prob of state: p(state)=exp(-Energy(state))/Z #Z: normalization const #W: W[1]=W_{12}, W[2]=W_{23},... vardim <- length(theta) W.mat <- matrix(0, ncol=vardim, vardim) W.mat[upper.tri(W.mat)] <- W exp(-state%*%theta-state%*%W.mat%*%state) } dboltz <- function(theta, W){ #distribution function of Boltzmann #p(x|theta,W) = exp(-Energy(state))/Z vardim <- length(theta) states <- 2^vardim probs <- sapply(0:(states-1),dec2bin,digit=vardim) probs <- rbind(probs, apply(probs,2,energy_boltz,theta=theta,W=W)) probs[(vardim+1),] <- probs[(vardim+1),]/sum(probs[vardim+1,]) #normalization t(probs) } rboltz <- function(n, theta, W){ #random n samples from Boltzmann probs <- dboltz(theta, W) vardim <- length(theta) states <- 2^vardim samples <- matrix(0, ncol=vardim, nrow=n) sample10 <- sample(x=0:(states-1),size=n,prob=probs[,vardim+1],replace=T) #sampling samples <- sapply(sample10,dec2bin,digit=vardim) #integer to binary t(samples) } ######################## #Maximum Likelihood Estimation #gradient ascent likegrad <- function(samples,theta,W,gradtheta.emp,gradW.emp){ thetadim <- length(theta) Wdim <- length(W) offdiag <- t(combn(thetadim,2)) #off-diagonal index of W gradtheta <- rep(0, thetadim) gradW <- rep(0, Wdim) probs <- dboltz(theta,W) gradtheta <- -gradtheta.emp+apply(probs[,1:thetadim]*probs[,thetadim+1],2,sum) #gradient of theta gradW.model <- sapply(1:Wdim, function(i) sum(apply(probs[,c(offdiag[i,],(thetadim+1))],1,prod))) #mean of gradient of W gradW <- -gradW.emp + gradW.model return(c(gradtheta, gradW)) } findmle <- function(samples,theta.ini,W.ini,lrate=1,maxiter=100){ pars <- c(theta.ini, W.ini) thetadim <- length(theta.ini) offdiag <- t(combn(thetadim,2)) Wdim <- length(W.ini) theta <- theta.ini W <- W.ini gradtheta.emp <- apply(samples,2,mean) #empirical mean of gradient of theta gradW.emp <- sapply(1:Wdim, function(i) mean(apply(samples[,offdiag[i,]],1,prod))) #empirical mean of gradient of W for(i in 1:maxiter){ pars <- pars + likegrad(samples,theta,W,gradtheta.emp,gradW.emp)*lrate #gradient ascent update theta <- head(pars,thetadim) W <- head(pars,Wdim) } rtrn.ls <- list(theta, W) names(rtrn.ls) <- c("theta.mle", "W.mle") return(rtrn.ls) } ######################## #Contrastive Divergence Method gibbsupdate <- function(state,theta,W,mcmciter){ dimstate <- length(state) state.now <- state #Gibbs sampler update #update only one component of state for(i in 1:mcmciter){ state.new <- state.now updateindex <- sample(dimstate,1) state.new[updateindex] <- 1-state.now[updateindex] accept <- energy_boltz(state.new,theta,W) reject <- energy_boltz(state.now,theta,W) acceptflag <- sample(x=1:0,size=1,prob=c(accept,reject)) state.now <- if(acceptflag) state.new else state.now } state.now } gibbslikegrad <- function(samples,theta,W,gradtheta.emp,gradW.emp,mcmciter){ thetadim <- length(theta) Wdim <- length(W) offdiag <- t(combn(thetadim,2)) gradtheta <- rep(0, thetadim) gradW <- rep(0, Wdim) mcmcsamples <- t(apply(samples,1,function(state)gibbsupdate(state,theta,W,mcmciter))) #samples from Gibbs sampler gradtheta <- -gradtheta.emp+apply(mcmcsamples,2,mean) gradW.model <- sapply(1:Wdim, function(i) mean(apply(mcmcsamples[,offdiag[i,]],1,prod))) gradW <- -gradW.emp + gradW.model return(c(gradtheta, gradW)) } findCD <- function(samples,theta.ini,W.ini,lrate=1,maxiter=100,mcmciter=1){ pars <- c(theta.ini, W.ini) thetadim <- length(theta.ini) offdiag <- t(combn(thetadim,2)) Wdim <- length(W.ini) theta <- theta.ini W <- W.ini gradtheta.emp <- apply(samples,2,mean) gradW.emp <- sapply(1:Wdim, function(i) mean(apply(samples[,offdiag[i,]],1,prod))) for(i in 1:maxiter){ pars <- pars + gibbslikegrad(samples,theta,W,gradtheta.emp,gradW.emp,mcmciter)*lrate #gradient ascent theta <- head(pars,thetadim) W <- tail(pars,Wdim) } rtrn.ls <- list(theta, W) names(rtrn.ls) <- c("theta.cd", "W.cd") return(rtrn.ls) } mle_to_cd <- function(n, n.sample, dim){ print(n) dimW <- (dim-1)*dim/2 risk <- rep(0, 2) theta.true <- rnorm(dim, mean=3) W.true <- rnorm(dimW, mean=-3) par.true <- c(theta.true, W.true) names(par.true) <- c("theta.true", "W.true") samples <- rboltz(n.sample,theta.true,W.true) result.mle <- findmle(samples,theta.ini=rep(0,dim),W.ini=rep(0,dimW),maxiter=500) result.cd1 <- findCD(samples,theta.ini=rep(0,dim),W.ini=rep(0,dimW),maxiter=500) result.cddim <- findCD(samples,theta.ini=rep(0,dim),W.ini=rep(0,dimW),maxiter=500,mcmciter=dim) #Mean Squared Errors risk[1] <- sqrt(sum((unlist(result.mle)-unlist(result.cd1))^2)/(dim+dimW)) risk[2] <- sqrt(sum((unlist(result.mle)-unlist(result.cddim))^2)/(dim+dimW)) risk }