diff --git a/tests/robustness/main_test.go b/tests/robustness/main_test.go index 6863cfbb17b..5cda05cfe48 100644 --- a/tests/robustness/main_test.go +++ b/tests/robustness/main_test.go @@ -16,12 +16,12 @@ package robustness import ( "context" + "sync" "testing" "time" "go.uber.org/zap" "go.uber.org/zap/zaptest" - "golang.org/x/sync/errgroup" "go.etcd.io/etcd/tests/v3/framework" "go.etcd.io/etcd/tests/v3/framework/e2e" @@ -108,7 +108,7 @@ func testRobustness(ctx context.Context, t *testing.T, lg *zap.Logger, s testSce func (s testScenario) run(ctx context.Context, t *testing.T, lg *zap.Logger, clus *e2e.EtcdProcessCluster) (reports []report.ClientReport) { ctx, cancel := context.WithCancel(ctx) defer cancel() - g := errgroup.Group{} + wg := sync.WaitGroup{} var operationReport, watchReport []report.ClientReport failpointInjected := make(chan failpoint.InjectionReport, 1) @@ -116,7 +116,9 @@ func (s testScenario) run(ctx context.Context, t *testing.T, lg *zap.Logger, clu // see https://github.com/golang/go/blob/master/src/time/time.go#L17 baseTime := time.Now() ids := identity.NewIDProvider() - g.Go(func() error { + wg.Add(3) + go func() { + defer wg.Done() defer close(failpointInjected) // Give some time for traffic to reach qps target before injecting failpoint. time.Sleep(time.Second) @@ -130,22 +132,21 @@ func (s testScenario) run(ctx context.Context, t *testing.T, lg *zap.Logger, clu if fr != nil { failpointInjected <- *fr } - return nil - }) + }() maxRevisionChan := make(chan int64, 1) - g.Go(func() error { + go func() { + defer wg.Done() defer close(maxRevisionChan) operationReport = traffic.SimulateTraffic(ctx, t, lg, clus, s.profile, s.traffic, failpointInjected, baseTime, ids) maxRevision := operationsMaxRevision(operationReport) maxRevisionChan <- maxRevision lg.Info("Finished simulating traffic", zap.Int64("max-revision", maxRevision)) - return nil - }) - g.Go(func() error { + }() + go func() { + defer wg.Done() watchReport = collectClusterWatchEvents(ctx, t, clus, maxRevisionChan, s.watch, baseTime, ids) - return nil - }) - g.Wait() + }() + wg.Wait() return append(operationReport, watchReport...) } diff --git a/tests/robustness/traffic/kubernetes.go b/tests/robustness/traffic/kubernetes.go index 2f065d84e9a..2744cc8f94c 100644 --- a/tests/robustness/traffic/kubernetes.go +++ b/tests/robustness/traffic/kubernetes.go @@ -21,7 +21,6 @@ import ( "math/rand" "sync" - "golang.org/x/sync/errgroup" "golang.org/x/time/rate" "go.etcd.io/etcd/api/v3/mvccpb" @@ -62,16 +61,18 @@ func (t kubernetesTraffic) Run(ctx context.Context, c *RecordingClient, limiter kc := &kubernetesClient{client: c} s := newStorage() keyPrefix := "/registry/" + t.resource + "/" - g := errgroup.Group{} + wg := sync.WaitGroup{} readLimit := t.averageKeyCount - g.Go(func() error { + wg.Add(2) + go func() { + defer wg.Done() for { select { case <-ctx.Done(): - return ctx.Err() + return case <-finish: - return nil + return default: } rev, err := t.Read(ctx, kc, s, limiter, keyPrefix, readLimit) @@ -80,15 +81,16 @@ func (t kubernetesTraffic) Run(ctx context.Context, c *RecordingClient, limiter } t.Watch(ctx, kc, s, limiter, keyPrefix, rev+1) } - }) - g.Go(func() error { + }() + go func() { + defer wg.Done() lastWriteFailed := false for { select { case <-ctx.Done(): - return ctx.Err() + return case <-finish: - return nil + return default: } // Avoid multiple failed writes in a row @@ -104,8 +106,9 @@ func (t kubernetesTraffic) Run(ctx context.Context, c *RecordingClient, limiter continue } } - }) - g.Wait() + }() + + wg.Wait() } func (t kubernetesTraffic) Read(ctx context.Context, kc *kubernetesClient, s *storage, limiter *rate.Limiter, keyPrefix string, limit int) (rev int64, err error) { diff --git a/tests/robustness/traffic/limiter_test.go b/tests/robustness/traffic/limiter_test.go index ef3ead7444d..010ac909fc4 100644 --- a/tests/robustness/traffic/limiter_test.go +++ b/tests/robustness/traffic/limiter_test.go @@ -15,40 +15,42 @@ package traffic import ( + "sync" "sync/atomic" "testing" "github.com/stretchr/testify/assert" - "golang.org/x/sync/errgroup" ) func TestLimiter(t *testing.T) { limiter := NewConcurrencyLimiter(3) counter := &atomic.Int64{} - g := errgroup.Group{} + wg := sync.WaitGroup{} for i := 0; i < 10; i++ { - g.Go(func() error { + wg.Add(1) + go func() { + defer wg.Done() if limiter.Take() { counter.Add(1) } - return nil - }) + }() } - g.Wait() + wg.Wait() assert.Equal(t, 3, int(counter.Load())) assert.False(t, limiter.Take()) limiter.Return() counter.Store(0) for i := 0; i < 10; i++ { - g.Go(func() error { + wg.Add(1) + go func() { + defer wg.Done() if limiter.Take() { counter.Add(1) } - return nil - }) + }() } - g.Wait() + wg.Wait() assert.Equal(t, 1, int(counter.Load())) assert.False(t, limiter.Take())