Skip to content

Commit

Permalink
feat(subtract): added Subtract function to wire
Browse files Browse the repository at this point in the history
  • Loading branch information
krhubert committed Oct 13, 2024
1 parent e57deea commit 5995e84
Show file tree
Hide file tree
Showing 11 changed files with 385 additions and 1 deletion.
114 changes: 113 additions & 1 deletion internal/wire/parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -546,6 +546,9 @@ func (oc *objectCache) processExpr(info *types.Info, pkgPath string, expr ast.Ex
case "NewSet":
pset, errs := oc.processNewSet(info, pkgPath, call, nil, varName)
return pset, notePositionAll(exprPos, errs)
case "Subtract":
pset, errs := oc.processSubtract(info, pkgPath, call, nil, varName)
return pset, notePositionAll(exprPos, errs)
case "Bind":
b, err := processBind(oc.fset, info, call)
if err != nil {
Expand Down Expand Up @@ -880,6 +883,116 @@ func isPrevented(tag string) bool {
return reflect.StructTag(tag).Get("wire") == "-"
}

func (oc *objectCache) processSubtract(info *types.Info, pkgPath string, call *ast.CallExpr, args *InjectorArgs, varName string) (interface{}, []error) {
// Assumes that call.Fun is wire.Subtract.
if len(call.Args) < 2 {
return nil, []error{notePosition(oc.fset.Position(call.Pos()),
errors.New("call to Subtract must specify types to be subtracted"))}
}
firstArg, errs := oc.processExpr(info, pkgPath, call.Args[0], "")
if len(errs) > 0 {
return nil, errs
}
set, ok := firstArg.(*ProviderSet)
if !ok {
return nil, []error{
notePosition(oc.fset.Position(call.Pos()),
fmt.Errorf("first argument to Subtract must be a Set")),
}
}
pset := &ProviderSet{
Pos: call.Pos(),
InjectorArgs: args,
PkgPath: pkgPath,
VarName: varName,
// Copy the other fields.
Providers: set.Providers,
Bindings: set.Bindings,
Values: set.Values,
Fields: set.Fields,
Imports: set.Imports,
}
ec := new(errorCollector)
for _, arg := range call.Args[1:] {
ptr, ok := info.TypeOf(arg).(*types.Pointer)
if !ok {
ec.add(notePosition(oc.fset.Position(arg.Pos()),
fmt.Errorf("argument to Subtract must be a pointer"),
))
continue
}
ec.add(oc.filterType(pset, ptr.Elem())...)
}
if len(ec.errors) > 0 {
return nil, ec.errors
}
return pset, nil
}

func (oc *objectCache) filterType(set *ProviderSet, t types.Type) []error {
hasType := func(outs []types.Type) bool {
for _, o := range outs {
if types.Identical(o, t) {
return true
}
pt, ok := o.(*types.Pointer)
if ok && types.Identical(pt.Elem(), t) {
return true
}
}
return false
}

providers := make([]*Provider, 0, len(set.Providers))
for _, p := range set.Providers {
if !hasType(p.Out) {
providers = append(providers, p)
}
}
set.Providers = providers

bindings := make([]*IfaceBinding, 0, len(set.Bindings))
for _, i := range set.Bindings {
if !types.Identical(i.Iface, t) {
bindings = append(bindings, i)
}
}
set.Bindings = bindings

values := make([]*Value, 0, len(set.Values))
for _, v := range set.Values {
if !types.Identical(v.Out, t) {
values = append(values, v)
}
}
set.Values = values

fields := make([]*Field, 0, len(set.Fields))
for _, f := range set.Fields {
if !hasType(f.Out) {
fields = append(fields, f)
}
}
set.Fields = fields

imports := make([]*ProviderSet, 0, len(set.Imports))
for _, p := range set.Imports {
clone := *p
if errs := oc.filterType(&clone, t); len(errs) > 0 {
return errs
}
imports = append(imports, &clone)
}
set.Imports = imports

var errs []error
set.providerMap, set.srcMap, errs = buildProviderMap(oc.fset, oc.hasher, set)
if len(errs) > 0 {
return errs
}
return nil
}

// processBind creates an interface binding from a wire.Bind call.
func processBind(fset *token.FileSet, info *types.Info, call *ast.CallExpr) (*IfaceBinding, error) {
// Assumes that call.Fun is wire.Bind.
Expand Down Expand Up @@ -1122,7 +1235,6 @@ func findInjectorBuild(info *types.Info, fn *ast.FuncDecl) (*ast.CallExpr, error
default:
invalid = true
}

}
if wireBuildCall == nil {
return nil, nil
Expand Down
68 changes: 68 additions & 0 deletions internal/wire/testdata/Subtract/foo/foo.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
// Copyright 2018 The Wire Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package main

import (
"github.com/google/wire"
)

type context struct{}

func main() {}

type (
FooOptions struct{}
Foo string
Bar struct{}
BarName string
)

func (b *Bar) Bar() {}

func provideFooOptions() *FooOptions {
return &FooOptions{}
}

func provideFoo(*FooOptions) Foo {
return Foo("foo")
}

func provideBar(Foo, BarName) *Bar {
return &Bar{}
}

type BarService interface {
Bar()
}

type FooBar struct {
BarService
Foo
}

var Set = wire.NewSet(
provideFooOptions,
provideFoo,
provideBar,
)

var SuperSet = wire.NewSet(Set,
wire.Struct(new(FooBar), "*"),
wire.Bind(new(BarService), new(*Bar)),
)

type FakeBarService struct{}

func (f *FakeBarService) Bar() {}
47 changes: 47 additions & 0 deletions internal/wire/testdata/Subtract/foo/wire.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
// Copyright 2018 The Wire Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

//go:build wireinject
// +build wireinject

package main

import (
"github.com/google/wire"
)

func inject(name BarName, opts *FooOptions) *Bar {
panic(wire.Build(wire.Subtract(Set, new(FooOptions))))
}

func injectBarService(name BarName, opts *FakeBarService) *FooBar {
panic(wire.Build(
wire.Subtract(SuperSet, new(BarService)),
wire.Bind(new(BarService), new(*FakeBarService)),
))
}

func injectFooBarService(name BarName, opts *FooOptions, bar *FakeBarService) *FooBar {
panic(wire.Build(
wire.Subtract(SuperSet, new(FooOptions), new(BarService)),
wire.Bind(new(BarService), new(*FakeBarService)),
))
}

func injectNone(name BarName, foo Foo, bar *FakeBarService) *FooBar {
panic(wire.Build(
wire.Subtract(SuperSet, new(Foo), new(BarService)),
wire.Bind(new(BarService), new(*FakeBarService)),
))
}
1 change: 1 addition & 0 deletions internal/wire/testdata/Subtract/pkg
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
example.com/foo
1 change: 1 addition & 0 deletions internal/wire/testdata/Subtract/want/program_out.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@

42 changes: 42 additions & 0 deletions internal/wire/testdata/Subtract/want/wire_gen.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

48 changes: 48 additions & 0 deletions internal/wire/testdata/SubtractErrors/foo/foo.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
// Copyright 2018 The Wire Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package main

import (
"github.com/google/wire"
)

type context struct{}

func main() {}

type (
FooOptions struct{}
Foo string
Bar struct{}
BarName string
)

func provideFooOptions() *FooOptions {
return &FooOptions{}
}

func provideFoo(*FooOptions) Foo {
return Foo("foo")
}

func provideBar(Foo, BarName) *Bar {
return &Bar{}
}

var Set = wire.NewSet(
provideFooOptions,
provideFoo,
provideBar,
)
34 changes: 34 additions & 0 deletions internal/wire/testdata/SubtractErrors/foo/wire.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
// Copyright 2018 The Wire Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

//go:build wireinject
// +build wireinject

package main

import (
"github.com/google/wire"
)

func injectMissArgs(opts *FooOptions) Foo {
panic(wire.Build(wire.Subtract(provideFoo)))
}

func injectNonSet(opts *FooOptions) Foo {
panic(wire.Build(wire.Subtract(provideFoo, new(FooOptions))))
}

func injectNonPointer(name BarName, opts *FooOptions) *Bar {
panic(wire.Build(wire.Subtract(Set, FooOptions{})))
}
1 change: 1 addition & 0 deletions internal/wire/testdata/SubtractErrors/pkg
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
example.com/foo
5 changes: 5 additions & 0 deletions internal/wire/testdata/SubtractErrors/want/wire_errs.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
example.com/foo/wire.go:x:y: call to Subtract must specify types to be subtracted

example.com/foo/wire.go:x:y: first argument to Subtract must be a Set

example.com/foo/wire.go:x:y: argument to Subtract must be a pointer
Loading

0 comments on commit 5995e84

Please sign in to comment.