1. HOME
  2. ブログ
  3. IT技術
  4. ディープラーニングの勾配計算についてまとめてみた【誤差逆伝播法】

ディープラーニングの勾配計算についてまとめてみた【誤差逆伝播法】

はじめに

ChatGPTやDeepLなどディープラーニングを用いた製品やサービスはたくさんありますが、その仕組みを全く知りませんでした。

「その仕組みを知りたい!」という気持ちからディープラーニングについて学習し始めたので、アウトプットも兼ねて今回、誤差逆伝播法の勾配を計算する部分について紹介したいと思います。ディープラーニングを勉強すると最初に躓く部分だと思うので、丁寧に説明できればと思います。

ディープランニング概要

ディープラーニングのモデルの学習の流れは以下の1~5のサイクルになっています。これらを繰り返すことで、より精度の高い出力ができるパラメータに更新し、精度の高いモデルを生成していきます。

  1. データを入力
  2. データを出力
  3. 損失を計算
  4. 勾配を計算
  5. パラメータを更新

今回は上記4番の「計算した損失から勾配を計算する仕組み」を紹介します。

誤差逆伝播法

これから紹介するMNIST画像の分類問題を行う中間層が2層のモデル(入力 → 中間層1 → 中間層2 → 出力 → 損失)を題材にします。

MNISTは手書き数字(0から9までの整数 = 10種類)の画像のデータセットです。

また、MNIST画像は( 28 \times 28) のグレー画像で、各ピクセルは0から255までの値をとります。

1画像につき784ピクセルのデータがあるので、入力データは ( N \times 784) ベクトルになります( Nはバッチ数です )。

入力 ( input ) から1つ目の中間層 ( layer1 ) への重み ( ( w_1 ) ) の次元は ( 784 \times 64) 、バイアス ( ( b_1 ) ) の次元は ( 1 \times 64) 、活性化関数は ReLU関数 としました。

ですので、中間層1の次元は ( N \times 64) になります。

1つ目の中間層から2つ目の中間層( layer2 )への重み ( ( w_2 ) ) の次元は ( 64 \times 64) 、バイアス ( ( b_2 ) ) の次元は ( 1 \times 64) 、活性化関数は ReLU関数 としました。

ですので、中間層2の次元は ( N \times 64) になります。

本モデルは10種類の分類問題なので、出力 ( output ) の次元を ( N \times 10) にする必要があります。

なので、2つ目の中間層から出力への重み ( ( w_3 ) ) の次元は ( 64 \times 10) 、バイアス ( ( b_3 ) ) の次元は ( 1 \times 10) 、活性化関数は softmax関数 としました。

また、損失関数には多値分類タスクによく使われるクロスエントロピー誤差を用います。( t ) は正解データを表しています。

パラメータ \( w_3, b_3 \) の更新のための勾配計算

まず、計算された損失 ( L )  を受けて、パラメータ ( w_3, b3 ) をどのように更新するか、つまり ( \frac{\partial L}{\partial w{3}}, \frac{\partial L}{\partial b_{3}} ) を求めていきます。

( \frac{\partial L}{\partial w{3}}, \frac{\partial L}{\partial b{3}} )を求める流れは以下のようになっています。

  1. \( \frac{\partial L}{\partial output} \) を求める。
  2. \( \frac{\partial L}{\partial y_{3}} \) を求める。
  3. \( \frac{\partial L}{\partial w_{3}}, \frac{\partial L}{\partial b_{3}} \) を求める。

では、最初に( \frac{\partial L}{\partial output} ) を求めます。損失は以下のような式で導出されます。

$$L = fc(output, t) = - \sum{i}t{i} \cdot log(output{i})$$

ただし、(f_c ) はクロスエントロピー誤差を表しています。クロスエントロピーにより ( N \times 10) ベクトルを ( N \times 1) ベクトルの損失に変換しています。

クロスエントロピーの定義
$$ L = f_c(x, t ) - \sum_{k}t_{k} \cdot log(x_{k}) $$
ただし、\( x\) はモデルの出力値、\( t\) はそれに対応する正解値。

