• トップ
  • ブログ一覧
  • 教師なし機械学習「VAE」による連続的な手書き文字の生成
  • 教師なし機械学習「VAE」による連続的な手書き文字の生成

    広告メディア事業部広告メディア事業部
    2020.06.19

    IT技術

    認識モデルから生成モデルへ

    以前まで、手書き文字の認識は、難しいタスクであると考えられてきました。

    しかし、ニューラルネットワークの開発が進んだ2020年現在では、比較的、簡単な課題へと変化しました。

    手書き文字の認識モデル

    このような「認識モデル」では、手書き文字などのデータの分布を考慮せず、与えられたデータそのものから、直接識別境界を引いていきます

    手書き文字の生成モデル

    それに対し「生成モデル」は与えられたデータ群の分布そのものを求めていきます。

    生成モデルのメリット

    生成モデルでは、分布が分かっているため、任意のデータを生成できます。

    これは、機械学習タスクの弱点の一つ、データ不足を補えます

    また、データを連続的に変化させることが可能となり、モーフィングを行うことができるようになります。

    Variational AutoEncoder(VAE)を用いて顔画像の表情・向きを変化させる

    例えば、映画やアニメなどで、キャラクターが「普通の顔」から「笑顔」に変化していくアニメーションを作成するとしましょう。

    このとき、様々な表情を学習した生成モデルを用いれば、イラストレーターなしに、アニメーションの連続的な変化を作成することができます。

    VAE で同一人物の顔画像を学習させ変化させた図

    以下は、VAE により、同一人物の顔画像を学習させ、連続的な表情・顔の向きを変化させた図です。

    AutoEncoder と Variational AutoEncoder(VAE)

    前回、AutoEncoder を用いた次元削減について話しましたが、AutoEncoder では、データと潜在変数の関係を 1:11:1 として次元埋め込みをしていました。

    「Variational AutoEncoder(VAE)」では、潜在変数に確率的なブレを与えることで、与えられたデータ群の分布を推定し、連続的な画像生成を可能にします。

    つまり、VAE は、生成モデルの一つとなります。

    AutoEncoder による次元削減の記事

    なお、AutoEncoder を用いた次元削減については、以下の記事で解説しています。

    featureImg2020.06.01「AutoEncoder」から見る機械学習の次元削減の意味AutoEncoder から見る機械学習の次元削減の意味とは「オッカムの剃刀」という言葉をご存知ですか?「オッカムの剃...

    Variational AutoEncoder(VAE) の理論

    ネットワーク構造

    さて、「生成モデルは、データの分布を求めること」と述べました。

    今後は、データを

    X=x1,x2,,xnxii番目のピクセルの画素値}X = { x_{1} , x_{2} , \ldots , x_{n} | x_{i} は i 番目のピクセルの画素値 }

    として、その確率分布を P(X)P(X) とします。

    そして、AutoEncoder の時にもお話したように、画像のような高次元データは、ほとんどの画素値が冗長であり、周りの画素値で補完できます

    そのため、実際には、データはより低次元に分布するはずです。

    「0~9」の手書き文字画像を3次元の潜在変数に次元削減する

    以下は、AutoEncoder により、「0~9」の手書き文字画像を3次元の潜在変数に次元削減したときの例です。

    このような低次元の潜在変数を zz とし、その確率分布を P(z)P(z) とします。

    潜在変数 z をニューラルネットワークで求める

    AutoEncoder 同様、潜在変数 zz をニューラルネットワークによって求めていきます

    しかし、AutoEncoder のように、直接 zz を求めるわけではありません

    分布 P(zX)P(z|X) が正規分布に従うとし、ニューラルネットワークは、zz をサンプリングするため、正規分布の平均 μ(X)μ(X) と、分散 σ(X)σ(X) を出力します。

    (詳しくは後で解説します)

    ニューラルネットワークの出力から平均と分散を求め、N(μ(X),σ(X))N(μ(X),σ(X)) から zz をサンプリングし、その zz から入力データ XX の復元を行います

    そもそもの目的

    ネットワーク構造は、比較的簡単にお話しできましたが、損失関数の理解は難しいです。

    VAE を実装するにあたって、私たちが考えるべき損失関数とは何でしょうか。

    生成モデルの目的は、データの分布、すなわち P(X)P(X) を求めることでした。

    尤度(ゆうど)の最大化をすることで確率分布を推定する

    確率分布を求める方法として、尤度の最大化があげられます。

    確率分布 P(X)P(X) が、何らかのパラメータ θθ(普通は平均とか分散をあらわす)で表されているとします。

    最尤推定で確率分布を求める際の尤度関数は、同時確率分布に等しく、以下のようになります。

    Pθ(X)=i=1nPθ(xi)P_{\theta}(X) = \prod_{i = 1}^n P_{\theta} (x_{i})

    また、尤度関数は、普通対数を取ります。

    尤度が確率の掛け算であるため、対数を取ることで、微分の計算が和で済むからです。

    log(Pθ(X))=i=1nlog(Pθ(xi))\log (P_{\theta} (X)) = \sum_{i = 1}^n \log (P_{\theta} (x_{i})) \dots ①

    この尤度関数を最大化するようなパラメータθθを求めることで、最もデータXXの生成分布に近い分布が得られます。

    argmaxθi=1nlog(Pθ(xi))ar g \max_{\theta} \sum_{i =1}^n \log (P_{\theta}(x_{i}))

    損失関数

    対数尤度関数の最大化が目的となったので、さらに式を変形します。

    観測データは、潜在変数により生成されたと考えることもできるため、zz により P(X)P(X) を周辺化し、式変形すると以下のようになります。

    log(Pθ(X))=i=1nlog(Pθ(xi))=i=1nlog(Pθ(xi,z)dz)=i=1nlog(Pθ(zxi)Pθ(xi,z)Pθ(zxi)dz)\begin{aligned} \log(P_{\theta}(X)) &=& \sum_{i = 1}^n \log(P_{\theta}(x_{i})) \dots ① \\ &=& \sum_{i = 1}^n \log \begin{pmatrix} \displaystyle \int P_{\theta} (x_{i} , z) dz \end{pmatrix} \\ &=& \sum_{i = 1}^n \log \begin{pmatrix} \displaystyle \int \frac{P_{\theta} (z|x_{i})P_{\theta}(x_{i} , z)}{P_{\theta}(z|x_{i})} dz \end{pmatrix} \\ \end{aligned}

    ここで P(Xz)P(X|z) は、潜在変数 zz が与えられたときの復元データの分布です。

    そのため、ニューラルネットワークでいう Decoder 部は、この分布に基づいてサンプリングされたものが出力されます。


    また逆に、P(zX)P(z|X) はデータが与えられたときの潜在変数の分布であり、Decoder 同様、ネットワークの Encoder 部分は、この分布に基づいて zz が観測されます。

    AutoEncoder では、データ XX と潜在変数 zz の関係を点で求めていました。

    しかし、VAE では、 P(zX)P(z|X) を求めることでデータの低次元な分布を得ることができます

    そして、このP(zX)P(z|X)を近似することで、対数尤度の最大化を解いていきます

    近似分布の平均と分散を推定する

    近似した分布 q(zX)q(z|X) は、正規分布に従うとし、先ほど述べたように、その平均と分散パラメータを Encoder により推定していきます。


    q(zX)q(z|X) で分布を仮定したら、イェンセンの不等式から対数尤度の下限を求めます

    i=1nlog(Pθ(zxi)Pθ(xi,z)Pθ(zxi)dz)i=1nlog(qφ(zxi)Pθ(xi,z)qφ(zxi)dz)i=1nqφ(zxi)log(Pθ(xi,z)qφ(zxi))dz=i=1nL(xi,θ,φ)\begin{aligned} \displaystyle \sum_{i = 1}^n \log \begin{pmatrix} \displaystyle \int \frac{P_{\theta}(z|x_{i})P_{\theta}(x_{i},z)}{P_{\theta}(z|x_{i})}dz\end{pmatrix} &\approx& \displaystyle \sum_{i = 1}^n \log \begin{pmatrix} \displaystyle \int \frac{q_{\varphi}(z|x_{i})P_{\theta}(x_{i},z)}{q_{\varphi}(z|x_{i})}dz \end{pmatrix}\\ &≧& \displaystyle \sum_{i = 1}^n \displaystyle \int q_{\varphi}(z|x_{i}) \log \begin{pmatrix} \frac{P_{\theta}(x_{i} , z)}{q_{\varphi}(z|x_{i})}\end{pmatrix}dz\\ &=& \displaystyle \sum_{i = 1}^n L(x_{i},\theta,\varphi) \dots ② \end{aligned}

    ②の下限は、変分下限(Variational Lower Bound)と言われており、VAE がそう呼ばれることの根拠となっています

    対数尤度の最大化をするには、この下限を押し上げればよいということになります。

    対数尤度①と変分下限②の差を計算

    ここで、対数尤度①と、右辺に出てきた変分下限②の差を計算してみましょう。

    i=1nlog(Pθ(xi))L(xi,θ,φ)=i=1nlog(Pθ(xi))qφ(zxi)dzqφ(zxi)log(Pθ(xi,z)qφ(zxi))dz=i=1nqφ(zxi)(log(Pθ(xi))log(Pθ(xi,z))+log(qφ(zxi)))dz=i=1nqφ(zxi)(log(Pθ(xi))log(Pθ(zxi))log(Pθ(xi))+log(qφ(zxi)))dz=i=1nqφ(zxi)(log(qφ(zxi))log(Pθ(zxi)))dz=i=1nKL[qφ(zxi)Pθ(zxi)]\begin{aligned} &&\sum_{i = 1}^n\log (P_{\theta}(x_{i})) - L(x_{i},\theta,\varphi)\\ &=& \sum_{i = 1}^n \log (P_{\theta}(x_{i}))\displaystyle \int q_{\varphi}(z|x_{i})dz - \displaystyle \int q_{\varphi}(z|x_{i}) \log \begin{pmatrix} \frac{P_{\theta}(x_{i} , z)}{q_{\varphi}(z|x_{i})} \end{pmatrix}dz\\ &=& \sum_{i = 1}^n \displaystyle \int q_{\varphi}(z|x_{i})\begin{pmatrix} \log (P_{\theta}(x_{i})) - \log (P_{\theta}(x_{i},z)) + \log (q_{\varphi}(z|x_{i}))\end{pmatrix}dz\\ &=& \sum_{i = 1}^n \displaystyle \int q_{\varphi}(z|x_{i})\begin{pmatrix} \log (P_{\theta}(x_{i})) - \log (P_{\theta}(z|x_{i})) - \log (P_{\theta}(x_{i})) + \log(q_{\varphi}(z|x_{i}))\end{pmatrix}dz\\ &=& \sum_{i = 1}^n \displaystyle \int q_{\varphi}(z|x_{i})\begin{pmatrix} \log (q_{\varphi}(z|x_{i})) - \log (P_{\theta}(z|x_{i}))\end{pmatrix}dz\\ &=& \sum_{i = 1}^n KL[q_{\varphi}(z|x_{i})||P_{\theta}(z|x_{i})] \end{aligned}

    最後の式は、KL ダイバージェンスと呼ばれるもので、分布間の距離を表す指標です。

    KL ダイバージェンスは、2つの分布が全く同じであるならば「0」を示します。

    つまり、対数尤度は、以下の2つの項で表されることが分かりました。

    i=1nlog(Pθ(xi))=i=1nL(xi,θ,φ)+i=1nKL[qφ(zxi)Pθ(zxi)]\sum_{i = 1}^n \log (P_{\theta}(x_{i})) = \sum_{i = 1}^n L(x_{i},\theta,\varphi) + \sum_{i = 1}^n KL[q_{\varphi}(z|x_{i})||P_{\theta}(z|x_{i})]

    第2項目は、近似精度がよくなれば、おのずと「0」に近づく非負値です。

    そのため、対数尤度の最大化を行うには、変分下限の最大化を行えばよいことになります。

    argmaxθi=1nlog(Pθ(xi))=argmaxθ,φi=1nL(xi,θ,φ)ar g \max_{\theta} \sum_{i =1}^n \log (P_{\theta}(x_{i})) = ar g \max_{\theta,\varphi} \sum_{i =1}^n L(x_{i},\theta,\varphi)

    変分下限を式変形していくと、ようやくお目当ての損失関数が得られます。

    i=1nlog(Pθ(xi))qφ(zxi)(log(qφ(zxi))log(Pθ(z))log(Pθ(xiz))+log(Pθ(xi)))dz=i=1nqφ(zxi)(log(qφ(zxi))log(Pθ(z))log(Pθ(xiz)))dz=i=1nqφ(zxi)log(Pθ(xiz))dzKL[qφ(zxi)Pθ(z)]\small \begin{aligned} &\sum_{i = 1}^n& \log (P_{\theta}(x_{i}))-\displaystyle \int q_{\varphi}(z|x_{i})\begin{pmatrix} \log(q_{\varphi}(z|x_{i}))-\log(P_{\theta}(z))-\log(P_{\theta}(x_{i}|z))+\log(P_{\theta}(x_{i})) \end{pmatrix}dz\\ &=& \sum_{i = 1}^n - \displaystyle \int q_{\varphi}(z|x_{i})\begin{pmatrix} \log(q_{\varphi}(z|x_{i})) - \log(P_{\theta}(z)) -\log(P_{\theta}(x_{i}|z)) \end{pmatrix}dz\\ &=& \sum_{i = 1}^n \displaystyle \int q_{\varphi}(z|x_{i})\log(P_{\theta}(x_{i}|z))dz - KL[q_{\varphi}(z|x_{i})||P_{\theta}(z)] \end{aligned}

    この式を最大化することで、VAE は学習を進めていきます。

    実装上の損失関数

    損失関数は、このままでは積分計算が入っているため、積分が入らない形に変形します。

    1項目の KL ダイバージェンスについて

    まずは、1項目の KL ダイバージェンスについて説明します。

    正規分布同士の KL ダイバージェンスは、以下のように計算されます。

    12j=1J(1+log(σj(X))μj(X)2σj(X)2\frac{1}{2}\sum_{j = 1}^J(1 + \log(\sigma_{j}(X))-\mu_{j}(X)^2 - \sigma_{j}(X)^2

    JJ は、潜在変数の次元数です。

    積分計算をゴリゴリやれば出るので、式は割愛します。

    2項目の積分はサンプリング近似を実行

    2項目に関して、積分計算を解くのが難しいので、サンプリング近似を行います。

    qφ(zxi)log(Pθ(xiz))dz1Li=1Llog(Pθ(xizi))\displaystyle \int q_{\varphi}(z|x_{i})\log(P_{\theta}(x_{i}|z))dz \approx \frac {1}{L} \sum_{i = 1}^L \log (P_{\theta}(x_{i}|z_{i}))

    これは、確率分布 q(zX)q(z|X) に対する期待値計算であるため、有限個のサンプルの平均で近似してしまおうというものです。

    また、mnist データは、普通「0, 1」 のデータで表されているため、分布には、ベルヌーイ分布を仮定します。

    さらに、学習時のバッチサイズが大きければ、L=1L=1 で十分です。

    以上を元に、近似した値を求めていくと、以下のようになります。

    log(Pθ(xizi))=log(f(zi)xi(1f(zi)1xi))=xilog(f(zi))+(1xi)log(1f(zi))\begin{aligned} \log(P_{\theta}(x_{i}|z_{i})) &=& \log(f(z_{i})^{x_{i}}(1-f(z_{i})^{1-x_{i}}))\\ &=& x_{i}\log(f(z_{i}))+(1-x_{i})\log(1-f(z_{i})) \end{aligned}

    ただし、ffは、潜在変数を入力とする Decoder の関数、つまりは、再構成後の画素値を出力します。

    以上から、第2項の損失関数は、以下のように表されます。

    i=1nxilog(f(zi))+(1xi)log(1f(zi))\sum_{i = 1}^n x_{i} \log (f(z_{i}))+(1-x_{i}) \log(1-f(z_{i}))

    実験結果

    画像の再構成

    AutoEncoder と同様に、再構成をしてみました。

    元画像

    再構成結果

    潜在変数が二次元なので、間違えている部分もありますね。

    VAE だからと言って、AutoEncoder よりも表現力が上がるわけではなさそうです

    それに画像がぼやけています。

    VAE に限らず、ピクセル単位で誤差を計算するモデルは、すべて画像がぼけやすいです。

    潜在変数の分布の可視化

    潜在変数の散らばりも見てみましょう!

    AutoEncoder のように無秩序ではなく、分布が正規分布に引き寄せられるようになりました。

    この分布であれば、本来は、非線形であるそれぞれの数字の分布を線形分離することもできそうですね。

    連続的な数字画像の生成

    では今度は、この二次元に落とし込んだ潜在変数上で、画像を滑らかに変化させてみましょう!

    潜在変数の値を徐々に変化させ、その値を Decoder に通して画像を生成しました。

    「2」⇒「4」⇒「9」⇒「3」⇒「5」⇒「3」⇒「8」と、左上から右下にジグザグに見ていくことで、数字が徐々に変化していく様子が分かります。

    さいごに

    確率モデルを扱う VAE を用いることで、画像を連続的に変化をさせることが可能になりました

    今回は、数字のみで実験しましたが、VAE を用いれば、顔画像であろうと衣類であろうと様々なものを連続的に変化させ、モーフィングさせることができます

    面白いので、ぜひ、ご自身の手でも確かめてみてくださいね!

    こちらの記事もオススメ!

    featureImg2020.07.28機械学習 特集知識編人工知能・機械学習でよく使われるワード徹底まとめ!機械学習の元祖「パーセプトロン」とは?【人工知能】ニューラルネ...

    featureImg2020.07.17ライトコード的「やってみた!」シリーズ「やってみた!」を集めました!(株)ライトコードが今まで作ってきた「やってみた!」記事を集めてみました!※作成日が新し...

    ソースコード全体

    1import torch
    2import torchvision
    3
    4import torchvision.datasets as dset
    5from torch import nn
    6from torch.autograd import Variable
    7from torch.utils.data import DataLoader
    8from torchvision import transforms
    9from torchvision.utils import save_image
    10from mpl_toolkits.mplot3d import axes3d
    11from torchvision.datasets import MNIST
    12import torch.nn.functional as F
    13import os
    14import pylab
    15import math
    16import matplotlib.pyplot as plt
    17
    18#gpuデバイスがあるならgpuを使う
    19#ないならcpuで
    20device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    21
    22#カレントディレクトリにdc_imgというフォルダが作られる
    23if not os.path.exists('./dc_img'):
    24    os.mkdir('./dc_img')
    25    
    26#カレントディレクトリにdataというフォルダが作られる
    27if not os.path.exists('./data'):
    28    os.mkdir('./data')
    29    
    30
    31num_epochs = 5#エポック数 
    32batch_size = 30  #バッチサイズ
    33learning_rate = 1e-3   #学習率
    34
    35train =True#Trueなら訓練用データ、Falseなら検証用データを使う
    36pretrained =False #学習済みのモデルを使うときはここをTrueに
    37latent_dim = 2  #最終的に落とし込む次元数
    38save_img = True   #元画像と再構成画像を保存するかどうか、バッチサイズが大きいときは保存しない方がいい
    39
    40def to_img(x):
    41    x = x
    42    x = x.clamp(0, 1)
    43    
    44    return x
    45
    46#画像データを前処理する関数
    47transform = transforms.Compose([
    48    
    49    transforms.RandomResizedCrop(32, scale=(1.0, 1.0), ratio=(1., 1.)),
    50    transforms.ToTensor(), 
    51    ])
    52    
    53#このコードで自動で./data/以下にm-nistデータがダウンロードされる
    54trainset = torchvision.datasets.MNIST(root='./data/', 
    55                                        train=True,
    56                                        download=True,
    57                                        transform=transform)
    58
    59#このコードで自動で./data/以下にm-nistデータがダウンロードされる
    60testset = torchvision.datasets.MNIST(root='./data/', 
    61                                        train=False, 
    62                                        download=True, 
    63                                        transform=transform)
    64#学習時なら訓練用データを用いる
    65if train:
    66    dataloader = DataLoader(trainset,
    67                            
    68                            batch_size=batch_size, shuffle=True)
    69    
    70#テスト時なら検証用データを用いる
    71else:
    72    dataloader = DataLoader(testset, batch_size=batch_size, shuffle=True)
    73
    74#ネットワーク定義
    75class VAE(nn.Module):
    76    def __init__(self, z_dim):
    77        super(VAE, self).__init__()
    78        
    79        self.conv1 = nn.Sequential(
    80            nn.Conv2d(1, 64, 3, stride=1, padding=1),  # b, 64, 32, 32
    81            nn.BatchNorm2d(64),
    82            nn.LeakyReLU(0,True),            
    83            nn.MaxPool2d(2)  # b, 64, 16, 16
    84        )
    85        
    86        self.conv2 = nn.Sequential(
    87            nn.Conv2d(64, 128, 3, stride=1, padding=1),  # b, 128, 16, 16
    88            nn.BatchNorm2d(128),
    89            nn.LeakyReLU(0,True),            
    90            nn.MaxPool2d(2)  # b, 128, 8, 8
    91        )
    92        
    93        self.conv3 = nn.Sequential(
    94            nn.Conv2d(128, 256, 3, stride=1, padding=1),  # b, 256, 8, 8
    95            nn.BatchNorm2d(256),
    96            nn.LeakyReLU(0,True),            
    97            nn.MaxPool2d(2)  # b, 256, 4, 4
    98        )
    99        
    100        self.conv4 = nn.Sequential(
    101            nn.Conv2d(256, 512, 4, stride=1, padding=0),  # b, 512, 1, 1
    102            nn.BatchNorm2d(512),
    103            nn.LeakyReLU(0,True),
    104        )
    105            
    106
    107        self.mean = nn.Sequential(
    108            nn.Linear(512,latent_dim),# b, 512 ==> b, latent_dim
    109            )
    110        
    111        self.var = nn.Sequential(
    112            nn.Linear(512,latent_dim),# b, 512 ==> b, latent_dim
    113            )
    114        self.decoder = nn.Sequential(
    115            nn.Linear(latent_dim,512),# b, latent_dim ==> b, 512
    116            nn.BatchNorm1d(512),
    117            nn.LeakyReLU(0,True),
    118            )
    119        self.convTrans1 = nn.Sequential(
    120            nn.ConvTranspose2d(512, 256, 4, stride=2,padding = 0),  # b, 256, 4, 4
    121            nn.BatchNorm2d(256),
    122            nn.LeakyReLU(0,True),
    123            )
    124        self.convTrans2 = nn.Sequential(
    125            nn.ConvTranspose2d(256, 128, 4, stride=2,padding = 1),  # b, 128, 8, 8
    126            nn.BatchNorm2d(128),
    127            nn.LeakyReLU(0,True),
    128            )
    129        self.convTrans3 = nn.Sequential(
    130            nn.ConvTranspose2d(128, 64, 4, stride=2,padding = 1),  # b, 64, 16, 16
    131            nn.BatchNorm2d(64),
    132            nn.LeakyReLU(0,True),
    133            )
    134        self.convTrans4 = nn.Sequential(
    135            nn.ConvTranspose2d(64, 1, 4, stride=2,padding = 1),  # b, 3, 32, 32
    136            nn.BatchNorm2d(1),
    137            nn.Sigmoid()
    138            )
    139        
    140    #Encoderの出力に基づいてzをサンプリングする関数
    141    #誤差逆伝搬ができるようにreparameterization trickを用いる
    142    def _sample_z(self, mean, var):
    143      std = var.mul(0.5).exp_()
    144      eps = Variable(std.data.new(std.size()).normal_())
    145      return eps.mul(std).add_(mean)
    146    
    147    #Encoder
    148    def _encoder(self, x):
    149        x = self.conv1(x)      
    150        x = self.conv2(x)    
    151        x = self.conv3(x)     
    152        x = self.conv4(x)   
    153        x = x.view(-1,512)
    154        mean = self.mean(x)
    155        var = self.var(x)
    156        return mean,var
    157    
    158    #Decoder
    159    def _decoder(self, z):
    160      z = self.decoder(z)
    161      z = z.view(-1,512,1,1)
    162      x = self.convTrans1(z)
    163      x = self.convTrans2(x)
    164      x = self.convTrans3(x)
    165      x = self.convTrans4(x)
    166      return x
    167  
    168    def forward(self, x):
    169        # xは元画像 
    170        mean,var = self._encoder(x) #Decoderの出力はlog σ^2を想定
    171        z = self._sample_z(mean, var) #潜在変数の分布に基づいてzをサンプリング
    172        x = self._decoder(z) #サンプリングしたzに対して画像を再構成
    173        return x,mean,var,z
    174    
    175    def loss(self, x):
    176      mean, var = self._encoder(x) #Decoderの出力はlog σ^2を想定
    177      
    178      KL = -0.5 * torch.mean(torch.sum(1 + var- mean**2 - var.exp())) #KLダイバージェンス
    179      
    180      z = self._sample_z(mean, var) #潜在変数の分布に基づいてzをサンプリング
    181      y = self._decoder(z) #サンプリングしたzに対して画像を再構成
    182      delta = 1e-7 #logの中身がマイナスにならないように微小な値を与える
    183      reconstruction = torch.mean(torch.sum(x * torch.log(y+delta) + (1 - x) * torch.log(1 - y +delta))) #再構成誤差
    184      lower_bound = [-KL, reconstruction]                           
    185      return -sum(lower_bound),y,mean,var,z
    186
    187def main():
    188    #ネットワーク宣言
    189    model = VAE(latent_dim).to(device)
    190    
    191    #事前に学習したモデルがあるならそれを使う
    192    if pretrained:
    193        param = torch.load('./conv_Variational_autoencoder_{}dim.pth'.format(latent_dim))
    194        model.load_state_dict(param)
    195    
    196    
    197    #最適化法はAdamを選択
    198    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate,
    199    
    200                                 weight_decay=1e-5)
    201    
    202    for epoch in range(num_epochs):
    203        itr = 0
    204        print(epoch)
    205        for data in dataloader:    
    206            itr+=1
    207            img, num = data            
    208            #img --> [batch_size,1,32,32]
    209            #num --> [batch_size,1]
    210            #imgは画像本体
    211            #numは画像に対する正解ラベル
    212            #ただし、学習時にnumは使わない
    213                
    214            #imgをデバイスに乗っける
    215            img = Variable(img).to(device)    
    216            # ===================forward=====================
    217            
    218            #outputが再構成画像、latentは次元削減されたデータ
    219            if train == False:
    220                output,mu,var,latent = model(img)            
    221            
    222            #学習時であれば、ネットワークパラメータを更新
    223            if train:
    224                #lossを計算
    225                #元画像と再構成後の画像が近づくように学習
    226                loss,output,mu,var,latent = model.loss(img)        
    227                # ===================backward====================
    228                #勾配を初期化
    229                optimizer.zero_grad()                
    230                #微分値を求める
    231                loss.backward()                
    232                #パラメータの更新
    233                optimizer.step()                
    234                print('{} {}'.format(itr,loss))
    235        # ===================log========================
    236            
    237    #データをtorchからnumpyに変換
    238    z = latent.cpu().detach().numpy()
    239    num = num.cpu().detach().numpy()
    240    
    241    #次元数が3の時のプロット
    242    if latent_dim == 3:
    243        fig = plt.figure(figsize=(15, 15))
    244        ax = fig.add_subplot(111, projection='3d')
    245        ax.scatter(z[:, 0], z[:, 1], z[:, 2], marker='.', c=num, cmap=pylab.cm.jet)
    246        for angle in range(0,360,60):
    247            ax.view_init(30,angle)
    248            plt.savefig("./fig{}.png".format(angle))
    249        
    250    #次元数が2の時のプロット
    251    if latent_dim == 2:
    252        plt.figure(figsize=(15, 15))
    253        plt.scatter(z[:, 0], z[:, 1], marker='.', c=num, cmap=pylab.cm.jet)
    254        plt.colorbar()
    255        plt.grid()
    256        plt.savefig("./fig.png")
    257    
    258    #元画像と再構成後の画像を保存するなら
    259    if save_img:
    260        value = int(math.sqrt(batch_size))
    261        pic = to_img(img.cpu().data)
    262        pic = torchvision.utils.make_grid(pic,nrow = value)
    263        save_image(pic, './dc_img/real_image_{}.png'.format(epoch))  #元画像の保存
    264        
    265        pic = to_img(output.cpu().data)
    266        pic = torchvision.utils.make_grid(pic,nrow = value)
    267        save_image(pic, './dc_img/image_{}.png'.format(epoch))  #再構成後の画像の保存
    268        
    269    #もし学習時ならモデルを保存
    270    #バージョン管理は各々で
    271    if train == True:
    272        torch.save(model.state_dict(), './conv_Variational_autoencoder_{}dim.pth'.format(latent_dim))
    273        
    274if __name__ == '__main__':
    275    main()

     

    広告メディア事業部

    広告メディア事業部

    おすすめ記事