1. HOME
  2. ブログ
  3. IT技術
  4. 【前編】PyTorchの自動微分を使って線形回帰をやってみた

【前編】PyTorchの自動微分を使って線形回帰をやってみた

前編~PyTorchの自動微分を使って線形回帰に挑戦!~

PyTorch」を使っていると、次のような疑問を持つ人は多いはず…。

model.zero_grad() って何やってるんだろう?」

loss.backward() では、何が計算されているの?」

「Tensor の属性の requires_grad って何?」

ここでは、そんな方のために、「PyTorch」の自動微分による線形回帰を、わかりやすく解説していきます!

「自動微分」の理解がカギを握る

冒頭の PyTorch の疑問や、Tensor の以下の属性は、すべて自動微分(automatic differentiation)に関係しています。

  1. requires_grad
  2. grad
  3. grad_fn
  4. is_leaf

つまり、PyTorch の「自動微分」を理解すれば、すべての疑問をスッキリと解消できるわけです。

では次から、そのカギとなる「自動微分」について、深く見ていきましょう!

まずは「自動微分」の準備

まずは、PyTorch で自動微分をするために、準備をしていきましょう!

ライブラリのインポート

はじめに、自動微分に必要なライブラリを、インポートしていきましょう。

「requires_grad=True」を指定する

入力データを x に設定し、変数で  wb も、それぞれTensor(テンソル)として定義します。

このとき、 wb には、「  requires_grad=True 」をつけています。

これを「True」にすることで、「微分の対象にしますよ!」と指定しているわけですね!

ちなみに、デフォルトの状態では、「False」となっています。

まずは、上のコードで代入されたものを、全部プリントしてみましょう!

しっかりと指定されているのが、わかりますね!

自動微分の「グラフ」を理解しよう

PyTorch は、計算の流れを、グラフにして記憶しています。

具体的な例は、次から紹介していきますので、順にみていきましょう!

自動微分のグラフ

それでは、簡単な計算をしてみましょう。

用意した計算式は、次のとおり。

上のコードに、先ほど代入した値を当てはめていくと、答えは次のようになります。

では、これもプリントしてみましょう。

答えは、確かに「11」になりました!

計算内容をグラフにして記憶

さっきのプリント結果では、「11」のあとに、grad_fn=<AddBackward0>がついていました。

これは、 y足し算( Add )によって生まれたものだと、記憶しているということです。

まずは、PyTorch がどんなグラフを持っているのか、以下の図で見てみましょう!

MulBackward0Mul は、「Multiply」の略で、「掛け算」のこと。

つまり、 wx が掛け算だったことも、記憶されているわけですね!

このように、計算に変数が含まれていると、PyTorch は計算の内容をグラフにして記憶していきます。

そうすることで、 y に対して自動微分を実行する時に、PyTorch は逆の順番で計算をたどることができるのです。

y は、  w * x  と b  の足し算なので、微分を w * xb とで、別々に計算できることがわかります。

さらに、 w * x  のところは、「定数」と「変数」の掛け算の微分になるということですね。

グラディエント関数「grad_fn」と末端変数「is_leaf」

「grad_fn」と「is_leaf」についても、理解を深めていきましょう!

グラディエント関数「grad_fn」

記憶された関数は、 grad_fn  を使うことで、参照することができます。

そのため、関数を利用していないユーザーが定義した変数では、「None」となるのです。

  grad_fn  の grad は、あとで出てくるグラディエント(gradient)の略です。

fn  は、関数(function)の略となります。

末端変数「is_leaf」

ちなみに、 wb はユーザーが定義した変数で、「leaf Variableと呼ばれています。

英語の「leaf」は、木の葉っぱのことなので、訳すとすれば「グラフの末端の変数」ですね!

wb は、この末端変数となるので、もちろん以下のように True が返ってきます。

ちなみに、 x も定数ですが末端の値なので、 is_leaf を呼ぶと True が返ってきます。

PyTorchのドキュメントでは、「leaf Tensorと呼ばれています。


そして、「leaf Tensor」grad_fn も、 None を返します。

自動微分を行うタイミング

まず、計算の「終わり → 始まり」へ向かって、微分計算していく手法を、「誤差逆伝播法(Back propagation、あるいはBackprop)」と呼びます。

PyTorch の自動微分の機能では、グラフを自動的に作り、Backprop を行えるように準備してくれるのです。

最終的に、いつ自動微分を行うのかは、ユーザーが決めることができます。

自動微分とグラディエント(grad)

では、ここで y に対して、自動微分を実行してみましょう。

y を計算する際に使われた変数、これに微分が自動で計算されます。

