Magicode logo
Magicode
2 min read

ECGでの心拍分類+転移学習

https://cdn.magicode.io/media/notebox/1369009b-4071-4427-b802-0287c366ea11.jpeg
論文を基に1-D CNN・転移学習・ECGデータ処理法を実装法を中心に学ぶ。

Abstract

従来のECG解析に関する機械学習メソッドは異なるタスクごとに独立したものであったが、この論文ではタスクごとに知識を再利用できないかを検証することを目的としている。
まずMIT-BIHデータセットで不整脈分類タスク(5クラス)について学習させる。その後得られた学習済みNN(転移学習)に対してPTBデータセットで心筋梗塞2値分類タスクについて学習させた。
結果、平均Accuracyは不整脈分類タスクでは 93.4% で心筋梗塞2値分類タスクでは 95.9%と高精度な予測器を生成することができた。この結果から不整脈分類タスクの知識をうまく心筋梗塞2値分類タスクに転移させたことができたと主張している。

Method & Implementation

Preprocessing

以下のような手順でECG波形の前処理を行ってます
  1. ECGデータを10秒間ごとのWindowに分けてそのうちの一つのWindowをとってくる
  2. 0~1に正規化する
  3. 極大値をすべて見つける
  4. 0.9以上の極大値をR-peakとする
  5. RR間隔の平均値を求め、その間隔をWindow長のTとする
  6. それぞれのR-peakから1.2Tだけデータをとる
  7. あらかじめ指定したデータ長に満たないデータを0で埋める(Zero-padding)
これら前処理済みデータセットはすでにKaggleに公開されています。
実際のデータに学習モデルを適応するためには入力データの形式を学習時のそれと合わせる必要があります。実際に得られるECGデータはサンプリング周波数も異なるので上の前処理に加えてそこも調節する必要があります。
学習データセットは125Hzなのでまずはその周波数に従ってリサンプリングしましょう。
チュートリアル用に360HzでサンプリングされたScipyのECGデータを使います。
python
import numpy as np
from scipy.misc import electrocardiogram
import matplotlib.pyplot as plt

V = electrocardiogram()
Hz = 360 # 360Hzだから
T = np.arange(ecg.size) * 1000 / Hz
plt.figure(figsize=(10,7))
plt.plot(T, V)
plt.xlabel("[ms]")
plt.ylabel("[mV]")
plt.xlim(0, 10000)
plt.ylim(-1.1, 2.0)
plt.show()
png
まずは125Hzでリサンプリングします。
python
from scipy import interpolate

def resample(T, V, Hz=125, kind='linear'):
    f = interpolate.interp1d(T,V,kind=kind)
    T = np.arange(np.min(T), np.max(T), 1000/Hz)
    V = f(T)
    return T, V

T_new, V_new = resample(T, V)
plt.figure(figsize=(10,7))
plt.plot(T, V)
plt.plot(T_new, V_new)
plt.xlabel("[ms]")
plt.ylabel("[mV]")
plt.xlim(0, 10000)
plt.ylim(-1.1, 2.0)
plt.show()
png
ほぼ一致していることが分かります。それでは1.からみていきましょう。
  1. ECGデータを10秒間ごとのWindowに分けてそのうちの一つのWindowをとってくる
python
def split(T, V, window=10):
    Hz = int(1000 / (T[1] - T[0]))
    Ts = []
    Vs = []
    for i in range(0, len(T), window*Hz):
        if T[i + window * Hz - 1:i + window * Hz]:
            Ts.append(T[i:i+window*Hz])
            Vs.append(V[i:i+window*Hz])
        else:
            Ts.append(T[i:])
            Vs.append(V[i:])    
    return Ts, Vs

Ts, Vs = split(T_new, V_new)

plt.figure(figsize=(10,7))
for T, V in zip(Ts,Vs):    
    plt.plot(T, V)
plt.xlabel("[ms]")
plt.ylabel("[mV]")
plt.show()
png
  1. 0~1に正規化する
python
def normalize(V):
    return (V-np.min(V))/(np.max(V)-np.min(V))

T_new, V_new = Ts[0], Vs[0]
V_new = normalize(V_new)

