Skip to content

Commit e381dd5

Browse files
authored
Merge pull request #12 from Algebra8/itertools-count-float-input
Itertools count float input
2 parents d3af75a + b847cb7 commit e381dd5

File tree

2 files changed

+192
-51
lines changed

2 files changed

+192
-51
lines changed

itertools.go

Lines changed: 99 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -6,24 +6,98 @@ import (
66
"go.starlark.net/starlark"
77
)
88

9-
type countObject struct {
10-
cnt int
11-
step int
12-
frozen bool
13-
value starlark.Value
9+
// float or int type to allow mixed inputs.
10+
type floatOrInt struct {
11+
value starlark.Value
12+
}
13+
14+
// Unpacker for floatOrInt.
15+
func (p *floatOrInt) Unpack(v starlark.Value) error {
16+
switch v := v.(type) {
17+
case starlark.Int:
18+
p.value = v
19+
return nil
20+
case starlark.Float:
21+
p.value = v
22+
return nil
23+
}
24+
return fmt.Errorf("got %s, want float or int", v.Type())
25+
}
26+
27+
func (f *floatOrInt) add(n floatOrInt) error {
28+
switch _f := f.value.(type) {
29+
case starlark.Int:
30+
switch _n := n.value.(type) {
31+
// int + int
32+
case starlark.Int:
33+
f.value = _f.Add(_n)
34+
return nil
35+
// int + float
36+
case starlark.Float:
37+
_n += _f.Float()
38+
f.value = _n
39+
return nil
40+
}
41+
case starlark.Float:
42+
switch _n := n.value.(type) {
43+
// float + int
44+
case starlark.Int:
45+
_f += _n.Float()
46+
f.value = _f
47+
return nil
48+
// float + float
49+
case starlark.Float:
50+
_f += _n
51+
f.value = _f
52+
return nil
53+
}
54+
}
55+
56+
return fmt.Errorf("error with addition: types are not int, float combos")
57+
}
58+
59+
func (f *floatOrInt) String() string {
60+
return f.value.String()
1461
}
1562

16-
func newCountObject(cnt int, stepValue int) *countObject {
17-
return &countObject{cnt: cnt, step: stepValue, value: starlark.MakeInt(cnt)}
63+
// Iterator implementation for countObject.
64+
type countIter struct {
65+
co *countObject
1866
}
1967

20-
func (co *countObject) String() string {
68+
func (c *countIter) Next(p *starlark.Value) bool {
69+
if c.co.frozen {
70+
return false
71+
}
72+
73+
*p = c.co.cnt.value
74+
75+
if e := c.co.cnt.add(c.co.step); e != nil {
76+
return false
77+
}
78+
79+
return true
80+
}
81+
82+
func (c *countIter) Done() {}
83+
84+
// countObject implementation as a starlark.Value.
85+
type countObject struct {
86+
cnt, step floatOrInt
87+
frozen bool
88+
}
89+
90+
func (co countObject) String() string {
2191
// As with the cpython implementation, we don't display
22-
// step when it is an integer equal to 1.
23-
if co.step == 1 {
24-
return fmt.Sprintf("count(%v)", co.cnt)
92+
// step when it is an integer equal to 1 (default step value).
93+
step, ok := co.step.value.(starlark.Int)
94+
if ok {
95+
if x, ok := step.Int64(); ok && x == 1 {
96+
return fmt.Sprintf("count(%v)", co.cnt.String())
97+
}
2598
}
26-
return fmt.Sprintf("count(%v, %v)", co.cnt, co.step)
99+
100+
return fmt.Sprintf("count(%v, %v)", co.cnt.String(), co.step.String())
27101
}
28102

29103
func (co *countObject) Type() string {
@@ -33,7 +107,6 @@ func (co *countObject) Type() string {
33107
func (co *countObject) Freeze() {
34108
if !co.frozen {
35109
co.frozen = true
36-
co.value.Freeze()
37110
}
38111
}
39112

@@ -50,59 +123,35 @@ func (co *countObject) Iterate() starlark.Iterator {
50123
return &countIter{co: co}
51124
}
52125

53-
type countIter struct {
54-
co *countObject
55-
}
56-
57-
func (c *countIter) Next(p *starlark.Value) bool {
58-
if c.co.frozen {
59-
return false
60-
}
61-
*p = starlark.MakeInt(c.co.cnt)
62-
c.co.cnt += c.co.step
63-
return true
64-
}
65-
66-
func (c *countIter) Done() {}
67-
68126
func count_(
69127
thread *starlark.Thread,
70128
_ *starlark.Builtin,
71129
args starlark.Tuple,
72130
kwargs []starlark.Tuple,
73131
) (starlark.Value, error) {
74132
var (
75-
start int
76-
step int
133+
defaultStart = starlark.MakeInt(0)
134+
defaultStep = starlark.MakeInt(1)
135+
start floatOrInt
136+
step floatOrInt
77137
)
78138

79139
if err := starlark.UnpackPositionalArgs(
80140
"count", args, kwargs, 0, &start, &step,
81141
); err != nil {
82142
return nil, fmt.Errorf(
83-
"Got %v but expected NoneType or valid integer values for "+
84-
"start and step, such as (0, 1).", args.String(),
143+
"Got %v but expected no args, or one or two valid numbers",
144+
args.String(),
85145
)
86146
}
87147

88-
const (
89-
defaultStart = 0
90-
defaultStep = 1
91-
)
92-
// The rules for populating the count object based on the number
93-
// of args passed is as follows:
94-
// 0 args -> default values for start and step
95-
// 1 args -> arg defines start, default for step
96-
// 2 args -> both start and step are defined by args
97-
var co_ *countObject
98-
switch nargs := len(args); {
99-
case nargs == 0:
100-
co_ = newCountObject(defaultStart, defaultStep)
101-
case nargs == 1:
102-
co_ = newCountObject(start, defaultStep)
103-
default: // nargs == 2
104-
co_ = newCountObject(start, step)
148+
// Check if start or step require default values.
149+
if start.value == nil {
150+
start.value = defaultStart
151+
}
152+
if step.value == nil {
153+
step.value = defaultStep
105154
}
106155

107-
return co_, nil
156+
return &countObject{cnt: start, step: step}, nil
108157
}

testdata/itertools.star

Lines changed: 93 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,15 +29,107 @@ def test_count():
2929
assert.eq(str(c2), "count(11, 3)")
3030
assert.eq(next(c2), 11)
3131

32+
# Negative args.
33+
c3 = count(-5, -10)
34+
assert.eq(str(c3), "count(-5, -10)")
35+
assert.eq(next(c3), -5)
36+
assert.eq(str(c3), "count(-15, -10)")
37+
assert.eq(next(c3), -15)
38+
39+
c4 = count(5, -5)
40+
assert.eq(str(c4), "count(5, -5)")
41+
assert.eq(next(c4), 5)
42+
assert.eq(str(c4), "count(0, -5)")
43+
assert.eq(next(c4), 0)
44+
assert.eq(str(c4), "count(-5, -5)")
45+
assert.eq(next(c4), -5)
46+
47+
# Int start, float step.
48+
c5 = count(0, 0.1)
49+
assert.eq(str(c5), "count(0, 0.1)")
50+
assert.eq(next(c5), 0)
51+
assert.eq(str(c5), "count(0.1, 0.1)")
52+
assert.eq(next(c5), 0.1)
53+
assert.eq(str(c5), "count(0.2, 0.1)")
54+
assert.eq(next(c5), 0.2)
55+
56+
# Float start, int step — this should be handled same as above
57+
# but check to be exhaustive.
58+
c6 = count(0.5, 5)
59+
assert.eq(str(c6), "count(0.5, 5)")
60+
assert.eq(next(c6), 0.5)
61+
assert.eq(str(c6), "count(5.5, 5)")
62+
assert.eq(next(c6), 5.5)
63+
assert.eq(str(c6), "count(10.5, 5)")
64+
assert.eq(next(c6), 10.5)
65+
66+
# This test may seem similar to c5 but is different because
67+
# here step > 1. In the case that 0 < step < 1, fmt.Sprintf,
68+
# which is used in String(), will display it as a float but
69+
# may display it as an int if the proper flags aren't used.
70+
c7 = count(5.0, 0.5)
71+
assert.eq(str(c7), "count(5.0, 0.5)")
72+
assert.eq(next(c7), 5.0)
73+
assert.eq(str(c7), "count(5.5, 0.5)")
74+
assert.eq(next(c7), 5.5)
75+
assert.eq(str(c7), "count(6.0, 0.5)")
76+
assert.eq(next(c7), 6.0)
77+
78+
# NaNs
79+
c8 = count(0, float('nan'))
80+
assert.eq(str(c8), "count(0, %s)" % (float('nan')))
81+
assert.eq(next(c8), 0)
82+
assert.eq(str(c8), "count(%s, %s)" % (float('nan'), float('nan')))
83+
assert.eq(next(c8), float('nan'))
84+
assert.eq(str(c8), "count(%s, %s)" % (float('nan'), float('nan')))
85+
assert.eq(next(c8), float('nan'))
86+
87+
c9 = count(0, float("+inf"))
88+
assert.eq(str(c9), "count(0, %s)" % (float("+inf")))
89+
assert.eq(next(c9), 0)
90+
assert.eq(str(c9), "count(%s, %s)" % (float("+inf"), float("+inf")))
91+
assert.eq(next(c9), float("+inf"))
92+
assert.eq(str(c9), "count(%s, %s)" % (float("+inf"), float("+inf")))
93+
assert.eq(next(c9), float("+inf"))
94+
95+
c10 = count(0, float("-inf"))
96+
assert.eq(str(c10), "count(0, %s)" % (float("-inf")))
97+
assert.eq(next(c10), 0)
98+
assert.eq(str(c10), "count(%s, %s)" % (float("-inf"), float("-inf")))
99+
assert.eq(next(c10), float("-inf"))
100+
assert.eq(str(c10), "count(%s, %s)" % (float("-inf"), float("-inf")))
101+
assert.eq(next(c10), float("-inf"))
102+
103+
c11 = count(float("nan"), 2)
104+
assert.eq(str(c11), "count(%s, 2)" % (float('nan')))
105+
assert.eq(next(c11), float('nan'))
106+
assert.eq(str(c11), "count(%s, 2)" % (float('nan')))
107+
assert.eq(next(c11), float('nan'))
108+
109+
c12 = count(float("+inf"), 2)
110+
assert.eq(str(c12), "count(%s, 2)" % (float('+inf')))
111+
assert.eq(next(c12), float('+inf'))
112+
assert.eq(str(c12), "count(%s, 2)" % (float('+inf')))
113+
assert.eq(next(c12), float('+inf'))
114+
115+
c13 = count(float("-inf"), 2)
116+
assert.eq(str(c13), "count(%s, 2)" % (float('-inf')))
117+
assert.eq(next(c13), float('-inf'))
118+
assert.eq(str(c13), "count(%s, 2)" % (float('-inf')))
119+
assert.eq(next(c13), float('-inf'))
120+
32121
# Fails
33-
z = ("a", "b")
122+
# Non-numeric arg fails.
34123
assert.fails(
35124
lambda: count("a", "b"),
36125
# fails uses match under the hood, which will use
37126
# regexp.MatchString, so need to use raw pattern
38127
# that MatchString would accept.
39128
r'Got \(\"a\", \"b\"\)',
40129
)
130+
131+
# Too many arg fails — should be handled by UnpackArgs but
132+
# check to be exhaustive.
41133
assert.fails(
42134
lambda: count(1, 2, 3),
43135
r'Got \(1, 2, 3\)'

0 commit comments

Comments
 (0)