Skip to content

Commit

Permalink
fix: generic headers type print issue
Browse files Browse the repository at this point in the history
add: test for generated stub
  • Loading branch information
sonalys committed Feb 29, 2024
1 parent 0b115f9 commit baf1ac4
Show file tree
Hide file tree
Showing 6 changed files with 128 additions and 73 deletions.
11 changes: 8 additions & 3 deletions ast.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,13 @@ func (f *ParsedInterface) printAstExpr(expr ast.Expr) string {
// Extract package and type name
switch fieldType := expr.(type) {
case *ast.Ident:
if strings.ToLower(fieldType.Name) == fieldType.Name {
// If it's a generic type, we don't need to print package name with it.
for _, name := range f.GenericsName {
if name == fieldType.Name {
return fieldType.Name
}
}
if strings.ToLower(fieldType.Name[:1]) == fieldType.Name[:1] {
return fieldType.Name
}
// If we have an object, that means we need to translate the type from mock package to current package.
Expand Down Expand Up @@ -62,15 +68,14 @@ func (f *ParsedInterface) printAstExpr(expr ast.Expr) string {
b := &strings.Builder{}
for _, method := range methods {
methodName := method.Ref.Names[0].Name
b.WriteString("\t")
b.WriteString("\t\t")
f.PrintMethodHeader(b, methodName, &ParsedField{
Interface: f,
Ref: &ast.Field{
Type: method.Ref.Type.(*ast.FuncType),
},
Name: methodName,
})
b.WriteString("\n")
}
return fmt.Sprintf("interface{\n%s\n}", b.String())
}
Expand Down
70 changes: 50 additions & 20 deletions boilerplate/call.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,47 +6,77 @@ import (
)

type (
// Config represents the configuration for all the functions passed to the On... function.
Config interface {
// Repeat sets how many times the function group should be called, note that if more than 1 function is given,
// all the functions should repeat n times.
// Repeat is 1 by default, meaning that functions can only be called 1 time.
// Set Repeat(-1) to allow the group to repeat indefinitely.
Repeat(times int)
// Maybe sets the group as not required for AssertExpectations,
// meaning that the function group will not fail the test if not called.
Maybe()
}

call[T any] struct {
Call[T any] struct {
lock *sync.Mutex
repeat int
maybe bool
cur int
hooks []T
}

mock[T any] struct {
Mock[T any] struct {
lock *sync.Mutex
calls []*call[T]
calls []*Call[T]
}
)

var (
// RepeatForever, can be used with:
// OnMock().Repeat(mocks.RepeatForever).
RepeatForever int = -1
)

// setupLocker is a sync.Once func to configure sync.Mutex in case newMock wasn't called.
func setupLocker() *sync.Mutex { return &sync.Mutex{} }

func newMock[T any](t *testing.T) mock[T] {
value := mock[T]{
lock: &sync.Mutex{},
func (c *Call[T]) Repeat(times int) {
c.lock = sync.OnceValue(setupLocker)()
c.lock.Lock()
defer c.lock.Unlock()
c.repeat = times
}

func (c *Call[T]) Maybe() {
c.lock = sync.OnceValue(setupLocker)()
c.lock.Lock()
defer c.lock.Unlock()
c.maybe = true
}

func NewMock[T any](t *testing.T) Mock[T] {
value := Mock[T]{
lock: sync.OnceValue(setupLocker)(),
}
t.Cleanup(func() {
value.AssertExpectations(t)
})
return value
}

func (c *mock[T]) AssertExpectations(t *testing.T) bool {
// AssertExpectations asserts that all expected function calls have been called.
// Returns true if all expectations were met, otherwise returns false.
func (c *Mock[T]) AssertExpectations(t *testing.T) bool {
c.lock = sync.OnceValue(setupLocker)()
c.lock.Lock()
defer c.lock.Unlock()

var missingCalls int
for _, call := range c.calls {
call.lock.Lock()
if call.maybe {
continue
}
if call.repeat <= 0 {
call.repeat = 1
}
Expand All @@ -60,14 +90,10 @@ func (c *mock[T]) AssertExpectations(t *testing.T) bool {
return true
}

func (c *call[T]) Repeat(times int) {
c.lock = sync.OnceValue(setupLocker)()
c.lock.Lock()
defer c.lock.Unlock()
c.repeat = times
}

func (c *call[T]) draw() (f T, empty bool) {
// Draw is a card Draw design, in which each call determines when to get removed from the deck.
// if repeat == 0, the call gets removed from deck.
// setting repeat = -1 will skip this condition, allowing it to repeat indefinitely through the group.
func (c *Call[T]) Draw() (f T, empty bool) {
c.lock = sync.OnceValue(setupLocker)()
c.lock.Lock()
defer c.lock.Unlock()
Expand All @@ -80,25 +106,29 @@ func (c *call[T]) draw() (f T, empty bool) {
return f, c.repeat == -1
}

func (c *mock[T]) call() (*T, bool) {
// Call returns a func of type T and a bool from the deck.
// It either returns (func, true) or (nil, false).
func (c *Mock[T]) Call() (*T, bool) {
c.lock = sync.OnceValue(setupLocker)()
c.lock.Lock()
defer c.lock.Unlock()
if len(c.calls) == 0 {
return nil, false
}
first, empty := c.calls[0].draw()
first, empty := c.calls[0].Draw()
if empty {
c.calls = c.calls[1:]
}
return &first, true
}

func (c *mock[T]) append(f ...T) Config {
// Append creates a new card for the group of functions given, returning Config.
// With Config you can configure the group expectations.
func (c *Mock[T]) Append(f ...T) Config {
c.lock = sync.OnceValue(setupLocker)()
c.lock.Lock()
defer c.lock.Unlock()
call := &call[T]{
call := &Call[T]{
hooks: f,
lock: &sync.Mutex{},
}
Expand Down
6 changes: 5 additions & 1 deletion file.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"io"
"os"
"path/filepath"
"strings"
)

type ParsedFile struct {
Expand Down Expand Up @@ -59,6 +60,9 @@ func (f *ParsedFile) ListInterfaces() []*ParsedInterface {
Name: typeSpec.Name.Name,
}
cur.GenericsHeader, cur.GenericsName = cur.GetGenericsInfo()
if len(cur.GenericsName) > 0 {
cur.GenericsNamelessHeader = fmt.Sprintf("[%s]", strings.Join(cur.GenericsName, ","))
}
resp = append(resp, cur)
}
}
Expand All @@ -70,7 +74,7 @@ func (f *ParsedFile) WriteImports(w io.Writer) {
fmt.Fprintf(w, "import (\n")
fmt.Fprintf(w, "\t\"fmt\"\n")
fmt.Fprintf(w, "\t\"testing\"\n")
fmt.Fprintf(w, "\t_ \"github.com/sonalys/fake/boilerplate\"\n")
fmt.Fprintf(w, "\tmockSetup \"github.com/sonalys/fake/boilerplate\"\n")
for name := range f.UsedImports {
info, ok := f.Imports[name]
if !ok {
Expand Down
44 changes: 21 additions & 23 deletions interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,13 @@ import (
)

type ParsedInterface struct {
ParsedFile *ParsedFile
Type *ast.TypeSpec
Ref *ast.InterfaceType
Name string
GenericsHeader string
GenericsName string
ParsedFile *ParsedFile
Type *ast.TypeSpec
Ref *ast.InterfaceType
Name string
GenericsHeader string
GenericsNamelessHeader string
GenericsName []string
}

func (i *ParsedInterface) ListFields() []*ParsedField {
Expand All @@ -43,18 +44,16 @@ func (g *Generator) ListInterfaceFields(i *ParsedInterface, imports map[string]*
return resp
}

func (f *ParsedInterface) getTypeGenerics(t *ast.TypeSpec) (string, string) {
func (f *ParsedInterface) getTypeGenerics(t *ast.TypeSpec) (string, []string) {
var genericsHeader string
var genericsNames string
var genericsNames []string
if t.TypeParams != nil {
types := []string{}
names := []string{}
for _, t := range t.TypeParams.List {
types = append(types, fmt.Sprintf("%s %s", t.Names[0].Name, f.printAstExpr(t.Type)))
names = append(names, t.Names[0].Name)
genericsNames = append(genericsNames, t.Names[0].Name)
}
genericsHeader = fmt.Sprintf("[%s]", strings.Join(types, ","))
genericsNames = fmt.Sprintf("[%s]", strings.Join(names, ","))
}
return genericsHeader, genericsNames
}
Expand Down Expand Up @@ -101,26 +100,26 @@ func (g *Generator) ParseInterface(ident *ast.SelectorExpr, usedImports map[stri
return nil
}

func (i *ParsedInterface) GetGenericsInfo() (string, string) {
func (i *ParsedInterface) GetGenericsInfo() (string, []string) {
return i.getTypeGenerics(i.Type)
}

func (i *ParsedInterface) WriteStruct(w io.Writer) {
// Write struct definition implementing the interface
fmt.Fprintf(w, "type %s%s struct {\n", i.Name, i.GenericsHeader)
for _, field := range i.ListFields() {
fmt.Fprintf(w, "\tsetup%s mock[", field.Name)
fmt.Fprintf(w, "\tsetup%s mockSetup.Mock[", field.Name)
i.PrintMethodHeader(w, "func", field)
fmt.Fprintf(w, "]\n")
}
fmt.Fprintf(w, "}\n\n")
}

func (i *ParsedInterface) WriteInitializer(w io.Writer) {
fmt.Fprintf(w, "func New%s%s(t *testing.T) *%s%s {\n", i.Name, i.GenericsHeader, i.Name, i.GenericsName)
fmt.Fprintf(w, "\treturn &%s{\n", i.Name)
fmt.Fprintf(w, "func New%s%s(t *testing.T) *%s%s {\n", i.Name, i.GenericsHeader, i.Name, i.GenericsNamelessHeader)
fmt.Fprintf(w, "\treturn &%s%s{\n", i.Name, i.GenericsNamelessHeader)
for _, field := range i.ListFields() {
fmt.Fprintf(w, "\t\tsetup%s: newMock[", field.Name)
fmt.Fprintf(w, "\t\tsetup%s: mockSetup.NewMock[", field.Name)
i.PrintMethodHeader(w, "func", field)
fmt.Fprintf(w, "](t),\n")
}
Expand All @@ -129,7 +128,7 @@ func (i *ParsedInterface) WriteInitializer(w io.Writer) {
}

func (i *ParsedInterface) WriteAssertExpectations(w io.Writer) {
fmt.Fprintf(w, "func (s *%s) AssertExpectations(t *testing.T) bool {\n", i.Name)
fmt.Fprintf(w, "func (s *%s%s) AssertExpectations(t *testing.T) bool {\n", i.Name, i.GenericsNamelessHeader)
fmt.Fprintf(w, "\treturn ")
for _, field := range i.ListFields() {
fmt.Fprintf(w, "s.setup%s.AssertExpectations(t) &&\n\t\t", field.Name)
Expand All @@ -139,16 +138,15 @@ func (i *ParsedInterface) WriteAssertExpectations(w io.Writer) {
}

func (i *ParsedInterface) WriteOnMethod(w io.Writer, methodName string, f *ParsedField) {
fmt.Fprintf(w, "func (s *%s) On%s(funcs ...", i.Name, methodName)
fmt.Fprintf(w, "func (s *%s%s) On%s(funcs ...", i.Name, i.GenericsNamelessHeader, methodName)
i.PrintMethodHeader(w, "func", f)
fmt.Fprintf(w, ") Config {\n")
fmt.Fprintf(w, "\treturn s.setup%s.append(funcs...)\n", methodName)
fmt.Fprintf(w, ") mockSetup.Config {\n")
fmt.Fprintf(w, "\treturn s.setup%s.Append(funcs...)\n", methodName)
fmt.Fprintf(w, "}\n\n")
}

func (i *ParsedInterface) WriteMethod(w io.Writer, methodName string, f *ParsedField) {
_, genericsNames := i.GetGenericsInfo()
fmt.Fprintf(w, "func (s *%s%s) ", i.Name, genericsNames)
fmt.Fprintf(w, "func (s *%s%s) ", i.Name, i.GenericsNamelessHeader)
i.PrintMethodHeader(w, methodName, f)
fmt.Fprintf(w, "{\n")
var callingNames []string
Expand All @@ -168,7 +166,7 @@ func (i *ParsedInterface) WriteMethod(w io.Writer, methodName string, f *ParsedF
argFlag = append(argFlag, "%v")
}
}
fmt.Fprintf(w, "\tf, ok := s.setup%s.call()\n", methodName)
fmt.Fprintf(w, "\tf, ok := s.setup%s.Call()\n", methodName)
fmt.Fprintf(w, "\tif !ok {\n")
fmt.Fprintf(
w, "\t\tpanic(fmt.Sprintf(\"unexpected call %s(%s)\", %v))\n",
Expand Down
19 changes: 19 additions & 0 deletions testdata/out/mock_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package out

import (
"testing"

stub "github.com/sonalys/fake/testdata"
mocks "github.com/sonalys/fake/testdata/out/testdata"
)

func Test_Mock(t *testing.T) {
mock := mocks.NewStubInterface[int](t)
var Stub stub.StubInterface[int] = mock
mock.OnWeirdFunc1(func(a any, b interface{ A() int }) {
if a == nil {
t.Fail()
}
})
Stub.WeirdFunc1(1, nil)
}
Loading

0 comments on commit baf1ac4

Please sign in to comment.