Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 78 additions & 0 deletions math.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package lo

import (
"math"
"sort"

"github.com/samber/lo/internal/constraints"
)
Expand Down Expand Up @@ -212,3 +213,80 @@ func Mode[T constraints.Integer | constraints.Float](collection []T) []T {

return mode
}

// Median calculates the median of a collection of numbers
func Median[T constraints.Float | constraints.Integer](collection []T) T {
length := len(collection)
if length == 0 {
return 0
}

sorted := make([]T, length)
copy(sorted, collection)
sort.Slice(sorted, func(i, j int) bool {
return sorted[i] < sorted[j]
})

mid := length / 2
if length%2 == 1 {
return sorted[mid]
}

return (sorted[mid-1] + sorted[mid]) / 2
}

// MedianBy calculates the median of a collection of numbers using the given return value from the iteration function
func MedianBy[T any, R constraints.Float | constraints.Integer](collection []T, iteratee func(item T) R) R {
length := len(collection)
if length == 0 {
return 0
}

values := make([]R, length)
for i, item := range collection {
values[i] = iteratee(item)
}

sort.Slice(values, func(i, j int) bool {
return values[i] < values[j]
})

mid := length / 2
if length%2 == 1 {
return values[mid]
}

return (values[mid-1] + values[mid]) / 2
}

// MedianByErr calculates the median of a collection of numbers using the given return value from the iteration function
// If the iteratee returns an error, iteration stops and the error is returned
// If collection is empty 0 and nil error are returned
func MedianByErr[T any, R constraints.Float | constraints.Integer](collection []T, iteratee func(item T) (R, error)) (R, error) {
length := len(collection)
if length == 0 {
return 0, nil
}

values := make([]R, length)
for i, item := range collection {
val, err := iteratee(item)
if err != nil {
return 0, err
}
values[i] = val
}

sorted := make([]R, length)
copy(sorted, values)
sort.Slice(sorted, func(i, j int) bool {
return sorted[i] < sorted[j]
})

mid := length / 2
if length%2 == 1 {
return sorted[mid], nil
}

return (sorted[mid-1] + sorted[mid]) / 2, nil
}
166 changes: 166 additions & 0 deletions math_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -529,3 +529,169 @@ func TestModeCapacityConsistency(t *testing.T) {
is.Equal([]int{3}, result, "Mode should return correct mode value")
is.Equal(len(result), cap(result), "Mode slice capacity should match its length")
}

func TestMedian(t *testing.T) {
t.Parallel()
is := assert.New(t)

result1 := Median([]float32{1.0, 2.0, 3.0, 4.0, 5.0})
result2 := Median([]float32{1.0, 2.0, 3.0, 4.0})
result3 := Median([]int32{1, 2, 3, 4, 5})
result4 := Median([]int32{1, 2, 3, 4})
result5 := Median([]uint32{})
result6 := Median([]float64{1.5, 2.5, 3.5, 4.5})

is.InEpsilon(3.0, result1, 1e-7)
is.InEpsilon(2.5, result2, 1e-7)
is.Equal(int32(3), result3)
is.Equal(int32(2), result4)
is.Equal(uint32(0), result5)
is.InEpsilon(3.0, result6, 1e-7)
}

func TestMedianBy(t *testing.T) {
t.Parallel()
is := assert.New(t)

result1 := MedianBy([]float32{1.0, 2.0, 3.0, 4.0, 5.0}, func(n float32) float32 { return n })
result2 := MedianBy([]float32{1.0, 2.0, 3.0, 4.0}, func(n float32) float32 { return n })
result3 := MedianBy([]int32{1, 2, 3, 4, 5}, func(n int32) int32 { return n })
result4 := MedianBy([]int32{1, 2, 3, 4}, func(n int32) int32 { return n })
result5 := MedianBy([]uint32{}, func(n uint32) uint32 { return n })
result6 := MedianBy([]float64{1.5, 2.5, 3.5, 4.5}, func(n float64) float64 { return n })

is.InEpsilon(3.0, result1, 1e-7)
is.InEpsilon(2.5, result2, 1e-7)
is.Equal(int32(3), result3)
is.Equal(int32(2), result4)
is.Equal(uint32(0), result5)
is.InEpsilon(3.0, result6, 1e-7)
}

//nolint:errcheck,forcetypeassert
func TestMedianByErr(t *testing.T) {
t.Parallel()
is := assert.New(t)

testErr := assert.AnError

// Test normal operation (no error) - table driven
tests := []struct {
name string
input any
expected any
}{
{
name: "float32 slice (odd)",
input: []float32{1.0, 2.0, 3.0, 4.0, 5.0},
expected: float32(3.0),
},
{
name: "float32 slice (even)",
input: []float32{1.0, 2.0, 3.0, 4.0},
expected: float32(2.5),
},
{
name: "float64 slice",
input: []float64{1.5, 2.5, 3.5, 4.5},
expected: float64(3.0),
},
{
name: "int32 slice (odd)",
input: []int32{1, 2, 3, 4, 5},
expected: int32(3),
},
{
name: "int32 slice (even)",
input: []int32{1, 2, 3, 4},
expected: int32(2),
},
{
name: "uint32 slice",
input: []uint32{2, 3, 4, 5},
expected: uint32(3),
},
{
name: "empty uint32 slice",
input: []uint32{},
expected: uint32(0),
},
{
name: "nil int32 slice",
input: ([]int32)(nil),
expected: int32(0),
},
}

for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()

switch input := tt.input.(type) {
case []float32:
result, err := MedianByErr(input, func(n float32) (float32, error) { return n, nil })
is.NoError(err)
is.InEpsilon(tt.expected.(float32), result, 1e-7)
case []float64:
result, err := MedianByErr(input, func(n float64) (float64, error) { return n, nil })
is.NoError(err)
is.InEpsilon(tt.expected.(float64), result, 1e-7)
case []int32:
result, err := MedianByErr(input, func(n int32) (int32, error) { return n, nil })
is.NoError(err)
is.Equal(tt.expected.(int32), result)
case []uint32:
result, err := MedianByErr(input, func(n uint32) (uint32, error) { return n, nil })
is.NoError(err)
is.Equal(tt.expected.(uint32), result)
}
})
}

// Test error cases - table driven
errorTests := []struct {
name string
input []int32
errorAt int32
expectedCalls int
}{
{
name: "error at third element",
input: []int32{1, 2, 3, 4, 5},
errorAt: 3,
expectedCalls: 3,
},
{
name: "error at first element",
input: []int32{1, 2, 3},
errorAt: 1,
expectedCalls: 1,
},
{
name: "error at last element",
input: []int32{1, 2, 3},
errorAt: 3,
expectedCalls: 3,
},
}

for _, tt := range errorTests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()

callbackCount := 0
result, err := MedianByErr(tt.input, func(n int32) (int32, error) {
callbackCount++
if n == tt.errorAt {
return 0, testErr
}
return n, nil
})
is.ErrorIs(err, testErr)
is.Equal(int32(0), result)
is.Equal(tt.expectedCalls, callbackCount)
})
}
}