diff --git a/internal/wire/parse.go b/internal/wire/parse.go index 09f1e1dd..fb7a4958 100644 --- a/internal/wire/parse.go +++ b/internal/wire/parse.go @@ -546,6 +546,9 @@ func (oc *objectCache) processExpr(info *types.Info, pkgPath string, expr ast.Ex case "NewSet": pset, errs := oc.processNewSet(info, pkgPath, call, nil, varName) return pset, notePositionAll(exprPos, errs) + case "Subtract": + pset, errs := oc.processSubtract(info, pkgPath, call, nil, varName) + return pset, notePositionAll(exprPos, errs) case "Bind": b, err := processBind(oc.fset, info, call) if err != nil { @@ -880,6 +883,116 @@ func isPrevented(tag string) bool { return reflect.StructTag(tag).Get("wire") == "-" } +func (oc *objectCache) processSubtract(info *types.Info, pkgPath string, call *ast.CallExpr, args *InjectorArgs, varName string) (interface{}, []error) { + // Assumes that call.Fun is wire.Subtract. + if len(call.Args) < 2 { + return nil, []error{notePosition(oc.fset.Position(call.Pos()), + errors.New("call to Subtract must specify types to be subtracted"))} + } + firstArg, errs := oc.processExpr(info, pkgPath, call.Args[0], "") + if len(errs) > 0 { + return nil, errs + } + set, ok := firstArg.(*ProviderSet) + if !ok { + return nil, []error{ + notePosition(oc.fset.Position(call.Pos()), + fmt.Errorf("first argument to Subtract must be a Set")), + } + } + pset := &ProviderSet{ + Pos: call.Pos(), + InjectorArgs: args, + PkgPath: pkgPath, + VarName: varName, + // Copy the other fields. + Providers: set.Providers, + Bindings: set.Bindings, + Values: set.Values, + Fields: set.Fields, + Imports: set.Imports, + } + ec := new(errorCollector) + for _, arg := range call.Args[1:] { + ptr, ok := info.TypeOf(arg).(*types.Pointer) + if !ok { + ec.add(notePosition(oc.fset.Position(arg.Pos()), + fmt.Errorf("argument to Subtract must be a pointer"), + )) + continue + } + ec.add(oc.filterType(pset, ptr.Elem())...) + } + if len(ec.errors) > 0 { + return nil, ec.errors + } + return pset, nil +} + +func (oc *objectCache) filterType(set *ProviderSet, t types.Type) []error { + hasType := func(outs []types.Type) bool { + for _, o := range outs { + if types.Identical(o, t) { + return true + } + pt, ok := o.(*types.Pointer) + if ok && types.Identical(pt.Elem(), t) { + return true + } + } + return false + } + + providers := make([]*Provider, 0, len(set.Providers)) + for _, p := range set.Providers { + if !hasType(p.Out) { + providers = append(providers, p) + } + } + set.Providers = providers + + bindings := make([]*IfaceBinding, 0, len(set.Bindings)) + for _, i := range set.Bindings { + if !types.Identical(i.Iface, t) { + bindings = append(bindings, i) + } + } + set.Bindings = bindings + + values := make([]*Value, 0, len(set.Values)) + for _, v := range set.Values { + if !types.Identical(v.Out, t) { + values = append(values, v) + } + } + set.Values = values + + fields := make([]*Field, 0, len(set.Fields)) + for _, f := range set.Fields { + if !hasType(f.Out) { + fields = append(fields, f) + } + } + set.Fields = fields + + imports := make([]*ProviderSet, 0, len(set.Imports)) + for _, p := range set.Imports { + clone := *p + if errs := oc.filterType(&clone, t); len(errs) > 0 { + return errs + } + imports = append(imports, &clone) + } + set.Imports = imports + + var errs []error + set.providerMap, set.srcMap, errs = buildProviderMap(oc.fset, oc.hasher, set) + if len(errs) > 0 { + return errs + } + return nil +} + // processBind creates an interface binding from a wire.Bind call. func processBind(fset *token.FileSet, info *types.Info, call *ast.CallExpr) (*IfaceBinding, error) { // Assumes that call.Fun is wire.Bind. @@ -1122,7 +1235,6 @@ func findInjectorBuild(info *types.Info, fn *ast.FuncDecl) (*ast.CallExpr, error default: invalid = true } - } if wireBuildCall == nil { return nil, nil diff --git a/internal/wire/testdata/Subtract/foo/foo.go b/internal/wire/testdata/Subtract/foo/foo.go new file mode 100644 index 00000000..e43d799c --- /dev/null +++ b/internal/wire/testdata/Subtract/foo/foo.go @@ -0,0 +1,68 @@ +// Copyright 2018 The Wire Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "github.com/google/wire" +) + +type context struct{} + +func main() {} + +type ( + FooOptions struct{} + Foo string + Bar struct{} + BarName string +) + +func (b *Bar) Bar() {} + +func provideFooOptions() *FooOptions { + return &FooOptions{} +} + +func provideFoo(*FooOptions) Foo { + return Foo("foo") +} + +func provideBar(Foo, BarName) *Bar { + return &Bar{} +} + +type BarService interface { + Bar() +} + +type FooBar struct { + BarService + Foo +} + +var Set = wire.NewSet( + provideFooOptions, + provideFoo, + provideBar, +) + +var SuperSet = wire.NewSet(Set, + wire.Struct(new(FooBar), "*"), + wire.Bind(new(BarService), new(*Bar)), +) + +type FakeBarService struct{} + +func (f *FakeBarService) Bar() {} diff --git a/internal/wire/testdata/Subtract/foo/wire.go b/internal/wire/testdata/Subtract/foo/wire.go new file mode 100644 index 00000000..1bb8c4f9 --- /dev/null +++ b/internal/wire/testdata/Subtract/foo/wire.go @@ -0,0 +1,47 @@ +// Copyright 2018 The Wire Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build wireinject +// +build wireinject + +package main + +import ( + "github.com/google/wire" +) + +func inject(name BarName, opts *FooOptions) *Bar { + panic(wire.Build(wire.Subtract(Set, new(FooOptions)))) +} + +func injectBarService(name BarName, opts *FakeBarService) *FooBar { + panic(wire.Build( + wire.Subtract(SuperSet, new(BarService)), + wire.Bind(new(BarService), new(*FakeBarService)), + )) +} + +func injectFooBarService(name BarName, opts *FooOptions, bar *FakeBarService) *FooBar { + panic(wire.Build( + wire.Subtract(SuperSet, new(FooOptions), new(BarService)), + wire.Bind(new(BarService), new(*FakeBarService)), + )) +} + +func injectNone(name BarName, foo Foo, bar *FakeBarService) *FooBar { + panic(wire.Build( + wire.Subtract(SuperSet, new(Foo), new(BarService)), + wire.Bind(new(BarService), new(*FakeBarService)), + )) +} diff --git a/internal/wire/testdata/Subtract/pkg b/internal/wire/testdata/Subtract/pkg new file mode 100644 index 00000000..f7a5c8ce --- /dev/null +++ b/internal/wire/testdata/Subtract/pkg @@ -0,0 +1 @@ +example.com/foo diff --git a/internal/wire/testdata/Subtract/want/program_out.txt b/internal/wire/testdata/Subtract/want/program_out.txt new file mode 100644 index 00000000..8b137891 --- /dev/null +++ b/internal/wire/testdata/Subtract/want/program_out.txt @@ -0,0 +1 @@ + diff --git a/internal/wire/testdata/Subtract/want/wire_gen.go b/internal/wire/testdata/Subtract/want/wire_gen.go new file mode 100644 index 00000000..130671c0 --- /dev/null +++ b/internal/wire/testdata/Subtract/want/wire_gen.go @@ -0,0 +1,42 @@ +// Code generated by Wire. DO NOT EDIT. + +//go:generate go run -mod=mod github.com/google/wire/cmd/wire +//go:build !wireinject +// +build !wireinject + +package main + +// Injectors from wire.go: + +func inject(name BarName, opts *FooOptions) *Bar { + foo := provideFoo(opts) + bar := provideBar(foo, name) + return bar +} + +func injectBarService(name BarName, opts *FakeBarService) *FooBar { + fooOptions := provideFooOptions() + foo := provideFoo(fooOptions) + fooBar := &FooBar{ + BarService: opts, + Foo: foo, + } + return fooBar +} + +func injectFooBarService(name BarName, opts *FooOptions, bar *FakeBarService) *FooBar { + foo := provideFoo(opts) + fooBar := &FooBar{ + BarService: bar, + Foo: foo, + } + return fooBar +} + +func injectNone(name BarName, foo Foo, bar *FakeBarService) *FooBar { + fooBar := &FooBar{ + BarService: bar, + Foo: foo, + } + return fooBar +} diff --git a/internal/wire/testdata/SubtractErrors/foo/foo.go b/internal/wire/testdata/SubtractErrors/foo/foo.go new file mode 100644 index 00000000..097fa247 --- /dev/null +++ b/internal/wire/testdata/SubtractErrors/foo/foo.go @@ -0,0 +1,48 @@ +// Copyright 2018 The Wire Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "github.com/google/wire" +) + +type context struct{} + +func main() {} + +type ( + FooOptions struct{} + Foo string + Bar struct{} + BarName string +) + +func provideFooOptions() *FooOptions { + return &FooOptions{} +} + +func provideFoo(*FooOptions) Foo { + return Foo("foo") +} + +func provideBar(Foo, BarName) *Bar { + return &Bar{} +} + +var Set = wire.NewSet( + provideFooOptions, + provideFoo, + provideBar, +) diff --git a/internal/wire/testdata/SubtractErrors/foo/wire.go b/internal/wire/testdata/SubtractErrors/foo/wire.go new file mode 100644 index 00000000..51d4fb75 --- /dev/null +++ b/internal/wire/testdata/SubtractErrors/foo/wire.go @@ -0,0 +1,34 @@ +// Copyright 2018 The Wire Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build wireinject +// +build wireinject + +package main + +import ( + "github.com/google/wire" +) + +func injectMissArgs(opts *FooOptions) Foo { + panic(wire.Build(wire.Subtract(provideFoo))) +} + +func injectNonSet(opts *FooOptions) Foo { + panic(wire.Build(wire.Subtract(provideFoo, new(FooOptions)))) +} + +func injectNonPointer(name BarName, opts *FooOptions) *Bar { + panic(wire.Build(wire.Subtract(Set, FooOptions{}))) +} diff --git a/internal/wire/testdata/SubtractErrors/pkg b/internal/wire/testdata/SubtractErrors/pkg new file mode 100644 index 00000000..f7a5c8ce --- /dev/null +++ b/internal/wire/testdata/SubtractErrors/pkg @@ -0,0 +1 @@ +example.com/foo diff --git a/internal/wire/testdata/SubtractErrors/want/wire_errs.txt b/internal/wire/testdata/SubtractErrors/want/wire_errs.txt new file mode 100644 index 00000000..6da996ec --- /dev/null +++ b/internal/wire/testdata/SubtractErrors/want/wire_errs.txt @@ -0,0 +1,5 @@ +example.com/foo/wire.go:x:y: call to Subtract must specify types to be subtracted + +example.com/foo/wire.go:x:y: first argument to Subtract must be a Set + +example.com/foo/wire.go:x:y: argument to Subtract must be a pointer \ No newline at end of file diff --git a/wire.go b/wire.go index 6af91dda..0e23e460 100644 --- a/wire.go +++ b/wire.go @@ -59,6 +59,31 @@ func NewSet(...interface{}) ProviderSet { return ProviderSet{} } +// Subtract removes type declaration from the provider set. +// +// Example: +// +// var MySetA = wire.NewSet( +// otherpkg.FooSet, +// otherpkg.BarSet, +// ) +// +// var MySetB = wire.NewSet( +// otherpkg.CarSet, +// otherpkg.BarSet, +// ) +// +// func Build() Set { +// panic(wire.Build( +// MySetA, +// wire.Subtract(MySetB, otherpkg.BarSet), +// NewSet, +// )) +// } +func Subtract(...interface{}) ProviderSet { + return ProviderSet{} +} + // Build is placed in the body of an injector function template to declare the // providers to use. The Wire code generation tool will fill in an // implementation of the function. The arguments to Build are interpreted the