Skip to content

Commit

Permalink
feat: Add T.Context() (#77)
Browse files Browse the repository at this point in the history
* feat: Add T.Context()

*Context*:
Go 1.24 adds a `Context()` method to `testing.TB` that returns
a context bound to the current test's lifetime.
The context becomes invalid _before_ `T.Cleanup` functions are run.

*Details*:
This change adds a similar `Context()` method to `rapid.T`,
except this context is only valid for the duration of one iteration
of a rapid check.

*Implementation notes*:
This changes `newT` to return a `cancel` function
instead of just adding a `cancel` method to `T`
to ensure that all callers of `newT` remember to call `cancel`.

This also uses IIFEs (immediately-invoked function expressions)
in a couple places to rely on well-timed `defer` calls for cleanup
instead of manually calling `cancel`.

*Future work*:
The logic added in this commit will make it relatively straightforward
to add a `Cleanup` method (#62).

* lazy init, cleanup in checkOnce, maybeValue, example

Per GitHub comment, delete the `cancel` return value from newT,
instead add a single `cleanup` method.

The method is called for cleanup in three places:

- checkOnce: this is per property
- maybeValue: this is per Custom generator function call
- example: this is per Example call

Context is now initialized lazily:
if there isn't one, it is created.

* fix: data race in t.ctx access
  • Loading branch information
abhinav authored Feb 20, 2025
1 parent ecc839f commit c3c5e3c
Show file tree
Hide file tree
Showing 8 changed files with 170 additions and 3 deletions.
1 change: 1 addition & 0 deletions combinators.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ func (g *customGen[V]) value(t *T) V {

func (g *customGen[V]) maybeValue(t *T) (V, bool) {
t = newT(t.tb, t.s, flags.debug, nil)
defer t.cleanup()

defer func() {
if r := recover(); r != nil {
Expand Down
35 changes: 35 additions & 0 deletions combinators_external_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
package rapid_test

import (
"context"
"errors"
"fmt"
"strconv"
"testing"
Expand Down Expand Up @@ -60,6 +62,39 @@ func TestCustom(t *testing.T) {
}
}

func TestCustomContext(t *testing.T) {
t.Parallel()

type key struct{}

gen := Custom(func(t *T) context.Context {
ctx := t.Context()

// Inside the custom generator, the context must be valid.
if err := ctx.Err(); err != nil {
t.Fatalf("context must be valid: %v", err)
}

x := Int().Draw(t, "x")
return context.WithValue(ctx, key{}, x)
})

Check(t, func(t *T) {
ctx := gen.Draw(t, "value")

if _, ok := ctx.Value(key{}).(int); !ok {
t.Fatalf("context must contain an int")
}

// Outside the custom generator,
// the context from inside the generator
// must no longer be valid.
if err := ctx.Err(); err == nil || !errors.Is(err, context.Canceled) {
t.Fatalf("context must be canceled: %v", err)
}
})
}

func TestFilter(t *testing.T) {
t.Parallel()

Expand Down
58 changes: 57 additions & 1 deletion engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ package rapid

import (
"bytes"
"context"
"encoding/binary"
"flag"
"fmt"
Expand Down Expand Up @@ -365,6 +366,7 @@ func checkOnce(t *T, prop func(*T)) (err *testError) {
}
defer func() { err = panicToError(recover(), 3) }()

defer t.cleanup()
prop(t)
t.failOnError()

Expand Down Expand Up @@ -500,7 +502,11 @@ func (nilTB) Failed() bool { panic("call to TB.Failed() outside a test"
// If concurrency is unavoidable, methods on *T, such as [*testing.T.Helper] and [*T.Errorf],
// are safe for concurrent calls, but *Generator.Draw from a given *T is not.
type T struct {
tb // unnamed to force re-export of (*T).Helper()
tb // unnamed to force re-export of (*T).Helper()

ctx context.Context
cancelCtx context.CancelFunc

tbLog bool
rawLog *log.Logger
s bitStream
Expand Down Expand Up @@ -539,6 +545,56 @@ func (t *T) shouldLog() bool {
return t.rawLog != nil || t.tbLog
}

// Context returns a context.Context associated with the test.
// It is valid only for the duration of the rapid check.
func (t *T) Context() context.Context {
// Fast path: no need to lock if the context is already set.
t.mu.RLock()
ctx := t.ctx
t.mu.RUnlock()
if ctx != nil {
return ctx
}

// Slow path: lock and check again, create new context if needed.
t.mu.Lock()
defer t.mu.Unlock()

if t.ctx != nil {
// Another goroutine set the context
// while we were waiting for the lock.
return t.ctx
}

// Use the testing.TB's context as the starting point if available,
// and the Background context if not.
//
// T.Context was added in Go 1.24.
if tctx, ok := t.tb.(interface{ Context() context.Context }); ok {
ctx = tctx.Context()
} else {
ctx = context.Background()
}

ctx, cancel := context.WithCancel(ctx)
t.ctx = ctx
t.cancelCtx = cancel
return ctx
}

// cleanup runs any cleanup tasks associated with the property check.
// It is safe to call multiple times.
func (t *T) cleanup() {
t.mu.Lock()
defer t.mu.Unlock()

if t.cancelCtx != nil {
t.cancelCtx()
t.cancelCtx = nil
t.ctx = nil
}
}

func (t *T) Logf(format string, args ...any) {
if t.rawLog != nil {
t.rawLog.Printf(format, args...)
Expand Down
26 changes: 26 additions & 0 deletions engine_fuzz_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
package rapid_test

import (
"context"
"testing"

. "pgregory.net/rapid"
Expand Down Expand Up @@ -66,3 +67,28 @@ func FuzzInt(f *testing.F) { f.Fuzz(MakeFuzz(checkInt)) }
func FuzzSlice(f *testing.F) { f.Fuzz(MakeFuzz(checkSlice)) }
func FuzzString(f *testing.F) { f.Fuzz(MakeFuzz(checkString)) }
func FuzzStuckStateMachine(f *testing.F) { f.Fuzz(MakeFuzz(checkStuckStateMachine)) }

func FuzzContext(f *testing.F) {
type key struct{}

var ctx context.Context
f.Fuzz(MakeFuzz(func(t *T) {
// Assign to outer variable
// so we can check it after the fuzzing.
ctx = context.WithValue(t.Context(), key{}, "value")
if err := ctx.Err(); err != nil {
t.Fatalf("context must be valid: %v", err)
}
}))

// ctx is set only if the fuzzing function was called.
if ctx != nil {
if err := ctx.Err(); err == nil {
f.Fatalf("context must be canceled")
}

if want, got := "value", ctx.Value(key{}); want != got {
f.Fatalf("context must have value %q, got %q", want, got)
}
}
}
22 changes: 22 additions & 0 deletions engine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
package rapid

import (
"context"
"errors"
"strings"
"testing"
)
Expand Down Expand Up @@ -93,3 +95,23 @@ func BenchmarkCheckOverhead(b *testing.B) {
checkTB(b, deadline, f)
}
}

func TestCheckContext(t *testing.T) {
type key struct{}

var ctx context.Context
Check(t, func(t *T) {
ctx = context.WithValue(t.Context(), key{}, Int().Draw(t, "x"))
if err := ctx.Err(); err != nil {
t.Fatalf("unexpected error: %v", err)
}
})

if err := ctx.Err(); err == nil || !errors.Is(err, context.Canceled) {
t.Fatalf("expected context to be canceled, got: %v", err)
}

if _, ok := ctx.Value(key{}).(int); !ok {
t.Fatalf("context must have a value")
}
}
2 changes: 2 additions & 0 deletions generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,8 @@ func (g *Generator[V]) AsAny() *Generator[any] {
}

func example[V any](g *Generator[V], t *T) (V, int, error) {
defer t.cleanup()

for i := 1; ; i++ {
r, err := recoverValue(g, t)
if err == nil {
Expand Down
28 changes: 27 additions & 1 deletion generator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@

package rapid

import "testing"
import (
"context"
"errors"
"testing"
)

type trivialGenImpl struct{}

Expand Down Expand Up @@ -41,3 +45,25 @@ func TestExampleHelper(t *testing.T) {

g.Example(0)
}

func TestExampleContext(t *testing.T) {
type key struct{}

g := Custom(func(t *T) context.Context {
ctx := context.WithValue(t.Context(), key{}, Int().Draw(t, "x"))
if err := ctx.Err(); err != nil {
t.Fatalf("unexpected error: %v", err)
}
return ctx
})

ctx := g.Example(0)

if err := ctx.Err(); err == nil || !errors.Is(err, context.Canceled) {
t.Fatalf("expected context to be canceled, got: %v", err)
}

if _, ok := ctx.Value(key{}).(int); !ok {
t.Fatalf("context must have a value")
}
}
1 change: 0 additions & 1 deletion vis_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ func TestDataVis(t *testing.T) {
f, err := os.Create("vis-test.html")
if err != nil {
t.Fatalf("failed to create vis html file: %v", err)

}
defer func() { _ = f.Close() }()

Expand Down

0 comments on commit c3c5e3c

Please sign in to comment.