Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions frontend/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ type API interface {
//
// If the absolute difference between the variables i1 and i2 is known, then
// it is more efficient to use the bounded methods in package
// [github.com/consensys/gnark/std/math/bits].
// [https://github.com/Consensys/gnark/blob/master/std/math/cmp].
Cmp(i1, i2 Variable) Variable

// ---------------------------------------------------------------------------------------------
Expand All @@ -121,7 +121,7 @@ type API interface {
//
// If the absolute difference between the variables b and bound is known, then
// it is more efficient to use the bounded methods in package
// [github.com/consensys/gnark/std/math/bits].
// [https://github.com/Consensys/gnark/blob/master/std/math/cmp].
AssertIsLessOrEqual(v Variable, bound Variable)

// Println behaves like fmt.Println but accepts cd.Variable as parameter
Expand Down
11 changes: 10 additions & 1 deletion std/math/cmp/bounded.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@ package cmp

import (
"fmt"
"math/big"

"github.com/consensys/gnark/constraint/solver"
"github.com/consensys/gnark/frontend"
"github.com/consensys/gnark/std/math/bits"
"math/big"
)

func init() {
Expand Down Expand Up @@ -151,6 +152,10 @@ func (bc BoundedComparator) AssertIsLess(a, b frontend.Variable) {
}

// IsLess returns 1 if a < b, and returns 0 if a >= b.
// When |a - b| >= 2^absDiffUpp.BitLen(), a panic is occurred,
// then the method has no return value, and a proof can not be generated.
// It is recommended to use the IsLess method to get a valid return value
// in https://github.com/Consensys/gnark/blob/master/std/math/cmp/generic.go
func (bc BoundedComparator) IsLess(a, b frontend.Variable) frontend.Variable {
res, err := bc.api.Compiler().NewHint(isLessOutputHint, 1, a, b)
if err != nil {
Expand All @@ -164,6 +169,10 @@ func (bc BoundedComparator) IsLess(a, b frontend.Variable) frontend.Variable {
}

// IsLessEq returns 1 if a <= b, and returns 0 if a > b.
// When |a - b| > 2^absDiffUpp.BitLen(), a panic is occurred,
// then the method has no return value, and a proof can not be generated.
// It is recommended to use the IsLessOrEqual method to get a valid return value
// in https://github.com/Consensys/gnark/blob/master/std/math/cmp/generic.go
func (bc BoundedComparator) IsLessEq(a, b frontend.Variable) frontend.Variable {
// a <= b <==> a < b + 1
return bc.IsLess(a, bc.api.Add(b, 1))
Expand Down
76 changes: 74 additions & 2 deletions std/math/cmp/bounded_test.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
package cmp_test

import (
"fmt"
"math/big"
"testing"

"github.com/consensys/gnark-crypto/ecc"
"github.com/consensys/gnark/frontend"
"github.com/consensys/gnark/std/math/cmp"
"github.com/consensys/gnark/test"
"math/big"
"testing"
)

func TestAssertIsLessEq(t *testing.T) {
Expand Down Expand Up @@ -143,3 +146,72 @@ func (c *minCircuit) Define(api frontend.API) error {

return nil
}

type boundedComparatorCircuit struct {
A frontend.Variable

WantIsLess int
WantIsLessEq int
Bound int
}

func (c *boundedComparatorCircuit) Define(api frontend.API) error {
comparator := cmp.NewBoundedComparator(api, big.NewInt(int64(c.Bound)), true)
if c.WantIsLess == 1 {
comparator.AssertIsLess(c.A, c.Bound)
}
if c.WantIsLessEq == 1 {
comparator.AssertIsLessEq(c.A, c.Bound)
}

api.AssertIsEqual(c.WantIsLess, comparator.IsLess(c.A, c.Bound))
api.AssertIsEqual(c.WantIsLessEq, comparator.IsLessEq(c.A, c.Bound))

return nil
}

type boundedComparatorTestCase struct {
A int

WantIsLess int
WantIsLessEq int
Bound int

expectedSuccess bool
}

func TestBoundedComparator(t *testing.T) {
assert := test.NewAssert(t)

var testCases []boundedComparatorTestCase
for bound := 2; bound <= 15; bound++ {
c := 1 << (big.NewInt(int64(bound)).BitLen())
for i := 0; i <= bound+5; i++ {
testCase := boundedComparatorTestCase{
A: i, Bound: bound, WantIsLess: 1, WantIsLessEq: 1, expectedSuccess: true}
if i >= bound {
testCase.WantIsLess = 0
if i > bound {
testCase.WantIsLessEq = 0
}
}
if i-bound >= c {
testCase.expectedSuccess = false
}
testCases = append(testCases, testCase)
}
}

for _, tc := range testCases {
assert.Run(func(assert *test.Assert) {
circuit := &boundedComparatorCircuit{Bound: tc.Bound, WantIsLess: tc.WantIsLess, WantIsLessEq: tc.WantIsLessEq}
assignment := &boundedComparatorCircuit{A: tc.A}
err := test.IsSolved(circuit, assignment, ecc.BN254.ScalarField())
if tc.expectedSuccess {
assert.NoError(err)
} else {
assert.Error(err)
}
}, fmt.Sprintf("bound=%d a=%d", tc.Bound, tc.A))
}
}