損失関数とは?ニューラルネットワークの学習理論【機械学習】
IT技術
損失関数とは?
ニューラルネットワークの学習フェーズでは、的確な推論を行うために最適な各パラメータ(重みやバイアス)を決定します。
このとき、最適なパラメータに近づくための指標となるのが「損失関数(loss function)」になります。
これは、「目標」と「実際」の出力の誤差を表すものです。
損失関数をイメージで例えると
もし、スイカ割りをする人を、ニューラルネットワークにより制御できるとした場合で考えてみます。
目的は、今いる地点からスイカのある場所に移動し、スイカを割ることです。
この場合、スイカからどれだけズレているか?を表すのが「損失関数」であるということができます。
この、「スイカからのズレ(損失関数)」が、最も小さくなるようなパラメータを探索していく事で、最終的にスイカのある地点に辿り着き、見事スイカを割る事ができるでしょう。
このように、ニューラルネットワークの学習では、損失関数が最も小さくなるパラメータを探し出します。
損失関数で用いる2つの関数
では、損失関数には、どんな関数が使われるのでしょうか?
一般に機械学習では、損失関数として「2乗和誤差」や「交差エントロピー誤差」などが用いられます。
今回は、この2つを解説していきたいと思います!
2乗和誤差
2乗和誤差(sum of squared error)は、以下の式で表されます。
:ニューラルネットワークの出力
:正解データ
:データの次元数
2乗和誤差は、ニューラルネットワークの「出力」と「正解データ」の差の、2乗の総和を求めます。
この総和が小さいほど、出力と正解データの誤差が少ないということになります。
この2乗和誤差を、Python で実装し、実際に損失関数を計算してみましょう。
損失関数をPythonで計算
1import numpy as np
2#2乗和誤差(sum of squared error)の実装
3def sum_squared_error(y, t):
4 return 0.5 * np.sum((y - t) ** 2)
5#正解データ (one-hot表現)
6t = np.array([0, 0, 0, 1, 0])
7#ニューラルネットワークの出力 (y1: 正解, y2: 不正解)
8y1 = np.array([0.1, 0.05, 0.0, 0.7, 0.15])
9y2 = np.array([0.5, 0.2, 0.0, 0.1, 0.2 ])
10#2乗和誤差による損失関数の計算
11y1_loss = sum_squared_error(y1, t)
12y2_loss = sum_squared_error(y2, t)
13print('2乗和誤差による損失関数の計算結果')
14print('y1(出力が正解の場合) : {:.4f}'.format(y1_loss))
15print('y2(出力が不正解の場合) : {:.3f}'.format(y2_loss))
2乗和誤差による損失関数の計算結果
1 (出力が正解の場合) : 0.0625
2 (出力が不正解の場合) : 0.570
5つの要素からなる、「出力 」および、「正解データ 」を用意し、定義した「sum_squared_error関数」で計算しています。
ここで、ニューラルネットワークの「出力 」は、その要素である確率データ(ソフトマックス関数による出力)に。
そして、「正解データ 」は、正解ラベルを「1」として、それ以外を「0」とする、one-hotベクトルとなります。
また「出力 」は、正解の場合と、不正解の場合の2つ用意し、これらの損失関数の値の違いをみています。
出力が正解の場合の方が、損失関数の値が小さく、正解データとの誤差が少ない適切な出力が得られていることがわかりますね。
交差エントロピー誤差
次に、「交差エントロピー誤差(cross entropy error)」です。
計算式は、以下のようになります。
「 」は、底が e の自然対数です。
こちらも実装し、2乗和誤差と同じデータで、同様に計算してみましょう。
損失関数をPythonで計算
1#交差エントロピー誤差(cross entropy error)の実装
2#log(0)の場合に負の無限大に発散することを防ぐため、微小値deltaを導入している
3def cross_entropy_error(y, t):
4 delta = 1e-7
5 return -np.sum(t * np.log(y + delta))
6#交差エントロピー誤差による損失関数の計算
7y1_loss = cross_entropy_error(y1, t)
8y2_loss = cross_entropy_error(y2, t)
9print('交差エントロピー誤差による損失関数の計算結果')
10print('y1(出力が正解の場合) : {:.4f}'.format(y1_loss))
11print('y2(出力が不正解の場合) : {:.3f}'.format(y2_loss))
交差エントロピー誤差による 損失関数の計算結果
1 (出力が正解の場合) : 0.3567
2 (出力が不正解の場合) : 2.303
こちらの計算結果からも、「出力」と「正解データ」との誤差を、損失関数として取得することができています。
この交差エントロピー誤差は、各要素における、「出力の自然対数 」と、「正解データ 」との積を計算しています。
ですが、正解データは、正解ラベルのみが「1」で、その他は「0」の、one-hotベクトルです。
なので実質的には、正解ラベルに対応する出力の値の、自然対数を計算することで得られます。
出力の自然対数を計算し「損失関数」が得られるのはなぜ?
なぜ、出力の自然対数を計算することで、損失関数が得られるか?
それは、自然対数 は、 が1の時に「0」となり、逆に が「0」に近づくと、どんどん小さくなるという性質を有しているからです。
つまり、自然対数の絶対値をとれば、ニューラルネットワークの出力 に対する のグラフは、以下のようになります。
のグラフ
1#-logyのグラフ描画
2%matplotlib inline
3import matplotlib.pyplot as plt
4y = np.arange(0, 1.01, 0.01)
5delta = 1e-7
6loss = -np.log(y + delta)
7plt.plot(y, loss)
8plt.xlim(0, 1)
9plt.xlabel('output y')
10plt.ylim(0, 5)
11plt.ylabel('- log y')
12plt.show()
出力 が「1」に近づくと(確率が1に近づくと)、損失関数の値は「0」に近づきます(正解との誤差が小さい)。
逆に、出力 が小さければ(確率が低ければ)、損失関数の値は大きくなる(正解との誤差が大きい)、ということになります。
このように、「交差エントロピー誤差」は、自然対数の性質を用いて損失関数を導きます。
損失関数の必要性と微分
次に、なぜ損失関数が必要か?という事に、少し触れていきます。
例えば、料理をする時には、「調味料を少しずつ加え」、「味見をしながら」全体の味を調節して完成に近づけていきます。
といった感じで、ニューラルネットにおける学習では、パラメータを微小に変化させた時に起こる、損失関数の微小な変化を確認しながら、損失関数を小さくする方向へとパラメータを更新していきます。
これは、「パラメータの微分」を行っていると言えます。
損失関数は微分が0にならない
重要なのは、損失関数はどの場所においても、微分が「0」(微小変化が0)にならないことです。
単純な認識精度を、損失関数のような指標にすることができないのは、パラメータの微分が、ほとんどの場所で「0」になってしまうからです。
パラメータの微小変化では、認識精度はほとんど変化しないため、パラメータ調整の指標にはならないのです。
勾配法へとつながる
ここで話したことは、ニューラルネットワークの重要な学習理論の基礎です。
また、本記事では触れていませんが、損失関数の最小値を探索するために用いられる「勾配法」に繋がっていきます。
「勾配法」により、損失関数の最小化が完了すると、最適化されたネットワークのパラメータが設定され、学習が完了します。
手書き数字データによる学習と損失関数最小化の実装例
最後に、実際にニューラルネットワークモデルを用いた学習と、損失関数最小化の実装例をご紹介します。
実装には、ディープラーニングフレームワーク「Chainer」を用いて、MNISTの手書き数字データセットの学習を行いました。
実装内容については、コード内コメントや、Chainer チュートリアルを参考にして下さい。
【Chainerチュートリアル】
https://tutorials.chainer.org/ja/
コード
1import chainer
2from chainer import datasets, optimizers
3from chainer.datasets import split_dataset_random
4from chainer.iterators import SerialIterator
5import chainer.links as L
6import chainer.functions as F
7%matplotlib inline
8import matplotlib.pyplot as plt
9import numpy as np
10#mnistデータセットの読み込み
11train, test = datasets.get_mnist()
12#データを訓練用と検証用に分割
13train, valid = split_dataset_random(train, int(len(train) * 0.8), seed=0)
14#訓練データでバッチサイズ100のイテレータを作成
15train_iter = SerialIterator(train, batch_size=100, repeat=True, shuffle=True)
16#ニューラルネットワークモデルの定義
17class Net(chainer.Chain):
18 def __init__(self, n_in=784, n_hidden=100, n_out=10):
19 super().__init__()
20 with self.init_scope():
21 self.l1 = L.Linear(n_in, n_hidden)
22 self.l2 = L.Linear(n_hidden, n_hidden)
23 self.l3 = L.Linear(n_hidden, n_out)
24
25 def forward(self, x):
26 h = F.relu(self.l1(x))
27 h = F.relu(self.l2(h))
28 h = self.l3(h)
29 return h
30
31net = Net()
32#最適化手法にSGDを選択(学習率:0.1)
33optimizer = optimizers.SGD(lr=0.1)
34optimizer.setup(net)
35#学習
36gpu_id = 0 # 使用する GPU 番号
37n_batch = 100 # バッチサイズ
38n_epoch = 30 # エポック数
39# ネットワークを GPU メモリ上に転送
40net.to_gpu(gpu_id)
41# ログ
42results_train, results_valid = {}, {}
43results_train['loss'], results_train['accuracy'] = [], []
44results_valid['loss'], results_valid['accuracy'] = [], []
45for epoch in range(n_epoch):
46 while True:
47 # ミニバッチの取得
48 train_batch = train_iter.next()
49 # x と t に分割
50 # データを GPU に転送するために、concat_examples に gpu_id を渡す
51 x_train, t_train = chainer.dataset.concat_examples(train_batch, gpu_id)
52 # 予測値と損失関数(交差エントロピー誤差)の計算
53 y_train = net(x_train)
54 loss_train = F.softmax_cross_entropy(y_train, t_train)
55 acc_train = F.accuracy(y_train, t_train)
56 # 勾配の初期化と勾配の計算
57 net.cleargrads()
58 loss_train.backward()
59 # パラメータの更新
60 optimizer.update()
61 # 1エポック終えたら、valid データで評価する
62 if train_iter.is_new_epoch:
63 # 検証用データに対する結果の確認
64 with chainer.using_config('train', False), chainer.using_config('enable_backprop', False):
65 x_valid, t_valid = chainer.dataset.concat_examples(valid, gpu_id)
66 y_valid = net(x_valid)
67 loss_valid = F.softmax_cross_entropy(y_valid, t_valid)
68 acc_valid = F.accuracy(y_valid, t_valid)
69 # CPU上に転送
70 loss_train.to_cpu()
71 loss_valid.to_cpu()
72 acc_train.to_cpu()
73 acc_valid.to_cpu()
74 # 可視化用に保存
75 results_train['loss'] .append(loss_train.array)
76 results_train['accuracy'] .append(acc_train.array)
77 results_valid['loss'].append(loss_valid.array)
78 results_valid['accuracy'].append(acc_valid.array)
79 break
今回は、「学習データ」と「検証データ」による損失関数と、精度の結果をグラフに示します。
エポック数が増えるにつれて、損失関数の値が小さくなっていき、学習が進んでいることが分かりますね。
同様に精度も向上しています。
損失関数
精度
さいごに
最後に登場したような「機械学習フレームワーク」を用いると、細かい理論や、考え方の理解なしでも、比較的容易にモデルの構築が可能になります。
ですが、「なぜそうなるのか?」という所に踏み込むのも、大切で面白いことだと思います。
今回の、損失関数もその一つです。
ニューラルネットワークの学習の基礎となるので、理解しておくことをオススメします。
こちらの記事もオススメ!
2020.07.28機械学習 特集知識編人工知能・機械学習でよく使われるワード徹底まとめ!機械学習の元祖「パーセプトロン」とは?【人工知能】ニューラルネ...
2020.07.17ライトコード的「やってみた!」シリーズ「やってみた!」を集めました!(株)ライトコードが今まで作ってきた「やってみた!」記事を集めてみました!※作成日が新し...
ライトコードでは、エンジニアを積極採用中!
ライトコードでは、エンジニアを積極採用しています!社長と一杯しながらお話しする機会もご用意しております。そのほかカジュアル面談等もございますので、くわしくは採用情報をご確認ください。
採用情報へ
「好きを仕事にするエンジニア集団」の(株)ライトコードです! ライトコードは、福岡、東京、大阪の3拠点で事業展開するIT企業です。 現在は、国内を代表する大手IT企業を取引先にもち、ITシステムの受託事業が中心。 いずれも直取引で、月間PV数1億を超えるWebサービスのシステム開発・運営、インフラの構築・運用に携わっています。 システム開発依頼・お見積もり大歓迎! また、現在「WEBエンジニア」「モバイルエンジニア」「営業」「WEBデザイナー」「WEBディレクター」を積極採用中です! インターンや新卒採用も行っております。 以下よりご応募をお待ちしております! https://rightcode.co.jp/recruit
おすすめ記事
浮動小数点について調べてみた
2024.09.09