Skip to content

Commit 953da99

Browse files
authored
checked constructor for potential error conditions (#4)
* checked constructor for potential error conditions This commit introduces a variant of the NewChooser constructor that will error on conditions that could later cause a runtime issue during Pick(). The conditions handled are a lack of valid choices and a potential integer overflow in the running total. This is a proof of concept, but the final API may likely be different to avoid introducing extra complexity into the library. This commit merely serves as the intial code to aide a discussion in the PR. * checked constructor as new default * privatize sentinel errors for NewChooser I don't see a current use case scenario where being able to act upon these as sentinel errors would be significant, so better to avoid the API complexity and keep them private for now. Always easier to export them later if needed than taking them away once out in the wild. * docs: clean up variable names
1 parent 3b00289 commit 953da99

File tree

5 files changed

+117
-30
lines changed

5 files changed

+117
-30
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ import (
2222
func main() {
2323
rand.Seed(time.Now().UTC().UnixNano()) // always seed random!
2424

25-
c := wr.NewChooser(
25+
chooser, _ := wr.NewChooser(
2626
wr.Choice{Item: "🍒", Weight: 0},
2727
wr.Choice{Item: "🍋", Weight: 1},
2828
wr.Choice{Item: "🍊", Weight: 1},
@@ -33,7 +33,7 @@ func main() {
3333
probability, and 🥑 with 0.5 probability. 🍒 will never be printed. (Note
3434
the weights don't have to add up to 10, that was just done here to make the
3535
example easier to read.) */
36-
result := c.Pick().(string)
36+
result := chooser.Pick().(string)
3737
fmt.Println(result)
3838
}
3939
```

examples/compbench/bench_test.go

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,10 @@ func BenchmarkMultiple(b *testing.B) {
3232
for n := BMMinChoices; n <= BMMaxChoices; n *= 10 {
3333
b.Run(strconv.Itoa(n), func(b *testing.B) {
3434
choices := mockChoices(b, n)
35-
chs := weightedrand.NewChooser(choices...)
35+
chs, err := weightedrand.NewChooser(choices...)
36+
if err != nil {
37+
b.Fatal(err)
38+
}
3639
b.ResetTimer()
3740
for i := 0; i < b.N; i++ {
3841
chs.Pick()
@@ -45,7 +48,10 @@ func BenchmarkMultiple(b *testing.B) {
4548
for n := BMMinChoices; n <= BMMaxChoices; n *= 10 {
4649
b.Run(strconv.Itoa(n), func(b *testing.B) {
4750
choices := mockChoices(b, n)
48-
chs := weightedrand.NewChooser(choices...)
51+
chs, err := weightedrand.NewChooser(choices...)
52+
if err != nil {
53+
b.Fatal(err)
54+
}
4955
b.ResetTimer()
5056
b.RunParallel(func(pb *testing.PB) {
5157
rs := rand.New(rand.NewSource(time.Now().UTC().UnixNano()))
@@ -84,7 +90,7 @@ func BenchmarkSingle(b *testing.B) {
8490
choices := mockChoices(b, n)
8591
b.ResetTimer()
8692
for i := 0; i < b.N; i++ {
87-
chs := weightedrand.NewChooser(choices...)
93+
chs, _ := weightedrand.NewChooser(choices...)
8894
chs.Pick()
8995
}
9096
})

examples/frequency/main.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package main
22

33
import (
44
"fmt"
5+
"log"
56
"math/rand"
67
"time"
78

@@ -11,13 +12,16 @@ import (
1112
func main() {
1213
rand.Seed(time.Now().UTC().UnixNano()) // always seed random!
1314

14-
c := wr.NewChooser(
15+
c, err := wr.NewChooser(
1516
wr.Choice{Item: '🍒', Weight: 0}, // alternatively: wr.NewChoice('🍒', 0)
1617
wr.Choice{Item: '🍋', Weight: 1},
1718
wr.Choice{Item: '🍊', Weight: 1},
1819
wr.Choice{Item: '🍉', Weight: 3},
1920
wr.Choice{Item: '🥑', Weight: 5},
2021
)
22+
if err != nil {
23+
log.Fatal(err)
24+
}
2125

2226
/* Let's pick a bunch of fruits so we can see the distribution in action! */
2327
fruits := make([]rune, 40*18)

weightedrand.go

Lines changed: 46 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
package weightedrand
1212

1313
import (
14+
"errors"
1415
"math/rand"
1516
"sort"
1617
)
@@ -34,28 +35,55 @@ type Chooser struct {
3435
max int
3536
}
3637

37-
// NewChooser initializes a new Chooser for picking from the provided Choices.
38-
func NewChooser(cs ...Choice) Chooser {
39-
sort.Slice(cs, func(i, j int) bool {
40-
return cs[i].Weight < cs[j].Weight
38+
// NewChooser initializes a new Chooser for picking from the provided choices.
39+
func NewChooser(choices ...Choice) (*Chooser, error) {
40+
sort.Slice(choices, func(i, j int) bool {
41+
return choices[i].Weight < choices[j].Weight
4142
})
42-
totals := make([]int, len(cs))
43+
44+
totals := make([]int, len(choices))
4345
runningTotal := 0
44-
for i, c := range cs {
45-
runningTotal += int(c.Weight)
46+
for i, c := range choices {
47+
weight := int(c.Weight)
48+
if (maxInt - runningTotal) <= weight {
49+
return nil, errWeightOverflow
50+
}
51+
runningTotal += weight
4652
totals[i] = runningTotal
4753
}
48-
return Chooser{data: cs, totals: totals, max: runningTotal}
54+
55+
if runningTotal <= 1 {
56+
return nil, errNoValidChoices
57+
}
58+
59+
return &Chooser{data: choices, totals: totals, max: runningTotal}, nil
4960
}
5061

62+
const (
63+
intSize = 32 << (^uint(0) >> 63) // cf. strconv.IntSize
64+
maxInt = 1<<(intSize-1) - 1
65+
)
66+
67+
// Possible errors returned by NewChooser, preventing the creation of a Chooser
68+
// with unsafe runtime states.
69+
var (
70+
// If the sum of provided Choice weights exceed the maximum integer value
71+
// for the current platform (e.g. math.MaxInt32 or math.MaxInt64), then
72+
// the internal running total will overflow, resulting in an imbalanced
73+
// distribution generating improper results.
74+
errWeightOverflow = errors.New("sum of Choice Weights exceeds max int")
75+
// If there are no Choices available to the Chooser with a weight >= 1,
76+
// there are no valid choices and Pick would produce a runtime panic.
77+
errNoValidChoices = errors.New("zero Choices with Weight >= 1")
78+
)
79+
5180
// Pick returns a single weighted random Choice.Item from the Chooser.
5281
//
53-
// Utilizes global rand as the source of randomness -- you will likely want to
54-
// seed it.
55-
func (chs Chooser) Pick() interface{} {
56-
r := rand.Intn(chs.max) + 1
57-
i := searchInts(chs.totals, r)
58-
return chs.data[i].Item
82+
// Utilizes global rand as the source of randomness.
83+
func (c Chooser) Pick() interface{} {
84+
r := rand.Intn(c.max) + 1
85+
i := searchInts(c.totals, r)
86+
return c.data[i].Item
5987
}
6088

6189
// PickSource returns a single weighted random Choice.Item from the Chooser,
@@ -67,10 +95,10 @@ func (chs Chooser) Pick() interface{} {
6795
//
6896
// It is the responsibility of the caller to ensure the provided rand.Source is
6997
// free from thread safety issues.
70-
func (chs Chooser) PickSource(rs *rand.Rand) interface{} {
71-
r := rs.Intn(chs.max) + 1
72-
i := searchInts(chs.totals, r)
73-
return chs.data[i].Item
98+
func (c Chooser) PickSource(rs *rand.Rand) interface{} {
99+
r := rs.Intn(c.max) + 1
100+
i := searchInts(c.totals, r)
101+
return c.data[i].Item
74102
}
75103

76104
// The standard library sort.SearchInts() just wraps the generic sort.Search()

weightedrand_test.go

Lines changed: 55 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ import (
1818
// not on any absolute scoring system. In this trivial case, we will assign a
1919
// weight of 0 to all but one fruit, so that the output will be predictable.
2020
func Example() {
21-
chooser := NewChooser(
21+
chooser, _ := NewChooser(
2222
NewChoice('🍋', 0),
2323
NewChoice('🍊', 0),
2424
NewChoice('🍉', 0),
@@ -42,12 +42,52 @@ const (
4242
testIterations = 1000000
4343
)
4444

45+
func TestNewChooser(t *testing.T) {
46+
tests := []struct {
47+
name string
48+
cs []Choice
49+
wantErr error
50+
}{
51+
{
52+
name: "zero choices",
53+
cs: []Choice{},
54+
wantErr: errNoValidChoices,
55+
},
56+
{
57+
name: "no choices with positive weight",
58+
cs: []Choice{{Item: 'a', Weight: 0}, {Item: 'b', Weight: 0}},
59+
wantErr: errNoValidChoices,
60+
},
61+
{
62+
name: "weight overflow",
63+
cs: []Choice{{Item: 'a', Weight: maxInt/2 + 1}, {Item: 'b', Weight: maxInt/2 + 1}},
64+
wantErr: errWeightOverflow,
65+
},
66+
{
67+
name: "nominal case",
68+
cs: []Choice{{Item: 'a', Weight: 1}, {Item: 'b', Weight: 2}},
69+
wantErr: nil,
70+
},
71+
}
72+
for _, tt := range tests {
73+
t.Run(tt.name, func(t *testing.T) {
74+
_, err := NewChooser(tt.cs...)
75+
if err != tt.wantErr {
76+
t.Errorf("NewChooser() error = %v, wantErr %v", err, tt.wantErr)
77+
}
78+
})
79+
}
80+
}
81+
4582
// TestChooser_Pick assembles a list of Choices, weighted 0-9, and tests that
4683
// over the course of 1,000,000 calls to Pick() each choice is returned more
4784
// often than choices with a lower weight.
4885
func TestChooser_Pick(t *testing.T) {
4986
choices := mockFrequencyChoices(t, testChoices)
50-
chooser := NewChooser(choices...)
87+
chooser, err := NewChooser(choices...)
88+
if err != nil {
89+
t.Fatal(err)
90+
}
5191
t.Log("totals in chooser", chooser.totals)
5292

5393
// run Pick() a million times, and record how often it returns each of the
@@ -67,7 +107,10 @@ func TestChooser_Pick(t *testing.T) {
67107
// randomness.
68108
func TestChooser_PickSource(t *testing.T) {
69109
choices := mockFrequencyChoices(t, testChoices)
70-
chooser := NewChooser(choices...)
110+
chooser, err := NewChooser(choices...)
111+
if err != nil {
112+
t.Fatal(err)
113+
}
71114
t.Log("totals in chooser", chooser.totals)
72115

73116
counts1 := make(map[int]int)
@@ -137,7 +180,7 @@ func BenchmarkNewChooser(b *testing.B) {
137180
b.ResetTimer()
138181

139182
for i := 0; i < b.N; i++ {
140-
_ = NewChooser(choices...)
183+
_, _ = NewChooser(choices...)
141184
}
142185
})
143186
}
@@ -147,7 +190,10 @@ func BenchmarkPick(b *testing.B) {
147190
for n := BMMinChoices; n <= BMMaxChoices; n *= 10 {
148191
b.Run(strconv.Itoa(n), func(b *testing.B) {
149192
choices := mockChoices(n)
150-
chooser := NewChooser(choices...)
193+
chooser, err := NewChooser(choices...)
194+
if err != nil {
195+
b.Fatal(err)
196+
}
151197
b.ResetTimer()
152198

153199
for i := 0; i < b.N; i++ {
@@ -161,7 +207,10 @@ func BenchmarkPickParallel(b *testing.B) {
161207
for n := BMMinChoices; n <= BMMaxChoices; n *= 10 {
162208
b.Run(strconv.Itoa(n), func(b *testing.B) {
163209
choices := mockChoices(n)
164-
chooser := NewChooser(choices...)
210+
chooser, err := NewChooser(choices...)
211+
if err != nil {
212+
b.Fatal(err)
213+
}
165214
b.ResetTimer()
166215
b.RunParallel(func(pb *testing.PB) {
167216
rs := rand.New(rand.NewSource(time.Now().UTC().UnixNano()))

0 commit comments

Comments
 (0)