• トップ
  • ブログ一覧
  • 【機械学習】CNNで低解像度な画像を高解像度に変換してみる
  • 【機械学習】CNNで低解像度な画像を高解像度に変換してみる

    広告メディア事業部広告メディア事業部
    2020.07.06

    IT技術

    CNNで低解像度な画像を高解像度に変換してみたい!

    前回は、「畳み込みニューラルネットワーク(Convolutional Neural Network:CNN)」を用いて、顔画像のカラー復元を行いました。

    小さな画像サイズで実験を行いましたが、実際の場面では限られたリソース(計算資源)の中で、大きな画像サイズを取り扱う場面も多くあります。

    そこで今回は、CNNを用いて、前回行ったカラー化と画像の高解像度化を組み合わせた実験をしてみたいと思います。

    前回の記事はこちら

    featureImg2020.05.21白黒画像を 畳み込みニューラルネットワーク(CNN)を用いてカラー化する畳み込みニューラルネットワーク(CNN)で白黒画像をカラー化しようカメラができた当初、撮影された画像は、今のようなカラ...

    低解像度画像と高解像度画像

    一般的にボケている画像のことを「低解像度な画像」と呼びます。

    それに対し「高解像度な画像」とは、ボケの少ない画像のことをいいます。

    そもそもボケというのは、画像のピクセル数が減少することで起こります。

    低解像度化のイメージ

    例えば、左下の画像のように適当なブロック分けを行い、その一ブロック内の全画素値の平均値を新たな画素値とすることを、画像全体に行うことでピクセル数を減少させることができます。

    低解像度化は簡単

    右の画像は、元画像に対してピクセル数が減少しているため「低解像度な画像」と言えます。

    このように、「高解像度の画像」から「低解像度な画像」を作ることは比較的容易にできます。

    低解像度画像から高解像度画像への変換

    では逆のパターンで、「低解像度な画像」から「高解像度な画像」を生成することを考えてみましょう。

    先ほどのブロックに分けた例で言うと、「低解像度な画像」の高解像度化は、一つのピクセルから複数のピクセルの値を決定するということが必要になります。

    高解像度化のイメージ

    「高解像度化」の作業は、下図のように解が一意に決まらず、様々なパターンが存在することが容易に想像できます。

    高解像度化は難しい

    言うなれば、先ほどの「高解像度な画像」もその解の一つであるものの、正解かどうかは分かりません。

    そのため、「高解像度化」を人間が手探りで行うのは途方もなく難しい作業なのです。

    そこで、あり得る様々なパターンから最も自然な画像になるパターンをニューラルネットワークに出力してもらうと言うのが、CNN で「高解像度化」をするモチベーションです。

    高解像度化を用いるメリット

    では、「高解像度化」をすることでどんなメリットがあるのでしょう。

    機械学習のボトルネック

    これは前回も話したとおり、機械学習の研究が進んだ現在、ボトルネックとなるのは計算資源です。

    「高解像度な画像」を使って機械学習を行うことができればそれでいいのですが、そうはいきません。

    ただでさえ制限があるメモリに、ニューラルネットワークのパラメータを多くメモリに保存しなければなりません。

    そのため、画像サイズでその足りないメモリの帳尻合わせをする必要があるのです。

    ボトルネック解消のために

    そのため病理画像など、「高解像度な画像」に対しては、かなり小さい画像でしか学習ができません

    そんな中で低解像度で学習したものを「高解像度化」できれば、メモリの節約にもなり、学習結果もより向上するのではないかと言うのが私の考えるメリットです。

    高解像度化のための学習モデル

    では実際に、「学習モデル」について話していきます。

    U-net のような構造にしました。

    また、白黒画像のカラー化のときと同じように、深い層で浅い層の情報を渡し、細部の画像情報と全体的な情報を両方とも加味できるようにしました。

    私の中でこれが一番損失関数が小さくなったので、このモデルを選びました。

    カラー化と高解像度化の併用実験

    検証用データに対する精度

    「低解像度画像」に比べて生成画像はボケが小さい、つまりは「高解像度」になっていることがわかります。

    低解像度画像

    生成画像

    私自身、かなり精度が良くてびっくりしています。

    SRGAN を使わなくてもある程度は精度が出るようです。

    白黒画像のカラー化との併用実験

    次に、白黒画像の「カラー化」と画像の「高解像度化」を併用して実験してみました。

    併用実験の方法

    併用実験の方法としては、以下のように行いました。

    1. 64×64の白黒画像をカラー化する
    2. 変換した64×64のカラー画像を128×128に高解像度化

    こうすることで、実際の実験環境ではメモリの影響で学習できなかった画像サイズでも、カラー画像の変換ができるようになります。

    実験!

    実際に、以下の変換画像の比較を行ってみましょう!

    1. 「64×64のサイズ」で学習したモデルで「128×128の画像」をカラー化した場合
    2. 「64×64の画像」をカラー化した後で「128×128に高解像度化」した場合

    カラー化のみ

    カラー化+高解像度化

    実験結果

    見てわかるように、「カラー化のみ」の結果では、色が正しく塗れていない部分があり、ムラがあります。

    しかし、「カラー化+高解像度化」の結果は、前者と比較してムラが少なく、「64×64のサイズ」で学習したモデルが「128×128の画像サイズ」にも適用できていることがわかります。

    このように、高解像度化は他の実験に併用することも有効であると言えます。

    様々なタスクへの応用

    それは白黒画像のカラー化だけでなく、画像のセグメンテーションなど様々なタスクに応用することができます。

    学習時の画像サイズに限界を感じている機械学習エンジニアの方は、ぜひ参考にしてみてくださ!

    さいごに ~次回予告~

    さて、画像の「カラー化」と「高解像度化」と行ってきました。

    しかし、白黒画像を「カラー化」したモデルでは、よく見ると肌色を顔の近くに塗っているだけで、背景も肌色に塗られていることがわかります。

    正直言って、自然な画像であるとは言えません

    さらに、「カラー化」との併用実験では、画像の「高解像度化」も、実際の画像の解像度に比べてボケていることがわかります。

    これは、「損失関数による問題」です。

    ピクセル単位での二乗誤差をとることで、画像が自然なパターンにならずに全体的にボヤッとした見た目になるからです。

    そこで次回からは、ピクセル単位での誤差に加えて、「GAN」による「Adversarial Loss」を加えることで、画像のボケ除去や色の多様性の追加を行っていきたいと思います!

    こちらの記事もオススメ!

    featureImg2020.07.28機械学習 特集知識編人工知能・機械学習でよく使われるワード徹底まとめ!機械学習の元祖「パーセプトロン」とは?【人工知能】ニューラルネ...

    featureImg2020.07.17ライトコード的「やってみた!」シリーズ「やってみた!」を集めました!(株)ライトコードが今まで作ってきた「やってみた!」記事を集めてみました!※作成日が新し...

    ソースコード

    1import torch
    2import torchvision
    3 
    4import torchvision.datasets as dset
    5from torch import nn
    6from torch.autograd import Variable
    7from torch.utils.data import DataLoader
    8from torchvision import transforms
    9from torchvision.utils import save_image
    10from mpl_toolkits.mplot3d import axes3d
    11from torchvision.datasets import MNIST
    12import os
    13import math
    14import pylab
    15import matplotlib.pyplot as plt
    16 
    17   
    18 
    19num_epochs = 1 #エポック数
    20batch_size = 100 #バッチサイズ
    21learning_rate = 1e-3 #学習率
    22 
    23train = True#学習を行うかどうかのフラグ
    24pretrained =False#事前に学習したモデルがあるならそれを使う
    25save_img = True #ネットワークによる生成画像を保存するかどうのフラグ
    26 
    27def to_img(x):
    28    x = 0.5 * (x + 1)
    29    x = x.clamp(0, 1)
    30    x = x.view(x.size(0), 3, x.shape[2], x.shape[3])
    31    return x
    32def to_img_mono(x):
    33    x = 0.5 * (x + 1)
    34    x = x.clamp(0, 1)
    35    x = x.view(x.size(0), 3,x.shape[2], x.shape[3])
    36    return x
    37 
    38#データセットを調整する関数
    39transform = transforms.Compose(
    40    [transforms.ToTensor(),
    41     transforms.Normalize((0.5, ), (0.5, ))])
    42   
    43#訓練用データセット
    44dataset = dset.ImageFolder(root='./drive/My Drive/face/',
    45                              transform=transforms.Compose([
    46                              transforms.RandomResizedCrop(64, scale=(1.0, 1.0), ratio=(1., 1.)),
    47                              transforms.RandomHorizontalFlip(),
    48                              transforms.ColorJitter(brightness=0.05, contrast=0.05, saturation=0.05, hue=0.05),
    49                              transforms.ToTensor(),
    50                              transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    51                          ])) 
    52 
    53#データセットをdataoaderで読み込み
    54dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    55 
    56#低解像度画像を高解像度化するニューラルネットワーク
    57class SizeDecoder(nn.Module):
    58    def __init__(self):
    59        super(SizeDecoder, self).__init__()
    60        nch_g = 64
    61        self.layer1 = nn.ModuleDict({
    62            'layer0': nn.Sequential(
    63                nn.Conv2d(3, nch_g , 3, 2, 1),    
    64                nn.BatchNorm2d(nch_g),                     
    65                nn.ReLU()                                   
    66            ),  # (1, 64, 64) -> (64, 32, 32)
    67           
    68           
    69        })
    70           
    71        self.layer2 = nn.ModuleDict({
    72           
    73            'layer0': nn.Sequential(
    74                nn.Conv2d(nch_g , nch_g*2 , 3, 2, 1),
    75                nn.BatchNorm2d(nch_g*2),
    76                nn.ReLU()  
    77            ),  # (64, 32, 32) -> (128, 16, 16)
    78          })
    79        self.layer3 = nn.ModuleDict({
    80           
    81            'layer0': nn.Sequential(
    82                nn.Conv2d(nch_g*2 , nch_g*4 , 3, 2, 1),
    83                nn.BatchNorm2d(nch_g*4),
    84                nn.ReLU()  
    85            ),  # (128, 16, 16) -> (256, 8, 8)
    86           
    87          })
    88        self.layer4= nn.ModuleDict({
    89            'layer0': nn.Sequential(
    90                nn.Conv2d(nch_g*4 , nch_g*8 , 3, 2, 1),
    91                nn.BatchNorm2d(nch_g*8),
    92                nn.ReLU()  
    93            ),  # (256, 8, 8) -> (512, 4, 4)
    94          
    95          })
    96           
    97      
    98        self.layer7 = nn.ModuleDict({
    99            'layer0': nn.Sequential(
    100                nn.ConvTranspose2d(nch_g*8 , nch_g*4 , 4, 2, 1),
    101                nn.BatchNorm2d(nch_g*4),
    102                nn.ReLU()  
    103            ),  # (512, 4, 4) -> (256, 8, 8)
    104          })
    105        self.layer8 = nn.ModuleDict({
    106            'layer0': nn.Sequential(
    107                nn.ConvTranspose2d(nch_g*4 , nch_g*2 , 4, 2, 1),
    108                nn.BatchNorm2d(nch_g*2),
    109                nn.ReLU()  
    110            ),  # (256, 8,8) -> (128, 16, 16)
    111          })
    112        self.layer9= nn.ModuleDict({
    113            'layer0': nn.Sequential(
    114                nn.ConvTranspose2d(nch_g*2 , nch_g , 4, 2, 1),
    115                nn.BatchNorm2d(nch_g),
    116                nn.ReLU()  
    117            ),  # (128, 16, 16) -> (64, 32, 32)
    118          })
    119        self.layer10 = nn.ModuleDict({
    120            'layer0': nn.Sequential(
    121                nn.ConvTranspose2d(nch_g,int(nch_g/2)  , 4, 2, 1),
    122                nn.BatchNorm2d(int(nch_g/2)),
    123                nn.Tanh()  
    124            ),  # (64, 32, 32) -> (32, 64, 64)
    125          })
    126        self.layer11 = nn.ModuleDict({
    127            'layer0': nn.Sequential(
    128                nn.ConvTranspose2d(int(nch_g/2) , 3 , 4, 2, 1),
    129                nn.BatchNorm2d(3),
    130                nn.Tanh()  
    131            ),  # (32, 64, 64) -> (3, 128, 128)
    132          })
    133 
    134    def forward(self, z):
    135       
    136        for layer in self.layer1.values(): 
    137            z = layer(z)
    138        z1 =z
    139        for layer in self.layer2.values(): 
    140            z = layer(z)
    141        z2 =z
    142        for layer in self.layer3.values(): 
    143            z = layer(z)
    144        z3=z
    145        for layer in self.layer4.values(): 
    146            z = layer(z)
    147       
    148        for layer in self.layer7.values(): 
    149            z = layer(z)
    150        z =z+z3
    151        
    152        for layer in self.layer8.values(): 
    153            z = layer(z)
    154        z =z+z2
    155        for layer in self.layer9.values(): 
    156            z = layer(z)
    157        z =z+z1
    158 
    159        for layer in self.layer10.values(): 
    160            z = layer(z)
    161 
    162        for layer in self.layer11.values(): 
    163            z = layer(z)
    164        return z
    165 
    166 
    167def main():
    168    #もしGPUがあるならGPUを使用してないならCPUを使用
    169    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    170   
    171    #ネットワークを呼び出し
    172    model = SizeDecoder().to(device)
    173    
    174 
    175    #事前に学習しているモデルがあるならそれを読み込む
    176    if pretrained:
    177        param = torch.load('./Size_Decoder.pth')
    178        model.load_state_dict(param)
    179   
    180    #誤差関数には二乗誤差を使用
    181    criterion = nn.MSELoss()
    182   
    183    #更新式はAdamを適用
    184    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate,
    185   
    186                                 weight_decay=1e-5)
    187   
    188   
    189    loss_train_list = []
    190    loss_test_list= []
    191    for epoch in range(num_epochs):
    192       
    193        print(epoch)
    194       
    195        for data in dataloader:
    196           
    197            img, num = data
    198            #img --> [batch_size,1,32,32]
    199            #imgは元画像
    200            #imgをGPUに載せる
    201            img = Variable(img).to(device)
    202   
    203            # ===================forward=====================
    204           
    205            #_imgは高解像度画像を低解像度に変換した画像
    206            
    207            _img = (img[:,:,::2, ::2] + img[:,:,1::2, ::2] + img[:,:,::2, 1::2] + img[:,:,1::2, 1::2])/4
    208            
    209            
    210            _img =_img.to(device)
    211            
    212            #ネットワークの出力結果
    213            output = model(_img)
    214            print(output.shape)
    215            #もし学習するなら
    216            if train:
    217                #ネットワークの出力と高解像度画像との誤差を損失として学習
    218               
    219                
    220                # ===================backward====================
    221                loss = criterion(output, img)
    222                print(loss)
    223                #勾配を初期化
    224                optimizer.zero_grad()
    225               
    226                #微分値を計算
    227                loss.backward()
    228               
    229                #パラメータを更新
    230                optimizer.step()
    231               
    232                
    233            else:#学習しないなら
    234                break
    235        # ===================log========================
    236          
    237        if train == True:
    238                #モデルを保存
    239                torch.save(model.state_dict(), './Size_Decoder.pth')
    240   
    241    
    242    #もし生成画像を保存するなら
    243    if save_img:
    244        value = int(math.sqrt(batch_size))
    245        
    246        pic = to_img(img.cpu().data)
    247        pic = torchvision.utils.make_grid(pic,nrow = value)
    248        save_image(pic, './real_image_{}.png'.format(epoch))  #元画像を保存
    249       
    250        pic = to_img_mono(output.cpu().data)
    251        pic = torchvision.utils.make_grid(pic,nrow = value)
    252        save_image(pic, './image_{}.png'.format(epoch))  #生成画像   
    253if __name__ == '__main__':
    254    main()

     

    広告メディア事業部

    広告メディア事業部

    おすすめ記事