L-BFGS法の更新式を導出してみる。

ニュートン法の復習

最小化したい損失関数をL(w)とする。

wk+1=wkαkHk1gk

ただし、αkはステップ幅、gkは勾配ベクトル、Hkはヘッセ行列である。

詳しくは弊ブログのニュートン法の更新式を導出を参照。

BFGS法

まずBFGS法について。

セカント法 (secant method)によるヘッセ行列の近似

割線法ともいう。求根アルゴリズム(root-finding algorithm)の1種である。
関数f(x)が区間[xn2,xn1]で連続であり、かつ根が1つだけ存在する場合

xn+1=xnxnxn1f(xn)f(xn1)f(xn)

としてxを更新していく。

セカント法を用いてwkの更新式を求めると

wk+1=wkwkwk1gkgk1gk

となる。これは

HkBk=gkgk1wkwk1

とヘッセ行列Hkを正定値対称行列Bkで近似していることになる。

(TODO: なぜgkgk1wkwk1が正定値対称行列になるのか?)

Hkの更新

sk=wk+1wk,yk=gk+1gkとすると、 Hkの更新式は、

Hk+1=[IskykTykTsk]Hk[IykskTykTsk]+skskTykTsk

と表される(TODO: 要導出)

ρk=1ykTsk,Vk=IykskTykTskとすると、

Hk+1=VkTHkVk+ρkskskT

BFGS法ではヘッセ行列を陽に求める必要はなくなったが、
Hkは対称行列なのでO(n2n2+n)=O(n22+n2)のメモリ領域を必要とする。

LBFGS法

更新式を再帰的に展開すると、

Hk+1=(VkTV0T)H0(V0Vk)+ρ0(VkTV1T)s0s0T(V1Vk)+ρ1(VkTV2T)s1s1T(V2Vk)+ρkskskT

となる。
H0は正定値対角行列であり、 kステップ分のwgが必要になるのでメモリ使用量はO((2k+1)n)
つまりステップ数が増えるにつれメモリ使用量も増えてしまう。

そこで、全ステップではなく直近mステップのwgを使用するようにする。
mは通常<10などの小さい値が取られる。

Hk+1=(VkTVkmT)H0(VkmVk)+ρkm(VkTVkm+1T)skmskmT(Vkm+1Vk)+ρkm+1(VkTVkm+2T)skm+1skm+1T(Vkm+2Vk)+ρkskskT

直近mステップのみなのでメモリ使用量はO((2m+1)n)に抑えられる。

参考文献・URL