Backpropしないニューラルネット入門 (2/2)

1. 概要

下記のarXiv論文を紹介します。

Jinshan Zeng, Tim Tsz-Kit Lau, Shaobo Lin, Yuan Yao (2018).
Block Coordinate Descent for Deep Learning: Unified Convergence Guarantees.arXiv:1803.00225

現時点では投稿されて間もない論文ですが、個人的には機械学習の論文を読んでいて久々に楽しい気持ちになれました。

論文の提案手法はgradient-free methodと呼ばれる手法の一種なので、本記事はそのあたりのレビューも少し兼ねます。

2. 勾配法の収束条件

ニューラルネットの構造をひとつ固定し、その構造を使って表せる関数の全体を $\mathcal{F}$ と書きます。ニューラルネットの学習とは、与えられた損失を最小化する関数を見つけることです。例えば、二乗損失なら
$$
\hat{f} \in \arg \min_{f \in \mathcal{F}} \sum_{i=1}^n (y_i - f(x_i))^2
$$
のようになります。Part 1で述べたとおり、この最適化問題が解けるという前提のもとで、いくつかの新しい汎化誤差バウンドが得られています。しかし、これは一般には非凸最適化となり、大域最適解へが得られることは保証できません。

現在、ニューラルネットの学習はSGD, AdaGrad, Adam, RMSPropといった勾配法ベースの最適化手法が使われていることが多く、主要な深層学習ライブラリも、勾配計算 (Backprop) の機能を提供することを主な目的としていると思います。これらの手法は、目的関数が凸であれば大域解への収束が保証されていることが多いです。

が、実際に興味のある目的関数は凸ではなく、もっというと、ReLUを使った場合は微分可能ですらありません。ちなみに、目的関数が凸でも、微分可能性のようなもの (正確にはsmoothness) は収束レートに関わることがあります。

また近年、ハードウェアへの移植などを目的として、Binarized Neural Networksのようにパラメータが離散的な値をとる深層学習モデルがあります。このようなモデルではそもそも勾配が存在しないため、勾配ベースの最適化アルゴリズムを直接適用することはできません。

3. Block Coordinate Descent (BCD, ブロック座標降下法)

だけど、

  • 目的関数が非凸でも
  • 微分不可能でも
  • 勾配を使わなくても
  • 初期値をどこに設定しても

局所解に収束することを理論的に示せるよ。ブロック座標降下法 ならね!!

3.1. アルゴリズムの概要

記号

$T$ 層のニューラルネットを学習する問題を抽象的に定義します。まず、

  • $d_0, d_1, \ldots, d_T$ : ニューラルネットの各層の幅。特に、$d_0$ は入力 $x$ の次元、$d_T$ はラベル $y$ の次元と一致
  • $W_i \in \mathbb{R}^{d_{i} \times d_{i - 1}}$ : 第 $i$ 層のパラメータ
  • $\sigma_i$ : 活性化関数

として、ニューラルネットで表現している関数は
$$
f(x) = W_T \sigma_{T - 1} (W_{T-1} \sigma_{T-2}(\cdots \sigma_{1}(W_1 x) \cdots ))
$$
であるとします。また、

  • $X := [x_1, \ldots, x_n] \in \mathbb{R}^{d_0 \times n}$ : 入力データ
  • $Y := [y_1, \ldots, y_n] \in \mathbb{R}^{d_T \times n}$ : ラベルデータ
  • $\mathcal{L}(f, y)$ : 損失関数 (非負・連続だが凸とは限らない)

として、経験損失を
$$
\ell(f(X); Y) := \frac{1}{n} \sum_{i = 1}^n \mathcal{L}(f(x_i), y_i)
$$
と定義します。$r_i, s_i$ を事前分布や正則化項を表す関数とします。このとき、興味のある最適化問題は、以下のような等式制約のある問題として書けます。

なお、本来最適化すべきパラメータは $W_i$ ($i = 1, \ldots, T$) ですが、スラック変数を導入して $(W_i, V_i, U_i)$ の三つ組を最適化するという意味で、論文中では上の最適化問題を3-splitting formulationと呼んでいます。

BCDアルゴリズム

BCDアルゴリズムのアイデアは簡単です。まず、上記の等式制約をFrobeniusノルムによるペナルティとして表現します。

拡張ラグランジュ関数 から線形のラグランジュ乗数を取り除いた格好ですが、なんと呼ぶのが適切かはちょっと知りません。二次バリア関数とか?

あとは、$V_i, U_i, W_i$ のそれぞれに関する最小化をサイクリックに行うだけです。論文からアルゴリズムを引用します。

それぞれの部分問題はちょうど近接作用素を求めるような形になっており、例えば「損失 $\ell$ が2次関数、$\sigma$ がReLU、$s_i$, $r_i$ が0」といった典型的な場合ではclosed-formの解が得られます (e.g. Lemma 1).

