From 953c8abfea839f1c8b5829fcfe199c758fd47bb2 Mon Sep 17 00:00:00 2001 From: Alec Iverson Date: Wed, 21 Apr 2021 16:14:51 -0700 Subject: [PATCH] Add explicit AutoBinding capability, with tests --- autowire.go | 34 +++++ internal/autowire/analyze.go | 79 +++++++----- internal/autowire/parse.go | 116 +++++++++++++++--- .../InterfaceAutoBinding/foo/autowire.go | 26 ++++ .../testdata/InterfaceAutoBinding/foo/foo.go | 43 +++++++ .../testdata/InterfaceAutoBinding/pkg | 1 + .../InterfaceAutoBinding/want/autowire_gen.go | 13 ++ .../InterfaceAutoBinding/want/program_out.txt | 1 + .../foo/autowire.go | 30 +++++ .../foo/foo.go | 63 ++++++++++ .../pkg | 1 + .../want/autowire_gen.go | 14 +++ .../want/program_out.txt | 1 + 13 files changed, 375 insertions(+), 47 deletions(-) create mode 100644 internal/autowire/testdata/InterfaceAutoBinding/foo/autowire.go create mode 100644 internal/autowire/testdata/InterfaceAutoBinding/foo/foo.go create mode 100644 internal/autowire/testdata/InterfaceAutoBinding/pkg create mode 100644 internal/autowire/testdata/InterfaceAutoBinding/want/autowire_gen.go create mode 100644 internal/autowire/testdata/InterfaceAutoBinding/want/program_out.txt create mode 100644 internal/autowire/testdata/InterfaceExplicitBindingPreferredOverAutoBinding/foo/autowire.go create mode 100644 internal/autowire/testdata/InterfaceExplicitBindingPreferredOverAutoBinding/foo/foo.go create mode 100644 internal/autowire/testdata/InterfaceExplicitBindingPreferredOverAutoBinding/pkg create mode 100644 internal/autowire/testdata/InterfaceExplicitBindingPreferredOverAutoBinding/want/autowire_gen.go create mode 100644 internal/autowire/testdata/InterfaceExplicitBindingPreferredOverAutoBinding/want/program_out.txt diff --git a/autowire.go b/autowire.go index ea024607..bd900066 100644 --- a/autowire.go +++ b/autowire.go @@ -114,6 +114,40 @@ func Bind(iface, to interface{}) Binding { return Binding{} } +// An AutoBinding makes a concrete type available to be bound to any interfaces it implements. +type AutoBinding struct{} + +// AutoBind declares that a concrete type should be used to satisfy any dependencies on interfaces +// that it implements. typ must be a pointer to a concrete type. +// +// Example: +// +// type Fooer interface { +// Foo() +// } +// +// type Barer interface { +// Bar() +// } +// +// type MyFooBar struct{} +// +// func (MyFooBar) Foo() {} +// func (MyFooBar) Bar() {} +// +// func useFoo(foo Fooer) error {} +// func useBar(bar Barer) error {} +// +// var MySet = wire.NewSet( +// wire.Struct(new(MyFooBar)), +// wire.AutoBind(new(MyFooBar)), +// useFoo, // *FooBar is injected +// useBar, // *Foobar is injected +// ) +func AutoBind(typ interface{}) AutoBinding { + return AutoBinding{} +} + // bindToUsePointer is detected by the autowire tool to indicate that Bind's second argument should take a pointer. // See https://github.com/dabbertorres/autowire/issues/120 for details. const bindToUsePointer = true diff --git a/internal/autowire/analyze.go b/internal/autowire/analyze.go index 7bf7ed63..3afa5b8a 100644 --- a/internal/autowire/analyze.go +++ b/internal/autowire/analyze.go @@ -287,6 +287,18 @@ func verifyArgsUsed(set *ProviderSet, used []*providerSetSrc) []error { errs = append(errs, fmt.Errorf("unused provider %q", p.Pkg.Name()+"."+p.Name)) } } + for _, ab := range set.AutoBindings { + found := false + for _, u := range used { + if u.AutoBinding == ab { + found = true + break + } + } + if !found { + errs = append(errs, fmt.Errorf("unused auto binding %q", types.TypeString(ab.Concrete, nil))) + } + } for _, v := range set.Values { found := false for _, u := range used { @@ -458,49 +470,50 @@ func verifyAcyclic(providerMap *typeutil.Map, hasher typeutil.Hasher) []error { continue } pt := x.(*ProvidedType) + var args []types.Type switch { case pt.IsValue(): // Leaf: values do not have dependencies. case pt.IsArg(): // Injector arguments do not have dependencies. - case pt.IsProvider() || pt.IsField(): - var args []types.Type - if pt.IsProvider() { - for _, arg := range pt.Provider().Args { - args = append(args, arg.Type) - } - } else { - args = append(args, pt.Field().Parent) + case pt.IsProvider(): + for _, arg := range pt.Provider().Args { + args = append(args, arg.Type) } - for _, a := range args { - hasCycle := false - for i, b := range curr { - if types.Identical(a, b) { - sb := new(strings.Builder) - fmt.Fprintf(sb, "cycle for %s:\n", types.TypeString(a, nil)) - for j := i; j < len(curr); j++ { - t := providerMap.At(curr[j]).(*ProvidedType) - if t.IsProvider() { - p := t.Provider() - fmt.Fprintf(sb, "%s (%s.%s) ->\n", types.TypeString(curr[j], nil), p.Pkg.Path(), p.Name) - } else { - p := t.Field() - fmt.Fprintf(sb, "%s (%s.%s) ->\n", types.TypeString(curr[j], nil), p.Parent, p.Name) - } + case pt.IsAutoBinding(): + args = append(args, pt.AutoBinding().Concrete) + case pt.IsField(): + args = append(args, pt.Field().Parent) + default: + panic("invalid provider map value") + } + + for _, a := range args { + hasCycle := false + for i, b := range curr { + if types.Identical(a, b) { + sb := new(strings.Builder) + fmt.Fprintf(sb, "cycle for %s:\n", types.TypeString(a, nil)) + for j := i; j < len(curr); j++ { + t := providerMap.At(curr[j]).(*ProvidedType) + if t.IsProvider() { + p := t.Provider() + fmt.Fprintf(sb, "%s (%s.%s) ->\n", types.TypeString(curr[j], nil), p.Pkg.Path(), p.Name) + } else { + p := t.Field() + fmt.Fprintf(sb, "%s (%s.%s) ->\n", types.TypeString(curr[j], nil), p.Parent, p.Name) } - fmt.Fprintf(sb, "%s", types.TypeString(a, nil)) - ec.add(errors.New(sb.String())) - hasCycle = true - break } - } - if !hasCycle { - next := append(append([]types.Type(nil), curr...), a) - stk = append(stk, next) + fmt.Fprintf(sb, "%s", types.TypeString(a, nil)) + ec.add(errors.New(sb.String())) + hasCycle = true + break } } - default: - panic("invalid provider map value") + if !hasCycle { + next := append(append([]types.Type(nil), curr...), a) + stk = append(stk, next) + } } } } diff --git a/internal/autowire/parse.go b/internal/autowire/parse.go index e9416089..acd91534 100644 --- a/internal/autowire/parse.go +++ b/internal/autowire/parse.go @@ -35,6 +35,7 @@ import ( // Exactly one of the fields will be set. type providerSetSrc struct { Provider *Provider + AutoBinding *AutoBinding Binding *IfaceBinding Value *Value Import *ProviderSet @@ -57,6 +58,8 @@ func (p *providerSetSrc) description(fset *token.FileSet, typ types.Type) string kind = "struct provider" } return fmt.Sprintf("%s %s(%s)", kind, quoted(p.Provider.Name), fset.Position(p.Provider.Pos)) + case p.AutoBinding != nil: + return fmt.Sprintf("wire.AutoBind (%s)", fset.Position(p.AutoBinding.Pos)) case p.Binding != nil: return fmt.Sprintf("autowire.Bind (%s)", fset.Position(p.Binding.Pos)) case p.Value != nil: @@ -98,12 +101,13 @@ type ProviderSet struct { // variable. VarName string - Providers []*Provider - Bindings []*IfaceBinding - Values []*Value - Fields []*Field - Imports []*ProviderSet - // InjectorArgs is only filled in for autowire.Build. + Providers []*Provider + Bindings []*IfaceBinding + AutoBindings []*AutoBinding + Values []*Value + Fields []*Field + Imports []*ProviderSet + // InjectorArgs is only filled in for wire.Build. InjectorArgs *InjectorArgs // providerMap maps from provided type to a *ProvidedType. @@ -125,6 +129,22 @@ func (set *ProviderSet) Outputs() []types.Type { func (set *ProviderSet) For(t types.Type) ProvidedType { pt := set.providerMap.At(t) if pt == nil { + // if t is an interface, we may have an AutoBinding that implements it. + iface, ok := t.Underlying().(*types.Interface) + if !ok { + return ProvidedType{} + } + + for _, ab := range set.AutoBindings { + if types.Implements(ab.Concrete, iface) { + // cache for later + pt := &ProvidedType{t: ab.Concrete, ab: ab} + set.providerMap.Set(t, pt) + set.srcMap.Set(t, &providerSetSrc{AutoBinding: ab}) + return *pt + } + } + return ProvidedType{} } return *pt.(*ProvidedType) @@ -179,6 +199,17 @@ type Provider struct { HasErr bool } +// AutoBinding records the signature of a provider eligible for auto-binding +// to interfaces it implements. A provider is a single Go object, either a +// function or a named type. +type AutoBinding struct { + // Concrete is always a type that implements N number of interfaces. + Concrete types.Type + + // Pos is the position where the binding was declared. + Pos token.Pos +} + // ProviderInput describes an incoming edge in the provider graph. type ProviderInput struct { Type types.Type @@ -520,7 +551,7 @@ func (oc *objectCache) varDecl(obj *types.Var) *ast.ValueSpec { } // processExpr converts an expression into a Wire structure. It may return a -// *Provider, an *IfaceBinding, a *ProviderSet, a *Value or a []*Field. +// *Provider, an *AutoBinding, an *IfaceBinding, a *ProviderSet, a *Value or a []*Field. func (oc *objectCache) processExpr(info *types.Info, pkgPath string, expr ast.Expr, varName string) (interface{}, []error) { exprPos := oc.fset.Position(expr.Pos()) expr = astutil.Unparen(expr) @@ -546,6 +577,12 @@ func (oc *objectCache) processExpr(info *types.Info, pkgPath string, expr ast.Ex case "NewSet": pset, errs := oc.processNewSet(info, pkgPath, call, nil, varName) return pset, notePositionAll(exprPos, errs) + case "AutoBind": + abs, err := processAutoBind(oc.fset, info, call) + if err != nil { + return nil, []error{notePosition(exprPos, err)} + } + return abs, nil case "Bind": b, err := processBind(oc.fset, info, call) if err != nil { @@ -607,6 +644,8 @@ func (oc *objectCache) processNewSet(info *types.Info, pkgPath string, call *ast continue } switch item := item.(type) { + case *AutoBinding: + pset.AutoBindings = append(pset.AutoBindings, item) case *Provider: pset.Providers = append(pset.Providers, item) case *ProviderSet: @@ -880,6 +919,41 @@ func isPrevented(tag string) bool { return reflect.StructTag(tag).Get("autowire") == "-" } +func processAutoBind(fset *token.FileSet, info *types.Info, call *ast.CallExpr) (*AutoBinding, error) { + // Assumes that call.Fun is wire.AutoBind. + + if len(call.Args) != 1 { + return nil, notePosition(fset.Position(call.Pos()), + errors.New("call to AutoBind takes exactly one argument")) + } + const firstArgReqFormat = "first argument to AutoBind must be a pointer to a type; found %s" + typ := info.TypeOf(call.Args[0]) + ptr, ok := typ.(*types.Pointer) + if !ok { + return nil, notePosition(fset.Position(call.Pos()), + fmt.Errorf(firstArgReqFormat, types.TypeString(typ, nil))) + } + + switch ptr.Elem().Underlying().(type) { + case *types.Named, + *types.Struct, + *types.Basic: + // good! + + default: + return nil, notePosition(fset.Position(call.Pos()), + fmt.Errorf(firstArgReqFormat, types.TypeString(ptr, nil))) + } + + typeExpr := call.Args[0].(*ast.CallExpr) + typeName := qualifiedIdentObject(info, typeExpr.Args[0]) // should be either an identifier or selector + autoBinding := &AutoBinding{ + Concrete: ptr, + Pos: typeName.Pos(), + } + return autoBinding, nil +} + // processBind creates an interface binding from a autowire.Bind call. func processBind(fset *token.FileSet, info *types.Info, call *ast.CallExpr) (*IfaceBinding, error) { // Assumes that call.Fun is autowire.Bind. @@ -1122,7 +1196,6 @@ func findInjectorBuild(info *types.Info, fn *ast.FuncDecl) (*ast.CallExpr, error default: invalid = true } - } if wireBuildCall == nil { return nil, nil @@ -1157,16 +1230,17 @@ func isProviderSetType(t types.Type) bool { // none of the above, and returns true for IsNil. type ProvidedType struct { // t is the provided concrete type. - t types.Type - p *Provider - v *Value - a *InjectorArg - f *Field + t types.Type + p *Provider + ab *AutoBinding + v *Value + a *InjectorArg + f *Field } // IsNil reports whether pt is the zero value. func (pt ProvidedType) IsNil() bool { - return pt.p == nil && pt.v == nil && pt.a == nil && pt.f == nil + return pt.p == nil && pt.ab == nil && pt.v == nil && pt.a == nil && pt.f == nil } // Type returns the output type. @@ -1185,6 +1259,11 @@ func (pt ProvidedType) IsProvider() bool { return pt.p != nil } +// IsAutoBinding reports whether pt points to an AutoBinding. +func (pt ProvidedType) IsAutoBinding() bool { + return pt.ab != nil +} + // IsValue reports whether pt points to a Value. func (pt ProvidedType) IsValue() bool { return pt.v != nil @@ -1209,6 +1288,15 @@ func (pt ProvidedType) Provider() *Provider { return pt.p } +// AutoBinding returns pt as a AutoBinding pointer. It panics if pt does not point +// to a AutoBinding. +func (pt ProvidedType) AutoBinding() *AutoBinding { + if pt.ab == nil { + panic("ProvidedType does not hold an AutoBinding") + } + return pt.ab +} + // Value returns pt as a Value pointer. It panics if pt does not point // to a Value. func (pt ProvidedType) Value() *Value { diff --git a/internal/autowire/testdata/InterfaceAutoBinding/foo/autowire.go b/internal/autowire/testdata/InterfaceAutoBinding/foo/autowire.go new file mode 100644 index 00000000..a1f4032f --- /dev/null +++ b/internal/autowire/testdata/InterfaceAutoBinding/foo/autowire.go @@ -0,0 +1,26 @@ +// Copyright 2018 The Wire Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//+build wireinject + +package main + +import ( + "github.com/google/wire" +) + +func injectFooer() Fooer { + wire.Build(Set, wire.AutoBind(new(Bar))) + return nil +} diff --git a/internal/autowire/testdata/InterfaceAutoBinding/foo/foo.go b/internal/autowire/testdata/InterfaceAutoBinding/foo/foo.go new file mode 100644 index 00000000..f005e5e6 --- /dev/null +++ b/internal/autowire/testdata/InterfaceAutoBinding/foo/foo.go @@ -0,0 +1,43 @@ +// Copyright 2018 The Wire Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "fmt" + + "github.com/google/wire" +) + +func main() { + fmt.Println(injectFooer().Foo()) +} + +type Fooer interface { + Foo() string +} + +type Bar string + +func (b *Bar) Foo() string { + return string(*b) +} + +func provideBar() *Bar { + b := new(Bar) + *b = "Hello, World!" + return b +} + +var Set = wire.NewSet(provideBar) diff --git a/internal/autowire/testdata/InterfaceAutoBinding/pkg b/internal/autowire/testdata/InterfaceAutoBinding/pkg new file mode 100644 index 00000000..f7a5c8ce --- /dev/null +++ b/internal/autowire/testdata/InterfaceAutoBinding/pkg @@ -0,0 +1 @@ +example.com/foo diff --git a/internal/autowire/testdata/InterfaceAutoBinding/want/autowire_gen.go b/internal/autowire/testdata/InterfaceAutoBinding/want/autowire_gen.go new file mode 100644 index 00000000..7c7fa825 --- /dev/null +++ b/internal/autowire/testdata/InterfaceAutoBinding/want/autowire_gen.go @@ -0,0 +1,13 @@ +// Code generated by Wire. DO NOT EDIT. + +//go:generate go run github.com/google/wire/cmd/wire +//+build !wireinject + +package main + +// Injectors from wire.go: + +func injectFooer() Fooer { + bar := provideBar() + return bar +} diff --git a/internal/autowire/testdata/InterfaceAutoBinding/want/program_out.txt b/internal/autowire/testdata/InterfaceAutoBinding/want/program_out.txt new file mode 100644 index 00000000..8ab686ea --- /dev/null +++ b/internal/autowire/testdata/InterfaceAutoBinding/want/program_out.txt @@ -0,0 +1 @@ +Hello, World! diff --git a/internal/autowire/testdata/InterfaceExplicitBindingPreferredOverAutoBinding/foo/autowire.go b/internal/autowire/testdata/InterfaceExplicitBindingPreferredOverAutoBinding/foo/autowire.go new file mode 100644 index 00000000..6fc6bc51 --- /dev/null +++ b/internal/autowire/testdata/InterfaceExplicitBindingPreferredOverAutoBinding/foo/autowire.go @@ -0,0 +1,30 @@ +// Copyright 2018 The Wire Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//+build wireinject + +package main + +import ( + "github.com/google/wire" +) + +func injectPlugher() Plugher { + wire.Build( + Set, + wire.AutoBind(new(Qux)), + wire.Bind(new(Fooer), new(*Bar)), + ) + return nil +} diff --git a/internal/autowire/testdata/InterfaceExplicitBindingPreferredOverAutoBinding/foo/foo.go b/internal/autowire/testdata/InterfaceExplicitBindingPreferredOverAutoBinding/foo/foo.go new file mode 100644 index 00000000..7c427230 --- /dev/null +++ b/internal/autowire/testdata/InterfaceExplicitBindingPreferredOverAutoBinding/foo/foo.go @@ -0,0 +1,63 @@ +// Copyright 2018 The Wire Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "fmt" + + "github.com/google/wire" +) + +func main() { + fmt.Println(injectPlugher().Plugh()) +} + +type Fooer interface { + Foo() string +} + +type Plugher interface { + Plugh() string +} + +type Bar string + +func (b *Bar) Foo() string { + return string(*b) +} + +func provideBar() *Bar { + b := new(Bar) + *b = "Bar!" + return b +} + +type Qux string + +func (q *Qux) Foo() string { + return string(*q) +} + +func (q *Qux) Plugh() string { + return string(*q) + string(*q) +} + +func provideQux(fooer Fooer) *Qux { + b := new(Qux) + *b = Qux("Qux!" + fooer.Foo()) + return b +} + +var Set = wire.NewSet(provideBar, provideQux) diff --git a/internal/autowire/testdata/InterfaceExplicitBindingPreferredOverAutoBinding/pkg b/internal/autowire/testdata/InterfaceExplicitBindingPreferredOverAutoBinding/pkg new file mode 100644 index 00000000..f7a5c8ce --- /dev/null +++ b/internal/autowire/testdata/InterfaceExplicitBindingPreferredOverAutoBinding/pkg @@ -0,0 +1 @@ +example.com/foo diff --git a/internal/autowire/testdata/InterfaceExplicitBindingPreferredOverAutoBinding/want/autowire_gen.go b/internal/autowire/testdata/InterfaceExplicitBindingPreferredOverAutoBinding/want/autowire_gen.go new file mode 100644 index 00000000..6313b19c --- /dev/null +++ b/internal/autowire/testdata/InterfaceExplicitBindingPreferredOverAutoBinding/want/autowire_gen.go @@ -0,0 +1,14 @@ +// Code generated by Wire. DO NOT EDIT. + +//go:generate go run github.com/google/wire/cmd/wire +//+build !wireinject + +package main + +// Injectors from wire.go: + +func injectPlugher() Plugher { + bar := provideBar() + qux := provideQux(bar) + return qux +} diff --git a/internal/autowire/testdata/InterfaceExplicitBindingPreferredOverAutoBinding/want/program_out.txt b/internal/autowire/testdata/InterfaceExplicitBindingPreferredOverAutoBinding/want/program_out.txt new file mode 100644 index 00000000..2c00b960 --- /dev/null +++ b/internal/autowire/testdata/InterfaceExplicitBindingPreferredOverAutoBinding/want/program_out.txt @@ -0,0 +1 @@ +Qux!Bar!Qux!Bar!