diff --git a/kad/trie/trie.go b/kad/trie/trie.go index 588d235..b76cca3 100644 --- a/kad/trie/trie.go +++ b/kad/trie/trie.go @@ -129,7 +129,7 @@ func (tr *Trie[K, D]) addAtDepth(depth int, kk K, data D) bool { } } -// Add adds the key to trie, returning a new trie. +// Add adds the key to trie, returning a new trie if the key was not already in the trie. // Add is immutable/non-destructive: the original trie remains unchanged. func Add[K kad.Key[K], D any](tr *Trie[K, D], kk K, data D) (*Trie[K, D], error) { return addAtDepth(0, tr, kk, data), nil @@ -148,8 +148,12 @@ func addAtDepth[K kad.Key[K], D any](depth int, tr *Trie[K, D], kk K, data D) *T default: dir := kk.Bit(depth) + b := addAtDepth(depth+1, tr.branch[dir], kk, data) + if b == tr.branch[dir] { + return tr + } s := &Trie[K, D]{} - s.branch[dir] = addAtDepth(depth+1, tr.branch[dir], kk, data) + s.branch[dir] = b s.branch[1-dir] = tr.branch[1-dir] return s } diff --git a/kad/trie/trie_test.go b/kad/trie/trie_test.go index d7acd65..541c37c 100644 --- a/kad/trie/trie_test.go +++ b/kad/trie/trie_test.go @@ -167,6 +167,22 @@ func TestSize(t *testing.T) { require.Equal(t, len(sampleKeySet.Keys), tr.Size()) } +func TestImmutableAddReturnsNewTrie(t *testing.T) { + tr := New[kadtest.Key32, any]() + for _, kk := range sampleKeySet.Keys { + var err error + trBefore := *tr // take a copy of tr before Add is called + trNext, err := Add(tr, kk, nil) + require.NoError(t, err) + // a new trie must be returned + require.NotSame(t, tr, trNext) + // old trie must be not be modified + require.EqualValues(t, trBefore, *tr) + tr = trNext + } + require.Equal(t, len(sampleKeySet.Keys), tr.Size()) +} + func TestAddIgnoresDuplicates(t *testing.T) { tr := New[kadtest.Key32, any]() for _, kk := range sampleKeySet.Keys { @@ -206,6 +222,28 @@ func TestImmutableAddIgnoresDuplicates(t *testing.T) { } } +func TestImmutableAddReturnsOriginalTrieForDuplicates(t *testing.T) { + tr := New[kadtest.Key32, any]() + var err error + for _, kk := range sampleKeySet.Keys { + tr, err = Add(tr, kk, nil) + require.NoError(t, err) + } + require.Equal(t, len(sampleKeySet.Keys), tr.Size()) + + for _, kk := range sampleKeySet.Keys { + next, err := Add(tr, kk, nil) + require.NoError(t, err) + // trie has not been changed + require.Same(t, tr, next) + } + require.Equal(t, len(sampleKeySet.Keys), tr.Size()) + + if d := CheckInvariant(tr); d != nil { + t.Fatalf("reordered trie invariant discrepancy: %v", d) + } +} + func TestAddWithData(t *testing.T) { tr := New[kadtest.Key32, int]() for i, kk := range sampleKeySet.Keys { @@ -223,10 +261,13 @@ func TestAddWithData(t *testing.T) { func TestImmutableAddWithData(t *testing.T) { tr := New[kadtest.Key32, int]() - var err error for i, kk := range sampleKeySet.Keys { - tr, err = Add(tr, kk, i) + var err error + trNext, err := Add(tr, kk, i) require.NoError(t, err) + // a new trie must be returned + require.NotSame(t, tr, trNext) + tr = trNext } require.Equal(t, len(sampleKeySet.Keys), tr.Size()) @@ -264,6 +305,9 @@ func TestImmutableRemove(t *testing.T) { require.NoError(t, err) require.Equal(t, len(sampleKeySet.Keys)-1, trNext.Size()) + // a new trie must be returned + require.NotSame(t, tr, trNext) + if d := CheckInvariant(tr); d != nil { t.Fatalf("reordered trie invariant discrepancy: %v", d) } diff --git a/kad/triert/table_test.go b/kad/triert/table_test.go index caf1f80..a2268a0 100644 --- a/kad/triert/table_test.go +++ b/kad/triert/table_test.go @@ -321,7 +321,8 @@ type nodeFilter struct { } func (f *nodeFilter) TryAdd(rt *TrieRT[kadtest.Key32, node[kadtest.Key32]], - n node[kadtest.Key32]) bool { + n node[kadtest.Key32], +) bool { if n == node2 { return false }