1. HOME
  2. ブログ
  3. IT技術
  4. PyTorchでデータオーグメンテーションを試そう

PyTorchでデータオーグメンテーションを試そう

PyTorchでデータオーグメンテーションを試してみる

機械学習、特にディープラーニングでは、学習データの量が重要であることは、ご承知のとおりだと思います。

しかし、大量の学習データを用意するには、金銭的にも時間的にもコストがかかります。

そのため、学習データをランダムに変更することによって、データを水増し(オーグメント: augment )することがよく行われます

データオーグメンテーションは、かねてより研究されてきましたが、ディープラーニングの台頭によって、研究は勢いを増し、様々な手法が提案されています。

今回は、特に画像分類タスクに興味を絞り、いくつかの手法を紹介します。

最新手法の実装

多くの手法は、  torchvision.transforms に実装されていたり、組み合わせで実現できます。

しかし、まだ実装のない最新手法を実装し、実際にディープラーニングモデルを学習させて、結果を比較検討します。

注意点

今回の記事では、「glob」「joblib」「numpy」「torch」「torchvision」 がインストール済みであることを前提としております。

下記の内容をインポートしておきます。

こちらの記事もオススメ!

データセット

データオーグメンテーションの手法を説明する前に、今回使用するデータセット, 「Animal -10」を紹介します。

「Animal -10」は犬・猫・蝶など、10種類の動物の画像データセットです。

【Animal -10(GPL-2)】
https://www.kaggle.com/alessiocorrado99/animals10

このような画像が、28000枚ほど含まれています。

大きさも縦横比もまちまちです。

zip ファイルを解凍すると、「raw-img」というフォルダの下に、動物名(スペイン語)のフォルダがあり、その中に jpeg 画像が入っています。

「象」がラベルであるサンプルが1446個、「犬」がラベルであるサンプルが4863個と、バランスの悪いデータセットなので、「象」に合わせて他のクラスの画像は減らします。

クラスごとにフォルダが分けられたデータ

さて、このようにクラスごとにフォルダが分けられたデータがあるとき、 torchvision.datasets.ImageFolder によって簡単に PyTorch 用のデータセットを得ることができます。

また、 ds = ImageFolder('raw-img/') により、 ds というインスタンスが得ることができます。

例えば ds[0] とすれば (0番目のPIL形式の画像, 0番目のラベル) というタプルが得られます。

データオーグメンテーション手法

まず、何もデータオーグメンテーションを行わない場合を見てみましょう。

Baseline

仮に、「224×224の画像を入力」とするモデルを考えると、シンプルに「元の画像を224×224にリサイズする」というのが、最も直感的です。

torchvision.transforms.Resize((h, w)) によって、 __call__(Input) されると、  Input を「高さ h 」、「幅 w 」に変換するインスタンスが得られます。

以下、このベースラインにデータオーグメンテーション手法を適用することにしましょう。

左右反転

画像をランダムに左右反転させます。

transforms.RandomHorizontalFlip によって実現できます。

と、 torchvision.transforms.Compose を使うと、画像の変換の組み合わせが簡単に書けます。

変換後の画像

フリップはランダムに起こるので、「Baseline」と同じ画像が得られることもあります。

Random Erasing ( Z Zhong et al., 2017, arXiv )

「Random Erasing」は下図のように、四角形で画像をマスクするデータオーグメンテーションです。

四角形の大きさや個数はランダムです。

マスク後の画像

コード

「 torchvision 」に実装されていますが、 torchvision.transforms.RandomErasing の引数は torch.Tensor なので、 torchvision.transforms.ToTensor によって変換しておかなければなりません。

ややこしいですね。

「 RandomErasing 」の発生確率やマスクの最大サイズなどは、与える引数でコントロールできます。

【論文のリンク】
https://arxiv.org/pdf/1708.04896.pdf

画像の情報が失われてしまう場合

Random Erasing によって画像の情報が失われてしまうことがあります。

例えば、下図は Random Erasing のマスクが、象を覆い尽くしてしまった例です。

このような状況でも、学習モデルはこの画像を象と判定するように学習しますが、これによって性能が向上するとは考えづらいです。

情報が失われた画像

この問題意識から、次に紹介する「GridMask」が開発されました。

GridMask (“GridMask Data Augmentation”, P. Cheng et al., 2020, arXiv)

「GridMask」は、下図のように、小さめの正方形のマスクを等間隔に並べて、元画像をマスクします。

GridMask には4つのパラメータがあります。

まず、\(d\) はマスクの間隔を表すパラメータです。

さらに \(r\) は、どれほど元の画像を残すかを決めるパラメータで、\(r=0\) なら画像は全てマスクされ、\(r=1\) なら全くマスクされません。

\(\delta_x,\delta_y\) は、オフセットです。

\(r\) は、予め決められています。

\(d\) は、ハイパーパラメータとして、与えられた範囲(実装では d_range )から、\(\delta_x, \ delta_y\) は [0, d-1] から、画像ごとにランダムに選ばれます。

GridMaskを用いた画像

GridMaskを自前で実装する

さて、GridMask はまだ torchvision に実装されていないので、自前で実装してみましょう。

transformは __init__ にハイパーパラメータを渡し、 __call__ に実際の処理を書くだけで実装できます。

実装内容

実装内容は、コチラになります。

Mobius Transform (“Data augmentation with Mobius transformations”, Zhou et al., 2020, arXiv)

最後に紹介するのが、メビウス変換を利用したデータオーグメンテーションです。

下図のように、画像をグニャリと曲げたような変換を行います。

