Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding capability to auto-bind a type to interfaces it implements #285

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 46 additions & 33 deletions internal/wire/analyze.go
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,18 @@ func verifyArgsUsed(set *ProviderSet, used []*providerSetSrc) []error {
errs = append(errs, fmt.Errorf("unused provider %q", p.Pkg.Name()+"."+p.Name))
}
}
for _, ab := range set.AutoBindings {
found := false
for _, u := range used {
if u.AutoBinding == ab {
found = true
break
}
}
if !found {
errs = append(errs, fmt.Errorf("unused auto binding %q", types.TypeString(ab.Concrete, nil)))
}
}
for _, v := range set.Values {
found := false
for _, u := range used {
Expand Down Expand Up @@ -458,49 +470,50 @@ func verifyAcyclic(providerMap *typeutil.Map, hasher typeutil.Hasher) []error {
continue
}
pt := x.(*ProvidedType)
var args []types.Type
switch {
case pt.IsValue():
// Leaf: values do not have dependencies.
case pt.IsArg():
// Injector arguments do not have dependencies.
case pt.IsProvider() || pt.IsField():
var args []types.Type
if pt.IsProvider() {
for _, arg := range pt.Provider().Args {
args = append(args, arg.Type)
}
} else {
args = append(args, pt.Field().Parent)
case pt.IsProvider():
for _, arg := range pt.Provider().Args {
args = append(args, arg.Type)
}
for _, a := range args {
hasCycle := false
for i, b := range curr {
if types.Identical(a, b) {
sb := new(strings.Builder)
fmt.Fprintf(sb, "cycle for %s:\n", types.TypeString(a, nil))
for j := i; j < len(curr); j++ {
t := providerMap.At(curr[j]).(*ProvidedType)
if t.IsProvider() {
p := t.Provider()
fmt.Fprintf(sb, "%s (%s.%s) ->\n", types.TypeString(curr[j], nil), p.Pkg.Path(), p.Name)
} else {
p := t.Field()
fmt.Fprintf(sb, "%s (%s.%s) ->\n", types.TypeString(curr[j], nil), p.Parent, p.Name)
}
case pt.IsAutoBinding():
args = append(args, pt.AutoBinding().Concrete)
case pt.IsField():
args = append(args, pt.Field().Parent)
default:
panic("invalid provider map value")
}

for _, a := range args {
hasCycle := false
for i, b := range curr {
if types.Identical(a, b) {
sb := new(strings.Builder)
fmt.Fprintf(sb, "cycle for %s:\n", types.TypeString(a, nil))
for j := i; j < len(curr); j++ {
t := providerMap.At(curr[j]).(*ProvidedType)
if t.IsProvider() {
p := t.Provider()
fmt.Fprintf(sb, "%s (%s.%s) ->\n", types.TypeString(curr[j], nil), p.Pkg.Path(), p.Name)
} else {
p := t.Field()
fmt.Fprintf(sb, "%s (%s.%s) ->\n", types.TypeString(curr[j], nil), p.Parent, p.Name)
}
fmt.Fprintf(sb, "%s", types.TypeString(a, nil))
ec.add(errors.New(sb.String()))
hasCycle = true
break
}
}
if !hasCycle {
next := append(append([]types.Type(nil), curr...), a)
stk = append(stk, next)
fmt.Fprintf(sb, "%s", types.TypeString(a, nil))
ec.add(errors.New(sb.String()))
hasCycle = true
break
}
}
default:
panic("invalid provider map value")
if !hasCycle {
next := append(append([]types.Type(nil), curr...), a)
stk = append(stk, next)
}
}
}
}
Expand Down
114 changes: 101 additions & 13 deletions internal/wire/parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ import (
// Exactly one of the fields will be set.
type providerSetSrc struct {
Provider *Provider
AutoBinding *AutoBinding
Binding *IfaceBinding
Value *Value
Import *ProviderSet
Expand All @@ -57,6 +58,8 @@ func (p *providerSetSrc) description(fset *token.FileSet, typ types.Type) string
kind = "struct provider"
}
return fmt.Sprintf("%s %s(%s)", kind, quoted(p.Provider.Name), fset.Position(p.Provider.Pos))
case p.AutoBinding != nil:
return fmt.Sprintf("wire.AutoBind (%s)", fset.Position(p.AutoBinding.Pos))
case p.Binding != nil:
return fmt.Sprintf("wire.Bind (%s)", fset.Position(p.Binding.Pos))
case p.Value != nil:
Expand Down Expand Up @@ -98,11 +101,12 @@ type ProviderSet struct {
// variable.
VarName string

Providers []*Provider
Bindings []*IfaceBinding
Values []*Value
Fields []*Field
Imports []*ProviderSet
Providers []*Provider
Bindings []*IfaceBinding
AutoBindings []*AutoBinding
Values []*Value
Fields []*Field
Imports []*ProviderSet
// InjectorArgs is only filled in for wire.Build.
InjectorArgs *InjectorArgs

Expand All @@ -125,6 +129,22 @@ func (set *ProviderSet) Outputs() []types.Type {
func (set *ProviderSet) For(t types.Type) ProvidedType {
pt := set.providerMap.At(t)
if pt == nil {
// if t is an interface, we may have an AutoBinding that implements it.
iface, ok := t.Underlying().(*types.Interface)
if !ok {
return ProvidedType{}
}

for _, ab := range set.AutoBindings {
if types.Implements(ab.Concrete, iface) {
// cache for later
pt := &ProvidedType{t: ab.Concrete, ab: ab}
set.providerMap.Set(t, pt)
set.srcMap.Set(t, &providerSetSrc{AutoBinding: ab})
return *pt
}
}

return ProvidedType{}
}
return *pt.(*ProvidedType)
Expand Down Expand Up @@ -179,6 +199,17 @@ type Provider struct {
HasErr bool
}

// AutoBinding records the signature of a provider eligible for auto-binding
// to interfaces it implements. A provider is a single Go object, either a
// function or a named type.
type AutoBinding struct {
// Concrete is always a type that implements N number of interfaces.
Concrete types.Type

// Pos is the position where the binding was declared.
Pos token.Pos
}

// ProviderInput describes an incoming edge in the provider graph.
type ProviderInput struct {
Type types.Type
Expand Down Expand Up @@ -520,7 +551,7 @@ func (oc *objectCache) varDecl(obj *types.Var) *ast.ValueSpec {
}

// processExpr converts an expression into a Wire structure. It may return a
// *Provider, an *IfaceBinding, a *ProviderSet, a *Value or a []*Field.
// *Provider, an *AutoBinding, an *IfaceBinding, a *ProviderSet, a *Value or a []*Field.
func (oc *objectCache) processExpr(info *types.Info, pkgPath string, expr ast.Expr, varName string) (interface{}, []error) {
exprPos := oc.fset.Position(expr.Pos())
expr = astutil.Unparen(expr)
Expand All @@ -546,6 +577,12 @@ func (oc *objectCache) processExpr(info *types.Info, pkgPath string, expr ast.Ex
case "NewSet":
pset, errs := oc.processNewSet(info, pkgPath, call, nil, varName)
return pset, notePositionAll(exprPos, errs)
case "AutoBind":
abs, err := processAutoBind(oc.fset, info, call)
if err != nil {
return nil, []error{notePosition(exprPos, err)}
}
return abs, nil
case "Bind":
b, err := processBind(oc.fset, info, call)
if err != nil {
Expand Down Expand Up @@ -607,6 +644,8 @@ func (oc *objectCache) processNewSet(info *types.Info, pkgPath string, call *ast
continue
}
switch item := item.(type) {
case *AutoBinding:
pset.AutoBindings = append(pset.AutoBindings, item)
case *Provider:
pset.Providers = append(pset.Providers, item)
case *ProviderSet:
Expand Down Expand Up @@ -880,6 +919,41 @@ func isPrevented(tag string) bool {
return reflect.StructTag(tag).Get("wire") == "-"
}

func processAutoBind(fset *token.FileSet, info *types.Info, call *ast.CallExpr) (*AutoBinding, error) {
// Assumes that call.Fun is wire.AutoBind.

if len(call.Args) != 1 {
return nil, notePosition(fset.Position(call.Pos()),
errors.New("call to AutoBind takes exactly one argument"))
}
const firstArgReqFormat = "first argument to AutoBind must be a pointer to a type; found %s"
typ := info.TypeOf(call.Args[0])
ptr, ok := typ.(*types.Pointer)
if !ok {
return nil, notePosition(fset.Position(call.Pos()),
fmt.Errorf(firstArgReqFormat, types.TypeString(typ, nil)))
}

switch ptr.Elem().Underlying().(type) {
case *types.Named,
*types.Struct,
*types.Basic:
// good!

default:
return nil, notePosition(fset.Position(call.Pos()),
fmt.Errorf(firstArgReqFormat, types.TypeString(ptr, nil)))
}

typeExpr := call.Args[0].(*ast.CallExpr)
typeName := qualifiedIdentObject(info, typeExpr.Args[0]) // should be either an identifier or selector
autoBinding := &AutoBinding{
Concrete: ptr,
Pos: typeName.Pos(),
}
return autoBinding, nil
}

// processBind creates an interface binding from a wire.Bind call.
func processBind(fset *token.FileSet, info *types.Info, call *ast.CallExpr) (*IfaceBinding, error) {
// Assumes that call.Fun is wire.Bind.
Expand Down Expand Up @@ -1122,7 +1196,6 @@ func findInjectorBuild(info *types.Info, fn *ast.FuncDecl) (*ast.CallExpr, error
default:
invalid = true
}

}
if wireBuildCall == nil {
return nil, nil
Expand Down Expand Up @@ -1157,16 +1230,17 @@ func isProviderSetType(t types.Type) bool {
// none of the above, and returns true for IsNil.
type ProvidedType struct {
// t is the provided concrete type.
t types.Type
p *Provider
v *Value
a *InjectorArg
f *Field
t types.Type
p *Provider
ab *AutoBinding
v *Value
a *InjectorArg
f *Field
}

// IsNil reports whether pt is the zero value.
func (pt ProvidedType) IsNil() bool {
return pt.p == nil && pt.v == nil && pt.a == nil && pt.f == nil
return pt.p == nil && pt.ab == nil && pt.v == nil && pt.a == nil && pt.f == nil
}

// Type returns the output type.
Expand All @@ -1185,6 +1259,11 @@ func (pt ProvidedType) IsProvider() bool {
return pt.p != nil
}

// IsAutoBinding reports whether pt points to an AutoBinding.
func (pt ProvidedType) IsAutoBinding() bool {
return pt.ab != nil
}

// IsValue reports whether pt points to a Value.
func (pt ProvidedType) IsValue() bool {
return pt.v != nil
Expand All @@ -1209,6 +1288,15 @@ func (pt ProvidedType) Provider() *Provider {
return pt.p
}

// AutoBinding returns pt as a AutoBinding pointer. It panics if pt does not point
// to a AutoBinding.
func (pt ProvidedType) AutoBinding() *AutoBinding {
if pt.ab == nil {
panic("ProvidedType does not hold an AutoBinding")
}
return pt.ab
}

// Value returns pt as a Value pointer. It panics if pt does not point
// to a Value.
func (pt ProvidedType) Value() *Value {
Expand Down
43 changes: 43 additions & 0 deletions internal/wire/testdata/InterfaceAutoBinding/foo/foo.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
// Copyright 2018 The Wire Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package main

import (
"fmt"

"github.com/google/wire"
)

func main() {
fmt.Println(injectFooer().Foo())
}

type Fooer interface {
Foo() string
}

type Bar string

func (b *Bar) Foo() string {
return string(*b)
}

func provideBar() *Bar {
b := new(Bar)
*b = "Hello, World!"
return b
}

var Set = wire.NewSet(provideBar)
26 changes: 26 additions & 0 deletions internal/wire/testdata/InterfaceAutoBinding/foo/wire.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// Copyright 2018 The Wire Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

//+build wireinject

package main

import (
"github.com/google/wire"
)

func injectFooer() Fooer {
wire.Build(Set, wire.AutoBind(new(Bar)))
return nil
}
1 change: 1 addition & 0 deletions internal/wire/testdata/InterfaceAutoBinding/pkg
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
example.com/foo
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Hello, World!
Loading