1. HOME
  2. ブログ
  3. IT技術
  4. 【前編】Pytorchの様々な最適化手法(torch.optim.Optimizer)の更新過程や性能を比較検証してみた!

【前編】Pytorchの様々な最適化手法(torch.optim.Optimizer)の更新過程や性能を比較検証してみた!

最適化手法の更新過程や性能を比較検証してみよう!

機械学習では、学習を行う際に、「損失関数の値」と「勾配」からパラメータを更新し、「モデルの最適化」を行います。

この時の「モデル最適化手法」は様々な方法が考案・活用されていますが、それぞれ最適化される過程は異なってきます。

今回は、Pytorchに用意されている各種最適化手法(torch.optim.Optimizer)の学習過程がどのように異なるのかについて、「損失関数」や「精度の比較」により検証します。

本記事は「前編」「後編」でお届けいたします。

各種torch.optim.Optimizerの紹介と挙動

今回検証する「最適化手法」は、次の6つです。

  1. SGD : torch.optim.SGD
  2. Adagrad : torch.optim.Adagrad
  3. RMSprop : torch.optim.RMSprop
  4. Adadelta : torch.optim.Adadelta
  5. Adam : torch.optim.Adam
  6. AdamW : torch.optim.AdamW

最適化手法の比較方法

まずは、これらの「最適化手法」について、関数 \( f(x, y)=x^2+y^2 \) 平面上での最適化過程を比較し、各手法を簡単に紹介していきます。

関数f(x,y)のグラフと更新方法

関数 f(x, y) のグラフは以下のようになり、x と y が共に「0」の時に最小値「0」をとります。

各 optimizer をそれぞれ定義し、関数 f(x,y) で求めた出力の勾配から、各 optimizer でパラメータ(x,y)の更新を行います

optimizerに渡すパラメータの設定

optimizer に渡すパラメータ (x、y) の初期値は、それぞれ「-75.0」、「-10.0」とします。

学習率(lr)は、 SGD 以外 「1」、SGD は収束させるために「0.1」に調整しました。

それ以外のパラメータは、デフォルト値とします。

コード

各optimizerの更新経路

下図が、各 optimizer 毎のパラメータ更新経路をまとめたグラフです。

それぞれ「x = 0、y = 0」に収束していますが、更新経路が異なっています

厳密には、実用的なデータやモデルで精度などを検証する必要があります。

しかし、これを見るだけでも、それぞれ異なる原理で最適化されています。

例えば、更新経路が短いことや余計な動きが少ない方が、優れた最適化手法であるという見方もできます。

では、一つずつ見ていきましょう。

SGD (Momentum)

「SGD(Stochastic Gradient Decent : 確率的勾配降下法)」は代表的な最適化手法であり、求めた勾配方向にその大きさだけパラメータを更新するという方法です。

更新式

$$ W \leftarrow W - \eta \frac{\partial L}{\partial W} $$
(W:パラメータ、η:学習率、L:損失関数、dL / dW:勾配)

SGDの改良版

また、「SGD」を改良した 「Momentum(モーメンタム)」という手法があり、以下の式で表されます。

$$ v \leftarrow \alpha v - \eta \frac{\partial L}{\partial W} $$
$$ W \leftarrow W + v $$

「Momentum」で導入された「v」と「α」はそれぞれ速度抵抗に相当します。

関数平面での勾配により発生する各方向への力が速度「v」となり、地面の摩擦や空気抵抗のような減速させる方向への力が「α」になります。

これにより「Momentum」では、関数平面上をボールが転がるように最適化されます

torch.optim.SGDのパラメータ

torch.optim.SGD()」では、パラメータに momentumdampeningnesterov を渡すことで設定できますが、今回は単純な「SGD」を用います。

SGDの最適化過程

「SGD」の最適化過程は、以下のようになりました。

学習率が「1」の時は、パラメータの更新が適切に行われませんでした。

しかし、学習率を小さくすることで、最小値へと直線的に収束しています。

他の最適化手法との比較

