• トップ
  • ブログ一覧
  • 【前編】自作の誤差逆伝播学習法で手書き数字を認識させてみよう!【機械学習】
  • 【前編】自作の誤差逆伝播学習法で手書き数字を認識させてみよう!【機械学習】

    広告メディア事業部広告メディア事業部
    2019.06.11

    IT技術

    前編〜手書き数字を認識するプログラムを作る~

    誤差逆伝播学習法は、教師信号とネットワークの実際の出力との誤差情報と勾配降下法を用いてネットワークを学習させる代表的な機械学習手法です。

    今回は、機械学習の要「誤差逆伝播学習法」を解説・実装してみる【人工知能】の記事で作成したコードを元に「手書き数字を認識するプログラム」を作ってみましょう!

    この記事での最終的なプログラムのイメージとしては、以下のように、自分で書いた手書き数字をネットワークに判別させるといった具合です。

    【成果物イメージ】

    手書き数字データセット

    今回用いるデータセットは、「UCI Optical Recognition of Handwritten Digits」と呼ばれる、8×8の小さな手書き数字データセットです。

    よく見る28×28のMNISTデータセットと違い、小規模なので扱いやすいのが特徴です。

    (※本当はMNISTを使うと学習に時間がかかってしまうからという理由もあります...)

    UCI Optical Recognition of Handwritten Digits

    それでは上記URLから、「optdigits.tra」(訓練データ)と「optdigits.tes」(テストデータ)をダウンロードして「OptDigits」というディレクトリに入れておきましょう!

    それでは、実際にコードを加筆・修正していきましょう!

    修正点1:データセット読み込み

    まずは、データセットが前回とは異なるので、関数を書き換えましょう。

    Optical Recognition of Handwritten Digits(以下OptDigits)は、3823枚の訓練画像、1797枚のテストデータからなります。

    今回は訓練データ全てを使って学習させ、最後にテストデータ全てを使ってテスト精度を計測する形を取ろうかと思います。

    また、学習が進むごとに訓練データを使って訓練精度を計測して学習進行具合も可視化してみましょう。

    前回の「Irisデータセット」を読み込む関数と大きく違う部分は、2種類のデータセットを扱うことぐらいでしょうか。

    また、OptDigitsは8×8で[0,16]の画素値が特徴パターンとなりますが、数値に差がないようにするために更に[0,1.0]の実数値に正規化して用いることにします。

    以上の正規化は、ラムダ式を利用してコーディングしてみました。

    (string型をfloat型に変更する必要もあるため)

    【OptDigitsの特徴パターン】

    実装

    実装は以下のようになります。

    1    def load_optdigits(self):
    2        train = open('OptDigits/optdigits.tra', 'r')  # 訓練データ
    3        test = open('OptDigits/optdigits.tes', 'r')  # テストデータ
    4
    5        # 訓練データ
    6        lines = train.read().split()
    7
    8        # データをランダムにシャッフル
    9        random.shuffle(lines)
    10
    11        dataset = ([])
    12        for line in lines:
    13            pattern = line.split(',')
    14            dataset.append(pattern)
    15        train.close()
    16
    17        for pat in dataset:
    18            # 入力は[0,1]に正規化する
    19            self.patterns.append(list(map(lambda x: float(x)*(1.0/16.0), pat[0:-1])))
    20
    21            # OptDigitsは最後にラベル[0,1,2,...,9]がある
    22            self.labels.append(int(pat[-1]))
    23
    24        # テストデータ
    25        lines = test.read().split()
    26
    27        # データをレンダムにシャッフル
    28        random.shuffle(lines)
    29
    30        dataset = ([])
    31        for line in lines:
    32            pattern = line.split(',')
    33            dataset.append(pattern)
    34        test.close()
    35
    36        for pat in dataset:
    37            self.test_patterns.append(list(map(lambda x: float(x)*(1.0/16.0), pat[0:-1])))
    38
    39            # OptDigitsは最後にラベル[0,1,2,...,9]がある
    40            self.test_labels.append(int(pat[-1]))

    修正点2:自分で書いた手書き数字をネットワークに流すための関数

    次に、自分で書いた手書き数字をネットワークに流すために、単純な画像処理を施す必要があります。

    今回は、「おえかきボード - ブラウザでかんたんお絵かき -」というサイトで200×200のキャンパスに以下のように数字を書いて、pngファイルで保存した画像を使うことにしましょう。

    【おえかきボードで手書き数字を書く】

    このような方法ですので、もちろん自分で書いた手書き数字は200×200のサイズでネットワークには流せません。

    したがって画像をリサイズします。

    また、一応カラー画像扱いなので、グレースケールにも変換します。

    これらの処理は、Pillow(PIL)という画像処理ライブラリを使います。

    (未インストールであれば、pip等でインストールしてください。pip install pillow )

    まずは、ライブラリのインポートを書き加えます。

    1from PIL import Image, ImageOps

    実装

    そして、以下の工程を実装します。

    1、画像のグレースケール化
    2、画像のリサイズ
    3、ネガポジ(白黒)反転
    4、ネットワークに流すためにList化

    ネガポジ変換する理由は、OptDigits(背景: 黒 数字: 白)に合わせるためです。

    実装は以下のようになります。

    1    def prop_my_digits(self, img_path):
    2        """
    3        自分で作った画像をネットワークに流して出力を得る関数。
    4        :param img_path:
    5        :return:
    6        """
    7        img = Image.open(img_path).convert('L')  # グレースケールで画像を読み込む
    8        resized_img = img.resize((8, 8))  # 画像リサイズ
    9        input_img = ImageOps.invert(resized_img)  # ネガポジ(白黒)反転
    10        array = np.array(input_img) * (1.0/255.0)  # [0,1]に変換
    11
    12        input_pattern = ([])
    13        for h in array:
    14            for w in h:
    15                input_pattern.append(w)  # 1次元の配列に変換
    16        ans = np.array(self.forward(input_pattern)).argmax()  # 出力値の大きいニューロンのインデックスを取得
    17        print(img_path + ' is ', ans)  # ネットワークの識別結果を出力

    修正点3:部分修正

    あとは少しだけコードを修正します。

    まず、δを計算するcalc_delta() で、教師ニューロンをIrisのときは3つでしたが、10クラス分に変更します。

    1#  teacher = ([0.1, 0.1, 0.1])
    2teacher = [0.1] * 10  # 10クラス分の教師ニューロンを作成

    次に、test() は検証(訓練)精度を計測する関数になるので、名前をvalidate() に変更します。

    そして新たに、テスト精度を計測するtest() 関数を作成します。

    内容は、ほとんど変わりません。

    1    # 関数名を変更
    2    def validate(self):
    3        """
    4        訓練精度を計算
    5        :return: accuracy (%)
    6        """
    7        correct = 0
    8        for p in range(len(self.patterns)):
    9            self.forward(self.patterns[p])
    10            max = 0
    11            ans = -1
    12            # 一番出力値の高いニューロンを取ってくる
    13            for o, out in enumerate(self.outputs[len(self.layers)-1]):
    14                if max < out:
    15                    max = out
    16                    ans = o
    17            # もしそのニューロンの番号とラベルの番号があっていれば正解!
    18            if ans == self.labels[p]:
    19                correct += 1
    20
    21        accuracy = correct / len(self.patterns) * 100
    22        return accuracy
    23
    24    # New!
    25    def test(self):
    26        """
    27        テスト精度を計算
    28        :return: accuracy (%)
    29        """
    30        correct = 0
    31        for p in range(len(self.test_patterns)):  # テストパターン
    32            self.forward(self.test_patterns[p])
    33            max = 0
    34            ans = -1
    35            # 一番出力値の高いニューロンを取ってくる
    36            for o, out in enumerate(self.outputs[len(self.layers)-1]):
    37                if max < out:
    38                    max = out
    39                    ans = o
    40            # もしそのニューロンの番号とラベルの番号があっていれば正解!
    41            if ans == self.test_labels[p]:
    42                correct += 1
    43
    44        accuracy = correct / len(self.patterns) * 100
    45        return accuracy

    変更点は以上になります!

    後編へつづく!

    後編はこちら

    featureImg2019.06.14【後編】自作の誤差逆伝播学習法で手書き数字を認識させてみよう!【機械学習】後編〜手書き数字を認識するプログラムを作る~今回は、機械学習の要「誤差逆伝播学習法」を解説・実装してみる【人工知能】の...

    こちらの記事もオススメ!

    featureImg2020.07.28機械学習 特集知識編人工知能・機械学習でよく使われるワード徹底まとめ!機械学習の元祖「パーセプトロン」とは?【人工知能】ニューラルネ...

    featureImg2020.07.17ライトコード的「やってみた!」シリーズ「やってみた!」を集めました!(株)ライトコードが今まで作ってきた「やってみた!」記事を集めてみました!※作成日が新し...

    広告メディア事業部

    広告メディア事業部

    おすすめ記事