ニューラルネットがconvolution層を含む場合も等式制約を追加すればよいので、アルゴリズムを拡張することは容易であることが述べられています。また、ResNetに拡張した例がAlgorithm 2として提案されています。

なお、後述するように、このアルゴリズムには理論収束保証もあります。しかし、オリジナルの問題についてではなく、この $\bar{\mathcal{L}}$ が局所解に収束することを保証するものですので、その点に関しては注意が必要です。

3.2. 数値例

論文の数値例を紹介します。追実験などはしていません。

$T = 4$ 層のDNNでMNISTとCIFAR-10の学習を、BCDおよびSGDで行った結果の引用です。NNの構造やSGDの学習率の詳細は論文を参照してください。

実時間ではなくエポック数の比較ですが、SGDが学習に約100エポック必要とするのに対し、BCDは5エポック以下(!)で学習がほとんど終了していることがわかります。また、最終的なtest accuracyもSGDと比較して遜色ないものとなっています (Table 1).

定式化から明らかなようにBCDの1回の更新にかかるコストは大きく、メモリに乗らない量のデータをどう扱うかなど疑問点はありますが、promissingな結果であることは間違いないと思われます。

3.3. その他のgradient-free method / 先行研究

先行研究に少し触れておきます。(このあたりはまだ勉強中です)。

ニューラルネットの学習をgradient-freeに行うアルゴリズムとしては、 ADMMベースの手法[1],[2],BCDベースの手法[3]が提案されています。Gradient-freeな方法の動機として、紹介論文や [1] では勾配法との比較が述べられています。計算面での主な違いは、勾配法は小さなミニバッチを使った低コストな更新を多数回行うのに対し、勾配フリー法は多くのデータを使った高コストな更新を小数回行う手法になっており、前者はGPU、後者はCPU上での実行が適しています。勾配法の収束については、勾配消失やプラトーなど既知の問題がありますが、勾配フリー法では紹介論文のような収束保証が得られています。

[1] Gavin Taylor, Ryan Burmeister, Zheng Xu, Bharat Singh, Ankit Patel, Tom Goldstein (2016).
Training Neural Networks Without Gradients: A Scalable ADMM Approach. In ICML2016. arXiv:1605.02026

[2] Ziming Zhang, Yuting Chen, Venkatesh Saligrama (2016).
Efficient Training of Very Deep Neural Networks for Supervised Hashing. In CVPR2016. arXiv:1511.04524

[3] Ziming Zhang, Matthew Brand (2017).
Convergent Block Coordinate Descent for Training Tikhonov Regularized Deep Neural Networks. In NIPS2017. arXiv:1711.07354

4. 収束保証

BCDアルゴリズムでは、$\bar{\mathcal{L}}$ が局所解に収束することが、比較的広い条件のもとで示されています。

4.1. Kurdyka-Łojasiewicz (KL)条件

最近の非凸最適化の論文では、目的関数に課す正則条件としてKurdyka-Łojasiewicz (KL) propertyというものがトレンドになっているようです。まず定義を書きます

記号

(凸とは限らない) 関数 $f: \mathbb{R}^d \to \mathbb{R}\cup \{ + \infty \}$ の $x \in \mathbb{R}^d$ における劣勾配 $g \in \partial f(x)$ とは $(x_k, g_k) \to (x, g)$ となる点列が存在して、$f(x_k) \to f(x)$ かつ
$$
f(x^\prime) \geq f(x_k) + \langle g_k, x^\prime - x_k \rangle + o(\lVert x^\prime - x_k \rVert_2)
$$
が成り立つことをいう (正確な定義は この文献 など)。劣勾配の集合 $\partial f(x)$ を $x$ における $f$ の劣微分という.

$\partial f(x)$ が空でない $x \in \mathbb{R}^d$ の集合を $\mathrm{dom}(\partial f)$ と書く。また、
$$
\mathrm{dist}(0, \partial f(x)) := \min\{ \lVert g \rVert_2: g \in \partial f(x) \}
$$
とする。

定義:KL条件

関数 $f: \mathbb{R}^d \to \mathbb{R} \cup \{ + \infty \}$ が点 $x^* \in \mathrm{dom}(\partial f)$ においてKL propertyをもつとは、3つ組

  • $\eta \in (0, +\infty]$
  • $x^*$ の近傍 $U$
  • 凹関数 $\varphi: [0, \eta) \to \mathbb{R}_+$

が存在して以下の条件を満たすことをいう:

  1. $\varphi(0) = 0$, $\varphi$ は $(0, \eta)$ で $C^1$ 級
  2. 任意の $s \in (0, \eta)$ に対して $\varphi^\prime(s) > 0$
  3. 任意の $x \in U \cap \{ x: f(x^*) < f(x) < f(x^*) + \eta \}$ に対して、次の Kurdyka-Łojasiewicz不等式 が成り立つ
    $$
    \varphi^\prime(f(x) - f(x^*)) \mathrm{dist}(0, \partial f(x)) \geq 1.
    $$