「SGD」は、勾配方向に一定の学習率で更新を繰り返す方法です。

今回のような単純な関数では、学習率の設定が適切であれば、勾配の大きさと方向に従いパラメータを更新していくため、効率よく最適化されます。

これに対し、「Momentum」や後述する他の最適化手法は、最適化過程で学習率を調整したり、参考にする勾配の影響を考慮した設計を導入しています。

そのため、さらに複雑な関数の場合では、「SGD」よりも優れた結果を示す場合があります。

Adagrad

Adagrad(Adaptive Gradient Algorithm)」は学習率を学習過程の中で更新していく方法です。

パラメータの更新度合いにより、次の学習率を各パラメータ毎に調整します。

すなわち、大きく更新されたパラメータの学習率は、より小さく調整されます。

そのため、大きな勾配の影響が小さくなり、効率的に最適化を行うことができます。

更新式

$$ h \leftarrow h + \left(\frac{\partial L}{\partial W}\right)^2 $$
$$ W \leftarrow W - \frac{\eta}{\sqrt{h}} \frac{\partial L}{\partial W} $$

「h」は学習率調整の変数で、過去の勾配の二乗和を全て記憶していきます。

Adagradの更新経路

「Adagrad」は、勾配の大きさに対し学習率が調整されています。

そのため、「SGD」とは更新経路が異なっており、各方向における勾配の影響がより考慮された形になっています。

少しわかり辛いですが、更新過程を見ると初期の更新量が大きく、徐々に学習率は小さくなるため更新量も小さくなっていきます

Adagradの問題点

一方で「Adagrad」の問題点は、学習が進むと、いずれ学習率が非常に小さくなり更新されなくなるということです。

「Adagrad」は学習率を調整しながら最適化します。

しかし、全ての勾配情報を記憶していくため、学習が進んでいくと同時に学習率が小さくなります。

そのため、最小値がさらに先にある場合、勾配があっても更新量が小さく更新されにくくなってしまいます

RMSprop

この「Adagrad」の問題点を受けて提案されたのが、「RMSprop(Root Mean Square propagation)」になります。

更新式

$$ h \leftarrow \alpha h + (1 - \alpha) \left( \frac{\partial L}{\partial W} \right)^2 $$
$$ W \leftarrow W - \frac{\eta}{\sqrt{h + \epsilon}} \frac{\partial L}{\partial W} $$

「RMSprop」では勾配情報の記憶を行いますが、古い情報を落として、新しい勾配情報がより反映されるように記憶していきます。

これにより、学習が進んでいっても、その場での勾配情報を反映し、学習率の調整を適切に行うことができます。

「α」は学習率を調整する係数、「ε」はゼロ徐算を防ぐための極めて小さい値です。

torch.optim.RMSprop のパラメータ

torch.optim.RMSprop() 」では、デフォルトのパラメータとして「α = 0.99ε = 1e - 08」が与えられています。

RMSpropの更新経路

今回の更新過程は以下のようになり、y がすぐに「0」となり、そこから x 方向に更新されていくという挙動を示しています。

Adadelta

これまで紹介した最適化手法は、基本的に単位が整っていません

勾配とパラメータでは単位が異なっていますが、更新式の中で単位の違いは特に考慮されず計算されていません。

そのため結果的に、パラメータを勾配の単位で更新するということになります。

そこで「Adadelta」は、単位が整うように「Adagrad」や「RMSprop」を改良したものになります。

更新式

$$ \Delta W = - \frac{\eta}{\sqrt{h + \epsilon}}\frac{\partial L}{\partial W} $$ (RMSpropの更新量)
$$ \Delta W \leftarrow \alpha \Delta W + (1 - \alpha) \Delta W^2 $$
$$ W \leftarrow W - \frac{\sqrt{\Delta W + \epsilon}}{\sqrt{h + \epsilon}}\frac{\partial L}{\partial W} $$

最後の「W」の更新式を見ると、まず「過去の更新量 ΔW の移動平均」を「過去0の勾配 h の移動平均」で徐算しています。

