• トップ
  • ブログ一覧
  • PyTorchでデータオーグメンテーションを試そう
  • PyTorchでデータオーグメンテーションを試そう

    広告メディア事業部広告メディア事業部
    2020.04.21

    IT技術

    PyTorchでデータオーグメンテーションを試してみる

    機械学習、特にディープラーニングでは、学習データの量が重要であることは、ご承知のとおりだと思います。

    しかし、大量の学習データを用意するには、金銭的にも時間的にもコストがかかります。

    そのため、学習データをランダムに変更することによって、データを水増し(オーグメント: augment )することがよく行われます

    データオーグメンテーションは、かねてより研究されてきましたが、ディープラーニングの台頭によって、研究は勢いを増し、様々な手法が提案されています。

    今回は、特に画像分類タスクに興味を絞り、いくつかの手法を紹介します。

    最新手法の実装

    多くの手法は、 torchvision.transforms に実装されていたり、組み合わせで実現できます。

    しかし、まだ実装のない最新手法を実装し、実際にディープラーニングモデルを学習させて、結果を比較検討します。

    注意点

    今回の記事では、「glob」「joblib」「numpy」「torch」「torchvision」 がインストール済みであることを前提としております。

    下記の内容をインポートしておきます。

    1import os
    2import glob
    3import numpy as np
    4from sklearn.metrics import accuracy_score
    5from sklearn.model_selection import KFold
    6
    7import torch
    8import torch.nn as nn
    9import torch.utils
    10from torch import optim
    11import torchvision
    12from torch.optim import lr_scheduler
    13from torchvision.datasets import ImageFolder
    14from torchvision import transforms

    データセット

    データオーグメンテーションの手法を説明する前に、今回使用するデータセット, 「Animal -10」を紹介します。

    「Animal -10」は犬・猫・蝶など、10種類の動物の画像データセットです。

    【Animal -10(GPL-2)】
    https://www.kaggle.com/alessiocorrado99/animals10

    このような画像が、28000枚ほど含まれています。

    大きさも縦横比もまちまちです。

    zip ファイルを解凍すると、「raw-img」というフォルダの下に、動物名(スペイン語)のフォルダがあり、その中に jpeg 画像が入っています。

    「象」がラベルであるサンプルが1446個、「犬」がラベルであるサンプルが4863個と、バランスの悪いデータセットなので、「象」に合わせて他のクラスの画像は減らします。

    クラスごとにフォルダが分けられたデータ

    1for directory in sorted(glob.glob('raw-img/*')):
    2    files = sorted(glob.glob(directory + '/*'))
    3    print(len(files), directory)
    4    for cnt, f in enumerate(files): # 1446 raw-img/elefante
    5        if cnt >= 1446:
    6            os.remove(f)

    さて、このようにクラスごとにフォルダが分けられたデータがあるとき、torchvision.datasets.ImageFolder によって簡単に PyTorch 用のデータセットを得ることができます。

    また、ds = ImageFolder('raw-img/') により、ds というインスタンスが得ることができます。

    例えば ds[0] とすれば (0番目のPIL形式の画像, 0番目のラベル) というタプルが得られます。

    データオーグメンテーション手法

    まず、何もデータオーグメンテーションを行わない場合を見てみましょう。

    Baseline

    仮に、「224x224の画像を入力」とするモデルを考えると、シンプルに「元の画像を224x224にリサイズする」というのが、最も直感的です。

    torchvision.transforms.Resize((h, w)) によって、__call__(Input) されると、 Input を「高さ h 」、「幅 w 」に変換するインスタンスが得られます。

    以下、このベースラインにデータオーグメンテーション手法を適用することにしましょう。

    左右反転

    画像をランダムに左右反転させます。

    transforms.RandomHorizontalFlip によって実現できます。

    1transform = transforms.Compose([
    2    transforms.Resize((224, 224)),
    3    transforms.RandomHorizontalFlip(),
    4])

    と、torchvision.transforms.Compose を使うと、画像の変換の組み合わせが簡単に書けます。

    変換後の画像

    フリップはランダムに起こるので、「Baseline」と同じ画像が得られることもあります。

    Random Erasing ( Z Zhong et al., 2017, arXiv )

    「Random Erasing」は下図のように、四角形で画像をマスクするデータオーグメンテーションです。

    四角形の大きさや個数はランダムです。

    マスク後の画像

    コード

    1transform = transforms.Compose([
    2    transforms.Resize((224, 224)),
    3    transforms.ToTensor(),
    4    transforms.RandomErasing(),
    5])

    「 torchvision 」に実装されていますが、torchvision.transforms.RandomErasing の引数は torch.Tensor なので、torchvision.transforms.ToTensor によって変換しておかなければなりません。

    ややこしいですね。

    「 RandomErasing 」の発生確率やマスクの最大サイズなどは、与える引数でコントロールできます。

    【論文のリンク】
    https://arxiv.org/pdf/1708.04896.pdf

    画像の情報が失われてしまう場合

    Random Erasing によって画像の情報が失われてしまうことがあります。

    例えば、下図は Random Erasing のマスクが、象を覆い尽くしてしまった例です。

    このような状況でも、学習モデルはこの画像を象と判定するように学習しますが、これによって性能が向上するとは考えづらいです。

    情報が失われた画像

    この問題意識から、次に紹介する「GridMask」が開発されました。

    GridMask ("GridMask Data Augmentation", P. Cheng et al., 2020, arXiv)

    「GridMask」は、下図のように、小さめの正方形のマスクを等間隔に並べて、元画像をマスクします。

    GridMask には4つのパラメータがあります。

    まず、dd はマスクの間隔を表すパラメータです。

    さらに rr は、どれほど元の画像を残すかを決めるパラメータで、r=0r=0 なら画像は全てマスクされ、r=1r=1 なら全くマスクされません。

    δx,δy\delta_x,\delta_y は、オフセットです。

    rr は、予め決められています。

    dd は、ハイパーパラメータとして、与えられた範囲(実装では d_range )から、δx, deltay\delta_x, \ delta_y は [0, d-1] から、画像ごとにランダムに選ばれます。

    GridMaskを用いた画像

    GridMaskを自前で実装する

    さて、GridMask はまだ torchvision に実装されていないので、自前で実装してみましょう。

    transformは__init__ にハイパーパラメータを渡し、__call__ に実際の処理を書くだけで実装できます。

    実装内容

    実装内容は、コチラになります。

    1class GridMask():
    2    def __init__(self, p=0.6, d_range=(96, 224), r=0.6):
    3        self.p = p
    4        self.d_range = d_range
    5        self.r = r
    6        
    7    def __call__(self, sample):
    8        """
    9        sample: torch.Tensor(3, height, width)
    10        """
    11        if np.random.uniform() > self.p:
    12            return sample
    13        sample = sample.numpy()
    14        side = sample.shape[1]
    15        d = np.random.randint(*self.d_range, dtype=np.uint8)
    16        r = int(self.r * d)
    17        
    18        mask = np.ones((side+d, side+d), dtype=np.uint8)
    19        for i in range(0, side+d, d):
    20            for j in range(0, side+d, d):
    21                mask[i: i+(d-r), j: j+(d-r)] = 0
    22        delta_x, delta_y = np.random.randint(0, d, size=2)
    23        mask = mask[delta_x: delta_x+side, delta_y: delta_y+side]
    24        sample *= np.expand_dims(mask, 0)
    25        return torch.from_numpy(sample)

    Mobius Transform ("Data augmentation with Mobius transformations", Zhou et al., 2020, arXiv)

    最後に紹介するのが、メビウス変換を利用したデータオーグメンテーションです。

    下図のように、画像をグニャリと曲げたような変換を行います。

    参考画像

    メビウス変換を行うため、計算が非常に遅くなります

    そのため、予め画像を変換して保存し、ランダムに読み込むほうが速いです。

    愚直に都度変換を行った場合、他のデータオーグメンテーションに比べて、「8倍」程度学習に時間がかかりました。

    こちらのURLが活用できるでしょう。

    【GitHubのリンク】
    https://github.com/nattochaduke/MobiusTransform_PyTorch

    実験

    実際にモデルを学習させて、性能を比較してみましょう!

    1. モデルはResNet -18 ( random initialization )
    2. optimizer は Adam
    3. 学習率は0.0001で、40エポック後に0.1倍しました。
    4. 学習は60エポック行いました。
    5. 実験数値は 3-fold cross validation の平均値です。

    データオーグメンテーションのハイパーパラメーター

    データオーグメンテーションのハイパーパラメーターは、以下の通りです。

    見出し意味発生確率その他の
    ハイパーパラメータ
    備考
    Baselineベースライン1
    Flip左右反転0.5
    RERandom Erasing0.5Torchvision実装デフォルト実装によってハイパーパラメータは異なる
    GMGridMask0.6拙実装デフォルト
    MobiusMobius Transform0.6文献では0.2くらいが良い

    この他、「A+BによってAの後にBを適用する」という複数段階のデータオーグメンテーションを、「Flip+RE」「Flip+GM」「Flip+Mobius」「Flip+GM+RE」の4つで考えます。

    validation accuracy の最高値

    下グラフが「validation accuracy」の最高値です。

    すべてのデータオーグメンテーションで、Baseline よりも性能が向上しました

    「Random Erasing」が振るわなかったのが気になりますが、ちゃんとハイパーパラメータチューニングを行えば改善する…かもしれません。

    1段階のデータオーグメンテーションでは、「Mobius Transform」が明らかに他のデータオーグメンテーションよりも優れています。

    「左右反転」との組み合わせでも、「Mobius Transform」は非常に良好ですね。

    「左右反転」と、他のデータオーグメンテーションを組み合わせるだけで、すべての場合で1段階どのデータオーグメンテーションよりも良い結果が得られました。

    このように、データオーグメンテーションは複数を組み合わせるのが普通です。

    注意点

    一方、「左右反転」「GridMask」「Random Erasing」の3つを組み合わせた場合は、「左右反転」と「Random Erasing」の組み合わせよりも僅かに良くなります

    しかし、「左右反転」と「GridMask」の組み合わせと比べると、明らかに性能が下がっています。

    これは、「GridMask」と「Random Erasing」が、とても似た処理を行っていることに起因すると考えられます。

    というのも、「GridMask」と「Random Erasing」が同時に適用された場合、下図のような画像が入力されてしまう可能性が有ります。

    これでは、まともな学習が不可能になってしまうのです。

    参考画像

    したがって、データオーグメンテーションを組み合わせるときには、できるだけ似ていないデータオーグメンテーションを選ぶことが重要です。

    あるデータオーグメンテーションと、別のデータオーグメンテーションが似ていないことをOrthogonal(直交している)と、文献ではよく表現されます。

    データオーグメンテーションで覚えるべきこと

    この記事で覚えていただきたい事は「3つだけ」です!

    1.データオーグメンテーションによって、性能が飛躍的に向上する可能性がある。

    今回は、ロクにハイパーパラメータチューニングを行いませんでしたが、ベースラインに比べ最大6%精度が向上しました。

    2.torchvision の transform は__init__ にハイパーパラメータを渡し、__call__ に実際の処理を書くだけで実装できる。

    torchvision は、画像処理用のパッケージですが、音声データや時系列データも同じ方法で transform を書くことで、簡単にデータオーグメンテーションが実装できます。

    3.データオーグメンテーションを複数組み合わせる時、その手法が Orthogonal であるか気をつけることが重要。

    似たようなデータオーグメンテーションを組み合わせても、性能は向上しないどころか悪化してしまうかもしれません。

    これらの注意点に気を付ければ飛躍的に性能を向上させることも可能です。

    ぜひ一度試してみてください!

    こちらの記事もオススメ!

    featureImg2020.07.28機械学習 特集知識編人工知能・機械学習でよく使われるワード徹底まとめ!機械学習の元祖「パーセプトロン」とは?【人工知能】ニューラルネ...

    featureImg2020.07.17ライトコード的「やってみた!」シリーズ「やってみた!」を集めました!(株)ライトコードが今まで作ってきた「やってみた!」記事を集めてみました!※作成日が新し...

    featureImg2020.07.30Python 特集実装編※最新記事順Responder + Firestore でモダンかつサーバーレスなブログシステムを作ってみた!P...

    広告メディア事業部

    広告メディア事業部

    おすすめ記事

    GitHubActionsのランナーに触れてみた

    こやまん(エンジニア)

    こやまん(エンジニア)

    2024.03.28

    IT技術

    Azure Data FactoryでSlackへ通知をしてみる

    たかやん(エンジニア)

    たかやん(エンジニア)

    2024.03.28

    IT技術

    GCP Secret Managerを使ってみた

    たなゆー(エンジニア)

    たなゆー(エンジニア)

    2024.03.21

    IT技術

    Bitriseのパイプラインと環境変数

    加納(エンジニア)

    加納(エンジニア)

    2024.03.11

    IT技術