1. HOME
  2. ブログ
  3. IT技術
  4. 【後編】PyTorchの自動微分を使って線形回帰をやってみた

【後編】PyTorchの自動微分を使って線形回帰をやってみた

後編~PyTorchの自動微分を使って線形回帰に挑戦!~

PyTorch」を使っていると、こんなことありませんか?

model.zero_grad() って何やってるんだろう?」

loss.backward() では、何が計算されているの?」

「Tensor の属性の requires_grad って何?」

そんな方のために、前回に引き続き、「PyTorch」の自動微分による線形回帰を、わかりやすく解説していきます!

前編をまだお読みでない方は、まずはこちらをお読みください。

勾配降下法(Gradient descent method)と「torch.no_grad()」について

まずは、損失カーブを見てみましょう。

良い形で、減少しているのがわかりますね!

これはグラディエントを使って、損失関数の値が下がるように、変数の値を何度もループの中で調節してます。

この手法は、「勾配降下法」と呼ばれます。

先ほど訓練をさせたコードだと、次の部分ですね。

このコードでは、「勾配降下法」で変数の wb を更新するときに、  with torch.no_grad() を使っています。

グラディエントの計算を継続しない」という意味です。

これをやらないと、PyTorch はグラフ上の変数を直接変更しているとみなし、エラーになります。

さらに変数 wb の更新後は、  w.grad.zero_()b.grad.zero_() で、グラディエントの値を「0」にします。

PyTorch が、グラディエントの値を自動で「0」にしないのは、グラディエントの値を継続して貯めていきたい場合もあるため。

つまり、ユーザーが決めたタイミングで、グラディエントの値を「0」に設定する必要があるわけです。

「with torch.no_grad()」を使わないと?

with torch.no_grad() を使わないと、一体どうなるか、実際にやってみましょう!

まず、変数 w2b2 を設定し、予測値 p2 を計算します。

損失値を出して、自動微分を実行していきましょう。

ここまでは、問題ありませんね。

ではグラディエントを使って、変数 wb の値を、 with torch.no_grad() を使わずに更新してみます。

エラーが出ましたね…。

このエラーの意味は、「グラディエントを必要とする leaf Variable の値を、直接変更することはできない」です。

これを許してしまうと、PyTorch がせっかく自動で作り上げてくれたグラフが、無意味になってしまうのです。

線形回帰の結果

線形回帰の結果を描画して、実際に目で確認しましょう!

結果を描画

散乱しているデータに対して、直線の近似ができていますね!

detach()を呼ばないとエラーになる?

ここで、疑問が浮かんだ方もいるかもしれません。

予測値 p に対して、Numpy に変換するときに、 p = p.detach().numpy()detach() を呼んでいます。

実は、 p.numpy() と、直接 Numpy へ変換しようとするとエラーになるのです。

エラーのメッセージを読むと、  detach() を挟んでから numpy() を呼べば、解決できそうです。

なぜ、そんなことが必要なのでしょうか?

これは、 x.numpy()p.numpy() が呼ばれたときに、何が起きているかを考えると理解しやすいです。

numpy()を呼ぶときに何が起きているのか

PyTorch は、効率性を重視するので、不必要なデータのコピーなどは、なるべく行わないようになっています。

なので、 numpy() が呼ばれたときに返されるのは、コピーされたデータではなく、元のデータを参照したものになっているわけです。

コードを見てみると?

具体的に、コードを見てみましょう。

この v1v2 は、同じデータを参照しており、 v1 値を変更すると v2 の値も変更されてしまいます。

両方とも同じデータを指しているので、「両方が変更される」というのは、正確には間違い…。

変更されたデータを、両方が参照している」と言ったほうが正しいです。

実際に、  v2 の値を変更してみましょう。

v1v2 の値が、両方とも同じ値を指しているのが分かりますね!

コピーして利用したい場合は?

もし、切り離したいのであれば、 clone() を呼んでコピーする必要があります。

ちなみに、コピーである  v3 の変更は、 v1v2 には反映されません。

detach()は変数から定数を作る

さて、変数に numpy() を呼ぶと、エラーが起きる現象に戻ります。

エラーは、

