Ccmmutty logo
Commutty IT
6 min read

K近傍によるアヤメの分類

https://cdn.magicode.io/media/notebox/b17a36e4-40bb-4b60-8e97-78d100df8588.jpeg
K近傍法とは、入力として与えられたデータ群に対して、未知のデータの中心から近い周辺の既知のデータ K 個を集め、それらの分類を多数決で決めるという手法です。以下では K 近傍法を Go 言語でどの様に実装するかを解説します。
機械学習によく利用されるアヤメの品種を入力データとして用います。
SepalLength,SepalWidth,PetalLength,PetalWidth,Name
5.1,3.5,1.4,0.2,Iris-setosa
4.9,3.0,1.4,0.2,Iris-setosa
4.7,3.2,1.3,0.2,Iris-setosa
4.6,3.1,1.5,0.2,Iris-setosa
5.0,3.6,1.4,0.2,Iris-setosa
5.4,3.9,1.7,0.4,Iris-setosa
4.6,3.4,1.4,0.3,Iris-setosa
    (略)
5.1,3.8,1.6,0.2,Iris-setosa
4.6,3.2,1.4,0.2,Iris-setosa
5.3,3.7,1.5,0.2,Iris-setosa
5.0,3.3,1.4,0.2,Iris-setosa
7.0,3.2,4.7,1.4,Iris-versicolor
6.4,3.2,4.5,1.5,Iris-versicolor
6.9,3.1,4.9,1.5,Iris-versicolor
5.5,2.3,4.0,1.3,Iris-versicolor
6.5,2.8,4.6,1.5,Iris-versicolor
5.7,2.8,4.5,1.3,Iris-versicolor
アヤメの「がく」と「はなびら」について、それぞれの幅と高さを属性とし、それに対応した品種名が並ぶデータフォーマットです。CSV では左から4カラムが入力となる属性、残り1カラムが分類となる品種になります。
まずはこの CSV を扱いやすいように、float64 型の属性4カラムのスライス、また各行に対応した string 型の分類名に分ける関数を用意します。
go
import (
    "encoding/csv"
    "os"
    "strconv"
    "sort"
    "math"
    "fmt"
  	"net/http"
)

func loadData() ([][]float64, []string, error) {
	resp, err := http.Get("https://gist.githubusercontent.com/mattn/8e652b219847dbf03a68aef6a6ff18d8/raw/cb4aed0ef98ac2cba29e2319f413df7487ed965d/iris.csv")
    if err != nil {
        return nil, nil, err
    }
    defer resp.Body.Close()
    
    r := csv.NewReader(resp.Body)
    r.Comma = ','
    r.LazyQuotes = true
    _, err = r.Read()
    if err != nil {
        return nil, nil, err
    }
    rows, err := r.ReadAll()
    if err != nil {
        return nil, nil, err
    }

    X := [][]float64{}
    Y := []string{}
    for _, cols := range rows {
        x := make([]float64, 4)
        y := cols[4]
        for j, s := range cols[:4] {
            v, err := strconv.ParseFloat(s, 64)
            if err != nil {
                return nil, nil, err
            }
            x[j] = v
        }
        X = append(X, x)
        Y = append(Y, y)
    }
    return X, Y, nil
}

X, Y, _ := loadData()
_, _ = fmt.Println(X[:3], T[:3])

[[5.1 3.5 1.4 0.2] [4.9 3 1.4 0.2] [4.7 3.2 1.3 0.2]] [Iris-setosa Iris-setosa Iris-setosa]
学習データとテストデータに分けましょう。半分を学習データ、残りをテストデータにします。
go
var trainX, testX [][]float64
var trainY, testY []string
for i, _ := range X {
    if i%2 == 0 {
        trainX = append(trainX, X[i])
        trainY = append(trainY, Y[i])
    } else {
        testX = append(testX, X[i])
        testY = append(testY, Y[i])
    }
}

_, _ = fmt.Println(testX[:3], testY[:3])
_, _ = fmt.Println(trainX[:3], trainY[:3])

[[4.9 3 1.4 0.2] [4.6 3.1 1.5 0.2] [5.4 3.9 1.7 0.4]] [Iris-setosa Iris-setosa Iris-setosa] [[5.1 3.5 1.4 0.2] [4.7 3.2 1.3 0.2] [5 3.6 1.4 0.2]] [Iris-setosa Iris-setosa Iris-setosa]
以下が K 近傍の実装になります。
go
type KNN struct {
    k  int         // 近傍のいくらのデータを対象とするか
    XX [][]float64 // 学習する属性
    Y []string     // それに対する分類
}

// distance は距離を返します
func distance(lhs, rhs []float64) float64 {
    val := 0.0
    for i, _ := range lhs {
        val += math.Pow(lhs[i] - rhs[i], 2)
    }
    return math.Sqrt(val)
}

// predict は検査対象 X を学習データ knn.XX で個々の距離を求め
// 距離が近い順に並べ替える。その内 k 個を取り出し、多い分類名
// を多数決で決定する。X と同じ行数分の分類名を返す。
func (knn *KNN) predict(X [][]float64) []string {
    results := []string{}
    for _, x := range X {
        type item struct {
            i int     // インデックス(どの行か)
            f float64 // 距離
        }
        // 学習データと検査対象との各々の距離を調べる
        var items []item
        for i, xx := range knn.XX {
            items = append(items, item {
                i: i,
                f: distance(x, xx),
            })
        }
        // 距離で並び変える
        sort.Slice(items, func(i, j int) bool {
            return items[i].f < items[j].f
        })

        // 近い方から k 個取り出し分類名を集める
        var labels []string
        for i := 0; i < knn.k; i++ {
            labels = append(labels, knn.Y[items[i].i])
        }

        // 見付かった各分類の個数を調べる
        founds := map[string]int{}
        for _, label := range labels {
            founds[label]++
        }

        type rank struct {
            i int
            s string
        }
        // 個数が多かった分類名を決定する
        var ranks []rank
        for k, v := range founds {
            ranks = append(ranks, rank {
                i: v,
                s: k,
            })
        }
        sort.Slice(ranks, func(i, j int) bool {
            return ranks[i].i > ranks[j].i
        })
        // 一番個数が多かった分類名を結果として残す
        results = append(results, ranks[0].s)
    }
    return results
}
実際に実行してみましょう。
go
knn := KNN {
    k: 8,       // 近傍 8 つを検査対象とする
    XX: trainX,
    Y: trainY,
}

// 推論
predicted := knn.predict(testX)

_, _ = fmt.Println(testX[:3], predicted[:3])

[[4.9 3 1.4 0.2] [4.6 3.1 1.5 0.2] [5.4 3.9 1.7 0.4]] [Iris-setosa Iris-setosa Iris-setosa]
predicted の各行は、testX の各行から推論した分類名になっています。元の結果と比較して、どれくらい正しく分類できたか調べましょう。
go
correct := 0
for i, _ := range predicted {
    // 推論結果(predicted)と実際の分類(testY)を確認します
    if predicted[i] == testY[i] {
        correct += 1
    }
}
_, _ = fmt.Printf("%f%%\n", 100*float64(correct)/float64(len(predicted)))

98.666667%
まずまずの結果ではないでしょうか。

Discussion

コメントにはログインが必要です。