このとき、 y の計算に使われた変数には、 grad という属性が作られます

こちらも、プリントしてみましょう!

grad は、「gradient」の略で、日本語では「グラディエント」とか「勾配」と呼ばれています。

グラディエント(勾配)の例

数学や物理では、「勾配」とはある関数の「最大傾斜を表すベクトルのことで、 y の全微分から導くことができます。

参考までに、全微分から勾配を求める式も見ていきましょう!

「\(y = f(w, b)\)」とし、 y は、 wb の関数だとします。

$$dy = \frac{\partial{f}}{\partial{w}}dw + \frac{\partial{f}}{\partial{b}}db
=
\begin{pmatrix}
\frac{\partial{f}}{\partial{w}} \\
\frac{\partial{f}}{\partial{b}}
\end{pmatrix}
\cdot
\begin{pmatrix}
dw \\
db
\end{pmatrix}
$$

よって、

$$\mathrm{grad} \, y =
\begin{pmatrix}
\frac{\partial{f}}{\partial{w}} \\
\frac{\partial{f}}{\partial{b}}
\end{pmatrix}
$$

となります。

つまり「グラディエント(勾配)」とは

グラディエントとは、簡単にいうと w や  b の値を増やすときに、yの値がどの程度変わるのかを表したもの。

先ほどの y は、直線の式なので、次のように簡単に表現できます。

  1. w.grad = 5 は 、 w が1増えると、  y が5増える」
  2. b.grad = 1 は 、 b が1増えると 、 y が1増える」

もともとの式が、  y = w * x + bx = 5 なので、正しいことがわかりますね!

線形回帰をやってみる

いよいよ、与えられたデータに対して、「線形回帰」を適用してみましょう!

「線形回帰」とは、データの分布を直線によって、近似させる手法です。

y = w * x + b では、 x が入力値で、  y x に対するデータの実測値となります。

これらデータを直線で最も近似させて表すとき、最適なパラメータ値「 w」と「 b」は、一体いくつなのかを求めるのです。

教師データを作る

まず、データを作ります。

ここでは、 y 2x という直線の式に、ノイズを加えたものを用意しました。

また、データ型は「フロート32」にしておきます。

理由は、あとで使う PyTorch のモジュール( nn.Linea など)がフロート32対応のものが多く、フロート64型のままだとエラーになるためです。

データ描画

では、データをプロットしましょう。

直線のまわりに、散乱したデータができましたね!

線形回帰のモデルとパラメータを設定する

では、データと変数を PyTorch の「Tensor」として定義しましょう。

xy は、与えられたデータであり、定数なので from_numpy で Tensor に変換します。

wb は求めたい値なので、変数として、適当に初期化しておきましょう!

この変数 wb を使って、入力データ x から y の値を予測する、「線形モデル」を定義します。

線形回帰の損失関数

線形回帰を数値計算するときは、最適なパラメータを「最小2乗法」を使って求めていきます。

最小2乗法では、2乗誤差(予測値と実測値の差の2乗)が最小になるように、パラメータを調整します。

損失関数」として、予測値 p と実測値 y との「平均2乗誤差(MSE: Mean Squared Error)」を定義しておきましょう。

上の mse は、2乗誤差の平均を計算したものです。

線形回帰モデルを訓練する

準備ができたので、トレーニングをしましょう!

データは小さいので、ミニバッチは考えずに、全てのデータを与えて何度もエポックを繰り返します。

すると、「 w は2に近い値、 b は0に近い値」になったので、ほぼ正解が得られましたね!

後編へつづく!

後編はこちらです!

こちらの記事もオススメ!

書いた人はこんな人

広告メディア事業部
広告メディア事業部
「好きを仕事にするエンジニア集団」の(株)ライトコードです!

ライトコードは、福岡、東京、大阪の3拠点で事業展開するIT企業です。
現在は、国内を代表する大手IT企業を取引先にもち、ITシステムの受託事業が中心。
いずれも直取引で、月間PV数1億を超えるWebサービスのシステム開発・運営、インフラの構築・運用に携わっています。

システム開発依頼・お見積もり大歓迎!

また、現在「WEBエンジニア」「モバイルエンジニア」「営業」「WEBデザイナー」「WEBディレクター」を積極採用中です!
インターンや新卒採用も行っております。

以下よりご応募をお待ちしております!
https://rightcode.co.jp/recruit

関連記事

採用情報

\ あの有名サービスに参画!? /

バックエンドエンジニア

\ クリエイティブの最前線 /

フロントエンドエンジニア

\ 世界を変える…! /

Androidエンジニア

\ みんなが使うアプリを創る /

iOSエンジニア