ほくそ笑む

R言語と統計解析について

SVM のチューニングのしかた(1)

SVM のチューニング

SVM(Support Vector Machine) はみなさん御存じ機械学習の手法です。
SVM はデフォルト設定でモデルを作ってもしょうがないです。gamma と cost というパラメータがあるので、これらの値に最適値を設定しなければなりません。R の SVM の Help にもこう書いてあります。

Parameters of SVM-models usually must be tuned to yield sensible results!
(訳) SVM でいい結果出したかったらチューニングしろよな!

というわけで、SVM のチューニングのしかたについて説明したいと思います。

交差検証

おっと、その前に、交差検証の話をしなければなりません。
SVM モデルをチューニングする際、二つのパラメータでグリッドサーチをします。
すなわち、パラメータをいろいろ変えてみて、一番いい SVM モデルとなる組合せを選ぶのです。
この「一番いい SVM」というのはどうやって評価すればよいでしょうか?
この SVM の評価方法として良く使われるのが交差検証です。
交差検証を説明するために、SVM を評価する単純な方法から順を追って説明していきたいと思います。

単純な方法

まず考えられるのは、SVM を作るのに使った学習データを再び評価にも用いるという方法です。*1
学習に使ったデータを判別してみて、うまく判別できた割合(正答率)を出します。
しかし、これはうまいやり方ではありません。
学習に使ったデータがうまく判別できるのは、当たり前のことだからです。
我々がやりたいことは、SVM で学習データ(既知データ)を判別したいのではなく、まだ手元に無いデータ(未知データ)をうまく判別できるかどうかで SVM を評価したいのです。
そこで、次の方法が考えられます。

ホールドアウト検証

ホールドアウト検定は、SVM の評価方法としてはかなり単純です。
学習に使用するために集めたデータのうち、何割かを評価専用のデータにするのです。
このとき、評価に使うデータは学習データに含めてはいけません。
つまり、集めたデータを学習専用と評価専用の二つに分け、学習用データで SVM を作成し、評価用データで評価する、ということです。
これにより、作成した SVM が、どのくらい未知データを判別できるのかが評価できます。
これがホールドアウト検証です。
ただし、ホールドアウト検証には次のような欠点があります。

  • 集めたデータを学習専用と評価専用に分けるため、学習に使えるデータの量が減る。
  • データを学習用と評価用に分ける際、分け方に偶然偏りができた場合にはうまく評価ができない。

これら二つの欠点をうまく解決しているのが、次に説明する交差検証という手法です。

交差検証

交差検証では、学習に使うために集めたデータをいくつかに分割します。
いま、例えば 150サンプルのデータを集めたとします。これを 50サンプルづつ、3個のグループに分割したときのことを考えてみましょう。
まず、第1グループと第2グループの合計 100 サンプルを学習データとして SVM を作成します。
そして、残りの第3グループの 50 サンプルを評価用データとして正答率を出します。
次に、第1グループと第3グループを学習データ、第2グループを評価用データとして正答率を出します。
最後に、第2グループと第3グループを学習データ、第1グループを評価用データとして正答率を出します。
これを表にすると次のようになります。

第1グループ 第2グループ 第3グループ
1回目 学習 学習 評価
2回目 学習 評価 学習
3回目 評価 学習 学習

こうして出された3つの正答率の平均値を、150サンプル全体を学習データとして作成したときの SVM の正答率とみなすのが交差検証という手法です。
3つのグループに分けたということを明示したいときは、3-交差検証といいます。
一般に、k個のグループに分割する交差検証のことを k-交差検証といいます。
この手法には次の特徴があります。

  • 未知データに対して正答率を出している。
  • 最終的な正答率は学習データ全体を使用して作成された SVM に対するものである。
  • 学習用データと評価用データに偶然偏りができても、平均を取ることによって影響を小さくできる。

ホールドアウト検証の欠点をうまく解決していることがわかると思います。

実際にやってみよう

さて、R で交差検証をやるのは、実に簡単です。実際にやってみましょう。

library(e1071)
data(iris)
model <- svm(Species ~ ., data = iris, cross=3)
summary(model)

とまあ、svm の引数に cross=3 と入れれば、3-交差検証が行われ、結果を summary() で見ることができます。

3-fold cross-validation on training data:
Total Accuracy: 94.66667
Single Accuracies:
98 90 96

ちなみに、cross-validation が交差検証、accuracy が正答率です。Single Accuracies がそれぞれの試行の正答率で、Total Accuracy がそれらの平均値になっているのがわかりますね。
ところで、iris は 150 サンプルあるので、グループ 3つでは少ない気がします。
いくつのグループに分けるかで評価も変わってくるので、ここは重要です。
僕がいつも使っているのは次の数式です*2。サンプル数を n とすると、

k = 1 + log(n)/log(2)

iris は 150 サンプルあるので、

k = 1 + log(150)/log(2) = 8.2288...

というわけで、k = 8 でやってみることにします。

model <- svm(Species ~ ., data = iris, cross=8)
summary(model)

8-fold cross-validation on training data:
Total Accuracy: 95.33333
Single Accuracies:
94.44444 89.47368 100 89.47368 100 94.73684 94.73684 100

これで、iris データに対して、SVM のデフォルト設定での正答率は 95.3 ということが求まりました。
次回は、この正答率を上げるべく、SVM をチューニングしていきたいと思います。

追記

Wikipedia の項目が交差検定から交差検証と変更されたため、それに合わせて記事内の用語も交差検証に変更しています。

*1:プラグイン法というらしい

*2:個人的に使っている式です。理論的な背景はまったくありません。