plt.figure(figsize=(10,7))
plt.plot(T_new, V_new)
plt.xlabel("[ms]")
plt.ylabel("[mV]")
plt.show()
png
  1. 極大値をすべて見つける
  2. 0.9以上の極大値をR-peakとする
python
from scipy.signal import find_peaks

def find_R_peaks(V, threshold=0.9):
    R_peaks, _ = find_peaks(V, height=threshold)
    return R_peaks

R_peaks = find_R_peaks(V_new)

plt.figure(figsize=(10,7))
plt.plot(T_new, V_new)
plt.scatter(T_new[R_peaks], V_new[R_peaks], color='r')
plt.xlabel("[ms]")
plt.ylabel("[mV]")
plt.show()
png
このデータではThreshold=0.9というのはあまりよくないようですが、論文の通りにいきましょう。
  1. RR間隔の平均値を求め、その間隔をWindow長のTとする
python
def find_median_interval(R_peaks):
    return np.mean(np.diff(R_peaks)) # index

interval = find_median_interval(R_peaks)
  1. それぞれのR-peakから1.2Tだけデータをとる
  2. あらかじめ指定したデータ長に満たないデータを0で埋める(Zero-padding)
python
def extract_beats(T, V, R_peaks, interval, max_duration=187):
    window = int(1.2*interval) # index
    beats = []
    durations = []
    for peak in R_peaks:
        beat = np.zeros(max_duration) # 固定長の空の行列をつくっとく

        if peak + window <= len(V): # R_peakからWindow長データを取り切れる前提をおく

            if window > max_duration: # Window長が指定した固定長を超えている場合
                duration = [T[peak],T[peak+max_duration-1]]
                beat = V[peak:peak+max_duration]
                beats.append(beat)
                durations.append(duration)

            else:
                duration = [T[peak],T[peak+window-1]]
                beat[:window] = V[peak:peak+window]
                beats.append(beat)
                durations.append(duration)
            
    return np.array(beats), durations # 抽出された心拍データとその始まりと終わりの時間を返す

beats, durations = extract_beats(T_new, V_new, R_peaks, interval)
print("Shape of the extracted beats data: ", beats.shape)
Shape of the extracted beats data: (4, 187)
python
def ecg_with_beats(T,V,durations):
    fig = plt.figure(figsize = (10,7))
    ax = fig.add_subplot(111)
    for i in range(len(durations)):
        duration = durations[i]
        ax.axvspan(duration[0], duration[1],color="coral" if i%2 == 0 else "lime" ,alpha=0.3)
    ax.plot(T,V)
    plt.xlabel("[ms]")
    plt.ylabel("[mV]")
    plt.show()
    return

ecg_with_beats(T_new, V_new, durations)
png
python
plt.figure(figsize=(10,7))
for beat in beats:
    plt.plot(beat)
plt.xlabel("index")
plt.ylabel("[mV]")
plt.show()
png
上手く心拍を抽出できていることがわかります。
最後に以上の処理をpreprocess関数にまとめてみましょう。
python
def preprocess(T, V, Hz=125, max_duration=187):
    T, V = resample(T, V, Hz)
    Ts, Vs = split(T, V)
    Beats = []
    Durations = []
    for T, V in zip(Ts, Vs):
        V = normalize(V)
        R_peaks = find_R_peaks(V)
        if len(R_peaks) >= 2:
            interval = find_median_interval(R_peaks)
            beats, durations = extract_beats(T, V, R_peaks, interval)
            if len(beats) >= 1:
                Beats.append(beats)
                Durations += durations
    Beats = np.vstack(Beats)
    return Beats, Durations

beats, durations = preprocess(T, V)
print("Shape of the extracted beats data: ", beats.shape)
ecg_with_beats(T, V, durations)
Shape of the extracted beats data: (104, 187)
png
ほんの一部の心拍が抽出されていることが分かります。この論文大丈夫か心配になってきましたね。

Model

論文のモデルは以下
論文にのってるモデルを少し変えたやつ(Residual blockなしバージョン)

1D-Convolution layer

"all convolution layers are applying 1-D convolution through time and each have 32 kernels of size 5"
  • カーネル:入力にかける行列のこと、今回は1次元。32 Kernalsはカーネルの層数を意味するので**Kerasだったらfilters = 32**にあたる。
  • サイズ:カーネルのWindow長。今回は一次元。size 5は**Kerasだったらkernel_size = 5**にあたる。
