• トップ
  • ブログ一覧
  • 【実装編】肺のCT画像からCOVID19かどうかの判断は可能か?【機械学習】
  • 【実装編】肺のCT画像からCOVID19かどうかの判断は可能か?【機械学習】

    広告メディア事業部広告メディア事業部
    2020.10.20

    IT技術

    実装編~肺のCT画像からCOVID19か予想できるのか?

    肺のCT画像

    今回は、前回に引き続き、肺の CT 画像から、COVID 19か否かを予測する深層学習モデルを、「PyTorch」で実装してみたいと思います。

    PyTorch を採用した理由は、「Kaggle」での実装例も多く、公式ドキュメントも充実しているためです。

    ちなみに、こちらの記事は「プログラミングで分類に挑戦する」ということが目的で、COVID 19を確実に分類できるわけではありませんので、予めご了承をお願い致します。

    【前処理編】をお読みでない方は、まずはこちらからお読みください。

    【前処理編】肺のCT画像からCOVID19かどうかの判断は可能か?2020.10.14【前処理編】肺のCT画像からCOVID19かどうかの判断は可能か?【機械学習】前処理編~肺のCT画像からCOVID19か予想できるのか?今回は、肺の CT 画像から、COVID 19か否かを予測す...

    モデルの訓練

    モデルの訓練部分は、PyTorch の公式チュートリアルを参考にしています。

    【 PyTorch公式:チュートリアル】
    https://pytorch.org/tutorials/

    二値分類では、出力を2次元にして、「CrossEntropyloss」を用いる人もいるかと思います。

    ですが、今回は1次元なので、「BinaryCrossEntropyloss」を用いています。

    1def train_fn(fold):
    2    print(f"### fold: {fold} ###")
    3    trn_idx = folds[folds['fold'] != fold].index
    4    val_idx = folds[folds['fold'] == fold].index
    5    train_dataset = TrainDataset(folds.loc[trn_idx].reset_index(drop=True), 
    6                                 folds.loc[trn_idx].reset_index(drop=True)[CFG.target_col], 
    7                                 transform1=get_transforms1(data='train'),transform2=to_tensor())
    8    valid_dataset = TrainDataset(folds.loc[val_idx].reset_index(drop=True), 
    9                                 folds.loc[val_idx].reset_index(drop=True)[CFG.target_col], 
    10                                 transform1=get_transforms1(data='valid'),transform2=to_tensor())
    11    
    12    
    13    train_loader = DataLoader(train_dataset, batch_size=CFG.batch_size, shuffle=True, num_workers=4)
    14    valid_loader = DataLoader(valid_dataset, batch_size=CFG.batch_size, shuffle=False, num_workers=4)
    15    model = Efnet_b2_ns(weight_path="/kaggle/input/pytorch-efnet-ns-weights/tf_efficientnet_b2_aa-60c94f97.pth")
    16    model.to(device)
    17    
    18    optimizer = Adam(model.parameters(), lr=CFG.lr, amsgrad=False)
    19    #scheduler = ReduceLROnPlateau(optimizer, 'min', factor=0.5, patience=2, verbose=True, eps=1e-6)
    20    scheduler = CosineAnnealingLR(optimizer, T_max=20, eta_min=0.001)
    21    
    22    criterion = nn.BCELoss()#weight = class_weight
    23    best_score = -100
    24    best_loss = np.inf
    25    best_preds = None
    26    for epoch in range(CFG.epochs): 
    27        start_time = time.time()
    28        model.train()
    29        avg_loss = 0.
    30
    31        optimizer.zero_grad()
    32        tk0 = tqdm(enumerate(train_loader), total=len(train_loader))
    33
    34        for i, (images, labels) in tk0:
    35            images = images.to(device)
    36            labels = labels.to(device)    
    37            y_preds = model(images.float())
    38            y_preds = torch.sigmoid(y_preds.view(-1))
    39            loss = criterion(y_preds, labels)
    40            loss.backward()
    41            optimizer.step()
    42            optimizer.zero_grad()
    43            avg_loss += loss.item() / len(train_loader)
    44            
    45        model.eval()
    46        avg_val_loss = 0.
    47        preds = []
    48        valid_labels = []
    49        tk1 = tqdm(enumerate(valid_loader), total=len(valid_loader))
    50
    51        for i, (images, labels) in tk1:
    52            images = images.to(device)
    53            labels = labels.to(device)
    54            
    55            with torch.no_grad():
    56                y_preds = model(images.float())
    57                
    58                y_preds = torch.sigmoid(y_preds.view(-1))
    59            preds.append(y_preds.to('cpu').numpy())
    60            valid_labels.append(labels.to('cpu').numpy())
    61
    62            loss = criterion(y_preds, labels)
    63            avg_val_loss += loss.item() / len(valid_loader)
    64        scheduler.step(avg_val_loss)
    65        preds = np.concatenate(preds)
    66        valid_labels = np.concatenate(valid_labels)
    67        score = auc(valid_labels,preds)
    68        elapsed = time.time() - start_time
    69        print(f'  Epoch {epoch+1} - avg_train_loss: {avg_loss:.4f}  avg_val_loss: {avg_val_loss:.4f}  time: {elapsed:.0f}s')
    70        print(f'  Epoch {epoch+1} - AUC: {score}')
    71        if score>best_score:#aucのスコアが良かったら予測値を更新...best_epochをきめるため
    72            best_score = score
    73            best_preds = preds
    74            print("====",f'  Epoch {epoch+1} - Save Best Score: {best_score:.4f}',"===")
    75            torch.save(model.state_dict(), f'/kaggle/working/fold{fold}_efnet_b2_ns_.pth')#各epochのモデルを保存。。。best_epoch終了時のモデルを推論に使用
    76    return best_preds, valid_labels

    純正の PyTorch では、人によって書き方にバリエーションがあるため、「各種ラッパーを使った方がいい」という意見もあります。

    ですが、一度純正で書いてみることで、「どこで何をやっているか」が理解しやすくなるのです。

    たとえば、ラッパーのひとつである「fastai」は、理解がしにくい代表的な例ですね。

    1import fastai
    2from fastai.vision import *

    データ型とsizeに注意!

    データ型や size が異なると、損失関数やモデルにテンソルを通すとき、エラーが出ることがあります。

    もし、以下のようなエラーが出たら、デバックして「dtype」や「size」を確認してみましょう。

    1RuntimeError: Expected object of scalar type Long but got scalar type Float for argument #3 'mat1' in call to _th_addmm_

    検証データに対するAUCを確認

    モデルの訓練の関数を実行して、「out-of-fold の予測値」と「正解ラベル」を考慮することで、AUC が計算できます。

    1# check your CV score!
    2preds = np.concatenate(preds)
    3valid_labels = np.concatenate(valid_labels)
    4
    5score = auc(valid_labels,preds)
    6import datetime
    7dt_now = datetime.datetime.now()
    8print("現在時刻",dt_now)
    9print("=====AUC(CV)======",score)
    1現在時刻 2020-08-12 12:36:38.190845
    2=====AUC(CV)====== 0.9531154143179866

    AUC では、高いスコアが出ていることがわかりますね。

    test データに対しても、高いスコアが出れば、「今回学習したモデルには汎化性能がある」と言えます。

    テストデータに対する予測をみる

    先ほど学習させた、各 fold の重みを用いて、推論を実行します。

    モデルの部分が少し違うことに注意!

    先ほどのモデルのコードと、比較をしてみてください。

    1class Efnet_b2_inference(nn.Module):
    2    def __init__(self,weight_path):
    3        super().__init__()
    4        self.weight_path = weight_path
    5        self.model = geffnet.tf_efficientnet_b2_ns(pretrained=False)
    6        #さいごの部分を付け替え
    7        self.model.global_pool=nn.AdaptiveAvgPool2d(1)
    8        self.model.classifier = nn.Linear(self.model.classifier.in_features, 1)
    9        state_dict = torch.load(self.weight_path,map_location=device)
    10        self.model.load_state_dict(fix_model_state_dict(state_dict))
    11        
    12   def forward(self, x):
    13        x = self.model(x)#ベースのモデルの流れに同じ
    14        return x
    15
    16def fix_model_state_dict(state_dict):
    17    from collections import OrderedDict
    18    new_state_dict = OrderedDict()
    19    for k, v in state_dict.items():
    20        name = k
    21        if name.startswith('model.'):
    22            name = name[6:]  # remove 'model.' of dataparallel
    23        new_state_dict[name] = v
    24    return new_state_dict

    ImageNet で学習させた重みは、出力が1000次元です。

    ですが、自分で学習させたモデルは、出力が1次元であることを考慮せねばなりません。

    推論を実際に行う部分

    各 fold ごとのモデル出力の平均を、最終的な出力とする、単縦なアンサンブルを用いています。

    1def inference(model, test_loader, device):
    2    model.to(device) 
    3    probs = []
    4    labels = []
    5    for i, (images,label) in tqdm(enumerate(test_loader), total=len(test_loader)):
    6            
    7        images = images.to(device)
    8            
    9        with torch.no_grad():
    10            y_preds = model(images)
    11            y_preds = torch.sigmoid(y_preds.view(-1))
    12            
    13        probs.append(y_preds.to('cpu').numpy())
    14        labels.append(label.numpy())
    15
    16    probs = np.concatenate(probs)
    17    labels = np.concatenate(labels)
    18    return probs,labels
    19
    20#ensamble your folds' models!
    21def submit():
    22        print('run inference')
    23        test_dataset = TestDataset(test_df, transform1=get_transforms1(data='valid'),transform2=to_tensor())
    24        test_loader = DataLoader(test_dataset, batch_size=CFG.batch_size, shuffle=False)
    25        probs = []
    26        labels = []
    27        for fold in range(4):
    28            weights_path = "/kaggle/working/fold{}_efnet_b2_ns_a_512_augmix_gridmask.pth".format(fold)
    29            model = Efnet_b2_inference(weights_path)
    30            _probs,_label = inference(model, test_loader, device)
    31            probs.append(_probs)
    32            labels.append(_label)
    33        probs = np.mean(probs, axis=0)
    34        return probs

    推論結果のAUCをみてみる

    上の関数を実行すれば、AUC が確認できます。

    1test_df['predict'] = submit()
    2print(test_df.head())
    3score = auc(test_df['covid'].values[:],test_df['predict'])
    4print("=====AUC(test)======",score)

    これによる結果は、以下のようになりました。

    1run inference
    2100%|██████████| 16/16 [00:01<00:00, 14.93it/s]
    3100%|██████████| 16/16 [00:00<00:00, 16.77it/s]
    4100%|██████████| 16/16 [00:00<00:00, 16.90it/s]
    5100%|██████████| 16/16 [00:00<00:00, 16.72it/s]
    6                                      filename  covid   predict
    70  CODE19 Data/Testing Data/Non Covid/1845.png      0  0.422551
    81  CODE19 Data/Testing Data/Non Covid/1501.png      0  0.481845
    92  CODE19 Data/Testing Data/Non Covid/1923.png      0  0.624478
    103  CODE19 Data/Testing Data/Non Covid/1562.png      0  0.452440
    114  CODE19 Data/Testing Data/Non Covid/1952.png      0  0.521156
    12=====AUC(test)====== 0.5926229508196722

    さいごに

    今回は、Kaggle の環境を用いて画像分類を試してみました。

    ただ、CV のスコアより大幅に低下しており、「今回モデルは訓練データに過学習してしまった」ということになります。

    パラメーターが多いほど、表現できることは増えますが、少ないデータに対しては、過学習をしてしまうのです。

    今回用いた「EfficientNet」は、VGG16 などに比べれば、パラメーターの数は多いです。

    ただ、test データは少なく、画像の撮影にばらつきがあるなど、理想的なデータセットだったとは言えません。

    また、実は、今回用いたものよりも、大規模なデータセットが既に公開されています。

    70 GB ほどで、セグメンテーションマスク付きなので、勉強にはちょうど良いかもしれません。

    【コーネル大学:BIMCV COVID-19+ 】
    https://arxiv.org/abs/2006.01174

    Kaggle公開中

    なお、今回のコードは、Kaggleで公開しています。

    リンク:肺のCT画像からCOVID19かどうか判断してみた!

    実行環境も整っているので、ぜひ試してみてくださいね!

    こちらの記事もオススメ!

    featureImg2020.07.28機械学習 特集知識編人工知能・機械学習でよく使われるワード徹底まとめ!機械学習の元祖「パーセプトロン」とは?【人工知能】ニューラルネ...
    featureImg2020.07.30Python 特集実装編※最新記事順Responder + Firestore でモダンかつサーバーレスなブログシステムを作ってみた!P...
    featureImg2020.07.17ライトコード的「やってみた!」シリーズ「やってみた!」を集めました!(株)ライトコードが今まで作ってきた「やってみた!」記事を集めてみました!※作成日が新し...

    広告メディア事業部

    広告メディア事業部

    おすすめ記事