Skip to content

Commit

Permalink
client: implement maxAttempts for retryPolicy (#7229)
Browse files Browse the repository at this point in the history
  • Loading branch information
imoore76 committed May 24, 2024
1 parent f7d3d3e commit 03da31a
Show file tree
Hide file tree
Showing 6 changed files with 129 additions and 13 deletions.
4 changes: 2 additions & 2 deletions clientconn.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ func NewClient(target string, opts ...DialOption) (conn *ClientConn, err error)
}

if cc.dopts.defaultServiceConfigRawJSON != nil {
scpr := parseServiceConfig(*cc.dopts.defaultServiceConfigRawJSON)
scpr := parseServiceConfig(*cc.dopts.defaultServiceConfigRawJSON, cc.dopts.maxCallAttempts)
if scpr.Err != nil {
return nil, fmt.Errorf("%s: %v", invalidDefaultServiceConfigErrPrefix, scpr.Err)
}
Expand Down Expand Up @@ -693,7 +693,7 @@ func (cc *ClientConn) waitForResolvedAddrs(ctx context.Context) error {
var emptyServiceConfig *ServiceConfig

func init() {
cfg := parseServiceConfig("{}")
cfg := parseServiceConfig("{}", defaultMaxCallAttempts)
if cfg.Err != nil {
panic(fmt.Sprintf("impossible error parsing empty service config: %v", cfg.Err))
}
Expand Down
24 changes: 24 additions & 0 deletions dialoptions.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,11 @@ import (
"google.golang.org/grpc/stats"
)

const (
// https://github.com/grpc/proposal/blob/master/A6-client-retries.md#limits-on-retries-and-hedges
defaultMaxCallAttempts = 5
)

func init() {
internal.AddGlobalDialOptions = func(opt ...DialOption) {
globalDialOptions = append(globalDialOptions, opt...)
Expand Down Expand Up @@ -89,6 +94,7 @@ type dialOptions struct {
idleTimeout time.Duration
recvBufferPool SharedBufferPool
defaultScheme string
maxCallAttempts int
}

// DialOption configures how we set up the connection.
Expand Down Expand Up @@ -677,6 +683,7 @@ func defaultDialOptions() dialOptions {
idleTimeout: 30 * time.Minute,
recvBufferPool: nopBufferPool{},
defaultScheme: "dns",
maxCallAttempts: defaultMaxCallAttempts,
}
}

Expand Down Expand Up @@ -734,6 +741,23 @@ func WithIdleTimeout(d time.Duration) DialOption {
})
}

// WithMaxCallAttempts returns a DialOption that configures the maximum number
// of attempts per call (including retries and hedging) using the channel.
// Service owners may specify a higher value for these parameters, but higher
// values will be treated as equal to the maximum value by the client
// implementation. This mitigates security concerns related to the service
// config being transferred to the client via DNS.
//
// A value of 5 will be used if this dial option is not set or n < 2.
func WithMaxCallAttempts(n int) DialOption {
return newFuncDialOption(func(o *dialOptions) {
if n < 2 {
n = defaultMaxCallAttempts
}
o.maxCallAttempts = n
})
}

// WithRecvBufferPool returns a DialOption that configures the ClientConn
// to use the provided shared buffer pool for parsing incoming messages. Depending
// on the application's workload, this could result in reduced memory allocation.
Expand Down
2 changes: 1 addition & 1 deletion resolver_wrapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ func (ccr *ccResolverWrapper) NewAddress(addrs []resolver.Address) {
// ParseServiceConfig is called by resolver implementations to parse a JSON
// representation of the service config.
func (ccr *ccResolverWrapper) ParseServiceConfig(scJSON string) *serviceconfig.ParseResult {
return parseServiceConfig(scJSON)
return parseServiceConfig(scJSON, ccr.cc.dopts.maxCallAttempts)
}

// addChannelzTraceEvent adds a channelz trace event containing the new
Expand Down
19 changes: 10 additions & 9 deletions service_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -164,9 +164,11 @@ type jsonSC struct {
}

func init() {
internal.ParseServiceConfig = parseServiceConfig
internal.ParseServiceConfig = func(js string) *serviceconfig.ParseResult {
return parseServiceConfig(js, defaultMaxCallAttempts)
}
}
func parseServiceConfig(js string) *serviceconfig.ParseResult {
func parseServiceConfig(js string, maxAttempts int) *serviceconfig.ParseResult {
if len(js) == 0 {
return &serviceconfig.ParseResult{Err: fmt.Errorf("no JSON service config provided")}
}
Expand Down Expand Up @@ -219,7 +221,7 @@ func parseServiceConfig(js string) *serviceconfig.ParseResult {
WaitForReady: m.WaitForReady,
Timeout: (*time.Duration)(m.Timeout),
}
if mc.RetryPolicy, err = convertRetryPolicy(m.RetryPolicy); err != nil {
if mc.RetryPolicy, err = convertRetryPolicy(m.RetryPolicy, maxAttempts); err != nil {
logger.Warningf("grpc: unmarshalling service config %s: %v", js, err)
return &serviceconfig.ParseResult{Err: err}
}
Expand Down Expand Up @@ -265,7 +267,7 @@ func parseServiceConfig(js string) *serviceconfig.ParseResult {
return &serviceconfig.ParseResult{Config: &sc}
}

func convertRetryPolicy(jrp *jsonRetryPolicy) (p *internalserviceconfig.RetryPolicy, err error) {
func convertRetryPolicy(jrp *jsonRetryPolicy, maxAttempts int) (p *internalserviceconfig.RetryPolicy, err error) {
if jrp == nil {
return nil, nil
}
Expand All @@ -279,17 +281,16 @@ func convertRetryPolicy(jrp *jsonRetryPolicy) (p *internalserviceconfig.RetryPol
return nil, nil
}

if jrp.MaxAttempts < maxAttempts {
maxAttempts = jrp.MaxAttempts
}
rp := &internalserviceconfig.RetryPolicy{
MaxAttempts: jrp.MaxAttempts,
MaxAttempts: maxAttempts,
InitialBackoff: time.Duration(jrp.InitialBackoff),
MaxBackoff: time.Duration(jrp.MaxBackoff),
BackoffMultiplier: jrp.BackoffMultiplier,
RetryableStatusCodes: make(map[codes.Code]bool),
}
if rp.MaxAttempts > 5 {
// TODO(retry): Make the max maxAttempts configurable.
rp.MaxAttempts = 5
}
for _, code := range jrp.RetryableStatusCodes {
rp.RetryableStatusCodes[code] = true
}
Expand Down
2 changes: 1 addition & 1 deletion service_config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ func runParseTests(t *testing.T, testCases []parseTestCase) {
t.Helper()
for i, c := range testCases {
t.Run(fmt.Sprint(i), func(t *testing.T) {
scpr := parseServiceConfig(c.scjs)
scpr := parseServiceConfig(c.scjs, defaultMaxCallAttempts)
var sc *ServiceConfig
sc, _ = scpr.Config.(*ServiceConfig)
if !c.wantErr {
Expand Down
91 changes: 91 additions & 0 deletions test/retry_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,97 @@ func (s) TestRetryStreaming(t *testing.T) {
}
}

func (s) TestMaxCallAttempts(t *testing.T) {
testCases := []struct {
serviceMaxAttempts int
clientMaxAttempts int
expectedAttempts int
}{
{serviceMaxAttempts: 9, clientMaxAttempts: 4, expectedAttempts: 4},
{serviceMaxAttempts: 9, clientMaxAttempts: 7, expectedAttempts: 7},
{serviceMaxAttempts: 3, clientMaxAttempts: 10, expectedAttempts: 3},
{serviceMaxAttempts: 8, clientMaxAttempts: -1, expectedAttempts: 5}, // 5 is default max
{serviceMaxAttempts: 3, clientMaxAttempts: 0, expectedAttempts: 3},
}

for _, tc := range testCases {
clientOpts := []grpc.DialOption{
grpc.WithMaxCallAttempts(tc.clientMaxAttempts),
grpc.WithDefaultServiceConfig(fmt.Sprintf(`{
"methodConfig": [{
"name": [{"service": "grpc.testing.TestService"}],
"waitForReady": true,
"retryPolicy": {
"MaxAttempts": %d,
"InitialBackoff": ".01s",
"MaxBackoff": ".01s",
"BackoffMultiplier": 1.0,
"RetryableStatusCodes": [ "UNAVAILABLE" ]
}
}]}`, tc.serviceMaxAttempts),
),
}

streamCallCount := 0
unaryCallCount := 0

ss := &stubserver.StubServer{
FullDuplexCallF: func(stream testgrpc.TestService_FullDuplexCallServer) error {
streamCallCount++
return status.New(codes.Unavailable, "this is a test error").Err()
},
EmptyCallF: func(context.Context, *testpb.Empty) (r *testpb.Empty, err error) {
unaryCallCount++
return nil, status.New(codes.Unavailable, "this is a test error").Err()
},
}

func() {

if err := ss.Start([]grpc.ServerOption{}, clientOpts...); err != nil {
t.Fatalf("Error starting endpoint server: %v", err)
}
defer ss.Stop()
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()

for {
if ctx.Err() != nil {
t.Fatalf("Timed out waiting for service config update")
}
if ss.CC.GetMethodConfig("/grpc.testing.TestService/FullDuplexCall").WaitForReady != nil {
break
}
time.Sleep(time.Millisecond)
}

// Test streaming RPC
stream, err := ss.Client.FullDuplexCall(ctx)
if err != nil {
t.Fatalf("Error while creating stream: %v", err)
}
if got, err := stream.Recv(); err == nil {
t.Fatalf("client: Recv() = %s, %v; want <nil>, error", got, err)
} else if status.Code(err) != codes.Unavailable {
t.Fatalf("client: Recv() = _, %v; want _, Unavailable", err)
}
if streamCallCount != tc.expectedAttempts {
t.Fatalf("stream expectedAttempts = %v; want %v", streamCallCount, tc.expectedAttempts)
}

// Test unary RPC
if ugot, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}); err == nil {
t.Fatalf("client: EmptyCall() = %s, %v; want <nil>, error", ugot, err)
} else if status.Code(err) != codes.Unavailable {
t.Fatalf("client: EmptyCall() = _, %v; want _, Unavailable", err)
}
if unaryCallCount != tc.expectedAttempts {
t.Fatalf("unary expectedAttempts = %v; want %v", unaryCallCount, tc.expectedAttempts)
}
}()
}
}

type retryStatsHandler struct {
mu sync.Mutex
s []stats.RPCStats
Expand Down

0 comments on commit 03da31a

Please sign in to comment.