気ままに実装する機械学習

機械学習に興味のある大学院生によるブログです. 機械学習以外のトピック多めです.

交差検定

交差検定

今回はモデル選択の時に使われる手法の、交差検定について軽くまとめて実装して見たいと思います.

交差検定について

訓練とテストに使えるデータには限りがありますが、良いモデルを選択するために得られたデータをできるだけたくさん訓練に使いたいです.

しかし、確認用に使うテストデータが小さいとたまたま良い精度が出てしまったのかもしれないなど、うまく評価ができないことがあります.

そこでよく用いられる手法に交差検定 (cross-validation) がありますので紹介したいと思います.

得られたデータのうち\frac{S-1}{S}の割合のデータを訓練に使い、残りをテストデータに使います.

例えば全データ数がN = 100とし、S = 4とします. この時訓練に使うデータは全体の\frac{4-1}{4}、つまり7.5割(75個)を使い残りの2.5割(25個)をテストデータに使います.

この場合、全データをS分割したことになります. S分割したとき、テストに使える訓練データはS個あるので、学習と評価をS回繰り返すことができます. S回分の精度の平均を学習器の精度として採用することで、比較を行うことができます.

1回目の学習では最初の25個をテストデータ残りを訓練データに、2回目は次のブロックのテストデータを使います. これを繰り返します.

f:id:linearml:20170923033620p:plain

欠点としては、分割数と訓練数が比例することです.

一回の訓練に時間のかかる学習器に対して、交差検定をしようとすると、S回訓練を行うことになるので、膨大な時間がかかることは想像つきますよね...

交差検定の実装例

以下に私が書いたソースコード載せておきます.

分類器はサポートベクタマシン、データセットirisを使っています.

交差検定

16,17,25行目で与えられたデータをシャッフルするためのマスクを定義しています.

28,29行目でs回目のテストデータ、30,31行目では学習に使う訓練データを作成しています.

setdiff1d(a,b)

はaとbの差集合を求めることができるので、それを利用して訓練データとテストデータを分割することにしました.

36,37,38行目で精度(この指標についてはまた今度)を計算して保存しています.

最後にその平均を出力しています.

結果は以下のように得られました.

評価指標
micro precision 0.96
micro recall 0.96
micro F1 0.96