Ccmmutty logo
Commutty IT
109 min read

04 Deep Learningフレームワークの基礎

https://picsum.photos/seed/458e752d19ff4a6999507e810e72cd21/600/800
ChainerはDeep Learningフレームワークの一つで,現在様々なDeep Learningフレームワーク(TensorFlow, PyTorch, etc.)でも採用され主要なニューラルネットワークの記法となっているDefine-by-Runというアイデアを汎用的なDeep Learningフレームワークとしては初めて採用し,2015年からPreferred Networks社によって開発が続けられています.Define-by-Runとは,ニューラルネットワーク中の計算を行うコードを記述することでニューラルネットワークの構造を定義する考え方です.学習を行う前にネットワーク構造を定義しておき,そのネットワークに学習に用いるデータを入力するためのコードを別途書く必要がある方法はDefine-and-Runと呼ばれます.Define-by-Runは実行時にネットワーク構造が決定されるため,動的な構造を記述しやすいという特徴があります.
ここでは,その柔軟性直感的であることを特徴とするこのChainerというフレームワークの基本的な使い方を解説します.

環境構築

まずはColab上で以下のセルを実行し,必要なライブラリをインストールしましょう.ここではgraphvizというソフトウェアをインストールしています.これは,後にニューラルネットワークのアーキテクチャをグラフ構造として可視化するために使用します.Google Colab上には,ChainerやCuPyは予めインストールされています.
python
!apt-get install -y graphviz
Reading package lists... Done Building dependency tree
Reading state information... Done graphviz is already the newest version (2.40.1-2). The following packages were automatically installed and are no longer required: cuda-cufft-10-1 cuda-cufft-dev-10-1 cuda-curand-10-1 cuda-curand-dev-10-1 cuda-cusolver-10-1 cuda-cusolver-dev-10-1 cuda-cusparse-10-1 cuda-cusparse-dev-10-1 cuda-license-10-2 cuda-npp-10-1 cuda-npp-dev-10-1 cuda-nsight-10-1 cuda-nsight-compute-10-1 cuda-nsight-systems-10-1 cuda-nvgraph-10-1 cuda-nvgraph-dev-10-1 cuda-nvjpeg-10-1 cuda-nvjpeg-dev-10-1 cuda-nvrtc-10-1 cuda-nvrtc-dev-10-1 cuda-nvvp-10-1 libcublas10 libnvidia-common-430 nsight-compute-2019.5.0 nsight-systems-2019.5.2 Use 'apt autoremove' to remove them. 0 upgraded, 0 newly installed, 0 to remove and 5 not upgraded.
それでは,以下のコマンドをターミナルで実行し,Chainerや,ChainerでGPUを活用するために必要となるCuPyというパッケージが正しくインストールされているかどうかを確認してみましょう.
python
!python -c 'import chainer; chainer.print_runtime_info()'
Platform: Linux-4.14.137+-x86_64-with-Ubuntu-18.04-bionic Chainer: 6.5.0 ChainerX: Not Available NumPy: 1.17.4 CuPy: CuPy Version : 6.5.0 CUDA Root : /usr/local/cuda CUDA Build Version : 10010 CUDA Driver Version : 10010 CUDA Runtime Version : 10010 cuDNN Build Version : 7603 cuDNN Version : 7603 NCCL Build Version : 2402 NCCL Runtime Version : 2402 iDeep: 2.0.0.post3
Chainer, NumPy, そしてCuPy, さらにCuPyの下にCUDAやcuDNN, NCCLといった項目があり,それぞれバージョン番号が表示されていれば成功です.

Chainerの基本的な使い方

はじめに,シンプルなタスクに実際に取り組むことによって,Chainerの基本的な使い方を説明していきます.さっそく,有名な手書き数字のデータセットMNISTを使って,画像を10クラス(数字の0 - 9)のいずれかに分類するネットワークを書き,学習させてみましょう.

データセットの準備

まずは学習対象となるデータセットの準備をします.教師あり学習の場合,データセットは**「入力データ」と「それと対になるラベルデータ」のペアを返すオブジェクト**である必要があります.
Chainerには,MNISTやCIFAR10/100のような良く用いられるデータセットに対して,データのダウンロードからオブジェクト作成までを自動的に行ってくれる便利なメソッドがあります.ここではひとまずこれを用いましょう.
python
from chainer.datasets import mnist

# データセットがダウンロード済みでなければ,ダウンロードも行う
train_val, test = mnist.get_mnist(withlabel=True, ndim=1)
データセットオブジェクトの準備ができました.このオブジェクトは, train_val[i] のように指定すると,i番目の (data, label) というタプルを返すリスト と同様のものと考えてください.(実際ただのPythonリストもChainerのデータセットオブジェクトとして利用可能です).それでは,0番目のデータとラベルを取り出して,表示してみましょう.
python
# matplotlibを使ったグラフ描画結果がnotebook内に表示されるようにします.
%matplotlib inline
import matplotlib.pyplot as plt

# データの例示
x, t = train_val[0]  # 0番目の (data, label) を取り出す
plt.imshow(x.reshape(28, 28), cmap='gray')
plt.axis('off')
plt.show()
print('label:', t)
png
label: 5

Validation用データセットを作る

次に,さきほど作成したtrain_valデータセットを,Training用のデータセットとValidation用のデータセットに分割します.Validationデータセットとは,学習には用いずにモデルの汎化性能をチェックしたり,学習率などのハイパーパラメータを調整するために用いる検証用のデータセットスプリットのことです.分割処理も,Chainerが提供しているデータセット分割用の関数を用いて行うことができます.元々60000個のデータが入っているtrainデータセットを,ランダムに選択された50000個のデータと残りの10000個のデータの2つに分割しましょう.これには,split_dataset_randomという関数を使用します.
python
from chainer.datasets import split_dataset_random

train, valid = split_dataset_random(train_val, 50000, seed=0)
関数の第1引数が分割したい対象のデータセットオブジェクト,第2引数が1つ目のデータセットの要素数,第3引数がランダムな抽出を行う際に用いられる乱数シード(これは省略可)となります.第3引数のseedとして同じ値を指定すると,再実行した際にデータセットを同じように分割するようになります.それでは,それぞれのデータセットの中に入っているデータの数を確認してみましょう.
python
print('Training dataset size:', len(train))
print('Validation dataset size:', len(valid))
Training dataset size: 50000 Validation dataset size: 10000

Iteratorの作成

次に,さきほど準備したデータセットオブジェクトから,幾つかのデータ(入力とラベルのペア)を束ねて学習モデルに次々に渡す,Iteratorという機能を紹介します.なぜIteratorの機能が必要かというと,ニューラルネットワークのパラメータを更新する際に利用される,確率的勾配降下法(Stochastic Gradient Descent, SGD)をはじめとする最適化手法では,一つのデータだけを元に更新する処理を繰り返すのではなく,幾つかのデータを束ねた ミニバッチ を元に計算していくのが一般的となっているためです(ミニバッチ計算が一般的である理由としては,勾配のミニバッチ平均を計算することでパラメータ更新が安定することや,GPUなどを用いた並列化がしやすいこと等が挙げられます).
Iteratorは,さきほど作成したデータセットオブジェクトを引数として指定し,next()メソッドを呼ぶことで新しいミニバッチを返してくれます.データセット内のデータすべてを1度ずつ学習に利用し終えた時点のことを 1エポック(epoch) と呼びます.Iteratorの内部では,学習中に何エポックまで学習を行ったか,などの情報が逐次記録されており,データセット内のデータを何度も使って学習のループを回すようなコードを簡単に書くことができるようになります.
データセットオブジェクトからイテレータを作るには,以下のようにします.
python
from chainer import iterators

batchsize = 128

train_iter = iterators.SerialIterator(train, batchsize)
valid_iter = iterators.SerialIterator(
    valid, batchsize, repeat=False, shuffle=False)
test_iter = iterators.SerialIterator(
    test, batchsize, repeat=False, shuffle=False)
今,学習データセット用のイテレータ(train_iter)と,検証データセット用のイテレータ(valid_iter),および学習したネットワークの評価に用いるテストデータセット用のイテレータ(test_iter)の計3つを作成しました.ここではbatchsize = 128としているため,作成した3つのイテレータはnext()メソッドが(train_iter.next()のように)呼ばれると,128枚の数字画像データを一括りにして返します.実際にnext()の返り値を調べてみましょう.
python
minibatch = train_iter.next()
このminibatchという変数は,(img, label)というタプルが128個(ミニバッチサイズだけ)並んだリストになっています.実際に,このリストの長さが128であることを確認してみましょう.
python
print('batchsize:', len(minibatch))
batchsize: 128
次に,このminibatchというリストの一つ目の要素(画像とラベルを持つタプルになっているはずです)をminibatch[0]として取り出してみます.
python
x, t = minibatch[0]

print('x:', x.shape)
print('t:', t.shape)
x: (784,) t: ()
そのときの返り値である2つの配列 xt のshapeを調べてみると,データはそれぞれ長さ784のベクトルとして格納されており,正解ラベルはスカラー値となっています.784は,28×2828 \times 28で,28ピクセル四方の画像データの画素値を1列に並べたものになっています.
SerialIteratorについて
Chainerにいくつか用意されているイテレータの一種であるSerialIteratorは,データセットの中のデータを順番に取り出してくる最もシンプルなイテレータです.SerialIterator のコンストラクタ(クラスをインスタンス化するタイミングで呼ばれるメソッド)の引数にデータセットオブジェクトと,バッチサイズを取ります.このとき,渡したデータセットオブジェクトから,データを繰り返し読み出す必要がある場合はrepeat引数をTrueとし,1周が終わったらそれ以上データを取り出したくない場合はこれをFalseとします.これは,主にvalidation用のデータセットに対して使うフラグです.デフォルトでは,Trueになっています.また,shuffle引数にTrueを渡すと,データセットから取り出されてくるデータの順番をエポックごとにランダムに変更します.SerialIteratorの他にも,マルチプロセスで高速にデータを処理できるようにしたMultiprocessIteratorMultithreadIteratorなど,複数のイテレータが用意されています.詳しくは以下を見てください.

ネットワークの定義

それでは,学習させるネットワークを定義してみましょう.今回は,全結合層のみからなるニューラルネットワーク(多層パーセプトロン)を作ることにして,中間層のユニット数は100とします.今回用いるMNISTデータセットは0〜9までの数字のいずれかを意味する10種のラベルを持つことから,出力ユニット数は10とします.
ここで,ネットワークを定義するために必要なLink, Function, Chainについて簡単に説明します.
LinkとFunction
Chainerでは,ニューラルネットワークの各層を,LinkFunctionに区別します.
  • Linkは,パラメータを持つ関数です.
  • Functionは,パラメータを持たない関数です.
