diff --git a/gen.go b/gen.go index f3fc39b..31c6cfc 100644 --- a/gen.go +++ b/gen.go @@ -70,7 +70,7 @@ func (s *Schema) Generate() ([]byte, error) { Definition: schema, } - if !s.AreTitleLinksUnique() { + if !context.Definition.AreTitleLinksUnique() { return nil, fmt.Errorf("duplicate titles detected for %s", context.Name) } diff --git a/gen_test.go b/gen_test.go index f63abfd..288df29 100644 --- a/gen_test.go +++ b/gen_test.go @@ -2,11 +2,134 @@ package schematic import ( "fmt" + "go/ast" + "go/parser" + "go/token" "reflect" "strings" "testing" ) +type walker func(ast.Node) bool + +func (w walker) Visit(node ast.Node) ast.Visitor { + if w(node) { + return w + } + return nil +} + +var generateTests = []struct { + ExpectedServiceFunctions []string + Schema *Schema +}{ + { + ExpectedServiceFunctions: []string{"AccountCreate"}, + Schema: &Schema{ + Title: "Account Manager", + Properties: map[string]*Schema{ + "account": { + Ref: NewReference("#/definitions/account"), + }, + }, + Definitions: map[string]*Schema{ + "account": { + Title: "Account", + Type: "object", + Definitions: map[string]*Schema{ + "id": { + Type: "string", + }, + "name": { + Type: "string", + }, + }, + Links: []*Link{ + { + Title: "Create", + Rel: "create", + HRef: NewHRef("/accounts"), + Method: "POST", + Schema: &Schema{ + Type: "object", + Properties: map[string]*Schema{ + "name": { + Ref: NewReference("#/definitions/account/definitions/name"), + }, + }, + }, + TargetSchema: &Schema{ + Ref: NewReference("#/definitions/account"), + }, + }, + }, + }, + }, + Links: []*Link{ + { + Rel: "self", + HRef: NewHRef("https://accounts.example.com"), + }, + { + Rel: "self", + HRef: NewHRef("/schema"), + Method: "GET", + TargetSchema: &Schema{ + AdditionalProperties: true, + }, + }, + }, + }, + }, +} + +func TestGenerate(t *testing.T) { + for i, tc := range generateTests { + tc := tc + t.Run(fmt.Sprintf("generateTests[%d]", i), func(t *testing.T) { + src, err := tc.Schema.Generate() + if err != nil { + t.Fatal(err) + } + + f, err := parser.ParseFile(token.NewFileSet(), "", src, 0) + if err != nil { + t.Fatal(err) + } + + // Extract all methods on *Service + serviceMethods := make(map[string]bool) + accumulator := func(node ast.Node) bool { + switch v := node.(type) { + case *ast.File: + return true + case *ast.FuncDecl: + if v.Recv != nil && len(v.Recv.List) > 0 { + if funt, ok := v.Recv.List[0].Type.(*ast.StarExpr); ok { + if ident, ok := funt.X.(*ast.Ident); ok { + if ident.Name == "Service" { + // Found a method on *Service + serviceMethods[v.Name.Name] = true + } + } + } + } + return false + default: + return false + } + } + ast.Walk(walker(accumulator), f) + + for _, fn := range tc.ExpectedServiceFunctions { + if !serviceMethods[fn] { + t.Fatalf("expected to find function %s on *Service, but was not present in the generated source", fn) + } + } + }) + } +} + var resolveTests = []struct { Schema *Schema }{