Ccmmutty logo
Commutty IT
4 min read

Pythonライブラリ(決定木の可視化):dtreeviz

https://cdn.magicode.io/media/notebox/5096276d-628a-42b1-8565-d134f6cd5d2e.jpeg
機械学習で紹介した決定木モデルの可視化ライブラリとしてdtreevizを紹介します。Graphvizよりも直感的なグラフが作成可能であり、機械学習によるモデルのブラックボックス化を改善できます。

1.サンプルデータ作成

python
!pip install graphviz
!pip install dtreeviz

Collecting graphviz
Downloading graphviz-0.20-py3-none-any.whl (46 kB) [?25l |███████ | 10 kB 22.0 MB/s eta 0:00:01 |██████████████ | 20 kB 21.0 MB/s eta 0:00:01 |█████████████████████ | 30 kB 23.9 MB/s eta 0:00:01 |████████████████████████████ | 40 kB 18.4 MB/s eta 0:00:01 |████████████████████████████████| 46 kB 7.9 MB/s [?25h
Installing collected packages: graphviz
Successfully installed graphviz-0.20
Collecting dtreeviz
Downloading dtreeviz-1.3.6.tar.gz (61 kB) [?25l |█████▎ | 10 kB 16.2 MB/s eta 0:00:01 |██████████▋ | 20 kB 18.7 MB/s eta 0:00:01 |████████████████ | 30 kB 11.9 MB/s eta 0:00:01 |█████████████████████▎ | 40 kB 9.9 MB/s eta 0:00:01 |██████████████████████████▋ | 51 kB 11.0 MB/s eta 0:00:01 |████████████████████████████████| 61 kB 11.1 MB/s eta 0:00:01 |████████████████████████████████| 61 kB 177 kB/s [?25h
Preparing metadata (setup.py) ... [?25l
-
 done [?25h
Requirement already satisfied: graphviz>=0.9 in /srv/conda/envs/notebook/lib/python3.7/site-packages (from dtreeviz) (0.20) Requirement already satisfied: pandas in /srv/conda/envs/notebook/lib/python3.7/site-packages (from dtreeviz) (1.1.5) Requirement already satisfied: numpy in /srv/conda/envs/notebook/lib/python3.7/site-packages (from dtreeviz) (1.19.5) Requirement already satisfied: scikit-learn in /srv/conda/envs/notebook/lib/python3.7/site-packages (from dtreeviz) (0.22.2.post1) Requirement already satisfied: matplotlib in /srv/conda/envs/notebook/lib/python3.7/site-packages (from dtreeviz) (3.2.2)
Collecting colour Downloading colour-0.1.5-py2.py3-none-any.whl (23 kB)
Collecting pytest Downloading pytest-7.1.2-py3-none-any.whl (297 kB) [?25l |█ | 10 kB 21.7 MB/s eta 0:00:01 |██▏ | 20 kB 24.7 MB/s eta 0:00:01 |███▎ | 30 kB 30.1 MB/s eta 0:00:01 |████▍ | 40 kB 33.1 MB/s eta 0:00:01 |█████▌ | 51 kB 35.5 MB/s eta 0:00:01 |██████▋ | 61 kB 38.1 MB/s eta 0:00:01 |███████▊ | 71 kB 29.7 MB/s eta 0:00:01 |████████▉ | 81 kB 25.6 MB/s eta 0:00:01 |██████████ | 92 kB 27.1 MB/s eta 0:00:01 |███████████ | 102 kB 26.0 MB/s eta 0:00:01 |████████████▏ | 112 kB 26.0 MB/s eta 0:00:01 |█████████████▎ | 122 kB 26.0 MB/s eta 0:00:01 |██████████████▍ | 133 kB 26.0 MB/s eta 0:00:01 |███████████████▌ | 143 kB 26.0 MB/s eta 0:00:01 |████████████████▌ | 153 kB 26.0 MB/s eta 0:00:01 |█████████████████▋ | 163 kB 26.0 MB/s eta 0:00:01 |██████████████████▊ | 174 kB 26.0 MB/s eta 0:00:01 |███████████████████▉ | 184 kB 26.0 MB/s eta 0:00:01 |█████████████████████ | 194 kB 26.0 MB/s eta 0:00:01 |██████████████████████ | 204 kB 26.0 MB/s eta 0:00:01 |███████████████████████▏ | 215 kB 26.0 MB/s eta 0:00:01 |████████████████████████▎ | 225 kB 26.0 MB/s eta 0:00:01 |█████████████████████████▍ | 235 kB 26.0 MB/s eta 0:00:01 |██████████████████████████▌ | 245 kB 26.0 MB/s eta 0:00:01 |███████████████████████████▋ | 256 kB 26.0 MB/s eta 0:00:01 |████████████████████████████▊ | 266 kB 26.0 MB/s eta 0:00:01 |█████████████████████████████▉ | 276 kB 26.0 MB/s eta 0:00:01 |███████████████████████████████ | 286 kB 26.0 MB/s eta 0:00:01 |████████████████████████████████| 296 kB 26.0 MB/s eta 0:00:01
|████████████████████████████████| 297 kB 26.0 MB/s [?25hRequirement already satisfied: python-dateutil>=2.1 in /srv/conda/envs/notebook/lib/python3.7/site-packages (from matplotlib->dtreeviz) (2.8.2) Requirement already satisfied: kiwisolver>=1.0.1 in /srv/conda/envs/notebook/lib/python3.7/site-packages (from matplotlib->dtreeviz) (1.3.2) Requirement already satisfied: cycler>=0.10 in /srv/conda/envs/notebook/lib/python3.7/site-packages (from matplotlib->dtreeviz) (0.11.0) Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /srv/conda/envs/notebook/lib/python3.7/site-packages (from matplotlib->dtreeviz) (3.0.7)
Requirement already satisfied: pytz>=2017.2 in /srv/conda/envs/notebook/lib/python3.7/site-packages (from pandas->dtreeviz) (2021.3)
Collecting tomli>=1.0.0 Downloading tomli-2.0.1-py3-none-any.whl (12 kB)
Collecting pluggy<2.0,>=0.12 Downloading pluggy-1.0.0-py2.py3-none-any.whl (13 kB) Requirement already satisfied: attrs>=19.2.0 in /srv/conda/envs/notebook/lib/python3.7/site-packages (from pytest->dtreeviz) (21.4.0)
Collecting iniconfig Downloading iniconfig-1.1.1-py2.py3-none-any.whl (5.0 kB)
Collecting py>=1.8.2
Downloading py-1.11.0-py2.py3-none-any.whl (98 kB) [?25l |███▎ | 10 kB 28.6 MB/s eta 0:00:01 |██████▋ | 20 kB 34.2 MB/s eta 0:00:01 |██████████ | 30 kB 36.9 MB/s eta 0:00:01 |█████████████▎ | 40 kB 40.4 MB/s eta 0:00:01 |████████████████▋ | 51 kB 39.3 MB/s eta 0:00:01 |████████████████████ | 61 kB 39.3 MB/s eta 0:00:01 |███████████████████████▎ | 71 kB 40.1 MB/s eta 0:00:01 |██████████████████████████▋ | 81 kB 41.4 MB/s eta 0:00:01 |█████████████████████████████▉ | 92 kB 43.1 MB/s eta 0:00:01 |████████████████████████████████| 98 kB 17.5 MB/s [?25hRequirement already satisfied: packaging in /srv/conda/envs/notebook/lib/python3.7/site-packages (from pytest->dtreeviz) (21.3) Requirement already satisfied: importlib-metadata>=0.12 in /srv/conda/envs/notebook/lib/python3.7/site-packages (from pytest->dtreeviz) (4.10.1) Requirement already satisfied: joblib>=0.11 in /srv/conda/envs/notebook/lib/python3.7/site-packages (from scikit-learn->dtreeviz) (1.1.0) Requirement already satisfied: scipy>=0.17.0 in /srv/conda/envs/notebook/lib/python3.7/site-packages (from scikit-learn->dtreeviz) (1.7.3)
Requirement already satisfied: typing-extensions>=3.6.4 in /srv/conda/envs/notebook/lib/python3.7/site-packages (from importlib-metadata>=0.12->pytest->dtreeviz) (4.0.1) Requirement already satisfied: zipp>=0.5 in /srv/conda/envs/notebook/lib/python3.7/site-packages (from importlib-metadata>=0.12->pytest->dtreeviz) (3.7.0) Requirement already satisfied: six>=1.5 in /srv/conda/envs/notebook/lib/python3.7/site-packages (from python-dateutil>=2.1->matplotlib->dtreeviz) (1.16.0)
Building wheels for collected packages: dtreeviz Building wheel for dtreeviz (setup.py) ... [?25l
-
 \
 done [?25h Created wheel for dtreeviz: filename=dtreeviz-1.3.6-py3-none-any.whl size=67326 sha256=c7e1d754459e077c6d6dcad3a63be8bc9a7dfc9cfe657f4d31053d8dd940d31b Stored in directory: /home/jovyan/.cache/pip/wheels/33/76/fe/99227ff871766e7284bb88d3d28e2f4886b8ec04d669033478 Successfully built dtreeviz
Installing collected packages: tomli, py, pluggy, iniconfig, pytest, colour, dtreeviz
Successfully installed colour-0.1.5 dtreeviz-1.3.6 iniconfig-1.1.1 pluggy-1.0.0 py-1.11.0 pytest-7.1.2 tomli-2.0.1
python
import numpy as np
import pandas as pd 
from sklearn import datasets
from sklearn.model_selection import train_test_split


iris = datasets.load_iris() #Irisデータを読み込む
data, target = iris.data, iris.target #データとラベルを分ける
x_train, x_test, y_train, y_test = train_test_split(data, target, test_size=0.3, random_state=0) # 学習データとテストデータへ7:3で分割

print(x_train.dtype, x_test.dtype, y_train.dtype, y_test.dtype) #データ型の確認
print(x_train.shape, x_test.shape, y_train.shape, y_test.shape) #データ数の確認

float64 float64 int64 int64 (105, 4) (45, 4) (105,) (45,)

2.決定木モデルの作成

python
from sklearn.tree import DecisionTreeClassifier

tree = DecisionTreeClassifier() #分類問題のモデルを作成
tree.fit(x_train, y_train) # 学習
y_pred = tree.predict(x_test) # テストデータの予測値

print(tree.get_params())
print(y_pred)
print('学習時スコア:', tree.score(x_train, y_train), '検証スコア', tree.score(x_test, y_test))

{'ccp_alpha': 0.0, 'class_weight': None, 'criterion': 'gini', 'max_depth': None, 'max_features': None, 'max_leaf_nodes': None, 'min_impurity_decrease': 0.0, 'min_impurity_split': None, 'min_samples_leaf': 1, 'min_samples_split': 2, 'min_weight_fraction_leaf': 0.0, 'presort': 'deprecated', 'random_state': None, 'splitter': 'best'} [2 1 0 2 0 2 0 1 1 1 2 1 1 1 1 0 1 1 0 0 2 1 0 0 2 0 0 1 1 0 2 1 0 2 2 1 0 2 1 1 2 0 2 0 0] 学習時スコア: 1.0 検証スコア 0.9777777777777777

3.重要度分析

python
import matplotlib.pyplot as plt

x = iris.feature_names #特徴量名 ->['sepal length (cm)','sepal width (cm)','petal length (cm)','petal width (cm)']
y = tree.feature_importances_ #特徴量の重要度

print(y)
plt.barh(x, y)

[0.02150464 0.02150464 0.90006666 0.05692405]
<BarContainer object of 4 artists>
<Figure size 432x288 with 1 Axes>

4. Graphvizによる木構造可視化

python
import graphviz
from sklearn.tree import export_graphviz

dot = export_graphviz(tree) #決定木モデルのdot形式を取得
graph = graphviz.Source(dot) #DOT記法をレンダリング
# print(dot) #Raw-Dotが出力
graph #グラフを出力
python
import graphviz
from sklearn.tree import export_graphviz

dot = export_graphviz(tree, filled=True, rounded=True, 
                      class_names=['setosa', 'versicolor', 'virginica'],
                      feature_names=['sepal length (cm)', 'sepal width (cm)', 'petal length (cm)', 'petal width (cm)'],
                      out_file=None) 

graph = graphviz.Source(dot) #DOT記法をレンダリング
graph #グラフを出力

5.dtreevizによる可視化

python
from dtreeviz.trees import dtreeviz

viz = dtreeviz(
    tree, # 決定木モデル
    iris.data, #データ
    iris.target, #データラベル
    target_name='variety', #正解値のラベル
    feature_names=iris.feature_names, #特徴量名
    class_names=[str(i) for i in iris.target_names], #クラス名:['setosa', 'versicolor', 'virginica']
) 

# viz.view() #ブラウザ上で表示
viz
python
from dtreeviz.trees import dtreeviz

viz = dtreeviz(
    tree, # 決定木モデル
    iris.data, #データ
    iris.target, #データラベル
    target_name='variety', #正解値のラベル
    feature_names=iris.feature_names, #特徴量名
    class_names=[str(i) for i in iris.target_names], #クラス名:['setosa', 'versicolor', 'virginica'],
    X = [1,2,3,4] #適当なデータ
) 

# viz.view() #ブラウザ上で表示
display(viz)

Discussion

コメントにはログインが必要です。