Ccmmutty logo
Commutty IT
2 min read

ChexNetで学ぶCNNの解釈

https://cdn.magicode.io/media/notebox/6475f8ee-b200-48c7-83da-fb766ca5e341.jpeg
畳み込みニューラルネットワークの推論過程の解釈法をCheXNetを例に学ぶ。

CheXNetとは

概要

主旨:2017年に公表された、NIHの公開している14疾患のラベル付き胸部X線画像に対して肺炎検出タスクをCNNに学習させ、放射線技師を上回るF1スコアを出せたという論文。肺炎検出という2値分類タスクから、14疾患の検出という多クラス分類タスクへと拡張し、過去のモデルを上回るAUROCスコアもだせたとも言っている。
画像のどこが推論に対して影響力が大きいか表したヒートマップ(CAMs)も導入している。
image.png

Dataset

  1. 30805人の患者の112120枚の前面胸部X線画像(1024*1024 png)
  2. 各画像には14種類の疾患ラベルが複数可で紐づいている。
    • ラベルは診断分をNLPによって生成したもので、精度は90+%程
    • NLPによる抽出法はCoursera AI for treatmentで紹介されている。
  3. 1000枚弱の画像に対して、病変部のBounding boxの座標が示されている。

Model architecture

ベースモデルはDenseNet。最後の特徴量マップにGlobal averaging pooling、その次に全結合、そのアウトプットにSigmoid活性化関数を適応してラベル予測確率を出す小構造になっている。Kerasだと下のようになる。
python
from keras.applications.densenet import DenseNet121
from keras.models import Model
from keras.layers import Dense, GlobalAveragePooling2D

base_model = DenseNet121(include_top=False)
x = base_model.output
x = GlobalAveragePooling2D()(x)
predictions = Dense(len(labels), activation="sigmoid")(x)
model = Model(inputs=base_model.input, outputs=predictions)

Discussion

コメントにはログインが必要です。