KL不等式が何を言っているかというのが問題ですね。KL条件の意味についてはこのスライド がわかりやすいです。(確か、以前tmaeharaさんに教えていただいたものです)

まず、$0 \in \partial f(x)$ というのは $x$ が $f$ の停留点であるということなので、逆に $\mathrm{dist}(0, \partial f(x))$ が大きいならば点 $x$ において $f$ が大きく傾いているということを表します。もし $f$ の等高線でスライスした領域 $\{ x: a < f(x) < b \}$ において一様に $\mathrm{dist}(0, \partial f(x)) \geq c > 0$ が成り立つのであれば、関数の「谷」のような部分に向かって急に落ち込んでいるものと考えられます。

(図は Bolte(2015) から引用)

この性質が成り立つとき、$f$ はこのスライスにおいてsharpである、ということにします。もし $f$ がsharpであれば、近接点法のようなアルゴリズムは有限ステップで「谷底」のような部分に到達できます (ラフな証明は Bolte(2015) p.17)。

KL条件は、本質的には
$$
\mathrm{dist}(0, \partial(\varphi\circ (f(x) - f(x^*)))) \geq 1
$$
と等価です。これは、凹関数 $\varphi$ によって、等高線のスライスを高さ $f(x^*)$ の「谷底」に近づける速さをコントロールしたときにsharpと見做せる、ということを述べています。よって、なんとなくですが、$f$ 自身が凸でなくても、谷底に到達する何らかの手段は手に入るような感じがします。

KL条件を満たす関数

重要な点は、実はかなり多くの関数がKL条件を満たすことがわかります。というか、「KL条件を満たす関数がたくさんある」ということそのものがこの理論の根幹をなす結果になっているようです。

  • 解析関数はKL条件を満たす。
  • グラフが半代数的集合になる関数を半代数的関数 (semialgebraic function) という。すべての半代数的関数はKL条件を満たす。
  • 半代数的関数は有限個の和や積、合成などの操作で閉じる。

例えば、2次関数やReLUなどは半代数的関数なのでKL条件を満たします。よって、それらの和・積・合成によってできている訓練損失もKL条件を満たすことが示せます (Proposition 1)。

KL条件から何が言えるか

ずばり、近接点法や座標降下法のようなアルゴリズムの収束保証と収束レートの導出に使えます[4] [5]。具体的には、近接点法では $\varphi(s) = c s^(1- \theta)$ ($\theta \in [0, 1)$) のような形をしているとき、$\theta = 0$ なら有限ステップで停留点に収束、$0 < \theta \leq 1/2$ ならば誤差が $O(\exp(-k))$ で収束、$1/2 < \theta < 1$ ならば $O(k^{-(1 - \theta)/(2\theta - 1)})$ で収束、という結果が得られており [4]、紹介論文でもこの結果に帰着することで収束レートを出しています。

[4] Hedy Attouch, Jérôme Bolte (2009).
On the convergence of the proximal algorithm for nonsmooth functions involving analytic features. Mathematical Programming, 116, 5–16. プレプリント

[5] Hedy Attouch, Jérôme Bolte, Benar Fux Svaiter (2013).
Convergence of descent methods for semi-algebraic and tame problems: proximal algorithms, forward–backward splitting, and regularized Gauss–Seidel methods. Mathematical Programming, 137, 91–129. URL

4.2. BCDの収束

DNNに対するBCDの収束は次のことが示せます。($\bar{\mathcal{L}}$ の収束であることに注意)

仮定

  • $i = 1, \ldots, T-1$ について、活性化関数 $\sigma_i$ は $L$-Lipschitzである
  • 関数 $\bar{\mathcal{L}}$ はある点でKL条件を満たす。特に、(a) $s_i, r_i$ が半代数的関数であり、(b) 損失関数が2乗損失、ロジスティック損失、ヒンジ損失のいずれかであり、(c) 活性化関数がReLU、シグモイド、線形リンクのいずれかであれば、どの組み合わせでもこの条件は満たされる。

Theorem 2 (要点のみ)
上の仮定が成り立つとする。BCDアルゴリズムの第 $k$ ステップでの出力を $P^k$ とする。

  1. BCDアルゴリズムにおいて、$\mathrm{dist}^2(0, \partial \bar{\mathcal{L}}(P^k))$ のrunning best rate ($k$ステップまでの最良値) は $o(1/k)$ である。
  2. 関数 $\bar{\mathcal{L}}$ がある点において、$\varphi(s) = c s^{1 - \theta}$ に対してKL条件を満たすならば、停留点への収束レートが得られる。

5. 結論

現在、「GPUに廃課金して勾配法で殴る」という手法が主流になっていると思われますが、これを機に今後はgradient-freeな手法にも期待できるかもしれません。

(参考)