Skip to content

Commit cdee119

Browse files
authored
Merge pull request #1124 from MakMukhi/rst_stream_issue
Upon observing timeout on rpc context, the client should send a RST_S…
2 parents 0713829 + 5535384 commit cdee119

File tree

3 files changed

+9
-81
lines changed

3 files changed

+9
-81
lines changed

test/end2end_test.go

Lines changed: 0 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -2447,85 +2447,6 @@ func testFailedServerStreaming(t *testing.T, e env) {
24472447
}
24482448
}
24492449

2450-
// checkTimeoutErrorServer is a gRPC server checks context timeout error in FullDuplexCall().
2451-
// It is only used in TestStreamingRPCTimeoutServerError.
2452-
type checkTimeoutErrorServer struct {
2453-
t *testing.T
2454-
done chan struct{}
2455-
testpb.TestServiceServer
2456-
}
2457-
2458-
func (s *checkTimeoutErrorServer) FullDuplexCall(stream testpb.TestService_FullDuplexCallServer) error {
2459-
defer close(s.done)
2460-
for {
2461-
_, err := stream.Recv()
2462-
if err != nil {
2463-
if grpc.Code(err) != codes.DeadlineExceeded {
2464-
s.t.Errorf("stream.Recv() = _, %v, want error code %s", err, codes.DeadlineExceeded)
2465-
}
2466-
return err
2467-
}
2468-
if err := stream.Send(&testpb.StreamingOutputCallResponse{
2469-
Payload: &testpb.Payload{
2470-
Body: []byte{'0'},
2471-
},
2472-
}); err != nil {
2473-
if grpc.Code(err) != codes.DeadlineExceeded {
2474-
s.t.Errorf("stream.Send(_) = %v, want error code %s", err, codes.DeadlineExceeded)
2475-
}
2476-
return err
2477-
}
2478-
}
2479-
}
2480-
2481-
func TestStreamingRPCTimeoutServerError(t *testing.T) {
2482-
defer leakCheck(t)()
2483-
for _, e := range listTestEnv() {
2484-
testStreamingRPCTimeoutServerError(t, e)
2485-
}
2486-
}
2487-
2488-
// testStreamingRPCTimeoutServerError tests the server side behavior.
2489-
// When context timeout happens on client side, server should get deadline exceeded error.
2490-
func testStreamingRPCTimeoutServerError(t *testing.T, e env) {
2491-
te := newTest(t, e)
2492-
serverDone := make(chan struct{})
2493-
te.startServer(&checkTimeoutErrorServer{t: t, done: serverDone})
2494-
defer te.tearDown()
2495-
2496-
cc := te.clientConn()
2497-
tc := testpb.NewTestServiceClient(cc)
2498-
2499-
req := &testpb.StreamingOutputCallRequest{}
2500-
for duration := 50 * time.Millisecond; ; duration *= 2 {
2501-
ctx, _ := context.WithTimeout(context.Background(), duration)
2502-
stream, err := tc.FullDuplexCall(ctx, grpc.FailFast(false))
2503-
if grpc.Code(err) == codes.DeadlineExceeded {
2504-
// Redo test with double timeout.
2505-
continue
2506-
}
2507-
if err != nil {
2508-
t.Errorf("%v.FullDuplexCall(_) = _, %v, want <nil>", tc, err)
2509-
return
2510-
}
2511-
for {
2512-
err := stream.Send(req)
2513-
if err != nil {
2514-
break
2515-
}
2516-
_, err = stream.Recv()
2517-
if err != nil {
2518-
break
2519-
}
2520-
}
2521-
2522-
// Wait for context timeout on server before closing connection
2523-
// to make sure the server will get timeout error.
2524-
<-serverDone
2525-
break
2526-
}
2527-
}
2528-
25292450
// concurrentSendServer is a TestServiceServer whose
25302451
// StreamingOutputCall makes ten serial Send calls, sending payloads
25312452
// "0".."9", inclusive. TestServerStreamingConcurrent verifies they

transport/http2_client.go

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -533,17 +533,19 @@ func (t *http2Client) CloseStream(s *Stream, err error) {
533533
// after having acquired the writableChan to send RST_STREAM out (look at
534534
// the controller() routine).
535535
var rstStream bool
536+
var rstError http2.ErrCode
536537
defer func() {
537538
// In case, the client doesn't have to send RST_STREAM to server
538539
// we can safely add back to streamsQuota pool now.
539540
if !rstStream {
540541
t.streamsQuota.add(1)
541542
return
542543
}
543-
t.controlBuf.put(&resetStream{s.id, http2.ErrCodeCancel})
544+
t.controlBuf.put(&resetStream{s.id, rstError})
544545
}()
545546
s.mu.Lock()
546547
rstStream = s.rstStream
548+
rstError = s.rstError
547549
if q := s.fc.resetPendingData(); q > 0 {
548550
if n := t.fc.onRead(q); n > 0 {
549551
t.controlBuf.put(&windowUpdate{0, n})
@@ -559,8 +561,9 @@ func (t *http2Client) CloseStream(s *Stream, err error) {
559561
}
560562
s.state = streamDone
561563
s.mu.Unlock()
562-
if se, ok := err.(StreamError); ok && se.Code != codes.DeadlineExceeded {
564+
if _, ok := err.(StreamError); ok {
563565
rstStream = true
566+
rstError = http2.ErrCodeCancel
564567
}
565568
}
566569

@@ -807,6 +810,7 @@ func (t *http2Client) handleData(f *http2.DataFrame) {
807810
s.statusCode = codes.Internal
808811
s.statusDesc = err.Error()
809812
s.rstStream = true
813+
s.rstError = http2.ErrCodeFlowControl
810814
close(s.done)
811815
s.mu.Unlock()
812816
s.write(recvMsg{err: io.EOF})

transport/transport.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ import (
4545
"sync"
4646

4747
"golang.org/x/net/context"
48+
"golang.org/x/net/http2"
4849
"google.golang.org/grpc/codes"
4950
"google.golang.org/grpc/credentials"
5051
"google.golang.org/grpc/keepalive"
@@ -217,6 +218,8 @@ type Stream struct {
217218
// rstStream indicates whether a RST_STREAM frame needs to be sent
218219
// to the server to signify that this stream is closing.
219220
rstStream bool
221+
// rstError is the error that needs to be sent along with the RST_STREAM frame.
222+
rstError http2.ErrCode
220223
}
221224

222225
// RecvCompress returns the compression algorithm applied to the inbound

0 commit comments

Comments
 (0)