Skip to content

Commit f5c011e

Browse files
committed
test: add custom aggregation pattern tests for Responses
1 parent 729ecbf commit f5c011e

File tree

1 file changed

+137
-0
lines changed

1 file changed

+137
-0
lines changed

responses_test.go

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -303,3 +303,140 @@ func TestIteratorMethods(t *testing.T) {
303303
}
304304
})
305305
}
306+
307+
// -------------------------------------------------------------------------
308+
// Custom Aggregation Pattern Tests
309+
// -------------------------------------------------------------------------
310+
311+
// TestCustomAggregation demonstrates how users can define custom aggregation
312+
// functions that operate on *Responses and return custom types.
313+
func TestCustomAggregation(t *testing.T) {
314+
t.Run("SameTypeAggregation", func(t *testing.T) {
315+
// Aggregation function that returns the same type (Resp -> Resp)
316+
majorityQF := func(resp *Responses[*pb.StringValue]) (*pb.StringValue, error) {
317+
replies := resp.IgnoreErrors().CollectN(2)
318+
if len(replies) < 2 {
319+
return nil, ErrIncomplete
320+
}
321+
for _, v := range replies {
322+
return v, nil
323+
}
324+
return nil, ErrIncomplete
325+
}
326+
327+
responses := []NodeResponse[proto.Message]{
328+
{NodeID: 1, Value: pb.String("response1"), Err: nil},
329+
{NodeID: 2, Value: pb.String("response2"), Err: nil},
330+
{NodeID: 3, Value: pb.String("response3"), Err: nil},
331+
}
332+
clientCtx := makeClientCtx[*pb.StringValue, *pb.StringValue](t, 3, responses)
333+
r := NewResponses(clientCtx)
334+
335+
// Call the aggregation function directly
336+
result, err := majorityQF(r)
337+
if err != nil {
338+
t.Fatalf("Expected no error, got %v", err)
339+
}
340+
if result.GetValue() != "response1" && result.GetValue() != "response2" {
341+
t.Errorf("Expected response1 or response2, got %s", result.GetValue())
342+
}
343+
})
344+
345+
t.Run("CustomReturnType", func(t *testing.T) {
346+
// Aggregation function that returns a different type (Resp -> []string)
347+
// This demonstrates the key benefit: Out can differ from In
348+
collectAllValues := func(resp *Responses[*pb.StringValue]) ([]string, error) {
349+
replies := resp.IgnoreErrors().CollectAll()
350+
if len(replies) == 0 {
351+
return nil, ErrIncomplete
352+
}
353+
result := make([]string, 0, len(replies))
354+
for _, v := range replies {
355+
result = append(result, v.GetValue())
356+
}
357+
return result, nil
358+
}
359+
360+
responses := []NodeResponse[proto.Message]{
361+
{NodeID: 1, Value: pb.String("alpha"), Err: nil},
362+
{NodeID: 2, Value: pb.String("beta"), Err: nil},
363+
{NodeID: 3, Value: pb.String("gamma"), Err: nil},
364+
}
365+
clientCtx := makeClientCtx[*pb.StringValue, *pb.StringValue](t, 3, responses)
366+
r := NewResponses(clientCtx)
367+
368+
// Call the aggregation function directly - returns []string from *Responses[*pb.StringValue]
369+
result, err := collectAllValues(r)
370+
if err != nil {
371+
t.Fatalf("Expected no error, got %v", err)
372+
}
373+
if len(result) != 3 {
374+
t.Errorf("Expected 3 values, got %d", len(result))
375+
}
376+
})
377+
378+
t.Run("WithFiltering", func(t *testing.T) {
379+
// Aggregation function that uses filtering and custom logic
380+
filterAndCount := func(resp *Responses[*pb.StringValue]) (int, error) {
381+
count := 0
382+
for range resp.IgnoreErrors().Filter(func(r NodeResponse[*pb.StringValue]) bool {
383+
return r.NodeID > 1 // Only nodes 2 and 3
384+
}) {
385+
count++
386+
}
387+
if count == 0 {
388+
return 0, ErrIncomplete
389+
}
390+
return count, nil
391+
}
392+
393+
responses := []NodeResponse[proto.Message]{
394+
{NodeID: 1, Value: pb.String("response1"), Err: nil},
395+
{NodeID: 2, Value: pb.String("response2"), Err: nil},
396+
{NodeID: 3, Value: pb.String("response3"), Err: nil},
397+
}
398+
clientCtx := makeClientCtx[*pb.StringValue, *pb.StringValue](t, 3, responses)
399+
r := NewResponses(clientCtx)
400+
401+
// Call the aggregation function directly
402+
count, err := filterAndCount(r)
403+
if err != nil {
404+
t.Fatalf("Expected no error, got %v", err)
405+
}
406+
if count != 2 {
407+
t.Errorf("Expected 2 filtered responses, got %d", count)
408+
}
409+
})
410+
411+
t.Run("ErrorHandling", func(t *testing.T) {
412+
// Aggregation function that handles errors explicitly
413+
requireAllSuccess := func(resp *Responses[*pb.StringValue]) (*pb.StringValue, error) {
414+
var first *pb.StringValue
415+
for r := range resp.Seq() {
416+
if r.Err != nil {
417+
return nil, r.Err
418+
}
419+
if first == nil {
420+
first = r.Value
421+
}
422+
}
423+
if first == nil {
424+
return nil, ErrIncomplete
425+
}
426+
return first, nil
427+
}
428+
429+
responses := []NodeResponse[proto.Message]{
430+
{NodeID: 1, Value: pb.String("response1"), Err: nil},
431+
{NodeID: 2, Value: nil, Err: errors.New("node 2 failed")},
432+
}
433+
clientCtx := makeClientCtx[*pb.StringValue, *pb.StringValue](t, 2, responses)
434+
r := NewResponses(clientCtx)
435+
436+
// Call the aggregation function directly
437+
_, err := requireAllSuccess(r)
438+
if err == nil {
439+
t.Error("Expected error, got nil")
440+
}
441+
})
442+
}

0 commit comments

Comments
 (0)