カーネルリッジ回帰の数式的な説明とその実装例について

カーネルリッジ回帰の導出

状況の設定

\(n\)個の入力変数と出力変数のペアの間にノイズを含む以下の関係があるとする。

\[ y_i=f^*(x_i)+\varepsilon_i\quad(i=1,\dots,n)\]

ここで\(\varepsilon\)はノイズを表す変数で\(\mathbb{E}[\varepsilon_i]=0\), \(\mathrm{Var}[\varepsilon_i]=\sigma^2\)を満たす同一の分布に独立に従っているとする(i.i.d.)。ガウスノイズ\(\varepsilon\sim N(0,\sigma^2I_n)\)とする場合も多いがどちらでもよい。

ここでは\(\{x_1,\dots,x_n\}\)は固定されており、出力にノイズが乗るモデルを考えていてこれをfixed designモデルという場合がある。random designの場合にはさらに\(\{x_1,\dots,x_n\}\)もある入力の分布\(P_X\)に従って毎回生成されていてその上で入出力関係に従って\(y_i\)が生成されると考える。

今はfixed designであり\(\{x_1,\dots,x_n\}\)は一貫して変化しないとするので\(X=(x_1^\top,\dots,x_n^\top)^\top\)というようにそれらを各行に並べて\(n\)行の行列を考えてそれをdesign matrix、計画行列と言ったりもする。ちなみに以下計画行列は出てこない。

ただし、圧縮センシングなどでは出力にさえノイズを乗せないものをfixedと言ったりもするのでこの辺りの呼び方には注意が必要。

仮説空間をRKHSに設定

関数\(f^*\)は真の入出力関係を表す関数であり、\(n\)の観測からその関数を推定したい。そのためにまず仮説空間という、\(f^*\)が含まれているであろう関数のクラスを考えて、その中からデータに最もフィットしている関数\(\widehat f\)を選び出すということをする。

カーネルリッジ回帰では、そのような仮説空間\(\mathcal{F}\)として、ある正定値カーネルが定めるRKHSを取る。そこで、考えるカーネル関数を\(k\)、それに対応する再生核ヒルベルト空間を\(\mathcal{H}\)と書いておく。\(\mathcal{F}=\mathcal{H}\)である。

この中からデータに最もフィットする関数を選ぶためには以下を考えるのが自然だろう。

\[ \begin{align} \widehat f &=\mathop{\rm argmin}_{f\in\mathcal{H}} \left[ \frac1n\sum_{i=1}^n\frac12(f(x_i)-y_i)^2 \right]\\ &=\mathop{\rm argmin}_{f\in\mathcal{H}} \left[ \frac1n\sum_{i=1}^n\frac12(\langle f,k(x_i,\cdot)\rangle_\mathcal{H}-y_i)^2 \right] \end{align}\]

突然2行目が出てきたがこれははRKHSの元における再生性によるもので、意味的には1行目と同じなので難しいことは無い。単に平均2乗誤差が最も小さくなるような関数を\(\mathcal{H}\)の中から選び出すだけだ。\(1/n\)\(1/2\)は最適化の結果に影響を及ぼさないが慣例でこうする。

ノイズを考慮して修正

ところがこのスキームには問題がある。RKHSにはデータ点の数だけの自由度があるので\(n\)個のデータ点全てに対して誤差ゼロでフィットしてしまう。つまり補間になり、場合によってはドラスティックな変動をともなう関数になってしまう。いま観測値には誤差が含まれているので実際のところ以下を満たしていれば十分だろう。

\[ (f(x_i)-y_i)^2\le\sigma^2\quad(i=1,\dots,n)\]

右辺は今ノイズの分散\(\sigma^2\)としているが、これ以外に選択肢はないという意味ではない。例えば\(\sigma^2/2\)としてより厳しい条件を与えてもかまわない。しかし、ノイズレベルに合わせて関数値とのずれ具合を制限するやり方には納得がいくはずだ。

この条件の下で\(\|f\|^2_\mathcal{H}\)が最も小さいような\(f\)を選ぶ戦略はサポートベクター回帰といわれる。カーネルという言葉は含まれていないがサポートベクター回帰ではカーネル関数を考えることがほぼ前提となっている。

ここではやや異なる以下の条件を考える。

\[ \frac1n\sum_{i=1}^n\frac12(f(x_i)-y_i)^2\le \frac{\delta^2}2\]

それぞれのデータ点ごとの条件と平均二乗誤差に対する条件なのでさきほどの条件とは異なる。この条件の下で\(\|f\|_\mathcal{H}^2\)が最も小さいような\(f\)を選ぶという方法が考えられる。\(\delta^2\)はさきほどと同様にノイズレベルに応じて適当に決めればよいがその意味はさきほどとは異なる。ゆえに違う文字を使っている。

カーネルリッジ回帰の問題

カーネルリッジ回帰では同じような考え方のもと以下を考える。

\[ \min_{f\in\mathcal{H}} \frac1n\sum_{i=1}^n\frac12(f(x_i)-y_i)^2\;\text{s.t.}\;\|f\|_\mathcal{H}^2\le R\]

