Skip to content

Commit

Permalink
wip: feat: add scalar kind
Browse files Browse the repository at this point in the history
[no ci]

Signed-off-by: Justin Chadwell <[email protected]>
  • Loading branch information
jedevc committed Apr 23, 2024
1 parent 819f525 commit 0e76110
Show file tree
Hide file tree
Showing 12 changed files with 495 additions and 22 deletions.
11 changes: 10 additions & 1 deletion cmd/codegen/generator/go/templates/module_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ func (ps *parseState) parseGoTypeReference(typ types.Type, named *types.Named, i
}
parsedType := &parsedPrimitiveType{goType: t, isPtr: isPtr}
if named != nil {
parsedType.aliasType = named
parsedType.alias = named.Obj().Name()
}
return parsedType, nil
Expand Down Expand Up @@ -121,7 +122,8 @@ type parsedPrimitiveType struct {
isPtr bool

// if this is something like `type Foo string`, then alias will be "Foo"
alias string
alias string
aliasType *types.Named
}

var _ ParsedType = &parsedPrimitiveType{}
Expand All @@ -144,6 +146,10 @@ func (spec *parsedPrimitiveType) TypeDefCode() (*Statement, error) {
if spec.isPtr {
def = def.Dot("WithOptional").Call(Lit(true))
}
if spec.alias != "" {
// XXX: should replace WithKind
def = def.Dot("WithScalar").Call(Lit(spec.alias))
}
return def, nil
}

Expand All @@ -152,6 +158,9 @@ func (spec *parsedPrimitiveType) GoType() types.Type {
}

func (spec *parsedPrimitiveType) GoSubTypes() []types.Type {
if spec.aliasType != nil {
return []types.Type{spec.aliasType}
}
return nil
}

Expand Down
18 changes: 18 additions & 0 deletions cmd/codegen/generator/go/templates/modules.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,24 @@ func (funcs goTemplateFuncs) moduleMainSrc() (string, error) { //nolint: gocyclo
}