これらを組み合わせてネットワークを記述します.パラメータを持つ層は,chainer.linksモジュール以下に用意されています.例えば chainer.links.Linear は,前章で説明した全結合層に対応しており,内部に Wb という学習できるパラメータが保持されています.パラメータを持たない層は,chainer.functionsモジュール以下に用意されています.これらに簡単にアクセスするために,
import chainer.links as L
import chainer.functions as F
と別名を与えて,L.Convolution2D(...)F.relu(...)のように用いる慣習がありますが,特にこれが決まった書き方というわけではありません.
Chain
Chainは,パラメータを持つ層(Link)をまとめておくためのクラスです.パラメータを持つということは,基本的にネットワークの学習の際にそれらを更新していく必要があるということです(更新されないパラメータを持たせることもできます).Chainerでは,モデルのパラメータの更新は,Optimizerという機能が担います.その際,更新すべき全てのパラメータを簡単に発見できるように,Chainで一箇所にまとめておきます.
同じ結果を保証する
ネットワークを書き始める際に乱数シードを固定すると,本記事とほぼ同様の結果が再現できるようになります.(cuDNNが有効になっている環境下でより厳密に計算結果の再現性を保証したい場合は,chainer.config.cudnn_deterministicというConfiguringオプションについて知る必要があります.こちらのドキュメントを参照してください:chainer.config.cudnn_deterministic
python
import random
import numpy
import chainer

def reset_seed(seed=0):
    random.seed(seed)
    numpy.random.seed(seed)
    if chainer.cuda.available:
        chainer.cuda.cupy.random.seed(seed)
        
reset_seed(0)
Chainを継承したネットワークの定義
Chainerでは,ネットワークは Chain クラスを継承したクラスとして定義されることが一般的です. Chain を継承することで,中間層のユニット数=100,出力ユニット数=10とした3層の多層パーセプトロンは以下のように書くことができます.
python
import chainer
import chainer.links as L
import chainer.functions as F

class MLP(chainer.Chain):

    def __init__(self, n_mid_units=100, n_out=10):
        super(MLP, self).__init__()
        
        # パラメータを持つ層の登録
        with self.init_scope():
            self.l1 = L.Linear(None, n_mid_units)
            self.l2 = L.Linear(n_mid_units, n_mid_units)
            self.l3 = L.Linear(n_mid_units, n_out)

    def forward(self, x):
        # データを受け取った際のforward計算を書く
        h1 = F.relu(self.l1(x))
        h2 = F.relu(self.l2(h1))
        return self.l3(h2)

gpu_id = 0  # CPUを用いる場合は,この値を-1にしてください

net = MLP()

if gpu_id >= 0:
    net.to_gpu(gpu_id)
継承した MLP クラスのコンストラクタ内で with self.init_scope() が呼ばれており,その中でネットワークに登場するLink (具体的には,全結合層の L.Linear )が定義されています.このような形で記述することで,Optimizer はこれらが最適化対象となるパラメータを持つ層であると自動的に解釈してくれるようになります.
また, forward というメソッドには,関数の名前の通り,ネットワークの順伝播を記述します.forward の引数としてデータ x を受け取り,出力として順伝播の計算結果を返すようにすることで, MLP クラスをインスタンス化して作成されたオブジェクトを,関数のように使えるようになります.(例:output = net(data)
Chainerには数多くの FunctionLink が用意されています.ぜひ一度以下の一覧のページを見てみてください.
Linkには,ニューラルネットワークによく用いられる全結合層や畳み込み層,LSTMなどに加えて,ResNetや,VGGなどの有名なネットワーク構造も登録されています.また,Functionには,ReLUなどの活性化関数や,画像の大きさをresizeする関数,サイン・コサインのような関数を始め,ネットワークの要素として使える関数が登録されています.Define-by-Runでは,データをネットワークに入力して順伝播計算を行ったあとに,データに適用された関数(パラメータあり・なし両方)の履歴をたどり直すことで,バックプロパゲーションによる勾配計算を行うパスを取得するため,パラメータを持たない関数であっても chainer.functions に含まれているものを繋げて用いる必要があります.
GPUで実行するには
深層学習で用いられるような多くのパラメータを持ったネットワークの学習には,GPUを用いることが一般的となっています.GPUを使うと,行列演算などの一部の処理をCPUに比べとても高速に行うことができます.Chainerで計算をGPUで行う方法は簡単です.Chainクラスはto_gpuメソッドを持ち,この引数にGPU IDを指定すると,指定したGPU IDのメモリ上にネットワークの全パラメータを転送します.こうしておくと,順伝播も学習の際のパラメータ更新なども全てGPU上で行われるようになります.GPU IDとして-1を指定すると,CPUを使用します.
入力側ユニット数の自動計算
上のネットワーク定義で,最初のLinear層は第一引数にNoneが渡されています.このように引数を指定すると,データが最初にその層に入力されたタイミングで,自動的に必要な数の入力側のユニット数を判断し, n_input ×\times n_mid_units の大きさの行列を作成し,学習対象パラメータとして保持します.これは後々,畳み込み層を全結合層の前に配置する際などに便利な機能となるため,覚えておいてください.

最適化手法の選択

それでは,上で定義したネットワークをMNISTデータセットを使って訓練してみましょう.学習時に用いる最適化の手法は数多く提案されていますが,Chainerは多くの手法を同一のインターフェースで利用できるよう,Optimizerという機能でそれらを提供しています.chainer.optimizersモジュール以下に定義されています.一覧はこちらにあります:
ここでは最もシンプルな勾配降下法の手法であるoptimizers.SGDを用います.Optimizerのオブジェクトには,setupメソッドを使ってモデル(Chainオブジェクト)を渡します.こうすることでOptimizerに,何を最適化すればいいか把握させることができます.
他にもいろいろな最適化手法が手軽に試せるので,色々と試してみて結果の変化を見てみてください.例えば,下のchainer.optimizers.SGDのうちSGDの部分をMomentumSGD, RMSprop, Adamなどに変えるだけで,最適化手法の違いがどのような学習曲線(ロスカーブとも言う.目的関数の値のプロットのこと)の違いを生むかなどを簡単に調べることができます.最適化の手法によっては,人が与える必要があった学習率を適切に自動決定するものもあります.
python
from chainer import optimizers

optimizer = optimizers.SGD(lr=0.01).setup(net)
学習率(learning rate)
今回はSGDのlrという引数に 0.010.01 を与えました.この値は学習率として知られ,モデルをうまく訓練して良いパフォーマンスを発揮させるために調整する必要がある重要なハイパーパラメータとして知られています.ハイパーパラメータは学習されるパラメータとは異なり人が手で与える学習の設定に関するものやネットワークの構造に関するもののことを指します.

学習の開始

今回は0〜9の数字を区別する分類問題なので,softmax_cross_entropyという損失関数を使って最小化すべき損失を計算します.Softmax関数は,dd次元のベクトルyRd{\bf y} \in \mathbb{R}^dが与えられたとき,その各次元の値の合計が1になるように正規化することができます.すなわち,確率分布のような出力を任意の実数ベクトルから作ることができます.y{\bf y}ii番目の次元をyiy_iと書くと,Softmax関数は
pi=exp(yi)j=1dexp(yj) p_i = \frac{\exp(y_i)}{\sum_{j=1}^d \exp(y_j)}
と表せます.これによって正規化された出力ベクトルを入力が各クラスに所属する確率を表しているものと考え,正解の1-hotベクトルとの間で前章で説明した交差エントロピーを計算するのが softmax_cross_entropy 関数です.
まずネットワークにデータを渡し,順伝播により予測値を計算します.そして,この予測値と入力データに対応する正解ラベルを損失関数に渡して損失(最小化したい値)を計算をします.損失は,chainer.Variableのオブジェクトとして得られます.このVariableは,過去の計算の履歴を覚えていて,辿れるようになっています.この仕組みが,Define-by-Run [Tokui 2015]とよばれる発明の中心的な役割を果たしています.
計算した損失に対する勾配をネットワークに逆向きに計算していく処理は,Chainerではネットワークが出力したVariableから,backwardメソッドを呼ぶだけで実現できます.これを呼ぶことで,誤差逆伝播用の計算グラフを構築し,途中のパラメータの勾配を連鎖率を使って計算してくれます.(詳しくは日本ソフトウェア科学会におけるチュートリアルの資料をご覧ください.)
最後に,計算された各パラメータに対する勾配を用いて,Optimizerによってネットワークパラメータの更新(=学習)が行われます.
まとめると,一連の更新処理の中で行われるのは,以下の4項目となります.
  1. ネットワークにデータを渡して順伝播を計算し,出力yを得る
  2. 出力yと正解ラベルtを使って,最小化すべき損失をsoftmax_cross_entropy関数で計算する
  3. softmax_cross_entropy関数の出力(Variable)のbackwardメソッドを呼んで,ネットワークの全てのパラメータの勾配を誤差逆伝播法で計算する
  4. Optimizerのupdateメソッドを呼び,3.で計算した勾配を使って全パラメータを更新する
パラメータの更新は,上記ステップを繰り返すことで行われます.一度のパラメータ更新に用いられるデータは,ネットワークに入力された,ミニバッチとして束ねられたデータのみです.次々と新しいミニバッチを入力し,上記のステップを繰り返すことで,データセット全体を用いて学習を行います.この過程を学習ループと呼んでいます.
目的関数
目的関数として,例えば分類問題ではなく回帰問題を解きたいような場合,F.softmax_cross_entropyの代わりにF.mean_squared_errorなどを用いることもできます.他にも,いろいろな問題設定に対応するために様々な損失関数がChainerには用意されています.こちらからその一覧を見ることができます:
学習ループのコード
python
import numpy as np
from chainer.dataset import concat_examples
from chainer.cuda import to_cpu

max_epoch = 10

while train_iter.epoch < max_epoch:
    
    # ---------- 学習の1イテレーション ----------
    train_batch = train_iter.next()
    x, t = concat_examples(train_batch, gpu_id)
    
    # 予測値の計算
    y = net(x)

    # 損失の計算
    loss = F.softmax_cross_entropy(y, t)

    # 勾配の計算
    net.cleargrads()
    loss.backward()

    # パラメータの更新
    optimizer.update()
    # --------------- ここまで ----------------

    # 1エポック終了ごとにValidationデータに対する予測精度を測って,
    # モデルの汎化性能が向上していることをチェックしよう
    if train_iter.is_new_epoch:  # 1 epochが終わったら

        # 損失の表示
        print('epoch:{:02d} train_loss:{:.4f} '.format(
            train_iter.epoch, float(to_cpu(loss.data))), end='')

        valid_losses = []
        valid_accuracies = []
        while True:
            valid_batch = valid_iter.next()
            x_valid, t_valid = concat_examples(valid_batch, gpu_id)

            # Validationデータをforward
            with chainer.using_config('train', False), \
                    chainer.using_config('enable_backprop', False):
                y_valid = net(x_valid)

            # 損失を計算
            loss_valid = F.softmax_cross_entropy(y_valid, t_valid)
            valid_losses.append(to_cpu(loss_valid.array))

            # 精度を計算
            accuracy = F.accuracy(y_valid, t_valid)
            accuracy.to_cpu()
            valid_accuracies.append(accuracy.array)
                        
            if valid_iter.is_new_epoch:
                valid_iter.reset()
                break

        print('val_loss:{:.4f} val_accuracy:{:.4f}'.format(
            np.mean(valid_losses), np.mean(valid_accuracies)))
        
# テストデータでの評価
test_accuracies = []
while True:
    test_batch = test_iter.next()
    x_test, t_test = concat_examples(test_batch, gpu_id)

    # テストデータをforward
    with chainer.using_config('train', False), \
            chainer.using_config('enable_backprop', False):
        y_test = net(x_test)

    # 精度を計算
    accuracy = F.accuracy(y_test, t_test)
    accuracy.to_cpu()
    test_accuracies.append(accuracy.array)

    if test_iter.is_new_epoch:
        test_iter.reset()
        break

print('test_accuracy:{:.4f}'.format(np.mean(test_accuracies)))
epoch:01 train_loss:0.9100 val_loss:0.9743 val_accuracy:0.8018 epoch:02 train_loss:0.5396 val_loss:0.5336 val_accuracy:0.8645 epoch:03 train_loss:0.4012 val_loss:0.4230 val_accuracy:0.8847 epoch:04 train_loss:0.3329 val_loss:0.3741 val_accuracy:0.8941 epoch:05 train_loss:0.4588 val_loss:0.3455 val_accuracy:0.9002 epoch:06 train_loss:0.2481 val_loss:0.3274 val_accuracy:0.9074 epoch:07 train_loss:0.3306 val_loss:0.3109 val_accuracy:0.9118 epoch:08 train_loss:0.3801 val_loss:0.2990 val_accuracy:0.9145 epoch:09 train_loss:0.2974 val_loss:0.2886 val_accuracy:0.9180 epoch:10 train_loss:0.3216 val_loss:0.2803 val_accuracy:0.9204 test_accuracy:0.9234
val_accuracyに着目してみると,最終的におよそ92%程度の精度で手書きの数字が分類できるようになりました.ここで言う精度とは,Validationデータセット中に NN 個のデータがあり分類結果が正しかったものが MM 個あるとすると M/NM/N を指します.学習中は,各ループの終わりに始めに取り分けておいたValidationデータセットを使って精度をはかることで,モデルの汎化性能をチェックしています.汎化性能とは,主に未知のデータに対する性能の高さのことを意味します.学習終了後には,テスト用のデータセットを用いて,学習が完了したネットワークの評価を行います.テストデータでの評価結果は,およそ92.37%の正解率となりました.
ValidationやTestを行う際の注意点
学習終了後の最終的な評価には,ハイパーパラメータ調整などにも用いられるValidationデータセットとはさらに別のTestデータセットを用います.TestデータセットはTrainingデータセットともValidationデータセットともデータの重複がないように用意しておきます.
さて,これまでは主に,「学習」のやり方について説明してきましたが,「評価」を行う際には注意すべき点があります.なぜなら,一部の関数や,計算過程において,学習時と評価時でその挙動が異なるためです.以下では,それらの挙動の違いを制御するための方法について説明します.
chainer.using_config('train', False)
先程の例では,学習時と推論時で動作が異なる関数は含まれていませんでしたが,Validationやテストのために推論を行うときは以下のように,chainer.using_config('train', False)をwith構文と共に使うことで,その中では対応する関数が推論モードとして動作することになります.これによって,学習時と推論時で挙動が異なる関数などが正しく推論のための動作をするようになります(例えば,Dropoutなど).詳しくはこちらの train の項をお読みください:Configuration Keys
python
with chainer.using_config('train', False):
    --- 何か推論処理 ---
chainer.using_config('enable_backprop', False)
評価のみ行うことを考えた場合,出力の計算後に損失関数の各パラメータについての勾配の情報は不要なため,chainer.using_config('enable_backprop', False)とすることで,無駄な計算グラフの構築が行われず,メモリ消費量を節約することができます.詳しくはこちらの enable_backprop の項をお読みください:Configuration Keys
ChainerのConfig
Chainerにはこの他にも,いくつかのグローバルなConfigが用意されています.また,chainer.config以下にユーザが自由な設定値を置くこともできます.詳しくはこちらを一読してください:Configuring Chainer

学習済みモデルの保存

学習が終了後,その結果を保存します.Chainerには,2種類のフォーマットで学習済みネットワークをファイルに保存する機能が用意されています.一つはHDF5形式,もう一つはNumPyのNPZ形式で,ネットワークを保存します.今回は,追加ライブラリのインストールが必要なHDF5ではなく,NumPy標準機能で提供されているシリアライズ機能(numpy.savez())を利用したNPZ形式でのモデルの保存を行います.
python
from chainer import serializers

serializers.save_npz('my_mnist.model', net)
python
# 保存されていることを確認
%ls -la my_mnist.model
-rw-r--r-- 1 root root 334084 Dec 9 11:13 my_mnist.model

保存したモデルを読み込んで推論

学習が終了して保存したモデルを読み込み,推論を行う方法について説明します.はじめに,学習に利用したネットワークを再度インスタンス化して,そこにさきほど保存したNPZファイルを読み込ませます.
python
# まず同じネットワークのオブジェクトを作る
infer_net = MLP()

# そのオブジェクトに保存済みパラメータをロードする
serializers.load_npz('my_mnist.model', infer_net)
以上で準備が整いました.それでは,試しにテストデータの中から一つ目の画像を取ってきて,それに対する分類を行ってみましょう.
python
gpu_id = 0  # CPUで計算をしたい場合は,-1を指定してください

if gpu_id >= 0:
    infer_net.to_gpu(gpu_id)

# 1つ目のテストデータを取り出します
x, t = test[0]  #  tは使わない

# どんな画像か表示してみます
plt.imshow(x.reshape(28, 28), cmap='gray')
plt.show()

# ミニバッチの形にする(複数の画像をまとめて推論に使いたい場合は,サイズnのミニバッチにしてまとめればよい)
print('元の形:', x.shape, end=' -> ')

x = x[None, ...]

print('ミニバッチの形にしたあと:', x.shape)

# ネットワークと同じデバイス上にデータを送る
x = infer_net.xp.asarray(x)

# モデルのforward関数に渡す
with chainer.using_config('train', False), chainer.using_config('enable_backprop', False):
    y = infer_net(x)

# Variable形式で出てくるので中身を取り出す
y = y.array

# 結果をCPUに送る
y = to_cpu(y)

# 予測確率の最大値のインデックスを見る
pred_label = y.argmax(axis=1)

print('ネットワークの予測:', pred_label[0])
png
元の形: (784,) -> ミニバッチの形にしたあと: (1, 784)
ネットワークの予測: 7
ネットワークの予測は7でした.画像を見る限り,正しく予測できていることが確認できます.

Trainerの使用方法

Chainerは,これまで書いてきたような学習ループを隠蔽するTrainerという機能を提供しています.これを使うと,学習ループを自ら書く必要がなくなり,また便利な拡張機能(Extention)を使うことで,学習過程での学習曲線の可視化や,ログの保存なども簡単に行うことができます.

データセット・Iterator・ネットワークの準備

データセット,Iterator,ネットワークは,Trainerを使用する場合にも同様に準備します.
python
reset_seed(0)

train_val, test = mnist.get_mnist()
train, valid = split_dataset_random(train_val, 50000, seed=0)

batchsize = 128

train_iter = iterators.SerialIterator(train, batchsize)
valid_iter = iterators.SerialIterator(valid, batchsize, False, False)
test_iter = iterators.SerialIterator(test, batchsize, False, False)

gpu_id = 0  # CPUを用いたい場合は,-1を指定してください

net = MLP()

if gpu_id >= 0:
    net.to_gpu(gpu_id)

Updaterの準備

学習ループを自分で書く場合の学習ステップについて再度確認すると,「データセットからミニバッチを作成」「ネットワークに入力して予測を出力」「正解と比較し誤差を計算」「バックワード(誤差逆伝播)を実行」「Optimizerによってパラメータを更新」という一連のステップを,以下のように書いていました.
python
# ---------- 学習の1イテレーション ----------
train_batch = train_iter.next()
x, t = concat_examples(train_batch, gpu_id)

# 予測値の計算
y = net(x)

# 損失の計算
loss = F.softmax_cross_entropy(y, t)

# 勾配の計算
net.cleargrads()
loss.backward()

# パラメータの更新
optimizer.update()
Chainerの機能として提供されているUpdaterを用いることで,これらの一連の処理を簡単に書けるようになります.UpdaterにはIteratorOptimizerを渡します. Iteratorはデータセットオブジェクトを持っているため,そこからミニバッチを作成します.Optimizerは最適化対象のネットワークを持っているため,それを使って順伝播と誤差計算・パラメータのアップデートをすることができます.従って,この2つを渡すことで,Updater内で全ての処理が完結します.さっそく,Updaterオブジェクトを作成してみましょう.
python
from chainer import training

gpu_id = 0  # CPUを使いたい場合は-1を指定してください

# ネットワークをClassifierで包んで,損失の計算などをモデルに含める
net = L.Classifier(net)

# 最適化手法の選択
optimizer = optimizers.SGD(lr=0.01).setup(net)

# UpdaterにIteratorとOptimizerを渡す
updater = training.StandardUpdater(train_iter, optimizer, device=gpu_id)
損失計算のためのChain
ここでは,ネットワークをL.Classifierで包んでいます.L.Classifierは,渡されたネットワーク自体をpredictorというattributeに持ち,損失計算を行う機能を追加してくれます.こうすることで,net()はデータxだけでなくラベルtも取るようになり,受け取ったデータをpredictorに通して予測値を計算し,正解ラベルtと比較して**損失のVariableを返します.**損失関数として何を用いるかはデフォルトではF.softmax_cross_entropyとなっていますが,L.Classifierの引数lossfunに損失計算を行う関数を渡してやれば変更することができ,(Classifierという名前ながら)回帰問題などの損失計算機能の追加にも使うことができます.(L.Classifier(net, lossfun=L.mean_squared_error, compute_accuracy=False)のようにする)
StandardUpdaterは前述のようなUpdaterの担当する処理を遂行するための最もシンプルなクラスです.この他にも複数のGPUを用いるためのParallelUpdaterなどが用意されています.

Trainerの準備

実際に学習ループ部分を隠蔽しているのはUpdaterですが,TrainerはさらにUpdaterを受け取って学習全体の管理を行う機能を提供しています.例えば,データセットを何周したら学習を終了するか(stop_trigger) や,途中の損失の値をどのファイルに保存したいか学習曲線を可視化した画像ファイルを保存するかどうかなど,学習全体の設定として必須・もしくはあると便利な色々な機能を提供しています.
必須なものとしては学習終了のタイミングを指定するstop_triggerがありますが,これはTrainerオブジェクトを作成するときのコンストラクタで指定します.指定の方法は単純で,(長さ, 単位)という形のタプルを与えればよいだけです.「長さ」には数字を,「単位」には'iteration'もしくは'epoch'のいずれかの文字列を指定します.こうすると,たとえば100 epoch(データセット100周)で学習を終了してください,とか,1000 iteration(1000回更新)で学習を終了してください,といったことが指定できます.Trainerを作るときに,stop_triggerを指定しないと,学習は自動的には止まりません.
では,実際にTrainerオブジェクトを作ってみましょう.
python
max_epoch = 10

# TrainerにUpdaterを渡す
trainer = training.Trainer(
    updater, (max_epoch, 'epoch'), out='results/mnist_result')
out引数では,この次に説明するExtensionを使って,ログファイルや損失の変化の過程を描画したグラフの画像ファイルなどを保存するディレクトリを指定しています.
Trainerと,その内側にあるいろいろなオブジェクトの関係は,図にまとめると以下のようになっています.このイメージを持っておくと自分で部分的に改造したりする際に便利だと思います.
Trainerに関連するオブジェクト間の関係図

TrainerにExtensionを追加

Trainerを使う利点として,
  • ログを自動的にファイルに保存(LogReport)
  • ターミナルに定期的に損失などの情報を表示(PrintReport
  • 損失を定期的にグラフで可視化して画像として保存(PlotReport)
  • 定期的にモデルやOptimizerの状態を自動シリアライズ(snapshot
  • 学習の進捗を示すプログレスバーを表示(ProgressBar
  • ネットワークの構造をGraphvizのdot形式で保存(dump_graph
  • ネットワークのパラメータの平均や分散などの統計情報を出力(ParameterStatistics
などの様々な便利な機能を簡単に利用することができる点があります.これらの機能を利用するには,Trainerオブジェクトに対してextendメソッドを使って追加したいExtensionのオブジェクトを渡すだけです.では実際に幾つかのExtensionを追加してみましょう.
python
from chainer.training import extensions

trainer.extend(extensions.LogReport())
trainer.extend(extensions.snapshot(filename='snapshot_epoch-{.updater.epoch}'))
trainer.extend(extensions.Evaluator(valid_iter, net, device=gpu_id), name='val')
trainer.extend(extensions.PrintReport(['epoch', 'main/loss', 'main/accuracy', 'val/main/loss', 'val/main/accuracy', 'l1/W/data/std', 'elapsed_time']))
trainer.extend(extensions.ParameterStatistics(net.predictor.l1, {'std': np.std}))
trainer.extend(extensions.PlotReport(['l1/W/data/std'], x_key='epoch', file_name='std.png'))
trainer.extend(extensions.PlotReport(['main/loss', 'val/main/loss'], x_key='epoch', file_name='loss.png'))
trainer.extend(extensions.PlotReport(['main/accuracy', 'val/main/accuracy'], x_key='epoch', file_name='accuracy.png'))
trainer.extend(extensions.dump_graph('main/loss'))
LogReport
epochiterationごとのloss, accuracyなどを自動的に集計し,logというファイル名で保存します.
snapshot
Trainerオブジェクトを指定されたタイミング(デフォルトでは1エポックごと)で保存します.Trainerオブジェクトは上述のようにUpdaterを持っており,この中にOptimizerとモデルが保持されているため,このExtensionでスナップショットをとっておけば,その時点から学習を再開させたり,学習済みモデルを使った推論などが可能になります.
dump_graph
指定されたVariableオブジェクトから辿れる計算グラフをGraphvizのdot形式で保存します.
Evaluator
評価用のデータセットのIteratorと,学習に使うモデルのオブジェクトを渡しておくことで,学習中のモデルを指定されたタイミングで評価用データセットを用いて評価します.内部では,chainer.config.using_config('train', False)が自動的に行われます.backprop_enableFalseにすることは行われないため,メモリ使用効率はデフォルトでは最適ではありませんが,基本的にはEvaluatorを使えば評価を行えるという点において問題はありません.
PrintReport
LogReportと同様に集計された値を標準出力に出力します.この際,どの値を出力するかをリストの形で与えます.
PlotReport
引数のリストで指定された値の変遷をmatplotlibライブラリを使ってグラフに描画し,出力ディレクトリにfile_name引数で指定されたファイル名で画像として保存します.
ParameterStatistics
指定したレイヤ(Link)が持つパラメータの平均・分散・最小値・最大値などなどの統計情報を計算して,ログに保存します.パラメータが発散していないかなどをチェックするのに便利です.

これらのExtensionは,ここで紹介した以外にも,例えばtriggerによって個別に作動するタイミングを指定できるなどのいくつかのオプションを持っており,より柔軟に組み合わせることができます.詳しくは公式のドキュメントを見てください.

学習の開始 (Trainer利用)

学習を開始するために,Trainerオブジェクトのメソッドrunを実行してください.
python
trainer.run()
epoch main/loss main/accuracy val/main/loss val/main/accuracy l1/W/data/std elapsed_time 1 1.66917 0.599904 0.938911 0.806764 0.0359232 4.00874
2 0.673337 0.843211 0.519283 0.86699 0.0366054 6.59789
3 0.459913 0.878686 0.414855 0.887757 0.0370351 9.17739
4 0.38953 0.893262 0.370488 0.896855 0.037301 11.7382
5 0.353165 0.901215 0.342328 0.90447 0.03749 14.4128
6 0.33014 0.90609 0.32212 0.90981 0.037639 17.1353
7 0.312328 0.910906 0.30679 0.913172 0.0377671 19.834
8 0.298127 0.914704 0.295095 0.915744 0.0378811 22.4303
9 0.28583 0.917659 0.284156 0.918513 0.0379864 25.0918
10 0.275227 0.921096 0.274761 0.921677 0.0380852 27.7848
学習ループを自分で書いた場合よりも遥かに簡単に,同様の結果を得ることができました.さらに,Extensionの機能を利用することで,様々なスコアや,学習曲線の可視化も自動で出力されます.
では,保存されている損失のグラフを確認してみましょう.
python
from IPython.display import Image
Image(filename='results/mnist_result/loss.png')
png
精度のグラフも見てみましょう.
python
Image(filename='results/mnist_result/accuracy.png')
png
もう少し学習を続ければ,さらに精度の向上が期待できそうです.
最後に,dump_graphというExtensionによって出力された計算グラフのファイルを,Graphvizで画像化してみましょう.
python
!dot -Tpng results/mnist_result/cg.dot -o results/mnist_result/cg.png
python
Image(filename='results/mnist_result/cg.png')
png
データやパラメータが関数に次々と渡され,損失が出力されるまでの一連の計算過程が確認できます.

テストデータ評価

Validationデータに対する評価を学習中に行うために使用されるEvaluatorは,Trainerと関係なく独立して使うこともできます.以下のようにしてIteratorとネットワークのオブジェクト(net),使用するデバイスIDを渡してEvaluatorオブジェクトを作成し,これを関数として実行するだけです.
python
test_evaluator = extensions.Evaluator(test_iter, net, device=gpu_id)
results = test_evaluator()
print('Test accuracy:', results['main/accuracy'])
Test accuracy: 0.9250395

学習済みモデルで推論する

それでは,Trainer Extensionのsnapshotが保存した学習済みパラメータを読み込んで,以前と同様に1番目のテストデータで推論を行ってみましょう.
ここで一点注意が必要ですが,snapshotが保存するnpzファイルはTrainer全体のスナップショットとなっており,学習の再開に必要となるextensionの内部のパラメータなども一緒に保存されています.しかし,今回はネットワークのパラメータだけを読み込めば良いので, serializers.load_npz()path引数にネットワーク部分までのパスを指定します.こうすることで,ネットワークのオブジェクトにパラメータだけを読み込ませることができます.
python
reset_seed(0)

infer_net = MLP()
serializers.load_npz(
    'results/mnist_result/snapshot_epoch-10',
    infer_net, path='updater/model:main/predictor/')

if gpu_id >= 0:
    infer_net.to_gpu(gpu_id)

x, t = test[0]
plt.imshow(x.reshape(28, 28), cmap='gray')
plt.show()

x = infer_net.xp.asarray(x[None, ...])
with chainer.using_config('train', False), chainer.using_config('enable_backprop', False):
    y = infer_net(x)
y = to_cpu(y.array)

print('予測ラベル:', y.argmax(axis=1)[0])
png
予測ラベル: 7
無事正解できていることが確認できました.

新しいネットワークの利用

ここでは,MNISTデータセットではなくCIFAR10という32x32サイズの小さなカラー画像に10クラスのいずれかのラベルがついたデータセットを用いて,いろいろなモデルを自分で書いて試行錯誤する流れを体験してみます.
airplaneautomobilebirdcatdeerdogfroghorseshiptruck
AirplaneAutomobileBirdCatDeerDogFrogHorseShipTruck

新しいネットワークの定義

ここでは,さきほど試した全結合層だけからなるネットワークではなく,前章で紹介した,畳込み層を持つネットワークを定義してみます.3つの畳み込み層を持ち,2つの全結合層がそのあとに続いています.
python
class MyNet(chainer.Chain):
    
    def __init__(self, n_out):
        super(MyNet, self).__init__()
        with self.init_scope():
            self.conv1 = L.Convolution2D(None, 32, 3, 3, 1)
            self.conv2 = L.Convolution2D(32, 64, 3, 3, 1)
            self.conv3 = L.Convolution2D(64, 128, 3, 3, 1)
            self.fc4 = L.Linear(None, 1000)
            self.fc5 = L.Linear(1000, n_out)
        
    def forward(self, x):
        h = F.relu(self.conv1(x))
        h = F.relu(self.conv2(h))
        h = F.relu(self.conv3(h))
        h = F.relu(self.fc4(h))
        h = self.fc5(h)
        return h

学習

ここで,あとから別のネットワークも簡単に同じ設定で訓練できるよう,train関数を作っておきます.これは,
  • ネットワークのオブジェクト
  • バッチサイズ
  • 使用するGPU ID
  • 学習を終了するエポック数
  • データセットオブジェクト
  • 学習率の初期値
  • 学習率減衰のタイミング
などを渡すと,内部でTrainerを用いて渡されたデータセットを使ってネットワークを訓練し,学習が終了した状態のネットワークを返してくれる関数です.Trainer.run()が終了した後に,テストデータセットを使って評価まで行ってくれます.先程のMNISTでの例と違い,最適化手法にはMomentumSGDを用い,ExponentialShiftというExtentionを使って,指定したタイミングごとに学習率を減衰させるようにしてみます.
また,ここではcifar.get_cifar10()が返す学習用データセットのうち9割のデータをtrain,残りの1割をvalidとして使うようにしています.
このtrain関数を用いて,上で定義したMyNetモデルを訓練してみます.
python
from chainer.datasets import cifar


def train(network_object, batchsize=128, gpu_id=0, max_epoch=20, train_dataset=None, valid_dataset=None, test_dataset=None, postfix='', base_lr=0.01, lr_decay=None, snapshot=None):

    # 1. Dataset
    if train_dataset is None and valid_dataset is None and test_dataset is None:
        train_val, test = cifar.get_cifar10()
        train_size = int(len(train_val) * 0.9)
        train, valid = split_dataset_random(train_val, train_size, seed=0)
    else:
        train, valid, test = train_dataset, valid_dataset, test_dataset

    # 2. Iterator
    train_iter = iterators.MultiprocessIterator(train, batchsize)
    valid_iter = iterators.MultiprocessIterator(valid, batchsize, False, False)

    # 3. Model
    net = L.Classifier(network_object)

    # 4. Optimizer
    optimizer = optimizers.MomentumSGD(lr=base_lr).setup(net)
    optimizer.add_hook(chainer.optimizer.WeightDecay(0.0005))

    # 5. Updater
    updater = training.StandardUpdater(train_iter, optimizer, device=gpu_id)

    # 6. Trainer
    trainer = training.Trainer(updater, (max_epoch, 'epoch'), out='results/{}_cifar10_{}result'.format(network_object.__class__.__name__, postfix))
    
    # 7. Trainer extensions
    trainer.extend(extensions.LogReport())
    trainer.extend(extensions.observe_lr())
    trainer.extend(extensions.snapshot(filename='snapshot_epoch_{.updater.epoch}'), trigger=(10, 'epoch'))
    trainer.extend(extensions.Evaluator(valid_iter, net, device=gpu_id), name='val')
    trainer.extend(extensions.PrintReport(['epoch', 'main/loss', 'main/accuracy', 'val/main/loss', 'val/main/accuracy', 'elapsed_time', 'lr']))
    trainer.extend(extensions.PlotReport(['main/loss', 'val/main/loss'], x_key='epoch', file_name='loss.png'))
    trainer.extend(extensions.PlotReport(['main/accuracy', 'val/main/accuracy'], x_key='epoch', file_name='accuracy.png'))
    if lr_decay is not None:
        trainer.extend(extensions.ExponentialShift('lr', 0.1), trigger=lr_decay)
    if snapshot is not None:
        chainer.serializers.load_npz(snapshot, trainer)
    trainer.run()
    del trainer
    
    # 8. Evaluation
    test_iter = iterators.MultiprocessIterator(test, batchsize, False, False)
    test_evaluator = extensions.Evaluator(test_iter, net, device=gpu_id)
    results = test_evaluator()
    print('Test accuracy:', results['main/accuracy'])
    
    return net
python
net = train(MyNet(10), gpu_id=0)
epoch       main/loss   main/accuracy  val/main/loss  val/main/accuracy  elapsed_time  lr        
1           1.92583     0.305065       1.72466        0.39668            4.88937       0.01        
2           1.60857     0.423007       1.53026        0.463281           9.25879       0.01        
3           1.47209     0.46964        1.48127        0.478125           13.5662       0.01        
4           1.39223     0.499911       1.39299        0.499609           18.0876       0.01        
5           1.32882     0.526197       1.3789         0.511719           22.5673       0.01        
6           1.26765     0.547852       1.35271        0.516406           27.1432       0.01        
7           1.21327     0.568999       1.25582        0.560547           31.8979       0.01        
8           1.16433     0.583984       1.22899        0.570508           36.4486       0.01        
9           1.12036     0.602384       1.23554        0.565039           40.9875       0.01        
10          1.07057     0.61899        1.21995        0.56543            45.5839       0.01        
11          1.02992     0.636808       1.1724         0.585938           50.4524       0.01        
12          0.98116     0.653112       1.19605        0.579883           55.0429       0.01        
13          0.938254    0.667392       1.159          0.59375            59.4494       0.01        
14          0.901819    0.681067       1.20838        0.579492           64.0684       0.01        
15          0.855333    0.698287       1.19485        0.585938           68.5982       0.01        
16          0.810262    0.714321       1.19381        0.583984           73.0674       0.01        
17          0.764117    0.731423       1.21938        0.587109           77.5318       0.01        
18          0.72205     0.743697       1.20823        0.585742           81.9437       0.01        
19          0.666414    0.764712       1.23899        0.593164           86.2922       0.01        
20          0.620457    0.782715       1.24922        0.597461           90.6681       0.01        
Test accuracy: 0.6065071
学習が20エポックまで終わりました.損失と精度のプロットを見てみましょう.
python
Image(filename='results/MyNet_cifar10_result/loss.png')
png
python
Image(filename='results/MyNet_cifar10_result/accuracy.png')
png
学習データでの精度(main/accuracy)は77%程度まで到達していますが,テストデータでの損失(val/main/loss)は途中から下げ止まり,精度(val/main/accuracy)も60%前後で頭打ちになってしまっています.表示されたログの最後の行を確認すると,テストデータでの精度も同様に60%程度となっています.学習データでは精度が良いが, テストデータでは精度が良くない場合,モデルが学習データにオーバーフィッティングしていると考えられます.

学習済みネットワークを使った予測

テスト精度は60%程度でしたが,試しにこの学習済みネットワークを使っていくつかのテスト画像を分類させてみましょう.あとで使いまわせるようにpredict関数を作っておきます.
python
cls_names = ['airplane', 'automobile', 'bird', 'cat', 'deer',
             'dog', 'frog', 'horse', 'ship', 'truck']

def predict(net, image_id):
    _, test = cifar.get_cifar10()
    x, t = test[image_id]
    net.to_cpu()
    with chainer.using_config('train', False), chainer.using_config('enable_backprop', False):
        y = net.predictor(x[None, ...]).data.argmax(axis=1)[0]

    plt.imshow(x.transpose(1, 2, 0))
    plt.show()
    print('predicted_label:', cls_names[y])
    print('answer:', cls_names[t])

for i in range(10, 15):
    predict(net, i)
png
predicted_label: airplane
answer: airplane



png
predicted_label: truck
answer: truck



png
predicted_label: dog
answer: dog



png
predicted_label: horse
answer: horse



png
predicted_label: truck
answer: truck
うまく分類できているものもあれば,そうでないものもありました.ネットワークの学習に使用したデータセット上ではほぼ百発百中で正解できても,未知のデータ,すなわちテストデータセットの画像に対して高精度な予測ができなければ意味がありません.テストデータでの精度は,モデルの汎化性能に関係していると言われています.
どうすれば高い汎化性能を持つネットワークを設計し,学習することができるでしょうか?これは非常に難しい問いですが,機械学習を使った応用を考えるとき,最も重要な問いの一つです.

深いネットワークの定義

では,さきほどのネットワークよりも多層のネットワークを定義してみましょう.ここでは,1層の畳み込みネットワークをConvBlock,1層の全結合ネットワークをLinearBlockとして定義し,これを数多く積み重ねることで大きなネットワークを定義してみます.
構成要素を定義する
まず,ネットワークの構成要素となるConvBlockLinearBlockを定義してみましょう.
python
class ConvBlock(chainer.Chain):
    
    def __init__(self, n_ch, pool_drop=False):
        w = chainer.initializers.HeNormal()
        super(ConvBlock, self).__init__()
        with self.init_scope():
            self.conv = L.Convolution2D(None, n_ch, 3, 1, 1, nobias=True, initialW=w)
            self.bn = L.BatchNormalization(n_ch)
        self.pool_drop = pool_drop
        
    def forward(self, x):
        h = F.relu(self.bn(self.conv(x)))
        if self.pool_drop:
            h = F.max_pooling_2d(h, 2, 2)
            h = F.dropout(h, ratio=0.25)
        return h
    
class LinearBlock(chainer.Chain):
    
    def __init__(self, drop=False):
        w = chainer.initializers.HeNormal()
        super(LinearBlock, self).__init__()
        with self.init_scope():
            self.fc = L.Linear(None, 1024, initialW=w)
        self.drop = drop
        
    def forward(self, x):
        h = F.relu(self.fc(x))
        if self.drop:
            h = F.dropout(h)
        return h
ConvBlockChainを継承した小さなネットワークとして定義されており,一つの畳み込み層とBatch Normalization層で構成されます.Batch Normalization層は,ネットワークの学習プロセスを安定させるために広く利用されている手法の一つで,例えば今回のように,畳み込み層の直後に挿入する形で利用されます.forwardメソッドでは,これらにデータを渡しつつ,活性化関数ReLUを適用して,さらにpool_drop引数がTrueであれば,Max PoolingとDropoutを適用するような順伝播の計算が行われます.Dropoutは,ネットワークの過学習を避けて汎化性能を上げる目的で利用される手法の一つで,層の中のノードのうち,一定割合(dropout ratioと呼ばれる)をランダムに無効にしながら学習を行います(無効にする割合はratioという引数で指定でき,何も指定しなければ50%が無効化されます).推論時は,dropout ratioをppとすると,Dropout層への入力をただpp倍して出力するだけの層として働きます.これによって,擬似的に複数のネットワークの学習結果をアンサンブル(参考:Ensemble averaging)するような効果があると言われ,汎化性能が向上する場合があります.最適化の際にモデルのパラメータに何らかの制約を与えて汎化性能を向上させるための工夫は正則化(regularization)と呼ばれ,このDropoutやパラメータの絶対値が大きくなりすぎないようにするWeight decayなどの方法が知られています.
Chainerでは,Pythonを使って書いたforward計算のコード自体がネットワークの構造を表します.すなわち,実行時にデータがどの層を通過していったか,によってネットワークそのものが定義されます.この性質によって,上記のような分岐などを含むネットワークも簡単に記述でき,柔軟かつシンプルで可読性の高いネットワーク定義が可能になります.これがDefine-by-Runの大きな特徴となっています.
大きなネットワークの定義
次に,これらの小さなネットワークを構成要素として積み重ねて,大きなネットワークを定義してみましょう.
python
class DeepCNN(chainer.ChainList):

    def __init__(self, n_output):
        super(DeepCNN, self).__init__(
            ConvBlock(64),
            ConvBlock(64, True),
            ConvBlock(128),
            ConvBlock(128, True),
            ConvBlock(256),
            ConvBlock(256),
            ConvBlock(256),
            ConvBlock(256, True),
            LinearBlock(),
            LinearBlock(),
            L.Linear(None, n_output)
        )

    def forward(self, x):
        for f in self.children():
            x = f(x)
        return x
ここで,ChainListというクラスが利用されています.このクラスはChainを継承したクラスで,いくつものLinkChainを順次呼び出していくようなネットワークを定義するときに便利です.ChainListを継承して定義されるモデルは,親クラスのコンストラクタを呼び出す際に,キーワード引数ではなく通常の引数としてLinkもしくはChainオブジェクトを渡すことができ,self.children()メソッドによって登録した順番に取り出すことができます.この特徴を使うと,forward計算が上記のように簡単に記述可能となります.
高速化のTIPS
今回は多くの畳込み層を使う大きなネットワークを使うので,Chainerが用意してくれているcuDNNのautotune機能を有効にしてみます.やり方は簡単で,以下の二行を事前に実行しておくだけです.これを有効にすると,cuDNNが自動的に高速な畳み込みのアルゴリズムを選択するなどの実行時の調整を行ってくれるようになります.
python
chainer.cuda.set_max_workspace_size(1024 * 1024 * 1024)
chainer.config.autotune = True
それでは,学習を回してみます.今回はパラメータ数も多いので,学習を停止するエポック数を100に設定します.また,学習率を0.1から始めて,30エポックごとに10分の1にするように設定します.
本来は,以下の2行を実行することで乱数シードを固定し,100エポック分上で定義した DeepCNN というクラスが表すモデルの学習ができるのですが,これは40分以上の時間を要するので,今回は事前に90エポックまで学習を進めておいた重みを読み込んで,90エポック終了時点から学習を再開し,最後の10エポックだけ実際にここで学習を回すことにします.
python
!wget https://github.com/japan-medical-ai/medical-ai-course-materials/releases/download/v0.1/DeepCNN_cifar10_snapshot_epoch_90.npz

reset_seed(0)

model = train(DeepCNN(10), max_epoch=100, base_lr=0.1, lr_decay=(30, 'epoch'), snapshot='DeepCNN_cifar10_snapshot_epoch_90.npz')
--2019-12-09 11:16:00-- https://github.com/japan-medical-ai/medical-ai-course-materials/releases/download/v0.1/DeepCNN_cifar10_snapshot_epoch_90.npz Resolving github.com (github.com)... 13.250.177.223 Connecting to github.com (github.com)|13.250.177.223|:443... connected. HTTP request sent, awaiting response... 302 Found Location: https://github-production-release-asset-2e65be.s3.amazonaws.com/153412006/4fcc1200-eeb7-11e8-8ca0-9095e5bca078?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAIWNJYAX4CSVEH53A%2F20191209%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20191209T111601Z&X-Amz-Expires=300&X-Amz-Signature=f2d9063e8a6ec0edd8567b4ae8caeebbd09f2d1e6eabc59db80bcd2c519f2787&X-Amz-SignedHeaders=host&actor_id=0&response-content-disposition=attachment%3B%20filename%3DDeepCNN_cifar10_snapshot_epoch_90.npz&response-content-type=application%2Foctet-stream [following] --2019-12-09 11:16:01-- https://github-production-release-asset-2e65be.s3.amazonaws.com/153412006/4fcc1200-eeb7-11e8-8ca0-9095e5bca078?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAIWNJYAX4CSVEH53A%2F20191209%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20191209T111601Z&X-Amz-Expires=300&X-Amz-Signature=f2d9063e8a6ec0edd8567b4ae8caeebbd09f2d1e6eabc59db80bcd2c519f2787&X-Amz-SignedHeaders=host&actor_id=0&response-content-disposition=attachment%3B%20filename%3DDeepCNN_cifar10_snapshot_epoch_90.npz&response-content-type=application%2Foctet-stream Resolving github-production-release-asset-2e65be.s3.amazonaws.com (github-production-release-asset-2e65be.s3.amazonaws.com)... 52.216.164.75 Connecting to github-production-release-asset-2e65be.s3.amazonaws.com (github-production-release-asset-2e65be.s3.amazonaws.com)|52.216.164.75|:443... connected. HTTP request sent, awaiting response... 200 OK Length: 56889445 (54M) [application/octet-stream] Saving to: ‘DeepCNN_cifar10_snapshot_epoch_90.npz’
DeepCNN_cifar10_sna 100%[===================>]  54.25M  12.5MB/s    in 5.4s    

2019-12-09 11:16:07 (10.1 MB/s) - ‘DeepCNN_cifar10_snapshot_epoch_90.npz’ saved [56889445/56889445]



/usr/local/lib/python3.6/dist-packages/chainer/training/triggers/interval_trigger.py:92: UserWarning: The previous value of iteration is not saved. IntervalTrigger guesses it using current iteration. If this trigger is not called at every iteration, it may not work correctly.
  'The previous value of iteration is not saved. '
/usr/local/lib/python3.6/dist-packages/chainer/training/triggers/interval_trigger.py:104: UserWarning: The previous value of epoch_detail is not saved. IntervalTrigger uses the value of trainer.updater.previous_epoch_detail. If this trigger is not called at every iteration, it may not work correctly.
  'The previous value of epoch_detail is not saved. '


epoch       main/loss   main/accuracy  val/main/loss  val/main/accuracy  elapsed_time  lr        
1           2.62849     0.144931       2.22525        0.15625            27.9748       0.1         
2           2.11316     0.210804       1.97533        0.266406           54.5483       0.1         
3           1.87483     0.289396       1.8026         0.320508           81.0827       0.1         
4           1.74066     0.340443       1.74728        0.358203           107.7         0.1         
5           1.58789     0.409411       1.63916        0.40332            134.235       0.1         
6           1.38757     0.492831       1.22399        0.561719           161.093       0.1         
7           1.2036      0.566128       1.32939        0.553906           187.55        0.1         
8           1.07527     0.617077       1.17079        0.589648           214.184       0.1         
9           0.964984    0.660711       0.950063       0.67207            240.817       0.1         
10          0.895905    0.688056       0.972697       0.657031           267.372       0.1         
11          0.828796    0.715043       0.944733       0.686914           297.779       0.1         
12          0.784123    0.731793       0.960205       0.687305           324.312       0.1         
13          0.742033    0.748291       0.858308       0.719727           351.032       0.1         
14          0.692516    0.76627        0.811853       0.725195           377.591       0.1         
15          0.65644     0.776256       0.692374       0.767578           404.23        0.1         
16          0.650682    0.780738       0.796731       0.733008           430.797       0.1         
17          0.610249    0.793435       0.632162       0.791406           457.467       0.1         
18          0.591896    0.803667       0.791757       0.739453           484.04        0.1         
19          0.578718    0.804465       1.04659        0.671484           510.602       0.1         
20          0.554954    0.814431       0.8127         0.733594           537.247       0.1         
21          0.549188    0.814236       0.654926       0.787109           567.567       0.1         
22          0.535866    0.820446       0.640323       0.788672           594.214       0.1         
23          0.527765    0.823028       0.958373       0.70957            620.761       0.1         
24          0.512286    0.830056       0.793664       0.748633           647.363       0.1         
25          0.497195    0.833141       0.717548       0.759961           673.951       0.1         
26          0.495153    0.835759       1.7557         0.511328           700.494       0.1         
27          0.486062    0.837069       0.732518       0.775391           727.13        0.1         
28          0.47963     0.839387       0.669157       0.786328           753.692       0.1         
29          0.475502    0.838312       0.915904       0.718555           780.308       0.1         
30          0.460236    0.844841       0.877713       0.72793            806.806       0.1         
31          0.298483    0.897239       0.381921       0.879687           837.16        0.01        
32          0.211182    0.92784        0.364046       0.882617           863.898       0.01        
33          0.180429    0.937478       0.374651       0.883984           890.554       0.01        
34          0.164156    0.943692       0.361041       0.888867           917.199       0.01        
35          0.144584    0.950387       0.375391       0.889258           943.774       0.01        
36          0.132288    0.954235       0.377427       0.890625           970.4         0.01        
37          0.12103     0.957376       0.390434       0.892578           996.96        0.01        
38          0.111974    0.961204       0.400307       0.886133           1023.62       0.01        
39          0.102573    0.964476       0.399275       0.892773           1050.2        0.01        
40          0.0972647   0.965931       0.432854       0.887109           1076.83       0.01        
41          0.0928545   0.966597       0.418165       0.887305           1107.19       0.01        
42          0.08498     0.96964        0.432462       0.884961           1133.75       0.01        
43          0.0845448   0.970792       0.4365         0.880273           1160.37       0.01        
44          0.0770674   0.973914       0.441935       0.883789           1187.02       0.01        
45          0.0732439   0.974565       0.469901       0.881836           1213.62       0.01        
46          0.070491    0.975962       0.453856       0.8875             1240.16       0.01        
47          0.068846    0.97583        0.461264       0.881055           1266.75       0.01        
48          0.0694101   0.976941       0.435111       0.885742           1293.25       0.01        
49          0.0653772   0.977117       0.461284       0.882617           1319.87       0.01        
50          0.0633419   0.97836        0.464232       0.889648           1346.47       0.01        
51          0.0584663   0.979834       0.464193       0.885547           1376.6        0.01        
52          0.0607617   0.979714       0.466352       0.881445           1403.21       0.01        
53          0.0615791   0.978632       0.451807       0.889453           1429.69       0.01        
54          0.0588031   0.979914       0.489054       0.881836           1456.28       0.01        
55          0.0582368   0.979367       0.4719         0.882617           1482.83       0.01        
56          0.0558719   0.981379       0.495846       0.878125           1509.45       0.01        
57          0.0579962   0.979936       0.472415       0.875781           1536.06       0.01        
58          0.0592009   0.9793         0.454762       0.88418            1562.6        0.01        
59          0.0568546   0.980735       0.487556       0.876172           1589.17       0.01        
60          0.0579785   0.980079       0.472908       0.883398           1615.93       0.01        
61          0.0318918   0.98968        0.416953       0.895703           1646.15       0.001       
62          0.0220127   0.993612       0.416859       0.899609           1672.69       0.001       
63          0.0186169   0.99434        0.417849       0.898242           1699.29       0.001       
64          0.0159041   0.995526       0.41804        0.900781           1725.78       0.001       
65          0.0147089   0.995916       0.429896       0.899609           1752.39       0.001       
66          0.0129457   0.996404       0.433748       0.898828           1779.01       0.001       
67          0.0131643   0.996283       0.433923       0.898828           1805.54       0.001       
68          0.0112659   0.996893       0.437222       0.901758           1832.17       0.001       
69          0.0106502   0.997151       0.443475       0.901758           1858.73       0.001       
70          0.0107926   0.997203       0.445066       0.900391           1885.34       0.001       
71          0.0105973   0.997062       0.44159        0.898633           1915.8        0.001       
72          0.00934292  0.99767        0.450084       0.897266           1942.45       0.001       
73          0.0104884   0.997092       0.451691       0.899023           1969.06       0.001       
74          0.00849317  0.997707       0.450391       0.9                1995.61       0.001       
75          0.00846932  0.997891       0.451362       0.902148           2022.21       0.001       
76          0.00826699  0.997841       0.448779       0.900781           2048.69       0.001       
77          0.00875069  0.997492       0.45095        0.900391           2075.31       0.001       
78          0.00823296  0.998019       0.449194       0.898438           2102.1        0.001       
79          0.00701245  0.998113       0.454196       0.899609           2128.74       0.001       
80          0.00846517  0.997596       0.455877       0.901172           2155.29       0.001       
81          0.00677115  0.99818        0.459518       0.899805           2185.54       0.001       
82          0.00717393  0.998047       0.465337       0.899805           2212.13       0.001       
83          0.00709802  0.997908       0.464472       0.898828           2238.69       0.001       
84          0.00699702  0.998091       0.470595       0.899219           2265.27       0.001       
85          0.00746627  0.997975       0.470063       0.901172           2291.78       0.001       
86          0.00666763  0.998069       0.466201       0.899414           2318.38       0.001       
87          0.00616266  0.998531       0.462948       0.9                2344.86       0.001       
88          0.00719447  0.997847       0.463587       0.899609           2371.48       0.001       
89          0.00638493  0.998335       0.465655       0.901367           2398.08       0.001       
90          0.0061445   0.998286       0.464918       0.900586           2424.59       0.001       
91          0.00582547  0.99838        0.46484        0.900977           2439.32       0.0001      
92          0.00602776  0.998331       0.461773       0.901367           2452.62       0.0001      
93          0.00597372  0.998491       0.464172       0.900195           2466.17       0.0001      
94          0.00619668  0.99813        0.46446        0.900391           2479.48       0.0001      
95          0.00545569  0.998557       0.466654       0.900977           2492.64       0.0001      
96          0.00613322  0.998308       0.465335       0.900781           2505.78       0.0001      
97          0.0054181   0.998624       0.465642       0.900586           2518.9        0.0001      
98          0.00512285  0.998801       0.467116       0.900781           2532.13       0.0001      
99          0.00562234  0.998576       0.464966       0.901367           2545.35       0.0001      
100         0.00551726  0.998624       0.462725       0.901367           2558.63       0.0001      
Test accuracy: 0.8966574
ゼロから学習する場合:
python
reset_seed(0)

model = train(DeepCNN(10), max_epoch=100, base_lr=0.1, lr_decay=(30, 'epoch'))
学習が終了しました.学習曲線と精度のグラフを見てみましょう.
python
Image(filename='results/DeepCNN_cifar10_result/loss.png')
png
python
Image(filename='results/DeepCNN_cifar10_result/accuracy.png')
png
先程の浅い(層数の少ない)CNNを用いた際には60%前後だったValidationデータでの精度が,90%程度まで上がりました.また,テストデータを用いた精度も,およそ90%程度となっています.しかし最新の研究成果では97%以上まで達成されています.さらに精度を上げるには,今回行ったようなネットワークの構造自体の改良ももちろんのこと,学習データを擬似的に増やす操作(Data augmentation)や,複数のモデルの出力を一つの出力に統合する操作(Ensemble)などなど,いろいろな工夫が考えられます.

データセットクラスの使用方法

ここでは,Chainerにすでに用意されているCIFAR10のデータを取得する機能を使って,データセットクラスを自分で書いてみます.Chainerでは,データセットを表すクラスは以下の機能を持っている必要があります.
  • データセット内のデータ数を返す__len__メソッド
  • 引数として渡されるiに対応したデータもしくはデータとラベルの組を返すget_exampleメソッド
その他のデータセットに必要な機能は,chainer.dataset.DatasetMixinクラスを継承することで用意できます.ここでは,DatasetMixinクラスを継承し,学習時に学習データに変換を施してモデルが受け取るデータのバリエーションを増やすData augmentation機能のついたデータセットクラスを作成してみましょう.

CIFAR10データセットクラス

python
class CIFAR10Augmented(chainer.dataset.DatasetMixin):

    def __init__(self, split='train', train_ratio=0.9):
        train_val, test_data = cifar.get_cifar10()
        train_size = int(len(train_val) * train_ratio)
        train_data, valid_data = split_dataset_random(train_val, train_size, seed=0)
        if split == 'train':
            self.data = train_data
        elif split == 'valid':
            self.data = valid_data
        elif split == 'test':
            self.data = test_data
        else:
            raise ValueError("'split' argument should be either 'train', 'valid', or 'test'. But {} was given.".format(split))

        self.split = split
        self.random_crop = 4

    def __len__(self):
        return len(self.data)

    def get_example(self, i):
        x, t = self.data[i]
        if self.split == 'train':
            x = x.transpose(1, 2, 0)
            h, w, _ = x.shape
            x_offset = np.random.randint(self.random_crop)
            y_offset = np.random.randint(self.random_crop)
            x = x[y_offset:y_offset + h - self.random_crop,
                  x_offset:x_offset + w - self.random_crop]
            if np.random.rand() > 0.5:
                x = np.fliplr(x)
            x = x.transpose(2, 0, 1)

        return x, t
このクラスは,CIFAR10のデータのそれぞれに対し,
  • 32x32の大きさの中からランダムに28x28の領域をクロップ
  • 1/2の確率で左右を反転させる
という加工を行っています.このような操作を加えて擬似的に学習データのバリエーションを増やすことで,オーバーフィッティングの抑制などに寄与することが知られています.これらの操作以外にも,画像の色味を変化させるような変換やランダムな回転,アフィン変換など,さまざまな加工によって学習データ数を擬似的に増やす方法が提案されています.

作成したデータセットクラスを用いた学習

それではさっそくこのCIFAR10クラスを使って学習を行ってみましょう.先程と同じネットワークを用い,Data augmentationの効果がどの程度あるのかを調べてみましょう.train関数も含め,データセットクラス以外は先程とすべて同様です.
ここでも,40分ほどの時間がかかりますので,上と同様に90エポックまで学習したあとのsnapshotをダウンロードして読み込ませ,最後の10エポックだけ実際に学習させてみましょう.
python
!wget https://github.com/japan-medical-ai/medical-ai-course-materials/releases/download/v0.1/DeepCNN_cifar10_augmented_snapshot_epoch_90.npz

reset_seed(0)

model = train(DeepCNN(10), max_epoch=100, train_dataset=CIFAR10Augmented(), valid_dataset=CIFAR10Augmented('valid'), test_dataset=CIFAR10Augmented('test'), postfix='augmented_', base_lr=0.1, lr_decay=(30, 'epoch'), snapshot='DeepCNN_cifar10_augmented_snapshot_epoch_90.npz')
--2019-12-09 11:18:31-- https://github.com/japan-medical-ai/medical-ai-course-materials/releases/download/v0.1/DeepCNN_cifar10_augmented_snapshot_epoch_90.npz Resolving github.com (github.com)... 52.74.223.119 Connecting to github.com (github.com)|52.74.223.119|:443... connected. HTTP request sent, awaiting response... 302 Found Location: https://github-production-release-asset-2e65be.s3.amazonaws.com/153412006/5064a880-eeb7-11e8-95bf-80b5d9533256?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAIWNJYAX4CSVEH53A%2F20191209%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20191209T111832Z&X-Amz-Expires=300&X-Amz-Signature=1dc193a84a383234eb92fd9c7cb0db88c9ad7dad53e4e9a32d06f33a3825b832&X-Amz-SignedHeaders=host&actor_id=0&response-content-disposition=attachment%3B%20filename%3DDeepCNN_cifar10_augmented_snapshot_epoch_90.npz&response-content-type=application%2Foctet-stream [following] --2019-12-09 11:18:32-- https://github-production-release-asset-2e65be.s3.amazonaws.com/153412006/5064a880-eeb7-11e8-95bf-80b5d9533256?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAIWNJYAX4CSVEH53A%2F20191209%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20191209T111832Z&X-Amz-Expires=300&X-Amz-Signature=1dc193a84a383234eb92fd9c7cb0db88c9ad7dad53e4e9a32d06f33a3825b832&X-Amz-SignedHeaders=host&actor_id=0&response-content-disposition=attachment%3B%20filename%3DDeepCNN_cifar10_augmented_snapshot_epoch_90.npz&response-content-type=application%2Foctet-stream Resolving github-production-release-asset-2e65be.s3.amazonaws.com (github-production-release-asset-2e65be.s3.amazonaws.com)... 52.216.241.116 Connecting to github-production-release-asset-2e65be.s3.amazonaws.com (github-production-release-asset-2e65be.s3.amazonaws.com)|52.216.241.116|:443... connected. HTTP request sent, awaiting response... 200 OK Length: 56730280 (54M) [application/octet-stream] Saving to: ‘DeepCNN_cifar10_augmented_snapshot_epoch_90.npz’
DeepCNN_cifar10_aug 100%[===================>]  54.10M  12.6MB/s    in 5.1s    

2019-12-09 11:18:38 (10.5 MB/s) - ‘DeepCNN_cifar10_augmented_snapshot_epoch_90.npz’ saved [56730280/56730280]



/usr/local/lib/python3.6/dist-packages/chainer/training/triggers/interval_trigger.py:92: UserWarning: The previous value of iteration is not saved. IntervalTrigger guesses it using current iteration. If this trigger is not called at every iteration, it may not work correctly.
  'The previous value of iteration is not saved. '
/usr/local/lib/python3.6/dist-packages/chainer/training/triggers/interval_trigger.py:104: UserWarning: The previous value of epoch_detail is not saved. IntervalTrigger uses the value of trainer.updater.previous_epoch_detail. If this trigger is not called at every iteration, it may not work correctly.
  'The previous value of epoch_detail is not saved. '


epoch       main/loss   main/accuracy  val/main/loss  val/main/accuracy  elapsed_time  lr        
1           2.5875      0.156405       2.11656        0.203125           24.1767       0.1         
2           1.99359     0.233842       1.84577        0.304492           47.9466       0.1         
3           1.76968     0.325365       1.98983        0.26543            71.6864       0.1         
4           1.61662     0.389537       2.07369        0.26875            95.5072       0.1         
5           1.41259     0.478989       1.52089        0.446484           119.236       0.1         
6           1.23382     0.555487       1.48775        0.480664           143.06        0.1         
7           1.09323     0.613404       1.11949        0.590234           166.813       0.1         
8           0.99303     0.650857       1.33017        0.566992           190.624       0.1         
9           0.926848    0.678755       1.0075         0.665234           214.446       0.1         
10          0.863165    0.702769       0.858035       0.717383           238.218       0.1         
11          0.807948    0.727162       0.9556         0.679297           265.757       0.1         
12          0.769224    0.739116       0.857765       0.711328           289.507       0.1         
13          0.739977    0.752952       0.91583        0.711133           313.339       0.1         
14          0.72064     0.759393       1.20587        0.61875            337.097       0.1         
15          0.690136    0.77093        0.837919       0.726562           360.905       0.1         
16          0.673935    0.774706       1.03539        0.678711           384.67        0.1         
17          0.662879    0.778742       0.730712       0.758789           408.47        0.1         
18          0.639202    0.78742        0.758566       0.765625           432.298       0.1         
19          0.625988    0.792713       1.24791        0.664062           456.061       0.1         
20          0.616269    0.795277       0.963706       0.70625            479.882       0.1         
21          0.611734    0.795962       0.887129       0.723437           507.312       0.1         
22          0.600444    0.800582       0.889526       0.710352           531.173       0.1         
23          0.605317    0.800414       0.715702       0.756445           554.926       0.1         
24          0.584194    0.805731       0.984225       0.694336           578.687       0.1         
25          0.584041    0.805464       0.956576       0.685156           602.49        0.1         
26          0.57384     0.80954        0.977559       0.712695           627.031       0.1         
27          0.560405    0.814298       0.894127       0.718945           650.845       0.1         
28          0.559933    0.816195       0.729981       0.752734           674.584       0.1         
29          0.555933    0.814387       0.841304       0.727344           698.4         0.1         
30          0.558057    0.814971       0.753542       0.757227           722.119       0.1         
31          0.397613    0.866455       0.337977       0.888867           749.619       0.01        
32          0.320024    0.890202       0.322082       0.894336           773.368       0.01        
33          0.293655    0.900479       0.323365       0.888867           797.169       0.01        
34          0.279553    0.904874       0.308263       0.897656           820.951       0.01        
35          0.26519     0.909945       0.301763       0.897852           844.682       0.01        
36          0.259166    0.909846       0.286909       0.904688           868.482       0.01        
37          0.246168    0.915064       0.289997       0.904688           892.324       0.01        
38          0.24097     0.91697        0.280986       0.903906           916.139       0.01        
39          0.233951    0.918892       0.291962       0.904102           939.884       0.01        
40          0.220939    0.923628       0.299502       0.902539           963.666       0.01        
41          0.216055    0.924272       0.2946         0.905664           991.125       0.01        
42          0.215143    0.926972       0.308637       0.897266           1014.88       0.01        
43          0.213903    0.926625       0.291742       0.907812           1038.64       0.01        
44          0.203283    0.929198       0.296043       0.905469           1062.33       0.01        
45          0.19772     0.931041       0.327708       0.89375            1086.11       0.01        
46          0.194907    0.932893       0.312555       0.901563           1109.82       0.01        
47          0.190354    0.934326       0.336271       0.895703           1133.58       0.01        
48          0.190431    0.932915       0.326305       0.902539           1157.3        0.01        
49          0.18926     0.934326       0.312767       0.901172           1181.05       0.01        
50          0.184469    0.936035       0.296937       0.907812           1204.93       0.01        
51          0.181691    0.936521       0.324149       0.901172           1232.22       0.01        
52          0.17546     0.939675       0.347524       0.893945           1256          0.01        
53          0.175786    0.937723       0.335627       0.893164           1279.74       0.01        
54          0.173612    0.940274       0.317897       0.902344           1303.52       0.01        
55          0.171849    0.939548       0.306998       0.90625            1327.23       0.01        
56          0.168304    0.94165        0.31145        0.902148           1351.04       0.01        
57          0.170139    0.941495       0.301311       0.910156           1374.82       0.01        
58          0.165011    0.942708       0.359516       0.892773           1398.52       0.01        
59          0.163968    0.94256        0.365818       0.886133           1422.3        0.01        
60          0.16541     0.942575       0.357          0.890234           1446.01       0.01        
61          0.129435    0.955988       0.277052       0.915234           1473.42       0.001       
62          0.101981    0.965434       0.284798       0.916406           1497.13       0.001       
63          0.0953637   0.967285       0.279956       0.919727           1520.93       0.001       
64          0.0911066   0.968171       0.28204        0.918359           1544.63       0.001       
65          0.0853851   0.97037        0.28504        0.918945           1568.41       0.001       
66          0.0800331   0.972101       0.287688       0.917969           1592.21       0.001       
67          0.0761374   0.973202       0.29148        0.920117           1616.19       0.001       
68          0.0756613   0.973699       0.299635       0.918945           1639.98       0.001       
69          0.075577    0.97407        0.293845       0.918359           1663.68       0.001       
70          0.0730666   0.974676       0.29563        0.920508           1687.47       0.001       
71          0.070825    0.975516       0.295581       0.920313           1714.73       0.001       
72          0.0710753   0.975697       0.29838        0.919336           1738.52       0.001       
73          0.0705982   0.975142       0.298369       0.920508           1762.3        0.001       
74          0.0667571   0.976562       0.299809       0.920508           1786.01       0.001       
75          0.0642319   0.978427       0.300881       0.920898           1809.77       0.001       
76          0.0640179   0.977742       0.304647       0.918359           1833.51       0.001       
77          0.0629752   0.977761       0.299763       0.919336           1857.31       0.001       
78          0.0586612   0.979523       0.306034       0.922461           1881.03       0.001       
79          0.059752    0.979869       0.311227       0.921289           1904.82       0.001       
80          0.0571715   0.980213       0.304607       0.920703           1928.51       0.001       
81          0.0573339   0.980136       0.315108       0.92168            1956.06       0.001       
82          0.0560348   0.979847       0.321934       0.916992           1979.87       0.001       
83          0.0553193   0.980613       0.315378       0.914648           2003.58       0.001       
84          0.0531816   0.98129        0.318977       0.919531           2027.34       0.001       
85          0.0560367   0.98097        0.310993       0.919141           2051.02       0.001       
86          0.0535048   0.981534       0.317829       0.920117           2075.04       0.001       
87          0.0522188   0.981571       0.31144        0.920313           2098.73       0.001       
88          0.0526632   0.982156       0.318594       0.920703           2122.46       0.001       
89          0.0528096   0.981445       0.309017       0.92207            2146.21       0.001       
90          0.0499371   0.982928       0.313269       0.920508           2169.92       0.001       
91          0.046129    0.984375       0.312747       0.919141           2182.89       0.0001      
92          0.0442634   0.984642       0.308682       0.921484           2195.39       0.0001      
93          0.0456881   0.98422        0.308501       0.920898           2208.13       0.0001      
94          0.0450576   0.984553       0.311052       0.921484           2220.46       0.0001      
95          0.0450287   0.985152       0.310078       0.920117           2232.59       0.0001      
96          0.0444907   0.984486       0.312837       0.921289           2244.59       0.0001      
97          0.0439379   0.985418       0.310588       0.92168            2256.65       0.0001      
98          0.0430186   0.985352       0.310459       0.920508           2268.82       0.0001      
99          0.0419429   0.985288       0.310179       0.92168            2281.02       0.0001      
100         0.0421573   0.985574       0.31397        0.920898           2293.27       0.0001      
Test accuracy: 0.917227
先程のData augmentationなしの場合は90%程度だったテスト精度が,学習データにaugmentationを施すことでおよそ1.8%程度向上していることが分かりました.
損失と精度のグラフを見てみましょう.
python
Image(filename='results/DeepCNN_cifar10_augmented_result/loss.png')
png
python
Image(filename='results/DeepCNN_cifar10_augmented_result/accuracy.png')
png

Data Augmentationの簡単な使い方

前述のようにデータセット内の各画像についていろいろな変換を行って擬似的にデータを増やすような操作をData Augmentationといいます.上では,オリジナルのデータセットクラスを作る方法を示すために変換の操作もget_example()内に書くという実装を行いましたが,実はもっと簡単にいろいろな変換をデータに対して行う方法があります.
それは,TransformDatasetクラスを使う方法です.TransformDatasetは,元になるデータセットオブジェクトと,そこからサンプルしてきた各データ点に対して行いたい変換を関数の形で与えると,変換済みのデータを返してくれるようなデータセットオブジェクトに加工してくれる便利なクラスです.簡単な使い方は以下のようになります.
python
from chainer.datasets import TransformDataset

train_val, test_dataset = cifar.get_cifar10()
train_size = int(len(train_val) * 0.9)
train_dataset, valid_dataset = split_dataset_random(train_val, train_size, seed=0)


# 行いたい変換を関数の形で書く
def transform(inputs):
    x, t = inputs
    x = x.transpose(1, 2, 0)
    h, w, _ = x.shape
    x_offset = np.random.randint(4)
    y_offset = np.random.randint(4)
    x = x[y_offset:y_offset + h - 4,
          x_offset:x_offset + w - 4]
    if np.random.rand() > 0.5:
        x = np.fliplr(x)
    x = x.transpose(2, 0, 1)
    
    return x, t


# 各データをtransform関数で処理して返すデータセットオブジェクト
train_dataset = TransformDataset(train_dataset, transform)
このようにして得られた新しいtrain_datasetは,自作のデータセットクラスと同じような変換処理を行った上でデータを返してくれるデータセットオブジェクトとなります.

ChainerCVを活用した変換処理

さて,先ほどご紹介したコードでは,画像に対するランダムクロップ,及びランダムな左右反転の処理を自ら実装していました.もし,より多様な変換を行いたい場合,上記のtransform関数に処理を追加していくことになりますが,一般的に用いられる変換処理をその度に自ら実装するのは手間です.そこで本項では最後に,ChainerCV[Niitani 2017]をご紹介します.ChainerCVは,Computer Visionに特化した機能が豊富に追加された,Chainerの補助パッケージとしての役割を担うオープンソース・ソフトウェアです.
python
!pip install chainercv
Collecting chainercv [?25l Downloading https://files.pythonhosted.org/packages/e8/1c/1f267ccf5ebdf1f63f1812fa0d2d0e6e35f0d08f63d2dcdb1351b0e77d85/chainercv-0.13.1.tar.gz (260kB)  |████████████████████████████████| 266kB 35.5MB/s [?25hRequirement already satisfied: chainer>=6.0 in /usr/local/lib/python3.6/dist-packages (from chainercv) (6.5.0) Requirement already satisfied: Pillow in /usr/local/lib/python3.6/dist-packages (from chainercv) (4.3.0) Requirement already satisfied: protobuf>=3.0.0 in /usr/local/lib/python3.6/dist-packages (from chainer>=6.0->chainercv) (3.10.0) Requirement already satisfied: six>=1.9.0 in /usr/local/lib/python3.6/dist-packages (from chainer>=6.0->chainercv) (1.12.0) Requirement already satisfied: filelock in /usr/local/lib/python3.6/dist-packages (from chainer>=6.0->chainercv) (3.0.12) Requirement already satisfied: typing-extensions<=3.6.6 in /usr/local/lib/python3.6/dist-packages (from chainer>=6.0->chainercv) (3.6.6) Requirement already satisfied: typing<=3.6.6 in /usr/local/lib/python3.6/dist-packages (from chainer>=6.0->chainercv) (3.6.6) Requirement already satisfied: numpy>=1.9.0 in /usr/local/lib/python3.6/dist-packages (from chainer>=6.0->chainercv) (1.17.4) Requirement already satisfied: setuptools in /usr/local/lib/python3.6/dist-packages (from chainer>=6.0->chainercv) (42.0.1) Requirement already satisfied: olefile in /usr/local/lib/python3.6/dist-packages (from Pillow->chainercv) (0.46) Building wheels for collected packages: chainercv Building wheel for chainercv (setup.py) ... [?25l[?25hdone Created wheel for chainercv: filename=chainercv-0.13.1-cp36-cp36m-linux_x86_64.whl size=537355 sha256=fc0aaac281e6d6effbb0699ea77604eacc16f6683691d9463803c54c192a746b Stored in directory: /root/.cache/pip/wheels/ea/10/01/e221beaa4b3d8341aa819a39ab8d4677457c79c81f521f3a94 Successfully built chainercv Installing collected packages: chainercv Successfully installed chainercv-0.13.1
ChainerCVには,画像に対する様々な変換があらかじめ用意されています.
例えば,上でNumPyを使って書いていたランダムクロップやランダム左右反転は,chainercv.transformsモジュールを使うと,それぞれ以下のように1行で書くことができます:
python
x = chainercv.transforms.random_crop(x, (28, 28))  # ランダムクロップ
x = chainercv.transforms.random_flip(x)  # ランダム左右反転
chainercv.transformsモジュールを使って,transform関数をアップデートしてみましょう.ちなみに,get_cifar10()で得られるデータセットでは,デフォルトで画像の画素値の範囲が[0, 1]にスケールされています.しかし,get_cifar10()scale=255.を渡しておくと,値の範囲をもともとの[0, 255]のままにできます.今回行われる処理は,以下の5つです:
  1. PCA lighting: 先行研究(AlexNet)の学習で使われていた方法で,色を変化させる変換処理を行います.
  2. Standardization: 訓練用データセット全体からチャンネルごとの画素値の平均・標準偏差を求めて標準化をします
  3. Random flip: ランダムに画像の左右を反転します
  4. Random expand: [1, 1.5]からランダムに決めた大きさの黒いキャンバスを作り,その中のランダムな位置へ画像を配置します
  5. Random crop: (28, 28)の大きさの領域をランダムにクロップします
python
from functools import partial
from chainercv import transforms

train_val, test_dataset = cifar.get_cifar10(scale=255.)
train_size = int(len(train_val) * 0.9)
train_dataset, valid_dataset = split_dataset_random(train_val, train_size, seed=0)

mean = np.mean([x for x, _ in train_dataset], axis=(0, 2, 3))
std = np.std([x for x, _ in train_dataset], axis=(0, 2, 3))


def transform(inputs, mean, std, train=True):
    img, label = inputs
    img = img.copy()
    
    # Color augmentation
    if train:
        img = transforms.pca_lighting(img, 76.5)
        
    # Standardization
    img -= mean[:, None, None]
    img /= std[:, None, None]
    
    # Random flip & crop
    if train:
        img = transforms.random_flip(img, x_random=True)
        img = transforms.random_expand(img, max_ratio=1.5)
        img = transforms.random_crop(img, (28, 28))
        
    return img, label

train_dataset = TransformDataset(train_dataset, partial(transform, mean=mean, std=std, train=True))
valid_dataset = TransformDataset(valid_dataset, partial(transform, mean=mean, std=std, train=False))
test_dataset = TransformDataset(test_dataset, partial(transform, mean=mean, std=std, train=False))
では,standardizationとChainerCVによるPCA Lightingを追加したTransformDatasetを使って学習をしてみましょう.
これまでと同様,90エポックまで学習させておいたsnapshotを用いて,最後の10エポックだけ学習を行います.
python
!wget https://github.com/japan-medical-ai/medical-ai-course-materials/releases/download/v0.1/DeepCNN_cifar10_augmented2_snapshot_epoch_90.npz

reset_seed(0)

model = train(DeepCNN(10), max_epoch=100, train_dataset=train_dataset, valid_dataset=valid_dataset, test_dataset=test_dataset, postfix='augmented2_', base_lr=0.1, lr_decay=(30, 'epoch'), snapshot='DeepCNN_cifar10_augmented2_snapshot_epoch_90.npz')

Discussion

コメントにはログインが必要です。