Skip to content

Commit

Permalink
fix: check address when GetCallerAddress to prevent panic
Browse files Browse the repository at this point in the history
  • Loading branch information
ppzqh committed Feb 10, 2025
1 parent 047444c commit bb4a79e
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 29 deletions.
16 changes: 8 additions & 8 deletions pkg/utils/kitexutil/kitexutil.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import (
"github.com/cloudwego/kitex/pkg/utils"
)

// GetCaller is used to get the Service Name of the caller.
// GetCaller is used to get the Service Name of the testCaller.
// Return false if failed to get the information.
func GetCaller(ctx context.Context) (string, bool) {
defer func() { recover() }()
Expand All @@ -37,7 +37,7 @@ func GetCaller(ctx context.Context) (string, bool) {
return ri.From().ServiceName(), true
}

// GetCallee is used to get the Service Name of the callee.
// GetCallee is used to get the Service Name of the testCallee.
// Return false if failed to get the information.
func GetCallee(ctx context.Context) (string, bool) {
defer func() { recover() }()
Expand All @@ -61,8 +61,8 @@ func GetMethod(ctx context.Context) (string, bool) {
return ri.To().Method(), true
}

// GetCallerHandlerMethod is used to get the method name of caller.
// Only the caller is a Kitex server will have this method information by default, or you can set K_METHOD into context.Context then kitex will get it.
// GetCallerHandlerMethod is used to get the testMethod name of testCaller.
// Only the testCaller is a Kitex server will have this testMethod information by default, or you can set K_METHOD into context.Context then kitex will get it.
// Return false if failed to get the information.
func GetCallerHandlerMethod(ctx context.Context) (string, bool) {
defer func() { recover() }()
Expand Down Expand Up @@ -92,7 +92,7 @@ func GetCallerAddr(ctx context.Context) (net.Addr, bool) {
defer func() { recover() }()

ri := rpcinfo.GetRPCInfo(ctx)
if ri == nil {
if ri == nil || ri.From() == nil || ri.From().Address() == nil {
return nil, false
}
return ri.From().Address(), true
Expand All @@ -104,7 +104,7 @@ func GetCallerIP(ctx context.Context) (string, bool) {
defer func() { recover() }()

ri := rpcinfo.GetRPCInfo(ctx)
if ri == nil {
if ri == nil || ri.From() == nil || ri.From().Address() == nil {
return "", false
}
addrStr := ri.From().Address().String()
Expand Down Expand Up @@ -142,7 +142,7 @@ func GetRPCInfo(ctx context.Context) (rpcinfo.RPCInfo, bool) {
}

// GetRealReqFromKitexArgs assert the req to be KitexArgs and return the real request if succeeded, otherwise return nil.
// This method should be used in the middleware.
// This testMethod should be used in the middleware.
func GetRealReqFromKitexArgs(req interface{}) interface{} {
if arg, ok := req.(utils.KitexArgs); ok {
return arg.GetFirstArgument()
Expand All @@ -151,7 +151,7 @@ func GetRealReqFromKitexArgs(req interface{}) interface{} {
}

// GetRealRespFromKitexResult assert the req to be KitexResult and return the real response if succeeded, otherwise return nil.
// This method should be used in the middleware.
// This testMethod should be used in the middleware.
func GetRealRespFromKitexResult(resp interface{}) interface{} {
if res, ok := resp.(utils.KitexResult); ok {
return res.GetResult()
Expand Down
45 changes: 24 additions & 21 deletions pkg/utils/kitexutil/kitexutil_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,20 +31,20 @@ import (
)

var (
testRi rpcinfo.RPCInfo
testCtx context.Context
panicCtx context.Context
caller = "kitexutil.from.service"
callee = "kitexutil.to.service"
idlServiceName = "MockService"
fromAddr = utils.NewNetAddr("test", "127.0.0.1:12345")
fromMethod = "from_method"
method = "method"
tp = transport.TTHeader
testRi rpcinfo.RPCInfo
testCtx context.Context
panicCtx context.Context
testCaller = "kitexutil.from.service"
testCallee = "kitexutil.to.service"
testIdlServiceName = "MockService"
testFromAddr = utils.NewNetAddr("test", "127.0.0.1:12345")
testFromMethod = "from_method"
testMethod = "testMethod"
testTp = transport.TTHeader
)

func TestMain(m *testing.M) {
testRi = buildRPCInfo()
testRi = buildRPCInfo(testCaller, testFromMethod, testCallee, testMethod, testIdlServiceName, testFromAddr, testTp)

testCtx = context.Background()
testCtx = rpcinfo.NewCtxWithRPCInfo(testCtx, testRi)
Expand All @@ -63,7 +63,7 @@ func TestGetCaller(t *testing.T) {
want string
want1 bool
}{
{name: "Success", args: args{testCtx}, want: caller, want1: true},
{name: "Success", args: args{testCtx}, want: testCaller, want1: true},
{name: "Failure", args: args{context.Background()}, want: "", want1: false},
{name: "Panic recovered", args: args{panicCtx}, want: "", want1: false},
}
Expand All @@ -90,7 +90,7 @@ func TestGetCallee(t *testing.T) {
want string
want1 bool
}{
{name: "Success", args: args{testCtx}, want: callee, want1: true},
{name: "Success", args: args{testCtx}, want: testCallee, want1: true},
{name: "Failure", args: args{context.Background()}, want: "", want1: false},
{name: "Panic recovered", args: args{panicCtx}, want: "", want1: false},
}
Expand All @@ -108,6 +108,8 @@ func TestGetCallee(t *testing.T) {
}

func TestGetCallerAddr(t *testing.T) {
riWithoutAddress := buildRPCInfo(testCaller, testFromMethod, testCallee, testMethod, testIdlServiceName, nil, testTp)
ctxWithoutAddress := rpcinfo.NewCtxWithRPCInfo(context.Background(), riWithoutAddress)
type args struct {
ctx context.Context
}
Expand All @@ -117,8 +119,9 @@ func TestGetCallerAddr(t *testing.T) {
want net.Addr
want1 bool
}{
{name: "Success", args: args{testCtx}, want: fromAddr, want1: true},
{name: "Failure", args: args{context.Background()}, want: nil, want1: false},
{name: "Success", args: args{testCtx}, want: testFromAddr, want1: true},
{name: "Failure: nil rpcinfo", args: args{context.Background()}, want: nil, want1: false},
{name: "Failure: nil address", args: args{ctxWithoutAddress}, want: nil, want1: false},
{name: "Panic recovered", args: args{panicCtx}, want: nil, want1: false},
}
for _, tt := range tests {
Expand All @@ -139,7 +142,7 @@ func TestGetCallerIP(t *testing.T) {
test.Assert(t, ok)
test.Assert(t, ip == "127.0.0.1", ip)

ri := buildRPCInfo()
ri := buildRPCInfo(testCaller, testFromMethod, testCallee, testMethod, testIdlServiceName, testFromAddr, testTp)
rpcinfo.AsMutableEndpointInfo(ri.From()).SetAddress(utils.NewNetAddr("test", "127.0.0.1"))
ip, ok = GetCallerIP(rpcinfo.NewCtxWithRPCInfo(context.Background(), ri))
test.Assert(t, ok)
Expand Down Expand Up @@ -169,7 +172,7 @@ func TestGetMethod(t *testing.T) {
want string
want1 bool
}{
{name: "Success", args: args{testCtx}, want: method, want1: true},
{name: "Success", args: args{testCtx}, want: testMethod, want1: true},
{name: "Failure", args: args{context.Background()}, want: "", want1: false},
}
for _, tt := range tests {
Expand All @@ -195,7 +198,7 @@ func TestGetCallerHandlerMethod(t *testing.T) {
want string
want1 bool
}{
{name: "Success", args: args{testCtx}, want: fromMethod, want1: true},
{name: "Success", args: args{testCtx}, want: testFromMethod, want1: true},
{name: "Failure", args: args{context.Background()}, want: "", want1: false},
{name: "Panic recovered", args: args{panicCtx}, want: "", want1: false},
}
Expand All @@ -222,7 +225,7 @@ func TestGetIDLServiceName(t *testing.T) {
want string
want1 bool
}{
{name: "Success", args: args{testCtx}, want: idlServiceName, want1: true},
{name: "Success", args: args{testCtx}, want: testIdlServiceName, want1: true},
{name: "Failure", args: args{context.Background()}, want: "", want1: false},
{name: "Panic recovered", args: args{panicCtx}, want: "", want1: false},
}
Expand Down Expand Up @@ -275,7 +278,7 @@ func TestGetCtxTransportProtocol(t *testing.T) {
want string
want1 bool
}{
{name: "Success", args: args{testCtx}, want: tp.String(), want1: true},
{name: "Success", args: args{testCtx}, want: testTp.String(), want1: true},
{name: "Failure", args: args{context.Background()}, want: "", want1: false},
{name: "Panic recovered", args: args{panicCtx}, want: "", want1: false},
}
Expand Down Expand Up @@ -340,7 +343,7 @@ func TestGetRealResponse(t *testing.T) {
}
}

func buildRPCInfo() rpcinfo.RPCInfo {
func buildRPCInfo(caller, fromMethod, callee, method, idlServiceName string, fromAddr net.Addr, tp transport.Protocol) rpcinfo.RPCInfo {
from := rpcinfo.NewEndpointInfo(caller, fromMethod, fromAddr, nil)
to := rpcinfo.NewEndpointInfo(callee, method, nil, nil)
ink := rpcinfo.NewInvocation(idlServiceName, method)
Expand Down

0 comments on commit bb4a79e

Please sign in to comment.