Skip to content

Commit

Permalink
feat: add scalar kind (dagger#7158)
Browse files Browse the repository at this point in the history
* chore: remove unneccessary InputType

Signed-off-by: Justin Chadwell <[email protected]>

* chore: format mod typedef query

Signed-off-by: Justin Chadwell <[email protected]>

* chore: avoid ModuleObjectType typed nil

Signed-off-by: Justin Chadwell <[email protected]>

* feat: add scalar typedefs to propagate scalars

Signed-off-by: Justin Chadwell <[email protected]>

* feat: interpret current arg to get current platform

Signed-off-by: Justin Chadwell <[email protected]>

* Add Python fix

Signed-off-by: Helder Correia <[email protected]>

* Replace ignored stdout with sync

Signed-off-by: Helder Correia <[email protected]>

* fix test to-platform

Signed-off-by: Justin Chadwell <[email protected]>

* Test return values and enum types

Signed-off-by: Helder Correia <[email protected]>

* Fix linter on changed function

Signed-off-by: Helder Correia <[email protected]>

* chore: regen rust

Signed-off-by: Justin Chadwell <[email protected]>

* feat: support scalar in TypeScript

Signed-off-by: Tom Chauveau <[email protected]>

* feat: add enum tests

Signed-off-by: Tom Chauveau <[email protected]>

* Remove old tests

Signed-off-by: Helder Correia <[email protected]>

* Add change log

Signed-off-by: Helder Correia <[email protected]>

---------

Signed-off-by: Justin Chadwell <[email protected]>
Signed-off-by: Helder Correia <[email protected]>
Signed-off-by: Tom Chauveau <[email protected]>
Co-authored-by: Helder Correia <[email protected]>
Co-authored-by: Tom Chauveau <[email protected]>
  • Loading branch information
3 people authored and vikram-dagger committed May 8, 2024
1 parent cc472d1 commit 7672599
Show file tree
Hide file tree
Showing 42 changed files with 1,499 additions and 105 deletions.
6 changes: 6 additions & 0 deletions .changes/unreleased/Added-20240506-124137.yaml
@@ -0,0 +1,6 @@
kind: Added
body: Added support for custom scalars and enums in function arguments
time: 2024-05-06T12:41:37.849472Z
custom:
Author: jedevc
PR: "7158"
18 changes: 15 additions & 3 deletions cmd/codegen/generator/go/templates/module_types.go
Expand Up @@ -69,6 +69,10 @@ func (ps *parseState) parseGoTypeReference(typ types.Type, named *types.Named, i
}
parsedType := &parsedPrimitiveType{goType: t, isPtr: isPtr}
if named != nil {
if ps.isDaggerGenerated(named.Obj()) {
// only pre-generated scalars allowed here
parsedType.scalarType = named
}
parsedType.alias = named.Obj().Name()
}
return parsedType, nil
Expand Down Expand Up @@ -120,6 +124,8 @@ type parsedPrimitiveType struct {
goType *types.Basic
isPtr bool

scalarType *types.Named

// if this is something like `type Foo string`, then alias will be "Foo"
alias string
}
Expand All @@ -138,9 +144,12 @@ func (spec *parsedPrimitiveType) TypeDefCode() (*Statement, error) {
default:
return nil, fmt.Errorf("unsupported basic type: %+v", spec.goType)
}
def := Qual("dag", "TypeDef").Call().Dot("WithKind").Call(
kind,
)
var def *Statement
if spec.scalarType != nil {
def = Qual("dag", "TypeDef").Call().Dot("WithScalar").Call(Lit(spec.scalarType.Obj().Name()))
} else {
def = Qual("dag", "TypeDef").Call().Dot("WithKind").Call(kind)
}
if spec.isPtr {
def = def.Dot("WithOptional").Call(Lit(true))
}
Expand All @@ -152,6 +161,9 @@ func (spec *parsedPrimitiveType) GoType() types.Type {
}

func (spec *parsedPrimitiveType) GoSubTypes() []types.Type {
if spec.scalarType != nil {
return []types.Type{spec.scalarType}
}
return nil
}

Expand Down
57 changes: 57 additions & 0 deletions cmd/dagger/flags.go
Expand Up @@ -17,6 +17,7 @@ import (
"strconv"
"strings"

"github.com/containerd/containerd/platforms"
"github.com/moby/buildkit/util/gitutil"
"github.com/spf13/pflag"

Expand Down Expand Up @@ -44,6 +45,8 @@ func GetCustomFlagValue(name string) DaggerValue {
return &moduleSourceValue{}
case Module:
return &moduleValue{}
case Platform:
return &platformValue{}
}
return nil
}
Expand All @@ -69,6 +72,8 @@ func GetCustomFlagValueSlice(name string) DaggerValue {
return &sliceValue[*moduleSourceValue]{}
case Module:
return &sliceValue[*moduleValue]{}
case Platform:
return &sliceValue[*platformValue]{}
}
return nil
}
Expand Down Expand Up @@ -592,6 +597,36 @@ func (v *moduleSourceValue) Get(ctx context.Context, dag *dagger.Client, _ *dagg
return modConf.Source, nil
}

