From 48da26a133a910879478339abceaaa39674896f8 Mon Sep 17 00:00:00 2001 From: "qiheng.zhou" Date: Mon, 4 Nov 2024 09:55:17 +0800 Subject: [PATCH] fix(gRPC): check interface in ProtocolMatch --- pkg/remote/trans/nphttp2/server_handler.go | 22 ++++++----- .../trans/nphttp2/server_handler_test.go | 38 +++++++++++++++++++ 2 files changed, 51 insertions(+), 9 deletions(-) diff --git a/pkg/remote/trans/nphttp2/server_handler.go b/pkg/remote/trans/nphttp2/server_handler.go index f249f84242..2de1006bfe 100644 --- a/pkg/remote/trans/nphttp2/server_handler.go +++ b/pkg/remote/trans/nphttp2/server_handler.go @@ -83,16 +83,20 @@ var prefaceReadAtMost = func() int { return grpcTransport.ClientPrefaceLen }() -func (t *svrTransHandler) ProtocolMatch(ctx context.Context, conn net.Conn) (err error) { +func (t *svrTransHandler) ProtocolMatch(ctx context.Context, conn net.Conn) error { // Check the validity of client preface. - npReader := conn.(interface{ Reader() netpoll.Reader }).Reader() - // read at most avoid block - preface, err := npReader.Peek(prefaceReadAtMost) - if err != nil { - return err - } - if bytes.Equal(preface[:prefaceReadAtMost], grpcTransport.ClientPreface[:prefaceReadAtMost]) { - return nil + // FIXME: should not rely on netpoll.Reader + if withReader, ok := conn.(interface{ Reader() netpoll.Reader }); ok { + if npReader := withReader.Reader(); npReader != nil { + // read at most avoid block + preface, err := npReader.Peek(prefaceReadAtMost) + if err != nil { + return err + } + if len(preface) >= prefaceReadAtMost && bytes.Equal(preface[:prefaceReadAtMost], grpcTransport.ClientPreface[:prefaceReadAtMost]) { + return nil + } + } } return errors.New("error protocol not match") } diff --git a/pkg/remote/trans/nphttp2/server_handler_test.go b/pkg/remote/trans/nphttp2/server_handler_test.go index a1a0d6f1dc..b3f2c537ea 100644 --- a/pkg/remote/trans/nphttp2/server_handler_test.go +++ b/pkg/remote/trans/nphttp2/server_handler_test.go @@ -23,11 +23,17 @@ import ( "testing" "time" + "github.com/cloudwego/kitex/internal/mocks" + + "github.com/cloudwego/kitex/internal/mocks/netpoll" + "github.com/golang/mock/gomock" + mocksremote "github.com/cloudwego/kitex/internal/mocks/remote" "github.com/cloudwego/kitex/internal/test" "github.com/cloudwego/kitex/pkg/kerrors" "github.com/cloudwego/kitex/pkg/remote" "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/grpc" + grpcTransport "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/grpc" "github.com/cloudwego/kitex/pkg/rpcinfo" "github.com/cloudwego/kitex/pkg/serviceinfo" "github.com/cloudwego/kitex/pkg/streaming" @@ -303,3 +309,35 @@ func Test_invokeStreamUnaryHandler(t *testing.T) { test.Assert(t, sendCalled) }) } + +func TestSvrTransHandlerProtocolMatch(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + th := &svrTransHandler{} + // netpoll reader + // 1. success + reader := netpoll.NewMockReader(ctrl) + reader.EXPECT().Peek(prefaceReadAtMost).Times(1).Return(grpcTransport.ClientPreface, nil) + conn := netpoll.NewMockConnection(ctrl) + conn.EXPECT().Reader().AnyTimes().Return(reader) + err := th.ProtocolMatch(context.Background(), conn) + test.Assert(t, err == nil, err) + // 2. failed, no reader + conn = netpoll.NewMockConnection(ctrl) + conn.EXPECT().Reader().AnyTimes().Return(nil) + err = th.ProtocolMatch(context.Background(), conn) + test.Assert(t, err != nil, err) + // 3. failed, wrong preface + failedReader := netpoll.NewMockReader(ctrl) + failedReader.EXPECT().Peek(prefaceReadAtMost).Times(1).Return([]byte{}, nil) + conn = netpoll.NewMockConnection(ctrl) + conn.EXPECT().Reader().AnyTimes().Return(failedReader) + err = th.ProtocolMatch(context.Background(), conn) + test.Assert(t, err != nil, err) + + // non-netpoll reader, all failed + rawConn := &mocks.Conn{} + err = th.ProtocolMatch(context.Background(), rawConn) + test.Assert(t, err != nil, err) +}