Contents
以下の手書き数字認識では、MNISTデータセットという手書き数字画像のデータセットを用います。
MNISTデータセットは以下のような画像データセットです。
今回使用する手書き数字認識で用いられる学習データやパラメータは、「ゼロから作るDeep Learning」に載っているものを参考にしています。
参照元[https://github.com/oreilly-japan/deep-learning-from-scratch]
バッチ処理の利点
バッチ処理とは、複数の入力を大きな配列とし処理することで効率よく計算できるようにする処理のことである。
バッチ処理を用いることで、1枚あたりの処理時間を大幅に短縮できるという利点がある。
その理由は、数値計算を扱うライブラリの多くは、大きな配列の計算を効率良く処理できるような高度な最適化が行われているからである。ニューラルネットワークの計算において、データ転送がボトルネックになる場合にはバッチ処理を用いることで負荷を軽減することも可能になります。
手書き数字認識のプログラム
- 今回はパラメータの最適化は置いといて、学習済みのパラメータの読み込んで数字認識を実行しています。
以下のプログラムでは、学習済みのパラメータを用いてテスト画像を推論し、テストラベルと一致するか否かの認識精度を求めています。
実行すると、認識精度は「Accuracy : 0.9352」と表示されます。
▼neuralnet_mnist_batch.py
# coding: utf-8 from common.functions import sigmoid, softmax from dataset.mnist import load_mnist import pickle import numpy as np import sys import os sys.path.append(os.pardir) # 親ディレクトリのファイルをインポートするための設定 def get_data(): # 画像データの読み込み (x_train, t_train), (x_test, t_test) = load_mnist( normalize=True, flatten=True, one_hot_label=False) return x_test, t_test def init_network(): # 学習済みデータの読み込み with open(“sample_weight.pkl”, ‘rb’) as f: network = pickle.load(f) return network def predict(network, x): w1, w2, w3 = network[‘W1’], network[‘W2’], network[‘W3’] b1, b2, b3 = network[‘b1’], network[‘b2’], network[‘b3’] a1 = np.dot(x, w1) + b1 z1 = sigmoid(a1) a2 = np.dot(z1, w2) + b2 z2 = sigmoid(a2) a3 = np.dot(z2, w3) + b3 y = softmax(a3) return y x, t = get_data() network = init_network() batch_size = 100 # バッチの数 accuracy_cnt = 0 for i in range(0, len(x), batch_size): x_batch = x[i:i+batch_size] y_batch = predict(network, x_batch) p = np.argmax(y_batch, axis=1) accuracy_cnt += np.sum(p == t[i:i+batch_size]) print(“Accuracy:” + str(float(accuracy_cnt) / len(x)) |
以下のプログラムは、上記プログラムでインポートに必要なプログラムです。
▼functions.py
# coding: utf-8 import numpy as np def sigmoid(x): return 1 / (1 + np.exp(-x)) def softmax(x): if x.ndim == 2: x = x.T x = x – np.max(x, axis=0) y = np.exp(x) / np.sum(np.exp(x), axis=0) return y.T x = x – np.max(x) # オーバーフロー対策 return np.exp(x) / np.sum(np.exp(x)) |
▼mnist.py
# coding: utf-8 try: import urllib.request except ImportError: raise ImportError(‘You should use Python 3.x’) import os.path import gzip import pickle import os import numpy as np url_base = ‘http://yann.lecun.com/exdb/mnist/’ key_file = { ‘train_img’:’train-images-idx3-ubyte.gz’, ‘train_label’:’train-labels-idx1-ubyte.gz’, ‘test_img’:’t10k-images-idx3-ubyte.gz’, ‘test_label’:’t10k-labels-idx1-ubyte.gz’ } dataset_dir = os.path.dirname(os.path.abspath(__file__)) save_file = dataset_dir + “/mnist.pkl” train_num = 60000 #訓練画像 test_num = 10000 #テスト画像 img_dim = (1, 28, 28) #1チャンネル、28 x 28 img_size = 784 def load_mnist(normalize=True, flatten=True, one_hot_label=False): if not os.path.exists(save_file): init_mnist() with open(save_file, ‘rb’) as f: dataset = pickle.load(f) if normalize: for key in (‘train_img’, ‘test_img’): dataset[key] = dataset[key].astype(np.float32) dataset[key] /= 255.0 if one_hot_label: dataset[‘train_label’] = _change_one_hot_label(dataset[‘train_label’]) dataset[‘test_label’] = _change_one_hot_label(dataset[‘test_label’]) if not flatten: for key in (‘train_img’, ‘test_img’): dataset[key] = dataset[key].reshape(-1, 1, 28, 28) return (dataset[‘train_img’], dataset[‘train_label’]), (dataset[‘test_img’], dataset[‘test_label’]) if __name__ == ‘__main__’: init_mnist() |
コメント