diff --git a/internal/wire/parse.go b/internal/wire/parse.go index 93fbda85..1c86e23e 100644 --- a/internal/wire/parse.go +++ b/internal/wire/parse.go @@ -546,6 +546,9 @@ func (oc *objectCache) processExpr(info *types.Info, pkgPath string, expr ast.Ex case "NewSet": pset, errs := oc.processNewSet(info, pkgPath, call, nil, varName) return pset, notePositionAll(exprPos, errs) + case "Subtract": + pset, errs := oc.processSubtract(info, pkgPath, call, nil, varName) + return pset, notePositionAll(exprPos, errs) case "Bind": b, err := processBind(oc.fset, info, call) if err != nil { @@ -590,6 +593,115 @@ func (oc *objectCache) processExpr(info *types.Info, pkgPath string, expr ast.Ex return nil, []error{notePosition(exprPos, errors.New("unknown pattern"))} } +func (oc *objectCache) filterType(set *ProviderSet, t types.Type) []error { + hasType := func(outs []types.Type) bool { + for _, o := range outs { + if types.Identical(o, t) { + return true + } + pt, ok := o.(*types.Pointer) + if ok && types.Identical(pt.Elem(), t) { + return true + } + } + return false + } + + providers := make([]*Provider, 0, len(set.Providers)) + for _, p := range set.Providers { + if !hasType(p.Out) { + providers = append(providers, p) + } + } + set.Providers = providers + + bindings := make([]*IfaceBinding, 0, len(set.Bindings)) + for _, i := range set.Bindings { + if !types.Identical(i.Iface, t) { + bindings = append(bindings, i) + } + } + set.Bindings = bindings + + values := make([]*Value, 0, len(set.Values)) + for _, v := range set.Values { + if !types.Identical(v.Out, t) { + values = append(values, v) + } + } + set.Values = values + + fields := make([]*Field, 0, len(set.Fields)) + for _, f := range set.Fields { + if !hasType(f.Out) { + fields = append(fields, f) + } + } + set.Fields = fields + + imports := make([]*ProviderSet, 0, len(set.Imports)) + for _, p := range set.Imports { + clone := *p + if errs := oc.filterType(&clone, t); len(errs) > 0 { + return errs + } + imports = append(imports, &clone) + } + set.Imports = imports + + var errs []error + set.providerMap, set.srcMap, errs = buildProviderMap(oc.fset, oc.hasher, set) + if len(errs) > 0 { + return errs + } + return nil +} + +func (oc *objectCache) processSubtract(info *types.Info, pkgPath string, call *ast.CallExpr, args *InjectorArgs, varName string) (interface{}, []error) { + // Assumes that call.Fun is wire.Subtract. + if len(call.Args) < 2 { + return nil, []error{notePosition(oc.fset.Position(call.Pos()), + errors.New("call to Subtract must specify types to be subtracted"))} + } + firstArg, errs := oc.processExpr(info, pkgPath, call.Args[0], "") + if len(errs) > 0 { + return nil, errs + } + set, ok := firstArg.(*ProviderSet) + if !ok { + return nil, []error{notePosition(oc.fset.Position(call.Pos()), + fmt.Errorf("first argument to Subtract must be a Set")), + } + } + pset := &ProviderSet{ + Pos: call.Pos(), + InjectorArgs: args, + PkgPath: pkgPath, + VarName: varName, + // Copy the other fields. + Providers: set.Providers, + Bindings: set.Bindings, + Values: set.Values, + Fields: set.Fields, + Imports: set.Imports, + } + ec := new(errorCollector) + for _, arg := range call.Args[1:] { + ptr, ok := info.TypeOf(arg).(*types.Pointer) + if !ok { + ec.add(notePosition(oc.fset.Position(arg.Pos()), + errors.New("argument to Subtract must be a pointer"), + )) + continue + } + ec.add(oc.filterType(pset, ptr.Elem())...) + } + if len(ec.errors) > 0 { + return nil, ec.errors + } + return pset, nil +} + func (oc *objectCache) processNewSet(info *types.Info, pkgPath string, call *ast.CallExpr, args *InjectorArgs, varName string) (*ProviderSet, []error) { // Assumes that call.Fun is wire.NewSet or wire.Build. @@ -1173,9 +1285,9 @@ func (pt ProvidedType) IsNil() bool { // // - For a function provider, this is the first return value type. // - For a struct provider, this is either the struct type or the pointer type -// whose element type is the struct type. -// - For a value, this is the type of the expression. -// - For an argument, this is the type of the argument. +// whose element type is the struct type. +// - For a value, this is the type of the expression. +// - For an argument, this is the type of the argument. func (pt ProvidedType) Type() types.Type { return pt.t } diff --git a/internal/wire/testdata/Subtract/foo/foo.go b/internal/wire/testdata/Subtract/foo/foo.go new file mode 100644 index 00000000..f7b9bc37 --- /dev/null +++ b/internal/wire/testdata/Subtract/foo/foo.go @@ -0,0 +1,66 @@ +// Copyright 2018 The Wire Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "github.com/google/wire" +) + +type context struct{} + +func main() {} + +type FooOptions struct{} +type Foo string +type Bar struct{} +type BarName string + +func (b *Bar) Bar() {} + +func provideFooOptions() *FooOptions { + return &FooOptions{} +} + +func provideFoo(*FooOptions) Foo { + return Foo("foo") +} + +func provideBar(Foo, BarName) *Bar { + return &Bar{} +} + +type BarService interface { + Bar() +} + +type FooBar struct { + BarService + Foo +} + +var Set = wire.NewSet( + provideFooOptions, + provideFoo, + provideBar, +) + +var SuperSet = wire.NewSet(Set, + wire.Struct(new(FooBar), "*"), + wire.Bind(new(BarService), new(*Bar)), +) + +type FakeBarService struct{} + +func (f *FakeBarService) Bar() {} diff --git a/internal/wire/testdata/Subtract/foo/wire.go b/internal/wire/testdata/Subtract/foo/wire.go new file mode 100644 index 00000000..d2f23ea1 --- /dev/null +++ b/internal/wire/testdata/Subtract/foo/wire.go @@ -0,0 +1,49 @@ +// Copyright 2018 The Wire Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build wireinject +// +build wireinject + +package main + +import ( + // "strings" + + "github.com/google/wire" +) + +func inject(name BarName, opts *FooOptions) *Bar { + panic(wire.Build(wire.Subtract(Set, new(FooOptions)))) +} + +func injectBarService(name BarName, opts *FakeBarService) *FooBar { + panic(wire.Build( + wire.Subtract(SuperSet, new(BarService)), + wire.Bind(new(BarService), new(*FakeBarService)), + )) +} + +func injectFooBarService(name BarName, opts *FooOptions, bar *FakeBarService) *FooBar { + panic(wire.Build( + wire.Subtract(SuperSet, new(FooOptions), new(BarService)), + wire.Bind(new(BarService), new(*FakeBarService)), + )) +} + +func injectNone(name BarName, foo Foo, bar *FakeBarService) *FooBar { + panic(wire.Build( + wire.Subtract(SuperSet, new(Foo), new(BarService)), + wire.Bind(new(BarService), new(*FakeBarService)), + )) +} diff --git a/internal/wire/testdata/Subtract/pkg b/internal/wire/testdata/Subtract/pkg new file mode 100644 index 00000000..f7a5c8ce --- /dev/null +++ b/internal/wire/testdata/Subtract/pkg @@ -0,0 +1 @@ +example.com/foo diff --git a/internal/wire/testdata/Subtract/want/program_out.txt b/internal/wire/testdata/Subtract/want/program_out.txt new file mode 100644 index 00000000..e69de29b diff --git a/internal/wire/testdata/Subtract/want/wire_gen.go b/internal/wire/testdata/Subtract/want/wire_gen.go new file mode 100644 index 00000000..130671c0 --- /dev/null +++ b/internal/wire/testdata/Subtract/want/wire_gen.go @@ -0,0 +1,42 @@ +// Code generated by Wire. DO NOT EDIT. + +//go:generate go run -mod=mod github.com/google/wire/cmd/wire +//go:build !wireinject +// +build !wireinject + +package main + +// Injectors from wire.go: + +func inject(name BarName, opts *FooOptions) *Bar { + foo := provideFoo(opts) + bar := provideBar(foo, name) + return bar +} + +func injectBarService(name BarName, opts *FakeBarService) *FooBar { + fooOptions := provideFooOptions() + foo := provideFoo(fooOptions) + fooBar := &FooBar{ + BarService: opts, + Foo: foo, + } + return fooBar +} + +func injectFooBarService(name BarName, opts *FooOptions, bar *FakeBarService) *FooBar { + foo := provideFoo(opts) + fooBar := &FooBar{ + BarService: bar, + Foo: foo, + } + return fooBar +} + +func injectNone(name BarName, foo Foo, bar *FakeBarService) *FooBar { + fooBar := &FooBar{ + BarService: bar, + Foo: foo, + } + return fooBar +} diff --git a/wire.go b/wire.go index fe8edc8c..dafbb872 100644 --- a/wire.go +++ b/wire.go @@ -59,6 +59,10 @@ func NewSet(...interface{}) ProviderSet { return ProviderSet{} } +func Subtract(...interface{}) ProviderSet { + return ProviderSet{} +} + // Build is placed in the body of an injector function template to declare the // providers to use. The Wire code generation tool will fill in an // implementation of the function. The arguments to Build are interpreted the @@ -156,12 +160,12 @@ type StructProvider struct{} // // For example: // -// type S struct { -// MyFoo *Foo -// MyBar *Bar -// } -// var Set = wire.NewSet(wire.Struct(new(S), "MyFoo")) -> inject only S.MyFoo -// var Set = wire.NewSet(wire.Struct(new(S), "*")) -> inject all fields +// type S struct { +// MyFoo *Foo +// MyBar *Bar +// } +// var Set = wire.NewSet(wire.Struct(new(S), "MyFoo")) -> inject only S.MyFoo +// var Set = wire.NewSet(wire.Struct(new(S), "*")) -> inject all fields func Struct(structType interface{}, fieldNames ...string) StructProvider { return StructProvider{} } @@ -175,22 +179,22 @@ type StructFields struct{} // // The following example would provide Foo and Bar using S.MyFoo and S.MyBar respectively: // -// type S struct { -// MyFoo Foo -// MyBar Bar -// } +// type S struct { +// MyFoo Foo +// MyBar Bar +// } // -// func NewStruct() S { /* ... */ } -// var Set = wire.NewSet(wire.FieldsOf(new(S), "MyFoo", "MyBar")) +// func NewStruct() S { /* ... */ } +// var Set = wire.NewSet(wire.FieldsOf(new(S), "MyFoo", "MyBar")) // -// or +// or // -// func NewStruct() *S { /* ... */ } -// var Set = wire.NewSet(wire.FieldsOf(new(*S), "MyFoo", "MyBar")) +// func NewStruct() *S { /* ... */ } +// var Set = wire.NewSet(wire.FieldsOf(new(*S), "MyFoo", "MyBar")) // -// If the structType argument is a pointer to a pointer to a struct, then FieldsOf -// additionally provides a pointer to each field type (e.g., *Foo and *Bar in the -// example above). +// If the structType argument is a pointer to a pointer to a struct, then FieldsOf +// additionally provides a pointer to each field type (e.g., *Foo and *Bar in the +// example above). func FieldsOf(structType interface{}, fieldNames ...string) StructFields { return StructFields{} }