これに現在の勾配を掛けたものが「Adadelta」での更新量となり、これによりパラメータ W は同じ単位で更新されるようになります。

また、学習率「η」が存在しないことから、学習率の設定が不要となることも特徴の一つです。

Adadeltaの更新経路

今回の関数では更新経路は以下のようになり、「Adagrad」と同じような経路となっています。

Adam

Adam(Adaptive Moment Estimation)」は「Momentum」と「Adagrad」の融合というアイディアにより考案されました。

更新式

勾配を記憶する「Momentum」と、勾配の二乗を記憶する「Adagrad」の項により構成されます。

$$ m \leftarrow \beta_1 m + (1 - \beta_1)\frac{\partial L}{\partial W} $$
$$ v \leftarrow \beta_2 v + (1 - \beta_2)\left(\frac{\partial L}{\partial W}\right)^2 $$
$$ \hat{m} = \frac{m}{1 - \beta_1} $$
$$ \hat{v} = \frac{v}{1 - \beta_2} $$
$$ W \leftarrow W - \frac{\eta \hat{m}}{\sqrt{\hat{v} + \epsilon}} $$

m」が 「Momentum」、「v」が「Adagrad」に相当し、それぞれ移動平均を用いて学習率の調整を行います。

torch.optim.Adamのパラメータ

torch.optim.Adam()」では、デフォルトで以下のように「β」などのパラメータが与えられています。

Adamの更新経路

今回の更新経路は、以下のようになりました。

「Momentum」のようにボールが転がるような動きをしますが、学習率を調整していくので、徐々に振れ幅は小さくなり、ゼロに近づいていきます

AdamW

「AdamW」は、「Adam」の Weight decay(重み減衰)に関する式について変更を行ったものです。

損失関数計算時およびパラメータ更新時において、L2 正則化項を追加することで、「Adam」よりも適切な Weight decay を得ることを目的としています。

更新式

$$ \frac{\partial L}{\partial W} \leftarrow \frac{\partial L}{\partial W} + \eta W $$(損失関数へのL2正則化項追加)
$$ W \leftarrow W - \frac{\eta \hat{m}}{\sqrt{\hat{v} + \epsilon}} + \eta W $$(Adam更新式へのL2正則化項追加)

Adamの更新経路

更新経路は「Adam」同様に蛇行していきますが、「AdamW」の方が経路が短くなっており、速く収束しています。

更新経路の比較

「Adam」と「AdamW」は更新経路が蛇行しており、他の方法に比べて非効率であるように見えます。

実際には、適用する関数や設定するハイパーパラメータ、学習データの種類などによりこれらの結果は変わってきます

そのため、一概にどれが良い方法かを決めることは難しいですが、各手法によって最適化経路が異なることは体感できたと思います。

後半へつづく

この記事は、【後半】へ続きます。

後編の目次は、

  1. 全結合層ニューラルネットワークによる検証
  2. CNNモデルによる検証

となっています。

後編はこちら

こちらの記事もオススメ!

書いた人はこんな人

広告メディア事業部
広告メディア事業部
「好きを仕事にするエンジニア集団」の(株)ライトコードです!

ライトコードは、福岡、東京、大阪の3拠点で事業展開するIT企業です。
現在は、国内を代表する大手IT企業を取引先にもち、ITシステムの受託事業が中心。
いずれも直取引で、月間PV数1億を超えるWebサービスのシステム開発・運営、インフラの構築・運用に携わっています。

システム開発依頼・お見積もり大歓迎!

また、現在「WEBエンジニア」「モバイルエンジニア」「営業」「WEBデザイナー」「WEBディレクター」を積極採用中です!
インターンや新卒採用も行っております。

以下よりご応募をお待ちしております!
https://rightcode.co.jp/recruit

関連記事

採用情報

\ あの有名サービスに参画!? /

バックエンドエンジニア

\ クリエイティブの最前線 /

フロントエンドエンジニア

\ 世界を変える…! /

Androidエンジニア

\ みんなが使うアプリを創る /

iOSエンジニア