From 624143455f1b667ea5595eaf22d4f5eb6a1c6a10 Mon Sep 17 00:00:00 2001 From: "qiheng.zhou" Date: Mon, 10 Feb 2025 11:14:24 +0800 Subject: [PATCH] fix: check address when GetCallerAddress to prevent panic --- pkg/utils/kitexutil/kitexutil.go | 5 ++- pkg/utils/kitexutil/kitexutil_test.go | 45 ++++++++++++++------------- 2 files changed, 26 insertions(+), 24 deletions(-) diff --git a/pkg/utils/kitexutil/kitexutil.go b/pkg/utils/kitexutil/kitexutil.go index b8fba7e3a5..700cb0c6a7 100644 --- a/pkg/utils/kitexutil/kitexutil.go +++ b/pkg/utils/kitexutil/kitexutil.go @@ -63,7 +63,6 @@ func GetMethod(ctx context.Context) (string, bool) { // GetCallerHandlerMethod is used to get the method name of caller. // Only the caller is a Kitex server will have this method information by default, or you can set K_METHOD into context.Context then kitex will get it. -// Return false if failed to get the information. func GetCallerHandlerMethod(ctx context.Context) (string, bool) { defer func() { recover() }() @@ -92,7 +91,7 @@ func GetCallerAddr(ctx context.Context) (net.Addr, bool) { defer func() { recover() }() ri := rpcinfo.GetRPCInfo(ctx) - if ri == nil { + if ri == nil || ri.From() == nil || ri.From().Address() == nil { return nil, false } return ri.From().Address(), true @@ -104,7 +103,7 @@ func GetCallerIP(ctx context.Context) (string, bool) { defer func() { recover() }() ri := rpcinfo.GetRPCInfo(ctx) - if ri == nil { + if ri == nil || ri.From() == nil || ri.From().Address() == nil { return "", false } addrStr := ri.From().Address().String() diff --git a/pkg/utils/kitexutil/kitexutil_test.go b/pkg/utils/kitexutil/kitexutil_test.go index 58b9b858d0..76745dce0a 100644 --- a/pkg/utils/kitexutil/kitexutil_test.go +++ b/pkg/utils/kitexutil/kitexutil_test.go @@ -31,20 +31,20 @@ import ( ) var ( - testRi rpcinfo.RPCInfo - testCtx context.Context - panicCtx context.Context - caller = "kitexutil.from.service" - callee = "kitexutil.to.service" - idlServiceName = "MockService" - fromAddr = utils.NewNetAddr("test", "127.0.0.1:12345") - fromMethod = "from_method" - method = "method" - tp = transport.TTHeader + testRi rpcinfo.RPCInfo + testCtx context.Context + panicCtx context.Context + testCaller = "kitexutil.from.service" + testCallee = "kitexutil.to.service" + testIdlServiceName = "MockService" + testFromAddr = utils.NewNetAddr("test", "127.0.0.1:12345") + testFromMethod = "from_method" + testMethod = "testMethod" + testTp = transport.TTHeader ) func TestMain(m *testing.M) { - testRi = buildRPCInfo() + testRi = buildRPCInfo(testCaller, testFromMethod, testCallee, testMethod, testIdlServiceName, testFromAddr, testTp) testCtx = context.Background() testCtx = rpcinfo.NewCtxWithRPCInfo(testCtx, testRi) @@ -63,7 +63,7 @@ func TestGetCaller(t *testing.T) { want string want1 bool }{ - {name: "Success", args: args{testCtx}, want: caller, want1: true}, + {name: "Success", args: args{testCtx}, want: testCaller, want1: true}, {name: "Failure", args: args{context.Background()}, want: "", want1: false}, {name: "Panic recovered", args: args{panicCtx}, want: "", want1: false}, } @@ -90,7 +90,7 @@ func TestGetCallee(t *testing.T) { want string want1 bool }{ - {name: "Success", args: args{testCtx}, want: callee, want1: true}, + {name: "Success", args: args{testCtx}, want: testCallee, want1: true}, {name: "Failure", args: args{context.Background()}, want: "", want1: false}, {name: "Panic recovered", args: args{panicCtx}, want: "", want1: false}, } @@ -108,6 +108,8 @@ func TestGetCallee(t *testing.T) { } func TestGetCallerAddr(t *testing.T) { + riWithoutAddress := buildRPCInfo(testCaller, testFromMethod, testCallee, testMethod, testIdlServiceName, nil, testTp) + ctxWithoutAddress := rpcinfo.NewCtxWithRPCInfo(context.Background(), riWithoutAddress) type args struct { ctx context.Context } @@ -117,8 +119,9 @@ func TestGetCallerAddr(t *testing.T) { want net.Addr want1 bool }{ - {name: "Success", args: args{testCtx}, want: fromAddr, want1: true}, - {name: "Failure", args: args{context.Background()}, want: nil, want1: false}, + {name: "Success", args: args{testCtx}, want: testFromAddr, want1: true}, + {name: "Failure: nil rpcinfo", args: args{context.Background()}, want: nil, want1: false}, + {name: "Failure: nil address", args: args{ctxWithoutAddress}, want: nil, want1: false}, {name: "Panic recovered", args: args{panicCtx}, want: nil, want1: false}, } for _, tt := range tests { @@ -139,7 +142,7 @@ func TestGetCallerIP(t *testing.T) { test.Assert(t, ok) test.Assert(t, ip == "127.0.0.1", ip) - ri := buildRPCInfo() + ri := buildRPCInfo(testCaller, testFromMethod, testCallee, testMethod, testIdlServiceName, testFromAddr, testTp) rpcinfo.AsMutableEndpointInfo(ri.From()).SetAddress(utils.NewNetAddr("test", "127.0.0.1")) ip, ok = GetCallerIP(rpcinfo.NewCtxWithRPCInfo(context.Background(), ri)) test.Assert(t, ok) @@ -169,7 +172,7 @@ func TestGetMethod(t *testing.T) { want string want1 bool }{ - {name: "Success", args: args{testCtx}, want: method, want1: true}, + {name: "Success", args: args{testCtx}, want: testMethod, want1: true}, {name: "Failure", args: args{context.Background()}, want: "", want1: false}, } for _, tt := range tests { @@ -195,7 +198,7 @@ func TestGetCallerHandlerMethod(t *testing.T) { want string want1 bool }{ - {name: "Success", args: args{testCtx}, want: fromMethod, want1: true}, + {name: "Success", args: args{testCtx}, want: testFromMethod, want1: true}, {name: "Failure", args: args{context.Background()}, want: "", want1: false}, {name: "Panic recovered", args: args{panicCtx}, want: "", want1: false}, } @@ -222,7 +225,7 @@ func TestGetIDLServiceName(t *testing.T) { want string want1 bool }{ - {name: "Success", args: args{testCtx}, want: idlServiceName, want1: true}, + {name: "Success", args: args{testCtx}, want: testIdlServiceName, want1: true}, {name: "Failure", args: args{context.Background()}, want: "", want1: false}, {name: "Panic recovered", args: args{panicCtx}, want: "", want1: false}, } @@ -275,7 +278,7 @@ func TestGetCtxTransportProtocol(t *testing.T) { want string want1 bool }{ - {name: "Success", args: args{testCtx}, want: tp.String(), want1: true}, + {name: "Success", args: args{testCtx}, want: testTp.String(), want1: true}, {name: "Failure", args: args{context.Background()}, want: "", want1: false}, {name: "Panic recovered", args: args{panicCtx}, want: "", want1: false}, } @@ -340,7 +343,7 @@ func TestGetRealResponse(t *testing.T) { } } -func buildRPCInfo() rpcinfo.RPCInfo { +func buildRPCInfo(caller, fromMethod, callee, method, idlServiceName string, fromAddr net.Addr, tp transport.Protocol) rpcinfo.RPCInfo { from := rpcinfo.NewEndpointInfo(caller, fromMethod, fromAddr, nil) to := rpcinfo.NewEndpointInfo(callee, method, nil, nil) ink := rpcinfo.NewInvocation(idlServiceName, method)