参考画像

メビウス変換を行うため、計算が非常に遅くなります

そのため、予め画像を変換して保存し、ランダムに読み込むほうが速いです。

愚直に都度変換を行った場合、他のデータオーグメンテーションに比べて、「8倍」程度学習に時間がかかりました。

こちらのURLが活用できるでしょう。

【GitHubのリンク】
https://github.com/nattochaduke/MobiusTransform_PyTorch

実験

実際にモデルを学習させて、性能を比較してみましょう!

  1. モデルはResNet -18 ( random initialization )
  2. optimizer は Adam
  3. 学習率は0.0001で、40エポック後に0.1倍しました。
  4. 学習は60エポック行いました。
  5. 実験数値は 3-fold cross validation の平均値です。

データオーグメンテーションのハイパーパラメーター

データオーグメンテーションのハイパーパラメーターは、以下の通りです。

見出し意味発生確率その他の
ハイパーパラメータ
備考
Baselineベースライン1
Flip左右反転0.5
RERandom Erasing0.5Torchvision実装デフォルト実装によってハイパーパラメータは異なる
GMGridMask0.6拙実装デフォルト
MobiusMobius Transform0.6文献では0.2くらいが良い

この他、「A+BによってAの後にBを適用する」という複数段階のデータオーグメンテーションを、「Flip+RE」「Flip+GM」「Flip+Mobius」「Flip+GM+RE」の4つで考えます。

validation accuracy の最高値

下グラフが「validation accuracy」の最高値です。

すべてのデータオーグメンテーションで、Baseline よりも性能が向上しました

「Random Erasing」が振るわなかったのが気になりますが、ちゃんとハイパーパラメータチューニングを行えば改善する…かもしれません。

1段階のデータオーグメンテーションでは、「Mobius Transform」が明らかに他のデータオーグメンテーションよりも優れています。

「左右反転」との組み合わせでも、「Mobius Transform」は非常に良好ですね。

「左右反転」と、他のデータオーグメンテーションを組み合わせるだけで、すべての場合で1段階どのデータオーグメンテーションよりも良い結果が得られました。

このように、データオーグメンテーションは複数を組み合わせるのが普通です。

注意点

一方、「左右反転」「GridMask」「Random Erasing」の3つを組み合わせた場合は、「左右反転」と「Random Erasing」の組み合わせよりも僅かに良くなります

しかし、「左右反転」と「GridMask」の組み合わせと比べると、明らかに性能が下がっています。

これは、「GridMask」と「Random Erasing」が、とても似た処理を行っていることに起因すると考えられます。

というのも、「GridMask」と「Random Erasing」が同時に適用された場合、下図のような画像が入力されてしまう可能性が有ります。

これでは、まともな学習が不可能になってしまうのです。

参考画像

したがって、データオーグメンテーションを組み合わせるときには、できるだけ似ていないデータオーグメンテーションを選ぶことが重要です。

あるデータオーグメンテーションと、別のデータオーグメンテーションが似ていないことをOrthogonal(直交している)と、文献ではよく表現されます。

データオーグメンテーションで覚えるべきこと

この記事で覚えていただきたい事は「3つだけ」です!

1.データオーグメンテーションによって、性能が飛躍的に向上する可能性がある。

今回は、ロクにハイパーパラメータチューニングを行いませんでしたが、ベースラインに比べ最大6%精度が向上しました。

2.torchvision の transform は __init__ にハイパーパラメータを渡し、 __call__ に実際の処理を書くだけで実装できる。

torchvision は、画像処理用のパッケージですが、音声データや時系列データも同じ方法で transform を書くことで、簡単にデータオーグメンテーションが実装できます。

3.データオーグメンテーションを複数組み合わせる時、その手法が Orthogonal であるか気をつけることが重要。

似たようなデータオーグメンテーションを組み合わせても、性能は向上しないどころか悪化してしまうかもしれません。

 

これらの注意点に気を付ければ飛躍的に性能を向上させることも可能です。

ぜひ一度試してみてください!

PyTorchでのシステム開発依頼・お見積もりはこちらまでお願いします。
また、機械学習系エンジニアを積極採用中です!詳しくはこちらをご覧ください。

こちらの記事もオススメ!

ライトコードよりお知らせ

にゃんこ師匠にゃんこ師匠
システム開発のご相談やご依頼はこちら
ミツオカミツオカ
ライトコードの採用募集はこちら
にゃんこ師匠にゃんこ師匠
社長と一杯飲みながらお話してみたい方はこちら
ミツオカミツオカ
フリーランスエンジニア様の募集はこちら
にゃんこ師匠にゃんこ師匠
その他、お問い合わせはこちら
ミツオカミツオカ
   
お気軽にお問い合わせください!せっかくなので、別の記事もぜひ読んでいって下さいね!

一緒に働いてくれる仲間を募集しております!

ライトコードでは、仲間を募集しております!

当社のモットーは「好きなことを仕事にするエンジニア集団」「エンジニアによるエンジニアのための会社」。エンジニアであるあなたの「やってみたいこと」を全力で応援する会社です。

また、ライトコードは現在、急成長中!だからこそ、あなたにお任せしたいやりがいのあるお仕事は沢山あります。「コアメンバー」として活躍してくれる、あなたからのご応募をお待ちしております!

なお、ご応募の前に、「話しだけ聞いてみたい」「社内の雰囲気を知りたい」という方はこちらをご覧ください。

ライトコードでは一緒に働いていただける方を募集しております!

採用情報はこちら

関連記事