PyTorchでデータオーグメンテーションを試そう
IT技術
PyTorchでデータオーグメンテーションを試してみる
機械学習、特にディープラーニングでは、学習データの量が重要であることは、ご承知のとおりだと思います。
しかし、大量の学習データを用意するには、金銭的にも時間的にもコストがかかります。
そのため、学習データをランダムに変更することによって、データを水増し(オーグメント: augment )することがよく行われます。
データオーグメンテーションは、かねてより研究されてきましたが、ディープラーニングの台頭によって、研究は勢いを増し、様々な手法が提案されています。
今回は、特に画像分類タスクに興味を絞り、いくつかの手法を紹介します。
最新手法の実装
多くの手法は、 torchvision.transforms に実装されていたり、組み合わせで実現できます。
しかし、まだ実装のない最新手法を実装し、実際にディープラーニングモデルを学習させて、結果を比較検討します。
注意点
今回の記事では、「glob」「joblib」「numpy」「torch」「torchvision」 がインストール済みであることを前提としております。
下記の内容をインポートしておきます。
1import os
2import glob
3import numpy as np
4from sklearn.metrics import accuracy_score
5from sklearn.model_selection import KFold
6
7import torch
8import torch.nn as nn
9import torch.utils
10from torch import optim
11import torchvision
12from torch.optim import lr_scheduler
13from torchvision.datasets import ImageFolder
14from torchvision import transforms
データセット
データオーグメンテーションの手法を説明する前に、今回使用するデータセット, 「Animal -10」を紹介します。
「Animal -10」は犬・猫・蝶など、10種類の動物の画像データセットです。
【Animal -10(GPL-2)】
https://www.kaggle.com/alessiocorrado99/animals10
このような画像が、28000枚ほど含まれています。
大きさも縦横比もまちまちです。
zip ファイルを解凍すると、「raw-img」というフォルダの下に、動物名(スペイン語)のフォルダがあり、その中に jpeg 画像が入っています。
「象」がラベルであるサンプルが1446個、「犬」がラベルであるサンプルが4863個と、バランスの悪いデータセットなので、「象」に合わせて他のクラスの画像は減らします。
クラスごとにフォルダが分けられたデータ
1for directory in sorted(glob.glob('raw-img/*')):
2 files = sorted(glob.glob(directory + '/*'))
3 print(len(files), directory)
4 for cnt, f in enumerate(files): # 1446 raw-img/elefante
5 if cnt >= 1446:
6 os.remove(f)
さて、このようにクラスごとにフォルダが分けられたデータがあるとき、torchvision.datasets.ImageFolder によって簡単に PyTorch 用のデータセットを得ることができます。
また、ds = ImageFolder('raw-img/') により、ds というインスタンスが得ることができます。
例えば ds[0] とすれば (0番目のPIL形式の画像, 0番目のラベル) というタプルが得られます。
データオーグメンテーション手法
まず、何もデータオーグメンテーションを行わない場合を見てみましょう。
Baseline
仮に、「224x224の画像を入力」とするモデルを考えると、シンプルに「元の画像を224x224にリサイズする」というのが、最も直感的です。
torchvision.transforms.Resize((h, w)) によって、__call__(Input) されると、 Input を「高さ h 」、「幅 w 」に変換するインスタンスが得られます。
以下、このベースラインにデータオーグメンテーション手法を適用することにしましょう。
左右反転
画像をランダムに左右反転させます。
transforms.RandomHorizontalFlip によって実現できます。
1transform = transforms.Compose([
2 transforms.Resize((224, 224)),
3 transforms.RandomHorizontalFlip(),
4])
と、torchvision.transforms.Compose を使うと、画像の変換の組み合わせが簡単に書けます。
変換後の画像
フリップはランダムに起こるので、「Baseline」と同じ画像が得られることもあります。
Random Erasing ( Z Zhong et al., 2017, arXiv )
「Random Erasing」は下図のように、四角形で画像をマスクするデータオーグメンテーションです。
四角形の大きさや個数はランダムです。
マスク後の画像
コード
1transform = transforms.Compose([
2 transforms.Resize((224, 224)),
3 transforms.ToTensor(),
4 transforms.RandomErasing(),
5])
「 torchvision 」に実装されていますが、torchvision.transforms.RandomErasing の引数は torch.Tensor なので、torchvision.transforms.ToTensor によって変換しておかなければなりません。
ややこしいですね。
「 RandomErasing 」の発生確率やマスクの最大サイズなどは、与える引数でコントロールできます。
【論文のリンク】
https://arxiv.org/pdf/1708.04896.pdf
画像の情報が失われてしまう場合
Random Erasing によって画像の情報が失われてしまうことがあります。
例えば、下図は Random Erasing のマスクが、象を覆い尽くしてしまった例です。
このような状況でも、学習モデルはこの画像を象と判定するように学習しますが、これによって性能が向上するとは考えづらいです。
情報が失われた画像
この問題意識から、次に紹介する「GridMask」が開発されました。
GridMask ("GridMask Data Augmentation", P. Cheng et al., 2020, arXiv)
「GridMask」は、下図のように、小さめの正方形のマスクを等間隔に並べて、元画像をマスクします。
GridMask には4つのパラメータがあります。
まず、 はマスクの間隔を表すパラメータです。
さらに は、どれほど元の画像を残すかを決めるパラメータで、 なら画像は全てマスクされ、 なら全くマスクされません。
は、オフセットです。
は、予め決められています。
は、ハイパーパラメータとして、与えられた範囲(実装では d_range )から、 は [0, d-1] から、画像ごとにランダムに選ばれます。
GridMaskを用いた画像
GridMaskを自前で実装する
さて、GridMask はまだ torchvision に実装されていないので、自前で実装してみましょう。
transformは__init__ にハイパーパラメータを渡し、__call__ に実際の処理を書くだけで実装できます。
実装内容
実装内容は、コチラになります。
1class GridMask():
2 def __init__(self, p=0.6, d_range=(96, 224), r=0.6):
3 self.p = p
4 self.d_range = d_range
5 self.r = r
6
7 def __call__(self, sample):
8 """
9 sample: torch.Tensor(3, height, width)
10 """
11 if np.random.uniform() > self.p:
12 return sample
13 sample = sample.numpy()
14 side = sample.shape[1]
15 d = np.random.randint(*self.d_range, dtype=np.uint8)
16 r = int(self.r * d)
17
18 mask = np.ones((side+d, side+d), dtype=np.uint8)
19 for i in range(0, side+d, d):
20 for j in range(0, side+d, d):
21 mask[i: i+(d-r), j: j+(d-r)] = 0
22 delta_x, delta_y = np.random.randint(0, d, size=2)
23 mask = mask[delta_x: delta_x+side, delta_y: delta_y+side]
24 sample *= np.expand_dims(mask, 0)
25 return torch.from_numpy(sample)
Mobius Transform ("Data augmentation with Mobius transformations", Zhou et al., 2020, arXiv)
最後に紹介するのが、メビウス変換を利用したデータオーグメンテーションです。
下図のように、画像をグニャリと曲げたような変換を行います。
参考画像
メビウス変換を行うため、計算が非常に遅くなります。
そのため、予め画像を変換して保存し、ランダムに読み込むほうが速いです。
愚直に都度変換を行った場合、他のデータオーグメンテーションに比べて、「8倍」程度学習に時間がかかりました。
こちらのURLが活用できるでしょう。
【GitHubのリンク】
https://github.com/nattochaduke/MobiusTransform_PyTorch
実験
実際にモデルを学習させて、性能を比較してみましょう!
- モデルはResNet -18 ( random initialization )
- optimizer は Adam
- 学習率は0.0001で、40エポック後に0.1倍しました。
- 学習は60エポック行いました。
- 実験数値は 3-fold cross validation の平均値です。
データオーグメンテーションのハイパーパラメーター
データオーグメンテーションのハイパーパラメーターは、以下の通りです。
見出し | 意味 | 発生確率 | その他の ハイパーパラメータ | 備考 |
Baseline | ベースライン | 1 | ||
Flip | 左右反転 | 0.5 | ||
RE | Random Erasing | 0.5 | Torchvision実装デフォルト | 実装によってハイパーパラメータは異なる |
GM | GridMask | 0.6 | 拙実装デフォルト | |
Mobius | Mobius Transform | 0.6 | 文献では0.2くらいが良い |
この他、「A+BによってAの後にBを適用する」という複数段階のデータオーグメンテーションを、「Flip+RE」「Flip+GM」「Flip+Mobius」「Flip+GM+RE」の4つで考えます。
validation accuracy の最高値
下グラフが「validation accuracy」の最高値です。
すべてのデータオーグメンテーションで、Baseline よりも性能が向上しました。
「Random Erasing」が振るわなかったのが気になりますが、ちゃんとハイパーパラメータチューニングを行えば改善する…かもしれません。
1段階のデータオーグメンテーションでは、「Mobius Transform」が明らかに他のデータオーグメンテーションよりも優れています。
「左右反転」との組み合わせでも、「Mobius Transform」は非常に良好ですね。
「左右反転」と、他のデータオーグメンテーションを組み合わせるだけで、すべての場合で1段階どのデータオーグメンテーションよりも良い結果が得られました。
このように、データオーグメンテーションは複数を組み合わせるのが普通です。
注意点
一方、「左右反転」「GridMask」「Random Erasing」の3つを組み合わせた場合は、「左右反転」と「Random Erasing」の組み合わせよりも僅かに良くなります。
しかし、「左右反転」と「GridMask」の組み合わせと比べると、明らかに性能が下がっています。
これは、「GridMask」と「Random Erasing」が、とても似た処理を行っていることに起因すると考えられます。
というのも、「GridMask」と「Random Erasing」が同時に適用された場合、下図のような画像が入力されてしまう可能性が有ります。
これでは、まともな学習が不可能になってしまうのです。
参考画像
したがって、データオーグメンテーションを組み合わせるときには、できるだけ似ていないデータオーグメンテーションを選ぶことが重要です。
あるデータオーグメンテーションと、別のデータオーグメンテーションが似ていないことをOrthogonal(直交している)と、文献ではよく表現されます。
データオーグメンテーションで覚えるべきこと
この記事で覚えていただきたい事は「3つだけ」です!
1.データオーグメンテーションによって、性能が飛躍的に向上する可能性がある。
今回は、ロクにハイパーパラメータチューニングを行いませんでしたが、ベースラインに比べ最大6%精度が向上しました。
2.torchvision の transform はにハイパーパラメータを渡し、に実際の処理を書くだけで実装できる。
torchvision は、画像処理用のパッケージですが、音声データや時系列データも同じ方法で transform を書くことで、簡単にデータオーグメンテーションが実装できます。
3.データオーグメンテーションを複数組み合わせる時、その手法が Orthogonal であるか気をつけることが重要。
似たようなデータオーグメンテーションを組み合わせても、性能は向上しないどころか悪化してしまうかもしれません。
これらの注意点に気を付ければ飛躍的に性能を向上させることも可能です。
ぜひ一度試してみてください!
こちらの記事もオススメ!
2020.07.28機械学習 特集知識編人工知能・機械学習でよく使われるワード徹底まとめ!機械学習の元祖「パーセプトロン」とは?【人工知能】ニューラルネ...
2020.07.17ライトコード的「やってみた!」シリーズ「やってみた!」を集めました!(株)ライトコードが今まで作ってきた「やってみた!」記事を集めてみました!※作成日が新し...
2020.07.30Python 特集実装編※最新記事順Responder + Firestore でモダンかつサーバーレスなブログシステムを作ってみた!P...
ライトコードでは、エンジニアを積極採用中!
ライトコードでは、エンジニアを積極採用しています!社長と一杯しながらお話しする機会もご用意しております。そのほかカジュアル面談等もございますので、くわしくは採用情報をご確認ください。
採用情報へ
「好きを仕事にするエンジニア集団」の(株)ライトコードです! ライトコードは、福岡、東京、大阪の3拠点で事業展開するIT企業です。 現在は、国内を代表する大手IT企業を取引先にもち、ITシステムの受託事業が中心。 いずれも直取引で、月間PV数1億を超えるWebサービスのシステム開発・運営、インフラの構築・運用に携わっています。 システム開発依頼・お見積もり大歓迎! また、現在「WEBエンジニア」「モバイルエンジニア」「営業」「WEBデザイナー」「WEBディレクター」を積極採用中です! インターンや新卒採用も行っております。 以下よりご応募をお待ちしております! https://rightcode.co.jp/recruit