Skip to content

Commit

Permalink
feat: add scalar typedefs to propagate scalars
Browse files Browse the repository at this point in the history
Signed-off-by: Justin Chadwell <[email protected]>
  • Loading branch information
jedevc committed Apr 30, 2024
1 parent df0e922 commit b3cb647
Show file tree
Hide file tree
Showing 27 changed files with 1,171 additions and 7 deletions.
16 changes: 12 additions & 4 deletions cmd/codegen/generator/go/templates/module_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ func (ps *parseState) parseGoTypeReference(typ types.Type, named *types.Named, i
}
parsedType := &parsedPrimitiveType{goType: t, isPtr: isPtr}
if named != nil {
parsedType.aliasType = named
parsedType.alias = named.Obj().Name()
}
return parsedType, nil
Expand Down Expand Up @@ -121,7 +122,8 @@ type parsedPrimitiveType struct {
isPtr bool

// if this is something like `type Foo string`, then alias will be "Foo"
alias string
alias string
aliasType *types.Named
}

var _ ParsedType = &parsedPrimitiveType{}
Expand All @@ -138,9 +140,12 @@ func (spec *parsedPrimitiveType) TypeDefCode() (*Statement, error) {
default:
return nil, fmt.Errorf("unsupported basic type: %+v", spec.goType)
}
def := Qual("dag", "TypeDef").Call().Dot("WithKind").Call(
kind,
)
var def *Statement
if spec.alias == "" {
def = Qual("dag", "TypeDef").Call().Dot("WithKind").Call(kind)
} else {
def = Qual("dag", "TypeDef").Call().Dot("WithScalar").Call(Lit(spec.alias))
}
if spec.isPtr {
def = def.Dot("WithOptional").Call(Lit(true))
}
Expand All @@ -152,6 +157,9 @@ func (spec *parsedPrimitiveType) GoType() types.Type {
}

func (spec *parsedPrimitiveType) GoSubTypes() []types.Type {
if spec.aliasType != nil {
return []types.Type{spec.aliasType}
}
return nil
}

