Skip to content

Commit 691cd1b

Browse files
committed
fix: make mock package more flexible and support wrapperspb
This replaces use of the mock.Register() test helper function with a more flexible solution that supports wrapperspb types. The only caveat is that this requires pre-registration via TestMain, but is much more flexible, since we can add custom services, methods and message types for our testing needs without having to define them via proto files that needs to be generated. Note: this adds mock.GetValue, which will be used in a future test.
1 parent 14854d6 commit 691cd1b

File tree

9 files changed

+161
-173
lines changed

9 files changed

+161
-173
lines changed

.vscode/gorums.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,5 +97,6 @@ unmarshaled
9797
unmarshaling
9898
unmarshals
9999
userguide
100+
wrapperspb
100101
Xeon
101102
Zorums

channel_test.go

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import (
1313
"github.com/relab/gorums/ordering"
1414
"google.golang.org/grpc"
1515
"google.golang.org/protobuf/proto"
16+
pb "google.golang.org/protobuf/types/known/wrapperspb"
1617
)
1718

1819
const defaultTestTimeout = 3 * time.Second
@@ -34,15 +35,15 @@ func newNodeWithServer(t testing.TB, delay time.Duration) *RawNode {
3435
type mockSrv struct{}
3536

3637
func (mockSrv) Test(_ ServerCtx, req proto.Message) (proto.Message, error) {
37-
return mock.NewResponse(mock.GetVal(req) + "-mocked-"), nil
38+
return pb.String(mock.GetVal(req) + "-mocked-"), nil
3839
}
3940

4041
func newNodeWithStoppableServer(t testing.TB, delay time.Duration) (*RawNode, func()) {
4142
t.Helper()
4243
addrs, teardown := TestSetup(t, 1, func(_ int) ServerIface {
4344
mockSrv := &mockSrv{}
4445
srv := NewServer()
45-
srv.RegisterHandler(mock.ServerMethodName, func(ctx ServerCtx, in *Message) (*Message, error) {
46+
srv.RegisterHandler(mock.TestMethod, func(ctx ServerCtx, in *Message) (*Message, error) {
4647
// Simulate slow processing
4748
time.Sleep(delay)
4849
resp, err := mockSrv.Test(ctx, in.GetProtoMessage())
@@ -59,7 +60,7 @@ func sendRequest(t testing.TB, node *RawNode, req request, msgID uint64) respons
5960
if req.ctx == nil {
6061
req.ctx = t.Context()
6162
}
62-
req.msg = NewRequestMessage(ordering.NewGorumsMetadata(req.ctx, msgID, mock.ServerMethodName), nil)
63+
req.msg = NewRequestMessage(ordering.NewGorumsMetadata(req.ctx, msgID, mock.TestMethod), nil)
6364
replyChan := make(chan response, 1)
6465
node.channel.enqueue(req, replyChan)
6566

@@ -517,7 +518,7 @@ func TestChannelDeadlock(t *testing.T) {
517518
ctx, cancel := context.WithTimeout(t.Context(), 3*time.Second)
518519
defer cancel()
519520

520-
md := ordering.NewGorumsMetadata(ctx, uint64(100+id), mock.ServerMethodName)
521+
md := ordering.NewGorumsMetadata(ctx, uint64(100+id), mock.TestMethod)
521522
req := request{ctx: ctx, msg: NewRequestMessage(md, nil)}
522523

523524
// try to enqueue

config_test.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010
"github.com/relab/gorums"
1111
"github.com/relab/gorums/internal/testutils/mock"
1212
"google.golang.org/grpc/encoding"
13+
pb "google.golang.org/protobuf/types/known/wrapperspb"
1314
)
1415

1516
func init() {
@@ -226,8 +227,8 @@ func TestConfigConcurrentAccess(t *testing.T) {
226227
for range 2 {
227228
wg.Go(func() {
228229
_, err := node.RPCCall(context.Background(), gorums.CallData{
229-
Message: mock.NewRequest(""),
230-
Method: mock.ServerMethodName,
230+
Message: pb.String(""),
231+
Method: mock.TestMethod,
231232
})
232233
if err != nil {
233234
errCh <- err

internal/testutils/mock/mock.go

Lines changed: 91 additions & 145 deletions
Original file line numberDiff line numberDiff line change
@@ -1,187 +1,133 @@
11
package mock
22

33
import (
4-
"errors"
54
"fmt"
6-
"sync"
7-
"testing"
5+
"strings"
86

97
"google.golang.org/protobuf/proto"
108
"google.golang.org/protobuf/reflect/protodesc"
11-
"google.golang.org/protobuf/reflect/protoreflect"
129
"google.golang.org/protobuf/reflect/protoregistry"
1310
"google.golang.org/protobuf/types/descriptorpb"
14-
"google.golang.org/protobuf/types/dynamicpb"
11+
"google.golang.org/protobuf/types/known/wrapperspb"
1512
)
1613

17-
// ServerMethodName is the method supported by the mock package.
18-
const ServerMethodName = "mock.Server.Test"
19-
20-
var (
21-
// Mock Service Descriptors
22-
mockFile *descriptorpb.FileDescriptorProto
23-
requestType protoreflect.MessageType
24-
responseType protoreflect.MessageType
14+
// TestMethod and GetValueMethod are the methods supported by the mock package.
15+
const (
16+
TestMethod = "mock.MockService.Test"
17+
GetValueMethod = "mock.MockService.GetValue"
2518
)
2619

27-
func init() {
28-
// Initialize Mock Definitions
29-
mockFile = &descriptorpb.FileDescriptorProto{
30-
Name: proto.String("mock/mock.proto"),
31-
Package: proto.String("mock"),
32-
MessageType: []*descriptorpb.DescriptorProto{
33-
{
34-
Name: proto.String("Request"),
35-
Field: []*descriptorpb.FieldDescriptorProto{
36-
{
37-
Name: proto.String("val"),
38-
Number: proto.Int32(1),
39-
Label: descriptorpb.FieldDescriptorProto_LABEL_OPTIONAL.Enum(),
40-
Type: descriptorpb.FieldDescriptorProto_TYPE_STRING.Enum(),
41-
},
42-
},
43-
},
44-
{
45-
Name: proto.String("Response"),
46-
Field: []*descriptorpb.FieldDescriptorProto{
47-
{
48-
Name: proto.String("val"),
49-
Number: proto.Int32(1),
50-
Label: descriptorpb.FieldDescriptorProto_LABEL_OPTIONAL.Enum(),
51-
Type: descriptorpb.FieldDescriptorProto_TYPE_STRING.Enum(),
52-
},
53-
},
54-
},
55-
},
56-
Service: []*descriptorpb.ServiceDescriptorProto{
57-
{
58-
Name: proto.String("Server"),
59-
Method: []*descriptorpb.MethodDescriptorProto{
60-
{
61-
Name: proto.String("Test"),
62-
InputType: proto.String(".mock.Request"),
63-
OutputType: proto.String(".mock.Response"),
64-
},
65-
},
66-
},
67-
},
68-
}
20+
// Service represents a service to be registered.
21+
type Service struct {
22+
Name string // Full package and service name, e.g., "mock.MockService"
23+
Methods []Method
6924
}
7025

71-
// Register registers the mock types in the global registry.
72-
// It is safe to call multiple times.
73-
func Register(t testing.TB) {
74-
t.Helper()
75-
if err := registerOnce(); err != nil {
76-
t.Fatal(err)
77-
}
26+
// Method represents a method in a service.
27+
type Method struct {
28+
Name string
29+
Input proto.Message
30+
Output proto.Message
7831
}
7932

80-
var registerOnce = sync.OnceValue(func() error {
81-
if fd, err := protoregistry.GlobalFiles.FindFileByPath(mockFile.GetName()); err == nil {
82-
// Already registered
83-
return initDescriptors(fd)
84-
}
33+
// RegisterServices registers the given services in the global registry.
34+
// It is safe to call multiple times, but services with the same package name
35+
// must be registered in the same call or be identical to previous registrations.
36+
// Returns an error if registration fails.
37+
func RegisterServices(services []Service) error {
38+
// Group by package
39+
packages := make(map[string][]*descriptorpb.ServiceDescriptorProto)
8540

86-
fd, err := protodesc.NewFile(mockFile, nil)
87-
if err != nil {
88-
return fmt.Errorf("failed to create mock file descriptor: %v", err)
89-
}
90-
if err := protoregistry.GlobalFiles.RegisterFile(fd); err != nil {
91-
return fmt.Errorf("failed to register mock file: %v", err)
92-
}
93-
if err := initDescriptors(fd); err != nil {
94-
return fmt.Errorf("failed to initialize mock descriptors: %v", err)
95-
}
96-
if err := protoregistry.GlobalTypes.RegisterMessage(requestType); err != nil {
97-
return fmt.Errorf("failed to register Request type: %v", err)
98-
}
99-
if err := protoregistry.GlobalTypes.RegisterMessage(responseType); err != nil {
100-
return fmt.Errorf("failed to register Response type: %v", err)
101-
}
102-
return nil
103-
})
41+
for _, s := range services {
42+
pkgName, svcName, found := strings.Cut(s.Name, ".")
43+
if !found {
44+
return fmt.Errorf("service name %q must contain a package", s.Name)
45+
}
10446

105-
func initDescriptors(fd protoreflect.FileDescriptor) error {
106-
mockServiceDesc := fd.Services().ByName("Server")
107-
if mockServiceDesc == nil {
108-
return errors.New("service Server not found")
109-
}
110-
mockMethodDesc := mockServiceDesc.Methods().ByName("Test")
111-
if mockMethodDesc == nil {
112-
return errors.New("method Test not found")
113-
}
114-
requestMsgDesc := fd.Messages().ByName("Request")
115-
if requestMsgDesc == nil {
116-
return errors.New("message Request not found")
117-
}
118-
responseMsgDesc := fd.Messages().ByName("Response")
119-
if responseMsgDesc == nil {
120-
return errors.New("message Response not found")
47+
svcDesc := &descriptorpb.ServiceDescriptorProto{
48+
Name: proto.String(svcName),
49+
}
50+
51+
for _, m := range s.Methods {
52+
inDesc := m.Input.ProtoReflect().Descriptor()
53+
outDesc := m.Output.ProtoReflect().Descriptor()
54+
inName := string(inDesc.FullName())
55+
outName := string(outDesc.FullName())
56+
57+
svcDesc.Method = append(svcDesc.Method, &descriptorpb.MethodDescriptorProto{
58+
Name: proto.String(m.Name),
59+
InputType: proto.String("." + inName),
60+
OutputType: proto.String("." + outName),
61+
})
62+
}
63+
packages[pkgName] = append(packages[pkgName], svcDesc)
12164
}
122-
requestType = dynamicpb.NewMessageType(requestMsgDesc)
123-
responseType = dynamicpb.NewMessageType(responseMsgDesc)
124-
return nil
125-
}
12665

127-
// Helpers for Mock messages
66+
for pkg, svcDescriptors := range packages {
67+
// Collect dependencies
68+
deps := make(map[string]struct{})
12869

129-
const panicMsg = "mock.Register() must be called before using NewRequest/NewResponse"
70+
// Iterate over the original services to find dependencies for this package.
71+
for _, s := range services {
72+
pName, _, found := strings.Cut(s.Name, ".")
73+
if !found || pName != pkg {
74+
continue
75+
}
13076

131-
func NewRequest(val string) proto.Message {
132-
if requestType == nil {
133-
panic(panicMsg)
134-
}
135-
msg := requestType.New()
136-
if val != "" {
137-
fd := msg.Descriptor().Fields().ByName("val")
138-
if fd != nil {
139-
msg.Set(fd, protoreflect.ValueOfString(val))
77+
for _, m := range s.Methods {
78+
if d := m.Input.ProtoReflect().Descriptor().ParentFile(); d != nil {
79+
deps[d.Path()] = struct{}{}
80+
}
81+
if d := m.Output.ProtoReflect().Descriptor().ParentFile(); d != nil {
82+
deps[d.Path()] = struct{}{}
83+
}
84+
}
14085
}
141-
}
142-
return msg.Interface()
143-
}
14486

145-
func NewResponse(val string) proto.Message {
146-
if responseType == nil {
147-
panic(panicMsg)
148-
}
149-
msg := responseType.New()
150-
if val != "" {
151-
fd := msg.Descriptor().Fields().ByName("val")
152-
if fd != nil {
153-
msg.Set(fd, protoreflect.ValueOfString(val))
87+
fd := &descriptorpb.FileDescriptorProto{
88+
Name: proto.String(fmt.Sprintf("mock/%s.proto", pkg)),
89+
Package: proto.String(pkg),
90+
Service: svcDescriptors,
91+
}
92+
93+
for dep := range deps {
94+
fd.Dependency = append(fd.Dependency, dep)
95+
}
96+
97+
// Check if already registered
98+
if _, err := protoregistry.GlobalFiles.FindFileByPath(fd.GetName()); err == nil {
99+
continue // Already registered
100+
}
101+
102+
fileDesc, err := protodesc.NewFile(fd, protoregistry.GlobalFiles)
103+
if err != nil {
104+
return fmt.Errorf("failed to create file descriptor for %s: %v", pkg, err)
105+
}
106+
107+
if err := protoregistry.GlobalFiles.RegisterFile(fileDesc); err != nil {
108+
return fmt.Errorf("failed to register file %s: %v", pkg, err)
154109
}
155110
}
156-
return msg.Interface()
111+
return nil
157112
}
158113

114+
// Helpers for Mock messages
115+
159116
func GetVal(msg proto.Message) string {
160117
if msg == nil {
161118
return ""
162119
}
163-
m := msg.ProtoReflect()
164-
if !m.IsValid() {
165-
return ""
166-
}
167-
fd := m.Descriptor().Fields().ByName("val")
168-
if fd == nil {
169-
return ""
120+
if m, ok := msg.(*wrapperspb.StringValue); ok {
121+
return m.Value
170122
}
171-
return m.Get(fd).String()
123+
return ""
172124
}
173125

174126
func SetVal(msg proto.Message, val string) {
175127
if msg == nil {
176128
return
177129
}
178-
m := msg.ProtoReflect()
179-
if !m.IsValid() {
180-
return
181-
}
182-
fd := m.Descriptor().Fields().ByName("val")
183-
if fd == nil {
184-
return
130+
if m, ok := msg.(*wrapperspb.StringValue); ok {
131+
m.Value = val
185132
}
186-
m.Set(fd, protoreflect.ValueOfString(val))
187133
}

main_test.go

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
package gorums
2+
3+
import (
4+
"os"
5+
"testing"
6+
7+
"github.com/relab/gorums/internal/testutils/mock"
8+
pb "google.golang.org/protobuf/types/known/wrapperspb"
9+
)
10+
11+
func TestMain(m *testing.M) {
12+
// Register the default mock services for integration tests.
13+
if err := mock.RegisterServices([]mock.Service{
14+
{
15+
Name: "mock.MockService",
16+
Methods: []mock.Method{
17+
{
18+
Name: "Test",
19+
Input: &pb.StringValue{},
20+
Output: &pb.StringValue{},
21+
},
22+
{
23+
Name: "GetValue",
24+
Input: &pb.Int32Value{},
25+
Output: &pb.Int32Value{},
26+
},
27+
},
28+
},
29+
}); err != nil {
30+
panic(err)
31+
}
32+
33+
os.Exit(m.Run())
34+
}

0 commit comments

Comments
 (0)