Skip to content

Commit beeed05

Browse files
authored
Add version constraints (#4)
Signed-off-by: Kimmo Lehto <[email protected]>
1 parent a97a29d commit beeed05

File tree

4 files changed

+320
-0
lines changed

4 files changed

+320
-0
lines changed

constraint.go

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
package version
2+
3+
import (
4+
"errors"
5+
"fmt"
6+
"regexp"
7+
"strings"
8+
)
9+
10+
var constraintRegex = regexp.MustCompile(`^(?:(>=|>|<=|<|!=|==?)\s*)?(.+)$`)
11+
12+
type constraintFunc func(a, b *Version) bool
13+
type constraint struct {
14+
f constraintFunc
15+
b *Version
16+
original string
17+
}
18+
19+
// Constraints is a collection of version constraint rules that can be checked against a version.
20+
type Constraints []constraint
21+
22+
// NewConstraint parses a string into a Constraints object that can be used to check
23+
// if a given version satisfies the constraint.
24+
func NewConstraint(cs string) (Constraints, error) {
25+
parts := strings.Split(cs, ",")
26+
newC := make(Constraints, len(parts))
27+
for i, p := range parts {
28+
parts[i] = strings.TrimSpace(p)
29+
}
30+
for i, p := range parts {
31+
c, err := newConstraint(p)
32+
if err != nil {
33+
return Constraints{}, err
34+
}
35+
newC[i] = c
36+
}
37+
38+
return newC, nil
39+
}
40+
41+
// MustConstraint is like NewConstraint but panics if the constraint is invalid.
42+
func MustConstraint(cs string) Constraints {
43+
c, err := NewConstraint(cs)
44+
if err != nil {
45+
panic("github.com/k0sproject/version: NewConstraint: " + err.Error())
46+
}
47+
return c
48+
}
49+
50+
// Check returns true if the given version satisfies all of the constraints.
51+
func (cs Constraints) Check(v *Version) bool {
52+
for _, c := range cs {
53+
if c.b.Prerelease() == "" && v.Prerelease() != "" {
54+
return false
55+
}
56+
if !c.f(c.b, v) {
57+
return false
58+
}
59+
}
60+
61+
return true
62+
}
63+
64+
// CheckString is like Check but takes a string version. If the version is invalid,
65+
// it returns false.
66+
func (cs Constraints) CheckString(v string) bool {
67+
vv, err := NewVersion(v)
68+
if err != nil {
69+
return false
70+
}
71+
return cs.Check(vv)
72+
}
73+
74+
// String returns the original constraint string.
75+
func (c *constraint) String() string {
76+
return c.original
77+
}
78+
79+
func newConstraint(s string) (constraint, error) {
80+
match := constraintRegex.FindStringSubmatch(s)
81+
if len(match) != 3 {
82+
return constraint{}, errors.New("invalid constraint: " + s)
83+
}
84+
85+
op := match[1]
86+
f, err := opfunc(op)
87+
if err != nil {
88+
return constraint{}, err
89+
}
90+
91+
// convert one or two digit constraints to threes digit unless it's an equality operation
92+
if op != "" && op != "=" && op != "==" {
93+
segments := strings.Split(match[2], ".")
94+
if len(segments) < 3 {
95+
lastSegment := segments[len(segments)-1]
96+
var pre string
97+
if strings.Contains(lastSegment, "-") {
98+
parts := strings.Split(lastSegment, "-")
99+
segments[len(segments)-1] = parts[0]
100+
pre = "-" + parts[1]
101+
}
102+
switch len(segments) {
103+
case 1:
104+
// >= 1 becomes >= 1.0.0
105+
// >= 1-rc.1 becomes >= 1.0.0-rc.1
106+
return newConstraint(fmt.Sprintf("%s %s.0.0%s", op, segments[0], pre))
107+
case 2:
108+
// >= 1.1 becomes >= 1.1.0
109+
// >= 1.1-rc.1 becomes >= 1.1.0-rc.1
110+
return newConstraint(fmt.Sprintf("%s %s.%s.0%s", op, segments[0], segments[1], pre))
111+
}
112+
}
113+
}
114+
115+
target, err := NewVersion(match[2])
116+
if err != nil {
117+
return constraint{}, err
118+
}
119+
120+
return constraint{f: f, b: target, original: s}, nil
121+
}
122+
123+
func opfunc(s string) (constraintFunc, error) {
124+
switch s {
125+
case "", "=", "==":
126+
return eq, nil
127+
case ">":
128+
return gt, nil
129+
case ">=":
130+
return gte, nil
131+
case "<":
132+
return lt, nil
133+
case "<=":
134+
return lte, nil
135+
case "!=":
136+
return neq, nil
137+
default:
138+
return nil, errors.New("invalid operator: " + s)
139+
}
140+
}
141+
142+
func gt(a, b *Version) bool { return b.GreaterThan(a) }
143+
func lt(a, b *Version) bool { return b.LessThan(a) }
144+
func gte(a, b *Version) bool { return b.GreaterThanOrEqual(a) }
145+
func lte(a, b *Version) bool { return b.LessThanOrEqual(a) }
146+
func eq(a, b *Version) bool { return b.Equal(a) }
147+
func neq(a, b *Version) bool { return !b.Equal(a) }
148+

constraint_test.go

Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
package version
2+
3+
import (
4+
"fmt"
5+
"testing"
6+
7+
"github.com/stretchr/testify/assert"
8+
)
9+
10+
func TestConstraint(t *testing.T) {
11+
type testCase struct {
12+
constraint string
13+
truthTable map[bool][]string
14+
}
15+
16+
testCases := []testCase{
17+
{
18+
constraint: ">= 1.1.0-beta.1+k0s.1",
19+
truthTable: map[bool][]string{
20+
true: {
21+
"1.1.0+k0s.0",
22+
"1.1.0-rc.1+k0s.0",
23+
"1.1.1+k0s.0",
24+
"1.1.1-rc.1+k0s.0",
25+
},
26+
false: {
27+
"1.1.0-alpha.1+k0s.2",
28+
"1.0.1+k0s.10",
29+
},
30+
},
31+
},
32+
{
33+
constraint: ">= 1.1.0+k0s.1",
34+
truthTable: map[bool][]string{
35+
true: {
36+
"1.1.0+k0s.1",
37+
"1.1.0+k0s.2",
38+
"1.1.1+k0s.0",
39+
},
40+
false: {
41+
"1.0.9+k0s.255",
42+
"1.1.0+k0s.0",
43+
},
44+
},
45+
},
46+
// simple operator checks
47+
{
48+
constraint: "= 1.0.0",
49+
truthTable: map[bool][]string{
50+
true: {"1.0.0"},
51+
false: {"1.0.1", "0.9.9"},
52+
},
53+
},
54+
{
55+
constraint: "1.0.0",
56+
truthTable: map[bool][]string{
57+
true: {"1.0.0"},
58+
false: {"1.0.1", "0.9.9"},
59+
},
60+
},
61+
{
62+
constraint: "!= 1.0.0",
63+
truthTable: map[bool][]string{
64+
true: {"1.0.1", "0.9.9"},
65+
false: {"1.0.0"},
66+
},
67+
},
68+
{
69+
constraint: "> 1.0.0",
70+
truthTable: map[bool][]string{
71+
true: {"1.0.1", "1.1.0"},
72+
false: {"1.0.0", "0.9.9"},
73+
},
74+
},
75+
{
76+
constraint: "< 1.0.0",
77+
truthTable: map[bool][]string{
78+
true: {"0.9.9", "0.9.8"},
79+
false: {"1.0.0", "1.0.1"},
80+
},
81+
},
82+
{
83+
constraint: ">= 1.0.0",
84+
truthTable: map[bool][]string{
85+
true: {"1.0.0", "1.0.1"},
86+
false: {"0.9.9"},
87+
},
88+
},
89+
{
90+
constraint: "<= 1.0.0",
91+
truthTable: map[bool][]string{
92+
true: {"1.0.0", "0.9.9"},
93+
false: {"1.0.1"},
94+
},
95+
},
96+
// two digit constraints
97+
{
98+
constraint: ">= 1.0",
99+
truthTable: map[bool][]string{
100+
true: {"1.0.0", "1.0.1", "1.1.0"},
101+
false: {"0.9.9", "1.0.1-alpha.1"},
102+
},
103+
},
104+
{
105+
constraint: ">= 1.0-a",
106+
truthTable: map[bool][]string{
107+
true: {"1.0.0", "1.0.1", "1.0.0-alpha.1"},
108+
false: {"0.9.9"},
109+
},
110+
},
111+
}
112+
113+
for _, tc := range testCases {
114+
t.Run(tc.constraint, func(t *testing.T) {
115+
c, err := NewConstraint(tc.constraint)
116+
assert.NoError(t, err)
117+
118+
for expected, versions := range tc.truthTable {
119+
t.Run(fmt.Sprintf("%t", expected), func(t *testing.T) {
120+
for _, version := range versions {
121+
t.Run(version, func(t *testing.T) {
122+
assert.Equal(t, expected, c.Check(MustParse(version)))
123+
})
124+
}
125+
})
126+
}
127+
})
128+
}
129+
}
130+
131+
func TestInvalidConstraint(t *testing.T) {
132+
invalidConstraints := []string{
133+
"",
134+
"==",
135+
">= ",
136+
"invalid",
137+
">= abc",
138+
}
139+
140+
for _, invalidConstraint := range invalidConstraints {
141+
_, err := newConstraint(invalidConstraint)
142+
assert.Error(t, err, "Expected error for invalid constraint: "+invalidConstraint)
143+
}
144+
}
145+
146+
func TestCheckString(t *testing.T) {
147+
c, err := NewConstraint(">= 1.0.0")
148+
assert.NoError(t, err)
149+
150+
assert.True(t, c.CheckString("1.0.0"))
151+
assert.False(t, c.CheckString("0.9.9"))
152+
assert.False(t, c.CheckString("x"))
153+
}

version.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,11 @@ func (v *Version) UnmarshalJSON(b []byte) error {
147147
})
148148
}
149149

