diff --git a/pkg/icingadb/entitiesbyid_test.go b/pkg/icingadb/entitiesbyid_test.go index 4608ff095..3e1046d7c 100644 --- a/pkg/icingadb/entitiesbyid_test.go +++ b/pkg/icingadb/entitiesbyid_test.go @@ -1,10 +1,13 @@ package icingadb import ( + "context" + "github.com/icinga/icinga-go-library/database" "github.com/icinga/icinga-go-library/types" "github.com/icinga/icingadb/pkg/icingadb/v1" "github.com/stretchr/testify/require" "testing" + "time" ) func TestEntitiesById_Keys(t *testing.T) { @@ -70,3 +73,75 @@ func TestEntitiesById_IDs(t *testing.T) { }) } } + +func TestEntitiesById_Entities(t *testing.T) { + subtests := []struct { + name string + io EntitiesById + }{ + {name: "nil"}, + { + name: "empty", + io: EntitiesById{}, + }, + { + name: "one", + io: EntitiesById{"one": newEntity([]byte{23})}, + }, + { + name: "two", + io: EntitiesById{"one": newEntity([]byte{23}), "two": newEntity([]byte{42})}, + }, + } + + for _, st := range subtests { + t.Run(st.name, func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + expected := make([]database.Entity, 0, len(st.io)) + actual := make([]database.Entity, 0, len(st.io)) + + for _, v := range st.io { + expected = append(expected, v) + } + + ch := st.io.Entities(ctx) + require.NotNil(t, ch) + + for range expected { + select { + case v, ok := <-ch: + require.True(t, ok, "channel closed too early") + actual = append(actual, v) + case <-time.After(time.Second): + require.Fail(t, "channel should not block") + } + } + + require.ElementsMatch(t, expected, actual) + + select { + case v, ok := <-ch: + require.False(t, ok, "channel should be closed, got %#v", v) + case <-time.After(time.Second): + require.Fail(t, "channel should not block") + } + }) + } + + t.Run("closed-ctx", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + ch := EntitiesById{"one": newEntity([]byte{23})}.Entities(ctx) + require.NotNil(t, ch) + + select { + case v, ok := <-ch: + require.False(t, ok, "channel should be closed, got %#v", v) + case <-time.After(time.Second): + require.Fail(t, "channel should not block") + } + }) +}