• トップ
  • ブログ一覧
  • ディープラーニングの勾配計算についてまとめてみた【誤差逆伝播法】
  • ディープラーニングの勾配計算についてまとめてみた【誤差逆伝播法】

    みや(エンジニア)みや(エンジニア)
    2023.10.31

    IT技術

    はじめに

    ChatGPTやDeepLなどディープラーニングを用いた製品やサービスはたくさんありますが、その仕組みを全く知りませんでした。

    「その仕組みを知りたい!」という気持ちからディープラーニングについて学習し始めたので、アウトプットも兼ねて今回、誤差逆伝播法の勾配を計算する部分について紹介したいと思います。ディープラーニングを勉強すると最初に躓く部分だと思うので、丁寧に説明できればと思います。

    ディープランニング概要

    ディープラーニングのモデルの学習の流れは以下の1~5のサイクルになっています。これらを繰り返すことで、より精度の高い出力ができるパラメータに更新し、精度の高いモデルを生成していきます。

    1. データを入力
    2. データを出力
    3. 損失を計算
    4. 勾配を計算
    5. パラメータを更新

    今回は上記4番の「計算した損失から勾配を計算する仕組み」を紹介します。

    誤差逆伝播法

    これから紹介するMNIST画像の分類問題を行う中間層が2層のモデル(入力 → 中間層1 → 中間層2 → 出力 → 損失)を題材にします。

    MNISTは手書き数字(0から9までの整数 = 10種類)の画像のデータセットです。

    また、MNIST画像は 28×2828 \times 28 のグレー画像で、各ピクセルは0から255までの値をとります。

    1画像につき784ピクセルのデータがあるので、入力データは N×784N \times 784 ベクトルになります( Nはバッチ数です )。

    入力 inputinput から1つ目の中間層 layer1layer1 への重み ( w1 w_1 ) の次元は 784×64 784 \times 64 、バイアス ( b1 b_1 ) の次元は 1×64 1 \times 64 、活性化関数は ReLU関数 としました。

    ですので、中間層1の次元は N×64 N \times 64 になります。

    1つ目の中間層から2つ目の中間層( layer2 )への重み ( w2 w_2 ) の次元は 64×64 64 \times 64 、バイアス ( b2 b_2 ) の次元は 1×64 1 \times 64 、活性化関数は ReLU関数 としました。

    ですので、中間層2の次元は N×64 N \times 64 になります。

    本モデルは10種類の分類問題なので、出力 ( output ) の次元を N×10 N \times 10 にする必要があります。

    なので、2つ目の中間層から出力への重み ( w3 w_3 ) の次元は 64×10 64 \times 10 、バイアス ( b3 b_3 ) の次元は 1×10 1 \times 10 、活性化関数は softmax関数 としました。

    また、損失関数には多値分類タスクによく使われるクロスエントロピー誤差を用います。t t は正解データを表しています。

    パラメータの更新のための勾配計算

    まず、計算された損失 L L を受けて、パラメータ w3,b3 w_3, b_3 をどのように更新するか、つまり Lw3,Lb3 \frac{\partial L}{\partial w_{3}}, \frac{\partial L}{\partial b_{3}} を求めていきます。

    Lw3,Lb3 \frac{\partial L}{\partial w_{3}}, \frac{\partial L}{\partial b_{3}} を求める流れは以下のようになっています。

    1. Loutput \frac{\partial L}{\partial output} を求める。
    2. Ly3 \frac{\partial L}{\partial y_{3}} を求める。
    3. Lw3,Lb3 \frac{\partial L}{\partial w_{3}}, \frac{\partial L}{\partial b_{3}} を求める。

    では、最初にLoutput \frac{\partial L}{\partial output} を求めます。損失は以下のような式で導出されます。

    L=fc(output,t)=itilog(outputi)L = f_c(output, t) = - \sum_{i}t_{i} \cdot log(output_{i})

    ただし、fcf_c はクロスエントロピー誤差を表しています。クロスエントロピーにより N×10 N \times 10 ベクトルを N×1 N \times 1 ベクトルの損失に変換しています。

    クロスエントロピーの定義
    L=fc(x,t)ktklog(xk) L = f_c(x, t ) - \sum_{k}t_{k} \cdot log(x_{k})
    ただし、x x はモデルの出力値、t t はそれに対応する正解値。

    自然対数の微分(xlogx=1x \frac{\partial}{\partial x} log x = \frac{1}{x})より、損失に対する outputi output_{i} の微分は以下のようにかけます。

    Loutputi=tioutputi\frac{\partial L}{\partial output_{i}} = - \frac{t_{i}}{output_{i}}

    次に、Ly3 \frac{\partial L}{\partial y_{3}} を求めます。output output は以下のような式で導出されます。

    output=fs(y3)output = f_{s}(y_{3})

    ただし、fsf_s はソフトマックス関数を表しています。ソフトマックス関数により N×10 N \times 10 ベクトルの y3 y_3 を10種類の数字の確率を表す N×10 N \times 10 ベクトルに変換しています。

    Ly3=Loutputoutputy3 \frac{\partial L}{\partial y_{3}} = \frac{\partial L}{\partial output} \cdot \frac{\partial output}{\partial y_{3}}

    より、Loutput\frac{\partial L}{\partial output} は先ほど求めたので、ソフトマックス関数の微分 outputy3 \frac{\partial output}{\partial y_{3}} を求めれば、Ly3 \frac{\partial L}{\partial y_{3}} が求まります。

    ソフトマックス関数の微分
    ソフトマックス関数の定義より、1×M 1 \times M ベクトルの出力 x x のある要素 xk x_k からそれに対応する確率を導出する式は以下になっています。
    fs(xk)=exkjexj f_s(x_k) = \frac{e^{x_k}}{\sum_{j}e^{x_{j}}}
    ここで S=jeyj S = \sum_{j}e^{y_{j}} とおくと、Sxl=eyl \frac{\partial S}{\partial x_{l}} = e^{y_{l}} とできます。
    (i)k=l (i) k = l のとき
    分数関数の微分公式 ( g(x)h(x)x=g(x)xh(x)g(x)h(x)xh(x)2\frac{\partial \frac{g(x)}{h(x)}}{\partial x} = \frac{\frac{\partial g(x)}{\partial x}h(x) - g(x)\frac{\partial h(x)}{\partial x}}{h(x)^2} ) を用いることで、以下のように書けます。
    fs(xk)xl=exkSe2xkS2=exkS(1exkS)=fs(xk)(1fs(xk))\frac{\partial f_s(x_k)}{\partial {x_{l}}} = \frac{e^{x_{k}}S - e^{2x_{k}}}{S^2} = \frac{e^{x_{k}}}{S}(1 - \frac{e^{x_{k}}}{S}) = f_s(x_k)(1 - f_s(x_k))
    (ii)kl (ii) k \neq l のとき
    fs(xk)xl=exkexlS2=fs(xk)fs(xl)\frac{\partial f_s(x_k)}{\partial {x_{l}}} = - \frac{e^{x_{k}} \cdot e^{x_{l}}}{S^2} = - f_s(x_k)f_s(x_l)
    これらより、ソフトマックス関数の微分は以下のようになります。
    fs(xk)xl={fs(xk)(1fs(xk))(x=l)fs(xk)fs(xl)(xl)\frac{\partial f_s(x_k)}{\partial {x_{l}}} =\left\{ \begin{array}{ll} f_s(x_k)(1 - f_s(x_k)) (x = l)\\ - f_s(x_k)f_s(x_l) (x \neq l) \end{array} \right.

    前述した式

    Ly3=Loutputoutputy3 \frac{\partial L}{\partial y_{3}} = \frac{\partial L}{\partial output} \cdot \frac{\partial output}{\partial y_{3}}

    は連鎖律より以下のようにかけます。

    Ly3i=jLoutputjoutputjy3i \frac{\partial L}{\partial {y_{3}}_{i}} = \sum_{j}\frac{\partial L}{\partial output_{j}} \cdot \frac{\partial output_{j}}{\partial {y_{3}}_{i}}

    より、これまでの計算結果とソフトマックス関数の微分の場合分けを考慮すると以下のように書けます。

    Ly3i=Loutputioutputi(1outputi)+ \frac{\partial L}{\partial {y_{3}}_{i}} = \frac{\partial L}{\partial output_{i}}output_{i}(1 - output_{i}) +

    kiLoutputk(outputkoutputi) \sum_{k \neq i}\frac{\partial L}{\partial output_{k}}(- output_{k}\cdot output_{i})

    =tioutputioutputi(1outputi)+ = - \frac{t_{i}}{output_{i}}output_{i}(1 - output_{i}) +

    ki(tkoutputk)(outputkoutputi)) \sum_{k \neq i}(- \frac{t_{k}}{output_{k}})(- output_{k} \cdot output_{i}))

    =ti(1outputi)+outputikitk = - t_{i}(1 - output_{i}) + output_{i}\sum_{k \neq i}t_{k}

    =ti(1outputi)+outputi(1ti)=outputiti = - t_{i}(1 - output_{i}) + output_{i}(1 - t_{i}) = output_{i} - t_{i}

    これにより、損失 L L に対する、y3i {y_{3}}_{i}の微分を求めることができました。

    これまでの計算結果より、遂に Lw3,Lb3 \frac{\partial L}{\partial w_{3}}, \frac{\partial L}{\partial b_{3}} を求めることができます。

    y3 y_3 は以下の式で表されます。

    y3=layer2w3+b3 y_3 = layer_2 \cdot w_3 + b_3

    まず、Lw3 \frac{\partial L}{\partial w_{3}} は以下のように導出できます(右上に添えてある ( T ) は転置を表しています)。

    Lw3=jy3jw3Ly3j=layer2T(outputt) \frac{\partial L}{\partial {w_3}} = \sum_{j}\frac{\partial {y_{3}}_{j}}{\partial {w_3}} \cdot \frac{\partial L}{\partial {y_{3}}_{j}} = layer_2^T \cdot (output - t)

    上式は 64×10=64×NN×10 64 \times 10 = 64 \times N \cdot N\times 10 のように内積の順番と次元の整合性が取れていることがわかります。

    次に、Lb3 \frac{\partial L}{\partial b_{3}} は以下のように導出できます。

    Lb3=jy3jb3Ly3j=j(outputjtj) \frac{\partial L}{\partial {b_3}} = \sum_{j}\frac{\partial {y_{3}}_{j}}{\partial {b_3}} \cdot \frac{\partial L}{\partial {y_{3}}_{j}} = \sum_{j}(output_j - t_j)

    上式は 1×10=1×10 1 \times 10 = 1 \times 10 のように内積の順番と次元の整合性が取れていることがわかります。

    また、以降の計算のために Llayer2 \frac{\partial L}{\partial {layer_2}} も求めます。
    Llayer2=Ly3y3layer2=(outputt)w3T \frac{\partial L}{\partial {layer_{2}}} = \frac{\partial L}{\partial {y_{3}}} \cdot \frac{\partial {y_{3}}}{\partial {layer_{2}}} = (output - t) \cdot w_3^T

    上式は N×64=N×1010×64 N \times 64 = N \times 10 \cdot 10 \times 64 のように内積の順番と次元の整合性が取れていることがわかります。

    パラメータの更新のための勾配計算

    Lw3,Lb3)を求めることができたので、次に(Lw2,Lb2 \frac{\partial L}{\partial w_{3}}, \frac{\partial L}{\partial b_{3}} ) を求めることができたので、次に( \frac{\partial L}{\partial w_{2}}, \frac{\partial L}{\partial b_{2}} を求めます。

    先ほどと違うのは、クロスエントロピー誤差を考慮しなくて良いこと。また、活性化関数がソフトマックス関数ではなくReLU関数であることです。

    layer2=fr(y2)=fr(layer1w2+b2)layer_2 = f_r(y_2) = f_r(layer_1 \cdot w_2 + b_2)

    ただし、fr f_r はReLU関数を表しています。

    まず、Ly2 \frac{\partial L}{\partial {y_2}} を求めます。

    Ly2=Llayer2layer2y2\frac{\partial L}{\partial {y_2}} = \frac{\partial L}{\partial {layer_2}} \cdot \frac{\partial {layer_2}}{\partial {y_2}}

    Llayer2 \frac{\partial L}{\partial {layer_2}} については前節で既に求めているので、layer2y2 \frac{\partial {layer_2}}{\partial {y_2}} を求めます。

    ReLU関数の微分
    v=fr(x) v = f_r(x) について、微分 vx\frac{\partial v}{\partial {x}} を求めることを考える。
    v={xamp;(x>0)0amp;(x0)v =\left\{ \begin{array}{ll} x & (x \gt 0)\\ 0 & (x \leq 0) \end{array} \right.
    ReLU関数の定義より v v は上記のように表すことができる。
    より、vx\frac{\partial v}{\partial {x}} は以下のようになる。
    vx={1amp;(x>0)0amp;(x0)\frac{\partial v}{\partial {x}} =\left\{ \begin{array}{ll} 1 & (x \gt 0)\\ 0 & (x \leq 0) \end{array} \right.

    ReLU関数の微分より、layer2y2 \frac{\partial {layer_2}}{\partial {y_2}}y2 y_2 の要素が0より大きい時に1、小さい時に0をとる N×64 N \times 64 ベクトル(y2 y_2 と同じ形)であることがわかりました。

    これを R2 R_2 とおくと、Ly2\frac{\partial L}{\partial {y_2}} は以下のように書けます( はアダマール積を表しています \odot はアダマール積を表しています )。

    Ly2=Llayer2R2 \frac{\partial L}{\partial {y_2}} = \frac{\partial L}{\partial {layer_2}} \odot R_2

    それでは Lw2 \frac{\partial L}{\partial {w_2}} を求めます。

    Lw2=y2w2Ly2\frac{\partial L}{\partial {w_2}} = \frac{\partial {y_2}}{\partial {w_2}} \cdot \frac{\partial L}{\partial {y_2}}

    Ly2 \frac{\partial L}{\partial {y_2}} は先ほど求めたので、y2=layer1w2+b2 y_2 = layer_1 \cdot w_2 + b_2 より、以下のように求まります。

    Lw2=layer1TLlayer2R2 \frac{\partial L}{\partial {w_2}} = layer_1^T \cdot \frac{\partial L}{\partial {layer_2}} \odot R_2

    上式は 64×64=64×NN×64 64 \times 64 = 64 \times N \cdot N \times 64 のように内積の順番と次元の整合性が取れていることがわかります。

    次に Lb2 \frac{\partial L}{\partial {b_2}} を求めます。

    Lb2=jy2jb2Ly2j=jLlayer2jR2 \frac{\partial L}{\partial {b_2}} = \sum_{j}\frac{\partial {y_{2}}_{j}}{\partial {b_2}} \cdot \frac{\partial L}{\partial {y_{2}}_{j}} = \sum_{j}\frac{\partial L}{\partial {layer_2}_{j}} \odot R_2

    上式は 1×64=1×64 1 \times 64 = 1 \times 64 のように内積の順番と次元の整合性が取れていることがわかります。

    パラメータの更新のための勾配計算

    Lw1,Lb1 \frac{\partial L}{\partial w_{1}}, \frac{\partial L}{\partial b_{1}} の計算はLw2,Lb2 \frac{\partial L}{\partial w_{2}}, \frac{\partial L}{\partial b_{2}} の導出と同じ方法で求めることができます。

    計算過程については省略します。

    Ly1=Llayer1R1 \frac{\partial L}{\partial {y_1}} = \frac{\partial L}{\partial {layer_1}} \odot R_1

    Lw1=inputTLlayer1R1 \frac{\partial L}{\partial {w_1}} = input^T \cdot \frac{\partial L}{\partial {layer_1}} \odot R_1

    Lb1=jLlayer1jR1 \frac{\partial L}{\partial {b_1}} = \sum_{j}\frac{\partial L}{\partial {layer_1}_{j}} \odot R_1

    パラメータの更新のための計算まとめ

    これまでの Lw1,Lb1,Lw2,Lb2,Lw3,Lb3 \frac{\partial L}{\partial w_{1}}, \frac{\partial L}{\partial b_{1}}, \frac{\partial L}{\partial w_{2}}, \frac{\partial L}{\partial b_{2}}, \frac{\partial L}{\partial w_{3}}, \frac{\partial L}{\partial b_{3}} Lyi \frac{\partial L}{\partial y_{i}} で以下のように表すことができます。

    Lwi=layeri1TLyi \frac{\partial L}{\partial w_{i}} = layer_{i - 1}^T \cdot \frac{\partial L}{\partial y_{i}}

    Lbi=jLyij \frac{\partial L}{\partial b_{i}} = \sum_{j}\frac{\partial L}{\partial {y_{i}}_{j}}

    ただし、layer0 layer_0input input を表しています。

    活性化関数の種類により、Lyi \frac{\partial L}{\partial y_{i}} を計算する必要はありますが、基本的に yi=layeri1wi+bi y_i = layer_{i - 1} \cdot w_i + b_i のようなニューラルネットワークでは、Lwi,Lbi \frac{\partial L}{\partial w_{i}}, \frac{\partial L}{\partial b_{i}} を上記のように計算できることがわかりました。

    まとめ

    今回はMNISTの分類問題のモデルという具体例を用いて、モデルの中でどのように勾配が計算されているのかを見ていきました。便利なライブラリを用いれば、このような仕組みを意識しなくても実装できると思いますが、理論の部分を理解するのも重要だと思っています。また、今回のタスクは分類タスクでしたが、回帰タスクでも損失関数から逆伝播させるのは出力データ output output と正解データ t t の差であることは変わらないようです。本記事では触れられませんでしたが、勾配を計算した後は、最適化によりパラメータを更新していきます。その最適化手法にも種類が様々あるようで、そちらの学習も今後進められればと思っています。

    ライトコードでは、エンジニアを積極採用中!

    ライトコードでは、エンジニアを積極採用しています!社長と一杯しながらお話しする機会もご用意しております。そのほかカジュアル面談等もございますので、くわしくは採用情報をご確認ください。

    採用情報へ

    みや(エンジニア)
    みや(エンジニア)
    Show more...

    おすすめ記事

    エンジニア大募集中!

    ライトコードでは、エンジニアを積極採用中です。

    特に、WEBエンジニアとモバイルエンジニアは是非ご応募お待ちしております!

    また、フリーランスエンジニア様も大募集中です。

    background