-
Notifications
You must be signed in to change notification settings - Fork 1
/
knn.go
64 lines (54 loc) · 1.28 KB
/
knn.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
package main
import (
"fmt"
"math"
"sort"
)
type DataPoint struct {
Features []float64
Label float64
}
type Distance struct {
Index int
Value float64
}
type KNN struct {
K int
DataPoints []DataPoint
}
func NewKNN(k int, dataPoints []DataPoint) *KNN {
return &KNN{
K: k,
DataPoints: dataPoints,
}
}
func (knn *KNN) Distance(point1 []float64, point2 []float64) float64 {
sum := 0.0
for i := 0; i < len(point1); i++ {
sum += math.Pow(point1[i]-point2[i], 2)
}
return math.Sqrt(sum)
}
func (knn *KNN) Classify(point []float64) float64 {
distances := make([]Distance, len(knn.DataPoints))
for i := 0; i < len(knn.DataPoints); i++ {
distances[i] = Distance{
Index: i,
Value: knn.Distance(point, knn.DataPoints[i].Features),
}
}
sort.Slice(distances, func(i, j int) bool { return distances[i].Value < distances[j].Value })
count := make(map[float64]int)
for i := 0; i < knn.K; i++ {
count[knn.DataPoints[distances[i].Index].Label]++
}
maxCount := 0
maxLabel := 0.0
for label, c := range count {
if c > maxCount {
maxCount = c
maxLabel = label
}
}
return maxLabel
}