diff --git a/mux/node_test.go b/mux/node_test.go index 36decd7..2fb58c2 100644 --- a/mux/node_test.go +++ b/mux/node_test.go @@ -384,21 +384,80 @@ func TestStaticNodeMatchMiddleware(t *testing.T) { runMiddlewareTests(tests, t) } +func TestWildcardNodeMatchMiddleware(t *testing.T) { + paramSize := 3 + middleware1 := buildMockMiddlewareFunc("1") + middleware2 := buildMockMiddlewareFunc("2") + middleware3 := buildMockMiddlewareFunc("3") + middleware4 := buildMockMiddlewareFunc("4") + mw1 := middleware.NewCollection(middleware1) + mw2 := middleware.NewCollection(middleware2, middleware3) + mw3 := middleware.NewCollection(middleware4) + + node1 := NewNode("test", uint8(paramSize)) + node1.PrependMiddleware(mw1) + item := NewNode("{item}", node1.MaxParamsSize()) // wildcardnode + item.PrependMiddleware(mw2) + view := NewNode("view", node1.MaxParamsSize()+1) + view.PrependMiddleware(mw3) + node1.WithChildren(node1.Tree().withNode(item).sort()) + item.WithChildren(item.Tree().withNode(view).sort()) + node1.WithChildren(node1.Tree().Compile()) + + node2 := NewNode("test", uint8(paramSize)) + node2.PrependMiddleware(mw1) + item2 := NewNode("{item}", node1.MaxParamsSize()) // wildcardnode + item2.PrependMiddleware(mw2) + item2.SkipSubPath() + node2.WithChildren(node2.Tree().withNode(item2).sort()) + node2.WithChildren(node2.Tree().Compile()) + + tests := []middlewareTest{ + { + name: "WildcardNode Exact match", + node: node1, + path: "test/item1", + expectedResult: middleware.NewCollection(middleware1, middleware2, middleware3), + }, + { + name: "WildcardNode Subpath match with skipSubPath", + node: node2, + path: "test/item2/random", + expectedResult: middleware.NewCollection(middleware1, middleware2, middleware3), + }, + { + name: "WildcardNode Subpath match without skipSubPath", + node: node1, + path: "test/item3/view", + expectedResult: middleware.NewCollection(middleware1, middleware2, middleware3, middleware4), + }, + { + name: "WildcardNode Subpath No match", + node: node1, + path: "test/item4/nomatch", + expectedResult: middleware.NewCollection(middleware1, middleware2, middleware3), + }, + } + + runMiddlewareTests(tests, t) +} + func runMiddlewareTests(tests []middlewareTest, t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got := tt.node.MatchMiddleware(tt.path) if len(got) != len(tt.expectedResult) { t.Errorf("%s: middleware length mismatch: got= %v, want %v", tt.name, got, tt.expectedResult) - } - for k, v := range tt.expectedResult { - // reflect.DeepEqual do not work for function values. - // hence compare the pointers of functions as a substitute. - // function pointers are unique to each function, even if the functions have the same code. - expectedPointer := reflect.ValueOf(v).Pointer() - gotPointer := reflect.ValueOf(got[k]).Pointer() - if expectedPointer != gotPointer { - t.Errorf("%s: middleware mismatch: got= %v, want %v", tt.name, v, got[k]) + } else { + for k, v := range tt.expectedResult { + // reflect.DeepEqual do not work for function values. + // hence compare the pointers of functions as a substitute. + // function pointers are unique to each function, even if the functions have the same code. + expectedPointer := reflect.ValueOf(v).Pointer() + gotPointer := reflect.ValueOf(got[k]).Pointer() + if expectedPointer != gotPointer { + t.Errorf("%s: middleware mismatch: got= %v, want %v", tt.name, v, got[k]) + } } } })