Skip to content

Commit

Permalink
Add explicit AutoBinding capability, with tests
Browse files Browse the repository at this point in the history
  • Loading branch information
dabbertorres committed Aug 21, 2021
1 parent 9e24387 commit 82059a6
Show file tree
Hide file tree
Showing 13 changed files with 375 additions and 47 deletions.
34 changes: 34 additions & 0 deletions autowire.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,40 @@ func Bind(iface, to interface{}) Binding {
return Binding{}
}

// An AutoBinding makes a concrete type available to be bound to any interfaces it implements.
type AutoBinding struct{}

// AutoBind declares that a concrete type should be used to satisfy any dependencies on interfaces
// that it implements. typ must be a pointer to a concrete type.
//
// Example:
//
// type Fooer interface {
// Foo()
// }
//
// type Barer interface {
// Bar()
// }
//
// type MyFooBar struct{}
//
// func (MyFooBar) Foo() {}
// func (MyFooBar) Bar() {}
//
// func useFoo(foo Fooer) error {}
// func useBar(bar Barer) error {}
//
// var MySet = wire.NewSet(
// wire.Struct(new(MyFooBar)),
// wire.AutoBind(new(MyFooBar)),
// useFoo, // *FooBar is injected
// useBar, // *Foobar is injected
// )
func AutoBind(typ interface{}) AutoBinding {
return AutoBinding{}
}