変数(Tensor that requires grad)に、 numpy() を呼べません。代わりに、 detach().numpy() を使ってください

といっています。

噛み砕くと、「PyTorchのグラディエントを、Numpy に持っていくことはできません」ということです。

つまり、グラディエントが必要のない「定数の Tensor」にすればいいわけです。

それが、  detach() の役割ということですね!

イメージ的には「変数をグラフから離す(detach)」ですが、実際には元の変数に変更はないので、「変数から定数を作る」という解釈の方がより正しいです。

コピーがされないことに注意!

ここでも、データ自体はコピーされないので注意が必要です。

v6 は、 requires_grad がないので、変数ではないのがわかります。

ただし、 v4v6 は同じデータを参照しているので、 v6 に対する変更が v4 からも参照されているのです。

データ変更するなら、cloneかcopyを使う

もし、データを変更する必要があるならば、「clone」や「copy」を使って、データそのものを複製しましょう!

v7v8 と v9 は、全て別のデータを参照しているのがわかります。

PyTorchの「nnパッケージ」で同じことをやってみる

比較として、PyTorchの「nn パッケージ」でも、線形回帰をやってみました。

早速、結果からみていきましょう!

「nn パッケージ」での線形回帰の結果

weight が「近似直線の傾き」で、 bias が「切片の値」です。

この方法も、「 weight は2に近く、 bias が0に近い」ので、うまく近似できていますね!

損失カーブは?

損失関数カーブも良い感じです。

近似の直線は?

近似の直線も、うまくデータを説明できています。

プログラムを追ってみよう!

あとは、コメントを見ながら、プログラムの流れを追ってみてください。

なぜ model.zero_grad()  が呼ばれているのか、 loss.backward() でどんなことが起きているのかも、想像できるはずです。

プログラムの全体的な流れとしては、「nn パッケージ」を使わない線形回帰のバージョンと、あまり違いがないですよね?

grad とか、  with torch.no_grad() などが出てこないぶん、簡単になっていますよ。

「nn パッケージ」の注意点

ただし、2つだけ注意点があります。

ひとつ目は、 x のデータタイプが torch.float32 になるように、Numpyでは  x を設定したときに np.float32 を指定していました。

これは「nn パッケージ」のモジュールが、Float 64型を受け付けないからです。

2つ目は、 x.reshape(-1, 1)y.reshape(-1, 1) を呼んでいること。

これは、 nn.Linear では x の次元を、「(バッチサイズ、入力値の変数の数)」の形で期待しているからです。

y に関しても、同様に「(バッチサイズ、実測値の次元)」である必要があります。

さいごに

お疲れ様でした!

今回は、PyTorch の自動微分を使って、「線形回帰」を実装してみました。

nn.Linear を、「使わないバージョン」と「使うバージョン」を通して、自動微分についての理解が深まったと思います。

今回紹介したコードでは、入力値が1つの変数である「単回帰」を扱いましたが、ディープラーニングなどでは何百もの変数の値を、 nn.Linear に渡します。

「nn パッケージ」は、とてもありがたいものですね!

こちらの記事もオススメ!



書いた人はこんな人

広告メディア事業部
広告メディア事業部
「好きを仕事にするエンジニア集団」の(株)ライトコードです!

ライトコードは、福岡、東京、大阪の3拠点で事業展開するIT企業です。
現在は、国内を代表する大手IT企業を取引先にもち、ITシステムの受託事業が中心。
いずれも直取引で、月間PV数1億を超えるWebサービスのシステム開発・運営、インフラの構築・運用に携わっています。

システム開発依頼・お見積もり大歓迎!

また、現在「WEBエンジニア」「モバイルエンジニア」「営業」「WEBデザイナー」「WEBディレクター」を積極採用中です!
インターンや新卒採用も行っております。

以下よりご応募をお待ちしております!
https://rightcode.co.jp/recruit

関連記事

採用情報

\ あの有名サービスに参画!? /

バックエンドエンジニア

\ クリエイティブの最前線 /

フロントエンドエンジニア

\ 世界を変える…! /

Androidエンジニア

\ みんなが使うアプリを創る /

iOSエンジニア