二次元畳み込み層よりもパラメータ数はもちろん少ない。今回は入力が一次元なので一次元畳み込み層で自然。

Dataset

https://www.kaggle.com/shayanfazeli/heartbeat のデータセットを使います。前処理済み最高長187の心拍がCSVファイルで格納されています。188番目のカラムにはその心拍のラベル(心室期外収縮や心筋梗塞など)が整数クラスで入ってます。
MITBIHのAnnotationは以下のようになってます。
N,S,V,F,Qはそれぞれ0,1,2,3,4クラスに対応しています。実際にデータセットを見てみましょう。
python
import pandas as pd

df_train = pd.read_csv("/content/drive/My Drive/kaggle_ECG/mitbih_train.csv", header=None) # 自分のGoogle driveにでもデータセットダウンロード
print("Data shape: ", df_train.shape)
print("All classes (shown in 188th column): ", df_train.iloc[:,187].unique())
Data shape: (87554, 188) All classes (shown in 188th column): [0. 1. 2. 3. 4.]
python
plt.figure(figsize=(10,7))
for beat in df_train.iloc[:5,:].values:
    plt.plot(beat)
plt.xlabel("index")
plt.ylabel("[mV]")
plt.show()
png
論文記載のアルゴリズムに従って前処理されていることが分かります。

Training the Arrhythmia Classifier

MITBIHデータセットでまずは学習します。github借りパくです。
python
from keras import optimizers, losses, activations, models
from keras.callbacks import ModelCheckpoint, EarlyStopping, LearningRateScheduler, ReduceLROnPlateau
from keras.layers import Dense, Input, Dropout, Convolution1D, MaxPool1D, GlobalMaxPool1D, GlobalAveragePooling1D, \
    concatenate
from sklearn.metrics import f1_score, accuracy_score


df_train = pd.read_csv("/content/drive/My Drive/kaggle_ECG/mitbih_train.csv", header=None) 
df_train = df_train.sample(frac=1)
df_test = pd.read_csv("/content/drive/My Drive/kaggle_ECG/mitbih_test.csv", header=None)

Y = np.array(df_train[187].values).astype(np.int8)
X = np.array(df_train[list(range(187))].values)[..., np.newaxis]

Y_test = np.array(df_test[187].values).astype(np.int8)
X_test = np.array(df_test[list(range(187))].values)[..., np.newaxis]


def get_model_mitbih():
    nclass = 5
    inp = Input(shape=(187, 1))
    img_1 = Convolution1D(16, kernel_size=5, activation=activations.relu, padding="valid")(inp)
    img_1 = Convolution1D(16, kernel_size=5, activation=activations.relu, padding="valid")(img_1)
    img_1 = MaxPool1D(pool_size=2)(img_1)
    img_1 = Dropout(rate=0.1)(img_1)
    img_1 = Convolution1D(32, kernel_size=3, activation=activations.relu, padding="valid")(img_1)
    img_1 = Convolution1D(32, kernel_size=3, activation=activations.relu, padding="valid")(img_1)
    img_1 = MaxPool1D(pool_size=2)(img_1)
    img_1 = Dropout(rate=0.1)(img_1)
    img_1 = Convolution1D(32, kernel_size=3, activation=activations.relu, padding="valid")(img_1)
    img_1 = Convolution1D(32, kernel_size=3, activation=activations.relu, padding="valid")(img_1)
    img_1 = MaxPool1D(pool_size=2)(img_1)
    img_1 = Dropout(rate=0.1)(img_1)
    img_1 = Convolution1D(256, kernel_size=3, activation=activations.relu, padding="valid")(img_1)
    img_1 = Convolution1D(256, kernel_size=3, activation=activations.relu, padding="valid")(img_1)
    img_1 = GlobalMaxPool1D()(img_1)
    img_1 = Dropout(rate=0.2)(img_1)

    dense_1 = Dense(64, activation=activations.relu, name="dense_1")(img_1)
    dense_1 = Dense(64, activation=activations.relu, name="dense_2")(dense_1)
    dense_1 = Dense(nclass, activation=activations.softmax, name="dense_3_mitbih")(dense_1)

    model = models.Model(inputs=inp, outputs=dense_1)
    opt = optimizers.Adam(0.001)

    model.compile(optimizer=opt, loss=losses.sparse_categorical_crossentropy, metrics=['acc'])
    model.summary()
    return model

