Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add scalar kind #7158

Merged
merged 15 commits into from May 6, 2024
6 changes: 6 additions & 0 deletions .changes/unreleased/Added-20240506-124137.yaml
@@ -0,0 +1,6 @@
kind: Added
body: Added support for custom scalars and enums in function arguments
time: 2024-05-06T12:41:37.849472Z
custom:
Author: jedevc
PR: "7158"
18 changes: 15 additions & 3 deletions cmd/codegen/generator/go/templates/module_types.go
Expand Up @@ -69,6 +69,10 @@ func (ps *parseState) parseGoTypeReference(typ types.Type, named *types.Named, i
}
parsedType := &parsedPrimitiveType{goType: t, isPtr: isPtr}
if named != nil {
if ps.isDaggerGenerated(named.Obj()) {
// only pre-generated scalars allowed here
parsedType.scalarType = named
}
parsedType.alias = named.Obj().Name()
}
return parsedType, nil
Expand Down Expand Up @@ -120,6 +124,8 @@ type parsedPrimitiveType struct {
goType *types.Basic
isPtr bool

scalarType *types.Named

// if this is something like `type Foo string`, then alias will be "Foo"
alias string
}
Expand All @@ -138,9 +144,12 @@ func (spec *parsedPrimitiveType) TypeDefCode() (*Statement, error) {
default:
return nil, fmt.Errorf("unsupported basic type: %+v", spec.goType)
}
def := Qual("dag", "TypeDef").Call().Dot("WithKind").Call(
kind,
)
var def *Statement
if spec.scalarType != nil {
def = Qual("dag", "TypeDef").Call().Dot("WithScalar").Call(Lit(spec.scalarType.Obj().Name()))
} else {
def = Qual("dag", "TypeDef").Call().Dot("WithKind").Call(kind)
}
if spec.isPtr {
def = def.Dot("WithOptional").Call(Lit(true))
}
Expand All @@ -152,6 +161,9 @@ func (spec *parsedPrimitiveType) GoType() types.Type {
}

func (spec *parsedPrimitiveType) GoSubTypes() []types.Type {
if spec.scalarType != nil {
return []types.Type{spec.scalarType}
}
return nil
}

Expand Down
57 changes: 57 additions & 0 deletions cmd/dagger/flags.go
Expand Up @@ -17,6 +17,7 @@ import (
"strconv"
"strings"

"github.com/containerd/containerd/platforms"
"github.com/moby/buildkit/util/gitutil"
"github.com/spf13/pflag"

Expand Down Expand Up @@ -44,6 +45,8 @@ func GetCustomFlagValue(name string) DaggerValue {
return &moduleSourceValue{}
case Module:
return &moduleValue{}
case Platform:
return &platformValue{}
}
return nil
}
Expand All @@ -69,6 +72,8 @@ func GetCustomFlagValueSlice(name string) DaggerValue {
return &sliceValue[*moduleSourceValue]{}
case Module:
return &sliceValue[*moduleValue]{}
case Platform:
return &sliceValue[*platformValue]{}
}
return nil
}
Expand Down Expand Up @@ -592,6 +597,36 @@ func (v *moduleSourceValue) Get(ctx context.Context, dag *dagger.Client, _ *dagg
return modConf.Source, nil
}

type platformValue struct {
platform string
}

func (v *platformValue) Type() string {
return Platform
}

func (v *platformValue) Set(s string) error {
if s == "" {
return fmt.Errorf("platform cannot be empty")
}
if s == "current" {
s = platforms.DefaultString()
}
v.platform = s
return nil
}

func (v *platformValue) String() string {
return v.platform
}

func (v *platformValue) Get(ctx context.Context, dag *dagger.Client, _ *dagger.ModuleSource) (any, error) {
if v.platform == "" {
return nil, fmt.Errorf("platform cannot be empty")
}
return v.platform, nil
}