目的関数と条件が入れ替わっている。これはラグランジュ乗数\(\lambda\)を用いて以下のように表現できる。

\[ \begin{align} &\min_{f\in\mathcal{H}}\left[ \frac1n\sum_{i=1}^n\frac12(f(x_i)-y_i)^2+\lambda\|\mathcal{f}\|_\mathcal{H}^2 \right]\\ =&\min_{f\in\mathcal{H}}\left[ \sum_{i=1}^n\frac12(f(x_i)-y_i)^2+n\lambda\|\mathcal{f}\|_\mathcal{H}^2 \right] \end{align}\]

これは、普通のリッジ回帰の問題と同じ形式をしており、このことをカーネルリッジ回帰と呼ぶ。ラグランジュ乗数として導入した\(\lambda\)はリッジ正則化のパラメータといわれる。まとめるとカーネルリッジ回帰の問題は、

\[ \begin{align} \widehat f&=\mathop{\rm argmin}_{f\in\mathcal{H}}\left[ \sum_{i=1}^n\frac12(f(x_i)-y_i)^2+n\lambda\|\mathcal{f}\|_\mathcal{H}^2 \right]\\ &=\mathop{\rm argmin}_{f\in\mathcal{H}}\left[ \sum_{i=1}^n\frac12(\langle f,k(x_i,\cdot)\rangle_\mathcal{H} -y_i)^2 +n\lambda \langle f, f\rangle_\mathcal{H} \right] \end{align}\]

を求めることである。最後の行は敢えてRKHSにおける内積演算のみで表してある。

カーネルリッジ回帰の問題の解

カーネルリッジ回帰の問題は一見して関数空間の中で関数を動かすよくわからない問題となっている。実際には、Representer定理により次のような形式の\(f\)の中から最適な関数を選んでおけばそれが大域的最適化になることがわかっている。

\[ f=\sum_{i=1}^n\alpha_ik(x_i,\cdot)\]

これで関数空間内で関数を動かすという最適化を係数\(\alpha_1,\dots,\alpha_n\)を動かすという最適化に変えることが出来る。まず第1項にこれを代入すると、

\[ \begin{align} \sum_{i=1}^n(\langle f,k(x_i,\cdot)\rangle_\mathcal{H}-y_i)^2 &=\sum_{i=1}^n\left(\left\langle \sum_{j=1}^n\alpha_jk(x_j,\cdot) ,k(x_i,\cdot)\right\rangle_\mathcal{H}-y_i\right)^2\\ &=\sum_{i=1}^n\left( \sum_{j=1}^n\alpha_j\left\langle k(x_j,\cdot) ,k(x_i,\cdot)\right\rangle_\mathcal{H}-y_i\right)^2\\ &=\sum_{i=1}^n\left( \sum_{j=1}^n\alpha_jk(x_i,x_j)-y_i\right)^2 \end{align}\]

次に第2項にこれを代入すると、

\[ \begin{align} \langle f, f\rangle_\mathcal{H} &=\left\langle \sum_{i=1}^n\alpha_ik(x_i,\cdot), \sum_{j=1}^n\alpha_jk(x_j,\cdot)\right\rangle_\mathcal{H}\\ &=\sum_{i=1}^n\sum_{j=1}^n\alpha_i\alpha_j \langle k(x_i,\cdot),k(x_j,\cdot)\rangle_\mathcal{H}\\ &=\sum_{i=1}^n\sum_{j=1}^n\alpha_i\alpha_jk(x_i,x_j) \end{align}\]

となる。いずれも\(k(x_i,x_j)\)というのが登場しており、この値を要素にもつ\(n\times n\)行列\(K\)\(\boldsymbol{\alpha}=(\alpha_i)_{i=1}^n\)考えると次の簡単な形式によってあらわせる(これをグラム行列という)。

\[ \begin{align} \widehat{\boldsymbol{\alpha}} &=\mathop{\rm argmin}_{\boldsymbol{\alpha}\in\mathbb{R}^n} [(K\boldsymbol{\alpha}-\boldsymbol{y})^\top (K\boldsymbol{\alpha}-\boldsymbol{y}) +n\lambda\boldsymbol{\alpha}^\top K\boldsymbol{\alpha}] \end{align}\]

この問題は凸最適化になっており解は陽に求まる。

\[ \boldsymbol{\alpha}=(K+n\lambda I_n)^{-1}\boldsymbol{y}\]

これにより求めたかった関数は次のようにあらわされる。

\[ \widehat f(x)=\boldsymbol{k}(x)^\top(K+n\lambda I_n)^{-1}\boldsymbol{y}\]

ここに、\(\boldsymbol{k}(x)=(k(x_i,x))_{i=1}^n\)というベクトルで\(x\)によって決まる。以上でカーネルリッジ回帰の解を求めることが出来た。

カーネルリッジ回帰の特徴

訓練、推論ともに計算量が大きい。なんの工夫なしには実用上の\(n\)が大きい問題に対しては使えない。逆行列の計算に\(O(n^3)\)が必要になる。新たな\(m\)点に対する推論に\(O(mn)\)の計算が必要になる。とにかく扱うデータが多ければ多いほどもりもり計算量が大きくなっていく手法だといえる。これに関しては工夫がいろいろあるようだ。