model = get_model_mitbih()
file_path = "/content/drive/My Drive/kaggle_ECG/baseline_cnn_mitbih.h5"
checkpoint = ModelCheckpoint(file_path, monitor='val_acc', verbose=1, save_best_only=True, mode='max')
early = EarlyStopping(monitor="val_acc", mode="max", patience=5, verbose=1)
redonplat = ReduceLROnPlateau(monitor="val_acc", mode="max", patience=3, verbose=2)
callbacks_list = [checkpoint, early, redonplat]  # early

model.fit(X, Y, epochs=1000, verbose=2, callbacks=callbacks_list, validation_split=0.1)
model.load_weights(file_path)

pred_test = model.predict(X_test)
pred_test = np.argmax(pred_test, axis=-1)

f1 = f1_score(Y_test, pred_test, average="macro")

print("Test f1 score : %s "% f1)

acc = accuracy_score(Y_test, pred_test)

print("Test accuracy score : %s "% acc)
Model: "functional_1" _________________________________________________________________ Layer (type) Output Shape Param #
================================================================= input_1 (InputLayer) [(None, 187, 1)] 0
_________________________________________________________________ conv1d (Conv1D) (None, 183, 16) 96
_________________________________________________________________ conv1d_1 (Conv1D) (None, 179, 16) 1296
_________________________________________________________________ max_pooling1d (MaxPooling1D) (None, 89, 16) 0
_________________________________________________________________ dropout (Dropout) (None, 89, 16) 0
_________________________________________________________________ conv1d_2 (Conv1D) (None, 87, 32) 1568
_________________________________________________________________ conv1d_3 (Conv1D) (None, 85, 32) 3104
_________________________________________________________________ max_pooling1d_1 (MaxPooling1 (None, 42, 32) 0
_________________________________________________________________ dropout_1 (Dropout) (None, 42, 32) 0
_________________________________________________________________ conv1d_4 (Conv1D) (None, 40, 32) 3104
_________________________________________________________________ conv1d_5 (Conv1D) (None, 38, 32) 3104
_________________________________________________________________ max_pooling1d_2 (MaxPooling1 (None, 19, 32) 0
_________________________________________________________________ dropout_2 (Dropout) (None, 19, 32) 0
_________________________________________________________________ conv1d_6 (Conv1D) (None, 17, 256) 24832
_________________________________________________________________ conv1d_7 (Conv1D) (None, 15, 256) 196864
_________________________________________________________________ global_max_pooling1d (Global (None, 256) 0
_________________________________________________________________ dropout_3 (Dropout) (None, 256) 0
_________________________________________________________________ dense_1 (Dense) (None, 64) 16448
_________________________________________________________________ dense_2 (Dense) (None, 64) 4160
_________________________________________________________________ dense_3_mitbih (Dense) (None, 5) 325
================================================================= Total params: 254,901 Trainable params: 254,901 Non-trainable params: 0 _________________________________________________________________ Epoch 1/1000
Epoch 00001: val_acc improved from -inf to 0.92005, saving model to /content/drive/My Drive/kaggle_ECG/baseline_cnn_mitbih.h5
2463/2463 - 11s - loss: 0.3817 - acc: 0.8841 - val_loss: 0.2877 - val_acc: 0.9201
Epoch 2/1000

Epoch 00002: val_acc improved from 0.92005 to 0.95420, saving model to /content/drive/My Drive/kaggle_ECG/baseline_cnn_mitbih.h5
2463/2463 - 11s - loss: 0.2364 - acc: 0.9326 - val_loss: 0.1660 - val_acc: 0.9542
Epoch 3/1000

Epoch 00003: val_acc improved from 0.95420 to 0.96551, saving model to /content/drive/My Drive/kaggle_ECG/baseline_cnn_mitbih.h5
2463/2463 - 11s - loss: 0.1691 - acc: 0.9537 - val_loss: 0.1357 - val_acc: 0.9655
Epoch 4/1000

Epoch 00004: val_acc improved from 0.96551 to 0.96962, saving model to /content/drive/My Drive/kaggle_ECG/baseline_cnn_mitbih.h5
2463/2463 - 11s - loss: 0.1409 - acc: 0.9620 - val_loss: 0.1146 - val_acc: 0.9696
Epoch 5/1000

Epoch 00005: val_acc improved from 0.96962 to 0.97270, saving model to /content/drive/My Drive/kaggle_ECG/baseline_cnn_mitbih.h5
2463/2463 - 11s - loss: 0.1234 - acc: 0.9672 - val_loss: 0.1008 - val_acc: 0.9727
Epoch 6/1000

Epoch 00006: val_acc improved from 0.97270 to 0.97487, saving model to /content/drive/My Drive/kaggle_ECG/baseline_cnn_mitbih.h5
2463/2463 - 12s - loss: 0.1114 - acc: 0.9697 - val_loss: 0.0901 - val_acc: 0.9749
Epoch 7/1000

Epoch 00007: val_acc improved from 0.97487 to 0.97602, saving model to /content/drive/My Drive/kaggle_ECG/baseline_cnn_mitbih.h5
2463/2463 - 11s - loss: 0.1013 - acc: 0.9723 - val_loss: 0.0847 - val_acc: 0.9760
Epoch 8/1000

Epoch 00008: val_acc did not improve from 0.97602
2463/2463 - 10s - loss: 0.0945 - acc: 0.9741 - val_loss: 0.0886 - val_acc: 0.9751
Epoch 9/1000

Epoch 00009: val_acc improved from 0.97602 to 0.97796, saving model to /content/drive/My Drive/kaggle_ECG/baseline_cnn_mitbih.h5
2463/2463 - 12s - loss: 0.0894 - acc: 0.9752 - val_loss: 0.0814 - val_acc: 0.9780
Epoch 10/1000

Epoch 00010: val_acc improved from 0.97796 to 0.97967, saving model to /content/drive/My Drive/kaggle_ECG/baseline_cnn_mitbih.h5
2463/2463 - 11s - loss: 0.0840 - acc: 0.9771 - val_loss: 0.0723 - val_acc: 0.9797
Epoch 11/1000

Epoch 00011: val_acc did not improve from 0.97967
2463/2463 - 10s - loss: 0.0795 - acc: 0.9775 - val_loss: 0.0743 - val_acc: 0.9788
Epoch 12/1000

Epoch 00012: val_acc improved from 0.97967 to 0.98184, saving model to /content/drive/My Drive/kaggle_ECG/baseline_cnn_mitbih.h5
2463/2463 - 11s - loss: 0.0772 - acc: 0.9776 - val_loss: 0.0666 - val_acc: 0.9818
Epoch 13/1000

Epoch 00013: val_acc did not improve from 0.98184
2463/2463 - 10s - loss: 0.0741 - acc: 0.9790 - val_loss: 0.0649 - val_acc: 0.9814
Epoch 14/1000

Epoch 00014: val_acc did not improve from 0.98184
2463/2463 - 11s - loss: 0.0702 - acc: 0.9798 - val_loss: 0.0660 - val_acc: 0.9802
Epoch 15/1000

Epoch 00015: val_acc improved from 0.98184 to 0.98241, saving model to /content/drive/My Drive/kaggle_ECG/baseline_cnn_mitbih.h5
2463/2463 - 12s - loss: 0.0711 - acc: 0.9795 - val_loss: 0.0708 - val_acc: 0.9824
Epoch 16/1000

Epoch 00016: val_acc did not improve from 0.98241
2463/2463 - 10s - loss: 0.0679 - acc: 0.9804 - val_loss: 0.0646 - val_acc: 0.9823
Epoch 17/1000

Epoch 00017: val_acc did not improve from 0.98241
2463/2463 - 10s - loss: 0.0656 - acc: 0.9808 - val_loss: 0.0651 - val_acc: 0.9812
Epoch 18/1000

Epoch 00018: val_acc improved from 0.98241 to 0.98344, saving model to /content/drive/My Drive/kaggle_ECG/baseline_cnn_mitbih.h5
2463/2463 - 12s - loss: 0.0641 - acc: 0.9813 - val_loss: 0.0602 - val_acc: 0.9834
Epoch 19/1000

Epoch 00019: val_acc did not improve from 0.98344
2463/2463 - 10s - loss: 0.0615 - acc: 0.9820 - val_loss: 0.0659 - val_acc: 0.9814
Epoch 20/1000

Epoch 00020: val_acc did not improve from 0.98344
2463/2463 - 10s - loss: 0.0615 - acc: 0.9820 - val_loss: 0.0592 - val_acc: 0.9828
Epoch 21/1000

Epoch 00021: val_acc did not improve from 0.98344

Epoch 00021: ReduceLROnPlateau reducing learning rate to 0.00010000000474974513.
2463/2463 - 11s - loss: 0.0599 - acc: 0.9828 - val_loss: 0.0545 - val_acc: 0.9831
Epoch 22/1000

Epoch 00022: val_acc improved from 0.98344 to 0.98561, saving model to /content/drive/My Drive/kaggle_ECG/baseline_cnn_mitbih.h5
2463/2463 - 12s - loss: 0.0449 - acc: 0.9867 - val_loss: 0.0477 - val_acc: 0.9856
Epoch 23/1000

Epoch 00023: val_acc improved from 0.98561 to 0.98584, saving model to /content/drive/My Drive/kaggle_ECG/baseline_cnn_mitbih.h5
2463/2463 - 11s - loss: 0.0408 - acc: 0.9879 - val_loss: 0.0453 - val_acc: 0.9858
Epoch 24/1000

Epoch 00024: val_acc improved from 0.98584 to 0.98595, saving model to /content/drive/My Drive/kaggle_ECG/baseline_cnn_mitbih.h5
2463/2463 - 12s - loss: 0.0390 - acc: 0.9882 - val_loss: 0.0439 - val_acc: 0.9860
Epoch 25/1000

Epoch 00025: val_acc improved from 0.98595 to 0.98721, saving model to /content/drive/My Drive/kaggle_ECG/baseline_cnn_mitbih.h5
2463/2463 - 12s - loss: 0.0358 - acc: 0.9891 - val_loss: 0.0426 - val_acc: 0.9872
Epoch 26/1000

Epoch 00026: val_acc did not improve from 0.98721
2463/2463 - 10s - loss: 0.0347 - acc: 0.9897 - val_loss: 0.0430 - val_acc: 0.9864
Epoch 27/1000

Epoch 00027: val_acc did not improve from 0.98721
2463/2463 - 10s - loss: 0.0351 - acc: 0.9889 - val_loss: 0.0438 - val_acc: 0.9869
Epoch 28/1000

Epoch 00028: val_acc did not improve from 0.98721

Epoch 00028: ReduceLROnPlateau reducing learning rate to 1.0000000474974514e-05.
2463/2463 - 11s - loss: 0.0338 - acc: 0.9897 - val_loss: 0.0425 - val_acc: 0.9862
Epoch 29/1000

Epoch 00029: val_acc did not improve from 0.98721
2463/2463 - 10s - loss: 0.0324 - acc: 0.9896 - val_loss: 0.0421 - val_acc: 0.9862
Epoch 30/1000

Epoch 00030: val_acc did not improve from 0.98721
2463/2463 - 10s - loss: 0.0324 - acc: 0.9897 - val_loss: 0.0422 - val_acc: 0.9866
Epoch 00030: early stopping
Test f1 score : 0.9158830356755775 
Test accuracy score : 0.9850630367257446 
結果
  • Test f1 score : 0.9158830356755775
  • Test accuracy score : 0.9850630367257446

Training the MI Predictor

先ほどのMITBIHデータセットで得られたNNを利用して心筋梗塞2値分類タスクについてPTBDBデータセットで学習します。
論文では不整脈分類タスクのNNの最後の2層のみFine-tuningしてましたが、今回使うGithubのほうでは最後の2層以外の重みを固定するということはしないで、一緒に学習しなおすということをして実際最終2層より以前の重みをフリーズして学習するよりもスコアが良かったのでそちらを紹介します。
python
from keras import optimizers, losses, activations, models
from keras.callbacks import ModelCheckpoint, EarlyStopping, LearningRateScheduler, ReduceLROnPlateau
from keras.layers import Dense, Input, Dropout, Convolution1D, MaxPool1D, GlobalMaxPool1D, GlobalAveragePooling1D, \
    concatenate
from sklearn.metrics import accuracy_score, f1_score
from sklearn.model_selection import train_test_split

df_1 = pd.read_csv("/content/drive/My Drive/kaggle_ECG/ptbdb_normal.csv", header=None)
df_2 = pd.read_csv("/content/drive/My Drive/kaggle_ECG/ptbdb_abnormal.csv", header=None)
df = pd.concat([df_1, df_2])

df_train, df_test = train_test_split(df, test_size=0.2, random_state=1337, stratify=df[187])


Y = np.array(df_train[187].values).astype(np.int8)
X = np.array(df_train[list(range(187))].values)[..., np.newaxis]

Y_test = np.array(df_test[187].values).astype(np.int8)
X_test = np.array(df_test[list(range(187))].values)[..., np.newaxis]


def get_model_ptbdb():
    nclass = 1
    inp = Input(shape=(187, 1))
    img_1 = Convolution1D(16, kernel_size=5, activation=activations.relu, padding="valid")(inp)
    img_1 = Convolution1D(16, kernel_size=5, activation=activations.relu, padding="valid")(img_1)
    img_1 = MaxPool1D(pool_size=2)(img_1)
    img_1 = Dropout(rate=0.1)(img_1)
    img_1 = Convolution1D(32, kernel_size=3, activation=activations.relu, padding="valid")(img_1)
    img_1 = Convolution1D(32, kernel_size=3, activation=activations.relu, padding="valid")(img_1)
    img_1 = MaxPool1D(pool_size=2)(img_1)
    img_1 = Dropout(rate=0.1)(img_1)
    img_1 = Convolution1D(32, kernel_size=3, activation=activations.relu, padding="valid")(img_1)
    img_1 = Convolution1D(32, kernel_size=3, activation=activations.relu, padding="valid")(img_1)
    img_1 = MaxPool1D(pool_size=2)(img_1)
    img_1 = Dropout(rate=0.1)(img_1)
    img_1 = Convolution1D(256, kernel_size=3, activation=activations.relu, padding="valid")(img_1)
    img_1 = Convolution1D(256, kernel_size=3, activation=activations.relu, padding="valid")(img_1)
    img_1 = GlobalMaxPool1D()(img_1)
    img_1 = Dropout(rate=0.2)(img_1)

    dense_1 = Dense(64, activation=activations.relu, name="dense_1")(img_1)
    dense_1 = Dense(64, activation=activations.relu, name="dense_2")(dense_1)
    dense_1 = Dense(nclass, activation=activations.sigmoid, name="dense_3_ptbdb")(dense_1)

    model = models.Model(inputs=inp, outputs=dense_1)
    opt = optimizers.Adam(0.001)

    model.compile(optimizer=opt, loss=losses.binary_crossentropy, metrics=['acc'])
    model.summary()
    return model

model = get_model_ptbdb()
file_path = "/content/drive/My Drive/kaggle_ECG/baseline_cnn_ptbdb_transfer_fullupdate.h5"
checkpoint = ModelCheckpoint(file_path, monitor='val_acc', verbose=1, save_best_only=True, mode='max')
early = EarlyStopping(monitor="val_acc", mode="max", patience=5, verbose=1)
redonplat = ReduceLROnPlateau(monitor="val_acc", mode="max", patience=3, verbose=2)
callbacks_list = [checkpoint, early, redonplat]  # early
model.load_weights("/content/drive/My Drive/kaggle_ECG/baseline_cnn_mitbih.h5", by_name=True)
model.fit(X, Y, epochs=1000, verbose=2, callbacks=callbacks_list, validation_split=0.1)
model.load_weights(file_path)

pred_test = model.predict(X_test)
pred_test = (pred_test>0.5).astype(np.int8)

f1 = f1_score(Y_test, pred_test)

print("Test f1 score : %s "% f1)

acc = accuracy_score(Y_test, pred_test)

print("Test accuracy score : %s "% acc)

Discussion

コメントにはログインが必要です。