Skip to content

Commit

Permalink
feat: allow sdks to declare custom scalars
Browse files Browse the repository at this point in the history
Signed-off-by: Justin Chadwell <[email protected]>
  • Loading branch information
jedevc committed Apr 25, 2024
1 parent 8cf88b7 commit 1feb12e
Show file tree
Hide file tree
Showing 13 changed files with 437 additions and 16 deletions.
18 changes: 18 additions & 0 deletions cmd/codegen/generator/go/templates/modules.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,24 @@ func (funcs goTemplateFuncs) moduleMainSrc() (string, error) { //nolint: gocyclo
}

switch underlyingObj := named.Underlying().(type) {
case *types.Basic:
if ps.isDaggerGenerated(obj) {
break
}

typeSpec, err := ps.parseGoTypeReference(underlyingObj, named, false)
if err != nil {
return "", err
}

// Add the scalar to the module
scalarTypeDefCode, err := typeSpec.TypeDefCode()
if err != nil {
return "", fmt.Errorf("failed to generate type def code for %s: %w", obj.Name(), err)
}
createMod = dotLine(createMod, "WithScalar").Call(Add(Line(), scalarTypeDefCode))
added[obj.Name()] = struct{}{}

case *types.Struct:
strct := underlyingObj
objTypeSpec, err := ps.parseGoStruct(strct, named)
Expand Down
31 changes: 17 additions & 14 deletions cmd/dagger/module.go
Original file line number Diff line number Diff line change
Expand Up @@ -744,27 +744,27 @@ fragment TypeDefRefParts on TypeDef {
kind
optional
asObject {
name
name
}
asInterface {
name
name
}
asInput {
name
name
}
asList {
elementTypeDef {
kind
asObject {
name
}
asInterface {
name
}
asInput {
name
}
elementTypeDef {
kind
asObject {
name
}
asInterface {
name
}
asInput {
name
}
}
}
asScalar {
name
Expand Down Expand Up @@ -851,6 +851,8 @@ query TypeDefs {
modDef := &moduleDef{Name: name}
for _, typeDef := range res.TypeDefs {
switch typeDef.Kind {
case dagger.ScalarKind:
modDef.Scalars = append(modDef.Scalars, typeDef)
case dagger.ObjectKind:
modDef.Objects = append(modDef.Objects, typeDef)
case dagger.InterfaceKind:
Expand All @@ -865,6 +867,7 @@ query TypeDefs {
// moduleDef is a representation of dagger.Module.
type moduleDef struct {
Name string
Scalars []*modTypeDef
Objects []*modTypeDef
Interfaces []*modTypeDef
Inputs []*modTypeDef
Expand Down
64 changes: 64 additions & 0 deletions core/integration/module_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3395,6 +3395,70 @@ func (t *Toplevel) SayHello(ctx context.Context) (string, error) {
require.JSONEq(t, `{"toplevel":{"sayHello":"hello!"}}`, out)
}

func TestModuleGoCustomScalar(t *testing.T) {
t.Parallel()

var logs safeBuffer
c, ctx := connect(t, dagger.WithLogOutput(io.MultiWriter(os.Stdout, &logs)))

ctr := c.Container().From(golangImage).
WithMountedFile(testCLIBinPath, daggerCliFile(t, c))

ctr = ctr.
WithWorkdir("/toplevel/foo").
With(daggerExec("init", "--name=foo", "--sdk=go", "--source=.")).
WithNewFile("main.go", dagger.ContainerWithNewFileOpts{
Contents: `package main
import "fmt"
type Foo struct {}
type CustomString string
type CustomInt int
type CustomBool bool
func (foo *Foo) Hello(s CustomString, i CustomInt, b CustomBool) string {
return fmt.Sprintf("%s %d %t", s, i, b)
}
`,
})
out, err := ctr.With(daggerQuery(`{foo{hello(s: "abc", i: 1, b: true)}}`)).Stdout(ctx)
require.NoError(t, err)
require.JSONEq(t, `{"foo":{"hello":"abc 1 true"}}`, out)

ctr = ctr.
WithWorkdir("/toplevel").
With(daggerExec("init", "--name=toplevel", "--sdk=go", "--source=.")).
With(daggerExec("install", "./foo")).
WithNewFile("main.go", dagger.ContainerWithNewFileOpts{
Contents: `package main
import "context"
type Toplevel struct {}
func (t *Toplevel) Hello(ctx context.Context) (string, error) {
return dag.Foo().Hello(ctx, FooCustomString("xyz"), FooCustomInt("2"), FooCustomBool("true"))
}
func (t *Toplevel) HelloBad(ctx context.Context) (string, error) {
return dag.Foo().Hello(ctx, FooCustomString("xyz"), FooCustomInt("2"), FooCustomBool("invalid"))
}
`,
})
logGen(ctx, t, ctr.Directory("."))

out, err = ctr.With(daggerQuery(`{toplevel{hello}}`)).Stdout(ctx)
require.NoError(t, err)
require.JSONEq(t, `{"toplevel":{"hello":"xyz 2 true"}}`, out)

_, err = ctr.With(daggerQuery(`{toplevel{helloBad}}`)).Stdout(ctx)
require.Error(t, err)
require.NoError(t, c.Close())
require.Contains(t, logs.String(), "invalid syntax")
}

var useInner = `package main
type Dep struct{}
Expand Down
26 changes: 26 additions & 0 deletions core/modtypes.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,32 @@ func (t *PrimitiveType) TypeDef() *TypeDef {
return t.Def
}

// ModuleScalarType is an arbitrary scalar type.
type ModuleScalarType struct {
Def *TypeDef
Mod *Module
}

func (t *ModuleScalarType) ConvertFromSDKResult(ctx context.Context, value any) (dagql.Typed, error) {
// NB: we lean on the fact that all primitive types are also dagql.Inputs
return t.Def.ToInput().Decoder().DecodeInput(value)
}

func (t *ModuleScalarType) ConvertToSDKInput(ctx context.Context, value dagql.Typed) (any, error) {
return value, nil
}

func (t *ModuleScalarType) SourceMod() Mod {
if t.Mod == nil {
return nil
}
return t.Mod
}

func (t *ModuleScalarType) TypeDef() *TypeDef {
return t.Def
}

type ListType struct {
Elem *TypeDef
Underlying ModType
Expand Down
92 changes: 90 additions & 2 deletions core/module.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ type Module struct {
// The module's interfaces
InterfaceDefs []*TypeDef `field:"true" name:"interfaces" doc:"Interfaces served by this module."`

// The module's scalars
ScalarDefs []*TypeDef `field:"true" name:"scalars" doc:"Scalars served by this module."`

// InstanceID is the ID of the initialized module.
InstanceID *call.ID
}
Expand Down Expand Up @@ -161,6 +164,12 @@ func (mod *Module) Initialize(ctx context.Context, oldSelf dagql.Instance[*Modul
return nil, fmt.Errorf("failed to add interface to module %q: %w", mod.Name(), err)
}
}
for _, scalar := range inst.Self.ScalarDefs {
newMod, err = newMod.WithScalar(ctx, scalar)
if err != nil {
return nil, fmt.Errorf("failed to add scalar to module %q: %w", mod.Name(), err)
}
}
newMod.InstanceID = newID

return newMod, nil
Expand Down Expand Up @@ -215,6 +224,25 @@ func (mod *Module) Install(ctx context.Context, dag *dagql.Server) error {
}
}

for _, def := range mod.ScalarDefs {
scalarDef := def.AsScalar.Value

slog.ExtraDebug("installing scalar", "name", mod.Name(), "scalar", scalarDef.Name)

var sc dagql.ScalarType
switch scalarDef.Kind {
case TypeDefKindString:
sc = dagql.NewScalar[dagql.String](scalarDef.Name, dagql.String(""))
case TypeDefKindInteger:
sc = dagql.NewScalar[dagql.Int](scalarDef.Name, dagql.Int(0))
case TypeDefKindBoolean:
sc = dagql.NewScalar[dagql.Boolean](scalarDef.Name, dagql.Boolean(false))
default:
return fmt.Errorf("unsupported type kind: %s", scalarDef.Kind)
}
dag.InstallScalar(sc)
}

return nil
}

Expand All @@ -234,6 +262,13 @@ func (mod *Module) TypeDefs(ctx context.Context) ([]*TypeDef, error) {
}
typeDefs = append(typeDefs, typeDef)
}
for _, def := range mod.ScalarDefs {
typeDef := def.Clone()
if typeDef.AsScalar.Valid {
typeDef.AsScalar.Value.SourceModuleName = mod.Name()
}
typeDefs = append(typeDefs, typeDef)
}
return typeDefs, nil
}

Expand Down Expand Up @@ -330,8 +365,19 @@ func (mod *Module) ModTypeFor(ctx context.Context, typeDef *TypeDef, checkDirect
}
}

slog.ExtraDebug("module did not find scalar", "mod", mod.Name(), "scalar", typeDef.AsScalar.Value.Name)
return nil, false, nil
var found bool
// otherwise it must be from this module
for _, obj := range mod.ScalarDefs {
if obj.AsScalar.Value.Name == typeDef.AsScalar.Value.Name {
modType = &ModuleScalarType{typeDef, mod}
found = true
break
}
}
if !found {
slog.ExtraDebug("module did not find scalar", "mod", mod.Name(), "scalar", typeDef.AsScalar.Value.Name)
return nil, false, nil
}

default:
return nil, false, fmt.Errorf("unexpected type def kind %s", typeDef.Kind)
Expand Down Expand Up @@ -485,6 +531,7 @@ func (mod *Module) namespaceTypeDef(ctx context.Context, typeDef *TypeDef) error
if err := mod.namespaceTypeDef(ctx, typeDef.AsList.Value.ElementTypeDef); err != nil {
return err
}

case TypeDefKindObject:
obj := typeDef.AsObject.Value

Expand Down Expand Up @@ -514,6 +561,7 @@ func (mod *Module) namespaceTypeDef(ctx context.Context, typeDef *TypeDef) error
}
}
}

case TypeDefKindInterface:
iface := typeDef.AsInterface.Value

Expand All @@ -537,6 +585,16 @@ func (mod *Module) namespaceTypeDef(ctx context.Context, typeDef *TypeDef) error
}
}
}

case TypeDefKindScalar:
scalar := typeDef.AsScalar.Value
_, ok, err := mod.Deps.ModTypeFor(ctx, typeDef)
if err != nil {
return fmt.Errorf("failed to get mod type for type def: %w", err)
}
if !ok {
scalar.Name = namespaceObject(scalar.OriginalName, mod.Name(), mod.OriginalName)
}
}
return nil
}
Expand Down Expand Up @@ -644,6 +702,11 @@ func (mod Module) Clone() *Module {
cp.InterfaceDefs[i] = def.Clone()
}

cp.ScalarDefs = make([]*TypeDef, len(mod.ScalarDefs))
for i, def := range mod.ScalarDefs {
cp.ScalarDefs[i] = def.Clone()
}

return &cp
}

Expand Down Expand Up @@ -703,6 +766,31 @@ func (mod *Module) WithInterface(ctx context.Context, def *TypeDef) (*Module, er
return mod, nil
}

func (mod *Module) WithScalar(ctx context.Context, def *TypeDef) (*Module, error) {
mod = mod.Clone()
if !def.AsScalar.Valid {
return nil, fmt.Errorf("expected scalar def, got %s: %+v", def.Kind, def)
}

// skip validation+namespacing for module objects being constructed by SDK with* calls
// they will be validated when merged into the real final module

if mod.Deps != nil {
if err := mod.validateTypeDef(ctx, def); err != nil {
return nil, fmt.Errorf("failed to validate type def: %w", err)
}
}
if mod.NameField != "" {
def = def.Clone()
if err := mod.namespaceTypeDef(ctx, def); err != nil {
return nil, fmt.Errorf("failed to namespace type def: %w", err)
}
}

mod.ScalarDefs = append(mod.ScalarDefs, def)
return mod, nil
}

type CurrentModule struct {
Module *Module
}
Expand Down
13 changes: 13 additions & 0 deletions core/schema/module.go
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,9 @@ func (s *moduleSchema) Install() {
Doc(`Retrieves the module with the given description`).
ArgDoc("description", `The description to set`),

dagql.Func("withScalar", s.moduleWithScalar).
Doc(`This module plus the given Scalar type.`),

dagql.Func("withObject", s.moduleWithObject).
Doc(`This module plus the given Object type and associated functions.`),

Expand Down Expand Up @@ -524,6 +527,16 @@ func (s *moduleSchema) moduleWithInterface(ctx context.Context, mod *core.Module
return mod.WithInterface(ctx, def.Self)
}

func (s *moduleSchema) moduleWithScalar(ctx context.Context, mod *core.Module, args struct {
Scalar core.TypeDefID
}) (_ *core.Module, rerr error) {
def, err := args.Scalar.Load(ctx, s.dag)
if err != nil {
return nil, err
}
return mod.WithScalar(ctx, def.Self)
}

func (s *moduleSchema) currentModuleName(
ctx context.Context,
curMod *core.CurrentModule,
Expand Down
6 changes: 6 additions & 0 deletions docs/docs-graphql/schema.graphqls
Original file line number Diff line number Diff line change
Expand Up @@ -1751,6 +1751,9 @@ type Module {
"""
runtime: Container!

"""Scalars served by this module."""
scalars: [TypeDef!]!

"""
The SDK used by this module. Either a name of a builtin SDK or a module source
ref string pointing to the SDK's implementation.
Expand Down Expand Up @@ -1780,6 +1783,9 @@ type Module {
"""This module plus the given Object type and associated functions."""
withObject(object: TypeDefID!): Module!

"""This module plus the given Scalar type."""
withScalar(scalar: TypeDefID!): Module!

"""Retrieves the module with basic configuration loaded if present."""
withSource(
"""The module source to initialize from."""
Expand Down

0 comments on commit 1feb12e

Please sign in to comment.