自然対数の微分(( \frac{\partial}{\partial x} log x = \frac{1}{x}))より、損失に対する ( output_{i}) の微分は以下のようにかけます。

$$\frac{\partial L}{\partial output{i}} = - \frac{t{i}}{output_{i}}$$

次に、( \frac{\partial L}{\partial y_{3}} ) を求めます。( output ) は以下のような式で導出されます。

$$output = f{s}(y{3}) $$

ただし、(f_s ) はソフトマックス関数を表しています。ソフトマックス関数により ( N \times 10) ベクトルの ( y_3 ) を10種類の数字の確率を表す ( N \times 10) ベクトルに変換しています。

$$ \frac{\partial L}{\partial y{3}} =  \frac{\partial L}{\partial output} \cdot \frac{\partial output}{\partial y{3}} $$

より、(\frac{\partial L}{\partial output} ) は先ほど求めたので、ソフトマックス関数の微分 ( \frac{\partial output}{\partial y{3}}) を求めれば、( \frac{\partial L}{\partial y{3}} ) が求まります。

ソフトマックス関数の微分
ソフトマックス関数の定義より、\( 1 \times M\) ベクトルの出力 \( x \) のある要素 \( x_k \) からそれに対応する確率を導出する式は以下になっています。
$$ f_s(x_k) = \frac{e^{x_k}}{\sum_{j}e^{x_{j}}} $$
ここで \( S = \sum_{j}e^{y_{j}} \) とおくと、\( \frac{\partial S}{\partial x_{l}} = e^{y_{l}}\) とできます。
\( (i) k = l \) のとき
分数関数の微分公式 ( \(\frac{\partial \frac{g(x)}{h(x)}}{\partial x} = \frac{\frac{\partial g(x)}{\partial x}h(x) - g(x)\frac{\partial h(x)}{\partial x}}{h(x)^2}\) ) を用いることで、以下のように書けます。
$$\frac{\partial f_s(x_k)}{\partial {x_{l}}} = \frac{e^{x_{k}}S - e^{2x_{k}}}{S^2} = \frac{e^{x_{k}}}{S}(1 - \frac{e^{x_{k}}}{S}) = f_s(x_k)(1 - f_s(x_k))$$
\( (ii) k \neq l \) のとき
$$\frac{\partial f_s(x_k)}{\partial {x_{l}}} = - \frac{e^{x_{k}} \cdot e^{x_{l}}}{S^2} = - f_s(x_k)f_s(x_l)$$
これらより、ソフトマックス関数の微分は以下のようになります。
$$\frac{\partial f_s(x_k)}{\partial {x_{l}}} =\left\{
\begin{array}{ll}
f_s(x_k)(1 - f_s(x_k)) & (x = l)\\
- f_s(x_k)f_s(x_l) & (x \neq  l)
\end{array}
\right. $$

前述した式

$$ \frac{\partial L}{\partial y{3}} =  \frac{\partial L}{\partial output} \cdot \frac{\partial output}{\partial y{3}} $$

は連鎖律より以下のようにかけます。

$$ \frac{\partial L}{\partial {y{3}}{i}} = \sum{j}\frac{\partial L}{\partial output{j}} \cdot \frac{\partial output{j}}{\partial {y{3}}_{i}}$$

より、これまでの計算結果とソフトマックス関数の微分の場合分けを考慮すると以下のように書けます。

$$ \frac{\partial L}{\partial {y{3}}{i}} = \frac{\partial L}{\partial output{i}}output{i}(1 - output_{i}) + $$

$$ \sum{k \neq i}\frac{\partial L}{\partial output{k}}(- output{k}\cdot output{i}) $$

$$ = - \frac{t{i}}{output{i}}output{i}(1 - output{i}) +$$

$$ \sum{k \neq i}(- \frac{t{k}}{output{k}})(- output{k} \cdot output_{i})) $$

$$ = - t{i}(1 - output{i}) + output{i}\sum{k \neq i}t_{k} $$

$$ = - t{i}(1 - output{i}) + output{i}(1 - t{i}) = output{i} - t{i} $$

これにより、損失 ( L )に対する、( {y{3}}{i})の微分を求めることができました。