次にカーネル関数の選択が難しい。リッジパラメータをCVで調整するだけならともかく考えうる正定値カーネルはいくらでも存在しており、選び方の指針も十分に確立されているとは言えないため。さらにカーネル関数がハイパーパラメータを含んでいる場合などもある。

理論的には\(f^*\in\mathcal{H}\)ととれるか\(f^*\notin\mathcal{H}\)となってしまうかの2択のように見えるが実際\(f^*\)なんてものがどういう関数なのか数値実験の場合を除いてわからない。

Numpyによる計算例

カーネルリッジ回帰の計算例

これは入力出力ともに1次元の場合に、\(f^*(x)=x(x-1)(x-3/4)/2\)という3次多項式に対して様々なカーネル関数でカーネルリッジ回帰を行った例である。左上から順番に、

  • 2次の多項式カーネル
  • 3次の多項式カーネル
  • Sobolevカーネル
  • ガウシアンカーネル

となっている。データ点数\(n=20\), ノイズレベルは\(\sigma=1/200\)としている。入力のデータ点については\([0,1]\)の範囲を\(n\)等分し、区間内に各1点を一様分布によって生成している。

意地悪なことに\([0,3/2]\)の範囲を表示しているが後半\([1,3/2]\)部分にはデータ点が存在しないため外挿が必要になる。いかにカーネル関数の設定による仮説空間の設定が真の関数の出やすさにマッチしているかが、この部分のフィッティングにおいては試されているといえるだろう。

予想される通り、真の関数が3次多項式であるので3次の多項式カーネルが全ての区間でよいマッチングを達成している。ガウシアンカーネルも全ての多項式成分を持つ特徴空間を生成することから悪くない。しかし実際にはハイパーパラメータの調整が必要になるのでだからといっていつで3次多項式とわかっているなら多項式カーネルの方が手間はかからない。

完全にリッジ正則化パラメータをゼロとしている例は示していないが、この場合は数値的不安定性から解が発散することがあるので代わりに十分に小さな正則化パラメータを「正則化無し」の場合と見なすのが普通だ。この場合には得られる関数はジャギジャギとしてしまっている。

逆に正則化パラメータが大きすぎる場合にはいずれの場合も真の関数の性質をキャプチャできているとは言えない状態になっている。Sobolevカーネルだけは単なる補間のようになっていて特殊だ。なぜこうなるか考えてみよう。

用いたPythonコード

上の実験を行って図まで作成する完全なPythonコードを以下に示しておく。

import numpy as np
import matplotlib.pyplot as plt

def Ftrue(x):
    return x*(x-1)*(x-3/4)/2

def k1(x,y):
    return (1+np.inner(x,y))**2

def k2(x,y):
    return (1+np.inner(x,y))**3

def k3(x,y):
    return 1+min(x,y)

def k4(x,y):
    return np.exp(-(x-y)**2/5)

def ker_ridge(X, Y, k, xt, lmd):
    K = np.array([[k(X[i],X[j]) for i in range(n)] for j in range(n)])
    a = np.linalg.solve(K+lmd*np.eye(n), Y)
    yp = np.array([np.sum([a[i]*k(X[i], x) for i in range(n)]) for x in xt])
    return K, a, yp

n = 20
sigma = 1/200

xt = np.linspace(0,1.5,150)
yt = Ftrue(xt)

rs = np.random.RandomState(10)
X = np.linspace(0,1,n+1)[:-1] + rs.uniform(0,1,size=n)/n
Y = Ftrue(X) + rs.randn(n)*sigma

fig, ax = plt.subplots(2,2,figsize=(8,8))

ks = [k1, k2, k3, k4]
axs = [ax[0,0], ax[0,1], ax[1,0], ax[1,1]]
for k, axx in zip(ks, axs):
    axx.set_ylim([-0.02, 0.06])
    axx.plot(xt, yt, lw=0.8, ls='--')
    axx.scatter(X, Y, color='red', s=3)
    for lmd in [1e-14, 1e-6, 1e-3, 1e-1]:
        K, a, y = ker_ridge(X, Y, k, xt, lmd)
        axx.plot(xt, y, lw=1, label='$n\lambda$=%.0e' % lmd)
        axx.legend()

ax[0,0].set_title('$k(x,y)=(1+x\cdot y)^2$')
ax[0,1].set_title('$k(x,y)=(1+x\cdot y)^3$')
ax[1,0].set_title('$k(x,y)=1+\min(x,y)$')
ax[1,1].set_title('$k(x,y)=\exp(-\|x-y\|^2/5)$')

fig.suptitle('kernel ridge regression $f^*(x)=x(x-1)(x-3/4)/2$')
fig.tight_layout(rect=[0, 0.03, 1, 0.95])

plt.show()
plt.close(fig)

そんなに効率よく書いていなくても50行程度でこれくらいの実験ができるPython+Numpyは素晴らしい。

参考文献

Martin J. Wainwright, High-dimensional Statistics: A Non-asymptotic Viewpoint. p. 407-408. Fitting via kernel ridge regression.