• トップ
  • ブログ一覧
  • 【実装編】肺のCT画像からCOVID19かどうかの判断は可能か?【機械学習】
  • 【実装編】肺のCT画像からCOVID19かどうかの判断は可能か?【機械学習】

    広告メディア事業部広告メディア事業部
    2020.10.20

    IT技術

    【実装編】肺のCT画像からCOVID19かどうかの判断は可能か?

    実装編~肺のCT画像からCOVID19か予想できるのか?

    肺のCT画像

    今回は、前回に引き続き、肺の CT 画像から、COVID 19か否かを予測する深層学習モデルを、「PyTorch」で実装してみたいと思います。

    PyTorch を採用した理由は、「Kaggle」での実装例も多く、公式ドキュメントも充実しているためです。

    ちなみに、こちらの記事は「プログラミングで分類に挑戦する」ということが目的で、COVID 19を確実に分類できるわけではありませんので、予めご了承をお願い致します。

    【前処理編】をお読みでない方は、まずはこちらからお読みください。

    【前処理編】肺のCT画像からCOVID19かどうかの判断は可能か?2020.10.14【前処理編】肺のCT画像からCOVID19かどうかの判断は可能か?【機械学習】前処理編~肺のCT画像からCOVID19か予想できるのか?今回は、肺の CT 画像から、COVID 19か否かを予測す...

    モデルの訓練

    モデルの訓練部分は、PyTorch の公式チュートリアルを参考にしています。

    【 PyTorch公式:チュートリアル】
    https://pytorch.org/tutorials/

    二値分類では、出力を2次元にして、「CrossEntropyloss」を用いる人もいるかと思います。

    ですが、今回は1次元なので、「BinaryCrossEntropyloss」を用いています。

    1def train_fn(fold):
    2    print(f"### fold: {fold} ###")
    3    trn_idx = folds[folds['fold'] != fold].index
    4    val_idx = folds[folds['fold'] == fold].index
    5    train_dataset = TrainDataset(folds.loc[trn_idx].reset_index(drop=True), 
    6                                 folds.loc[trn_idx].reset_index(drop=True)[CFG.target_col], 
    7                                 transform1=get_transforms1(data='train'),transform2=to_tensor())
    8    valid_dataset = TrainDataset(folds.loc[val_idx].reset_index(drop=True), 
    9                                 folds.loc[val_idx].reset_index(drop=True)[CFG.target_col], 
    10                                 transform1=get_transforms1(data='valid'),transform2=to_tensor())
    11    
    12    
    13    train_loader = DataLoader(train_dataset, batch_size=CFG.batch_size, shuffle=True, num_workers=4)
    14    valid_loader = DataLoader(valid_dataset, batch_size=CFG.batch_size, shuffle=False, num_workers=4)
    15    model = Efnet_b2_ns(weight_path="/kaggle/input/pytorch-efnet-ns-weights/tf_efficientnet_b2_aa-60c94f97.pth")
    16    model.to(device)
    17    
    18    optimizer = Adam(model.parameters(), lr=CFG.lr, amsgrad=False)
    19    #scheduler = ReduceLROnPlateau(optimizer, 'min', factor=0.5, patience=2, verbose=True, eps=1e-6)
    20    scheduler = CosineAnnealingLR(optimizer, T_max=20, eta_min=0.001)
    21    
    22    criterion = nn.BCELoss()#weight = class_weight
    23    best_score = -100
    24    best_loss = np.inf
    25    best_preds = None
    26    for epoch in range(CFG.epochs): 
    27        start_time = time.time()
    28        model.train()
    29        avg_loss = 0.
    30
    31        optimizer.zero_grad()
    32        tk0 = tqdm(enumerate(train_loader), total=len(train_loader))
    33
    34        for i, (images, labels) in tk0:
    35            images = images.to(device)
    36            labels = labels.to(device)    
    37            y_preds = model(images.float())
    38            y_preds = torch.sigmoid(y_preds.view(-1))
    39            loss = criterion(y_preds, labels)
    40            loss.backward()
    41            optimizer.step()
    42            optimizer.zero_grad()
    43            avg_loss += loss.item() / len(train_loader)
    44            
    45        model.eval()
    46        avg_val_loss = 0.
    47        preds = []
    48        valid_labels = []
    49        tk1 = tqdm(enumerate(valid_loader), total=len(valid_loader))
    50
    51        for i, (images, labels) in tk1:
    52            images = images.to(device)
    53            labels = labels.to(device)
    54            
    55            with torch.no_grad():
    56                y_preds = model(images.float())
    57                
    58                y_preds = torch.sigmoid(y_preds.view(-1))
    59            preds.append(y_preds.to('cpu').numpy())
    60            valid_labels.append(labels.to('cpu').numpy())
    61
    62            loss = criterion(y_preds, labels)
    63            avg_val_loss += loss.item() / len(valid_loader)
    64        scheduler.step(avg_val_loss)
    65        preds = np.concatenate(preds)
    66        valid_labels = np.concatenate(valid_labels)
    67        score = auc(valid_labels,preds)
    68        elapsed = time.time() - start_time
    69        print(f'  Epoch {epoch+1} - avg_train_loss: {avg_loss:.4f}  avg_val_loss: {avg_val_loss:.4f}  time: {elapsed:.0f}s')
    70        print(f'  Epoch {epoch+1} - AUC: {score}')
    71        if score>best_score:#aucのスコアが良かったら予測値を更新...best_epochをきめるため
    72            best_score = score
    73            best_preds = preds
    74            print("====",f'  Epoch {epoch+1} - Save Best Score: {best_score:.4f}',"===")
    75            torch.save(model.state_dict(), f'/kaggle/working/fold{fold}_efnet_b2_ns_.pth')#各epochのモデルを保存。。。best_epoch終了時のモデルを推論に使用
    76    return best_preds, valid_labels

    純正の PyTorch では、人によって書き方にバリエーションがあるため、「各種ラッパーを使った方がいい」という意見もあります。

    ですが、一度純正で書いてみることで、「どこで何をやっているか」が理解しやすくなるのです。

    たとえば、ラッパーのひとつである「fastai」は、理解がしにくい代表的な例ですね。

    1import fastai
    2from fastai.vision import *

    データ型とsizeに注意!

    データ型や size が異なると、損失関数やモデルにテンソルを通すとき、エラーが出ることがあります。

    もし、以下のようなエラーが出たら、デバックして「dtype」や「size」を確認してみましょう。

    1RuntimeError: Expected object of scalar type Long but got scalar type Float for argument #3 'mat1' in call to _th_addmm_

    検証データに対するAUCを確認

    モデルの訓練の関数を実行して、「out-of-fold の予測値」と「正解ラベル」を考慮することで、AUC が計算できます。

    1# check your CV score!
    2preds = np.concatenate(preds)
    3valid_labels = np.concatenate(valid_labels)
    4
    5score = auc(valid_labels,preds)
    6import datetime
    7dt_now = datetime.datetime.now()
    8print("現在時刻",dt_now)
    9print("=====AUC(CV)======",score)
    1現在時刻 2020-08-12 12:36:38.190845
    2=====AUC(CV)====== 0.9531154143179866

    AUC では、高いスコアが出ていることがわかりますね。

    test データに対しても、高いスコアが出れば、「今回学習したモデルには汎化性能がある」と言えます。

    テストデータに対する予測をみる

    先ほど学習させた、各 fold の重みを用いて、推論を実行します。

    モデルの部分が少し違うことに注意!

    先ほどのモデルのコードと、比較をしてみてください。

    1class Efnet_b2_inference(nn.Module):
    2    def __init__(self,weight_path):
    3        super().__init__()
    4        self.weight_path = weight_path
    5        self.model = geffnet.tf_efficientnet_b2_ns(pretrained=False)
    6        #さいごの部分を付け替え
    7        self.model.global_pool=nn.AdaptiveAvgPool2d(1)
    8        self.model.classifier = nn.Linear(self.model.classifier.in_features, 1)
    9        state_dict = torch.load(self.weight_path,map_location=device)
    10        self.model.load_state_dict(fix_model_state_dict(state_dict))
    11        
    12   def forward(self, x):
    13        x = self.model(x)#ベースのモデルの流れに同じ
    14        return x
    15
    16def fix_model_state_dict(state_dict):
    17    from collections import OrderedDict
    18    new_state_dict = OrderedDict()
    19    for k, v in state_dict.items():
    20        name = k
    21        if name.startswith('model.'):
    22            name = name[6:]  # remove 'model.' of dataparallel
    23        new_state_dict[name] = v
    24    return new_state_dict

    ImageNet で学習させた重みは、出力が1000次元です。

    ですが、自分で学習させたモデルは、出力が1次元であることを考慮せねばなりません。

    推論を実際に行う部分

    各 fold ごとのモデル出力の平均を、最終的な出力とする、単縦なアンサンブルを用いています。

    1def inference(model, test_loader, device):
    2    model.to(device) 
    3    probs = []
    4    labels = []
    5    for i, (images,label) in tqdm(enumerate(test_loader), total=len(test_loader)):
    6            
    7        images = images.to(device)
    8            
    9        with torch.no_grad():
    10            y_preds = model(images)
    11            y_preds = torch.sigmoid(y_preds.view(-1))
    12            
    13        probs.append(y_preds.to('cpu').numpy())
    14        labels.append(label.numpy())
    15
    16    probs = np.concatenate(probs)
    17    labels = np.concatenate(labels)
    18    return probs,labels
    19
    20#ensamble your folds' models!
    21def submit():
    22        print('run inference')
    23        test_dataset = TestDataset(test_df, transform1=get_transforms1(data='valid'),transform2=to_tensor())
    24        test_loader = DataLoader(test_dataset, batch_size=CFG.batch_size, shuffle=False)
    25        probs = []
    26        labels = []
    27        for fold in range(4):
    28            weights_path = "/kaggle/working/fold{}_efnet_b2_ns_a_512_augmix_gridmask.pth".format(fold)
    29            model = Efnet_b2_inference(weights_path)
    30            _probs,_label = inference(model, test_loader, device)
    31            probs.append(_probs)
    32            labels.append(_label)
    33        probs = np.mean(probs, axis=0)
    34        return probs

    推論結果のAUCをみてみる

    上の関数を実行すれば、AUC が確認できます。

    1test_df['predict'] = submit()
    2print(test_df.head())
    3score = auc(test_df['covid'].values[:],test_df['predict'])
    4print("=====AUC(test)======",score)

    これによる結果は、以下のようになりました。

    1run inference
    2100%|██████████| 16/16 [00:01<00:00, 14.93it/s]
    3100%|██████████| 16/16 [00:00<00:00, 16.77it/s]
    4100%|██████████| 16/16 [00:00<00:00, 16.90it/s]
    5100%|██████████| 16/16 [00:00<00:00, 16.72it/s]
    6                                      filename  covid   predict
    70  CODE19 Data/Testing Data/Non Covid/1845.png      0  0.422551
    81  CODE19 Data/Testing Data/Non Covid/1501.png      0  0.481845
    92  CODE19 Data/Testing Data/Non Covid/1923.png      0  0.624478
    103  CODE19 Data/Testing Data/Non Covid/1562.png      0  0.452440
    114  CODE19 Data/Testing Data/Non Covid/1952.png      0  0.521156
    12=====AUC(test)====== 0.5926229508196722

    さいごに

    今回は、Kaggle の環境を用いて画像分類を試してみました。

    ただ、CV のスコアより大幅に低下しており、「今回モデルは訓練データに過学習してしまった」ということになります。

    パラメーターが多いほど、表現できることは増えますが、少ないデータに対しては、過学習をしてしまうのです。

    今回用いた「EfficientNet」は、VGG16 などに比べれば、パラメーターの数は多いです。

    ただ、test データは少なく、画像の撮影にばらつきがあるなど、理想的なデータセットだったとは言えません。

    また、実は、今回用いたものよりも、大規模なデータセットが既に公開されています。

    70 GB ほどで、セグメンテーションマスク付きなので、勉強にはちょうど良いかもしれません。

    【コーネル大学:BIMCV COVID-19+ 】
    https://arxiv.org/abs/2006.01174

    Kaggle公開中

    なお、今回のコードは、Kaggleで公開しています。

    リンク:肺のCT画像からCOVID19かどうか判断してみた!

    実行環境も整っているので、ぜひ試してみてくださいね!

    こちらの記事もオススメ!

    featureImg2020.07.28機械学習 特集知識編人工知能・機械学習でよく使われるワード徹底まとめ!機械学習の元祖「パーセプトロン」とは?【人工知能】ニューラルネ...
    featureImg2020.07.30Python 特集実装編※最新記事順Responder + Firestore でモダンかつサーバーレスなブログシステムを作ってみた!P...
    featureImg2020.07.17ライトコード的「やってみた!」シリーズ「やってみた!」を集めました!(株)ライトコードが今まで作ってきた「やってみた!」記事を集めてみました!※作成日が新し...

    ライトコードでは、エンジニアを積極採用中!

    ライトコードでは、エンジニアを積極採用しています!社長と一杯しながらお話しする機会もご用意しております。そのほかカジュアル面談等もございますので、くわしくは採用情報をご確認ください。

    採用情報へ

    広告メディア事業部
    広告メディア事業部
    Show more...

    おすすめ記事

    エンジニア大募集中!

    ライトコードでは、エンジニアを積極採用中です。

    特に、WEBエンジニアとモバイルエンジニアは是非ご応募お待ちしております!

    また、フリーランスエンジニア様も大募集中です。

    background