type platformValue struct {
platform string
}

func (v *platformValue) Type() string {
return Platform
}

func (v *platformValue) Set(s string) error {
if s == "" {
return fmt.Errorf("platform cannot be empty")
}
if s == "current" {
s = platforms.DefaultString()
}
v.platform = s
return nil
}

func (v *platformValue) String() string {
return v.platform
}

func (v *platformValue) Get(ctx context.Context, dag *dagger.Client, _ *dagger.ModuleSource) (any, error) {
if v.platform == "" {
return nil, fmt.Errorf("platform cannot be empty")
}
return v.platform, nil
}

// AddFlag adds a flag appropriate for the argument type. Should return a
// pointer to the value.
func (r *modFunctionArg) AddFlag(flags *pflag.FlagSet) (any, error) {
Expand All @@ -615,6 +650,17 @@ func (r *modFunctionArg) AddFlag(flags *pflag.FlagSet) (any, error) {
val, _ := getDefaultValue[bool](r)
return flags.Bool(name, val, usage), nil

case dagger.ScalarKind:
scalarName := r.TypeDef.AsScalar.Name

if val := GetCustomFlagValue(scalarName); val != nil {
flags.Var(val, name, usage)
return val, nil
}

val, _ := getDefaultValue[string](r)
return flags.String(name, val, usage), nil

case dagger.ObjectKind:
objName := r.TypeDef.AsObject.Name

Expand Down Expand Up @@ -653,6 +699,17 @@ func (r *modFunctionArg) AddFlag(flags *pflag.FlagSet) (any, error) {
val, _ := getDefaultValue[[]bool](r)
return flags.BoolSlice(name, val, usage), nil

case dagger.ScalarKind:
scalarName := r.TypeDef.AsScalar.Name

if val := GetCustomFlagValueSlice(scalarName); val != nil {
flags.Var(val, name, usage)
return val, nil
}

val, _ := getDefaultValue[[]string](r)
return flags.StringSlice(name, val, usage), nil

case dagger.ObjectKind:
objName := elementType.AsObject.Name

Expand Down
1 change: 1 addition & 0 deletions cmd/dagger/functions.go
Expand Up @@ -30,6 +30,7 @@ const (
CacheVolume string = "CacheVolume"
ModuleSource string = "ModuleSource"
Module string = "Module"
Platform string = "Platform"
)

var funcGroup = &cobra.Group{
Expand Down
39 changes: 25 additions & 14 deletions cmd/dagger/module.go
Expand Up @@ -744,27 +744,30 @@ fragment TypeDefRefParts on TypeDef {
kind
optional
asObject {
name
name
}
asInterface {
name
name
}
asInput {
name
name
}
asList {
elementTypeDef {
kind
asObject {
name
}
asInterface {
name
}
asInput {
name
}
elementTypeDef {
kind
asObject {
name
}
asInterface {
name
}
asInput {
name
}
}
}
asScalar {
name
}
}
Expand Down Expand Up @@ -809,6 +812,9 @@ query TypeDefs {
...FieldParts
}
}
asScalar {
name
}
asInterface {
name
sourceModuleName
Expand Down Expand Up @@ -987,6 +993,7 @@ type modTypeDef struct {
AsInterface *modInterface
AsInput *modInput
AsList *modList
AsScalar *modScalar
}

type functionProvider interface {
Expand Down Expand Up @@ -1089,6 +1096,10 @@ func (o *modInterface) GetFunction(name string) (*modFunction, error) {
return nil, fmt.Errorf("no function '%s' in interface type '%s'", name, o.Name)
}

type modScalar struct {
Name string
}

type modInput struct {
Name string
Fields []*modField
Expand Down
78 changes: 78 additions & 0 deletions core/integration/module_call_test.go
Expand Up @@ -5,6 +5,7 @@ import (
"strings"
"testing"

"github.com/containerd/containerd/platforms"
"github.com/moby/buildkit/identity"
"github.com/stretchr/testify/require"

Expand Down Expand Up @@ -499,6 +500,83 @@ func (m *Test) Cacher(ctx context.Context, cache *CacheVolume, val string) (stri
require.Equal(t, "foo\nbar\n", out)
})

t.Run("platform args", func(t *testing.T) {
t.Parallel()

c, ctx := connect(t)

modGen := c.Container().From(golangImage).
WithMountedFile(testCLIBinPath, daggerCliFile(t, c)).
WithWorkdir("/work").
With(daggerExec("init", "--source=.", "--name=test", "--sdk=go")).
WithNewFile("main.go", dagger.ContainerWithNewFileOpts{
Contents: `package main
type Test struct {}
func (m *Test) FromPlatform(platform Platform) string {
return string(platform)
}
func (m *Test) ToPlatform(platform string) Platform {
return Platform(platform)
}
`,
})

out, err := modGen.With(daggerCall("from-platform", "--platform", "linux/amd64")).Stdout(ctx)
require.NoError(t, err)
require.Equal(t, "linux/amd64", out)
out, err = modGen.With(daggerCall("from-platform", "--platform", "current")).Stdout(ctx)
require.NoError(t, err)
require.Equal(t, platforms.DefaultString(), out)
_, err = modGen.With(daggerCall("from-platform", "--platform", "invalid")).Stdout(ctx)
require.ErrorContains(t, err, "unknown operating system or architecture")

out, err = modGen.With(daggerCall("to-platform", "--platform", "linux/amd64")).Stdout(ctx)
require.NoError(t, err)
require.Equal(t, "linux/amd64", out)
_, err = modGen.With(daggerCall("to-platform", "--platform", "invalid")).Stdout(ctx)
require.ErrorContains(t, err, "unknown operating system or architecture")
})

t.Run("enum args", func(t *testing.T) {
t.Parallel()

c, ctx := connect(t)

modGen := c.Container().From(golangImage).
WithMountedFile(testCLIBinPath, daggerCliFile(t, c)).
WithWorkdir("/work").
With(daggerExec("init", "--source=.", "--name=test", "--sdk=go")).
WithNewFile("main.go", dagger.ContainerWithNewFileOpts{
Contents: `package main
type Test struct {}
func (m *Test) FromProto(proto NetworkProtocol) string {
return string(proto)
}
func (m *Test) ToProto(proto string) NetworkProtocol {
return NetworkProtocol(proto)
}
`,
})

out, err := modGen.With(daggerCall("from-proto", "--proto", "TCP")).Stdout(ctx)
require.NoError(t, err)
require.Equal(t, "TCP", out)
_, err = modGen.With(daggerCall("from-proto", "--proto", "INVALID")).Stdout(ctx)
require.ErrorContains(t, err, "invalid enum value")

out, err = modGen.With(daggerCall("to-proto", "--proto", "TCP")).Stdout(ctx)
require.NoError(t, err)
require.Equal(t, "TCP", out)
_, err = modGen.With(daggerCall("to-proto", "--proto", "INVALID")).Stdout(ctx)
require.ErrorContains(t, err, "invalid enum value")
})

t.Run("module args", func(t *testing.T) {
t.Parallel()

Expand Down
42 changes: 0 additions & 42 deletions core/integration/module_python_test.go
Expand Up @@ -1420,48 +1420,6 @@ func TestModulePythonWithOtherModuleTypes(t *testing.T) {
})
}

func TestModulePythonScalarKind(t *testing.T) {
t.Parallel()

c, ctx := connect(t)

_, err := pythonModInit(t, c, `
import dagger
from dagger import dag, function, object_type
@object_type
class Test:
@function
def foo(self, platform: dagger.Platform) -> dagger.Container:
return dag.container(platform=platform)
`).
With(daggerCall("foo", "--platform", "linux/arm64")).
Sync(ctx)

require.ErrorContains(t, err, "not supported yet")
}

func TestModulePythonEnumKind(t *testing.T) {
t.Parallel()

c, ctx := connect(t)

_, err := pythonModInit(t, c, `
import dagger
from dagger import dag, function, object_type
@object_type
class Test:
@function
def foo(self, protocol: dagger.NetworkProtocol) -> dagger.Container:
return dag.container().with_exposed_port(8000, protocol=protocol)
`).
With(daggerCall("foo", "--protocol", "UDP")).
Sync(ctx)

require.ErrorContains(t, err, "not supported yet")
}

func pythonSource(contents string) dagger.WithContainerFunc {
return pythonSourceAt("", contents)
}
Expand Down

0 comments on commit 7672599

Please sign in to comment.