L-BFGS法の更新式を導出してみる。

ニュートン法の復習

最小化したい損失関数を$L(\boldsymbol{w})$とする。

\[\boldsymbol{w}_{k+1} = \boldsymbol{w}_{k} - \alpha_{k}\boldsymbol{H}_{k}^{-1}\boldsymbol{g}_{k}\]

ただし、$\alpha_{k}$はステップ幅、$\boldsymbol{g}_{k}$は勾配ベクトル、$\boldsymbol{H}_{k}$はヘッセ行列である。

詳しくは弊ブログのニュートン法の更新式を導出を参照。

BFGS法

まずBFGS法について。

セカント法 (secant method)によるヘッセ行列の近似

割線法ともいう。求根アルゴリズム(root-finding algorithm)の1種である。
関数$f(x)$が区間$[x_{n-2}, x_{n-1}]$で連続であり、かつ根が1つだけ存在する場合

\[x_{n+1} = x_{n} - \cfrac{x_{n}-x_{n-1}}{f(x_{n})-f(x_{n-1})} f(x_{n})\]

として$x$を更新していく。

セカント法を用いて$\boldsymbol{w}_{k}$の更新式を求めると

\[\boldsymbol{w}_{k+1} = \boldsymbol{w}_{k} - \cfrac{\boldsymbol{w}_{k}-\boldsymbol{w}_{k-1}}{\boldsymbol{g}_{k}-\boldsymbol{g}_{k-1}} \boldsymbol{g}_{k}\]

となる。これは

\[\boldsymbol{H}_{k} \approx \boldsymbol{B}_{k} = \cfrac{\boldsymbol{g}_{k}-\boldsymbol{g}_{k-1}}{\boldsymbol{w}_{k}-\boldsymbol{w}_{k-1}}\]

とヘッセ行列$\boldsymbol{H}_{k}$を正定値対称行列$\boldsymbol{B}_{k}$で近似していることになる。

(TODO: なぜ$\frac{\boldsymbol{g}_{k}-\boldsymbol{g}_{k-1}}{\boldsymbol{w}_{k}-\boldsymbol{w}_{k-1}}$が正定値対称行列になるのか?)

$\boldsymbol{H}_{k}$の更新

$\boldsymbol{s}_{k}=\boldsymbol{w}_{k+1}-\boldsymbol{w}_{k},\quad \boldsymbol{y}_{k}=\boldsymbol{g}_{k+1}-\boldsymbol{g}_{k}$とすると、 $\boldsymbol{H}_{k}$の更新式は、

\[\boldsymbol{H}_{k+1} = \left[\boldsymbol{I}-\cfrac{\boldsymbol{s}_{k}\boldsymbol{y}_{k}^{T}}{\boldsymbol{y}_{k}^{T}\boldsymbol{s}_{k}}\right] \boldsymbol{H}_{k} \left[\boldsymbol{I}-\cfrac{\boldsymbol{y}_{k}\boldsymbol{s}_{k}^{T}}{\boldsymbol{y}_{k}^{T}\boldsymbol{s}_{k}}\right] + \cfrac{\boldsymbol{s}_{k}\boldsymbol{s}_{k}^{T}}{\boldsymbol{y}_{k}^{T}\boldsymbol{s}_{k}}\]

と表される(TODO: 要導出)

$\boldsymbol{\rho}_{k}=\cfrac{1}{\boldsymbol{y}_{k}^{T}\boldsymbol{s}_{k}},\quad \boldsymbol{V}_{k}=\boldsymbol{I}-\cfrac{\boldsymbol{y}_{k}\boldsymbol{s}_{k}^{T}}{\boldsymbol{y}_{k}^{T}\boldsymbol{s}_{k}}$とすると、

\[\boldsymbol{H}_{k+1} = \boldsymbol{V}_{k}^{T} \boldsymbol{H}_{k} \boldsymbol{V}_{k} + \boldsymbol{\rho}_{k}\boldsymbol{s}_{k}\boldsymbol{s}_{k}^{T}\]

BFGS法ではヘッセ行列を陽に求める必要はなくなったが、
$\boldsymbol{H}_{k}$は対称行列なので$O(\frac{n^{2}-n}{2}+n)=O(\frac{n^{2}}{2}+\frac{n}{2})$のメモリ領域を必要とする。

LBFGS法

更新式を再帰的に展開すると、

\[\begin{align} \boldsymbol{H}_{k+1} =&\left(\boldsymbol{V}_{k}^{T}\cdots\boldsymbol{V}_{0}^{T}\right)\boldsymbol{H}_{0}\left(\boldsymbol{V}_{0}\cdots\boldsymbol{V}_{k}\right) \\ &+\rho_{0}\left(\boldsymbol{V}_{k}^{T}\cdots\boldsymbol{V}_{1}^{T}\right)\boldsymbol{s}_{0}\boldsymbol{s}_{0}^{T}\left(\boldsymbol{V}_{1}\cdots\boldsymbol{V}_{k}\right) \\ &+\rho_{1}\left(\boldsymbol{V}_{k}^{T}\cdots\boldsymbol{V}_{2}^{T}\right)\boldsymbol{s}_{1}\boldsymbol{s}_{1}^{T}\left(\boldsymbol{V}_{2}\cdots\boldsymbol{V}_{k}\right) \\ &\quad\vdots \\ &+\rho_{k}\boldsymbol{s}_{k}\boldsymbol{s}_{k}^{T} \end{align}\]

となる。
$\boldsymbol{H}_{0}$は正定値対角行列であり、 $k$ステップ分の$\boldsymbol{w}$と$\boldsymbol{g}$が必要になるのでメモリ使用量は$O((2k+1)n)$。
つまりステップ数が増えるにつれメモリ使用量も増えてしまう。

そこで、全ステップではなく直近$m$ステップの$\boldsymbol{w}$と$\boldsymbol{g}$を使用するようにする。
$m$は通常$<10$などの小さい値が取られる。

\[\begin{align} \boldsymbol{H}_{k+1} =&\left(\boldsymbol{V}_{k}^{T}\cdots\boldsymbol{V}_{k-m}^{T}\right)\boldsymbol{H}_{0}\left(\boldsymbol{V}_{k-m}\cdots\boldsymbol{V}_{k}\right) \\ &+\rho_{k-m}\left(\boldsymbol{V}_{k}^{T}\cdots\boldsymbol{V}_{k-m+1}^{T}\right)\boldsymbol{s}_{k-m}\boldsymbol{s}_{k-m}^{T}\left(\boldsymbol{V}_{k-m+1}\cdots\boldsymbol{V}_{k}\right) \\ &+\rho_{k-m+1}\left(\boldsymbol{V}_{k}^{T}\cdots\boldsymbol{V}_{k-m+2}^{T}\right)\boldsymbol{s}_{k-m+1}\boldsymbol{s}_{k-m+1}^{T}\left(\boldsymbol{V}_{k-m+2}\cdots\boldsymbol{V}_{k}\right) \\ &\quad\vdots \\ &+\rho_{k}\boldsymbol{s}_{k}\boldsymbol{s}_{k}^{T} \end{align}\]

直近$m$ステップのみなのでメモリ使用量は$O((2m+1)n)$に抑えられる。

参考文献・URL