これまでの計算結果より、遂に ( \frac{\partial L}{\partial w{3}}, \frac{\partial L}{\partial b{3}} )  を求めることができます。

( y_3 ) は以下の式で表されます。

$$ y_3 = layer_2 \cdot w_3 + b_3$$

まず、( \frac{\partial L}{\partial w_{3}}) は以下のように導出できます(右上に添えてある ( T ) は転置を表しています)。

$$ \frac{\partial L}{\partial {w3}} =  \sum{j}\frac{\partial {y{3}}{j}}{\partial {w3}} \cdot \frac{\partial L}{\partial {y{3}}_{j}} = layer_2^T \cdot (output - t)$$

上式は ( 64 \times 10 = 64 \times N \cdot N\times 10 ) のように内積の順番と次元の整合性が取れていることがわかります。

次に、( \frac{\partial L}{\partial b_{3}}) は以下のように導出できます。

$$ \frac{\partial L}{\partial {b3}} =  \sum{j}\frac{\partial {y{3}}{j}}{\partial {b3}} \cdot \frac{\partial L}{\partial {y{3}}{j}} = \sum{j}(output_j - t_j) $$

上式は ( 1 \times 10 = 1 \times 10 ) のように内積の順番と次元の整合性が取れていることがわかります。

また、以降の計算のために ( \frac{\partial L}{\partial {layer2}} ) も求めます。
$$ \frac{\partial L}{\partial {layer
{2}}} = \frac{\partial L}{\partial {y{3}}} \cdot \frac{\partial {y{3}}}{\partial {layer_{2}}} = (output - t) \cdot w_3^T $$

上式は ( N \times 64 = N \times 10 \cdot 10 \times 64 ) のように内積の順番と次元の整合性が取れていることがわかります。

パラメータ \( w_2, b_2 \) の更新のための勾配計算

( \frac{\partial L}{\partial w{3}}, \frac{\partial L}{\partial b{3}} ) を求めることができたので、次に( \frac{\partial L}{\partial w{2}}, \frac{\partial L}{\partial b{2}} ) を求めます。

先ほどと違うのは、クロスエントロピー誤差を考慮しなくて良いこと。また、活性化関数がソフトマックス関数ではなくReLU関数であることです。

$$layer_2 = f_r(y_2) = f_r(layer_1 \cdot w_2 + b_2)$$

ただし、( f_r) はReLU関数を表しています。

まず、( \frac{\partial L}{\partial {y_2}}) を求めます。

$$\frac{\partial L}{\partial {y_2}} = \frac{\partial L}{\partial {layer_2}} \cdot \frac{\partial {layer_2}}{\partial {y_2}}$$

( \frac{\partial L}{\partial {layer_2}}) については前節で既に求めているので、( \frac{\partial {layer_2}}{\partial {y_2}}) を求めます。

