ほくそ笑む

R言語と統計解析について

SVM のチューニングのしかた(2)

さて、前回は交差検証の説明で終わってしまいましたが、今回はちゃんと SVM のチューニングの話をします。
チューニングの手順としては、

  1. グリッドサーチで大雑把に検索する。
  2. 最適なパラメータがありそうなところを絞って再びグリッドサーチを行う。

という2段階のグリッドサーチを行います。

1段階目:グリッドサーチで大雑把に検索する

SVM のチューニングは tune.svm() という関数を用いて行います。
チューニングのやり方は、単純にグリッドサーチを行っているだけです。
パラメータの値をいろいろ変えてみて、正答率の一番いい値をベストパラメータとして出力します。
プログラムは下記のようになります。

gammaRange = 10^(-5:5)
costRange = 10^(-2:2)
t <- tune.svm(Species ~ ., data = iris, gamma=gammaRange, cost=costRange,
              tunecontrol = tune.control(sampling="cross", cross=8))
cat("- best parameters:\n")
cat("gamma =", t$best.parameters$gamma, "; cost =", t$best.parameters$cost, ";\n")
cat("accuracy:", 100 - t$best.performance * 100, "%\n\n")
plot(t, transform.x=log10, transform.y=log10)

まずはグリッドサーチするために、gamma と cost の範囲を決めています。ここでは gamma は 10^{-5}10^5 まで、cost は 10^{-2}10^2 までを指定しています。
次に tune.svm() でグリッドサーチを行います。svm() と同じような使い方ですが、交差検証を行うために tunecontrol = tune.control(sampling="cross", cross=8) と指定しています。これで 8-交差検証で評価を行うということになります。
あとはグリッドサーチの結果としてベストパラメータを出力しています。上記のプログラムを実行すると、結果は下記のように出ました。

- best parameters:
gamma = 0.1 ; cost = 1 ;
accuracy: 97.33187 %

gamma=0.1, cost=1 の組合せのとき、正答率 97.3 % を出していることがわかります。
ただし、1段階目ではベストパラメータの値はそれほど重視しません。
重要なのは最後の行、plot() 関数でグリッドサーチの結果を等高線図にしています。
等高線図は次のようになりました。

この図の中で色の濃い部分に最適なパラメータがありそうです。
そこで、gamma=10^{-1},cost=10^{0.5} のあたりと gamma=10^{-1.5},cost=10^{1.5} のあたりを再びグリッドサーチで調べてみましょう。これが2段階目となります。

2段階目:最適なパラメータがありそうなところを絞って再びグリッドサーチ

まずは gamma=10^{-1},cost=10^{0.5} のあたりをグリッドサーチしてみましょう。

gamma <- 10^(-1)
cost  <- 10^(0.5)
gammaRange <- 10^seq(log10(gamma)-1,log10(gamma)+1,length=11)[2:10]
costRange  <- 10^seq(log10(cost)-1 ,log10(cost)+1 ,length=11)[2:10]
t <- tune.svm(Species ~ ., data = iris, gamma=gammaRange, cost=costRange,
              tunecontrol = tune.control(sampling="cross", cross=8))
cat("[gamma =", gamma, ", cost =" , cost , "]\n")
cat("- best parameters:\n")
cat("gamma =", t$best.parameters$gamma, "; cost =", t$best.parameters$cost, ";\n")
cat("accuracy:", 100 - t$best.performance * 100, "%\n\n")
plot(t, transform.x=log10, transform.y=log10, zlim=c(0,0.1))

around [gamma = 0.1 , cost = 3.162278 ]
- best parameters:
gamma = 0.06309573 ; cost = 5.011872 ;
accuracy: 97.33187 %

gamma, cost, 正答率が出ました。
次は、gamma=10^{-1.5},cost=10^{1.5} のあたりをグリッドサーチしてみましょう。

gamma <- 10^(-1.5)
cost  <- 10^(1.5)
gammaRange <- 10^seq(log10(gamma)-1,log10(gamma)+1,length=11)[2:10]
costRange  <- 10^seq(log10(cost)-1 ,log10(cost)+1 ,length=11)[2:10]
t <- tune.svm(Species ~ ., data = iris, gamma=gammaRange, cost=costRange,
              tunecontrol = tune.control(sampling="cross", cross=8))
cat("[gamma =", gamma, ", cost =" , cost , "]\n")
cat("- best parameters:\n")
cat("gamma =", t$best.parameters$gamma, "; cost =", t$best.parameters$cost, ";\n")
cat("accuracy:", 100 - t$best.performance * 100, "%\n\n")
plot(t, transform.x=log10, transform.y=log10, zlim=c(0,0.1))

around [gamma = 0.03162278 , cost = 31.62278 ]
- best parameters:
gamma = 0.05011872 ; cost = 5.011872 ;
accuracy: 98.64766 %

こうして2回グリッドサーチをやりましたが、1回目の正答率は 97.3%、2回目の正答率は 98.6% と2回目の方が高いです。
というわけで、最適なパラメータとして、gamma = 0.05011872, cost = 5.011872 を選ぶことにします。

おわりに

以上で、SVM のチューニングが終わりました。上記のパラメータで SVM を作成するには下記のようにします。

gamma = 0.05011872 ; cost = 5.011872 ;
model <- svm(Species ~ ., data = iris, gamma=gamma, cost=cost)

これで、学習データを判別してみると

pred <- predict(model, iris)
table(pred, iris[,5])
pred         setosa versicolor virginica
  setosa         50          0         0
  versicolor      0         49         1
  virginica       0          1        49

本に載っているより、判別精度が上がっていることもわかります。
以上です。

追記

Wikipedia の項目が交差検定から交差検証に変更されたため、それに合わせて記事内の用語も交差検証に変更しています。