2016年2月5日金曜日

相関係数を用いてmnistのラベルを認識してみる[Python]

線形代数はわかるけど,ディープラーニングとかわかる気がしねぇ(゜∀。)ワヒャヒャヒャヒャヒャヒャ
という人向けです.
各ラベルごとの平均画像との相関関係からmnistのラベルを予想します.

1.使用する画像

MNIST は 0 から 9 の手書き数字が書かれた,70,000 点の白黒画像です.こんな感じ.
画像は28×28の784ピクセル.今回は学習データを評価用のデータに含めちゃダメなので,この画像データ70,000点のうち,60,000点を学習用のデータとし,10,000点を評価用のデータとします。
とは言うものの学習という学習はしないのであまり関係がないけれど.

2.相関係数の求め方

まぁ今回は一応相関係数を求める部分も実装したのでTeXの練習がてら書いときます.普通はnumpy.corrcoef関数とか使えばいいと思います.
次のようなベクトルx,yがあったとします.
このとき,x,yの相関係数は次のようになります.

ただし,xバー,yバーは各要素の平均です.
同じことですが,これはx,yの平均ベクトルを次のように定義してやるともう少し見通しが良くなります.
とすると,相関係数rは
と表すことができ,x,yベクトルから要素の平均値を引いたベクトルの方向余弦(cosθ)になるわけです.(割とこれが言いたいだけの記事だったりする.線形代数の先生がドヤ顔して話してた)

今回作ったプログラムではxは調べたい画像のベクトル,yは0,1,2,3...9の平均画像に当たります.最も似ている(相関がある)ものを探していきます.

3.ソースコード


4.結果

こんな感じ.
[[ 878    0    3    6    4   40   36    0   14    2]
 [   0 1095   16    4    1    6    2    3   33    1]
 [  26   37  732   44   23    1   43   24   45    4]
 [   8   14   35  809    2   45    6   17   73   25]
 [   4   14    5    0  818    1   24    4   17  117]
 [  29   41   13  127   21  554   24    9   36   23]
 [  21   23   22    1    9   25  862    0    4    0]
 [  17   27   15    0   17    1    0  823   18   68]
 [   7   44   14   85    8   31   10    8  728   38]
 [  19   19    9   12   66    4    1   53   33  820]]
             precision    recall  f1-score   support

        0.0       0.87      0.89      0.88       983
        1.0       0.83      0.94      0.88      1161
        2.0       0.85      0.75      0.79       979
        3.0       0.74      0.78      0.76      1034
        4.0       0.84      0.81      0.83      1004
        5.0       0.78      0.63      0.70       877
        6.0       0.86      0.89      0.87       967
        7.0       0.87      0.83      0.85       986
        8.0       0.73      0.75      0.74       973
        9.0       0.75      0.79      0.77      1036

avg / total       0.81      0.81      0.81     10000
統計はよく知らない民なのでprecision,recallでググるとwikipediaの画像が非常にわかりやすいので見ると良いかと思います.まぁ普段生活する中で言う「的中率」というのはrecallだと思う….

幸いprecision,recall,f1-scoreも81%なので,とりあえず結果は81%ってことですね.mnistの公式サイトを見ると圧倒的最下位ですね.現状1位はConvolutional neural networkの99.77%ッスか.もうこの域まで来ると人よりも精度がいいらしいのでもう関係ないような気がしますがね.

5.感想とか

こうしてみるとやっぱりニューラルネットすげーってなるわけですが,もう少しニューラルネットを使わずに精度を上げられないかと思われるわけですが,ミスした画像を見ればどうすればいいかはある程度わかります.
一番左は平均画像.周りが判定ミスをした画像たちです.数字が書かれている位置が違っていたり,ある程度傾いていたりすると平均画像との相関係数だけだと限界があるわけですね.
もし更に精度を上げるには傾きや位置によらない特徴量を抽出してやる必要がありそうです.その特徴量ベクトルに対して相関係数を用いるのもよし,ユークリッド距離を用いるのもよし,別の手法でやってみるのも良しです.

ちなみにKaggleにはDigit Recognizerというのがあって,Forumとかを見れば結構わかりやすいコードだとかがありますし,自分で実装してmake a submissionで結果を送信するのも楽しいです.
でも大体みんなForumにある頭のいい人が書いたソースコードを実行して送信してるだけみたいな人が多いのでオリジナルのコードで上位に食い込むにはやっぱりディープラーニングとかを実装するしかない気がします.

6.参考

[1] MNIST handwritten digit database, Yann LeCun, CorinnaCortes and Chris Burges
[2] 酒井幸市 著,『画像処理とパターン認識入門』,森北出版株式会社
[3] 多層パーセプトロンで手書き数字認識 - 人工知能に関する断創録