diff --git a/cmd/shim/main.go b/cmd/shim/main.go index ae01fc8d9d0..8172b3a2e58 100644 --- a/cmd/shim/main.go +++ b/cmd/shim/main.go @@ -24,7 +24,6 @@ import ( "github.com/cenkalti/backoff/v4" "github.com/containerd/console" "github.com/google/uuid" - "github.com/opencontainers/go-digest" "github.com/opencontainers/runtime-spec/specs-go" "go.opentelemetry.io/otel/propagation" "go.opentelemetry.io/otel/sdk/resource" @@ -501,11 +500,8 @@ func setupBundle() (returnExitCode int) { } var searchDomains []string - for _, parentClientID := range execMetadata.ParentClientIDs { - searchDomains = append(searchDomains, network.ClientDomain(parentClientID)) - } - if len(searchDomains) > 0 { - spec.Process.Env = append(spec.Process.Env, "_DAGGER_PARENT_CLIENT_IDS="+strings.Join(execMetadata.ParentClientIDs, " ")) + if ns := execMetadata.ServerID; ns != "" { + searchDomains = append(searchDomains, network.ClientDomain(ns)) } var hostsFilePath string @@ -543,7 +539,7 @@ func setupBundle() (returnExitCode int) { keepEnv := []string{} for _, env := range spec.Process.Env { switch { - case strings.HasPrefix(env, "_DAGGER_ENABLE_NESTING="): + case strings.HasPrefix(env, "_DAGGER_NESTED_CLIENT_ID="): // keep the env var; we use it at runtime keepEnv = append(keepEnv, env) @@ -562,6 +558,9 @@ func setupBundle() (returnExitCode int) { Source: "/run/buildkit/buildkitd.sock", }) case strings.HasPrefix(env, "_DAGGER_SERVER_ID="): + case strings.HasPrefix(env, "_DAGGER_ENGINE_VERSION="): + // don't need this at runtime, it is just for invalidating cache, which + // has already happened by now case strings.HasPrefix(env, aliasPrefix): // NB: don't keep this env var, it's only for the bundling step // keepEnv = append(keepEnv, env) @@ -715,7 +714,8 @@ func internalEnv(name string) (string, bool) { } func runWithNesting(ctx context.Context, cmd *exec.Cmd) error { - if _, found := internalEnv("_DAGGER_ENABLE_NESTING"); !found { + clientID, ok := internalEnv("_DAGGER_NESTED_CLIENT_ID") + if !ok { // no nesting; run as normal return execProcess(cmd, true) } @@ -733,25 +733,16 @@ func runWithNesting(ctx context.Context, cmd *exec.Cmd) error { } sessionPort := l.Addr().(*net.TCPAddr).Port - parentClientIDsVal, _ := internalEnv("_DAGGER_PARENT_CLIENT_IDS") - - clientParams := client.Params{ - SecretToken: sessionToken.String(), - RunnerHost: "unix:///.runner.sock", - ParentClientIDs: strings.Fields(parentClientIDsVal), - } - - if _, ok := internalEnv("_DAGGER_ENABLE_NESTING_IN_SAME_SESSION"); ok { - serverID, ok := internalEnv("_DAGGER_SERVER_ID") - if !ok { - return fmt.Errorf("missing _DAGGER_SERVER_ID") - } - clientParams.ServerID = serverID + serverID, ok := internalEnv("_DAGGER_SERVER_ID") + if !ok { + return errors.New("missing nested client server ID") } - moduleCallerDigest, ok := internalEnv("_DAGGER_MODULE_CALLER_DIGEST") - if ok { - clientParams.ModuleCallerDigest = digest.Digest(moduleCallerDigest) + clientParams := client.Params{ + ID: clientID, + ServerID: serverID, + SecretToken: sessionToken.String(), + RunnerHost: "unix:///.runner.sock", } sess, ctx, err := client.Connect(ctx, clientParams) diff --git a/core/c2h.go b/core/c2h.go index b11dc09e7d6..8a70d931ba8 100644 --- a/core/c2h.go +++ b/core/c2h.go @@ -20,6 +20,7 @@ type c2hTunnel struct { upstreamHost string tunnelServiceHost string tunnelServicePorts []PortForward + sessionID string } func (d *c2hTunnel) Tunnel(ctx context.Context) (rerr error) { @@ -56,6 +57,7 @@ func (d *c2hTunnel) Tunnel(ctx context.Context) (rerr error) { upstream := NewHostIPSocket( port.Protocol.Network(), fmt.Sprintf("%s:%d", d.upstreamHost, port.Backend), + d.sessionID, ) sockPath := fmt.Sprintf("/upstream.%d.sock", frontend) diff --git a/core/container.go b/core/container.go index 8abecd4a3e3..e8e82c0d66e 100644 --- a/core/container.go +++ b/core/container.go @@ -1007,16 +1007,22 @@ func (container *Container) WithExec(ctx context.Context, opts ContainerExecOpts // this allows executed containers to communicate back to this API if opts.ExperimentalPrivilegedNesting { - // include the engine version so that these execs get invalidated if the engine/API change - runOpts = append(runOpts, llb.AddEnv("_DAGGER_ENABLE_NESTING", engine.Version)) - } - - if opts.ModuleCallerDigest != "" { - runOpts = append(runOpts, llb.AddEnv("_DAGGER_MODULE_CALLER_DIGEST", opts.ModuleCallerDigest.String())) - } - - if opts.NestedInSameSession { - runOpts = append(runOpts, llb.AddEnv("_DAGGER_ENABLE_NESTING_IN_SAME_SESSION", "")) + callerOpts := opts.NestedExecFunctionCall + if callerOpts == nil { + // default to caching the nested exec + callerOpts = &FunctionCall{ + Cache: true, + } + } + clientID, err := container.Query.RegisterCaller(ctx, callerOpts) + if err != nil { + return nil, fmt.Errorf("register caller: %w", err) + } + runOpts = append(runOpts, + llb.AddEnv("_DAGGER_NESTED_CLIENT_ID", clientID), + // include the engine version so that these execs get invalidated if the engine/API change + llb.AddEnv("_DAGGER_ENGINE_VERSION", engine.Version), + ) } metaSt, metaSourcePath := metaMount(opts.Stdin) @@ -1057,13 +1063,7 @@ func (container *Container) WithExec(ctx context.Context, opts ContainerExecOpts } // don't pass these through to the container when manually set, they are internal only - if name == "_DAGGER_ENABLE_NESTING" && !opts.ExperimentalPrivilegedNesting { - continue - } - if name == "_DAGGER_MODULE_CALLER_DIGEST" && opts.ModuleCallerDigest == "" { - continue - } - if name == "_DAGGER_ENABLE_NESTING_IN_SAME_SESSION" && !opts.NestedInSameSession { + if name == "_DAGGER_NESTED_CLIENT_ID" && !opts.ExperimentalPrivilegedNesting { continue } @@ -1188,8 +1188,7 @@ func (container *Container) WithExec(ctx context.Context, opts ContainerExecOpts return nil, err } execMeta := buildkit.ContainerExecUncachedMetadata{ - ParentClientIDs: clientMetadata.ClientIDs(), - ServerID: clientMetadata.ServerID, + ServerID: clientMetadata.ServerID, } proxyVal, err := execMeta.ToPBFtpProxyVal() if err != nil { @@ -1784,22 +1783,15 @@ type ContainerExecOpts struct { // Redirect the command's standard error to a file in the container RedirectStderr string `default:""` - // Provide dagger access to the executed command - // Do not use this option unless you trust the command being executed. - // The command being executed WILL BE GRANTED FULL ACCESS TO YOUR HOST FILESYSTEM + // Provide the executed command access back to the Dagger API ExperimentalPrivilegedNesting bool `default:"false"` // Grant the process all root capabilities InsecureRootCapabilities bool `default:"false"` - // (Internal-only) If this exec is for a module function, this digest will be set in the - // grpc context metadata for any api requests back to the engine. It's used by the API - // server to determine which schema to serve and other module context metadata. - ModuleCallerDigest digest.Digest `name:"-"` - - // (Internal-only) Used for module function execs to trigger the nested api client to - // be connected back to the same session. - NestedInSameSession bool `name:"-"` + // (Internal-only) If this is a nested exec for a Function call, this should be set + // with the metadata for that call + NestedExecFunctionCall *FunctionCall `name:"-"` } type BuildArg struct { diff --git a/core/git.go b/core/git.go index f0f9774f815..3317c13f583 100644 --- a/core/git.go +++ b/core/git.go @@ -60,14 +60,20 @@ func (*GitRef) TypeDescription() string { func (ref *GitRef) Tree(ctx context.Context) (*Directory, error) { bk := ref.Query.Buildkit - st := ref.getState(ctx, bk) - return NewDirectorySt(ctx, ref.Query, *st, "", ref.Repo.Platform, ref.Repo.Services) + st, err := ref.getState(ctx, bk) + if err != nil { + return nil, err + } + return NewDirectorySt(ctx, ref.Query, st, "", ref.Repo.Platform, ref.Repo.Services) } func (ref *GitRef) Commit(ctx context.Context) (string, error) { bk := ref.Query.Buildkit - st := ref.getState(ctx, bk) - p, err := resolveProvenance(ctx, bk, *st) + st, err := ref.getState(ctx, bk) + if err != nil { + return "", err + } + p, err := resolveProvenance(ctx, bk, st) if err != nil { return "", err } @@ -77,7 +83,7 @@ func (ref *GitRef) Commit(ctx context.Context) (string, error) { return p.Sources.Git[0].Commit, nil } -func (ref *GitRef) getState(ctx context.Context, bk *buildkit.Client) *llb.State { +func (ref *GitRef) getState(ctx context.Context, bk *buildkit.Client) (llb.State, error) { opts := []llb.GitOption{} if ref.Repo.KeepGitDir { @@ -96,26 +102,10 @@ func (ref *GitRef) getState(ctx context.Context, bk *buildkit.Client) *llb.State opts = append(opts, llb.AuthHeaderSecret(ref.Repo.AuthHeader.Accessor)) } - useDNS := len(ref.Repo.Services) > 0 - clientMetadata, err := engine.ClientMetadataFromContext(ctx) - if err == nil && !useDNS { - useDNS = len(clientMetadata.ParentClientIDs) > 0 + if err != nil { + return llb.State{}, err } - var st llb.State - if useDNS { - // NB: only configure search domains if we're directly using a service, or - // if we're nested beneath another search domain. - // - // we have to be a bit selective here to avoid breaking Dockerfile builds - // that use a Buildkit frontend (# syntax = ...) that doesn't have the - // networks API cap. - // - // TODO: add API cap - st = gitdns.Git(ref.Repo.URL, ref.Ref, clientMetadata.ClientIDs(), opts...) - } else { - st = llb.Git(ref.Repo.URL, ref.Ref, opts...) - } - return &st + return gitdns.Git(ref.Repo.URL, ref.Ref, clientMetadata.ServerID, opts...), nil } diff --git a/core/integration/module_test.go b/core/integration/module_test.go index 7e1a2dcf04f..b1ea971a8c4 100644 --- a/core/integration/module_test.go +++ b/core/integration/module_test.go @@ -5628,6 +5628,118 @@ func TestModuleUnicodePath(t *testing.T) { require.JSONEq(t, `{"test":{"hello":"hello"}}`, out) } +func TestModuleStartServices(t *testing.T) { + t.Parallel() + + // regression test for https://github.com/dagger/dagger/pull/6914 + t.Run("use service in multiple functions", func(t *testing.T) { + t.Parallel() + c, ctx := connect(t) + + out, err := c.Container().From(golangImage). + WithMountedFile(testCLIBinPath, daggerCliFile(t, c)). + WithWorkdir("/work"). + With(daggerExec("init", "--source=.", "--name=test", "--sdk=go")). + WithNewFile("/work/main.go", dagger.ContainerWithNewFileOpts{ + Contents: `package main + import ( + "context" + "fmt" + ) + + type Test struct { + } + + func (m *Test) FnA(ctx context.Context) (*Sub, error) { + svc := dag.Container(). + From("python"). + WithMountedDirectory( + "/srv/www", + dag.Directory().WithNewFile("index.html", "hey there"), + ). + WithWorkdir("/srv/www"). + WithExposedPort(23457). + WithExec([]string{"python", "-m", "http.server", "23457"}). + AsService() + + ctr := dag.Container(). + From("alpine:3.18.6"). + WithServiceBinding("svc", svc). + WithExec([]string{"wget", "-O", "-", "http://svc:23457"}) + + out, err := ctr.Stdout(ctx) + if err != nil { + return nil, err + } + if out != "hey there" { + return nil, fmt.Errorf("unexpected output: %q", out) + } + return &Sub{Ctr: ctr}, nil + } + + type Sub struct { + Ctr *Container + } + + func (m *Sub) FnB(ctx context.Context) (string, error) { + return m.Ctr. + WithExec([]string{"wget", "-O", "-", "http://svc:23457"}). + Stdout(ctx) + } + `, + }). + With(daggerCall("fnA", "fnB")). + Stdout(ctx) + require.NoError(t, err) + require.Equal(t, "hey there", strings.TrimSpace(out)) + }) + + // regression test for https://github.com/dagger/dagger/issues/6951 + t.Run("service in multiple containers", func(t *testing.T) { + t.Parallel() + c, ctx := connect(t) + + _, err := c.Container().From(golangImage). + WithMountedFile(testCLIBinPath, daggerCliFile(t, c)). + WithWorkdir("/work"). + With(daggerExec("init", "--source=.", "--name=test", "--sdk=go")). + WithNewFile("/work/main.go", dagger.ContainerWithNewFileOpts{ + Contents: `package main +import ( + "context" +) + +type Test struct { +} + +func (m *Test) Fn(ctx context.Context) *Container { + redis := dag.Container(). + From("redis"). + WithExposedPort(6379). + AsService() + cli := dag.Container(). + From("redis"). + WithoutEntrypoint(). + WithServiceBinding("redis", redis) + + ctrA := cli.WithExec([]string{"sh", "-c", "redis-cli -h redis info >> /tmp/out.txt"}) + + file := ctrA.Directory("/tmp").File("/out.txt") + + ctrB := dag.Container(). + From("alpine"). + WithFile("/out.txt", file) + + return ctrB.WithExec([]string{"cat", "/out.txt"}) +} + `, + }). + With(daggerCall("fn", "stdout")). + Sync(ctx) + require.NoError(t, err) + }) +} + func daggerExec(args ...string) dagger.WithContainerFunc { return func(c *dagger.Container) *dagger.Container { return c.WithExec(append([]string{"dagger", "--debug"}, args...), dagger.ContainerWithExecOpts{ diff --git a/core/interface.go b/core/interface.go index 9eb3d604fc9..8cdca1ed137 100644 --- a/core/interface.go +++ b/core/interface.go @@ -222,7 +222,7 @@ func (iface *InterfaceType) Install(ctx context.Context, dag *dagql.Server) erro }) } - res, err := callable.Call(ctx, dagql.CurrentID(ctx), &CallOpts{ + res, err := callable.Call(ctx, &CallOpts{ Inputs: callInputs, ParentVal: runtimeVal.Fields, }) diff --git a/core/modfunc.go b/core/modfunc.go index dba748d2a9f..c75e539cbfe 100644 --- a/core/modfunc.go +++ b/core/modfunc.go @@ -17,7 +17,6 @@ import ( "github.com/dagger/dagger/core/pipeline" "github.com/dagger/dagger/dagql" "github.com/dagger/dagger/dagql/call" - "github.com/dagger/dagger/engine" "github.com/dagger/dagger/engine/buildkit" ) @@ -113,7 +112,7 @@ func (fn *ModuleFunction) recordCall(ctx context.Context) { analytics.Ctx(ctx).Capture(ctx, "module_call", props) } -func (fn *ModuleFunction) Call(ctx context.Context, caller *call.ID, opts *CallOpts) (t dagql.Typed, rerr error) { +func (fn *ModuleFunction) Call(ctx context.Context, opts *CallOpts) (t dagql.Typed, rerr error) { mod := fn.mod lg := bklog.G(ctx).WithField("module", mod.Name()).WithField("function", fn.metadata.Name) @@ -166,23 +165,6 @@ func (fn *ModuleFunction) Call(ctx context.Context, caller *call.ID, opts *CallO }) } - callerDigestInputs := []string{} - { - callerIDDigest := caller.Digest() // FIXME(vito) canonicalize, once all that's implemented - callerDigestInputs = append(callerDigestInputs, callerIDDigest.String()) - } - if !opts.Cache { - // use the ServerID so that we bust cache once-per-session - clientMetadata, err := engine.ClientMetadataFromContext(ctx) - if err != nil { - return nil, fmt.Errorf("failed to get client metadata: %w", err) - } - callerDigestInputs = append(callerDigestInputs, clientMetadata.ServerID) - } - - callerDigest := digest.FromString(strings.Join(callerDigestInputs, " ")) - - ctx = bklog.WithLogger(ctx, bklog.G(ctx).WithField("caller_digest", callerDigest.String())) bklog.G(ctx).Debug("function call") defer func() { bklog.G(ctx).Debug("function call done") @@ -191,10 +173,28 @@ func (fn *ModuleFunction) Call(ctx context.Context, caller *call.ID, opts *CallO } }() + parentJSON, err := json.Marshal(opts.ParentVal) + if err != nil { + return nil, fmt.Errorf("failed to marshal parent value: %w", err) + } + + callMeta := &FunctionCall{ + Query: fn.root, + Name: fn.metadata.OriginalName, + Parent: parentJSON, + InputArgs: callInputs, + Module: mod, + Cache: opts.Cache, + SkipSelfSchema: opts.SkipSelfSchema, + } + if fn.objDef != nil { + callMeta.ParentName = fn.objDef.OriginalName + } + ctr := fn.runtime metaDir := NewScratchDirectory(mod.Query, mod.Query.Platform) - ctr, err := ctr.WithMountedDirectory(ctx, modMetaDirPath, metaDir, "", false) + ctr, err = ctr.WithMountedDirectory(ctx, modMetaDirPath, metaDir, "", false) if err != nil { return nil, fmt.Errorf("failed to mount mod metadata directory: %w", err) } @@ -219,45 +219,13 @@ func (fn *ModuleFunction) Call(ctx context.Context, caller *call.ID, opts *CallO // Setup the Exec for the Function call and evaluate it ctr, err = ctr.WithExec(ctx, ContainerExecOpts{ - ModuleCallerDigest: callerDigest, ExperimentalPrivilegedNesting: true, - NestedInSameSession: true, + NestedExecFunctionCall: callMeta, }) if err != nil { return nil, fmt.Errorf("failed to exec function: %w", err) } - parentJSON, err := json.Marshal(opts.ParentVal) - if err != nil { - return nil, fmt.Errorf("failed to marshal parent value: %w", err) - } - - callMeta := &FunctionCall{ - Query: fn.root, - Name: fn.metadata.OriginalName, - Parent: parentJSON, - InputArgs: callInputs, - } - if fn.objDef != nil { - callMeta.ParentName = fn.objDef.OriginalName - } - - var deps *ModDeps - if opts.SkipSelfSchema { - // Only serve the APIs of the deps of this module. This is currently only needed for the special - // case of the function used to get the definition of the module itself (which can't obviously - // be served the API its returning the definition of). - deps = mod.Deps - } else { - // by default, serve both deps and the module's own API to itself - deps = mod.Deps.Prepend(mod) - } - - err = mod.Query.RegisterFunctionCall(ctx, callerDigest, deps, fn.mod, callMeta) - if err != nil { - return nil, fmt.Errorf("failed to register function call: %w", err) - } - _, err = ctr.Evaluate(ctx) if err != nil { if fn.metadata.OriginalName == "" { diff --git a/core/module.go b/core/module.go index b4e28b07cf9..3a84fb89a49 100644 --- a/core/module.go +++ b/core/module.go @@ -138,7 +138,7 @@ func (mod *Module) Initialize(ctx context.Context, oldSelf dagql.Instance[*Modul return nil, fmt.Errorf("failed to create module definition function for module %q: %w", mod.Name(), err) } - result, err := getModDefFn.Call(ctx, newID, &CallOpts{Cache: true, SkipSelfSchema: true}) + result, err := getModDefFn.Call(ctx, &CallOpts{Cache: true, SkipSelfSchema: true}) if err != nil { return nil, fmt.Errorf("failed to call module %q to get functions: %w", mod.Name(), err) } diff --git a/core/object.go b/core/object.go index d2c1009cdf0..80a81ca6291 100644 --- a/core/object.go +++ b/core/object.go @@ -9,7 +9,6 @@ import ( "github.com/vektah/gqlparser/v2/ast" "github.com/dagger/dagger/dagql" - "github.com/dagger/dagger/dagql/call" "github.com/dagger/dagger/engine/slog" ) @@ -85,7 +84,7 @@ func (t *ModuleObjectType) TypeDef() *TypeDef { } type Callable interface { - Call(context.Context, *call.ID, *CallOpts) (dagql.Typed, error) + Call(context.Context, *CallOpts) (dagql.Typed, error) ReturnType() (ModType, error) ArgType(argName string) (ModType, error) } @@ -262,7 +261,7 @@ func (obj *ModuleObject) installConstructor(ctx context.Context, dag *dagql.Serv Value: v, }) } - return fn.Call(ctx, dagql.CurrentID(ctx), &CallOpts{ + return fn.Call(ctx, &CallOpts{ Inputs: callInput, ParentVal: nil, }) @@ -359,7 +358,7 @@ func objFun(ctx context.Context, mod *Module, objDef *ObjectTypeDef, fun *Functi sort.Slice(opts.Inputs, func(i, j int) bool { return opts.Inputs[i].Name < opts.Inputs[j].Name }) - return modFun.Call(ctx, dagql.CurrentID(ctx), opts) + return modFun.Call(ctx, opts) }, }, nil } @@ -370,7 +369,7 @@ type CallableField struct { Return ModType } -func (f *CallableField) Call(ctx context.Context, id *call.ID, opts *CallOpts) (dagql.Typed, error) { +func (f *CallableField) Call(ctx context.Context, opts *CallOpts) (dagql.Typed, error) { val, ok := opts.ParentVal[f.Field.OriginalName] if !ok { return nil, fmt.Errorf("field %q not found on object %q", f.Field.Name, opts.ParentVal) diff --git a/core/query.go b/core/query.go index 9a9f7a5d24d..22114d5cac8 100644 --- a/core/query.go +++ b/core/query.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "net/http" + "strings" "sync" "github.com/containerd/containerd/content" @@ -17,6 +18,7 @@ import ( "github.com/dagger/dagger/dagql/call" "github.com/dagger/dagger/engine" "github.com/dagger/dagger/engine/buildkit" + "github.com/dagger/dagger/engine/slog" ) // Query forms the root of the DAG and houses all necessary state and @@ -57,8 +59,9 @@ type QueryOpts struct { // For the special case of the main client caller, the key is just empty string. // This is never explicitly deleted from; instead it will just be garbage collected // when this server for the session shuts down - ClientCallContext map[digest.Digest]*ClientCallContext - ClientCallMu *sync.RWMutex + ClientCallContext map[string]*ClientCallContext + ClientCallMu *sync.RWMutex + MainClientCallerID string // the http endpoints being served (as a map since APIs like shellEndpoint can add more) Endpoints map[string]http.Handler @@ -107,24 +110,18 @@ type ClientCallContext struct { // If the client is itself from a function call in a user module, these are set with the // metadata of that ongoing function call - Module *Module FnCall *FunctionCall } -func (q *Query) ServeModuleToMainClient(ctx context.Context, modMeta dagql.Instance[*Module]) error { +func (q *Query) ServeModule(ctx context.Context, mod *Module) error { clientMetadata, err := engine.ClientMetadataFromContext(ctx) if err != nil { return err } - if clientMetadata.ModuleCallerDigest != "" { - return fmt.Errorf("cannot serve module to client %s", clientMetadata.ClientID) - } - - mod := modMeta.Self q.ClientCallMu.Lock() defer q.ClientCallMu.Unlock() - callCtx, ok := q.ClientCallContext[""] + callCtx, ok := q.ClientCallContext[clientMetadata.ClientID] if !ok { return fmt.Errorf("client call not found") } @@ -132,34 +129,65 @@ func (q *Query) ServeModuleToMainClient(ctx context.Context, modMeta dagql.Insta return nil } -func (q *Query) RegisterFunctionCall( - ctx context.Context, - dgst digest.Digest, - deps *ModDeps, - mod *Module, - call *FunctionCall, -) error { - if dgst == "" { - return fmt.Errorf("cannot register function call with empty digest") +func (q *Query) RegisterCaller(ctx context.Context, call *FunctionCall) (string, error) { + if call == nil { + call = &FunctionCall{} + } + callCtx := &ClientCallContext{ + FnCall: call, + } + + currentID := dagql.CurrentID(ctx) + clientIDInputs := []string{currentID.Digest().String()} + if !call.Cache { + // use the ServerID so that we bust cache once-per-session + clientMetadata, err := engine.ClientMetadataFromContext(ctx) + if err != nil { + return "", err + } + clientIDInputs = append(clientIDInputs, clientMetadata.ServerID) + } + clientIDDigest := digest.FromString(strings.Join(clientIDInputs, " ")) + + // only use encoded part of digest because this ID ends up becoming a buildkit Session ID + // and buildkit has some ancient internal logic that splits on a colon to support some + // dev mode logic: https://github.com/moby/buildkit/pull/290 + // also trim it to 25 chars as it ends up becoming part of service URLs + clientID := clientIDDigest.Encoded()[:25] + + slog.ExtraDebug("registering nested caller", + "client_id", clientID, + "op", currentID.Display(), + ) + + if call.Module == nil { + callCtx.Deps = q.DefaultDeps + } else { + callCtx.Deps = call.Module.Deps + // By default, serve both deps and the module's own API to itself. But if SkipSelfSchema is set, + // only serve the APIs of the deps of this module. This is currently only needed for the special + // case of the function used to get the definition of the module itself (which can't obviously + // be served the API its returning the definition of). + if !call.SkipSelfSchema { + callCtx.Deps = callCtx.Deps.Append(call.Module) + } } q.ClientCallMu.Lock() defer q.ClientCallMu.Unlock() - _, ok := q.ClientCallContext[dgst] + _, ok := q.ClientCallContext[clientID] if ok { - return nil + return clientID, nil } - newRoot, err := NewRoot(ctx, q.QueryOpts) + + var err error + callCtx.Root, err = NewRoot(ctx, q.QueryOpts) if err != nil { - return err - } - q.ClientCallContext[dgst] = &ClientCallContext{ - Root: newRoot, - Deps: deps, - Module: mod, - FnCall: call, + return "", err } - return nil + + q.ClientCallContext[clientID] = callCtx + return clientID, nil } func (q *Query) CurrentModule(ctx context.Context) (*Module, error) { @@ -167,17 +195,20 @@ func (q *Query) CurrentModule(ctx context.Context) (*Module, error) { if err != nil { return nil, err } - if clientMetadata.ModuleCallerDigest == "" { - return nil, fmt.Errorf("%w: main client caller has no module", ErrNoCurrentModule) + if clientMetadata.ClientID == q.MainClientCallerID { + return nil, fmt.Errorf("%w: main client caller has no current module", ErrNoCurrentModule) } q.ClientCallMu.RLock() defer q.ClientCallMu.RUnlock() - callCtx, ok := q.ClientCallContext[clientMetadata.ModuleCallerDigest] + callCtx, ok := q.ClientCallContext[clientMetadata.ClientID] if !ok { - return nil, fmt.Errorf("client call %s not found", clientMetadata.ModuleCallerDigest) + return nil, fmt.Errorf("client call %s not found", clientMetadata.ClientID) + } + if callCtx.FnCall.Module == nil { + return nil, ErrNoCurrentModule } - return callCtx.Module, nil + return callCtx.FnCall.Module, nil } func (q *Query) CurrentFunctionCall(ctx context.Context) (*FunctionCall, error) { @@ -185,15 +216,15 @@ func (q *Query) CurrentFunctionCall(ctx context.Context) (*FunctionCall, error) if err != nil { return nil, err } - if clientMetadata.ModuleCallerDigest == "" { + if clientMetadata.ClientID == q.MainClientCallerID { return nil, fmt.Errorf("%w: main client caller has no function", ErrNoCurrentModule) } q.ClientCallMu.RLock() defer q.ClientCallMu.RUnlock() - callCtx, ok := q.ClientCallContext[clientMetadata.ModuleCallerDigest] + callCtx, ok := q.ClientCallContext[clientMetadata.ClientID] if !ok { - return nil, fmt.Errorf("client call %s not found", clientMetadata.ModuleCallerDigest) + return nil, fmt.Errorf("client call %s not found", clientMetadata.ClientID) } return callCtx.FnCall, nil @@ -204,9 +235,9 @@ func (q *Query) CurrentServedDeps(ctx context.Context) (*ModDeps, error) { if err != nil { return nil, err } - callCtx, ok := q.ClientCallContext[clientMetadata.ModuleCallerDigest] + callCtx, ok := q.ClientCallContext[clientMetadata.ClientID] if !ok { - return nil, fmt.Errorf("client call %s not found", clientMetadata.ModuleCallerDigest) + return nil, fmt.Errorf("client call %s not found", clientMetadata.ClientID) } return callCtx.Deps, nil } @@ -262,11 +293,12 @@ func (q *Query) NewTunnelService(upstream dagql.Instance[*Service], ports []Port } } -func (q *Query) NewHostService(upstream string, ports []PortForward) *Service { +func (q *Query) NewHostService(upstream string, ports []PortForward, sessionID string) *Service { return &Service{ - Query: q, - HostUpstream: upstream, - HostPorts: ports, + Query: q, + HostUpstream: upstream, + HostPorts: ports, + HostSessionID: sessionID, } } diff --git a/core/schema/host.go b/core/schema/host.go index 28e9c3b3bda..c9cc745c214 100644 --- a/core/schema/host.go +++ b/core/schema/host.go @@ -160,6 +160,7 @@ func (s *hostSchema) Install() { `If ports are given and native is true, the ports are additive.`), dagql.Func("service", s.service). + Impure("Value depends on the caller as it points to their host."). Doc(`Creates a service that forwards traffic to a specified address via the host.`). ArgDoc("ports", `Ports to expose via the service, forwarding through the host network.`, @@ -168,6 +169,10 @@ func (s *hostSchema) Install() { `An empty set of ports is not valid; an error will be returned.`). ArgDoc("host", `Upstream host to forward traffic to.`), + // hidden from external clients via the __ prefix + dagql.Func("__internalService", s.internalService). + Doc(`(Internal-only) "service" but scoped to the exact right buildkit session ID.`), + dagql.Func("setSecretFile", s.setSecretFile). Impure("`setSecretFile` reads its value from the local machine."). Doc( @@ -207,7 +212,7 @@ func (s *hostSchema) socket(ctx context.Context, host *core.Host, args hostSocke if err != nil { return nil, fmt.Errorf("failed to get client metadata: %w", err) } - if clientMetadata.ClientID != host.Query.Buildkit.MainClientCallerID { + if clientMetadata.ClientID != host.Query.MainClientCallerID { return nil, fmt.Errorf("only the main client can access the host's unix sockets") } @@ -279,10 +284,58 @@ type hostServiceArgs struct { Ports []dagql.InputObject[core.PortForward] } -func (s *hostSchema) service(ctx context.Context, parent *core.Host, args hostServiceArgs) (*core.Service, error) { +func (s *hostSchema) service(ctx context.Context, parent *core.Host, args hostServiceArgs) (inst dagql.Instance[*core.Service], err error) { if len(args.Ports) == 0 { - return nil, errors.New("no ports specified") + return inst, errors.New("no ports specified") + } + + clientMetadata, err := engine.ClientMetadataFromContext(ctx) + if err != nil { + return inst, fmt.Errorf("failed to get client metadata: %w", err) + } + + portsArg := make(dagql.ArrayInput[dagql.InputObject[core.PortForward]], len(args.Ports)) + copy(portsArg, args.Ports) + + err = s.srv.Select(ctx, s.srv.Root(), &inst, + dagql.Selector{ + Field: "host", + }, + dagql.Selector{ + Field: "__internalService", + Args: []dagql.NamedInput{ + { + Name: "host", + Value: dagql.NewString(args.Host), + }, + { + Name: "ports", + Value: portsArg, + }, + { + Name: "sessionId", + Value: dagql.NewString(clientMetadata.BuildkitSessionID()), + }, + }, + }, + ) + return inst, err +} + +type hostInternalServiceArgs struct { + Host string `default:"localhost"` + Ports []dagql.InputObject[core.PortForward] + SessionID string +} + +func (s *hostSchema) internalService(ctx context.Context, parent *core.Host, args hostInternalServiceArgs) (*core.Service, error) { + if args.SessionID == "" { + return nil, errors.New("no session ID specified") } - return parent.Query.NewHostService(args.Host, collectInputsSlice(args.Ports)), nil + return parent.Query.NewHostService( + args.Host, + collectInputsSlice(args.Ports), + args.SessionID, + ), nil } diff --git a/core/schema/http.go b/core/schema/http.go index d052c9e7146..a581317bdb5 100644 --- a/core/schema/http.go +++ b/core/schema/http.go @@ -60,26 +60,11 @@ func (s *httpSchema) http(ctx context.Context, parent *core.Query, args httpArgs llb.Filename(filename), } - useDNS := len(svcs) > 0 - clientMetadata, err := engine.ClientMetadataFromContext(ctx) - if err == nil && !useDNS { - useDNS = len(clientMetadata.ParentClientIDs) > 0 - } - - var st llb.State - if useDNS { - // NB: only configure search domains if we're directly using a service, or - // if we're nested. - // - // we have to be a bit selective here to avoid breaking Dockerfile builds - // that use a Buildkit frontend (# syntax = ...). - // - // TODO: add API cap - st = httpdns.HTTP(args.URL, clientMetadata.ClientIDs(), opts...) - } else { - st = llb.HTTP(args.URL, opts...) + if err != nil { + return nil, err } + st := httpdns.HTTP(args.URL, clientMetadata.ServerID, opts...) return core.NewFileSt(ctx, parent, st, filename, parent.Platform, svcs) } diff --git a/core/schema/module.go b/core/schema/module.go index 49d8d266fe6..dbbc2a1c23f 100644 --- a/core/schema/module.go +++ b/core/schema/module.go @@ -465,7 +465,7 @@ func (s *moduleSchema) currentFunctionCall(ctx context.Context, self *core.Query } func (s *moduleSchema) moduleServe(ctx context.Context, modMeta dagql.Instance[*core.Module], _ struct{}) (dagql.Nullable[core.Void], error) { - return dagql.Null[core.Void](), modMeta.Self.Query.ServeModuleToMainClient(ctx, modMeta) + return dagql.Null[core.Void](), modMeta.Self.Query.ServeModule(ctx, modMeta.Self) } func (s *moduleSchema) currentTypeDefs(ctx context.Context, self *core.Query, _ struct{}) ([]*core.TypeDef, error) { diff --git a/core/service.go b/core/service.go index 13e661bd688..c779141768f 100644 --- a/core/service.go +++ b/core/service.go @@ -44,6 +44,8 @@ type Service struct { HostUpstream string `json:"reverse_tunnel_upstream_addr,omitempty"` // HostPorts configures the port forwarding rules for the host. HostPorts []PortForward `json:"host_ports,omitempty"` + // HostSessionID is the session ID of the host (could differ from main client in the case of nested execs). + HostSessionID string `json:"host_session_id,omitempty"` } func (*Service) Type() *ast.Type { @@ -260,7 +262,7 @@ func (svc *Service) startContainer( } }() - fullHost := host + "." + network.ClientDomain(clientMetadata.ClientID) + fullHost := host + "." + network.ClientDomain(clientMetadata.ServerID) bk := svc.Query.Buildkit @@ -328,8 +330,7 @@ func (svc *Service) startContainer( } execMeta := buildkit.ContainerExecUncachedMetadata{ - ParentClientIDs: clientMetadata.ClientIDs(), - ServerID: clientMetadata.ServerID, + ServerID: clientMetadata.ServerID, } execOp.Meta.ProxyEnv.FtpProxy, err = execMeta.ToPBFtpProxyVal() if err != nil { @@ -596,7 +597,7 @@ func (svc *Service) startReverseTunnel(ctx context.Context, id *call.ID) (runnin return nil, err } - fullHost := host + "." + network.ClientDomain(clientMetadata.ClientID) + fullHost := host + "." + network.ClientDomain(clientMetadata.ServerID) bk := svc.Query.Buildkit @@ -605,6 +606,7 @@ func (svc *Service) startReverseTunnel(ctx context.Context, id *call.ID) (runnin upstreamHost: svc.HostUpstream, tunnelServiceHost: fullHost, tunnelServicePorts: svc.HostPorts, + sessionID: svc.HostSessionID, } checkPorts := []Port{} diff --git a/core/socket.go b/core/socket.go index a5b7ad1ff50..5b6958653cd 100644 --- a/core/socket.go +++ b/core/socket.go @@ -17,6 +17,9 @@ type Socket struct { // IP HostProtocol string `json:"host_protocol,omitempty"` HostAddr string `json:"host_addr,omitempty"` + + // The session ID of the host's client + SessionID string `json:"session_id,omitempty"` } func (*Socket) Type() *ast.Type { @@ -36,10 +39,11 @@ func NewHostUnixSocket(absPath string) *Socket { } } -func NewHostIPSocket(proto string, addr string) *Socket { +func NewHostIPSocket(proto string, addr string, sessionID string) *Socket { return &Socket{ HostAddr: addr, HostProtocol: proto, + SessionID: sessionID, } } @@ -52,6 +56,7 @@ func (socket *Socket) SSHID() string { default: u.Scheme = socket.HostProtocol u.Host = socket.HostAddr + u.Fragment = socket.SessionID } return u.String() } diff --git a/core/typedef.go b/core/typedef.go index 781edbb15c0..3b67e9689c7 100644 --- a/core/typedef.go +++ b/core/typedef.go @@ -849,6 +849,17 @@ type FunctionCall struct { ParentName string `field:"true" doc:"The name of the parent object of the function being called. If the function is top-level to the module, this is the name of the module."` Parent JSON `field:"true" doc:"The value of the parent object of the function being called. If the function is top-level to the module, this is always an empty object."` InputArgs []*FunctionCallArgValue `field:"true" doc:"The argument values the function is being invoked with."` + + // Below are not in public API + + // The module that the function is being called from + Module *Module + + // Whether the function call should be cached across different servers + Cache bool + + // Whether to serve the schema for the function's own module to it or not + SkipSelfSchema bool } func (*FunctionCall) Type() *ast.Type { diff --git a/engine/buildkit/client.go b/engine/buildkit/client.go index 04da992f04d..08ffbc84917 100644 --- a/engine/buildkit/client.go +++ b/engine/buildkit/client.go @@ -40,7 +40,6 @@ import ( "google.golang.org/grpc/metadata" "github.com/dagger/dagger/auth" - "github.com/dagger/dagger/engine" "github.com/dagger/dagger/engine/session" ) @@ -63,11 +62,10 @@ type Opts struct { // client. It is special in that when it shuts down, the client will be closed and // that registry auth and sockets are currently only ever sourced from this caller, // not any nested clients (may change in future). - MainClientCaller bksession.Caller - MainClientCallerID string - DNSConfig *oci.DNSConfig - Frontends map[string]bkfrontend.Frontend - BuildkitLogSink io.Writer + MainClientCaller bksession.Caller + DNSConfig *oci.DNSConfig + Frontends map[string]bkfrontend.Frontend + BuildkitLogSink io.Writer sharedClientState } @@ -577,13 +575,7 @@ func (c *Client) ListenHostToContainer( return nil, nil, err } - clientMetadata, err := engine.ClientMetadataFromContext(ctx) - if err != nil { - cancel() - return nil, nil, fmt.Errorf("failed to get requester session ID: %w", err) - } - - clientCaller, err := c.SessionManager.Get(ctx, clientMetadata.ClientID, false) + clientCaller, err := c.GetSessionCaller(ctx, false) if err != nil { cancel() return nil, nil, fmt.Errorf("failed to get requester session: %w", err) @@ -722,8 +714,7 @@ func withOutgoingContext(ctx context.Context) context.Context { // the "real" ftp proxy setting in here too and have the shim handle // leaving only that set in the actual env var. type ContainerExecUncachedMetadata struct { - ParentClientIDs []string `json:"parentClientIDs,omitempty"` - ServerID string `json:"serverID,omitempty"` + ServerID string `json:"serverID,omitempty"` } func (md ContainerExecUncachedMetadata) ToPBFtpProxyVal() (string, error) { diff --git a/engine/buildkit/containerimage.go b/engine/buildkit/containerimage.go index a9ec7920d20..78a5f1728ce 100644 --- a/engine/buildkit/containerimage.go +++ b/engine/buildkit/containerimage.go @@ -108,7 +108,7 @@ func (c *Client) ExportContainerImage( IsFileStream: true, }.AppendToOutgoingContext(ctx) - resp, descRef, err := expInstance.Export(ctx, combinedResult, nil, clientMetadata.ClientID) + resp, descRef, err := expInstance.Export(ctx, combinedResult, nil, clientMetadata.BuildkitSessionID()) if err != nil { return nil, fmt.Errorf("failed to export: %w", err) } diff --git a/engine/buildkit/filesync.go b/engine/buildkit/filesync.go index c45b5b75e58..7606f91f23d 100644 --- a/engine/buildkit/filesync.go +++ b/engine/buildkit/filesync.go @@ -45,7 +45,7 @@ func (c *Client) LocalImport( } localOpts := []llb.LocalOption{ - llb.SessionID(clientMetadata.ClientID), + llb.SessionID(clientMetadata.BuildkitSessionID()), llb.SharedKeyHint(strings.Join([]string{clientMetadata.ClientHostname, srcPath}, " ")), } @@ -108,13 +108,9 @@ func (c *Client) diffcopy(ctx context.Context, opts engine.LocalImportOpts, msg } defer cancel() - clientMetadata, err := engine.ClientMetadataFromContext(ctx) - if err != nil { - return fmt.Errorf("failed to get requester session ID: %w", err) - } ctx = opts.AppendToOutgoingContext(ctx) - clientCaller, err := c.SessionManager.Get(ctx, clientMetadata.ClientID, false) + clientCaller, err := c.GetSessionCaller(ctx, true) if err != nil { return fmt.Errorf("failed to get requester session: %w", err) } @@ -213,7 +209,7 @@ func (c *Client) LocalDirExport( Merge: merge, }.AppendToOutgoingContext(ctx) - _, descRef, err := expInstance.Export(ctx, cacheRes, nil, clientMetadata.ClientID) + _, descRef, err := expInstance.Export(ctx, cacheRes, nil, clientMetadata.BuildkitSessionID()) if err != nil { return fmt.Errorf("failed to export: %w", err) } @@ -288,11 +284,6 @@ func (c *Client) LocalFileExport( return fmt.Errorf("failed to stat file: %w", err) } - clientMetadata, err := engine.ClientMetadataFromContext(ctx) - if err != nil { - return fmt.Errorf("failed to get requester session ID: %w", err) - } - ctx = engine.LocalExportOpts{ Path: destPath, IsFileStream: true, @@ -301,7 +292,7 @@ func (c *Client) LocalFileExport( FileMode: stat.Mode().Perm(), }.AppendToOutgoingContext(ctx) - clientCaller, err := c.SessionManager.Get(ctx, clientMetadata.ClientID, false) + clientCaller, err := c.GetSessionCaller(ctx, true) if err != nil { return fmt.Errorf("failed to get requester session: %w", err) } @@ -357,11 +348,6 @@ func (c *Client) IOReaderExport(ctx context.Context, r io.Reader, destPath strin lg.Trace("finished exporting bytes") }() - clientMetadata, err := engine.ClientMetadataFromContext(ctx) - if err != nil { - return fmt.Errorf("failed to get requester session ID: %w", err) - } - ctx = engine.LocalExportOpts{ Path: destPath, IsFileStream: true, @@ -369,7 +355,7 @@ func (c *Client) IOReaderExport(ctx context.Context, r io.Reader, destPath strin FileMode: destMode, }.AppendToOutgoingContext(ctx) - clientCaller, err := c.SessionManager.Get(ctx, clientMetadata.ClientID, false) + clientCaller, err := c.GetSessionCaller(ctx, true) if err != nil { return fmt.Errorf("failed to get requester session: %w", err) } diff --git a/engine/buildkit/session.go b/engine/buildkit/session.go index 7ba27c76c6f..694dcddaa34 100644 --- a/engine/buildkit/session.go +++ b/engine/buildkit/session.go @@ -15,6 +15,7 @@ import ( "github.com/moby/buildkit/session/secrets/secretsprovider" "github.com/moby/buildkit/util/bklog" + "github.com/dagger/dagger/engine" "github.com/dagger/dagger/engine/client" "github.com/dagger/dagger/engine/distconsts" ) @@ -85,7 +86,17 @@ func (c *Client) newSession() (*bksession.Session, error) { return sess, nil } -func (c *Client) GetSessionCaller(ctx context.Context, clientID string) (bksession.Caller, error) { - waitForSession := true - return c.SessionManager.Get(ctx, clientID, !waitForSession) +func (c *Client) GetSessionCaller(ctx context.Context, wait bool) (bksession.Caller, error) { + clientMetadata, err := engine.ClientMetadataFromContext(ctx) + if err != nil { + return nil, err + } + caller, err := c.SessionManager.Get(ctx, clientMetadata.BuildkitSessionID(), !wait) + if err != nil { + return nil, err + } + if caller == nil { + return nil, fmt.Errorf("session for %q not found", clientMetadata.BuildkitSessionID()) + } + return caller, nil } diff --git a/engine/buildkit/socket.go b/engine/buildkit/socket.go index 3f8c6bf5020..789c6d336d1 100644 --- a/engine/buildkit/socket.go +++ b/engine/buildkit/socket.go @@ -2,11 +2,15 @@ package buildkit import ( "context" + "fmt" + "net/url" "github.com/moby/buildkit/session/sshforward" "go.opentelemetry.io/otel/trace" "google.golang.org/grpc" + "google.golang.org/grpc/codes" "google.golang.org/grpc/metadata" + "google.golang.org/grpc/status" ) type socketProxy struct { @@ -29,10 +33,29 @@ func (p *socketProxy) ForwardAgent(stream sshforward.SSH_ForwardAgentServer) err ctx = trace.ContextWithSpanContext(ctx, p.c.spanCtx) // ensure server's span context is propagated - incomingMD, _ := metadata.FromIncomingContext(ctx) - ctx = metadata.NewOutgoingContext(ctx, incomingMD) + opts, _ := metadata.FromIncomingContext(ctx) + ctx = metadata.NewOutgoingContext(ctx, opts) - forwardAgentClient, err := sshforward.NewSSHClient(p.c.MainClientCaller.Conn()).ForwardAgent(ctx) + var connURL *url.URL + if v, ok := opts[sshforward.KeySSHID]; ok && len(v) > 0 && v[0] != "" { + var err error + connURL, err = url.Parse(v[0]) + if err != nil { + return status.Errorf(codes.InvalidArgument, "invalid id: %s", err) + } + } + + caller := p.c.MainClientCaller + if connURL != nil && connURL.Fragment != "" { + sessionID := connURL.Fragment + var err error + caller, err = p.c.SessionManager.Get(ctx, sessionID, true) + if err != nil { + return fmt.Errorf("failed to get session: %w", err) + } + } + + forwardAgentClient, err := sshforward.NewSSHClient(caller.Conn()).ForwardAgent(ctx) if err != nil { return err } diff --git a/engine/client/client.go b/engine/client/client.go index 64b731f24db..f60a0111841 100644 --- a/engine/client/client.go +++ b/engine/client/client.go @@ -31,7 +31,6 @@ import ( "github.com/moby/buildkit/session/filesync" "github.com/moby/buildkit/session/grpchijack" "github.com/moby/buildkit/util/grpcerrors" - "github.com/opencontainers/go-digest" "github.com/tonistiigi/fsutil" fstypes "github.com/tonistiigi/fsutil/types" sdktrace "go.opentelemetry.io/otel/sdk/trace" @@ -53,15 +52,14 @@ import ( ) type Params struct { - // The id of the server to connect to, or if blank a new one - // should be started. - ServerID string + // The id to connect to the API server with. If blank, will be set to a + // new random value. + ID string - // Parent client IDs of this Dagger client. - // - // Used by Dagger-in-Dagger so that nested sessions can resolve addresses - // passed from the parent. - ParentClientIDs []string + // The id of the server to connect to, or if blank a new one should be started. + // Needed separately from the client ID as that ID is a digest which could + // be reused across multiple servers. + ServerID string SecretToken string @@ -72,14 +70,8 @@ type Params struct { EngineNameCallback func(string) CloudURLCallback func(string) - - // If this client is for a module function, this digest will be set in the - // grpc context metadata for any api requests back to the engine. It's used by the API - // server to determine which schema to serve and other module context metadata. - ModuleCallerDigest digest.Digest - - EngineTrace sdktrace.SpanExporter - EngineLogs sdklog.LogExporter + EngineTrace sdktrace.SpanExporter + EngineLogs sdklog.LogExporter } type Client struct { @@ -119,12 +111,16 @@ type Client struct { func Connect(ctx context.Context, params Params) (_ *Client, _ context.Context, rerr error) { c := &Client{Params: params} - if c.SecretToken == "" { - c.SecretToken = uuid.New().String() + if c.ID == "" { + c.ID = identity.NewID() } + configuredServerID := c.ServerID if c.ServerID == "" { c.ServerID = identity.NewID() } + if c.SecretToken == "" { + c.SecretToken = uuid.New().String() + } // keep the root ctx around so we can detect whether we've been interrupted, // so we can drain immediately in that scenario @@ -181,7 +177,8 @@ func Connect(ctx context.Context, params Params) (_ *Client, _ context.Context, telemetry.Encapsulate(), } - if c.Params.ModuleCallerDigest != "" { + if configuredServerID != "" { + // infer that this is not a main client caller, server ID is never set for those currently connectSpanOpts = append(connectSpanOpts, telemetry.Internal()) } @@ -320,13 +317,10 @@ func (c *Client) startSession(ctx context.Context) (rerr error) { }() c.internalCtx = engine.ContextWithClientMetadata(c.internalCtx, &engine.ClientMetadata{ - ClientID: c.ID(), - ClientSecretToken: c.SecretToken, - ServerID: c.ServerID, - ClientHostname: c.hostname, - Labels: c.labels, - ParentClientIDs: c.ParentClientIDs, - ModuleCallerDigest: c.ModuleCallerDigest, + ClientID: c.ID, + ClientSecretToken: c.SecretToken, + ClientHostname: c.hostname, + Labels: c.labels, }) // filesync @@ -352,15 +346,13 @@ func (c *Client) startSession(ctx context.Context) (rerr error) { return bkSession.Run(c.internalCtx, func(ctx context.Context, proto string, meta map[string][]string) (net.Conn, error) { return grpchijack.Dialer(c.bkClient.ControlClient())(ctx, proto, engine.ClientMetadata{ RegisterClient: true, - ClientID: c.ID(), - ClientSecretToken: c.SecretToken, + ClientID: c.ID, ServerID: c.ServerID, - ParentClientIDs: c.ParentClientIDs, + ClientSecretToken: c.SecretToken, ClientHostname: hostname, UpstreamCacheImportConfig: c.upstreamCacheImportOptions, UpstreamCacheExportConfig: c.upstreamCacheExportOptions, Labels: c.labels, - ModuleCallerDigest: c.ModuleCallerDigest, CloudToken: os.Getenv("DAGGER_CLOUD_TOKEN"), DoNotTrack: analytics.DoNotTrack(), }.AppendToMD(meta)) @@ -609,10 +601,6 @@ func (c *Client) withClientCloseCancel(ctx context.Context) (context.Context, co return ctx, cancel, nil } -func (c *Client) ID() string { - return c.bkSession.ID() -} - func (c *Client) DialContext(ctx context.Context, _, _ string) (conn net.Conn, err error) { // NOTE: the context given to grpchijack.Dialer is for the lifetime of the stream. // If http connection re-use is enabled, that can be far past this DialContext call. @@ -629,13 +617,11 @@ func (c *Client) DialContext(ctx context.Context, _, _ string) (conn net.Conn, e }).Dial("tcp", "127.0.0.1:"+strconv.Itoa(c.nestedSessionPort)) } else { conn, err = grpchijack.Dialer(c.bkClient.ControlClient())(ctx, "", engine.ClientMetadata{ - ClientID: c.ID(), - ClientSecretToken: c.SecretToken, - ServerID: c.ServerID, - ClientHostname: c.hostname, - ParentClientIDs: c.ParentClientIDs, - Labels: c.labels, - ModuleCallerDigest: c.ModuleCallerDigest, + ClientID: c.ID, + ServerID: c.ServerID, + ClientSecretToken: c.SecretToken, + ClientHostname: c.hostname, + Labels: c.labels, }.ToGRPCMD()) } if err != nil { diff --git a/engine/opts.go b/engine/opts.go index 12f55f6fecf..2e7866b16d9 100644 --- a/engine/opts.go +++ b/engine/opts.go @@ -9,10 +9,10 @@ import ( "os" "path/filepath" "strconv" + "strings" "unicode" controlapi "github.com/moby/buildkit/api/services/control" - "github.com/opencontainers/go-digest" "google.golang.org/grpc/metadata" "github.com/dagger/dagger/telemetry" @@ -33,11 +33,12 @@ const ( ) type ClientMetadata struct { - // ClientID is unique to every session created by every client + // ClientID is unique to each client. The main client's ID is the empty string, + // any module and/or nested exec client's ID is a unique digest. ClientID string `json:"client_id"` // ClientSecretToken is a secret token that is unique to every client. It's - // initially provided to the server in the controller.Solve request. Every + // initially provided to the server in the controller.Session request. Every // other request w/ that client ID must also include the same token. ClientSecretToken string `json:"client_secret_token"` @@ -61,16 +62,6 @@ type ClientMetadata struct { // (Optional) Pipeline labels for e.g. vcs info like branch, commit, etc. Labels telemetry.Labels `json:"labels"` - // ParentClientIDs is a list of session ids that are parents of the current - // session. The first element is the direct parent, the second element is the - // parent of the parent, and so on. - ParentClientIDs []string `json:"parent_client_ids"` - - // If this client is for a module function, this digest will be set in the - // grpc context metadata for any api requests back to the engine. It's used by the API - // server to determine which schema to serve and other module context metadata. - ModuleCallerDigest digest.Digest `json:"module_caller_digest"` - // Import configuration for Buildkit's remote cache UpstreamCacheImportConfig []*controlapi.CacheOptionsEntry @@ -84,11 +75,6 @@ type ClientMetadata struct { DoNotTrack bool } -// ClientIDs returns the ClientID followed by ParentClientIDs. -func (m ClientMetadata) ClientIDs() []string { - return append([]string{m.ClientID}, m.ParentClientIDs...) -} - func (m ClientMetadata) ToGRPCMD() metadata.MD { return encodeMeta(clientMetadataMetaKey, m) } @@ -100,6 +86,15 @@ func (m ClientMetadata) AppendToMD(md metadata.MD) metadata.MD { return md } +// The ID to use for this client's buildkit session. It's a combination of both +// the client and the server IDs to account for the fact that the client ID is +// a content digest for functions/nested-execs, meaning it can reoccur across +// different servers; that doesn't work because buildkit's SessionManager is +// global to the whole process. +func (m ClientMetadata) BuildkitSessionID() string { + return strings.Join([]string{m.ClientID, m.ServerID}, "-") +} + func ContextWithClientMetadata(ctx context.Context, clientMetadata *ClientMetadata) context.Context { return contextWithMD(ctx, clientMetadata.ToGRPCMD()) } diff --git a/engine/server/buildkitcontroller.go b/engine/server/buildkitcontroller.go index 211a547c989..450df5c6e92 100644 --- a/engine/server/buildkitcontroller.go +++ b/engine/server/buildkitcontroller.go @@ -154,7 +154,6 @@ func (e *BuildkitController) Session(stream controlapi.Control_SessionServer) (r ctx = bklog.WithLogger(ctx, bklog.G(ctx). WithField("client_id", opts.ClientID). WithField("client_hostname", opts.ClientHostname). - WithField("client_call_digest", opts.ModuleCallerDigest). WithField("server_id", opts.ServerID)) { @@ -206,6 +205,15 @@ func (e *BuildkitController) Session(stream controlapi.Control_SessionServer) (r eg, egctx := errgroup.WithContext(ctx) eg.Go(func() error { + // overwrite the session ID to be our client ID + server ID + const sessionIDHeader = "x-docker-expose-session-uuid" + if _, ok := hijackmd[sessionIDHeader]; !ok { + // should never happen unless upstream changes the value of the header key, + // in which case we want to know + panic(fmt.Errorf("missing header %s", sessionIDHeader)) + } + hijackmd[sessionIDHeader] = []string{opts.BuildkitSessionID()} + bklog.G(ctx).Trace("session manager handling conn") err := e.SessionManager.HandleConn(egctx, conn, hijackmd) bklog.G(ctx).WithError(err).Trace("session manager handle conn done") @@ -259,6 +267,12 @@ func (e *BuildkitController) Session(stream controlapi.Control_SessionServer) (r if err != nil { return fmt.Errorf("failed to register client: %w", err) } + defer func() { + err := srv.UnregisterClient(opts.ClientID) + if err != nil { + slog.Error("failed to unregister client", "err", err) + } + }() eg.Go(func() error { bklog.G(ctx).Trace("waiting for server") diff --git a/engine/server/server.go b/engine/server/server.go index 9cfe4cad0a7..f7bca64d2dd 100644 --- a/engine/server/server.go +++ b/engine/server/server.go @@ -19,7 +19,6 @@ import ( bkgw "github.com/moby/buildkit/frontend/gateway/client" "github.com/moby/buildkit/session" "github.com/moby/buildkit/util/bklog" - "github.com/opencontainers/go-digest" "github.com/sirupsen/logrus" "github.com/vektah/gqlparser/v2/gqlerror" "go.opentelemetry.io/otel/propagation" @@ -46,10 +45,9 @@ type DaggerServer struct { clientIDMu sync.RWMutex // The metadata of client calls. - // For the special case of the main client caller, the key is just empty string. // This is never explicitly deleted from; instead it will just be garbage collected // when this server for the session shuts down - clientCallContext map[digest.Digest]*core.ClientCallContext + clientCallContext map[string]*core.ClientCallContext clientCallMu *sync.RWMutex // the http endpoints being served (as a map since APIs like shellEndpoint can add more) @@ -74,7 +72,7 @@ func (e *BuildkitController) newDaggerServer(ctx context.Context, clientMetadata serverID: clientMetadata.ServerID, clientIDToSecretToken: map[string]string{}, - clientCallContext: map[digest.Digest]*core.ClientCallContext{}, + clientCallContext: map[string]*core.ClientCallContext{}, clientCallMu: &sync.RWMutex{}, endpoints: map[string]http.Handler{}, endpointMu: &sync.RWMutex{}, @@ -108,7 +106,7 @@ func (e *BuildkitController) newDaggerServer(ctx context.Context, clientMetadata getSessionCtx, getSessionCancel := context.WithTimeout(ctx, 10*time.Second) defer getSessionCancel() - sessionCaller, err := e.SessionManager.Get(getSessionCtx, clientMetadata.ClientID, false) + sessionCaller, err := e.SessionManager.Get(getSessionCtx, clientMetadata.BuildkitSessionID(), false) if err != nil { return nil, fmt.Errorf("get session: %w", err) } @@ -149,21 +147,21 @@ func (e *BuildkitController) newDaggerServer(ctx context.Context, clientMetadata PrivilegedExecEnabled: e.privilegedExecEnabled, UpstreamCacheImports: cacheImporterCfgs, MainClientCaller: sessionCaller, - MainClientCallerID: s.mainClientCallerID, DNSConfig: e.DNSConfig, Frontends: e.Frontends, BuildkitLogSink: e.BuildkitLogSink, }, - Services: s.services, - Platform: core.Platform(e.worker.Platforms(true)[0]), - Secrets: secretStore, - OCIStore: e.worker.ContentStore(), - LeaseManager: e.worker.LeaseManager(), - Auth: authProvider, - ClientCallContext: s.clientCallContext, - ClientCallMu: s.clientCallMu, - Endpoints: s.endpoints, - EndpointMu: s.endpointMu, + Services: s.services, + Platform: core.Platform(e.worker.Platforms(true)[0]), + Secrets: secretStore, + OCIStore: e.worker.ContentStore(), + LeaseManager: e.worker.LeaseManager(), + Auth: authProvider, + ClientCallContext: s.clientCallContext, + ClientCallMu: s.clientCallMu, + MainClientCallerID: s.mainClientCallerID, + Endpoints: s.endpoints, + EndpointMu: s.endpointMu, }) if err != nil { return nil, err @@ -183,7 +181,7 @@ func (e *BuildkitController) newDaggerServer(ctx context.Context, clientMetadata } // the main client caller starts out with the core API loaded - s.clientCallContext[""] = &core.ClientCallContext{ + s.clientCallContext[s.mainClientCallerID] = &core.ClientCallContext{ Deps: root.DefaultDeps, Root: root, } @@ -251,18 +249,21 @@ func (s *DaggerServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } - callContext, ok := s.ClientCallContext(clientMetadata.ModuleCallerDigest) + callContext, ok := s.ClientCallContext(clientMetadata.ClientID) if !ok { - errorOut(fmt.Errorf("client call %s not found", clientMetadata.ModuleCallerDigest), http.StatusInternalServerError) + errorOut(fmt.Errorf("client call for %s not found", clientMetadata.ClientID), http.StatusInternalServerError) return } + s.clientCallMu.RLock() schema, err := callContext.Deps.Schema(ctx) if err != nil { + s.clientCallMu.RUnlock() // TODO: technically this is not *always* bad request, should ideally be more specific and differentiate errorOut(err, http.StatusBadRequest) return } + s.clientCallMu.RUnlock() defer func() { if v := recover(); v != nil { @@ -314,12 +315,10 @@ func (s *DaggerServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { slog := slog.With( "isImmediate", immediate, "isMainClient", clientMetadata.ClientID == s.mainClientCallerID, - "isModule", clientMetadata.ModuleCallerDigest != "", "serverID", s.serverID, "traceID", s.traceID, "clientID", clientMetadata.ClientID, - "mainClientID", s.mainClientCallerID, - "callerID", clientMetadata.ModuleCallerDigest) + "mainClientID", s.mainClientCallerID) slog.Trace("shutting down server") defer slog.Trace("done shutting down server") @@ -346,7 +345,7 @@ func (s *DaggerServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { } } s.clientCallMu.RLock() - bk := s.clientCallContext[""].Root.Buildkit + bk := s.clientCallContext[s.mainClientCallerID].Root.Buildkit s.clientCallMu.RUnlock() err := bk.UpstreamCacheExport(ctx, cacheExporterFuncs) if err != nil { @@ -383,10 +382,6 @@ func (s *DaggerServer) RegisterClient(clientID, clientHostname, secretToken stri return nil } s.clientIDToSecretToken[clientID] = secretToken - // NOTE: we purposely don't delete the secret token, it should never be reused and will be released - // from memory once the dagger server instance corresponding to this buildkit client shuts down. - // Deleting it would make it easier to create race conditions around using the client's session - // before it is fully closed. return nil } @@ -404,10 +399,17 @@ func (s *DaggerServer) VerifyClient(clientID, secretToken string) error { return nil } -func (s *DaggerServer) ClientCallContext(clientDigest digest.Digest) (*core.ClientCallContext, bool) { +func (s *DaggerServer) UnregisterClient(clientID string) error { + s.clientIDMu.Lock() + defer s.clientIDMu.Unlock() + delete(s.clientIDToSecretToken, clientID) + return nil +} + +func (s *DaggerServer) ClientCallContext(clientID string) (*core.ClientCallContext, bool) { s.clientCallMu.RLock() defer s.clientCallMu.RUnlock() - ctx, ok := s.clientCallContext[clientDigest] + ctx, ok := s.clientCallContext[clientID] return ctx, ok } @@ -416,9 +418,9 @@ func (s *DaggerServer) CurrentServedDeps(ctx context.Context) (*core.ModDeps, er if err != nil { return nil, err } - callCtx, ok := s.ClientCallContext(clientMetadata.ModuleCallerDigest) + callCtx, ok := s.ClientCallContext(clientMetadata.ClientID) if !ok { - return nil, fmt.Errorf("client call %s not found", clientMetadata.ModuleCallerDigest) + return nil, fmt.Errorf("client call for %s not found", clientMetadata.ClientID) } return callCtx.Deps, nil } diff --git a/engine/sources/gitdns/identifier.go b/engine/sources/gitdns/identifier.go index ae78e187003..6ba5611ce28 100644 --- a/engine/sources/gitdns/identifier.go +++ b/engine/sources/gitdns/identifier.go @@ -4,10 +4,10 @@ import ( bkgit "github.com/moby/buildkit/source/git" ) -const AttrGitClientIDs = "dagger.git.clientids" +const AttrDNSNamespace = "dagger.dns.namespace" type GitIdentifier struct { bkgit.GitIdentifier - ClientIDs []string + Namespace string } diff --git a/engine/sources/gitdns/source.go b/engine/sources/gitdns/source.go index 31f8ab61e3f..26e1ea508d2 100644 --- a/engine/sources/gitdns/source.go +++ b/engine/sources/gitdns/source.go @@ -78,8 +78,8 @@ func (gs *gitSource) Identifier(scheme, ref string, attrs map[string]string, pla GitIdentifier: *(srcid.(*srcgit.GitIdentifier)), } - if v, ok := attrs[AttrGitClientIDs]; ok { - id.ClientIDs = strings.Split(v, ",") + if v, ok := attrs[AttrDNSNamespace]; ok { + id.Namespace = v } return id, nil @@ -322,10 +322,9 @@ func (gs *gitSourceHandler) mountKnownHosts() (string, func() error, error) { func (gs *gitSourceHandler) dnsConfig() *oci.DNSConfig { clientDomains := []string{} - for _, clientID := range gs.src.ClientIDs { - clientDomains = append(clientDomains, network.ClientDomain(clientID)) + if gs.src.Namespace != "" { + clientDomains = append(clientDomains, network.ClientDomain(gs.src.Namespace)) } - dns := *gs.dns dns.SearchDomains = append(clientDomains, dns.SearchDomains...) return &dns diff --git a/engine/sources/gitdns/state.go b/engine/sources/gitdns/state.go index 25192b3a4bb..b5e6cbc9d19 100644 --- a/engine/sources/gitdns/state.go +++ b/engine/sources/gitdns/state.go @@ -2,7 +2,6 @@ package gitdns import ( "path" - "strings" "github.com/moby/buildkit/client/llb" "github.com/moby/buildkit/solver/pb" @@ -11,11 +10,9 @@ import ( "github.com/pkg/errors" ) -const AttrNetConfig = "gitdns.netconfig" - // Git is a helper mimicking the llb.Git function, but with the ability to // set additional attributes. -func Git(url, ref string, clientIDs []string, opts ...llb.GitOption) llb.State { +func Git(url, ref string, namespace string, opts ...llb.GitOption) llb.State { remote, err := gitutil.ParseURL(url) if errors.Is(err, gitutil.ErrUnknownProtocol) { url = "https://" + url @@ -78,7 +75,7 @@ func Git(url, ref string, clientIDs []string, opts ...llb.GitOption) llb.State { } } - attrs[AttrGitClientIDs] = strings.Join(clientIDs, ",") + attrs[AttrDNSNamespace] = namespace source := llb.NewSource("git://"+id, attrs, gi.Constraints) return llb.NewState(source.Output()) diff --git a/engine/sources/httpdns/identifier.go b/engine/sources/httpdns/identifier.go index 47275d6beff..b30ec9e7c91 100644 --- a/engine/sources/httpdns/identifier.go +++ b/engine/sources/httpdns/identifier.go @@ -4,10 +4,10 @@ import ( bkhttp "github.com/moby/buildkit/source/http" ) -const AttrHTTPClientIDs = "dagger.http.clientids" +const AttrDNSNamespace = "dagger.dns.namespace" type HTTPIdentifier struct { bkhttp.HTTPIdentifier - ClientIDs []string + Namespace string } diff --git a/engine/sources/httpdns/source.go b/engine/sources/httpdns/source.go index be27de8f63c..c51201b2f56 100644 --- a/engine/sources/httpdns/source.go +++ b/engine/sources/httpdns/source.go @@ -76,8 +76,8 @@ func (hs *httpSource) Identifier(scheme, ref string, attrs map[string]string, pl HTTPIdentifier: *(srcid.(*srchttp.HTTPIdentifier)), } - if v, ok := attrs[AttrHTTPClientIDs]; ok { - id.ClientIDs = strings.Split(v, ",") + if v, ok := attrs[AttrDNSNamespace]; ok { + id.Namespace = v } return id, nil @@ -106,8 +106,8 @@ type httpSourceHandler struct { func (hs *httpSourceHandler) client(g session.Group) *http.Client { clientDomains := []string{} - for _, clientID := range hs.src.ClientIDs { - clientDomains = append(clientDomains, network.ClientDomain(clientID)) + if ns := hs.src.Namespace; ns != "" { + clientDomains = append(clientDomains, network.ClientDomain(ns)) } dns := *hs.dns diff --git a/engine/sources/httpdns/state.go b/engine/sources/httpdns/state.go index 458dff3f50a..e980a959509 100644 --- a/engine/sources/httpdns/state.go +++ b/engine/sources/httpdns/state.go @@ -2,17 +2,14 @@ package httpdns import ( "strconv" - "strings" "github.com/moby/buildkit/client/llb" "github.com/moby/buildkit/solver/pb" ) -const AttrNetConfig = "httpdns.netconfig" - // HTTP is a helper mimicking the llb.HTTP function, but with the ability to // set additional attributes. -func HTTP(url string, clientIDs []string, opts ...llb.HTTPOption) llb.State { +func HTTP(url string, namespace string, opts ...llb.HTTPOption) llb.State { hi := &llb.HTTPInfo{} for _, o := range opts { o.SetHTTPOption(hi) @@ -34,7 +31,7 @@ func HTTP(url string, clientIDs []string, opts ...llb.HTTPOption) llb.State { attrs[pb.AttrHTTPGID] = strconv.Itoa(hi.GID) } - attrs[AttrHTTPClientIDs] = strings.Join(clientIDs, ",") + attrs[AttrDNSNamespace] = namespace source := llb.NewSource(url, attrs, hi.Constraints) return llb.NewState(source.Output())