ざっくりML

機械学習に興味ある大学院生によるブログです.

交差検定

今回はモデル選択の時に使われる手法の、交差検定について軽くまとめて実装して見たいと思う.

交差検定

訓練とテストに使えるデータには限りがあり、良いモデルを選択するために得られたデータをできるだけたくさん訓練に使いたい. しかし、確認用に使うテストデータが小さいとたまたま良い精度が出てしまったのかもしれないなど、うまく評価ができない. そこでよく用いられる手法に交差検定 (cross-validation) がある. 得られたデータのうち\frac{S-1}{S}の割合のデータを訓練に使い、残りをテストデータに使う. 例えば全データ数が[N = 100]とし、S = 4とする. この時訓練に使うデータは全体の\frac{4-1}{4}、つまり7.5割(75個)を使い残りの2.5割(25個)をテストデータに使う. この場合全データをS分割したことになる. したがって、テストに使える訓練データはS個あるので、学習と評価をS回繰り返す. S回分の精度の平均を学習器の精度として採用することで、比較を行うことができる. 1回目の学習では最初の25個をテストデータ、残りを訓練データに、2回目は次のブロックのテストデータを使う. これを繰り返す.
f:id:linearml:20170923033620p:plain 欠点としては、分割数と訓練数が比例することである. 一回の訓練に時間のかかる学習器に対して、交差検定をしようとすると、S回訓練を行うことになるので、膨大な時間がかかることは想像つく.
以下に私が書いたソースコード載せる. 分類器はサポートベクタマシン、データセットはirisを使う.

交差検定

16,17,25行目で与えられたデータをシャッフルするためのマスクを定義している. 28,29行目でs回目のテストデータ、30,31行目では学習に使う訓練データを作成している. setdiff1d(a,b)はaとbの差集合を求めることができるので、それを利用して訓練データとテストデータを分割することにした. 36,37,38行目で精度(この指標についてはまた今度)を計算して保存しておく. 最後にその平均を出力することにしている

結果は以下のように得られた.
micro precision : 0.96
micro recall : 0.96
micro F1 : 0.96