Expand Down
22 changes: 22 additions & 0 deletions cmd/dagger/flags.go
Original file line number Diff line number Diff line change
Expand Up @@ -615,6 +615,17 @@ func (r *modFunctionArg) AddFlag(flags *pflag.FlagSet, dag *dagger.Client) (any,
val, _ := getDefaultValue[bool](r)
return flags.Bool(name, val, usage), nil

case dagger.ScalarKind:
scalarName := r.TypeDef.AsScalar.Name

if val := GetCustomFlagValue(scalarName); val != nil {
flags.Var(val, name, usage)
return val, nil
}

val, _ := getDefaultValue[string](r)
return flags.String(name, val, usage), nil

case dagger.ObjectKind:
objName := r.TypeDef.AsObject.Name

Expand Down Expand Up @@ -653,6 +664,17 @@ func (r *modFunctionArg) AddFlag(flags *pflag.FlagSet, dag *dagger.Client) (any,
val, _ := getDefaultValue[[]bool](r)
return flags.BoolSlice(name, val, usage), nil

case dagger.ScalarKind:
scalarName := r.TypeDef.AsScalar.Name

if val := GetCustomFlagValueSlice(scalarName); val != nil {
flags.Var(val, name, usage)
return val, nil
}

val, _ := getDefaultValue[[]string](r)
return flags.StringSlice(name, val, usage), nil

case dagger.ObjectKind:
objName := elementType.AsObject.Name

Expand Down
11 changes: 11 additions & 0 deletions cmd/dagger/module.go
Original file line number Diff line number Diff line change
Expand Up @@ -766,6 +766,9 @@ fragment TypeDefRefParts on TypeDef {
}
}
}
asScalar {
name
}
}
fragment FunctionParts on Function {
Expand Down Expand Up @@ -809,6 +812,9 @@ query TypeDefs {
...FieldParts
}
}
asScalar {
name
}
asInterface {
name
sourceModuleName
Expand Down Expand Up @@ -987,6 +993,7 @@ type modTypeDef struct {
AsInterface *modInterface
AsInput *modInput
AsList *modList
AsScalar *modScalar
}

type functionProvider interface {
Expand Down Expand Up @@ -1089,6 +1096,10 @@ func (o *modInterface) GetFunction(name string) (*modFunction, error) {
return nil, fmt.Errorf("no function '%s' in interface type '%s'", name, o.Name)
}

type modScalar struct {
Name string
}

type modInput struct {
Name string
Fields []*modField
Expand Down
74 changes: 74 additions & 0 deletions core/integration/module_call_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -499,6 +499,80 @@ func (m *Test) Cacher(ctx context.Context, cache *CacheVolume, val string) (stri
require.Equal(t, "foo\nbar\n", out)
})

t.Run("platform args", func(t *testing.T) {
t.Parallel()

c, ctx := connect(t)

modGen := c.Container().From(golangImage).
WithMountedFile(testCLIBinPath, daggerCliFile(t, c)).
WithWorkdir("/work").
With(daggerExec("init", "--source=.", "--name=test", "--sdk=go")).
WithNewFile("main.go", dagger.ContainerWithNewFileOpts{
Contents: `package main
type Test struct {}
func (m *Test) FromPlatform(platform Platform) string {
return string(platform)
}
func (m *Test) ToPlatform(platform string) Platform {
return Platform(platform)
}
`,
})

out, err := modGen.With(daggerCall("from-platform", "--platform", "linux/amd64")).Stdout(ctx)
require.NoError(t, err)
require.Equal(t, "linux/amd64", out)
_, err = modGen.With(daggerCall("from-platform", "--platform", "invalid")).Stdout(ctx)
require.ErrorContains(t, err, "unknown operating system or architecture")

out, err = modGen.With(daggerCall("to-platform", "--platform", "linux/amd64")).Stdout(ctx)
require.NoError(t, err)
require.Equal(t, "linux/amd64", out)
_, err = modGen.With(daggerCall("from-platform", "--platform", "invalid")).Stdout(ctx)
require.ErrorContains(t, err, "unknown operating system or architecture")
})

t.Run("enum args", func(t *testing.T) {
t.Parallel()

c, ctx := connect(t)

modGen := c.Container().From(golangImage).
WithMountedFile(testCLIBinPath, daggerCliFile(t, c)).
WithWorkdir("/work").
With(daggerExec("init", "--source=.", "--name=test", "--sdk=go")).
WithNewFile("main.go", dagger.ContainerWithNewFileOpts{
Contents: `package main
type Test struct {}
func (m *Test) FromProto(proto NetworkProtocol) string {
return string(proto)
}
func (m *Test) ToProto(proto string) NetworkProtocol {
return NetworkProtocol(proto)
}
`,
})

out, err := modGen.With(daggerCall("from-proto", "--proto", "TCP")).Stdout(ctx)
require.NoError(t, err)
require.Equal(t, "TCP", out)
_, err = modGen.With(daggerCall("from-proto", "--proto", "INVALID")).Stdout(ctx)
require.ErrorContains(t, err, "invalid enum value")

out, err = modGen.With(daggerCall("to-proto", "--proto", "TCP")).Stdout(ctx)
require.NoError(t, err)
require.Equal(t, "TCP", out)
_, err = modGen.With(daggerCall("to-proto", "--proto", "INVALID")).Stdout(ctx)
require.ErrorContains(t, err, "invalid enum value")
})

t.Run("module args", func(t *testing.T) {
t.Parallel()

Expand Down
43 changes: 43 additions & 0 deletions core/integration/module_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2888,6 +2888,49 @@ class Foo {
}
}

func TestModuleScalarType(t *testing.T) {
// Verify use of a core scalar as an argument type.

t.Parallel()

type testCase struct {
sdk string
source string
}
for _, tc := range []testCase{
{
sdk: "go",
source: `package main
type Foo struct{}
func (m *Foo) SayHello(platform Platform) string {
return "hello " + string(platform)
}`,
},
} {
tc := tc

t.Run(tc.sdk, func(t *testing.T) {
t.Parallel()
c, ctx := connect(t)

modGen := c.Container().From(golangImage).
WithMountedFile(testCLIBinPath, daggerCliFile(t, c)).
WithWorkdir("/work").
With(daggerExec("init", "--name=foo", "--sdk="+tc.sdk)).
With(sdkSource(tc.sdk, tc.source))

out, err := modGen.With(daggerQuery(`{foo{sayHello(platform: "linux/amd64")}}`)).Stdout(ctx)
require.NoError(t, err)
require.Equal(t, "hello linux/amd64", gjson.Get(out, "foo.sayHello").String())

_, err = modGen.With(daggerQuery(`{foo{sayHello(platform: "invalid")}}`)).Stdout(ctx)
require.ErrorContains(t, err, "unknown operating system or architecture")
})
}
}

func TestModuleConflictingSameNameDeps(t *testing.T) {
// A -> B -> Dint
// A -> C -> Dstr
Expand Down
15 changes: 15 additions & 0 deletions core/module.go
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,21 @@ func (mod *Module) ModTypeFor(ctx context.Context, typeDef *TypeDef, checkDirect
return nil, false, nil
}

case TypeDefKindScalar:
if checkDirectDeps {
// check to see if this is from a *direct* dependency
depType, ok, err := mod.Deps.ModTypeFor(ctx, typeDef)
if err != nil {
return nil, false, fmt.Errorf("failed to get type from dependency: %w", err)
}
if ok {
return depType, true, nil
}
}

slog.ExtraDebug("module did not find scalar", "mod", mod.Name(), "scalar", typeDef.AsScalar.Value.Name)
return nil, false, nil

default:
return nil, false, fmt.Errorf("unexpected type def kind %s", typeDef.Kind)
}
Expand Down
56 changes: 53 additions & 3 deletions core/schema/coremod.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,19 @@ func (m *CoreMod) ModTypeFor(ctx context.Context, typeDef *core.TypeDef, checkDi
Underlying: underlyingType,
}

case core.TypeDefKindScalar:
_, ok := m.Dag.ScalarType(typeDef.AsScalar.Value.Name)
if !ok {
return nil, false, nil
}
modType = &CoreModScalar{coreMod: m, name: typeDef.AsScalar.Value.Name}

case core.TypeDefKindObject:
_, ok := m.Dag.ObjectType(typeDef.AsObject.Value.Name)
if !ok {
return nil, false, nil
}
modType = &CoreModObject{coreMod: m}
modType = &CoreModObject{coreMod: m, name: typeDef.AsObject.Value.Name}

case core.TypeDefKindInterface:
// core does not yet defined any interfaces
Expand Down Expand Up @@ -204,6 +211,48 @@ func (m *CoreMod) TypeDefs(ctx context.Context) ([]*core.TypeDef, error) {
return typeDefs, nil
}

// CoreModScalar represents scalars from core (Platform, etc)
type CoreModScalar struct {
coreMod *CoreMod
name string
}

var _ core.ModType = (*CoreModScalar)(nil)

func (obj *CoreModScalar) ConvertFromSDKResult(ctx context.Context, value any) (dagql.Typed, error) {
s, ok := obj.coreMod.Dag.ScalarType(obj.name)
if !ok {
return nil, fmt.Errorf("CoreModScalar.ConvertFromSDKResult: found no scalar type")
}
return s.DecodeInput(value)
}

func (obj *CoreModScalar) ConvertToSDKInput(ctx context.Context, value dagql.Typed) (any, error) {
s, ok := obj.coreMod.Dag.ScalarType(obj.name)
if !ok {
return nil, fmt.Errorf("CoreModScalar.ConvertToSDKInput: found no scalar type")
}
val, ok := value.(dagql.Scalar[dagql.String])
if !ok {
// we assume all core scalars are strings
return nil, fmt.Errorf("CoreModScalar.ConvertToSDKInput: core scalar should be string")
}
return s.DecodeInput(string(val.Value))
}

func (obj *CoreModScalar) SourceMod() core.Mod {
return obj.coreMod
}

func (obj *CoreModScalar) TypeDef() *core.TypeDef {
return &core.TypeDef{
Kind: core.TypeDefKindScalar,
AsScalar: dagql.NonNull(&core.ScalarTypeDef{
Name: obj.name,
}),
}
}

// CoreModObject represents objects from core (Container, Directory, etc.)
type CoreModObject struct {
coreMod *CoreMod
Expand Down Expand Up @@ -291,8 +340,9 @@ func introspectionRefToTypeDef(introspectionType *introspection.TypeRef, nonNull
case string(introspection.ScalarBoolean):
typeDef.Kind = core.TypeDefKindBoolean
default:
// default to saying it's a string for now
typeDef.Kind = core.TypeDefKindString
// assume that all core scalars are strings
typeDef.Kind = core.TypeDefKindScalar
typeDef.AsScalar = dagql.NonNull(core.NewScalarTypeDef(introspectionType.Name, ""))
}

return typeDef, true, nil
Expand Down
14 changes: 14 additions & 0 deletions core/schema/module.go
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,9 @@ func (s *moduleSchema) Install() {
dagql.Func("withKind", s.typeDefWithKind).
Doc(`Sets the kind of the type.`),

dagql.Func("withScalar", s.typeDefWithScalar).
Doc(`Returns a TypeDef of kind Scalar with the provided name.`),

dagql.Func("withListOf", s.typeDefWithListOf).
Doc(`Returns a TypeDef of kind List with the provided type for its elements.`),

Expand Down Expand Up @@ -284,6 +287,7 @@ func (s *moduleSchema) Install() {
dagql.Fields[*core.InputTypeDef]{}.Install(s.dag)
dagql.Fields[*core.FieldTypeDef]{}.Install(s.dag)
dagql.Fields[*core.ListTypeDef]{}.Install(s.dag)
dagql.Fields[*core.ScalarTypeDef]{}.Install(s.dag)

dagql.Fields[*core.GeneratedCode]{
dagql.Func("withVCSGeneratedPaths", s.generatedCodeWithVCSGeneratedPaths).
Expand All @@ -309,6 +313,16 @@ func (s *moduleSchema) typeDefWithKind(ctx context.Context, def *core.TypeDef, a
return def.WithKind(args.Kind), nil
}

func (s *moduleSchema) typeDefWithScalar(ctx context.Context, def *core.TypeDef, args struct {
Name string
Description string `default:""`
}) (*core.TypeDef, error) {
if args.Name == "" {
return nil, fmt.Errorf("scalar type def must have a name")
}
return def.WithScalar(args.Name, args.Description), nil
}

func (s *moduleSchema) typeDefWithListOf(ctx context.Context, def *core.TypeDef, args struct {
ElementType core.TypeDefID
}) (*core.TypeDef, error) {
Expand Down

0 comments on commit b3cb647

Please sign in to comment.