150+
// Satisfies returns true if the version satisfies the supplied constraint
151+
func (v *Version) Satisfies(constraint Constraints) bool {
152+
return constraint.Check(v)
153+
}
154+
150155
// NewVersion returns a new Version created from the supplied string or an error if the string is not a valid version number
151156
func NewVersion(v string) (*Version, error) {
152157
n, err := goversion.NewVersion(strings.TrimPrefix(v, "v"))

version_test.go

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,20 @@ func TestK0sComparison(t *testing.T) {
3838
assert.False(t, b.Equal(a), "version %s should not be equal to %s", b, a)
3939
}
4040

41+
func TestSatisfies(t *testing.T) {
42+
v, err := NewVersion("1.23.1+k0s.1")
43+
assert.NoError(t, err)
44+
assert.True(t, v.Satisfies(MustConstraint(">=1.23.1")))
45+
assert.True(t, v.Satisfies(MustConstraint(">=1.23.1+k0s.0")))
46+
assert.True(t, v.Satisfies(MustConstraint(">=1.23.1+k0s.1")))
47+
assert.True(t, v.Satisfies(MustConstraint("=1.23.1+k0s.1")))
48+
assert.True(t, v.Satisfies(MustConstraint("<1.23.1+k0s.2")))
49+
assert.False(t, v.Satisfies(MustConstraint(">=1.23.1+k0s.2")))
50+
assert.False(t, v.Satisfies(MustConstraint(">=1.23.2")))
51+
assert.False(t, v.Satisfies(MustConstraint(">1.23.1+k0s.1")))
52+
assert.False(t, v.Satisfies(MustConstraint("<1.23.1+k0s.1")))
53+
}
54+
4155
func TestURLs(t *testing.T) {
4256
a, err := NewVersion("1.23.3+k0s.1")
4357
assert.NoError(t, err)

0 commit comments

Comments
 (0)