diff --git a/cmd/wire/main_test.go b/cmd/wire/main_test.go new file mode 100644 index 00000000..7d9cd93d --- /dev/null +++ b/cmd/wire/main_test.go @@ -0,0 +1,38 @@ +package main + +import ( + "context" + "flag" + "os" + "path/filepath" + "runtime" + "testing" +) + +func TestMethodProvider(t *testing.T) { + /* + // Support type's method as provider, for example: + + func InitDog() *animals.Dog { + panic(wire.Build( + animals.NewAnimals, + (*animals.Animals).NewDog, // pointer receiver method + )) + } + + func InitCat() *animals.Cat { + panic(wire.Build( + wire.Value(animals.Animals{}), + animals.Animals.NewCat, // struct receiver method + )) + } + */ + _, b, _, _ := runtime.Caller(0) + wireRepoPath := filepath.Dir(filepath.Dir(filepath.Dir(b))) + _ = os.Chdir(filepath.Join(wireRepoPath, "tests", "method_provider")) + cmd := &genCmd{} + code := int(cmd.Execute(context.Background(), flag.CommandLine)) + if code != 0 { + t.Fatal(code) + } +} diff --git a/internal/wire/parse.go b/internal/wire/parse.go index 93fbda85..ce2b852c 100644 --- a/internal/wire/parse.go +++ b/internal/wire/parse.go @@ -659,14 +659,21 @@ func qualifiedIdentObject(info *types.Info, expr ast.Expr) types.Object { case *ast.Ident: return info.ObjectOf(expr) case *ast.SelectorExpr: - pkgName, ok := expr.X.(*ast.Ident) - if !ok { - return nil - } - if _, ok := info.ObjectOf(pkgName).(*types.PkgName); !ok { - return nil + switch value := expr.X.(type) { + case *ast.ParenExpr, *ast.SelectorExpr: + if selection, ok := info.Selections[expr]; !ok { + return nil + } else { + return selection.Obj() + } + case *ast.Ident: + pkgName := value + if _, ok := info.ObjectOf(pkgName).(*types.PkgName); !ok { + return nil + } + return info.ObjectOf(expr.Sel) } - return info.ObjectOf(expr.Sel) + return nil default: return nil } @@ -701,6 +708,17 @@ func processFuncProvider(fset *token.FileSet, fn *types.Func) (*Provider, []erro } } } + if recv := sig.Recv(); recv != nil { + switch typ := recv.Type().(type) { + case *types.Pointer: + elem := typ.Elem().(*types.Named) + provider.Name = fmt.Sprintf("(*%s).%s", elem.Obj().Name(), fn.Name()) + provider.Args = append([]ProviderInput{{Type: typ}}, provider.Args...) + case *types.Named: + provider.Name = fmt.Sprintf("(%s).%s", typ.Obj().Name(), fn.Name()) + provider.Args = append([]ProviderInput{{Type: typ}}, provider.Args...) + } + } return provider, nil } diff --git a/internal/wire/wire.go b/internal/wire/wire.go index a9b7a50d..f4331435 100644 --- a/internal/wire/wire.go +++ b/internal/wire/wire.go @@ -523,6 +523,11 @@ func (g *gen) qualifiedID(pkgName, pkgPath, sym string) string { if name == "" { return sym } + if strings.HasPrefix(sym, "(*") { + return sym[:2] + name + "." + sym[2:] + } else if strings.HasPrefix(sym, "(") { + return sym[:1] + name + "." + sym[1:] + } return name + "." + sym } diff --git a/tests/method_provider/animals/animals.go b/tests/method_provider/animals/animals.go new file mode 100644 index 00000000..c6f4d033 --- /dev/null +++ b/tests/method_provider/animals/animals.go @@ -0,0 +1,18 @@ +package animals + +type Dog struct{} +type Cat struct{} + +type Animals struct{} + +func (f *Animals) NewDog() *Dog { + return &Dog{} +} + +func (Animals) NewCat() *Cat { + return &Cat{} +} + +func NewAnimals() *Animals { + return &Animals{} +} diff --git a/tests/method_provider/wire.go b/tests/method_provider/wire.go new file mode 100644 index 00000000..ed902582 --- /dev/null +++ b/tests/method_provider/wire.go @@ -0,0 +1,23 @@ +//+build wireinject +//go:generate wire + +package method_provider + +import ( + "github.com/google/wire" + "github.com/google/wire/tests/method_provider/animals" +) + +func InitDog() *animals.Dog { + panic(wire.Build( + animals.NewAnimals, + (*animals.Animals).NewDog, + )) +} + +func InitCat() *animals.Cat { + panic(wire.Build( + wire.Value(animals.Animals{}), + animals.Animals.NewCat, + )) +} diff --git a/tests/method_provider/wire_gen.go b/tests/method_provider/wire_gen.go new file mode 100644 index 00000000..6a9f025c --- /dev/null +++ b/tests/method_provider/wire_gen.go @@ -0,0 +1,28 @@ +// Code generated by Wire. DO NOT EDIT. + +//go:generate go run github.com/google/wire/cmd/wire +//+build !wireinject + +package method_provider + +import ( + "github.com/google/wire/tests/method_provider/animals" +) + +// Injectors from wire.go: + +func InitDog() *animals.Dog { + animalsAnimals := animals.NewAnimals() + dog := (*animals.Animals).NewDog(animalsAnimals) + return dog +} + +func InitCat() *animals.Cat { + animalsAnimals := _wireAnimalsValue + cat := (animals.Animals).NewCat(animalsAnimals) + return cat +} + +var ( + _wireAnimalsValue = animals.Animals{} +)