// AddFlag adds a flag appropriate for the argument type. Should return a
// pointer to the value.
func (r *modFunctionArg) AddFlag(flags *pflag.FlagSet) (any, error) {
Expand All @@ -615,6 +650,17 @@ func (r *modFunctionArg) AddFlag(flags *pflag.FlagSet) (any, error) {
val, _ := getDefaultValue[bool](r)
return flags.Bool(name, val, usage), nil

case dagger.ScalarKind:
scalarName := r.TypeDef.AsScalar.Name

if val := GetCustomFlagValue(scalarName); val != nil {
flags.Var(val, name, usage)
return val, nil
}

val, _ := getDefaultValue[string](r)
return flags.String(name, val, usage), nil

case dagger.ObjectKind:
objName := r.TypeDef.AsObject.Name

Expand Down Expand Up @@ -653,6 +699,17 @@ func (r *modFunctionArg) AddFlag(flags *pflag.FlagSet) (any, error) {
val, _ := getDefaultValue[[]bool](r)
return flags.BoolSlice(name, val, usage), nil

case dagger.ScalarKind:
scalarName := r.TypeDef.AsScalar.Name

if val := GetCustomFlagValueSlice(scalarName); val != nil {
flags.Var(val, name, usage)
return val, nil
}

val, _ := getDefaultValue[[]string](r)
return flags.StringSlice(name, val, usage), nil

case dagger.ObjectKind:
objName := elementType.AsObject.Name

Expand Down
1 change: 1 addition & 0 deletions cmd/dagger/functions.go
Expand Up @@ -30,6 +30,7 @@ const (
CacheVolume string = "CacheVolume"
ModuleSource string = "ModuleSource"
Module string = "Module"
Platform string = "Platform"
)

var funcGroup = &cobra.Group{
Expand Down
39 changes: 25 additions & 14 deletions cmd/dagger/module.go
Expand Up @@ -744,27 +744,30 @@ fragment TypeDefRefParts on TypeDef {
kind
optional
asObject {
name
name
}
asInterface {
name
name
}
asInput {
name
name
}
asList {
elementTypeDef {
kind
asObject {
name
}
asInterface {
name
}
asInput {
name
}
elementTypeDef {
kind
asObject {
name
}
asInterface {
name
}
asInput {
name
}
}
}
asScalar {
name
}
}

Expand Down Expand Up @@ -809,6 +812,9 @@ query TypeDefs {
...FieldParts
}
}
asScalar {
name
}
asInterface {
name
sourceModuleName
Expand Down Expand Up @@ -987,6 +993,7 @@ type modTypeDef struct {
AsInterface *modInterface
AsInput *modInput
AsList *modList
AsScalar *modScalar
}

type functionProvider interface {
Expand Down Expand Up @@ -1089,6 +1096,10 @@ func (o *modInterface) GetFunction(name string) (*modFunction, error) {
return nil, fmt.Errorf("no function '%s' in interface type '%s'", name, o.Name)
}

type modScalar struct {
Name string
}

type modInput struct {
Name string
Fields []*modField
Expand Down
78 changes: 78 additions & 0 deletions core/integration/module_call_test.go
Expand Up @@ -5,6 +5,7 @@ import (
"strings"
"testing"

"github.com/containerd/containerd/platforms"
"github.com/moby/buildkit/identity"
"github.com/stretchr/testify/require"

Expand Down Expand Up @@ -499,6 +500,83 @@ func (m *Test) Cacher(ctx context.Context, cache *CacheVolume, val string) (stri
require.Equal(t, "foo\nbar\n", out)
})

t.Run("platform args", func(t *testing.T) {
t.Parallel()

c, ctx := connect(t)

modGen := c.Container().From(golangImage).
WithMountedFile(testCLIBinPath, daggerCliFile(t, c)).
WithWorkdir("/work").
With(daggerExec("init", "--source=.", "--name=test", "--sdk=go")).
WithNewFile("main.go", dagger.ContainerWithNewFileOpts{
Contents: `package main

type Test struct {}

func (m *Test) FromPlatform(platform Platform) string {
return string(platform)
}

func (m *Test) ToPlatform(platform string) Platform {
return Platform(platform)
}
`,
})

out, err := modGen.With(daggerCall("from-platform", "--platform", "linux/amd64")).Stdout(ctx)
require.NoError(t, err)
require.Equal(t, "linux/amd64", out)
out, err = modGen.With(daggerCall("from-platform", "--platform", "current")).Stdout(ctx)
require.NoError(t, err)
require.Equal(t, platforms.DefaultString(), out)
_, err = modGen.With(daggerCall("from-platform", "--platform", "invalid")).Stdout(ctx)
require.ErrorContains(t, err, "unknown operating system or architecture")

out, err = modGen.With(daggerCall("to-platform", "--platform", "linux/amd64")).Stdout(ctx)
require.NoError(t, err)
require.Equal(t, "linux/amd64", out)
_, err = modGen.With(daggerCall("to-platform", "--platform", "invalid")).Stdout(ctx)
require.ErrorContains(t, err, "unknown operating system or architecture")
})

t.Run("enum args", func(t *testing.T) {
t.Parallel()

c, ctx := connect(t)

modGen := c.Container().From(golangImage).
WithMountedFile(testCLIBinPath, daggerCliFile(t, c)).
WithWorkdir("/work").
With(daggerExec("init", "--source=.", "--name=test", "--sdk=go")).
WithNewFile("main.go", dagger.ContainerWithNewFileOpts{
Contents: `package main

type Test struct {}

func (m *Test) FromProto(proto NetworkProtocol) string {
return string(proto)
}

func (m *Test) ToProto(proto string) NetworkProtocol {
return NetworkProtocol(proto)
}
`,
})

out, err := modGen.With(daggerCall("from-proto", "--proto", "TCP")).Stdout(ctx)
require.NoError(t, err)
require.Equal(t, "TCP", out)
_, err = modGen.With(daggerCall("from-proto", "--proto", "INVALID")).Stdout(ctx)
require.ErrorContains(t, err, "invalid enum value")

out, err = modGen.With(daggerCall("to-proto", "--proto", "TCP")).Stdout(ctx)
require.NoError(t, err)
require.Equal(t, "TCP", out)
_, err = modGen.With(daggerCall("to-proto", "--proto", "INVALID")).Stdout(ctx)
require.ErrorContains(t, err, "invalid enum value")
})

t.Run("module args", func(t *testing.T) {
t.Parallel()

Expand Down
42 changes: 0 additions & 42 deletions core/integration/module_python_test.go
Expand Up @@ -1420,48 +1420,6 @@ func TestModulePythonWithOtherModuleTypes(t *testing.T) {
})
}

func TestModulePythonScalarKind(t *testing.T) {
t.Parallel()

c, ctx := connect(t)

_, err := pythonModInit(t, c, `
import dagger
from dagger import dag, function, object_type

@object_type
class Test:
@function
def foo(self, platform: dagger.Platform) -> dagger.Container:
return dag.container(platform=platform)
`).
With(daggerCall("foo", "--platform", "linux/arm64")).
Sync(ctx)

require.ErrorContains(t, err, "not supported yet")
}

func TestModulePythonEnumKind(t *testing.T) {
t.Parallel()

c, ctx := connect(t)

_, err := pythonModInit(t, c, `
import dagger
from dagger import dag, function, object_type

@object_type
class Test:
@function
def foo(self, protocol: dagger.NetworkProtocol) -> dagger.Container:
return dag.container().with_exposed_port(8000, protocol=protocol)
`).
With(daggerCall("foo", "--protocol", "UDP")).
Sync(ctx)

require.ErrorContains(t, err, "not supported yet")
}

func pythonSource(contents string) dagger.WithContainerFunc {
return pythonSourceAt("", contents)
}
Expand Down