// bindToUsePointer is detected by the autowire tool to indicate that Bind's second argument should take a pointer.
// See https://github.com/dabbertorres/autowire/issues/120 for details.
const bindToUsePointer = true
Expand Down
79 changes: 46 additions & 33 deletions internal/autowire/analyze.go
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,18 @@ func verifyArgsUsed(set *ProviderSet, used []*providerSetSrc) []error {
errs = append(errs, fmt.Errorf("unused provider %q", p.Pkg.Name()+"."+p.Name))
}
}
for _, ab := range set.AutoBindings {
found := false
for _, u := range used {
if u.AutoBinding == ab {
found = true
break
}
}
if !found {
errs = append(errs, fmt.Errorf("unused auto binding %q", types.TypeString(ab.Concrete, nil)))
}
}
for _, v := range set.Values {
found := false
for _, u := range used {
Expand Down Expand Up @@ -458,49 +470,50 @@ func verifyAcyclic(providerMap *typeutil.Map, hasher typeutil.Hasher) []error {
continue
}
pt := x.(*ProvidedType)
var args []types.Type
switch {
case pt.IsValue():
// Leaf: values do not have dependencies.
case pt.IsArg():
// Injector arguments do not have dependencies.
case pt.IsProvider() || pt.IsField():
var args []types.Type
if pt.IsProvider() {
for _, arg := range pt.Provider().Args {
args = append(args, arg.Type)
}
} else {
args = append(args, pt.Field().Parent)
case pt.IsProvider():
for _, arg := range pt.Provider().Args {
args = append(args, arg.Type)
}
for _, a := range args {
hasCycle := false
for i, b := range curr {
if types.Identical(a, b) {
sb := new(strings.Builder)
fmt.Fprintf(sb, "cycle for %s:\n", types.TypeString(a, nil))
for j := i; j < len(curr); j++ {
t := providerMap.At(curr[j]).(*ProvidedType)
if t.IsProvider() {
p := t.Provider()
fmt.Fprintf(sb, "%s (%s.%s) ->\n", types.TypeString(curr[j], nil), p.Pkg.Path(), p.Name)
} else {
p := t.Field()
fmt.Fprintf(sb, "%s (%s.%s) ->\n", types.TypeString(curr[j], nil), p.Parent, p.Name)
}
case pt.IsAutoBinding():
args = append(args, pt.AutoBinding().Concrete)
case pt.IsField():
args = append(args, pt.Field().Parent)
default:
panic("invalid provider map value")
}

for _, a := range args {
hasCycle := false
for i, b := range curr {
if types.Identical(a, b) {
sb := new(strings.Builder)
fmt.Fprintf(sb, "cycle for %s:\n", types.TypeString(a, nil))
for j := i; j < len(curr); j++ {
t := providerMap.At(curr[j]).(*ProvidedType)
if t.IsProvider() {
p := t.Provider()
fmt.Fprintf(sb, "%s (%s.%s) ->\n", types.TypeString(curr[j], nil), p.Pkg.Path(), p.Name)
} else {
p := t.Field()
fmt.Fprintf(sb, "%s (%s.%s) ->\n", types.TypeString(curr[j], nil), p.Parent, p.Name)
}
fmt.Fprintf(sb, "%s", types.TypeString(a, nil))
ec.add(errors.New(sb.String()))
hasCycle = true
break
}
}
if !hasCycle {
next := append(append([]types.Type(nil), curr...), a)
stk = append(stk, next)
fmt.Fprintf(sb, "%s", types.TypeString(a, nil))
ec.add(errors.New(sb.String()))
hasCycle = true
break
}
}
default:
panic("invalid provider map value")
if !hasCycle {
next := append(append([]types.Type(nil), curr...), a)
stk = append(stk, next)
}
}
}
}
Expand Down
116 changes: 102 additions & 14 deletions internal/autowire/parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ import (
// Exactly one of the fields will be set.
type providerSetSrc struct {
Provider *Provider
AutoBinding *AutoBinding
Binding *IfaceBinding
Value *Value
Import *ProviderSet
Expand All @@ -57,6 +58,8 @@ func (p *providerSetSrc) description(fset *token.FileSet, typ types.Type) string
kind = "struct provider"
}
return fmt.Sprintf("%s %s(%s)", kind, quoted(p.Provider.Name), fset.Position(p.Provider.Pos))
case p.AutoBinding != nil:
return fmt.Sprintf("wire.AutoBind (%s)", fset.Position(p.AutoBinding.Pos))
case p.Binding != nil:
return fmt.Sprintf("autowire.Bind (%s)", fset.Position(p.Binding.Pos))
case p.Value != nil:
Expand Down Expand Up @@ -98,12 +101,13 @@ type ProviderSet struct {
// variable.
VarName string

Providers []*Provider
Bindings []*IfaceBinding
Values []*Value
Fields []*Field
Imports []*ProviderSet
// InjectorArgs is only filled in for autowire.Build.
Providers []*Provider
Bindings []*IfaceBinding
AutoBindings []*AutoBinding
Values []*Value
Fields []*Field
Imports []*ProviderSet
// InjectorArgs is only filled in for wire.Build.
InjectorArgs *InjectorArgs

// providerMap maps from provided type to a *ProvidedType.
Expand All @@ -125,6 +129,22 @@ func (set *ProviderSet) Outputs() []types.Type {
func (set *ProviderSet) For(t types.Type) ProvidedType {
pt := set.providerMap.At(t)
if pt == nil {
// if t is an interface, we may have an AutoBinding that implements it.
iface, ok := t.Underlying().(*types.Interface)
if !ok {
return ProvidedType{}
}

for _, ab := range set.AutoBindings {
if types.Implements(ab.Concrete, iface) {
// cache for later
pt := &ProvidedType{t: ab.Concrete, ab: ab}
set.providerMap.Set(t, pt)
set.srcMap.Set(t, &providerSetSrc{AutoBinding: ab})
return *pt
}
}

return ProvidedType{}
}
return *pt.(*ProvidedType)
Expand Down Expand Up @@ -179,6 +199,17 @@ type Provider struct {
HasErr bool
}

// AutoBinding records the signature of a provider eligible for auto-binding
// to interfaces it implements. A provider is a single Go object, either a
// function or a named type.
type AutoBinding struct {
// Concrete is always a type that implements N number of interfaces.
Concrete types.Type

// Pos is the position where the binding was declared.
Pos token.Pos
}

// ProviderInput describes an incoming edge in the provider graph.
type ProviderInput struct {
Type types.Type
Expand Down Expand Up @@ -520,7 +551,7 @@ func (oc *objectCache) varDecl(obj *types.Var) *ast.ValueSpec {
}

// processExpr converts an expression into a Wire structure. It may return a
// *Provider, an *IfaceBinding, a *ProviderSet, a *Value or a []*Field.
// *Provider, an *AutoBinding, an *IfaceBinding, a *ProviderSet, a *Value or a []*Field.
func (oc *objectCache) processExpr(info *types.Info, pkgPath string, expr ast.Expr, varName string) (interface{}, []error) {
exprPos := oc.fset.Position(expr.Pos())
expr = astutil.Unparen(expr)
Expand All @@ -546,6 +577,12 @@ func (oc *objectCache) processExpr(info *types.Info, pkgPath string, expr ast.Ex
case "NewSet":
pset, errs := oc.processNewSet(info, pkgPath, call, nil, varName)
return pset, notePositionAll(exprPos, errs)
case "AutoBind":
abs, err := processAutoBind(oc.fset, info, call)
if err != nil {
return nil, []error{notePosition(exprPos, err)}
}
return abs, nil
case "Bind":
b, err := processBind(oc.fset, info, call)
if err != nil {
Expand Down Expand Up @@ -607,6 +644,8 @@ func (oc *objectCache) processNewSet(info *types.Info, pkgPath string, call *ast
continue
}
switch item := item.(type) {
case *AutoBinding:
pset.AutoBindings = append(pset.AutoBindings, item)
case *Provider:
pset.Providers = append(pset.Providers, item)
case *ProviderSet:
Expand Down Expand Up @@ -880,6 +919,41 @@ func isPrevented(tag string) bool {
return reflect.StructTag(tag).Get("autowire") == "-"
}

func processAutoBind(fset *token.FileSet, info *types.Info, call *ast.CallExpr) (*AutoBinding, error) {
// Assumes that call.Fun is wire.AutoBind.

if len(call.Args) != 1 {
return nil, notePosition(fset.Position(call.Pos()),
errors.New("call to AutoBind takes exactly one argument"))
}
const firstArgReqFormat = "first argument to AutoBind must be a pointer to a type; found %s"
typ := info.TypeOf(call.Args[0])
ptr, ok := typ.(*types.Pointer)
if !ok {
return nil, notePosition(fset.Position(call.Pos()),
fmt.Errorf(firstArgReqFormat, types.TypeString(typ, nil)))
}

switch ptr.Elem().Underlying().(type) {
case *types.Named,
*types.Struct,
*types.Basic:
// good!

default:
return nil, notePosition(fset.Position(call.Pos()),
fmt.Errorf(firstArgReqFormat, types.TypeString(ptr, nil)))
}

typeExpr := call.Args[0].(*ast.CallExpr)
typeName := qualifiedIdentObject(info, typeExpr.Args[0]) // should be either an identifier or selector
autoBinding := &AutoBinding{
Concrete: ptr,
Pos: typeName.Pos(),
}
return autoBinding, nil
}

// processBind creates an interface binding from a autowire.Bind call.
func processBind(fset *token.FileSet, info *types.Info, call *ast.CallExpr) (*IfaceBinding, error) {
// Assumes that call.Fun is autowire.Bind.
Expand Down Expand Up @@ -1122,7 +1196,6 @@ func findInjectorBuild(info *types.Info, fn *ast.FuncDecl) (*ast.CallExpr, error
default:
invalid = true
}

}
if wireBuildCall == nil {
return nil, nil
Expand Down Expand Up @@ -1157,16 +1230,17 @@ func isProviderSetType(t types.Type) bool {
// none of the above, and returns true for IsNil.
type ProvidedType struct {
// t is the provided concrete type.
t types.Type
p *Provider
v *Value
a *InjectorArg
f *Field
t types.Type
p *Provider
ab *AutoBinding
v *Value
a *InjectorArg
f *Field
}

// IsNil reports whether pt is the zero value.
func (pt ProvidedType) IsNil() bool {
return pt.p == nil && pt.v == nil && pt.a == nil && pt.f == nil
return pt.p == nil && pt.ab == nil && pt.v == nil && pt.a == nil && pt.f == nil
}

// Type returns the output type.
Expand All @@ -1185,6 +1259,11 @@ func (pt ProvidedType) IsProvider() bool {
return pt.p != nil
}

// IsAutoBinding reports whether pt points to an AutoBinding.
func (pt ProvidedType) IsAutoBinding() bool {
return pt.ab != nil
}

// IsValue reports whether pt points to a Value.
func (pt ProvidedType) IsValue() bool {
return pt.v != nil
Expand All @@ -1209,6 +1288,15 @@ func (pt ProvidedType) Provider() *Provider {
return pt.p
}

// AutoBinding returns pt as a AutoBinding pointer. It panics if pt does not point
// to a AutoBinding.
func (pt ProvidedType) AutoBinding() *AutoBinding {
if pt.ab == nil {
panic("ProvidedType does not hold an AutoBinding")
}
return pt.ab
}

// Value returns pt as a Value pointer. It panics if pt does not point
// to a Value.
func (pt ProvidedType) Value() *Value {
Expand Down
26 changes: 26 additions & 0 deletions internal/autowire/testdata/InterfaceAutoBinding/foo/autowire.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// Copyright 2018 The Wire Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

//+build wireinject

package main

import (
"github.com/google/wire"
)

func injectFooer() Fooer {
wire.Build(Set, wire.AutoBind(new(Bar)))
return nil
}
Loading

0 comments on commit 82059a6

Please sign in to comment.