はじめに
ニューラルネットワークの強みは「特徴学習」を行うことにある。特徴学習とは、データから適応的に基底関数(特徴)を学習すること指す。カーネル法のような従来手法では基底関数の学習をしなかったわけだが、データに適応的な基底をとることによって精度の面でも学習速度の面でもしばしば得をすることが経験的にわかっている。これらの事実に対して理論的に気になるのは
- ニューラルネットは勾配降下法によって証明可能な特徴学習をするか?
- 勾配ベースの学習はどれだけ効率がよいか?最適か?
という点である。
この文章では、最近のニューラルネットにおける特徴学習研究の最前線を追ってみてまとめたものを紹介してみる。わかったことは、ニューラルネットはいい感じのアルゴリズムを使えば理論的にもよいサンプル・計算複雑性を達成するということである。ここでは特に、学習対象としてSingle Index Modelとよばれるものを利用したものを取り扱う。ベースは[1]の鈴木先生のスライドに基づくが、原論文を追って自分なりに再構築したものである。
Single Index Modelの定義と問題設定
ニューラルネットワークの特徴学習能力について調べるために、しばしば学習対象として使われるのがSingle Index Modelとよばれるものである:
$$ f _ \ast(\boldsymbol{x})=σ _ \ast(\langle \boldsymbol{\theta},\boldsymbol{x}\rangle) $$
多くの場合与える入力データは $\boldsymbol{x}\sim\mathcal{N}(0,I _ d)$ ととられる。推定する必要があるのは $\boldsymbol{\theta}\in\mathbb{R} ^ d$ とリンク関数 $σ _ \ast$ である。多くの場合、情報指数(詳しくは後述)が $k _ \ast$ となるような多項式がとられる:
$$ σ _ \ast(\boldsymbol{x})=\sum _ {i=k _ \ast} ^ p \alpha _ i\mathrm{He} _ i(\boldsymbol{x}) $$
この場合 $i<k _ \ast$ については $\alpha _ i=0$ であり、はじめて $\alpha _ i\neq 0$ となるのが $k _ \ast$ である。なお、リンク関数は非線形なReLU関数に設定する場合もあるが、ここでは多項式のみを考えることにする。また、観測する出力としては $f _ \ast(\boldsymbol{x})$ にノイズが乗ったものを考える場合もあるが、ここであまり深く考えないことにする。
簡単のために、自分が準備する学習関数は $f(\boldsymbol{x})=σ(\langle \boldsymbol{x},\boldsymbol{w}\rangle)$ とした設定を考える。ここで $σ$ は次のように書けるとする。
$$ σ(\boldsymbol{x})=\sum _ {i=0} ^ \infty \beta _ i\mathrm{He} _ i(\boldsymbol{x}) $$
学習するパラメータは $\boldsymbol{w}$ である。 $\lbrace \beta _ i\rbrace $ を2層目とみるなら、2層目を固定したニューラルネットワークとも見れるかもしれない。学習パラメータはたかが1層であるが、活性化関数の中身を学習するという意味では、どのように基底を学習するかを解析するためのトイモデルとなっている。
この文章を通して考えたい問いは、このSingle Index Modelを学習するために、どれくらいのサンプル数 $n$ が必要か?ということである。ここでの学習は、学習対象とほぼ無相関な初期状態から、それなりに相関をもつような状態へ移り変わることを指している(weak recovery)。
学習に必要なサンプル数の情報理論的下限
ここではアルゴリズムの計算時間のことを考えず、ベースラインとして最低限サンプルがこれだけあれば何らかの方法で学習できるという知られた事実を紹介する。まず、カーネル法の場合次の事実が知られている。
一方、より広いクラスでは次の事実が知られており、ニューラルネットワークもこれに該当する。
この必要サンプル数の乖離が特徴学習の恩恵であるといえる。ただし、一般の $σ _ \ast$ に対して何らかのアルゴリズムで学習したとき、非凸最適化を行うこともあり、サンプル数は少なくても指数的な計算量がかかる可能性があることに注意。
ここで問題となるのは、これらの情報理論的下限を達成するような多項式時間アルゴリズムが実装できるか?ということである。
(Correlational) Statistical Queryの枠組みによる計算複雑さの下限
多項式時間で関数を学習するにあたっては少なくともこれだけのサンプルが必要あればよい、という計算量的な下限はSQ学習、CSQ学習の枠組みを利用することで(ヒューリスティックに)捉えることができる。(何かの論文で以下に類似の定義を見てメモしたのたが、どこで見たか忘れてしまった)
余談だが、この枠組みはしばしば差分プライバシーの文脈で見ることもあり、PAC学習とも関連があるようだ。
ここで二乗和誤差での勾配降下を思い出す。このとき
$$ \mathbb{E}[\nabla l(\boldsymbol{w})]=\frac{1}{2}\mathbb{E}[\nabla(y-f(\boldsymbol{x})) ^ 2]=\mathbb{E}[y\nabla f(\boldsymbol{x})]-\mathbb{E}[f(\boldsymbol{x})\nabla f(\boldsymbol{x})] $$
となる。
- $\mathbb{E}[y\nabla f(\boldsymbol{x})]$ この項は、未知のターゲット $y$ の情報を含む。これは、$\tilde{\phi}(x) = \nabla _ {\boldsymbol{w}} f _ {\boldsymbol{w}}(\boldsymbol{x})$とおくことで、まさにCSQクエリの形 $\mathbb{E}[\tilde{\phi}(\boldsymbol{x})y]$ に対応する。学習アルゴリズムは、CSQオラクルに問い合わせることでこの値(に最大で $τ$ だけ誤差ののった値)を得る。
- $\mathbb{E}[f(\boldsymbol{x})\nabla f(\boldsymbol{x})]$ : この項にはターゲット $y$ が含まれていないので、学習アルゴリズムは(入力データ $\boldsymbol{x}$ の分布に関する情報があれば)この期待値を自分自身で評価できる。
この観察から、(ワンパスの)SGDはCSQのフレームワークで捉えることができる。
当然だが、現実世界にCSQオラクルは存在しないので、現実世界ではこのCSQオラクルを模倣するシミュレータを使うことになる。
シミュレータのうち最も単純なものとして、我々は $n$ 個のサンプルを使って $\frac{1}{n}\sum _ {i=1} ^ ny _ i\nabla f (\boldsymbol{x} _ i)$ を返すものを考えることができる。サンプル $n$ 個を利用して近似するため、このシミュレータは測度の集中の考え方から $τ=O(n ^ {-1/2})$ となるようなCSQオラクルと同じものとみなせる。別のシミュレータを考えることもできる(他の手法により)が、累積される誤差を考えれば結局同じようなことがいえるはずである。したがって、この枠組みで学習複雑度を議論する論文では $τ=O(n ^ {-1/2})$ とヒューリスティックに定めたCSQオラクルを使うような学習者が、学習対象をどれだけのサンプル数で学習可能といえるかが議論される。これらの証明の中では、学習が可能となるような $τ$ の条件を $d$ の関係する不等式で与えることができるため、 $τ=O(n ^ {-1/2})$ の関係式によって $n$ と $d$ の関係性をみるというアプローチがとられる。
SQ, CSQ学習の枠組みをSingle Index Modelの学習に適用した一連の論文によれば、以下のような結果が得られている。
- CSQ学習法は $n\gtrsim d ^ {k _ \ast/2}$
- SQ学習法は $n\gtrsim d$
ただし、これらの結果はある種最悪ケース的な証明手法を利用することで得られるため、平均ケースインスタンスに対して主張するためには、LDP法(Low Degree Polynomial Method)というものを利用することになる。
ここで、情報指数 $k _ \ast$ という言葉が出てきたが、ここで定義しておこう。
あとでみるように、 $k _ \ast$ という指数は通常のSGDを含むCSQ学習の枠組みで説明されるアルゴリズムの計算限界を示す指標として顔を出す。
SQ学習法、CSQ学習法はある種この形式で捉えられるアルゴリズムクラスのことを指し示しているため、以下では実際のアルゴリズムによって得られる上限がこの下限を達成するかを確かめていくことになる。
SGDの達成するサンプル複雑度
ここでは、具体的なアルゴリズム、特にSGDが達成するサンプル複雑度について議論する[2]。
典型的には二乗和誤差最小 $\frac{1}{2}(f(\boldsymbol{x})-f _ \ast(\boldsymbol{x})) ^ 2$ もしくは自己相互作用を考えないために相関誤差最小 $1-f(\boldsymbol{x})f _ \ast(\boldsymbol{x})$ にするような設定が取られる。勾配降下による学習を行うにあたって、しばしば理論の簡単のため $\boldsymbol{w}\in\mathcal{S} ^ {d-1}$ に制約する関係で球面勾配とすることが多いが、このセクションのうちはあまり気にしなくてもいいような状況になっている(後から見返すとこのセクションは少し雑かもしれない)。
二乗和誤差は、エルミート多項式を使って次のように展開できる。ただし $m=\langle \boldsymbol{\theta},\boldsymbol{w}\rangle$ とする。
$$ \begin{split} \mathbb{E}[l(\boldsymbol{w})]&=\frac{1}{2}\mathbb{E}[(f(\boldsymbol{x})-f _ \ast(\boldsymbol{x})) ^ 2]\cr &=\frac{1}{2}\mathbb{E}[f(\boldsymbol{x}) ^ 2]+\frac{1}{2}\mathbb{E}[f _ \ast(\boldsymbol{x}) ^ 2]-\mathbb{E}[f(\boldsymbol{x})f _ \ast(\boldsymbol{x})]\cr &=\frac{1}{2}\mathbb{E}[σ(\langle \boldsymbol{w},\boldsymbol{x} \rangle) ^ 2]+\frac{1}{2}\mathbb{E}[σ _ \ast(\langle \boldsymbol{\theta},\boldsymbol{x} \rangle) ^ 2]-\mathbb{E}[σ(\langle \boldsymbol{w},\boldsymbol{x} \rangle)σ _ \ast(\langle \boldsymbol{\theta},\boldsymbol{x} \rangle) ]\cr &=\frac{1}{2}\sum _ {j=0} ^ \infty\beta _ j ^ 2\lVert\boldsymbol{w}\rVert ^ {2j}+\frac{1}{2}\sum _ {j=0} ^ \infty\alpha _ j ^ 2\lVert\boldsymbol{\theta}\rVert ^ {2j}-\sum _ {j=0} ^ \infty\alpha _ j\beta _ jm ^ j \end{split} $$
少々乱暴な議論だが、 $m $ が大きくなる前には $\lVert\boldsymbol{w}\rVert$ はほとんど変化がないと仮定する。 $\boldsymbol{w}\in\mathcal{S} ^ {d-1}$ に制約していれば $\lVert\boldsymbol{w}\rVert=1$ のため、また相関誤差を考えている場合にはそもそも $\lVert\boldsymbol{w}\rVert$ の項は存在しないため仮定することなく成り立つ。そうでない場合でも、 $\boldsymbol{w}$ を小さな値で初期化していれば、スケールは変えずに $\boldsymbol{\theta}$ とのアラインメントをある程度して(この段階ではロスの値の変化はほぼ観察されない)、その後に $\lVert\boldsymbol{w}\rVert$ のスケールを調節する、といった挙動が観察されることがあるため、ある程度妥当な仮定なはずである。この仮定のもとでは、ロスは相関部分 $-\sum _ {j=0} ^ \infty\alpha _ j\beta _ jm ^ j$ の挙動で決まる。もし、この相関部分だけをロスとみれば、 $m=0$ をみたす $\boldsymbol{w}$ の集合が鞍点であり、 $m=1$ となる $\boldsymbol{w}=\boldsymbol{\theta}$ が大域的最適解である。
一般にネットワークの初期化をランダムに行うと、 $m\sim 1/\sqrt{d}$ くらいの小さなオーダーになる。したがって、このロスランドスケープ内を移動する時、はじめは $\alpha _ j\beta _ j$ が存在する最低次の $j$ の影響が大きくなる。そのような $j$ が情報指数 $k _ \ast$ というものだった。
ところで、たとえば $\alpha _ k\beta _ k<0$ だった場合、 $m<0$ の方に動いていってしまい、局所固定点に収束してしまう。したがって、多くの場合 $\alpha _ k\beta _ k>0$ となっていることを仮定する。もしくは、 $\beta$ を学習することによって $\beta\to \alpha$ とすることで $\alpha _ k\beta _ k>0$ とできることを示す論文もある[3]。
勾配降下法を考えるにあたって、ロスの勾配の期待値を考えると(学習したい方向 $\boldsymbol{\theta}$ への変化量が知りたいので $\boldsymbol{w}$ 方向の変化は学習する方向とは無関係と思って無視すると)
$$ \begin{split} -\mathbb{E}[\nabla _ {\boldsymbol{w}}l(\boldsymbol{w})] &\approx \mathbb{E}[\nabla _ {\boldsymbol{w}}(f(\boldsymbol{x})f _ \ast(\boldsymbol{x}))] \cr &=\mathbb{E}[\boldsymbol{x}σ ^ {\prime}(\langle \boldsymbol{w},\boldsymbol{x} \rangle)σ _ \ast(\langle \boldsymbol{\theta},\boldsymbol{x} \rangle) ]\cr &= \boldsymbol{\theta} \mathbb{E}[σ ^ {\prime}(\langle \boldsymbol{w},\boldsymbol{x} \rangle)σ _ \ast ^ {\prime}(\langle \boldsymbol{\theta},\boldsymbol{x} \rangle)] + \boldsymbol{w} \mathbb{E}[σ''(\langle\boldsymbol{w},\boldsymbol{x}\rangle)σ _ \ast(\langle\boldsymbol{\theta},\boldsymbol{x}\rangle)] \cr &= \boldsymbol{\theta} \sum _ {i=0} ^ {\infty} (i+1) ^ 2 \alpha _ {i+1}\beta _ {i+1} \langle \boldsymbol{w}, \boldsymbol{\theta} \rangle ^ i + \dots \end{split} $$
となる(3つめの等号はスタインの補題で、そのあとはエルミート展開)。初期化時点では $\langle \boldsymbol{w}, \boldsymbol{\theta} \rangle\approx d ^ {-1/2}$ のオーダーであるはずである。もし $σ _ \ast$ が情報指数 $k _ \ast$ の多項式なら、この和にでてくる最低次数は $\langle \boldsymbol{w}, \boldsymbol{\theta} \rangle ^ {k _ \ast-1}$ になる。したがって、初期化時点でのオーダーは $O(d ^ {-(k _ \ast-1)/2})$ である。
ここでは、極端なケース2つをみることで直感的に必要となるサンプル数を見てることにする。
単純なケースその1: 1stepで有意に部分空間を回復するには?
勾配のバッチサイズ $n$ の平均と期待値の差分は測度の集中から
$$ \left\lVert\mathbb{E}[\nabla _ {\boldsymbol{w}}(f(\boldsymbol{x})f _ \ast(\boldsymbol{x}))]-\frac{1}{n}\sum _ {i=1} ^ n\nabla _ {\boldsymbol{w}}f(\boldsymbol{x} _ i)f _ \ast(\boldsymbol{x} _ i)\right\rVert\lesssim\sqrt{\frac{d}{n}} $$
と評価できる。だいたいオーダーとして $\sqrt{d/n}$ となるところから、シグナルの $O(d ^ {-(k _ \ast-1)/2})$ のオーダーがノイズに埋もれないようにするには、 $n=\Omega(d ^ {k _ \ast})$ でなければならないことがわかる。この場合、毎回のイテレーションで非自明な信号が常に取り出せることになる。
単純なケースその2: オンライン勾配降下で有意に部分空間を回復するには?
一方で、バッチサイズ1の確率的勾配降下法(オンライン勾配降下法)の場合
$$ \boldsymbol{w} _ {t+1}=\boldsymbol{w} _ t-η\nabla l(\boldsymbol{w} _ t)=\boldsymbol{w} _ t-η\mathbb{E} _ {\boldsymbol{w}}[\nabla l(\boldsymbol{w})]-η(l(\boldsymbol{w} _ t)-\mathbb{E} _ {\boldsymbol{w}}[\nabla l(\boldsymbol{w})]) $$
とみると、2項目のシグナルは $\mathbb{E} _ {\boldsymbol{w}}[\nabla l(\boldsymbol{w})]=O(d ^ {-(k _ \ast-1)/2})$ であり、ノイズ $l(\boldsymbol{w} _ t)-\mathbb{E} _ {\boldsymbol{w}}[\nabla l(\boldsymbol{w})]=O(1)$ を仮定する。すると、 $T$ ステップあとにはシグナルは累積されていくので $O(Td ^ {-(k _ \ast-1)/2})$ になり、ノイズはランダムに動いているのでスケールとしては $O(\sqrt{T})$ になる。したがって、シグナルがノイズを優越するには $T\gtrsim d ^ {k _ \ast-1}$ が必要となる。この設定ではサンプル数とステップ数は同じなので $n=\Omega(d ^ {{k _ \ast-1}})$ とわかる。
これら以外の中間のケースについては、バッチサイズとステップ数についての相図を描くような研究がなされている[4]。
ロスの平滑化はCSQ下限を達成する
ここまでの話では通常のSGDではCSQ下限よりも大きな上限しか得られないことがわかった。CSQ下限を達成するような方法として、ロスを平滑化するという方法がある。少し詳しく導出をする(というか論文[5]をほぼそのままなぞる)ので、まずはここまでで説明していた直感を粒度を細かくして再導出する。ここでは相関損失 $l(\boldsymbol{w})=1-yf(\boldsymbol{x})$ を損失関数に採用して話を進める。必要に応じてデータで期待値をとりエルミート展開 $L(\boldsymbol{w}):=\mathbb{E}[l(\boldsymbol{w})]=\sum _ {k\geq 0}\frac{c _ k ^ 2}{k!}(1-m ^ k)$ を行っていることに注意。
球面上でのSGDの1ステップを考える。 $\boldsymbol{v} _ t$ を勾配により更新する方向と定めることにする。球面上で勾配降下をする場合、更新ルールは次のようになる。
$$ \boldsymbol{w} _ {t+1}=\frac{\boldsymbol{w} _ t+η _ t\boldsymbol{v} _ t}{\lVert\boldsymbol{w} _ t+η _ t\boldsymbol{v} _ t\rVert} $$
以下では相関 $m _ t=\langle\boldsymbol{w} _ t,\boldsymbol{\theta}\rangle$ を追跡することを考え、はじめは $m _ t$ が小さいことを考慮して近似したり、 $\boldsymbol{w} _ t\perp \boldsymbol{v} _ t$ に注意したりすれば
$$ \begin{split} m _ {t+1}&=\frac{m _ t+η _ t\langle \boldsymbol{v} _ t,\boldsymbol{\theta}\rangle}{\lVert\boldsymbol{w} _ t+η _ t\boldsymbol{v} _ t\rVert}\cr &=\frac{m _ t+η _ t\langle \boldsymbol{v} _ t,\boldsymbol{\theta}\rangle}{\sqrt{1+η _ t ^ 2\lVert\boldsymbol{v} _ t\rVert ^ 2}}\cr &\approx m _ t+η _ t\langle \boldsymbol{v} _ t,\boldsymbol{\theta}\rangle-\frac{η _ t ^ 2\lVert\boldsymbol{v} _ t\rVert ^ 2m _ t}{2}+O(η _ t ^ 3) \end{split} $$
とできる。ここで、 $m _ {t+1}\geq m _ t$ となってほしい制約を課すことにすると(数学的にはフィルトレーション $\mathcal{F} _ t=σ(\lbrace (x _ {τ},y _ {τ})\rbrace _ {τ=0} ^ {t-1})$ 上での期待値を考えることでSNR $\frac{\langle \boldsymbol{v} _ t,\boldsymbol{\theta}\rangle ^ 2}{\lVert\boldsymbol{v} _ t\rVert ^ 2}$ を評価するが、以下では厳密に書かないことにする)条件として $η _ t\leq\frac{2\langle \boldsymbol{v} _ t,\boldsymbol{\theta}\rangle}{\lVert\boldsymbol{v} _ t\rVert ^ 2m _ t}$ が得られる。したがってこの範囲で最大更新してくれるような $η _ t$ を選んでしまうことにすれば
$$ m _ {t+1}\approx m _ t+\frac{1}{2m _ t}\frac{\langle \boldsymbol{v} _ t,\boldsymbol{\theta}\rangle ^ 2}{\lVert\boldsymbol{v} _ t\rVert ^ 2} $$
によって学習が進行していく。以下では、SNR $\frac{\langle \boldsymbol{v} _ t,\boldsymbol{\theta}\rangle ^ 2}{\lVert\boldsymbol{v} _ t\rVert ^ 2}$ の大きさを評価することで、勾配の更新に必要なサンプル複雑度を評価する。直感的には、この離散更新を微分方程式だと思うことで
$$ \frac{d}{dt}m=\frac{\mathrm{S N R}}{m} $$
から、必要なサンプル複雑度(学習時間)が評価できる。評価には $m = \Theta(d ^ {-1/2})$ から $m = \Theta(1)$ となるまでの時間 $T$ を計算すればよい。
通常のSGDにおける必要サンプル数
まずは通常のSGD(オンライン勾配降下法)について考える。このとき $\boldsymbol{v} _ t=-\nabla l(\boldsymbol{w} _ {t})$ である。期待値をとって評価を考える。 $\mathbb{E}[\lVert\boldsymbol{v} _ t\rVert ^ 2]=d$ くらいである。また $\mathbb{E}[\langle\boldsymbol{v} _ t,\boldsymbol{\theta}\rangle ^ 2]\approx\mathbb{E}[\langle\boldsymbol{v} _ t,\boldsymbol{\theta}\rangle] ^ 2$ を考えると $L(\boldsymbol{w})=\mathbb{E}[l(\boldsymbol{w})]=\mathbb{E}[1-yf(\boldsymbol{x})]=\sum _ {k=0} ^ \infty\frac{c _ k ^ 2}{k!}(1-m ^ k)$ のように展開することで $\nabla L(\boldsymbol{w})=-(I-\boldsymbol{w}\boldsymbol{w} ^ \top)\boldsymbol{w} ^ \ast \sum _ {k = 0} ^ \infty\frac{c _ k ^ 2}{(k - 1)!}m ^ {k - 1}$ であるが、情報指数の定義を思い出せば $\nabla L(\boldsymbol{w}) \approx m ^ {k _ \ast - 1}$ となる。これより
$$ \mathbb{E}[\langle\boldsymbol{v} _ t,\boldsymbol{\theta}\rangle]=-\langle\nabla L(\boldsymbol{w} _ {t}),\boldsymbol{\theta}\rangle\approx -m ^ {k _ \ast-1} $$
したがって
$$ \frac{dm}{dt}=-\frac{m ^ {2k _ \ast-3}}{d} $$
をとくと
$$ T=-\int _ {\Theta(d ^ {-1/2})} ^ {\Theta(1)}\frac{dm}{m ^ {2k _ \ast-3}}=\Theta(d ^ {k _ \ast-1}) $$
となり、以前の節で導出したものとコンシステントである。
平滑化ロスの場合の必要サンプル数
平滑化したロスをここから考える。平滑化というのは $\mathcal{L} _ \lambda$ というオペレータを使って
$$ l _ \lambda(\boldsymbol{w})=\mathcal{L} _ {\lambda}l(\boldsymbol{w})=\mathbb{E} _ {z\sim\mu(\mathbb{S} ^ {d-1})}\left[l\left(\frac{\boldsymbol{w}+\lambda\boldsymbol{z}}{\lVert\boldsymbol{w}+\lambda\boldsymbol{z}\rVert}\right)\right] $$
とすることをいう。同様に
$$ L _ \lambda(\boldsymbol{w})=\mathcal{L} _ {\lambda}L(\boldsymbol{w})=\mathbb{E} _ {z\sim\mu(\mathbb{S} ^ {d-1})}\left[L\left(\frac{\boldsymbol{w}+\lambda\boldsymbol{z}}{\lVert\boldsymbol{w}+\lambda\boldsymbol{z}\rVert}\right)\right] $$
も定義される。このオペレータは $\boldsymbol{w}$ のところ(もしくは $m $)のところにかかることに注意して、主要項だけ取り出すと、情報指数 $k _ \ast$ のところだけ考えればよいことがわかる。したがって、ここの次数をどれくらい下げられるかが焦点となる。
$$ \begin{split} \mathcal{L} _ \lambda(m ^ {k _ \ast})&=\mathbb{E} _ {z\sim\mu(\mathbb{S} ^ {d-1})}\left[\left\langle\frac{\boldsymbol{w}+\lambda\boldsymbol{z}}{\lVert\boldsymbol{w}+\lambda\boldsymbol{z}\rVert},\boldsymbol{\theta}\right\rangle ^ {k _ \ast}\right]\cr &\approx\lambda ^ {-k _ \ast}\mathbb{E} _ {z\sim\mu(\mathbb{S} ^ {d-1})}[(m+\lambda\langle \boldsymbol{z},\boldsymbol{\theta}\rangle) ^ {k _ \ast}]\cr &=\lambda ^ {-k _ \ast}\sum _ {j=0} ^ {k _ \ast} {} _ {k _ \ast}C _ j\ m ^ {k _ \ast-j}\lambda ^ j\mathbb{E} _ {z\sim\mu(\mathbb{S} ^ {d-1})}[\langle\boldsymbol{z},\boldsymbol{\theta}\rangle ^ j]\cr &\approx \lambda ^ {-k _ \ast}\sum _ {j=0} ^ {\lfloor k _ \ast/2\rfloor}m ^ {k _ \ast-2j}\left(\frac{\lambda ^ 2}{d}\right) ^ j=\sum _ {j=0} ^ {\lfloor k _ \ast/2\rfloor}\left(\frac{m}{\lambda}\right) ^ {k _ \ast}\left(\frac{\lambda ^ 2}{m ^ 2 d}\right) ^ j \end{split} $$
2つ目の近似では $\boldsymbol{z}\perp \boldsymbol{w}$ と $\lVert\boldsymbol{z}\rVert=1$ から $\lVert\boldsymbol{w}+\lambda \boldsymbol{z}\rVert=\sqrt{1+\lambda ^ 2}\approx\lambda$ となることを使っている。最後は、 $\boldsymbol{z}$ が $-\boldsymbol{z}$ と分布として等しいからゼロになることと、 $\langle\boldsymbol{z},\boldsymbol{\theta}\rangle=O(d ^ {-1/2})$ であることを使った上で、定数をすべて取り払った。したがって( $\boldsymbol{v} _ t=-\nabla l _ {\lambda _ t}(\boldsymbol{w} _ {t})$ として)
$$ \begin{split} \mathbb{E}[\langle\boldsymbol{v},\boldsymbol{\theta}\rangle]&\approx -\langle\nabla L _ \lambda(\boldsymbol{w}),\boldsymbol{\theta}\rangle\approx-\langle \nabla\mathcal{L} _ \lambda(m ^ {k _ \ast}),\boldsymbol{\theta}\rangle\cr &\approx \lambda ^ {-1}\sum _ {j=0} ^ {\lfloor(k _ \ast-1)/2\rfloor}\left(\frac{m}{\lambda}\right) ^ {k _ \ast-1}\left(\frac{\lambda ^ 2}{m ^ 2 d}\right) ^ j\cr &\approx\lambda ^ {-1}\begin{cases} \left(\frac{m}{\lambda}\right) ^ {k _ \ast-1} & (m\geq \lambda d ^ {-1/2})\cr d ^ {-(k _ \ast-1)/2} & (m\leq \lambda d ^ {-1/2}, k _ \ast\textrm{ is odd})\cr \frac{m}{\lambda}d ^ {-k _ \ast/2-1} & (m\leq \lambda d ^ {-1/2}, k _ \ast\textrm{ is even})\cr \end{cases} \end{split} $$
ここで、勾配の分散を考える。 $\nabla L _ \lambda(\boldsymbol{w})=-y\nabla\mathcal{L} _ \lambda(σ(\langle\boldsymbol{w},\boldsymbol{x}\rangle)) \approx \lambda ^ {-1}y\boldsymbol{x}\mathcal{L} _ \lambda(σ'(\langle\boldsymbol{w},\boldsymbol{x}\rangle))$ とみると $y=\tilde{O}(1)$ でありまた $\lVert\boldsymbol{x}\rVert=O(\sqrt{d})$ であることから以後は $\mathcal{L} _ \lambda(σ'(\langle\boldsymbol{w},\boldsymbol{x}\rangle))$ の評価を考える。
$$ \begin{split} \mathbb{E} _ {\boldsymbol{x}}[\mathcal{L} _ \lambda(σ'(\langle\boldsymbol{w},\boldsymbol{x}\rangle)) ^ 2] &=\mathbb{E} _ {\boldsymbol{x}}\left[\mathbb{E} _ {\boldsymbol{z}\sim\mu(\mathbb{S} ^ {d-1})}\left[σ'\left(\left\langle\frac{\boldsymbol{w}+\lambda\boldsymbol{z}}{\sqrt{1+\lambda ^ 2}},\boldsymbol{x}\right\rangle\right)\right] ^ 2\right]\cr &=\mathbb{E} _ {\boldsymbol{x}}\left[\mathbb{E} _ {\boldsymbol{z},\boldsymbol{z}'\sim\mu(\mathbb{S} ^ {d-1})}\left[σ'\left(\left\langle\frac{\boldsymbol{w}+\lambda\boldsymbol{z}}{\sqrt{1+\lambda ^ 2}},\boldsymbol{x}\right\rangle\right)σ'\left(\left\langle\frac{\boldsymbol{w}+\lambda\boldsymbol{z}'}{\sqrt{1+\lambda ^ 2}},\boldsymbol{x}\right\rangle\right)\right]\right]\cr &\approx\mathbb{E} _ {\boldsymbol{z},\boldsymbol{z}'\sim\mu(\mathbb{S} ^ {d-1})}\left[\left\langle\frac{\boldsymbol{w}+\lambda\boldsymbol{z}}{\sqrt{1+\lambda ^ 2}},\frac{\boldsymbol{w}+\lambda\boldsymbol{z}'}{\sqrt{1+\lambda ^ 2}}\right\rangle ^ {k _ \ast-1}\right]\cr &=\mathbb{E} _ {\boldsymbol{z},\boldsymbol{z}'\sim\mu(\mathbb{S} ^ {d-1})}\left[\left(\frac{1+\lambda ^ 2\langle\boldsymbol{z},\boldsymbol{z}'\rangle}{1+\lambda ^ 2}\right) ^ {k _ \ast-1}\right]\cr &\approx\begin{cases} (1+\lambda ^ 2) ^ {-(k _ \ast-1)}\approx\lambda ^ {-2(k _ \ast-1)} & (\lambda\ll d ^ {1/4})\cr \langle\boldsymbol{z},\boldsymbol{z}'\rangle ^ {k _ \ast-1}=O(d ^ {-(k _ \ast-1)/2}) & (\lambda\gg d ^ {1/4}) \end{cases} \end{split} $$
ここで、近似の部分では $x$ についてのエルミート多項式展開の期待値計算を利用して主要項だけを取り出している。この評価において、 $\langle\boldsymbol{z},\boldsymbol{z}'\rangle=O(d ^ {-1/2})$ 程度であることを考慮した。これより $\mathbb{E} _ {\boldsymbol{x}}[\mathcal{L} _ \lambda(σ'(\langle\boldsymbol{w},\boldsymbol{x}\rangle)) ^ 2]\lesssim\min(\lambda,d ^ {1/4}) ^ {-2(k _ \ast-1)}$ とわかる。
$y=\tilde{O}(1)$ と $\lVert\boldsymbol{x}\rVert=O(\sqrt{d})$ をあわせれば
$$ \mathbb{E}[\lVert\boldsymbol{v}\rVert ^ 2]=\lambda ^ {-2}\cdot 1\cdot d\cdot \min(\lambda,d ^ {1/4}) ^ {-2(k _ \ast-1)}\leq d\lambda ^ {-2k _ \ast} $$
と評価でき、特に $\lambda\leq d ^ {1/4}$ のときにノイズが $\mathbb{E}[\lVert\boldsymbol{v}\rVert ^ 2]\leq d\lambda ^ {-2k _ \ast}$ をみたし、信号よりもノイズが小さくなれる。これより結局知りたかったSNRは
$$ \mathrm{S N R}=\frac{\langle \boldsymbol{v} _ t,\boldsymbol{\theta}\rangle ^ 2}{\lVert\boldsymbol{v} _ t\rVert ^ 2}\approx\frac{1}{d}\begin{cases} m ^ {2(k _ \ast-1)} & (m\geq \lambda d ^ {-1/2})\cr \left(\frac{\lambda ^ 2}{d}\right) ^ {k _ \ast-1} & (m\leq \lambda d ^ {-1/2}, k _ \ast\textrm{ is odd})\cr m ^ 2\left(\frac{\lambda ^ 2}{d}\right) ^ {k _ \ast-2} & (m\leq \lambda d ^ {-1/2}, k _ \ast\textrm{ is even})\cr \end{cases} $$
をみたすと評価できる。なお特に $\lambda=d ^ {1/4}$ と定めた場合
$$ \mathrm{S N R}=\frac{\langle \boldsymbol{v} _ t,\boldsymbol{\theta}\rangle ^ 2}{\lVert\boldsymbol{v} _ t\rVert ^ 2}\approx\frac{1}{d}\begin{cases} m ^ {2(k _ \ast-1)} & (m\geq d ^ {-1/4})\cr d ^ {-(k _ \ast-1)/2} & (m\leq d ^ {-1/4}, k _ \ast\textrm{ is odd})\cr m ^ 2d ^ {-(k _ \ast-2)/2} & (m\leq d ^ {-1/4}, k _ \ast\textrm{ is even})\cr \end{cases} $$
となる。注目すべきは $m\leq d ^ {-1/4}$ の領域で、この領域ではSNRが $O(d ^ {-(k _ \ast-1)/2})$ になっており、このスケールで勾配が更新されるという点である。したがって、その分weak recoveryまでのサンプル複雑度を減らすことができるということである。
これを使って、微分方程式としてみた式に代入することで真面目に計算すると
$$ \begin{split} T&=\int _ {\Theta(d ^ {-1/2})} ^ {\lambda d ^ {-1/2}} m ^ {-1}\frac{d ^ {k _ \ast-3}}{\lambda ^ {2k _ \ast-4}} \cdot dm+\int _ {\lambda d ^ {-1/2}} ^ {\Theta(1)}dm ^ {-2k _ \ast+3}\cdot dm\cr &=\left[-\log m\right] _ {\Theta(d ^ {-1/2})} ^ {\lambda d ^ {-1/2}}d ^ {k _ \ast-3}\lambda d ^ {-1/2}+d\cdot \lambda ^ {-2k _ \ast+4}d ^ {k _ \ast-2}+\Theta(d)\cr &=\tilde{O}(d ^ {k _ \ast-1}\lambda ^ {-2k _ \ast+4}) \end{split} $$
となる。 $\lambda=d ^ {1/4}$ を代入すれば $T=\tilde{O}(d ^ {k _ \ast/2})$ となり、対数分を除いてCSQ下限に一致する。したがってCSQの意味で最適なアルゴリズムは平滑化によって達成できることがわかる。
バッチ再利用はSQ下限を達成する
このように、SGDを使った学習はCSQ学習法で捉えることができ、情報指数が計算量的困難性を特徴づけていることがわかった。しかしながら、CSQ学習法では情報論的下限のサンプル数を達成できない。
一方でSQ学習法では情報論的下限を達成できている。ここで得られる示唆は、SQ学習法は、(その定義をみると)何らかの方法で情報指数を下げるように、 $y$ に対して非線形な変換を行っているのではないか?ということである。SGDに対してなんらかの工夫を取り入れることで、この非線形な変換を達成できるだろうか?となるわけだが、この非線形な変換を行うひとつの方法がバッチ再利用(同じ訓練データ点を2回使う)である。バッチ再利用を行うSGDアルゴリズムはSQ学習のフレームワークで捉えることができる。さらには、SQ下限を達成できることが証明されている[6][7]。
バッチ再利用を行う次のような更新方式 $\boldsymbol{w} _ t\mapsto \boldsymbol{w} _ {t+1}$ を考える。 $l(\boldsymbol{w})=\frac{1}{2}(y-f _ {\boldsymbol{w}}(\boldsymbol{x})) ^ 2$ とする。
$$ \tilde{\boldsymbol{w}} _ {t+1}=\boldsymbol{w} _ t-η\nabla l(\boldsymbol{w} _ t),\quad\boldsymbol{w} _ {t+1}=\boldsymbol{w} _ t-\gamma\nabla l(\tilde{\boldsymbol{w}} _ {t+1}) $$
実務的に使われるバッチ再利用は次のエポックで行われ、それを明示的に連続的にして取り入れたものと見ることができる気がする。なお、明示的にこのような戦略をとるアルゴリズムにはSharpness Aware Minimization[8]があるが、平坦性に注目するアルゴリズムとして設計されているものであった。
初期化を $\boldsymbol{w} _ 0\approx 0$ にとったとすれば、そのスケールが初期化付近と変わらない限りは $-\nabla l(\boldsymbol{w})=y\nabla f _ {\boldsymbol{w}}(\boldsymbol{x})\approx y\boldsymbol{x}σ'(\langle\boldsymbol{w},\boldsymbol{x}\rangle)$ となることに注意して
$$ \begin{split} \boldsymbol{w} _ {t+1} &=\boldsymbol{w} _ t-\gamma\nabla l(\tilde{\boldsymbol{w}} _ {t+1})\cr &=\boldsymbol{w} _ t-\gamma\nabla l(\boldsymbol{w} _ t-η\nabla l(\boldsymbol{w} _ t))\cr &=\boldsymbol{w} _ t+\gamma y\boldsymbol{x}σ'(\langle\boldsymbol{w} _ t-η\nabla l(\boldsymbol{w} _ t),\boldsymbol{x}\rangle)\cr &=\boldsymbol{w} _ t+\gamma y\boldsymbol{x}σ'(\langle\boldsymbol{w} _ t,\boldsymbol{x}\rangle+η y\lVert\boldsymbol{x}\rVert ^ 2σ'(\langle\boldsymbol{w} _ t,\boldsymbol{x}\rangle))\cr &=\boldsymbol{w} _ t+\gamma y\boldsymbol{x}\left[σ'(\langle\boldsymbol{w} _ t,\boldsymbol{x}\rangle)+σ''(\langle\boldsymbol{w} _ t,\boldsymbol{x}\rangle)(ησ'(\langle\boldsymbol{w} _ t,\boldsymbol{x}\rangle) y\lVert\boldsymbol{x}\rVert ^ 2)+\frac{1}{2}σ'''(\langle\boldsymbol{w} _ t,\boldsymbol{x}\rangle)(ησ'(\langle\boldsymbol{w} _ t,\boldsymbol{x}\rangle) y\lVert\boldsymbol{x}\rVert ^ 2) ^ 2+\cdots \right]\cr &=\boldsymbol{w} _ t+\gamma\sum _ {i=0} ^ {\infty}\frac{1}{i!}(η\lVert\boldsymbol{x}\rVert ^ 2) ^ iy ^ {i+1}σ'(\langle\boldsymbol{w} _ t,\boldsymbol{x}\rangle) ^ iσ ^ {(i+1)}(\langle\boldsymbol{w} _ t,\boldsymbol{x}\rangle)\boldsymbol{x} \end{split} $$
のように書ける。このとき、 $i=1$ の $\gamma yσ'(\langle\boldsymbol{w} _ t,\boldsymbol{x}\rangle)\boldsymbol{x}$ の部分は $y$ と $\boldsymbol{x}$ の相関であるため、これまでの議論を考えればCSQ学習の枠組みで捉えることができた。しかしながら、それ以降の $i>1$ の部分は $y$ について2次以上の式となっており、CSQ学習の枠組みでは捉えられず、SQ学習の枠組みで捉えることとなる。
これに対して $\boldsymbol{\theta}$ との内積をとり、右の和の部分を期待値にすれば $m _ {t}=\langle\boldsymbol{\theta},\boldsymbol{w} _ t\rangle$ とし、 $u=\langle\boldsymbol{\theta},\boldsymbol{x}\rangle,v _ t=\langle\boldsymbol{w} _ t,\boldsymbol{x}\rangle$ とすれば
$$ m _ {t+1}=m _ t+\gamma\sum _ {i=0} ^ {\infty}\frac{η ^ i}{i!}\mathbb{E}\left[\lVert\boldsymbol{x}\rVert ^ {2i}σ _ \ast(u) ^ {i+1}σ'(v _ t) ^ iσ ^ {(i+1)}(v _ t)u\right] $$
となる。さて、この期待値をざっくり評価していくことにする。まず、$\mathbb{E}[\lVert\boldsymbol{x}\rVert ^ 2]\approx d$ となることを使い、この部分が他の部分とほとんど独立だと思うことにすれば
$$ \begin{split} \mathbb{E}\left[\lVert\boldsymbol{x}\rVert ^ {2i}σ _ \ast(u) ^ {i+1}σ'(v _ t) ^ iσ ^ {(i+1)}(v _ t)u\right]&\approx\mathbb{E}[\lVert\boldsymbol{x}\rVert ^ {2i}]\mathbb{E}[σ _ \ast(u) ^ {i+1}uσ'(v _ t) ^ iσ ^ {(i+1)}(v _ t)]\cr &=d ^ i\mathbb{E}[σ _ \ast(u) ^ {i+1}uσ'(v _ t) ^ iσ ^ {(i+1)}(v _ t)] \end{split} $$
となる。次に $\mathbb{E}[σ _ \ast(u) ^ {i+1}uσ'(v _ t) ^ iσ ^ {(i+1)}(v _ t)]$ を評価する。$σ _ \ast(u) ^ {i+1}=\sum _ {j=\tilde{k} _ \ast} ^ \infty\alpha _ j ^ {(i+1)}\mathrm{He} _ j(u)$ とエルミート多項式展開する( $\tilde{k} _ \ast$ はこの展開における最低次数であり、 $i$ に依存することに注意する)。同様に $σ'(v _ t) ^ iσ ^ {(i+1)}(v _ t)=\sum _ {k=0} ^ \infty\beta _ k ^ {(i)}\mathrm{He} _ k(v _ t)$ と展開する。このとき、エルミート多項式の性質である $u\mathrm{He} _ j(u)=\mathrm{He} _ {j+1}(u)+j\mathrm{He} _ {j-1}(u)$ と $\mathbb{E}[\mathrm{He} _ a(u)\mathrm{He} _ b(v _ t)]=\delta _ {ab}m _ t ^ a$ を利用すると
$$ \begin{split} \mathbb{E}[σ _ \ast(u) ^ {i+1}uσ'(v _ t) ^ iσ ^ {(i+1)}(v _ t)] &=\mathbb{E}\left[\left(\sum _ {j=\tilde{k} _ \ast} ^ \infty\alpha _ j ^ {(i+1)}\mathrm{He} _ j(u)u\right)\left(\sum _ {k=0} ^ \infty\beta _ k ^ {(i)}\mathrm{He} _ k(v _ t)\right)\right]\cr &=\mathbb{E}\left[\left(\sum _ {j=\tilde{k} _ \ast} ^ \infty\alpha _ j ^ {(i+1)}(\mathrm{He} _ {j+1}(u)+j\mathrm{He} _ {j-1}(u))\right)\left(\sum _ {k=0} ^ \infty\beta _ k ^ {(i)}\mathrm{He} _ k(v _ t)\right)\right]\cr &=\sum _ {j=\tilde{k} _ \ast} ^ \infty\alpha _ j ^ {(i+1)}(\beta _ {j+1} ^ {(i)}m _ t ^ {j+1}+j\beta _ {j-1} ^ {(i)}m _ t ^ {j-1}) \end{split} $$
となる。したがって、この更新式で最も $m _ t$ の次数が小さい部分だけ取り出したとすれば、それは $p _ \ast=\min _ {i\geq 1}\tilde{k} _ \ast$ と、$i _ \ast=\arg\min _ {i\geq 1}\tilde{k} _ \ast$ として
$$ m _ {t+1}\approx m _ t+\gamma\frac{η ^ {i _ \ast}d ^ {i _ \ast}}{i _ \ast!}\alpha _ {p _ \ast} ^ {(i _ \ast+1)}p _ \ast\beta _ {p _ \ast-1} ^ {(i _ \ast)}m _ t ^ {p _ \ast-1} $$
となる。ここまで来れば、あとは $p _ \ast$ がいくつであるかを評価すればよい。 $p _ \ast$ の出処をたどると $σ _ \ast(u) ^ {i _ \ast+1}$ にたどり着く。したがって、これについての評価を行う。ここでは見やすさのために $\mathcal{T}(σ _ \ast(z))=σ _ \ast(z) ^ i$ と定めて、これについて考えることにする。結論としては、 $\mathcal{T}(σ _ \ast(z))$ の $p _ \ast$ は $1$ もしくは $2$ になることが、以下のように示すことができる。
まずは、もしこの関数の情報指数が $1$ であるならば( $z\sim\mathcal{N}(0,1)$ とする)
$$ \mathbb{E}[\mathcal{T}(σ _ \ast(z))\mathrm{He} _ 1(z)]=\mathbb{E}[T(σ _ \ast(z))z]\neq 0 $$
となるはずであるなので、こうなる場合はいつかを考える。
ここですべての $i$ について $\mathbb{E}[\mathcal{T}(σ _ \ast(z))\mathrm{He} _ 1(z)]=0$ となる状況を考える。これは純粋な偶関数となる場合だけであることを背理法で示す。純粋な偶関数ではないとすれば、グラフを描いたときに $y$ 軸非対称になっていることに注意する。いま、期待値をとる範囲を $(-\infty,-z _ \ast],(-z _ \ast,z _ \ast],(z _ \ast,\infty]$ と分割する。ここで $(-z _ \ast,z _ \ast]$ の区間は $T(σ _ \ast(z))=0$ の解がすべて含まれるようにとる。そうすると、さらに $σ _ \ast$ の非対称性から $(-\infty,-z _ \ast]$ と $(z _ \ast,\infty]$ の範囲の積分は打ち消し合わず、多項式のため $z\to\pm\infty$ で発散することを考えれば $z _ 0\in(z _ \ast,\infty]$ として、定数 $M=\max _ {z\in(-z _ \ast,z _ \ast]}|σ _ \ast(z)|$ を使って $|σ _ \ast(z _ 0) ^ i-σ(-z _ 0) ^ i|>(2M) ^ i$ とできるはずである。このことは、 $i\to\infty$ とすれば $\mathbb{E}[|T(σ _ \ast(z))\mathrm{He} _ 1(z)|\mathbb{1} _ {\lbrace z\in(-z _ \ast,z _ \ast]\rbrace }]$ については $M ^ i$ のように増加するのに対して、 $\mathbb{E}[|\mathcal{T}(σ _ \ast(z))\mathrm{He} _ 1(z)|\mathbb{1} _ {\lbrace z\notin(-z _ \ast,z _ \ast]\rbrace }]$ は $(2M) ^ i$ のように増加することを示している。すなわち $\mathbb{E}[|\mathcal{T}(σ _ \ast(z))\mathrm{He} _ 1(z)|\mathbb{1} _ {\lbrace z\in(-z _ \ast,z _ \ast]\rbrace }]<\mathbb{E}[|\mathcal{T}(σ _ \ast(z))\mathrm{He} _ 1(z)|\mathbb{1} _ {\lbrace z\notin(-z _ \ast,z _ \ast]\rbrace }]$ となることを意味しており、これは $\mathbb{E}[\mathcal{T}(σ _ \ast(z))\mathrm{He} _ 1(z)]=0$ となるはずであったことに矛盾する。
したがって、背理法からすべての $i$ について $\mathbb{E}[\mathcal{T}(σ _ \ast(z))\mathrm{He} _ 1(z)]=0$ ならば $σ _ \ast(z)$ は純粋な偶関数であることが示されたので、対偶をとって、 $σ _ \ast(z)$ が偶関数でないならば、ある $i$ に対して $\mathbb{E}[\mathcal{T}(σ _ \ast(z))\mathrm{He} _ 1(z)]\neq0$ となることがわかった。これは、 $σ _ \ast(z)$ が偶関数でないならば $p _ \ast=1$ であることを意味する。
次に偶関数の場合について考える。このとき
$$ \mathbb{E}[\mathcal{T}(σ _ \ast(z))\mathrm{He} _ 2(z)]=\mathbb{E}[\mathcal{T}(σ _ \ast(z))(z ^ 2-1)]\neq 0 $$
となることを示す。これも偶関数でない場合と同様に考えて、期待値をとる範囲を $(-\infty,-z _ \ast],(-z _ \ast,z _ \ast],(z _ \ast,\infty]$ と分割する。ここで $z _ \ast$ は十分大きくとる。これによって、 $σ _ \ast(z)$ の $y$ 軸対称性と $|z|\to\infty$ では $σ _ \ast(z)$ がいくらでも大きくなることから $i\to\infty$ では $\mathbb{E}[|\mathcal{T}(σ _ \ast(z))\mathrm{He} _ 2(z)|\mathbb{1} _ {\lbrace z\in(-z _ \ast,z _ \ast]\rbrace }]<\mathbb{E}[|\mathcal{T}(σ _ \ast(z))\mathrm{He} _ 2(z)|\mathbb{1} _ {\lbrace z\notin(-z _ \ast,z _ \ast]\rbrace }]$ とできる。これはある $i$ が存在して $\mathbb{E}[\mathcal{T}(σ _ \ast(z))\mathrm{He} _ 2(z)]\neq 0$ を意味する。つまり、 $σ _ \ast(z)$ が偶関数ならば $p _ \ast=2$ であることを意味する。
偶関数であるかどうかで場合分けしたので、 $σ _ \ast(z)$ の関数形についての場合分けは尽くされていて、示したかったことが示せた。
というわけで、長かったが $η d=\Theta(1)$ になるようにパラメータを設定し、さらに $\lVert\nabla L\rVert=O(\sqrt{d})$ 程度の大きさになることが期待されるため $\gamma=O(1/d)$ とスケールしておくことに注意すれば、適当な定数 $C$ があって
$$ m _ {t+1}\approx m _ t+\frac{C}{d}m _ t ^ {p _ \ast-1} $$
となることがわかる。これを近似的に微分方程式とみなせば
$$ \frac{dm}{dt}=\frac{C}{d}m ^ {p _ \ast-1} $$
となるので、これをとけばよい。 $t$ が $m _ t=\Theta(1/\sqrt{d})$ から $m _ t=\Theta(1)$ に到達した時間 $T$ を計算すると
$$ T=\begin{cases} O(d) & (p _ \ast=1)\cr O(d\log d) & (p _ \ast=2) \end{cases}=\tilde{O}(d) $$
となることがわかる。
なお、ここで出てきた $p _ \ast$ というものが生成的指数(Generative Exponent)とよばれるものである。
生成的指数は、定義にあるように、考えられるあらゆるラベル変換を試したうえで、達成可能な最小の情報指数として定められている。これはまさにCSQ学習をこえるSQ学習の枠組みから示唆されるアイデアを数式に起こしたものといえよう。ここで紹介している定義は変分表現のような最適化アルゴリズム的な見方をしたものになっているが、そうでない場合には $\chi _ 2$ 情報量を使って定義される[9]。情報指数が相関ベースの手法という特定のアプローチで $L _ 2$ 空間的な鞍点の形状に注目しており、ラベルをうまく前処理することによって見かけ上の困難性が変化していた。生成的指数ではこのあらゆる変換を考慮して見かけ上の部分を排除し、情報量を測っていることから、より手法に依存しない、問題固有の計算論的な困難さの度合いを示していそうに見える。
ここまでのまとめ
長かったのでまとめることにしよう。入力データの次元を $d$ とし、学習に必要なサンプル数を $n$ とかく。情報理論的なサンプル数の下限として、カーネル法は学習対象の多項式の次数 $p$ として $n\gtrsim d ^ p$ 程度のサンプル数が必要だったが、ニューラルネット等では $n\gtrsim d$ 程度あればよいことが調べられていた。このことはニューラルネットの情報理論的な性能の優越を意味しているが、具体的にこれを達成できる多項式アルゴリズムが存在するかはまた別の話である。理論的にどれくらいのサンプル数であるアルゴリズムのクラスが関数を学習可能かをチェックする枠組みが、SQ学習やCSQ学習といわれるクラスである。前者は $n\gtrsim d$ 程度、後者は $n \gtrsim d ^ {k _ \ast /2}$ 程度であることが知られている。ひとまず通常のSGDを考えると、簡単な式変形でCSQ学習の枠組みで捉えることができることがチェックできる。通常のSGDがどの程度この下限に迫れるかを確かめると、だいたい $n \gtrsim d ^ {k _ \ast -1}$ 程度であった。ここで $k _ \ast $ は情報指数とよばれる学習関数のエルミート多項式展開における最低次数である。下限までにはギャップがあるが、これはロスの平滑化によって達成可能なことが示された。とはいえ、CSQ学習のサンプル数の下限は情報理論的な下限には程遠い。そこで、SQ学習の枠組みを思い出し、ラベルに対する何らかの変換をアルゴリズムに組み込むことで情報理論的な下限を達成できる可能性が示唆される。これを自然な形で実装するのがバッチの再利用であり、この方法では情報理論的な下限を達成できることが示された。なお厳密には生成的指数というものがこのアルゴリズムにおけるサンプル数を支配している。
以下はおまけである。
教師信号にノイズのある場合とSQの枠組みを超えたアルゴリズムの可能性
学習対象がSingle Index Modelの場合にはこのように多項式時間で学習可能となるアルゴリズムが確かにあり、これらは(C)SQ学習とも密接に関係することがわかった。しかし、層を増やして、たとえば2層のReLUを活性化関数に利用したものに対しては、暗号学的困難性に絡めた次のような結果も知られている[10]。ReLUで2層のネットワークの出力 $\sum _ {i=1} ^ n a _ i\mathrm{ReLU}(\gamma\langle w _ i,x\rangle+b _ i)$ に有界なノイズが乗ったものが観測されるとした場合の問題設定を考える。このとき、ノイズのスケールによって学習可能性が変化することが知られている。具体的には、ノイズのスケールが入力次元に対して指数的に小さい場合には多項式時間で解けるが、ノイズのスケールが入力次元に対して多項式の逆数程度である場合には、もし解けてしまうと暗号学的な困難性の仮定に矛盾するというものである。
このことを示すには、1層のSingle Periodic Neuronの特定インスタンスが2層ReLUネットワークに変換できることが示せるため、Single Periodic Neuron $z=\cos(\gamma\langle w,x\rangle)+\xi$ のパラメータが推定可能かどうかを考えることになる[11]。
まず、ノイズのスケールが入力次元に対して指数的に小さい場合には多項式時間で解けることについて説明する。特に、次元 $d$ に対してサンプル数は $d+1$ 個だけで十分である。具体的に解くには、まず arccos を作用させて $1/2π$ をかけ、適当に $2 ^ N$ 倍する(LLLアルゴリズムは整数に対して適用するアルゴリズムのため、適当な精度ではあるが整数に変換する必要がある)ことでSingle Periodic Neuronの問題を $z _ i$ を変換した $\tilde{z} _ i$ と、ある $\epsilon _ i\in\lbrace \pm 1\rbrace $ と $K _ i\in\mathbb{Z}$ を用いて $\langle w,x _ i\rangle=\epsilon _ i \tilde{z} _ i+K _ i$ のような形式に変換する。この時点では $\epsilon _ i$ や $K _ i$ がわからないため、これを求める必要がある。そのためにまず $\sum _ {i=1} ^ {d+1}\lambda _ i x _ i=0$ をみたすような $\lbrace \lambda _ i\rbrace _ {i=1} ^ {d+1}$ を見つける。この操作はただちに可能で、 $\lambda _ 1=1$ と $(\lambda _ 2,\ldots,\lambda _ {d+1}) ^ \top=X ^ {-1}x _ 1$ とすればよい。これを使うと $\sum _ {i=1} ^ {d+1}\epsilon _ i(\lambda _ i\tilde{z} _ i)+\sum _ {i=1} ^ {d+1}K _ i\lambda _ i=\sum _ {i=1} ^ {d+1}\lambda _ i\langle x _ i,w\rangle\approx 0$ という関係式が得られ、これは左辺と右辺をみることで変数が $K _ 1,\ldots,K _ {d+1},\epsilon _ 1,\ldots,\epsilon _ {d+1}$ であるような一次不定方程式となる。したがって、これに対してLLLアルゴリズムを適用することによりこの値を求めることができる。これができれば、 $\langle w,x _ i\rangle=\epsilon _ i \tilde{z} _ i+K _ i$ を並べたものに対して $X ^ +=(x _ 1,\ldots,x _ {d+1}) ^ +$ を作用させることによって $w$ を復元することができる。
ここで注意すべきは、LLLアルゴリズムは上記で説明した(C)SQ学習の枠組みでは捉えられないということである。SQ学習の枠組みは、なんらかの関数 $\phi(x,y)$ の形のクエリの、分布にわたる期待値に基づいた値を返すため、サンプルひとつひとつを捉えるような枠組みにはなっておらず、分布として変換したものを返すイメージである。一方でLLLアルゴリズムは、実際に入力されたサンプルを変換して出力するアルゴリズムであるため、本質的に分布で捉えることのできないものとなっている。この結果は、現在使われているような勾配降下法以外のなんらかの方法で、ニューラルネットワークをよりサンプル効率良く学習できるアルゴリズムが存在する可能性を示唆しているといえるかもしれない。
なお、LLLアルゴリズムはCTFでも利用されることがある。使われ方としてはまさに上記のような一次不定方程式を解くために利用されることもある。また特に、現代暗号である格子暗号の文脈で、その安全性を保証するNP困難問題であるSVP問題の近似版を多項式時間で解くアルゴリズムとして知られている。
次に、ノイズのスケールが入力次元に対して多項式の逆数程度である場合、もしこれを推定できるアルゴリズムが存在すれば、NP困難問題であるSVP問題が解けることになることが示されている。このことは、いかなる多項式時間アルゴリズムを利用しても推定できないことを示唆する。具体的には、CLWEのインスタンスがSingle Periodic Neuronのインスタンスに帰着できることがわかれば、CLWEより前のインスタンスの帰着の連鎖から困難性を示すことができる。
CLWE問題 $\tilde{z}=\gamma\langle w,x\rangle+\xi'\bmod 1$ に対して $\cos$ を作用させることによって $z=\cos(\gamma\langle w,x\rangle+\xi')=\cos(\gamma\langle w,x\rangle)+\xi$ に帰着することができる。ここで、うまくこのような $\xi$ が定まり、この状況ではリプシッツ性から適当な定数 $L$ で $|\xi|\leq L|\xi'|$ とできることに注意する。さらにCLWEにおける $\xi'$ はガウスノイズであるため、適当な閾値を定めれば十分に高い確率でノイズの値が閾値以下となるようにでき、高確率でノイズの有界性がみたされる。したがってSingle Periodic Neuronのインスタンスになることがわかる。したがって、もしSingle Periodic Neuronの問題が解ければ、CLWE問題が解け、するとLWE問題が解け、するとSVP問題が解け、ということになり、SVP問題が多項式時間で解けないことを仮定していれば矛盾することがわかる。
参考文献
[1] 2025年度人工知能学会全国大会チュートリアル講演「深層基盤モデルの数理」
[2] G. B. Arous, R. Gheissari, and A. Jagannath. Online Stochastic gradient descent on non-convex losses from high-dimensional inference. Journal of Machine Learning Research, Vol. 22, No. 106, pp. 1–51, 2021.
[3] A. Bietti, J. Bruna, C. Sanford, and M. J. Song. Learning Single-index Models with Shallow Neural Networks. Advances in neural information processing systems, Vol. 35, pp. 9768–9783, 2022.
[4] L. Arnaboldi, Y. Dandi, F. Krzakala, B. Loureiro, L. Pesce, and L. Stephan. Online Learning and Information Exponents: the Importance of Batch Size & Time / Complexity Tradeoffs. In International Conference on Machine Learning, pp. 1730–1762. PMLR, 2024.
[5] A. Damian, E. Nichani, R. Ge, and J. D. Lee. Smoothing the Landscape Boosts the Signal for SGD Optimal Sample Complexity for Learning Single Index Models. Advances in Neural Information Processing Systems, Vol. 36, pp. 752–784, 2023.
[6] J. D. Lee, K. Oko, T. Suzuki, and D. Wu. Neural Network Learns Low-dimensional Polynomials with Sgd Near the Information-theoretic Limit. Advances in Neural Information Processing Systems, Vol. 37, pp. 58716–58756, 2024.
[7] L. Arnaboldi, Y. Dandi, F. Krzakala, L. Pesce, and L. Stephan. Repetita Iuvant: Data Repetition Allows Sgd to Learn High-dimensional Multi-index Functions. arXiv preprint arXiv:2405.15459, 2024.
[8] P. Foret, A. Kleiner, H. Mobahi, and B. Neyshabur. Sharpness-aware Minimization for Efficiently Improving Generalization. arXiv preprint arXiv:2010.01412, 2020.
[9] A. Damian, L. Pillaud-Vivien, J. D. Lee, and J. Bruna. Computational-statistical gaps in gaussian single-index models. arXiv preprint arXiv:2403.05529, 2024.
[10] S. Li, I. Zadik, and M. Zampetakis. On the Hardness of Learning One Hidden Layer Neural Networks. In Algorithmic Learning Theory, pp. 700–701. PMLR, 2025.
[11] M. J. Song, I. Zadik, and J. Bruna. On the Cryptographic Hardness of Learning Single Periodic Neurons. Advances in neural information processing systems, Vol. 34, pp. 29602–29615, 2021.