switch underlyingObj := named.Underlying().(type) {
case *types.Basic:
if ps.isDaggerGenerated(obj) {
break
}

typeSpec, err := ps.parseGoTypeReference(underlyingObj, named, false)
if err != nil {
return "", err
}

// Add the scalar to the module
scalarTypeDefCode, err := typeSpec.TypeDefCode()
if err != nil {
return "", fmt.Errorf("failed to generate type def code for %s: %w", obj.Name(), err)
}
createMod = dotLine(createMod, "WithScalar").Call(Add(Line(), scalarTypeDefCode))
added[obj.Name()] = struct{}{}

case *types.Struct:
strct := underlyingObj
objTypeSpec, err := ps.parseGoStruct(strct, named)
Expand Down
22 changes: 22 additions & 0 deletions cmd/dagger/flags.go
Original file line number Diff line number Diff line change
Expand Up @@ -615,6 +615,17 @@ func (r *modFunctionArg) AddFlag(flags *pflag.FlagSet, dag *dagger.Client) (any,
val, _ := getDefaultValue[bool](r)
return flags.Bool(name, val, usage), nil

case dagger.ScalarKind:
scalarName := r.TypeDef.AsScalar.Name

if val := GetCustomFlagValue(scalarName); val != nil {
flags.Var(val, name, usage)
return val, nil
}

val, _ := getDefaultValue[string](r)
return flags.String(name, val, usage), nil

case dagger.ObjectKind:
objName := r.TypeDef.AsObject.Name

Expand Down Expand Up @@ -653,6 +664,17 @@ func (r *modFunctionArg) AddFlag(flags *pflag.FlagSet, dag *dagger.Client) (any,
val, _ := getDefaultValue[[]bool](r)
return flags.BoolSlice(name, val, usage), nil

case dagger.ScalarKind:
scalarName := r.TypeDef.AsScalar.Name

if val := GetCustomFlagValueSlice(scalarName); val != nil {
flags.Var(val, name, usage)
return val, nil
}

val, _ := getDefaultValue[[]string](r)
return flags.StringSlice(name, val, usage), nil

case dagger.ObjectKind:
objName := elementType.AsObject.Name

Expand Down
42 changes: 28 additions & 14 deletions cmd/dagger/module.go
Original file line number Diff line number Diff line change
Expand Up @@ -744,27 +744,30 @@ fragment TypeDefRefParts on TypeDef {
kind
optional
asObject {
name
name
}
asInterface {
name
name
}
asInput {
name
name
}
asList {
elementTypeDef {
kind
asObject {
name
}
asInterface {
name
}
asInput {
name
}
elementTypeDef {
kind
asObject {
name
}
asInterface {
name
}
asInput {
name
}
}
}
asScalar {
name
}
}
Expand Down Expand Up @@ -809,6 +812,9 @@ query TypeDefs {
...FieldParts
}
}
asScalar {
name
}
asInterface {
name
sourceModuleName
Expand Down Expand Up @@ -843,6 +849,8 @@ query TypeDefs {
modDef := &moduleDef{Name: name}
for _, typeDef := range res.TypeDefs {
switch typeDef.Kind {
case dagger.ScalarKind:
modDef.Scalars = append(modDef.Scalars, typeDef)
case dagger.ObjectKind:
modDef.Objects = append(modDef.Objects, typeDef)
case dagger.InterfaceKind:
Expand All @@ -857,6 +865,7 @@ query TypeDefs {
// moduleDef is a representation of dagger.Module.
type moduleDef struct {
Name string
Scalars []*modTypeDef
Objects []*modTypeDef
Interfaces []*modTypeDef
Inputs []*modTypeDef
Expand Down Expand Up @@ -987,6 +996,7 @@ type modTypeDef struct {
AsInterface *modInterface
AsInput *modInput
AsList *modList
AsScalar *modScalar
}

type functionProvider interface {
Expand Down Expand Up @@ -1089,6 +1099,10 @@ func (o *modInterface) GetFunction(name string) (*modFunction, error) {
return nil, fmt.Errorf("no function '%s' in interface type '%s'", name, o.Name)
}

type modScalar struct {
Name string
}

type modInput struct {
Name string
Fields []*modField
Expand Down
6 changes: 5 additions & 1 deletion core/modtypes.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ type ModType interface {
// PrimitiveType are the basic types like string, int, bool, void, etc.
type PrimitiveType struct {
Def *TypeDef
Mod *Module // XXX: no this is the wrong place for it
}

func (t *PrimitiveType) ConvertFromSDKResult(ctx context.Context, value any) (dagql.Typed, error) {
Expand All @@ -42,7 +43,10 @@ func (t *PrimitiveType) ConvertToSDKInput(ctx context.Context, value dagql.Typed
}

func (t *PrimitiveType) SourceMod() Mod {
return nil
if t.Mod == nil {
return nil
}
return t.Mod
}

func (t *PrimitiveType) TypeDef() *TypeDef {
Expand Down
98 changes: 97 additions & 1 deletion core/module.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ type Module struct {
// The module's interfaces
InterfaceDefs []*TypeDef `field:"true" name:"interfaces" doc:"Interfaces served by this module."`

// The module's scalars
ScalarDefs []*TypeDef `field:"true" name:"scalars" doc:"Scalars served by this module."`

// InstanceID is the ID of the initialized module.
InstanceID *call.ID
}
Expand Down Expand Up @@ -161,6 +164,12 @@ func (mod *Module) Initialize(ctx context.Context, oldSelf dagql.Instance[*Modul
return nil, fmt.Errorf("failed to add interface to module %q: %w", mod.Name(), err)
}
}
for _, scalar := range inst.Self.ScalarDefs {
newMod, err = newMod.WithScalar(ctx, scalar)
if err != nil {
return nil, fmt.Errorf("failed to add scalar to module %q: %w", mod.Name(), err)
}
}
newMod.InstanceID = newID

return newMod, nil
Expand Down Expand Up @@ -215,6 +224,17 @@ func (mod *Module) Install(ctx context.Context, dag *dagql.Server) error {
}
}

for _, def := range mod.ScalarDefs {
scalarDef := def.AsScalar.Value

slog.ExtraDebug("installing scalar", "name", mod.Name(), "scalar", scalarDef.Name)

s := dagql.NewScalar(scalarDef.Name, scalarDef.Description)
dag.InstallScalar(s)

return nil
}

return nil
}

Expand All @@ -234,6 +254,13 @@ func (mod *Module) TypeDefs(ctx context.Context) ([]*TypeDef, error) {
}
typeDefs = append(typeDefs, typeDef)
}
for _, def := range mod.ScalarDefs {
typeDef := def.Clone()
if typeDef.AsScalar.Valid {
typeDef.AsScalar.Value.SourceModuleName = mod.Name()
}
typeDefs = append(typeDefs, typeDef)
}
return typeDefs, nil
}

Expand All @@ -245,7 +272,7 @@ func (mod *Module) ModTypeFor(ctx context.Context, typeDef *TypeDef, checkDirect
var modType ModType
switch typeDef.Kind {
case TypeDefKindString, TypeDefKindInteger, TypeDefKindBoolean, TypeDefKindVoid:
modType = &PrimitiveType{typeDef}
modType = &PrimitiveType{typeDef, nil}

case TypeDefKindList:
underlyingType, ok, err := mod.ModTypeFor(ctx, typeDef.AsList.Value.ElementTypeDef, checkDirectDeps)
Expand All @@ -260,6 +287,32 @@ func (mod *Module) ModTypeFor(ctx context.Context, typeDef *TypeDef, checkDirect
Underlying: underlyingType,
}

case TypeDefKindScalar:
if checkDirectDeps {
// check to see if this is from a *direct* dependency
depType, ok, err := mod.Deps.ModTypeFor(ctx, typeDef)
if err != nil {
return nil, false, fmt.Errorf("failed to get type from dependency: %w", err)
}
if ok {
return depType, true, nil
}
}

var found bool
// otherwise it must be from this module
for _, obj := range mod.ScalarDefs {
if obj.AsScalar.Value.Name == typeDef.AsScalar.Value.Name {
modType = &PrimitiveType{typeDef, mod}
found = true
break
}
}
if !found {
slog.ExtraDebug("module did not find scalar", "mod", mod.Name(), "scalar", typeDef.AsScalar.Value.Name)
return nil, false, nil
}

case TypeDefKindObject:
if checkDirectDeps {
// check to see if this is from a *direct* dependency
Expand Down Expand Up @@ -470,6 +523,17 @@ func (mod *Module) namespaceTypeDef(ctx context.Context, typeDef *TypeDef) error
if err := mod.namespaceTypeDef(ctx, typeDef.AsList.Value.ElementTypeDef); err != nil {
return err
}

case TypeDefKindScalar:
scalar := typeDef.AsScalar.Value
_, ok, err := mod.Deps.ModTypeFor(ctx, typeDef)
if err != nil {
return fmt.Errorf("failed to get mod type for type def: %w", err)
}
if !ok {
scalar.Name = namespaceObject(scalar.OriginalName, mod.Name(), mod.OriginalName)
}

case TypeDefKindObject:
obj := typeDef.AsObject.Value

Expand Down Expand Up @@ -499,6 +563,7 @@ func (mod *Module) namespaceTypeDef(ctx context.Context, typeDef *TypeDef) error
}
}
}

case TypeDefKindInterface:
iface := typeDef.AsInterface.Value

Expand Down Expand Up @@ -629,6 +694,11 @@ func (mod Module) Clone() *Module {
cp.InterfaceDefs[i] = def.Clone()
}

cp.ScalarDefs = make([]*TypeDef, len(mod.ScalarDefs))
for i, def := range mod.ScalarDefs {
cp.ScalarDefs[i] = def.Clone()
}

return &cp
}

Expand Down Expand Up @@ -688,6 +758,32 @@ func (mod *Module) WithInterface(ctx context.Context, def *TypeDef) (*Module, er
return mod, nil
}

func (mod *Module) WithScalar(ctx context.Context, def *TypeDef) (*Module, error) {
mod = mod.Clone()
if !def.AsScalar.Valid {
return nil, fmt.Errorf("expected scalar def, got %s: %+v", def.Kind, def)
}

// skip validation+namespacing for module objects being constructed by SDK with* calls
// they will be validated when merged into the real final module

// XXX: aha do this bit!
if mod.Deps != nil {
if err := mod.validateTypeDef(ctx, def); err != nil {
return nil, fmt.Errorf("failed to validate type def: %w", err)
}
}
if mod.NameField != "" {
def = def.Clone()
if err := mod.namespaceTypeDef(ctx, def); err != nil {
return nil, fmt.Errorf("failed to namespace type def: %w", err)
}
}

mod.ScalarDefs = append(mod.ScalarDefs, def)
return mod, nil
}

type CurrentModule struct {
Module *Module
}
Expand Down

0 comments on commit 0e76110

Please sign in to comment.