ReLU関数の微分
\( v = f_r(x) \) について、微分 \(\frac{\partial v}{\partial {x}}\) を求めることを考える。
$$v =\left\{
\begin{array}{ll}
x & (x > 0)\\
0 & (x \leq  0)
\end{array}
\right. $$
ReLU関数の定義より \( v\) は上記のように表すことができる。
より、\(\frac{\partial v}{\partial {x}}\) は以下のようになる。
$$\frac{\partial v}{\partial {x}} =\left\{
\begin{array}{ll}
1 & (x > 0)\\
0 & (x \leq  0)
\end{array}
\right. $$

ReLU関数の微分より、( \frac{\partial {layer_2}}{\partial {y_2}}) は ( y_2) の要素が0より大きい時に1、小さい時に0をとる ( N \times 64) ベクトル(( y_2) と同じ形)であることがわかりました。

これを ( R_2) とおくと、(\frac{\partial L}{\partial {y_2}}) は以下のように書けます( ( \odot はアダマール積を表しています ) )。

$$ \frac{\partial L}{\partial {y_2}} = \frac{\partial L}{\partial {layer_2}} \odot R_2 $$

それでは ( \frac{\partial L}{\partial {w_2}}) を求めます。

$$\frac{\partial L}{\partial {w_2}} = \frac{\partial {y_2}}{\partial {w_2}} \cdot \frac{\partial L}{\partial {y_2}} $$

( \frac{\partial L}{\partial {y_2}}) は先ほど求めたので、( y_2 = layer_1 \cdot w_2 + b_2) より、以下のように求まります。

$$ \frac{\partial L}{\partial {w_2}} =  layer_1^T \cdot \frac{\partial L}{\partial {layer_2}} \odot R_2 $$

上式は ( 64 \times 64 = 64 \times N \cdot N \times 64 ) のように内積の順番と次元の整合性が取れていることがわかります。

次に ( \frac{\partial L}{\partial {b_2}}) を求めます。

$$ \frac{\partial L}{\partial {b2}} =  \sum{j}\frac{\partial {y{2}}{j}}{\partial {b2}} \cdot \frac{\partial L}{\partial {y{2}}{j}} = \sum{j}\frac{\partial L}{\partial {layer2}{j}} \odot R_2 $$

上式は ( 1 \times 64 = 1 \times 64 ) のように内積の順番と次元の整合性が取れていることがわかります。

パラメータ \( w_1, b_1 \) の更新のための勾配計算

( \frac{\partial L}{\partial w{1}}, \frac{\partial L}{\partial b{1}} ) の計算は( \frac{\partial L}{\partial w{2}}, \frac{\partial L}{\partial b{2}} ) の導出と同じ方法で求めることができます。

計算過程については省略します。

$$ \frac{\partial L}{\partial {y_1}} = \frac{\partial L}{\partial {layer_1}} \odot R_1 $$

$$ \frac{\partial L}{\partial {w_1}} =  input^T \cdot \frac{\partial L}{\partial {layer_1}} \odot R_1 $$

$$ \frac{\partial L}{\partial {b1}} =  \sum{j}\frac{\partial L}{\partial {layer1}{j}} \odot R_1 $$

パラメータの更新のための計算まとめ

これまでの ( \frac{\partial L}{\partial w{1}}, \frac{\partial L}{\partial b{1}}, \frac{\partial L}{\partial w{2}}, \frac{\partial L}{\partial b{2}}, \frac{\partial L}{\partial w{3}}, \frac{\partial L}{\partial b{3}} ) を ( \frac{\partial L}{\partial y_{i}}) で以下のように表すことができます。

$$ \frac{\partial L}{\partial w{i}} = layer{i - 1}^T \cdot \frac{\partial L}{\partial y_{i}}$$

$$ \frac{\partial L}{\partial b{i}} = \sum{j}\frac{\partial L}{\partial {y{i}}{j}}$$

ただし、( layer_0) は ( input ) を表しています。

活性化関数の種類により、( \frac{\partial L}{\partial y_{i}} ) を計算する必要はありますが、基本的に ( yi = layer{i - 1} \cdot w_i + bi) のようなニューラルネットワークでは、( \frac{\partial L}{\partial w{i}}, \frac{\partial L}{\partial b_{i}}) を上記のように計算できることがわかりました。

まとめ

今回はMNISTの分類問題のモデルという具体例を用いて、モデルの中でどのように勾配が計算されているのかを見ていきました。便利なライブラリを用いれば、このような仕組みを意識しなくても実装できると思いますが、理論の部分を理解するのも重要だと思っています。また、今回のタスクは分類タスクでしたが、回帰タスクでも損失関数から逆伝播させるのは出力データ ( output ) と正解データ ( t ) の差であることは変わらないようです。本記事では触れられませんでしたが、勾配を計算した後は、最適化によりパラメータを更新していきます。その最適化手法にも種類が様々あるようで、そちらの学習も今後進められればと思っています。

書いた人はこんな人

みや(エンジニア)
みや(エンジニア)
社会人一年目のみやです。大学時代はエンジニアとは無縁の生活。ひょんな事からこの業界に興味を持ち、エンジニアとしてライトコードで働いています。未経験ですが、どんどん吸収して成長していきます!

関連記事

採用情報

\ あの有名サービスに参画!? /

バックエンドエンジニア

\ クリエイティブの最前線 /

フロントエンドエンジニア

\ 世界を変える…! /

Androidエンジニア

\ みんなが使うアプリを創る /

iOSエンジニア