diff --git a/internal/wire/analyze.go b/internal/wire/analyze.go index 9650ef16..90777939 100644 --- a/internal/wire/analyze.go +++ b/internal/wire/analyze.go @@ -33,6 +33,7 @@ const ( structProvider valueExpr selectorExpr + rawValueExpr ) // A call represents a step of an injector function. It may be either a @@ -210,8 +211,12 @@ dfs: case pv.IsValue(): v := pv.Value() index.Set(curr.t, given.Len()+len(calls)) + valueExprKind := valueExpr + if pv.v.RawValue == true { + valueExprKind = rawValueExpr + } calls = append(calls, call{ - kind: valueExpr, + kind: valueExprKind, out: curr.t, valueExpr: v.expr, valueTypeInfo: v.info, diff --git a/internal/wire/parse.go b/internal/wire/parse.go index 93fbda85..e694b2ea 100644 --- a/internal/wire/parse.go +++ b/internal/wire/parse.go @@ -200,6 +200,8 @@ type Value struct { // info is the type info for the expression. info *types.Info + + RawValue bool } // InjectorArg describes a specific argument passed to an injector function. @@ -558,6 +560,12 @@ func (oc *objectCache) processExpr(info *types.Info, pkgPath string, expr ast.Ex return nil, []error{notePosition(exprPos, err)} } return v, nil + case "RawValue": + v, err := processRawValue(oc.fset, info, call) + if err != nil { + return nil, []error{notePosition(exprPos, err)} + } + return v, nil case "InterfaceValue": v, err := processInterfaceValue(oc.fset, info, call) if err != nil { @@ -964,10 +972,56 @@ func processValue(fset *token.FileSet, info *types.Info, call *ast.CallExpr) (*V return nil, notePosition(fset.Position(call.Pos()), fmt.Errorf("argument to Value may not be an interface value (found %s); use InterfaceValue instead", types.TypeString(argType, nil))) } return &Value{ - Pos: call.Args[0].Pos(), - Out: info.TypeOf(call.Args[0]), - expr: call.Args[0], - info: info, + Pos: call.Args[0].Pos(), + Out: info.TypeOf(call.Args[0]), + expr: call.Args[0], + info: info, + RawValue: false, + }, nil +} + +func processRawValue(fset *token.FileSet, info *types.Info, call *ast.CallExpr) (*Value, error) { + // Assumes that call.Fun is wire.Value. + + if len(call.Args) != 1 { + return nil, notePosition(fset.Position(call.Pos()), errors.New("call to Value takes exactly one argument")) + } + ok := true + ast.Inspect(call.Args[0], func(node ast.Node) bool { + switch expr := node.(type) { + case nil, *ast.ArrayType, *ast.BasicLit, *ast.BinaryExpr, *ast.ChanType, *ast.CompositeLit, *ast.FuncType, *ast.Ident, *ast.IndexExpr, *ast.InterfaceType, *ast.KeyValueExpr, *ast.MapType, *ast.ParenExpr, *ast.SelectorExpr, *ast.SliceExpr, *ast.StarExpr, *ast.StructType, *ast.TypeAssertExpr: + // Good! + case *ast.UnaryExpr: + if expr.Op == token.ARROW { + ok = false + return false + } + case *ast.CallExpr: + // Only acceptable if it's a type conversion. + if _, isFunc := info.TypeOf(expr.Fun).(*types.Signature); isFunc { + ok = false + return false + } + default: + ok = false + return false + } + return true + }) + if !ok { + return nil, notePosition(fset.Position(call.Pos()), errors.New("argument to Value is too complex")) + } + // Result type can't be an interface type; use wire.InterfaceValue for that. + argType := info.TypeOf(call.Args[0]) + if _, isInterfaceType := argType.Underlying().(*types.Interface); isInterfaceType { + return nil, notePosition(fset.Position(call.Pos()), fmt.Errorf("argument to Value may not be an interface value (found %s); use InterfaceValue instead", types.TypeString(argType, nil))) + } + return &Value{ + Pos: call.Args[0].Pos(), + Out: info.TypeOf(call.Args[0]), + expr: call.Args[0], + info: info, + RawValue: true, }, nil } @@ -993,10 +1047,11 @@ func processInterfaceValue(fset *token.FileSet, info *types.Info, call *ast.Call return nil, notePosition(fset.Position(call.Pos()), fmt.Errorf("%s does not implement %s", types.TypeString(provided, nil), types.TypeString(iface, nil))) } return &Value{ - Pos: call.Args[1].Pos(), - Out: iface, - expr: call.Args[1], - info: info, + Pos: call.Args[1].Pos(), + Out: iface, + expr: call.Args[1], + info: info, + RawValue: false, }, nil } @@ -1173,9 +1228,9 @@ func (pt ProvidedType) IsNil() bool { // // - For a function provider, this is the first return value type. // - For a struct provider, this is either the struct type or the pointer type -// whose element type is the struct type. -// - For a value, this is the type of the expression. -// - For an argument, this is the type of the argument. +// whose element type is the struct type. +// - For a value, this is the type of the expression. +// - For an argument, this is the type of the argument. func (pt ProvidedType) Type() types.Type { return pt.t } diff --git a/internal/wire/wire.go b/internal/wire/wire.go index 5cedeb1a..484be923 100644 --- a/internal/wire/wire.go +++ b/internal/wire/wire.go @@ -364,6 +364,19 @@ func (g *gen) inject(pos token.Pos, name string, sig *types.Signature, set *Prov typeInfo: c.valueTypeInfo, }) } + } else if c.kind == rawValueExpr { + if err := accessibleFrom(c.valueTypeInfo, c.valueExpr, g.pkg.PkgPath); err != nil { + // TODO(light): Display line number of value expression. + ts := types.TypeString(c.out, nil) + ec.add(notePosition( + g.pkg.Fset.Position(pos), + fmt.Errorf("inject %s: value %s can't be used: %v", name, ts, err))) + } + if g.values[c.valueExpr] == "" { + var printValue strings.Builder + printer.Fprint(&printValue, g.pkg.Fset, g.rewritePkgRefs(c.valueTypeInfo, c.valueExpr)) + g.values[c.valueExpr] = printValue.String() + } } } if len(ec.errors) > 0 { @@ -645,6 +658,8 @@ func injectPass(name string, sig *types.Signature, calls []call, set *ProviderSe ig.funcProviderCall(lname, c, injectSig) case valueExpr: ig.valueExpr(lname, c) + case rawValueExpr: + ig.valueExpr(lname, c) case selectorExpr: ig.fieldExpr(lname, c) default: diff --git a/wire.go b/wire.go index fe8edc8c..5c522eb9 100644 --- a/wire.go +++ b/wire.go @@ -143,6 +143,10 @@ func InterfaceValue(typ interface{}, x interface{}) ProvidedValue { return ProvidedValue{} } +func RawValue(interface{}) ProvidedValue { + return ProvidedValue{} +} + // A StructProvider represents a named struct. type StructProvider struct{} @@ -156,12 +160,12 @@ type StructProvider struct{} // // For example: // -// type S struct { -// MyFoo *Foo -// MyBar *Bar -// } -// var Set = wire.NewSet(wire.Struct(new(S), "MyFoo")) -> inject only S.MyFoo -// var Set = wire.NewSet(wire.Struct(new(S), "*")) -> inject all fields +// type S struct { +// MyFoo *Foo +// MyBar *Bar +// } +// var Set = wire.NewSet(wire.Struct(new(S), "MyFoo")) -> inject only S.MyFoo +// var Set = wire.NewSet(wire.Struct(new(S), "*")) -> inject all fields func Struct(structType interface{}, fieldNames ...string) StructProvider { return StructProvider{} } @@ -175,22 +179,22 @@ type StructFields struct{} // // The following example would provide Foo and Bar using S.MyFoo and S.MyBar respectively: // -// type S struct { -// MyFoo Foo -// MyBar Bar -// } +// type S struct { +// MyFoo Foo +// MyBar Bar +// } // -// func NewStruct() S { /* ... */ } -// var Set = wire.NewSet(wire.FieldsOf(new(S), "MyFoo", "MyBar")) +// func NewStruct() S { /* ... */ } +// var Set = wire.NewSet(wire.FieldsOf(new(S), "MyFoo", "MyBar")) // -// or +// or // -// func NewStruct() *S { /* ... */ } -// var Set = wire.NewSet(wire.FieldsOf(new(*S), "MyFoo", "MyBar")) +// func NewStruct() *S { /* ... */ } +// var Set = wire.NewSet(wire.FieldsOf(new(*S), "MyFoo", "MyBar")) // -// If the structType argument is a pointer to a pointer to a struct, then FieldsOf -// additionally provides a pointer to each field type (e.g., *Foo and *Bar in the -// example above). +// If the structType argument is a pointer to a pointer to a struct, then FieldsOf +// additionally provides a pointer to each field type (e.g., *Foo and *Bar in the +// example above). func FieldsOf(structType interface{}, fieldNames ...string) StructFields { return StructFields{} }