【準備編】肺のCT画像からCOVID19かどうかの判断は可能か?【機械学習】
IT技術
準備編~肺のCT画像からCOVID19か予想できるのか?
今回は、前回に引き続き、肺の CT 画像から、COVID 19か否かを予測する深層学習モデルを、「PyTorch」で実装してみたいと思います。
PyTorch を採用した理由は、「Kaggle」での実装例も多く、公式ドキュメントも充実しているためです。
ちなみに、こちらの記事は「プログラミングで分類に挑戦する」ということが目的で、COVID 19を確実に分類できるわけではありませんので、予めご了承をお願い致します。
【下処理編】をお読みでない方は、まずは以下をお読みください。
画像拡張をする
「画像拡張」は、データ数の少なさを緩和するために、有効な手段です。
今回は、ライブラリとして、「albumentations」を用いました。
「albumentations」には、様々な拡張が搭載されています。
【 GitHub:albumentations 】
https://github.com/albumentations-team/albumentations
また、自作の画像拡張関数を作成することも可能です。
使用した augumentation は、以下の通りです。
1def get_transforms1(*, data):
2
3 #train,valid以外だったら処理を止める。
4
5 if data == 'train':
6 return Compose([
7 HorizontalFlip(p=0.5),
8 VerticalFlip(p=0.5),
9 GaussNoise(p=0.5),
10 RandomRotate90(p=0.5),
11 RandomGamma(p=0.5),
12 RandomAugMix(severity=3, width=3, alpha=1., p=0.5),
13 Normalize(
14 mean=[0.485, 0.456, 0.406],
15 std=[0.229, 0.224, 0.225],
16 )
17 ])
18 elif data == 'valid':
19 return Compose([
20 Normalize(
21 mean=[0.485, 0.456, 0.406],
22 std=[0.229, 0.224, 0.225],
23 )
24 ])
25
26def to_tensor(*args):
27
28 return Compose([
29 ToTensor()
30 ])
Augmix
画像拡張の手法では、コーネル大学が公開している、「Augmix」というものがあります。
【コーネル大学:Augmix 】
https://arxiv.org/abs/1912.02781
これは、「train」と「test」でデータに違いがある場合に生じる、「堅牢性」と「不確実性」の問題を大幅改善できる手法です。
「Dataset」や「model」も準備する
Datasetは典型!
PyTorch では、「channel」「height」「width」の順に、Tensor の軸を並べ替える必要があります。
以下が、PyTorch における Dataset です。
1class TrainDataset(Dataset):
2 def __init__(self, df, labels, transform1=None, transform2=None):
3 self.df = df
4 self.labels = labels
5 self.transform = transform1
6 self.transform_ = transform2
7
8 def __len__(self):
9 return len(self.df)
10
11 def __getitem__(self, idx):
12 file_name = self.df['filename'].values[idx]
13 file_path = '/kaggle/input/computed-tomography-of-lungs-datase-for-covid19/{}'.format(file_name)
14 image = cv2.imread(file_path)
15 image = cv2.resize(image,(SIZE,SIZE))
16 if self.transform:
17 image = self.transform(image=image)['image']
18 if self.transform_:
19 image = self.transform_(image=image)['image']
20
21 label = torch.tensor(self.labels[idx]).float()
22 return image, label
画像は、OpenCV により「ndarray」で扱っているので、NumPy の「swapaxis」でも対応できるかと思います。
それ以外は、Keras と似たような、Dataset を書けば良いわけです。
モデルはお好みで!
深層学習のモデルの技術革新は凄まじく、2019年に Google から発表された「EfficientNet」は、Kaggle 上で大人気です。
その EfficientNet よりも、ImageNet での性能が良いとされているのが、「EfficientNet-Noisy-Student」。
これは、「Self-learning」に、画像拡張のようなノイズを加えたものです。
同じパラメーターで見た精度は、従来のものよりも向上していますね。
というわけで今回は、「EfficientNet-Noisy-Student」の B2 を使用しました。
ImageNet では、出力が1000次元なので、出力が1次元の FC 層に付け替えます。
重みを読み込んでからでないと、エラーが出るので、注意が必要です。
コード
1class Efnet_b2_ns(nn.Module):
2
3 def __init__(self,weight_path):
4 super().__init__()
5 self.weight_path = weight_path
6 self.model = geffnet.tf_efficientnet_b2_ns(pretrained=False)
7 state_dict = torch.load(self.weight_path,map_location=device)
8 self.model.load_state_dict(fix_model_state_dict(state_dict))
9 #さいごの部分を付け替え
10 self.model.global_pool=nn.AdaptiveAvgPool2d(1)
11 self.model.classifier = nn.Linear(self.model.classifier.in_features, 1)
12
13
14 def forward(self, x):
15 x = self.model(x)#ベースのモデルの流れに同じ
16 return x
17def fix_model_state_dict(state_dict):
18 from collections import OrderedDict
19 new_state_dict = OrderedDict()
20 for k, v in state_dict.items():
21 name = k
22 if name.startswith('model.'):
23 name = name[6:] # remove 'model.' of dataparallel
24 new_state_dict[name] = v
25 return new_state_dict
精度を向上させるためには、以下のような手順を踏むのが、一般的です。
- 軽いモデルで試す
- 画像拡張の探索を行う
- 重いモデルを複数学習させる
- アンサンブル
EfficientNet 以外では、「SE-ResNeXt」などの SE 系モデルも、試す価値があるかもしれませんね。
評価指標はAUCを選択
モデルの出力は「0~1」ですが、ラベルは「0」か「1」。
そのため、F1-score などを用いると、閾値によっては精度のスコアが変化してしまいます。
ですが、ROC 曲線の曲線下面積である「AUC」を用いれば、閾値によらない精度を考えることができるのです。
AUC は、QWK と並んで、医療データでよく扱われます。
実装編へつづく!
こちらの記事は、【実装編】へつづきます。
こちらの記事もオススメ!
2020.07.28機械学習 特集知識編人工知能・機械学習でよく使われるワード徹底まとめ!機械学習の元祖「パーセプトロン」とは?【人工知能】ニューラルネ...
2020.07.30Python 特集実装編※最新記事順Responder + Firestore でモダンかつサーバーレスなブログシステムを作ってみた!P...
2020.07.17ライトコード的「やってみた!」シリーズ「やってみた!」を集めました!(株)ライトコードが今まで作ってきた「やってみた!」記事を集めてみました!※作成日が新し...
ライトコードでは、エンジニアを積極採用中!
ライトコードでは、エンジニアを積極採用しています!社長と一杯しながらお話しする機会もご用意しております。そのほかカジュアル面談等もございますので、くわしくは採用情報をご確認ください。
採用情報へ
「好きを仕事にするエンジニア集団」の(株)ライトコードです! ライトコードは、福岡、東京、大阪、名古屋の4拠点で事業展開するIT企業です。 現在は、国内を代表する大手IT企業を取引先にもち、ITシステムの受託事業が中心。 いずれも直取引で、月間PV数1億を超えるWebサービスのシステム開発・運営、インフラの構築・運用に携わっています。 システム開発依頼・お見積もり大歓迎! また、現在「WEBエンジニア」「モバイルエンジニア」「営業」「WEBデザイナー」を積極採用中です! インターンや新卒採用も行っております。 以下よりご応募をお待ちしております! https://rightcode.co.jp/recruit