@@ -303,3 +303,140 @@ func TestIteratorMethods(t *testing.T) {
303303 }
304304 })
305305}
306+
307+ // -------------------------------------------------------------------------
308+ // Custom Aggregation Pattern Tests
309+ // -------------------------------------------------------------------------
310+
311+ // TestCustomAggregation demonstrates how users can define custom aggregation
312+ // functions that operate on *Responses and return custom types.
313+ func TestCustomAggregation (t * testing.T ) {
314+ t .Run ("SameTypeAggregation" , func (t * testing.T ) {
315+ // Aggregation function that returns the same type (Resp -> Resp)
316+ majorityQF := func (resp * Responses [* pb.StringValue ]) (* pb.StringValue , error ) {
317+ replies := resp .IgnoreErrors ().CollectN (2 )
318+ if len (replies ) < 2 {
319+ return nil , ErrIncomplete
320+ }
321+ for _ , v := range replies {
322+ return v , nil
323+ }
324+ return nil , ErrIncomplete
325+ }
326+
327+ responses := []NodeResponse [proto.Message ]{
328+ {NodeID : 1 , Value : pb .String ("response1" ), Err : nil },
329+ {NodeID : 2 , Value : pb .String ("response2" ), Err : nil },
330+ {NodeID : 3 , Value : pb .String ("response3" ), Err : nil },
331+ }
332+ clientCtx := makeClientCtx [* pb.StringValue , * pb.StringValue ](t , 3 , responses )
333+ r := NewResponses (clientCtx )
334+
335+ // Call the aggregation function directly
336+ result , err := majorityQF (r )
337+ if err != nil {
338+ t .Fatalf ("Expected no error, got %v" , err )
339+ }
340+ if result .GetValue () != "response1" && result .GetValue () != "response2" {
341+ t .Errorf ("Expected response1 or response2, got %s" , result .GetValue ())
342+ }
343+ })
344+
345+ t .Run ("CustomReturnType" , func (t * testing.T ) {
346+ // Aggregation function that returns a different type (Resp -> []string)
347+ // This demonstrates the key benefit: Out can differ from In
348+ collectAllValues := func (resp * Responses [* pb.StringValue ]) ([]string , error ) {
349+ replies := resp .IgnoreErrors ().CollectAll ()
350+ if len (replies ) == 0 {
351+ return nil , ErrIncomplete
352+ }
353+ result := make ([]string , 0 , len (replies ))
354+ for _ , v := range replies {
355+ result = append (result , v .GetValue ())
356+ }
357+ return result , nil
358+ }
359+
360+ responses := []NodeResponse [proto.Message ]{
361+ {NodeID : 1 , Value : pb .String ("alpha" ), Err : nil },
362+ {NodeID : 2 , Value : pb .String ("beta" ), Err : nil },
363+ {NodeID : 3 , Value : pb .String ("gamma" ), Err : nil },
364+ }
365+ clientCtx := makeClientCtx [* pb.StringValue , * pb.StringValue ](t , 3 , responses )
366+ r := NewResponses (clientCtx )
367+
368+ // Call the aggregation function directly - returns []string from *Responses[*pb.StringValue]
369+ result , err := collectAllValues (r )
370+ if err != nil {
371+ t .Fatalf ("Expected no error, got %v" , err )
372+ }
373+ if len (result ) != 3 {
374+ t .Errorf ("Expected 3 values, got %d" , len (result ))
375+ }
376+ })
377+
378+ t .Run ("WithFiltering" , func (t * testing.T ) {
379+ // Aggregation function that uses filtering and custom logic
380+ filterAndCount := func (resp * Responses [* pb.StringValue ]) (int , error ) {
381+ count := 0
382+ for range resp .IgnoreErrors ().Filter (func (r NodeResponse [* pb.StringValue ]) bool {
383+ return r .NodeID > 1 // Only nodes 2 and 3
384+ }) {
385+ count ++
386+ }
387+ if count == 0 {
388+ return 0 , ErrIncomplete
389+ }
390+ return count , nil
391+ }
392+
393+ responses := []NodeResponse [proto.Message ]{
394+ {NodeID : 1 , Value : pb .String ("response1" ), Err : nil },
395+ {NodeID : 2 , Value : pb .String ("response2" ), Err : nil },
396+ {NodeID : 3 , Value : pb .String ("response3" ), Err : nil },
397+ }
398+ clientCtx := makeClientCtx [* pb.StringValue , * pb.StringValue ](t , 3 , responses )
399+ r := NewResponses (clientCtx )
400+
401+ // Call the aggregation function directly
402+ count , err := filterAndCount (r )
403+ if err != nil {
404+ t .Fatalf ("Expected no error, got %v" , err )
405+ }
406+ if count != 2 {
407+ t .Errorf ("Expected 2 filtered responses, got %d" , count )
408+ }
409+ })
410+
411+ t .Run ("ErrorHandling" , func (t * testing.T ) {
412+ // Aggregation function that handles errors explicitly
413+ requireAllSuccess := func (resp * Responses [* pb.StringValue ]) (* pb.StringValue , error ) {
414+ var first * pb.StringValue
415+ for r := range resp .Seq () {
416+ if r .Err != nil {
417+ return nil , r .Err
418+ }
419+ if first == nil {
420+ first = r .Value
421+ }
422+ }
423+ if first == nil {
424+ return nil , ErrIncomplete
425+ }
426+ return first , nil
427+ }
428+
429+ responses := []NodeResponse [proto.Message ]{
430+ {NodeID : 1 , Value : pb .String ("response1" ), Err : nil },
431+ {NodeID : 2 , Value : nil , Err : errors .New ("node 2 failed" )},
432+ }
433+ clientCtx := makeClientCtx [* pb.StringValue , * pb.StringValue ](t , 2 , responses )
434+ r := NewResponses (clientCtx )
435+
436+ // Call the aggregation function directly
437+ _ , err := requireAllSuccess (r )
438+ if err == nil {
439+ t .Error ("Expected error, got nil" )
440+ }
441+ })
442+ }
0 commit comments