【実装編】肺のCT画像からCOVID19かどうかの判断は可能か?【機械学習】
IT技術
実装編~肺のCT画像からCOVID19か予想できるのか?
今回は、前回に引き続き、肺の CT 画像から、COVID 19か否かを予測する深層学習モデルを、「PyTorch」で実装してみたいと思います。
PyTorch を採用した理由は、「Kaggle」での実装例も多く、公式ドキュメントも充実しているためです。
ちなみに、こちらの記事は「プログラミングで分類に挑戦する」ということが目的で、COVID 19を確実に分類できるわけではありませんので、予めご了承をお願い致します。
【前処理編】をお読みでない方は、まずはこちらからお読みください。
モデルの訓練
モデルの訓練部分は、PyTorch の公式チュートリアルを参考にしています。
【 PyTorch公式:チュートリアル】
https://pytorch.org/tutorials/
二値分類では、出力を2次元にして、「CrossEntropyloss」を用いる人もいるかと思います。
ですが、今回は1次元なので、「BinaryCrossEntropyloss」を用いています。
1def train_fn(fold):
2 print(f"### fold: {fold} ###")
3 trn_idx = folds[folds['fold'] != fold].index
4 val_idx = folds[folds['fold'] == fold].index
5 train_dataset = TrainDataset(folds.loc[trn_idx].reset_index(drop=True),
6 folds.loc[trn_idx].reset_index(drop=True)[CFG.target_col],
7 transform1=get_transforms1(data='train'),transform2=to_tensor())
8 valid_dataset = TrainDataset(folds.loc[val_idx].reset_index(drop=True),
9 folds.loc[val_idx].reset_index(drop=True)[CFG.target_col],
10 transform1=get_transforms1(data='valid'),transform2=to_tensor())
11
12
13 train_loader = DataLoader(train_dataset, batch_size=CFG.batch_size, shuffle=True, num_workers=4)
14 valid_loader = DataLoader(valid_dataset, batch_size=CFG.batch_size, shuffle=False, num_workers=4)
15 model = Efnet_b2_ns(weight_path="/kaggle/input/pytorch-efnet-ns-weights/tf_efficientnet_b2_aa-60c94f97.pth")
16 model.to(device)
17
18 optimizer = Adam(model.parameters(), lr=CFG.lr, amsgrad=False)
19 #scheduler = ReduceLROnPlateau(optimizer, 'min', factor=0.5, patience=2, verbose=True, eps=1e-6)
20 scheduler = CosineAnnealingLR(optimizer, T_max=20, eta_min=0.001)
21
22 criterion = nn.BCELoss()#weight = class_weight
23 best_score = -100
24 best_loss = np.inf
25 best_preds = None
26 for epoch in range(CFG.epochs):
27 start_time = time.time()
28 model.train()
29 avg_loss = 0.
30
31 optimizer.zero_grad()
32 tk0 = tqdm(enumerate(train_loader), total=len(train_loader))
33
34 for i, (images, labels) in tk0:
35 images = images.to(device)
36 labels = labels.to(device)
37 y_preds = model(images.float())
38 y_preds = torch.sigmoid(y_preds.view(-1))
39 loss = criterion(y_preds, labels)
40 loss.backward()
41 optimizer.step()
42 optimizer.zero_grad()
43 avg_loss += loss.item() / len(train_loader)
44
45 model.eval()
46 avg_val_loss = 0.
47 preds = []
48 valid_labels = []
49 tk1 = tqdm(enumerate(valid_loader), total=len(valid_loader))
50
51 for i, (images, labels) in tk1:
52 images = images.to(device)
53 labels = labels.to(device)
54
55 with torch.no_grad():
56 y_preds = model(images.float())
57
58 y_preds = torch.sigmoid(y_preds.view(-1))
59 preds.append(y_preds.to('cpu').numpy())
60 valid_labels.append(labels.to('cpu').numpy())
61
62 loss = criterion(y_preds, labels)
63 avg_val_loss += loss.item() / len(valid_loader)
64 scheduler.step(avg_val_loss)
65 preds = np.concatenate(preds)
66 valid_labels = np.concatenate(valid_labels)
67 score = auc(valid_labels,preds)
68 elapsed = time.time() - start_time
69 print(f' Epoch {epoch+1} - avg_train_loss: {avg_loss:.4f} avg_val_loss: {avg_val_loss:.4f} time: {elapsed:.0f}s')
70 print(f' Epoch {epoch+1} - AUC: {score}')
71 if score>best_score:#aucのスコアが良かったら予測値を更新...best_epochをきめるため
72 best_score = score
73 best_preds = preds
74 print("====",f' Epoch {epoch+1} - Save Best Score: {best_score:.4f}',"===")
75 torch.save(model.state_dict(), f'/kaggle/working/fold{fold}_efnet_b2_ns_.pth')#各epochのモデルを保存。。。best_epoch終了時のモデルを推論に使用
76 return best_preds, valid_labels
純正の PyTorch では、人によって書き方にバリエーションがあるため、「各種ラッパーを使った方がいい」という意見もあります。
ですが、一度純正で書いてみることで、「どこで何をやっているか」が理解しやすくなるのです。
たとえば、ラッパーのひとつである「fastai」は、理解がしにくい代表的な例ですね。
1import fastai
2from fastai.vision import *
データ型とsizeに注意!
データ型や size が異なると、損失関数やモデルにテンソルを通すとき、エラーが出ることがあります。
もし、以下のようなエラーが出たら、デバックして「dtype」や「size」を確認してみましょう。
1RuntimeError: Expected object of scalar type Long but got scalar type Float for argument #3 'mat1' in call to _th_addmm_
検証データに対するAUCを確認
モデルの訓練の関数を実行して、「out-of-fold の予測値」と「正解ラベル」を考慮することで、AUC が計算できます。
1# check your CV score!
2preds = np.concatenate(preds)
3valid_labels = np.concatenate(valid_labels)
4
5score = auc(valid_labels,preds)
6import datetime
7dt_now = datetime.datetime.now()
8print("現在時刻",dt_now)
9print("=====AUC(CV)======",score)
1現在時刻 2020-08-12 12:36:38.190845
2=====AUC(CV)====== 0.9531154143179866
AUC では、高いスコアが出ていることがわかりますね。
test データに対しても、高いスコアが出れば、「今回学習したモデルには汎化性能がある」と言えます。
テストデータに対する予測をみる
先ほど学習させた、各 fold の重みを用いて、推論を実行します。
モデルの部分が少し違うことに注意!
先ほどのモデルのコードと、比較をしてみてください。
1class Efnet_b2_inference(nn.Module):
2 def __init__(self,weight_path):
3 super().__init__()
4 self.weight_path = weight_path
5 self.model = geffnet.tf_efficientnet_b2_ns(pretrained=False)
6 #さいごの部分を付け替え
7 self.model.global_pool=nn.AdaptiveAvgPool2d(1)
8 self.model.classifier = nn.Linear(self.model.classifier.in_features, 1)
9 state_dict = torch.load(self.weight_path,map_location=device)
10 self.model.load_state_dict(fix_model_state_dict(state_dict))
11
12 def forward(self, x):
13 x = self.model(x)#ベースのモデルの流れに同じ
14 return x
15
16def fix_model_state_dict(state_dict):
17 from collections import OrderedDict
18 new_state_dict = OrderedDict()
19 for k, v in state_dict.items():
20 name = k
21 if name.startswith('model.'):
22 name = name[6:] # remove 'model.' of dataparallel
23 new_state_dict[name] = v
24 return new_state_dict
ImageNet で学習させた重みは、出力が1000次元です。
ですが、自分で学習させたモデルは、出力が1次元であることを考慮せねばなりません。
推論を実際に行う部分
各 fold ごとのモデル出力の平均を、最終的な出力とする、単縦なアンサンブルを用いています。
1def inference(model, test_loader, device):
2 model.to(device)
3 probs = []
4 labels = []
5 for i, (images,label) in tqdm(enumerate(test_loader), total=len(test_loader)):
6
7 images = images.to(device)
8
9 with torch.no_grad():
10 y_preds = model(images)
11 y_preds = torch.sigmoid(y_preds.view(-1))
12
13 probs.append(y_preds.to('cpu').numpy())
14 labels.append(label.numpy())
15
16 probs = np.concatenate(probs)
17 labels = np.concatenate(labels)
18 return probs,labels
19
20#ensamble your folds' models!
21def submit():
22 print('run inference')
23 test_dataset = TestDataset(test_df, transform1=get_transforms1(data='valid'),transform2=to_tensor())
24 test_loader = DataLoader(test_dataset, batch_size=CFG.batch_size, shuffle=False)
25 probs = []
26 labels = []
27 for fold in range(4):
28 weights_path = "/kaggle/working/fold{}_efnet_b2_ns_a_512_augmix_gridmask.pth".format(fold)
29 model = Efnet_b2_inference(weights_path)
30 _probs,_label = inference(model, test_loader, device)
31 probs.append(_probs)
32 labels.append(_label)
33 probs = np.mean(probs, axis=0)
34 return probs
推論結果のAUCをみてみる
上の関数を実行すれば、AUC が確認できます。
1test_df['predict'] = submit()
2print(test_df.head())
3score = auc(test_df['covid'].values[:],test_df['predict'])
4print("=====AUC(test)======",score)
これによる結果は、以下のようになりました。
1run inference
2100%|██████████| 16/16 [00:01<00:00, 14.93it/s]
3100%|██████████| 16/16 [00:00<00:00, 16.77it/s]
4100%|██████████| 16/16 [00:00<00:00, 16.90it/s]
5100%|██████████| 16/16 [00:00<00:00, 16.72it/s]
6 filename covid predict
70 CODE19 Data/Testing Data/Non Covid/1845.png 0 0.422551
81 CODE19 Data/Testing Data/Non Covid/1501.png 0 0.481845
92 CODE19 Data/Testing Data/Non Covid/1923.png 0 0.624478
103 CODE19 Data/Testing Data/Non Covid/1562.png 0 0.452440
114 CODE19 Data/Testing Data/Non Covid/1952.png 0 0.521156
12=====AUC(test)====== 0.5926229508196722
さいごに
今回は、Kaggle の環境を用いて画像分類を試してみました。
ただ、CV のスコアより大幅に低下しており、「今回モデルは訓練データに過学習してしまった」ということになります。
パラメーターが多いほど、表現できることは増えますが、少ないデータに対しては、過学習をしてしまうのです。
今回用いた「EfficientNet」は、VGG16 などに比べれば、パラメーターの数は多いです。
ただ、test データは少なく、画像の撮影にばらつきがあるなど、理想的なデータセットだったとは言えません。
また、実は、今回用いたものよりも、大規模なデータセットが既に公開されています。
70 GB ほどで、セグメンテーションマスク付きなので、勉強にはちょうど良いかもしれません。
【コーネル大学:BIMCV COVID-19+ 】
https://arxiv.org/abs/2006.01174
Kaggle公開中
なお、今回のコードは、Kaggleで公開しています。
リンク:肺のCT画像からCOVID19かどうか判断してみた!
実行環境も整っているので、ぜひ試してみてくださいね!
こちらの記事もオススメ!
2020.07.28機械学習 特集知識編人工知能・機械学習でよく使われるワード徹底まとめ!機械学習の元祖「パーセプトロン」とは?【人工知能】ニューラルネ...
2020.07.30Python 特集実装編※最新記事順Responder + Firestore でモダンかつサーバーレスなブログシステムを作ってみた!P...
2020.07.17ライトコード的「やってみた!」シリーズ「やってみた!」を集めました!(株)ライトコードが今まで作ってきた「やってみた!」記事を集めてみました!※作成日が新し...
ライトコードでは、エンジニアを積極採用中!
ライトコードでは、エンジニアを積極採用しています!社長と一杯しながらお話しする機会もご用意しております。そのほかカジュアル面談等もございますので、くわしくは採用情報をご確認ください。
採用情報へ
「好きを仕事にするエンジニア集団」の(株)ライトコードです! ライトコードは、福岡、東京、大阪の3拠点で事業展開するIT企業です。 現在は、国内を代表する大手IT企業を取引先にもち、ITシステムの受託事業が中心。 いずれも直取引で、月間PV数1億を超えるWebサービスのシステム開発・運営、インフラの構築・運用に携わっています。 システム開発依頼・お見積もり大歓迎! また、現在「WEBエンジニア」「モバイルエンジニア」「営業」「WEBデザイナー」「WEBディレクター」を積極採用中です! インターンや新卒採用も行っております。 以下よりご応募をお待ちしております! https://rightcode.co.jp/recruit