ほくそ笑む

R言語と統計解析について

交差検証の k の値はどれくらいにすればいいのか

分類器(識別器)のモデルを評価する手法に交差検証(クロスバリデーション)があります。
交差検証を行うには、データをいくつに分割するかを表す k の値を決めてあげなければなりません。
SVM のチューニングのしかた(1) において、交差検証の k の値を決めるとき、僕は個人的に

k = 1 + log(n)/log(2)

という式を用いていると書きました。
この式は、知っている人ならわかると思いますが、スタージェスの公式です。
スタージェスの公式は、ヒストグラムを描く際にサンプル数から階級数を決めるのに便利な公式です。
しかし、この公式を交差検証の k を決める際に使用するのは、はっきりいって根拠がありません。
そこで、今日は交差検証の k の値をどのくらいにすれば良いのかについて考えてみたいと思います。

準備(予備知識)

k の値は大きければ大きいほど、正確にモデルを評価できます。
k の最大値はサンプル数ですので、「k = サンプル数」とした場合が最も正確な評価だと言えます*1
というわけで、できれば「k = サンプル数」でやりたいところですが、k が大きければそれだけ実行に時間がかかるので、k はなるべく小さい値にしなければなりません。
交差検証の原理上、「k = サンプル数」とした場合、モデルの評価値(正答率)は、交差検証を繰り返したとしても常に同じ値になります。
逆に、「k < サンプル数」の場合は、交差検証を何度かやると、そのたびに違った値が返ってきます。
この違いのばらつき具合(分散)は、k が小さいほど顕著になると考えられます。
そこで、k をいろいろ変えてみて、一つの k に対して交差検証を繰り返し、そのときの分散を調べてみると何かがわかるのではないかと思いました。

シミュレーション

上で書いた実験を、下記のような R のコードでシミュレーションしてみました。

library(e1071)
data(iris)

sds <- c()
means <- c()
cat("k-fold", "Mean", "SD", "\n", sep="\t")
range <- seq(5, 150, by=5)
for(k in range) {
  accs <- c()
  for(j in 1:100) {
    model <- svm(Species ~ ., data = iris, cross=k)
    accs <- c(accs, model$tot.accuracy)
  }
  mean <- mean(accs)
  sd <- sd(accs)
  cat(k, mean, sd, "\n", sep="\t")
  means <- c(means, mean)
  sds <- c(sds, sd)
}

データとしては以前に使った iris を使用しました。k の値を 5, 10, 15, ... というように増やしていって、150(iris のサンプル数)まで調べています。
一つの k に対して、交差検証を 100 回繰り返し、正答率の平均(mean)と標準偏差(sd)を出しています。
結果は下記のような感じで出てきます。

k-fold Mean SD
5 94.73333 1.005038
10 95.86 0.6162411
15 96.18 0.4910421
20 96.27333 0.435207
25 96.43333 0.3834395
... ... ...
150 96.66667 0

グラフ

さて、標準偏差のグラフを描いてみます。

plot(range, sds)
sp <- smooth.spline(range, sds)
pred <- predict(sp, seq(5, 150, length=20))
lines(pred)
abline(v=(1+log2(150))*4, col="red")


やはり k が小さいほど標準偏差は大きいことが見て取れます。
また、標準偏差は k が大きくなるにつれ減少していきますが、k の小さい範囲での減少率と k がある程度大きくなってからの減少率では違いが見えます。
グラフ中に赤い線を引いてありますが、これは

k = (1 + log(n)/log(2)) * 4

のところに引いています。僕が思うに、ここら辺が減少率の境目かなーというところです。
もうひとつ、グラフを作ってみましょう。

plot(range, means, ylim=c(93, 97))
sp <- smooth.spline(range, means)
pred <- predict(sp, seq(5, 150, length=100))
lines(pred)

pred <- predict(sp, range)
ucl <- pred$y + sds
lcl <- pred$y - sds
segments(range, ucl, range, lcl)
abline(v=(1+log2(150))*4, col="red")


これは、正答率の平均値に、エラーバーとして標準偏差を付け加えたものです。
このグラフからも、k が小さい間は正確な評価ができていないのに対し、ある程度大きくなると横ばいになっている様子が見て取れます。

まとめ

上で見たように、サンプル数に対して k が小さすぎると、正答率の誤差が大きくなります。
モデル比較の際には、この誤差に気をつける必要があります。
安易に「こっちのモデルの方が正答率が 1% 高かったので採用する」などというのはナンセンスです。
なぜなら k が小さすぎて誤差が 1% 以上あるかもしれないからです。

*1:LOOCV(leave-one-out cross-validation) と言います