読者です 読者をやめる 読者になる 読者になる

TensorFlowを使用して手書きの数字を学習して判別する


 最近、オープンソース化して話題になっているTensorFlowを使用して多くの人たちが面白いことをしているので、自分は機械学習をやったことがなかったので流行りだし触ってみようかと思い、やってみたら早速つまずいたりしたので、ブログとして取り上げてみました。japanese.engadget.com

環境構築
 TensorFlowはPythonにより動くため、まずPythonのインストールが必要です。MacPythonを使用・管理する場合は、pyenvがオススメです。
 Pythonには2.系と3.系というややこしい内輪揉めがあるため、pyenvにどちらのバージョンを入れ管理することができるのでオススメしています。
 Pythonのインストールはこちらを参考にして導入しました。breakbee.hatenablog.jp
dev.classmethod.jp

 そしてその後にpipというものをインストールするのですが、そこで問題発生。
error 13 permission denied”というものが出てきてpipのインストールができなかったため、以下のサイトを参考にしました。d.hatena.ne.jp
結局、Macの既存のPythonのライブラリーが邪魔をしてきたみたいだったので、

# 従来のPythonの削除
$ rm -rf /Library/Python/*

# Pythonのバージョンをpyenvの2.7.10に指定
$ pyenv global 2.7.10

# pipのインストール
$ easy_install pip

すると、うまくpipのインストールができました。
その後、チュートリアルにあるようにインストールをしてもらうと

 pip install https://storage.googleapis.com/tensorflow/mac/tensorflow-0.5.0-py2-none-any.whl

TensorFlowがマシンにインストールできます。

Hello, World!
実際に、TensorFlowが動いているか確かめてみましょう。
試しに、画面上にHello, World!を表示してみました。

#pythonの起動
$ python
>>> import tensorflow as tf
>>> hello = tf.constant('Hello, World!')
>>> sess = tf.Session()
>>> print sess.run(hello)
Hello, World!

ちゃんと動いていることが確認できますね!

手書きの数字を学習して判別を行う
手書きの数字の画像データを利用して、識別できるようにしてみましょう。下記のような画像から、その数字は何を表しているのか、機械に学習してみます。
f:id:yuta-horn:20151120131235p:plain
このような学習のことをMNISTと呼びます。

MNISTはチュートリアルとして既にGithub上にすべて公開されているため、それを利用してみましょう。

# Githubから取得
$ git clone --recurse-submodules https://github.com/tensorflow/tensorflow

# 実行
$ python tensorflow/g3doc/tutorials/mnist/fully_connected_feed.py
Traceback (most recent call last):
  File "fully_connected_feed.py", line 23, in <module>
    from tensorflow.g3doc.tutorials.mnist import input_data
ImportError: No module named g3doc.tutorials.mnist

実行してみるとあれれ、エラーが出てしまっていますね。
そこで、fully_connected_feed.pyの1部を修正してみます。

#削除
23行目 from tensorflow.g3doc.tutorials.mnist import input_data
24行目 from tensorflow.g3doc.tutorials.mnist import mnist
#追加
23行目 import input_data
24行目 import mnist

再度実行してみましょう。

$ python fully_connected_feed.py
Successfully downloaded train-images-idx3-ubyte.gz 9912422 bytes.
Extracting data/train-images-idx3-ubyte.gz
Successfully downloaded train-labels-idx1-ubyte.gz 28881 bytes.
Extracting data/train-labels-idx1-ubyte.gz
Successfully downloaded t10k-images-idx3-ubyte.gz 1648877 bytes.
Extracting data/t10k-images-idx3-ubyte.gz
Successfully downloaded t10k-labels-idx1-ubyte.gz 4542 bytes.
Extracting data/t10k-labels-idx1-ubyte.gz
can't determine number of CPU cores: assuming 4
I tensorflow/core/common_runtime/local_device.cc:25] Local device intra op parallelism threads: 4
can't determine number of CPU cores: assuming 4
I tensorflow/core/common_runtime/local_session.cc:45] Local session inter op parallelism threads: 4
Step 0: loss = 2.29 (0.034 sec)
Step 100: loss = 2.15 (0.005 sec)
....
Training Data Eval:
  Num examples: 55000  Num correct: 49288  Precision @ 1: 0.8961
Validation Data Eval:
  Num examples: 5000  Num correct: 4528  Precision @ 1: 0.9056
Test Data Eval:
  Num examples: 10000  Num correct: 9017  Precision @ 1: 0.9017

おぉーできています!
最終的には約90%の制度で識別することができたようです。