From aaa722fc2877e460c9fabbd32ecd8821280921a6 Mon Sep 17 00:00:00 2001 From: Erik Sipsma Date: Wed, 20 Mar 2024 20:29:22 -0700 Subject: [PATCH 1/5] engine: consolidate IDs and re-use servers for nested execs This is an internal only refactor, though it fixes a few bugs while also simplifying quite a bit and setting us up for more simplifications soon. The biggest change is that nested execs connect back to the same server as the main client caller rather than being completely independent. * This is required for the fix to services used in modules (separate PR) to fully work * It also should fix the lack of docker auth in many of our integ tests, specifically those that use nested execs, which leads to dockerhub rate limiting Along the way it also does some consolidation of IDs, removing ModuleCallerDigest and just exclusively using ClientID. This requires that we tell module functions and other nested execs which ID to use, but that itself is setup for even more simplifications in follow-ups (we can remove the need for the current DaggerServer construct entirely, among other things). Signed-off-by: Erik Sipsma --- cmd/shim/main.go | 29 ++++---- core/container.go | 42 ++++------- core/interface.go | 2 +- core/modfunc.go | 74 ++++++------------- core/module.go | 2 +- core/object.go | 9 ++- core/query.go | 108 +++++++++++++++++----------- core/schema/host.go | 2 +- core/schema/module.go | 2 +- core/typedef.go | 11 +++ engine/buildkit/client.go | 18 ++--- engine/buildkit/containerimage.go | 2 +- engine/buildkit/filesync.go | 24 ++----- engine/buildkit/session.go | 17 ++++- engine/client/client.go | 65 ++++++++--------- engine/opts.go | 21 +++--- engine/server/buildkitcontroller.go | 10 ++- engine/server/server.go | 53 +++++++------- 18 files changed, 238 insertions(+), 253 deletions(-) diff --git a/cmd/shim/main.go b/cmd/shim/main.go index ae01fc8d9d..22943f5c45 100644 --- a/cmd/shim/main.go +++ b/cmd/shim/main.go @@ -24,7 +24,6 @@ import ( "github.com/cenkalti/backoff/v4" "github.com/containerd/console" "github.com/google/uuid" - "github.com/opencontainers/go-digest" "github.com/opencontainers/runtime-spec/specs-go" "go.opentelemetry.io/otel/propagation" "go.opentelemetry.io/otel/sdk/resource" @@ -543,7 +542,7 @@ func setupBundle() (returnExitCode int) { keepEnv := []string{} for _, env := range spec.Process.Env { switch { - case strings.HasPrefix(env, "_DAGGER_ENABLE_NESTING="): + case strings.HasPrefix(env, "_DAGGER_NESTED_CLIENT_ID="): // keep the env var; we use it at runtime keepEnv = append(keepEnv, env) @@ -562,6 +561,9 @@ func setupBundle() (returnExitCode int) { Source: "/run/buildkit/buildkitd.sock", }) case strings.HasPrefix(env, "_DAGGER_SERVER_ID="): + case strings.HasPrefix(env, "_DAGGER_ENGINE_VERSION="): + // don't need this at runtime, it is just for invalidating cache, which + // has already happened by now case strings.HasPrefix(env, aliasPrefix): // NB: don't keep this env var, it's only for the bundling step // keepEnv = append(keepEnv, env) @@ -715,7 +717,8 @@ func internalEnv(name string) (string, bool) { } func runWithNesting(ctx context.Context, cmd *exec.Cmd) error { - if _, found := internalEnv("_DAGGER_ENABLE_NESTING"); !found { + clientID, ok := internalEnv("_DAGGER_NESTED_CLIENT_ID") + if !ok { // no nesting; run as normal return execProcess(cmd, true) } @@ -733,27 +736,21 @@ func runWithNesting(ctx context.Context, cmd *exec.Cmd) error { } sessionPort := l.Addr().(*net.TCPAddr).Port + serverID, ok := internalEnv("_DAGGER_SERVER_ID") + if !ok { + return errors.New("missing nested client server ID") + } + parentClientIDsVal, _ := internalEnv("_DAGGER_PARENT_CLIENT_IDS") clientParams := client.Params{ + ID: clientID, + ServerID: serverID, SecretToken: sessionToken.String(), RunnerHost: "unix:///.runner.sock", ParentClientIDs: strings.Fields(parentClientIDsVal), } - if _, ok := internalEnv("_DAGGER_ENABLE_NESTING_IN_SAME_SESSION"); ok { - serverID, ok := internalEnv("_DAGGER_SERVER_ID") - if !ok { - return fmt.Errorf("missing _DAGGER_SERVER_ID") - } - clientParams.ServerID = serverID - } - - moduleCallerDigest, ok := internalEnv("_DAGGER_MODULE_CALLER_DIGEST") - if ok { - clientParams.ModuleCallerDigest = digest.Digest(moduleCallerDigest) - } - sess, ctx, err := client.Connect(ctx, clientParams) if err != nil { return fmt.Errorf("error connecting to engine: %w", err) diff --git a/core/container.go b/core/container.go index 8abecd4a3e..7683b57ac0 100644 --- a/core/container.go +++ b/core/container.go @@ -1007,16 +1007,15 @@ func (container *Container) WithExec(ctx context.Context, opts ContainerExecOpts // this allows executed containers to communicate back to this API if opts.ExperimentalPrivilegedNesting { - // include the engine version so that these execs get invalidated if the engine/API change - runOpts = append(runOpts, llb.AddEnv("_DAGGER_ENABLE_NESTING", engine.Version)) - } - - if opts.ModuleCallerDigest != "" { - runOpts = append(runOpts, llb.AddEnv("_DAGGER_MODULE_CALLER_DIGEST", opts.ModuleCallerDigest.String())) - } - - if opts.NestedInSameSession { - runOpts = append(runOpts, llb.AddEnv("_DAGGER_ENABLE_NESTING_IN_SAME_SESSION", "")) + clientID, err := container.Query.RegisterCaller(ctx, opts.NestedExecFunctionCall) + if err != nil { + return nil, fmt.Errorf("register caller: %w", err) + } + runOpts = append(runOpts, + llb.AddEnv("_DAGGER_NESTED_CLIENT_ID", clientID), + // include the engine version so that these execs get invalidated if the engine/API change + llb.AddEnv("_DAGGER_ENGINE_VERSION", engine.Version), + ) } metaSt, metaSourcePath := metaMount(opts.Stdin) @@ -1057,13 +1056,7 @@ func (container *Container) WithExec(ctx context.Context, opts ContainerExecOpts } // don't pass these through to the container when manually set, they are internal only - if name == "_DAGGER_ENABLE_NESTING" && !opts.ExperimentalPrivilegedNesting { - continue - } - if name == "_DAGGER_MODULE_CALLER_DIGEST" && opts.ModuleCallerDigest == "" { - continue - } - if name == "_DAGGER_ENABLE_NESTING_IN_SAME_SESSION" && !opts.NestedInSameSession { + if name == "_DAGGER_NESTED_CLIENT_ID" && !opts.ExperimentalPrivilegedNesting { continue } @@ -1784,22 +1777,15 @@ type ContainerExecOpts struct { // Redirect the command's standard error to a file in the container RedirectStderr string `default:""` - // Provide dagger access to the executed command - // Do not use this option unless you trust the command being executed. - // The command being executed WILL BE GRANTED FULL ACCESS TO YOUR HOST FILESYSTEM + // Provide the executed command access back to the Dagger API ExperimentalPrivilegedNesting bool `default:"false"` // Grant the process all root capabilities InsecureRootCapabilities bool `default:"false"` - // (Internal-only) If this exec is for a module function, this digest will be set in the - // grpc context metadata for any api requests back to the engine. It's used by the API - // server to determine which schema to serve and other module context metadata. - ModuleCallerDigest digest.Digest `name:"-"` - - // (Internal-only) Used for module function execs to trigger the nested api client to - // be connected back to the same session. - NestedInSameSession bool `name:"-"` + // (Internal-only) If this is a nested exec for a Function call, this should be set + // with the metadata for that call + NestedExecFunctionCall *FunctionCall `name:"-"` } type BuildArg struct { diff --git a/core/interface.go b/core/interface.go index 9eb3d604fc..8cdca1ed13 100644 --- a/core/interface.go +++ b/core/interface.go @@ -222,7 +222,7 @@ func (iface *InterfaceType) Install(ctx context.Context, dag *dagql.Server) erro }) } - res, err := callable.Call(ctx, dagql.CurrentID(ctx), &CallOpts{ + res, err := callable.Call(ctx, &CallOpts{ Inputs: callInputs, ParentVal: runtimeVal.Fields, }) diff --git a/core/modfunc.go b/core/modfunc.go index dba748d2a9..c75e539cbf 100644 --- a/core/modfunc.go +++ b/core/modfunc.go @@ -17,7 +17,6 @@ import ( "github.com/dagger/dagger/core/pipeline" "github.com/dagger/dagger/dagql" "github.com/dagger/dagger/dagql/call" - "github.com/dagger/dagger/engine" "github.com/dagger/dagger/engine/buildkit" ) @@ -113,7 +112,7 @@ func (fn *ModuleFunction) recordCall(ctx context.Context) { analytics.Ctx(ctx).Capture(ctx, "module_call", props) } -func (fn *ModuleFunction) Call(ctx context.Context, caller *call.ID, opts *CallOpts) (t dagql.Typed, rerr error) { +func (fn *ModuleFunction) Call(ctx context.Context, opts *CallOpts) (t dagql.Typed, rerr error) { mod := fn.mod lg := bklog.G(ctx).WithField("module", mod.Name()).WithField("function", fn.metadata.Name) @@ -166,23 +165,6 @@ func (fn *ModuleFunction) Call(ctx context.Context, caller *call.ID, opts *CallO }) } - callerDigestInputs := []string{} - { - callerIDDigest := caller.Digest() // FIXME(vito) canonicalize, once all that's implemented - callerDigestInputs = append(callerDigestInputs, callerIDDigest.String()) - } - if !opts.Cache { - // use the ServerID so that we bust cache once-per-session - clientMetadata, err := engine.ClientMetadataFromContext(ctx) - if err != nil { - return nil, fmt.Errorf("failed to get client metadata: %w", err) - } - callerDigestInputs = append(callerDigestInputs, clientMetadata.ServerID) - } - - callerDigest := digest.FromString(strings.Join(callerDigestInputs, " ")) - - ctx = bklog.WithLogger(ctx, bklog.G(ctx).WithField("caller_digest", callerDigest.String())) bklog.G(ctx).Debug("function call") defer func() { bklog.G(ctx).Debug("function call done") @@ -191,10 +173,28 @@ func (fn *ModuleFunction) Call(ctx context.Context, caller *call.ID, opts *CallO } }() + parentJSON, err := json.Marshal(opts.ParentVal) + if err != nil { + return nil, fmt.Errorf("failed to marshal parent value: %w", err) + } + + callMeta := &FunctionCall{ + Query: fn.root, + Name: fn.metadata.OriginalName, + Parent: parentJSON, + InputArgs: callInputs, + Module: mod, + Cache: opts.Cache, + SkipSelfSchema: opts.SkipSelfSchema, + } + if fn.objDef != nil { + callMeta.ParentName = fn.objDef.OriginalName + } + ctr := fn.runtime metaDir := NewScratchDirectory(mod.Query, mod.Query.Platform) - ctr, err := ctr.WithMountedDirectory(ctx, modMetaDirPath, metaDir, "", false) + ctr, err = ctr.WithMountedDirectory(ctx, modMetaDirPath, metaDir, "", false) if err != nil { return nil, fmt.Errorf("failed to mount mod metadata directory: %w", err) } @@ -219,45 +219,13 @@ func (fn *ModuleFunction) Call(ctx context.Context, caller *call.ID, opts *CallO // Setup the Exec for the Function call and evaluate it ctr, err = ctr.WithExec(ctx, ContainerExecOpts{ - ModuleCallerDigest: callerDigest, ExperimentalPrivilegedNesting: true, - NestedInSameSession: true, + NestedExecFunctionCall: callMeta, }) if err != nil { return nil, fmt.Errorf("failed to exec function: %w", err) } - parentJSON, err := json.Marshal(opts.ParentVal) - if err != nil { - return nil, fmt.Errorf("failed to marshal parent value: %w", err) - } - - callMeta := &FunctionCall{ - Query: fn.root, - Name: fn.metadata.OriginalName, - Parent: parentJSON, - InputArgs: callInputs, - } - if fn.objDef != nil { - callMeta.ParentName = fn.objDef.OriginalName - } - - var deps *ModDeps - if opts.SkipSelfSchema { - // Only serve the APIs of the deps of this module. This is currently only needed for the special - // case of the function used to get the definition of the module itself (which can't obviously - // be served the API its returning the definition of). - deps = mod.Deps - } else { - // by default, serve both deps and the module's own API to itself - deps = mod.Deps.Prepend(mod) - } - - err = mod.Query.RegisterFunctionCall(ctx, callerDigest, deps, fn.mod, callMeta) - if err != nil { - return nil, fmt.Errorf("failed to register function call: %w", err) - } - _, err = ctr.Evaluate(ctx) if err != nil { if fn.metadata.OriginalName == "" { diff --git a/core/module.go b/core/module.go index b4e28b07cf..3a84fb89a4 100644 --- a/core/module.go +++ b/core/module.go @@ -138,7 +138,7 @@ func (mod *Module) Initialize(ctx context.Context, oldSelf dagql.Instance[*Modul return nil, fmt.Errorf("failed to create module definition function for module %q: %w", mod.Name(), err) } - result, err := getModDefFn.Call(ctx, newID, &CallOpts{Cache: true, SkipSelfSchema: true}) + result, err := getModDefFn.Call(ctx, &CallOpts{Cache: true, SkipSelfSchema: true}) if err != nil { return nil, fmt.Errorf("failed to call module %q to get functions: %w", mod.Name(), err) } diff --git a/core/object.go b/core/object.go index d2c1009cdf..80a81ca629 100644 --- a/core/object.go +++ b/core/object.go @@ -9,7 +9,6 @@ import ( "github.com/vektah/gqlparser/v2/ast" "github.com/dagger/dagger/dagql" - "github.com/dagger/dagger/dagql/call" "github.com/dagger/dagger/engine/slog" ) @@ -85,7 +84,7 @@ func (t *ModuleObjectType) TypeDef() *TypeDef { } type Callable interface { - Call(context.Context, *call.ID, *CallOpts) (dagql.Typed, error) + Call(context.Context, *CallOpts) (dagql.Typed, error) ReturnType() (ModType, error) ArgType(argName string) (ModType, error) } @@ -262,7 +261,7 @@ func (obj *ModuleObject) installConstructor(ctx context.Context, dag *dagql.Serv Value: v, }) } - return fn.Call(ctx, dagql.CurrentID(ctx), &CallOpts{ + return fn.Call(ctx, &CallOpts{ Inputs: callInput, ParentVal: nil, }) @@ -359,7 +358,7 @@ func objFun(ctx context.Context, mod *Module, objDef *ObjectTypeDef, fun *Functi sort.Slice(opts.Inputs, func(i, j int) bool { return opts.Inputs[i].Name < opts.Inputs[j].Name }) - return modFun.Call(ctx, dagql.CurrentID(ctx), opts) + return modFun.Call(ctx, opts) }, }, nil } @@ -370,7 +369,7 @@ type CallableField struct { Return ModType } -func (f *CallableField) Call(ctx context.Context, id *call.ID, opts *CallOpts) (dagql.Typed, error) { +func (f *CallableField) Call(ctx context.Context, opts *CallOpts) (dagql.Typed, error) { val, ok := opts.ParentVal[f.Field.OriginalName] if !ok { return nil, fmt.Errorf("field %q not found on object %q", f.Field.Name, opts.ParentVal) diff --git a/core/query.go b/core/query.go index 9a9f7a5d24..09b9d7bc18 100644 --- a/core/query.go +++ b/core/query.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "net/http" + "strings" "sync" "github.com/containerd/containerd/content" @@ -57,8 +58,9 @@ type QueryOpts struct { // For the special case of the main client caller, the key is just empty string. // This is never explicitly deleted from; instead it will just be garbage collected // when this server for the session shuts down - ClientCallContext map[digest.Digest]*ClientCallContext - ClientCallMu *sync.RWMutex + ClientCallContext map[string]*ClientCallContext + ClientCallMu *sync.RWMutex + MainClientCallerID string // the http endpoints being served (as a map since APIs like shellEndpoint can add more) Endpoints map[string]http.Handler @@ -107,24 +109,18 @@ type ClientCallContext struct { // If the client is itself from a function call in a user module, these are set with the // metadata of that ongoing function call - Module *Module FnCall *FunctionCall } -func (q *Query) ServeModuleToMainClient(ctx context.Context, modMeta dagql.Instance[*Module]) error { +func (q *Query) ServeModule(ctx context.Context, mod *Module) error { clientMetadata, err := engine.ClientMetadataFromContext(ctx) if err != nil { return err } - if clientMetadata.ModuleCallerDigest != "" { - return fmt.Errorf("cannot serve module to client %s", clientMetadata.ClientID) - } - - mod := modMeta.Self q.ClientCallMu.Lock() defer q.ClientCallMu.Unlock() - callCtx, ok := q.ClientCallContext[""] + callCtx, ok := q.ClientCallContext[clientMetadata.ClientID] if !ok { return fmt.Errorf("client call not found") } @@ -132,34 +128,63 @@ func (q *Query) ServeModuleToMainClient(ctx context.Context, modMeta dagql.Insta return nil } -func (q *Query) RegisterFunctionCall( - ctx context.Context, - dgst digest.Digest, - deps *ModDeps, - mod *Module, - call *FunctionCall, -) error { - if dgst == "" { - return fmt.Errorf("cannot register function call with empty digest") +func (q *Query) RegisterCaller(ctx context.Context, call *FunctionCall) (string, error) { + if call == nil { + call = &FunctionCall{} + } + callCtx := &ClientCallContext{ + FnCall: call, + } + + currentID := dagql.CurrentID(ctx) + clientIDInputs := []string{currentID.Digest().String()} + if !call.Cache { + // use the ServerID so that we bust cache once-per-session + clientMetadata, err := engine.ClientMetadataFromContext(ctx) + if err != nil { + return "", err + } + clientIDInputs = append(clientIDInputs, clientMetadata.ServerID) + } + clientIDDigest := digest.FromString(strings.Join(clientIDInputs, " ")) + + // only use encoded part of digest because this ID ends up becoming a buildkit Session ID + // and buildkit has some ancient internal logic that splits on a colon to support some + // dev mode logic: https://github.com/moby/buildkit/pull/290 + // also trim it to 25 chars as it ends up becoming part of service URLs + clientID := clientIDDigest.Encoded()[:25] + + // break glass for debugging which client is which operation + // bklog.G(ctx).Debugf("CLIENT ID %s = %s", clientID, currentID.Display()) + + if call.Module == nil { + callCtx.Deps = q.DefaultDeps + } else { + callCtx.Deps = call.Module.Deps + // By default, serve both deps and the module's own API to itself. But if SkipSelfSchema is set, + // only serve the APIs of the deps of this module. This is currently only needed for the special + // case of the function used to get the definition of the module itself (which can't obviously + // be served the API its returning the definition of). + if !call.SkipSelfSchema { + callCtx.Deps = callCtx.Deps.Append(call.Module) + } } q.ClientCallMu.Lock() defer q.ClientCallMu.Unlock() - _, ok := q.ClientCallContext[dgst] + _, ok := q.ClientCallContext[clientID] if ok { - return nil + return clientID, nil } - newRoot, err := NewRoot(ctx, q.QueryOpts) + + var err error + callCtx.Root, err = NewRoot(ctx, q.QueryOpts) if err != nil { - return err - } - q.ClientCallContext[dgst] = &ClientCallContext{ - Root: newRoot, - Deps: deps, - Module: mod, - FnCall: call, + return "", err } - return nil + + q.ClientCallContext[clientID] = callCtx + return clientID, nil } func (q *Query) CurrentModule(ctx context.Context) (*Module, error) { @@ -167,17 +192,20 @@ func (q *Query) CurrentModule(ctx context.Context) (*Module, error) { if err != nil { return nil, err } - if clientMetadata.ModuleCallerDigest == "" { - return nil, fmt.Errorf("%w: main client caller has no module", ErrNoCurrentModule) + if clientMetadata.ClientID == q.MainClientCallerID { + return nil, fmt.Errorf("%w: main client caller has no current module", ErrNoCurrentModule) } q.ClientCallMu.RLock() defer q.ClientCallMu.RUnlock() - callCtx, ok := q.ClientCallContext[clientMetadata.ModuleCallerDigest] + callCtx, ok := q.ClientCallContext[clientMetadata.ClientID] if !ok { - return nil, fmt.Errorf("client call %s not found", clientMetadata.ModuleCallerDigest) + return nil, fmt.Errorf("client call %s not found", clientMetadata.ClientID) + } + if callCtx.FnCall.Module == nil { + return nil, ErrNoCurrentModule } - return callCtx.Module, nil + return callCtx.FnCall.Module, nil } func (q *Query) CurrentFunctionCall(ctx context.Context) (*FunctionCall, error) { @@ -185,15 +213,15 @@ func (q *Query) CurrentFunctionCall(ctx context.Context) (*FunctionCall, error) if err != nil { return nil, err } - if clientMetadata.ModuleCallerDigest == "" { + if clientMetadata.ClientID == q.MainClientCallerID { return nil, fmt.Errorf("%w: main client caller has no function", ErrNoCurrentModule) } q.ClientCallMu.RLock() defer q.ClientCallMu.RUnlock() - callCtx, ok := q.ClientCallContext[clientMetadata.ModuleCallerDigest] + callCtx, ok := q.ClientCallContext[clientMetadata.ClientID] if !ok { - return nil, fmt.Errorf("client call %s not found", clientMetadata.ModuleCallerDigest) + return nil, fmt.Errorf("client call %s not found", clientMetadata.ClientID) } return callCtx.FnCall, nil @@ -204,9 +232,9 @@ func (q *Query) CurrentServedDeps(ctx context.Context) (*ModDeps, error) { if err != nil { return nil, err } - callCtx, ok := q.ClientCallContext[clientMetadata.ModuleCallerDigest] + callCtx, ok := q.ClientCallContext[clientMetadata.ClientID] if !ok { - return nil, fmt.Errorf("client call %s not found", clientMetadata.ModuleCallerDigest) + return nil, fmt.Errorf("client call %s not found", clientMetadata.ClientID) } return callCtx.Deps, nil } diff --git a/core/schema/host.go b/core/schema/host.go index 28e9c3b3bd..6d9246c03a 100644 --- a/core/schema/host.go +++ b/core/schema/host.go @@ -207,7 +207,7 @@ func (s *hostSchema) socket(ctx context.Context, host *core.Host, args hostSocke if err != nil { return nil, fmt.Errorf("failed to get client metadata: %w", err) } - if clientMetadata.ClientID != host.Query.Buildkit.MainClientCallerID { + if clientMetadata.ClientID != host.Query.MainClientCallerID { return nil, fmt.Errorf("only the main client can access the host's unix sockets") } diff --git a/core/schema/module.go b/core/schema/module.go index 49d8d266fe..dbbc2a1c23 100644 --- a/core/schema/module.go +++ b/core/schema/module.go @@ -465,7 +465,7 @@ func (s *moduleSchema) currentFunctionCall(ctx context.Context, self *core.Query } func (s *moduleSchema) moduleServe(ctx context.Context, modMeta dagql.Instance[*core.Module], _ struct{}) (dagql.Nullable[core.Void], error) { - return dagql.Null[core.Void](), modMeta.Self.Query.ServeModuleToMainClient(ctx, modMeta) + return dagql.Null[core.Void](), modMeta.Self.Query.ServeModule(ctx, modMeta.Self) } func (s *moduleSchema) currentTypeDefs(ctx context.Context, self *core.Query, _ struct{}) ([]*core.TypeDef, error) { diff --git a/core/typedef.go b/core/typedef.go index 781edbb15c..3b67e9689c 100644 --- a/core/typedef.go +++ b/core/typedef.go @@ -849,6 +849,17 @@ type FunctionCall struct { ParentName string `field:"true" doc:"The name of the parent object of the function being called. If the function is top-level to the module, this is the name of the module."` Parent JSON `field:"true" doc:"The value of the parent object of the function being called. If the function is top-level to the module, this is always an empty object."` InputArgs []*FunctionCallArgValue `field:"true" doc:"The argument values the function is being invoked with."` + + // Below are not in public API + + // The module that the function is being called from + Module *Module + + // Whether the function call should be cached across different servers + Cache bool + + // Whether to serve the schema for the function's own module to it or not + SkipSelfSchema bool } func (*FunctionCall) Type() *ast.Type { diff --git a/engine/buildkit/client.go b/engine/buildkit/client.go index 04da992f04..a781be3181 100644 --- a/engine/buildkit/client.go +++ b/engine/buildkit/client.go @@ -40,7 +40,6 @@ import ( "google.golang.org/grpc/metadata" "github.com/dagger/dagger/auth" - "github.com/dagger/dagger/engine" "github.com/dagger/dagger/engine/session" ) @@ -63,11 +62,10 @@ type Opts struct { // client. It is special in that when it shuts down, the client will be closed and // that registry auth and sockets are currently only ever sourced from this caller, // not any nested clients (may change in future). - MainClientCaller bksession.Caller - MainClientCallerID string - DNSConfig *oci.DNSConfig - Frontends map[string]bkfrontend.Frontend - BuildkitLogSink io.Writer + MainClientCaller bksession.Caller + DNSConfig *oci.DNSConfig + Frontends map[string]bkfrontend.Frontend + BuildkitLogSink io.Writer sharedClientState } @@ -577,13 +575,7 @@ func (c *Client) ListenHostToContainer( return nil, nil, err } - clientMetadata, err := engine.ClientMetadataFromContext(ctx) - if err != nil { - cancel() - return nil, nil, fmt.Errorf("failed to get requester session ID: %w", err) - } - - clientCaller, err := c.SessionManager.Get(ctx, clientMetadata.ClientID, false) + clientCaller, err := c.GetSessionCaller(ctx, false) if err != nil { cancel() return nil, nil, fmt.Errorf("failed to get requester session: %w", err) diff --git a/engine/buildkit/containerimage.go b/engine/buildkit/containerimage.go index a9ec7920d2..78a5f1728c 100644 --- a/engine/buildkit/containerimage.go +++ b/engine/buildkit/containerimage.go @@ -108,7 +108,7 @@ func (c *Client) ExportContainerImage( IsFileStream: true, }.AppendToOutgoingContext(ctx) - resp, descRef, err := expInstance.Export(ctx, combinedResult, nil, clientMetadata.ClientID) + resp, descRef, err := expInstance.Export(ctx, combinedResult, nil, clientMetadata.BuildkitSessionID()) if err != nil { return nil, fmt.Errorf("failed to export: %w", err) } diff --git a/engine/buildkit/filesync.go b/engine/buildkit/filesync.go index c45b5b75e5..7606f91f23 100644 --- a/engine/buildkit/filesync.go +++ b/engine/buildkit/filesync.go @@ -45,7 +45,7 @@ func (c *Client) LocalImport( } localOpts := []llb.LocalOption{ - llb.SessionID(clientMetadata.ClientID), + llb.SessionID(clientMetadata.BuildkitSessionID()), llb.SharedKeyHint(strings.Join([]string{clientMetadata.ClientHostname, srcPath}, " ")), } @@ -108,13 +108,9 @@ func (c *Client) diffcopy(ctx context.Context, opts engine.LocalImportOpts, msg } defer cancel() - clientMetadata, err := engine.ClientMetadataFromContext(ctx) - if err != nil { - return fmt.Errorf("failed to get requester session ID: %w", err) - } ctx = opts.AppendToOutgoingContext(ctx) - clientCaller, err := c.SessionManager.Get(ctx, clientMetadata.ClientID, false) + clientCaller, err := c.GetSessionCaller(ctx, true) if err != nil { return fmt.Errorf("failed to get requester session: %w", err) } @@ -213,7 +209,7 @@ func (c *Client) LocalDirExport( Merge: merge, }.AppendToOutgoingContext(ctx) - _, descRef, err := expInstance.Export(ctx, cacheRes, nil, clientMetadata.ClientID) + _, descRef, err := expInstance.Export(ctx, cacheRes, nil, clientMetadata.BuildkitSessionID()) if err != nil { return fmt.Errorf("failed to export: %w", err) } @@ -288,11 +284,6 @@ func (c *Client) LocalFileExport( return fmt.Errorf("failed to stat file: %w", err) } - clientMetadata, err := engine.ClientMetadataFromContext(ctx) - if err != nil { - return fmt.Errorf("failed to get requester session ID: %w", err) - } - ctx = engine.LocalExportOpts{ Path: destPath, IsFileStream: true, @@ -301,7 +292,7 @@ func (c *Client) LocalFileExport( FileMode: stat.Mode().Perm(), }.AppendToOutgoingContext(ctx) - clientCaller, err := c.SessionManager.Get(ctx, clientMetadata.ClientID, false) + clientCaller, err := c.GetSessionCaller(ctx, true) if err != nil { return fmt.Errorf("failed to get requester session: %w", err) } @@ -357,11 +348,6 @@ func (c *Client) IOReaderExport(ctx context.Context, r io.Reader, destPath strin lg.Trace("finished exporting bytes") }() - clientMetadata, err := engine.ClientMetadataFromContext(ctx) - if err != nil { - return fmt.Errorf("failed to get requester session ID: %w", err) - } - ctx = engine.LocalExportOpts{ Path: destPath, IsFileStream: true, @@ -369,7 +355,7 @@ func (c *Client) IOReaderExport(ctx context.Context, r io.Reader, destPath strin FileMode: destMode, }.AppendToOutgoingContext(ctx) - clientCaller, err := c.SessionManager.Get(ctx, clientMetadata.ClientID, false) + clientCaller, err := c.GetSessionCaller(ctx, true) if err != nil { return fmt.Errorf("failed to get requester session: %w", err) } diff --git a/engine/buildkit/session.go b/engine/buildkit/session.go index 7ba27c76c6..694dcddaa3 100644 --- a/engine/buildkit/session.go +++ b/engine/buildkit/session.go @@ -15,6 +15,7 @@ import ( "github.com/moby/buildkit/session/secrets/secretsprovider" "github.com/moby/buildkit/util/bklog" + "github.com/dagger/dagger/engine" "github.com/dagger/dagger/engine/client" "github.com/dagger/dagger/engine/distconsts" ) @@ -85,7 +86,17 @@ func (c *Client) newSession() (*bksession.Session, error) { return sess, nil } -func (c *Client) GetSessionCaller(ctx context.Context, clientID string) (bksession.Caller, error) { - waitForSession := true - return c.SessionManager.Get(ctx, clientID, !waitForSession) +func (c *Client) GetSessionCaller(ctx context.Context, wait bool) (bksession.Caller, error) { + clientMetadata, err := engine.ClientMetadataFromContext(ctx) + if err != nil { + return nil, err + } + caller, err := c.SessionManager.Get(ctx, clientMetadata.BuildkitSessionID(), !wait) + if err != nil { + return nil, err + } + if caller == nil { + return nil, fmt.Errorf("session for %q not found", clientMetadata.BuildkitSessionID()) + } + return caller, nil } diff --git a/engine/client/client.go b/engine/client/client.go index 64b731f24d..9baa9aef79 100644 --- a/engine/client/client.go +++ b/engine/client/client.go @@ -31,7 +31,6 @@ import ( "github.com/moby/buildkit/session/filesync" "github.com/moby/buildkit/session/grpchijack" "github.com/moby/buildkit/util/grpcerrors" - "github.com/opencontainers/go-digest" "github.com/tonistiigi/fsutil" fstypes "github.com/tonistiigi/fsutil/types" sdktrace "go.opentelemetry.io/otel/sdk/trace" @@ -53,8 +52,13 @@ import ( ) type Params struct { - // The id of the server to connect to, or if blank a new one - // should be started. + // The id to connect to the API server with. If blank, will be set to a + // new random value. + ID string + + // The id of the server to connect to, or if blank a new one should be started. + // Needed separately from the client ID as that ID is a digest which could + // be reused across multiple servers. ServerID string // Parent client IDs of this Dagger client. @@ -72,14 +76,8 @@ type Params struct { EngineNameCallback func(string) CloudURLCallback func(string) - - // If this client is for a module function, this digest will be set in the - // grpc context metadata for any api requests back to the engine. It's used by the API - // server to determine which schema to serve and other module context metadata. - ModuleCallerDigest digest.Digest - - EngineTrace sdktrace.SpanExporter - EngineLogs sdklog.LogExporter + EngineTrace sdktrace.SpanExporter + EngineLogs sdklog.LogExporter } type Client struct { @@ -119,12 +117,16 @@ type Client struct { func Connect(ctx context.Context, params Params) (_ *Client, _ context.Context, rerr error) { c := &Client{Params: params} - if c.SecretToken == "" { - c.SecretToken = uuid.New().String() + if c.ID == "" { + c.ID = identity.NewID() } + configuredServerID := c.ServerID if c.ServerID == "" { c.ServerID = identity.NewID() } + if c.SecretToken == "" { + c.SecretToken = uuid.New().String() + } // keep the root ctx around so we can detect whether we've been interrupted, // so we can drain immediately in that scenario @@ -181,7 +183,8 @@ func Connect(ctx context.Context, params Params) (_ *Client, _ context.Context, telemetry.Encapsulate(), } - if c.Params.ModuleCallerDigest != "" { + if configuredServerID != "" { + // infer that this is not a main client caller, server ID is never set for those currently connectSpanOpts = append(connectSpanOpts, telemetry.Internal()) } @@ -320,13 +323,11 @@ func (c *Client) startSession(ctx context.Context) (rerr error) { }() c.internalCtx = engine.ContextWithClientMetadata(c.internalCtx, &engine.ClientMetadata{ - ClientID: c.ID(), - ClientSecretToken: c.SecretToken, - ServerID: c.ServerID, - ClientHostname: c.hostname, - Labels: c.labels, - ParentClientIDs: c.ParentClientIDs, - ModuleCallerDigest: c.ModuleCallerDigest, + ClientID: c.ID, + ClientSecretToken: c.SecretToken, + ClientHostname: c.hostname, + Labels: c.labels, + ParentClientIDs: c.ParentClientIDs, }) // filesync @@ -352,15 +353,14 @@ func (c *Client) startSession(ctx context.Context) (rerr error) { return bkSession.Run(c.internalCtx, func(ctx context.Context, proto string, meta map[string][]string) (net.Conn, error) { return grpchijack.Dialer(c.bkClient.ControlClient())(ctx, proto, engine.ClientMetadata{ RegisterClient: true, - ClientID: c.ID(), - ClientSecretToken: c.SecretToken, + ClientID: c.ID, ServerID: c.ServerID, + ClientSecretToken: c.SecretToken, ParentClientIDs: c.ParentClientIDs, ClientHostname: hostname, UpstreamCacheImportConfig: c.upstreamCacheImportOptions, UpstreamCacheExportConfig: c.upstreamCacheExportOptions, Labels: c.labels, - ModuleCallerDigest: c.ModuleCallerDigest, CloudToken: os.Getenv("DAGGER_CLOUD_TOKEN"), DoNotTrack: analytics.DoNotTrack(), }.AppendToMD(meta)) @@ -609,10 +609,6 @@ func (c *Client) withClientCloseCancel(ctx context.Context) (context.Context, co return ctx, cancel, nil } -func (c *Client) ID() string { - return c.bkSession.ID() -} - func (c *Client) DialContext(ctx context.Context, _, _ string) (conn net.Conn, err error) { // NOTE: the context given to grpchijack.Dialer is for the lifetime of the stream. // If http connection re-use is enabled, that can be far past this DialContext call. @@ -629,13 +625,12 @@ func (c *Client) DialContext(ctx context.Context, _, _ string) (conn net.Conn, e }).Dial("tcp", "127.0.0.1:"+strconv.Itoa(c.nestedSessionPort)) } else { conn, err = grpchijack.Dialer(c.bkClient.ControlClient())(ctx, "", engine.ClientMetadata{ - ClientID: c.ID(), - ClientSecretToken: c.SecretToken, - ServerID: c.ServerID, - ClientHostname: c.hostname, - ParentClientIDs: c.ParentClientIDs, - Labels: c.labels, - ModuleCallerDigest: c.ModuleCallerDigest, + ClientID: c.ID, + ServerID: c.ServerID, + ClientSecretToken: c.SecretToken, + ClientHostname: c.hostname, + ParentClientIDs: c.ParentClientIDs, + Labels: c.labels, }.ToGRPCMD()) } if err != nil { diff --git a/engine/opts.go b/engine/opts.go index 12f55f6fec..6219b7d699 100644 --- a/engine/opts.go +++ b/engine/opts.go @@ -9,10 +9,10 @@ import ( "os" "path/filepath" "strconv" + "strings" "unicode" controlapi "github.com/moby/buildkit/api/services/control" - "github.com/opencontainers/go-digest" "google.golang.org/grpc/metadata" "github.com/dagger/dagger/telemetry" @@ -33,11 +33,12 @@ const ( ) type ClientMetadata struct { - // ClientID is unique to every session created by every client + // ClientID is unique to each client. The main client's ID is the empty string, + // any module and/or nested exec client's ID is a unique digest. ClientID string `json:"client_id"` // ClientSecretToken is a secret token that is unique to every client. It's - // initially provided to the server in the controller.Solve request. Every + // initially provided to the server in the controller.Session request. Every // other request w/ that client ID must also include the same token. ClientSecretToken string `json:"client_secret_token"` @@ -66,11 +67,6 @@ type ClientMetadata struct { // parent of the parent, and so on. ParentClientIDs []string `json:"parent_client_ids"` - // If this client is for a module function, this digest will be set in the - // grpc context metadata for any api requests back to the engine. It's used by the API - // server to determine which schema to serve and other module context metadata. - ModuleCallerDigest digest.Digest `json:"module_caller_digest"` - // Import configuration for Buildkit's remote cache UpstreamCacheImportConfig []*controlapi.CacheOptionsEntry @@ -100,6 +96,15 @@ func (m ClientMetadata) AppendToMD(md metadata.MD) metadata.MD { return md } +// The ID to use for this client's buildkit session. It's a combination of both +// the client and the server IDs to account for the fact that the client ID is +// a content digest for functions/nested-execs, meaning it can reoccur across +// different servers; that doesn't work because buildkit's SessionManager is +// global to the whole process. +func (m ClientMetadata) BuildkitSessionID() string { + return strings.Join([]string{m.ClientID, m.ServerID}, "-") +} + func ContextWithClientMetadata(ctx context.Context, clientMetadata *ClientMetadata) context.Context { return contextWithMD(ctx, clientMetadata.ToGRPCMD()) } diff --git a/engine/server/buildkitcontroller.go b/engine/server/buildkitcontroller.go index 211a547c98..a3ad5a58a2 100644 --- a/engine/server/buildkitcontroller.go +++ b/engine/server/buildkitcontroller.go @@ -154,7 +154,6 @@ func (e *BuildkitController) Session(stream controlapi.Control_SessionServer) (r ctx = bklog.WithLogger(ctx, bklog.G(ctx). WithField("client_id", opts.ClientID). WithField("client_hostname", opts.ClientHostname). - WithField("client_call_digest", opts.ModuleCallerDigest). WithField("server_id", opts.ServerID)) { @@ -206,6 +205,15 @@ func (e *BuildkitController) Session(stream controlapi.Control_SessionServer) (r eg, egctx := errgroup.WithContext(ctx) eg.Go(func() error { + // overwrite the session ID to be our client ID + server ID + const sessionIDHeader = "x-docker-expose-session-uuid" + if _, ok := hijackmd[sessionIDHeader]; !ok { + // should never happen unless upstream changes the value of the header key, + // in which case we want to know + panic(fmt.Errorf("missing header %s", sessionIDHeader)) + } + hijackmd[sessionIDHeader] = []string{opts.BuildkitSessionID()} + bklog.G(ctx).Trace("session manager handling conn") err := e.SessionManager.HandleConn(egctx, conn, hijackmd) bklog.G(ctx).WithError(err).Trace("session manager handle conn done") diff --git a/engine/server/server.go b/engine/server/server.go index 9cfe4cad0a..ffcb20465a 100644 --- a/engine/server/server.go +++ b/engine/server/server.go @@ -19,7 +19,6 @@ import ( bkgw "github.com/moby/buildkit/frontend/gateway/client" "github.com/moby/buildkit/session" "github.com/moby/buildkit/util/bklog" - "github.com/opencontainers/go-digest" "github.com/sirupsen/logrus" "github.com/vektah/gqlparser/v2/gqlerror" "go.opentelemetry.io/otel/propagation" @@ -46,10 +45,9 @@ type DaggerServer struct { clientIDMu sync.RWMutex // The metadata of client calls. - // For the special case of the main client caller, the key is just empty string. // This is never explicitly deleted from; instead it will just be garbage collected // when this server for the session shuts down - clientCallContext map[digest.Digest]*core.ClientCallContext + clientCallContext map[string]*core.ClientCallContext clientCallMu *sync.RWMutex // the http endpoints being served (as a map since APIs like shellEndpoint can add more) @@ -74,7 +72,7 @@ func (e *BuildkitController) newDaggerServer(ctx context.Context, clientMetadata serverID: clientMetadata.ServerID, clientIDToSecretToken: map[string]string{}, - clientCallContext: map[digest.Digest]*core.ClientCallContext{}, + clientCallContext: map[string]*core.ClientCallContext{}, clientCallMu: &sync.RWMutex{}, endpoints: map[string]http.Handler{}, endpointMu: &sync.RWMutex{}, @@ -108,7 +106,7 @@ func (e *BuildkitController) newDaggerServer(ctx context.Context, clientMetadata getSessionCtx, getSessionCancel := context.WithTimeout(ctx, 10*time.Second) defer getSessionCancel() - sessionCaller, err := e.SessionManager.Get(getSessionCtx, clientMetadata.ClientID, false) + sessionCaller, err := e.SessionManager.Get(getSessionCtx, clientMetadata.BuildkitSessionID(), false) if err != nil { return nil, fmt.Errorf("get session: %w", err) } @@ -149,21 +147,21 @@ func (e *BuildkitController) newDaggerServer(ctx context.Context, clientMetadata PrivilegedExecEnabled: e.privilegedExecEnabled, UpstreamCacheImports: cacheImporterCfgs, MainClientCaller: sessionCaller, - MainClientCallerID: s.mainClientCallerID, DNSConfig: e.DNSConfig, Frontends: e.Frontends, BuildkitLogSink: e.BuildkitLogSink, }, - Services: s.services, - Platform: core.Platform(e.worker.Platforms(true)[0]), - Secrets: secretStore, - OCIStore: e.worker.ContentStore(), - LeaseManager: e.worker.LeaseManager(), - Auth: authProvider, - ClientCallContext: s.clientCallContext, - ClientCallMu: s.clientCallMu, - Endpoints: s.endpoints, - EndpointMu: s.endpointMu, + Services: s.services, + Platform: core.Platform(e.worker.Platforms(true)[0]), + Secrets: secretStore, + OCIStore: e.worker.ContentStore(), + LeaseManager: e.worker.LeaseManager(), + Auth: authProvider, + ClientCallContext: s.clientCallContext, + ClientCallMu: s.clientCallMu, + MainClientCallerID: s.mainClientCallerID, + Endpoints: s.endpoints, + EndpointMu: s.endpointMu, }) if err != nil { return nil, err @@ -183,7 +181,7 @@ func (e *BuildkitController) newDaggerServer(ctx context.Context, clientMetadata } // the main client caller starts out with the core API loaded - s.clientCallContext[""] = &core.ClientCallContext{ + s.clientCallContext[s.mainClientCallerID] = &core.ClientCallContext{ Deps: root.DefaultDeps, Root: root, } @@ -251,18 +249,21 @@ func (s *DaggerServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } - callContext, ok := s.ClientCallContext(clientMetadata.ModuleCallerDigest) + callContext, ok := s.ClientCallContext(clientMetadata.ClientID) if !ok { - errorOut(fmt.Errorf("client call %s not found", clientMetadata.ModuleCallerDigest), http.StatusInternalServerError) + errorOut(fmt.Errorf("client call for %s not found", clientMetadata.ClientID), http.StatusInternalServerError) return } + s.clientCallMu.RLock() schema, err := callContext.Deps.Schema(ctx) if err != nil { + s.clientCallMu.RUnlock() // TODO: technically this is not *always* bad request, should ideally be more specific and differentiate errorOut(err, http.StatusBadRequest) return } + s.clientCallMu.RUnlock() defer func() { if v := recover(); v != nil { @@ -314,12 +315,10 @@ func (s *DaggerServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { slog := slog.With( "isImmediate", immediate, "isMainClient", clientMetadata.ClientID == s.mainClientCallerID, - "isModule", clientMetadata.ModuleCallerDigest != "", "serverID", s.serverID, "traceID", s.traceID, "clientID", clientMetadata.ClientID, - "mainClientID", s.mainClientCallerID, - "callerID", clientMetadata.ModuleCallerDigest) + "mainClientID", s.mainClientCallerID) slog.Trace("shutting down server") defer slog.Trace("done shutting down server") @@ -346,7 +345,7 @@ func (s *DaggerServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { } } s.clientCallMu.RLock() - bk := s.clientCallContext[""].Root.Buildkit + bk := s.clientCallContext[s.mainClientCallerID].Root.Buildkit s.clientCallMu.RUnlock() err := bk.UpstreamCacheExport(ctx, cacheExporterFuncs) if err != nil { @@ -404,10 +403,10 @@ func (s *DaggerServer) VerifyClient(clientID, secretToken string) error { return nil } -func (s *DaggerServer) ClientCallContext(clientDigest digest.Digest) (*core.ClientCallContext, bool) { +func (s *DaggerServer) ClientCallContext(clientID string) (*core.ClientCallContext, bool) { s.clientCallMu.RLock() defer s.clientCallMu.RUnlock() - ctx, ok := s.clientCallContext[clientDigest] + ctx, ok := s.clientCallContext[clientID] return ctx, ok } @@ -416,9 +415,9 @@ func (s *DaggerServer) CurrentServedDeps(ctx context.Context) (*core.ModDeps, er if err != nil { return nil, err } - callCtx, ok := s.ClientCallContext(clientMetadata.ModuleCallerDigest) + callCtx, ok := s.ClientCallContext(clientMetadata.ClientID) if !ok { - return nil, fmt.Errorf("client call %s not found", clientMetadata.ModuleCallerDigest) + return nil, fmt.Errorf("client call for %s not found", clientMetadata.ClientID) } return callCtx.Deps, nil } From adc52ead0350b6dce4f98d38daff64c95dadd886 Mon Sep 17 00:00:00 2001 From: Alex Suraci Date: Tue, 9 Apr 2024 16:31:15 -0400 Subject: [PATCH 2/5] namespace services by server, not by client Previously it was possible to start a dependent service in one module API call, and then use it again in a later call, only to have it fail because it cannot resolve the service address, even though it's still running. This happened because each invocation has its own client ID, and client IDs were used to build service addresses. This change brings service addresses into alignment with the recent change to uniq them by service ID instead of client ID. The overall effect is that services are deduped within a Dagger invocation, even across module calls. So with this change, the service will just stay running and be re-used by a later call, thanks to the grace period. Signed-off-by: Alex Suraci --- cmd/shim/main.go | 18 +++++-------- core/container.go | 3 +-- core/git.go | 38 ++++++++++------------------ core/schema/http.go | 21 +++------------ core/service.go | 7 +++-- engine/buildkit/client.go | 3 +-- engine/client/client.go | 9 ------- engine/opts.go | 10 -------- engine/sources/gitdns/identifier.go | 4 +-- engine/sources/gitdns/source.go | 9 +++---- engine/sources/gitdns/state.go | 7 ++--- engine/sources/httpdns/identifier.go | 4 +-- engine/sources/httpdns/source.go | 8 +++--- engine/sources/httpdns/state.go | 7 ++--- 14 files changed, 44 insertions(+), 104 deletions(-) diff --git a/cmd/shim/main.go b/cmd/shim/main.go index 22943f5c45..8172b3a2e5 100644 --- a/cmd/shim/main.go +++ b/cmd/shim/main.go @@ -500,11 +500,8 @@ func setupBundle() (returnExitCode int) { } var searchDomains []string - for _, parentClientID := range execMetadata.ParentClientIDs { - searchDomains = append(searchDomains, network.ClientDomain(parentClientID)) - } - if len(searchDomains) > 0 { - spec.Process.Env = append(spec.Process.Env, "_DAGGER_PARENT_CLIENT_IDS="+strings.Join(execMetadata.ParentClientIDs, " ")) + if ns := execMetadata.ServerID; ns != "" { + searchDomains = append(searchDomains, network.ClientDomain(ns)) } var hostsFilePath string @@ -741,14 +738,11 @@ func runWithNesting(ctx context.Context, cmd *exec.Cmd) error { return errors.New("missing nested client server ID") } - parentClientIDsVal, _ := internalEnv("_DAGGER_PARENT_CLIENT_IDS") - clientParams := client.Params{ - ID: clientID, - ServerID: serverID, - SecretToken: sessionToken.String(), - RunnerHost: "unix:///.runner.sock", - ParentClientIDs: strings.Fields(parentClientIDsVal), + ID: clientID, + ServerID: serverID, + SecretToken: sessionToken.String(), + RunnerHost: "unix:///.runner.sock", } sess, ctx, err := client.Connect(ctx, clientParams) diff --git a/core/container.go b/core/container.go index 7683b57ac0..9d2c976d8b 100644 --- a/core/container.go +++ b/core/container.go @@ -1181,8 +1181,7 @@ func (container *Container) WithExec(ctx context.Context, opts ContainerExecOpts return nil, err } execMeta := buildkit.ContainerExecUncachedMetadata{ - ParentClientIDs: clientMetadata.ClientIDs(), - ServerID: clientMetadata.ServerID, + ServerID: clientMetadata.ServerID, } proxyVal, err := execMeta.ToPBFtpProxyVal() if err != nil { diff --git a/core/git.go b/core/git.go index f0f9774f81..3317c13f58 100644 --- a/core/git.go +++ b/core/git.go @@ -60,14 +60,20 @@ func (*GitRef) TypeDescription() string { func (ref *GitRef) Tree(ctx context.Context) (*Directory, error) { bk := ref.Query.Buildkit - st := ref.getState(ctx, bk) - return NewDirectorySt(ctx, ref.Query, *st, "", ref.Repo.Platform, ref.Repo.Services) + st, err := ref.getState(ctx, bk) + if err != nil { + return nil, err + } + return NewDirectorySt(ctx, ref.Query, st, "", ref.Repo.Platform, ref.Repo.Services) } func (ref *GitRef) Commit(ctx context.Context) (string, error) { bk := ref.Query.Buildkit - st := ref.getState(ctx, bk) - p, err := resolveProvenance(ctx, bk, *st) + st, err := ref.getState(ctx, bk) + if err != nil { + return "", err + } + p, err := resolveProvenance(ctx, bk, st) if err != nil { return "", err } @@ -77,7 +83,7 @@ func (ref *GitRef) Commit(ctx context.Context) (string, error) { return p.Sources.Git[0].Commit, nil } -func (ref *GitRef) getState(ctx context.Context, bk *buildkit.Client) *llb.State { +func (ref *GitRef) getState(ctx context.Context, bk *buildkit.Client) (llb.State, error) { opts := []llb.GitOption{} if ref.Repo.KeepGitDir { @@ -96,26 +102,10 @@ func (ref *GitRef) getState(ctx context.Context, bk *buildkit.Client) *llb.State opts = append(opts, llb.AuthHeaderSecret(ref.Repo.AuthHeader.Accessor)) } - useDNS := len(ref.Repo.Services) > 0 - clientMetadata, err := engine.ClientMetadataFromContext(ctx) - if err == nil && !useDNS { - useDNS = len(clientMetadata.ParentClientIDs) > 0 + if err != nil { + return llb.State{}, err } - var st llb.State - if useDNS { - // NB: only configure search domains if we're directly using a service, or - // if we're nested beneath another search domain. - // - // we have to be a bit selective here to avoid breaking Dockerfile builds - // that use a Buildkit frontend (# syntax = ...) that doesn't have the - // networks API cap. - // - // TODO: add API cap - st = gitdns.Git(ref.Repo.URL, ref.Ref, clientMetadata.ClientIDs(), opts...) - } else { - st = llb.Git(ref.Repo.URL, ref.Ref, opts...) - } - return &st + return gitdns.Git(ref.Repo.URL, ref.Ref, clientMetadata.ServerID, opts...), nil } diff --git a/core/schema/http.go b/core/schema/http.go index d052c9e714..a581317bdb 100644 --- a/core/schema/http.go +++ b/core/schema/http.go @@ -60,26 +60,11 @@ func (s *httpSchema) http(ctx context.Context, parent *core.Query, args httpArgs llb.Filename(filename), } - useDNS := len(svcs) > 0 - clientMetadata, err := engine.ClientMetadataFromContext(ctx) - if err == nil && !useDNS { - useDNS = len(clientMetadata.ParentClientIDs) > 0 - } - - var st llb.State - if useDNS { - // NB: only configure search domains if we're directly using a service, or - // if we're nested. - // - // we have to be a bit selective here to avoid breaking Dockerfile builds - // that use a Buildkit frontend (# syntax = ...). - // - // TODO: add API cap - st = httpdns.HTTP(args.URL, clientMetadata.ClientIDs(), opts...) - } else { - st = llb.HTTP(args.URL, opts...) + if err != nil { + return nil, err } + st := httpdns.HTTP(args.URL, clientMetadata.ServerID, opts...) return core.NewFileSt(ctx, parent, st, filename, parent.Platform, svcs) } diff --git a/core/service.go b/core/service.go index 13e661bd68..8cbfab5505 100644 --- a/core/service.go +++ b/core/service.go @@ -260,7 +260,7 @@ func (svc *Service) startContainer( } }() - fullHost := host + "." + network.ClientDomain(clientMetadata.ClientID) + fullHost := host + "." + network.ClientDomain(clientMetadata.ServerID) bk := svc.Query.Buildkit @@ -328,8 +328,7 @@ func (svc *Service) startContainer( } execMeta := buildkit.ContainerExecUncachedMetadata{ - ParentClientIDs: clientMetadata.ClientIDs(), - ServerID: clientMetadata.ServerID, + ServerID: clientMetadata.ServerID, } execOp.Meta.ProxyEnv.FtpProxy, err = execMeta.ToPBFtpProxyVal() if err != nil { @@ -596,7 +595,7 @@ func (svc *Service) startReverseTunnel(ctx context.Context, id *call.ID) (runnin return nil, err } - fullHost := host + "." + network.ClientDomain(clientMetadata.ClientID) + fullHost := host + "." + network.ClientDomain(clientMetadata.ServerID) bk := svc.Query.Buildkit diff --git a/engine/buildkit/client.go b/engine/buildkit/client.go index a781be3181..08ffbc8491 100644 --- a/engine/buildkit/client.go +++ b/engine/buildkit/client.go @@ -714,8 +714,7 @@ func withOutgoingContext(ctx context.Context) context.Context { // the "real" ftp proxy setting in here too and have the shim handle // leaving only that set in the actual env var. type ContainerExecUncachedMetadata struct { - ParentClientIDs []string `json:"parentClientIDs,omitempty"` - ServerID string `json:"serverID,omitempty"` + ServerID string `json:"serverID,omitempty"` } func (md ContainerExecUncachedMetadata) ToPBFtpProxyVal() (string, error) { diff --git a/engine/client/client.go b/engine/client/client.go index 9baa9aef79..f60a011184 100644 --- a/engine/client/client.go +++ b/engine/client/client.go @@ -61,12 +61,6 @@ type Params struct { // be reused across multiple servers. ServerID string - // Parent client IDs of this Dagger client. - // - // Used by Dagger-in-Dagger so that nested sessions can resolve addresses - // passed from the parent. - ParentClientIDs []string - SecretToken string RunnerHost string // host of dagger engine runner serving buildkit apis @@ -327,7 +321,6 @@ func (c *Client) startSession(ctx context.Context) (rerr error) { ClientSecretToken: c.SecretToken, ClientHostname: c.hostname, Labels: c.labels, - ParentClientIDs: c.ParentClientIDs, }) // filesync @@ -356,7 +349,6 @@ func (c *Client) startSession(ctx context.Context) (rerr error) { ClientID: c.ID, ServerID: c.ServerID, ClientSecretToken: c.SecretToken, - ParentClientIDs: c.ParentClientIDs, ClientHostname: hostname, UpstreamCacheImportConfig: c.upstreamCacheImportOptions, UpstreamCacheExportConfig: c.upstreamCacheExportOptions, @@ -629,7 +621,6 @@ func (c *Client) DialContext(ctx context.Context, _, _ string) (conn net.Conn, e ServerID: c.ServerID, ClientSecretToken: c.SecretToken, ClientHostname: c.hostname, - ParentClientIDs: c.ParentClientIDs, Labels: c.labels, }.ToGRPCMD()) } diff --git a/engine/opts.go b/engine/opts.go index 6219b7d699..2e7866b16d 100644 --- a/engine/opts.go +++ b/engine/opts.go @@ -62,11 +62,6 @@ type ClientMetadata struct { // (Optional) Pipeline labels for e.g. vcs info like branch, commit, etc. Labels telemetry.Labels `json:"labels"` - // ParentClientIDs is a list of session ids that are parents of the current - // session. The first element is the direct parent, the second element is the - // parent of the parent, and so on. - ParentClientIDs []string `json:"parent_client_ids"` - // Import configuration for Buildkit's remote cache UpstreamCacheImportConfig []*controlapi.CacheOptionsEntry @@ -80,11 +75,6 @@ type ClientMetadata struct { DoNotTrack bool } -// ClientIDs returns the ClientID followed by ParentClientIDs. -func (m ClientMetadata) ClientIDs() []string { - return append([]string{m.ClientID}, m.ParentClientIDs...) -} - func (m ClientMetadata) ToGRPCMD() metadata.MD { return encodeMeta(clientMetadataMetaKey, m) } diff --git a/engine/sources/gitdns/identifier.go b/engine/sources/gitdns/identifier.go index ae78e18700..6ba5611ce2 100644 --- a/engine/sources/gitdns/identifier.go +++ b/engine/sources/gitdns/identifier.go @@ -4,10 +4,10 @@ import ( bkgit "github.com/moby/buildkit/source/git" ) -const AttrGitClientIDs = "dagger.git.clientids" +const AttrDNSNamespace = "dagger.dns.namespace" type GitIdentifier struct { bkgit.GitIdentifier - ClientIDs []string + Namespace string } diff --git a/engine/sources/gitdns/source.go b/engine/sources/gitdns/source.go index 31f8ab61e3..26e1ea508d 100644 --- a/engine/sources/gitdns/source.go +++ b/engine/sources/gitdns/source.go @@ -78,8 +78,8 @@ func (gs *gitSource) Identifier(scheme, ref string, attrs map[string]string, pla GitIdentifier: *(srcid.(*srcgit.GitIdentifier)), } - if v, ok := attrs[AttrGitClientIDs]; ok { - id.ClientIDs = strings.Split(v, ",") + if v, ok := attrs[AttrDNSNamespace]; ok { + id.Namespace = v } return id, nil @@ -322,10 +322,9 @@ func (gs *gitSourceHandler) mountKnownHosts() (string, func() error, error) { func (gs *gitSourceHandler) dnsConfig() *oci.DNSConfig { clientDomains := []string{} - for _, clientID := range gs.src.ClientIDs { - clientDomains = append(clientDomains, network.ClientDomain(clientID)) + if gs.src.Namespace != "" { + clientDomains = append(clientDomains, network.ClientDomain(gs.src.Namespace)) } - dns := *gs.dns dns.SearchDomains = append(clientDomains, dns.SearchDomains...) return &dns diff --git a/engine/sources/gitdns/state.go b/engine/sources/gitdns/state.go index 25192b3a4b..b5e6cbc9d1 100644 --- a/engine/sources/gitdns/state.go +++ b/engine/sources/gitdns/state.go @@ -2,7 +2,6 @@ package gitdns import ( "path" - "strings" "github.com/moby/buildkit/client/llb" "github.com/moby/buildkit/solver/pb" @@ -11,11 +10,9 @@ import ( "github.com/pkg/errors" ) -const AttrNetConfig = "gitdns.netconfig" - // Git is a helper mimicking the llb.Git function, but with the ability to // set additional attributes. -func Git(url, ref string, clientIDs []string, opts ...llb.GitOption) llb.State { +func Git(url, ref string, namespace string, opts ...llb.GitOption) llb.State { remote, err := gitutil.ParseURL(url) if errors.Is(err, gitutil.ErrUnknownProtocol) { url = "https://" + url @@ -78,7 +75,7 @@ func Git(url, ref string, clientIDs []string, opts ...llb.GitOption) llb.State { } } - attrs[AttrGitClientIDs] = strings.Join(clientIDs, ",") + attrs[AttrDNSNamespace] = namespace source := llb.NewSource("git://"+id, attrs, gi.Constraints) return llb.NewState(source.Output()) diff --git a/engine/sources/httpdns/identifier.go b/engine/sources/httpdns/identifier.go index 47275d6bef..b30ec9e7c9 100644 --- a/engine/sources/httpdns/identifier.go +++ b/engine/sources/httpdns/identifier.go @@ -4,10 +4,10 @@ import ( bkhttp "github.com/moby/buildkit/source/http" ) -const AttrHTTPClientIDs = "dagger.http.clientids" +const AttrDNSNamespace = "dagger.dns.namespace" type HTTPIdentifier struct { bkhttp.HTTPIdentifier - ClientIDs []string + Namespace string } diff --git a/engine/sources/httpdns/source.go b/engine/sources/httpdns/source.go index be27de8f63..c51201b2f5 100644 --- a/engine/sources/httpdns/source.go +++ b/engine/sources/httpdns/source.go @@ -76,8 +76,8 @@ func (hs *httpSource) Identifier(scheme, ref string, attrs map[string]string, pl HTTPIdentifier: *(srcid.(*srchttp.HTTPIdentifier)), } - if v, ok := attrs[AttrHTTPClientIDs]; ok { - id.ClientIDs = strings.Split(v, ",") + if v, ok := attrs[AttrDNSNamespace]; ok { + id.Namespace = v } return id, nil @@ -106,8 +106,8 @@ type httpSourceHandler struct { func (hs *httpSourceHandler) client(g session.Group) *http.Client { clientDomains := []string{} - for _, clientID := range hs.src.ClientIDs { - clientDomains = append(clientDomains, network.ClientDomain(clientID)) + if ns := hs.src.Namespace; ns != "" { + clientDomains = append(clientDomains, network.ClientDomain(ns)) } dns := *hs.dns diff --git a/engine/sources/httpdns/state.go b/engine/sources/httpdns/state.go index 458dff3f50..e980a95950 100644 --- a/engine/sources/httpdns/state.go +++ b/engine/sources/httpdns/state.go @@ -2,17 +2,14 @@ package httpdns import ( "strconv" - "strings" "github.com/moby/buildkit/client/llb" "github.com/moby/buildkit/solver/pb" ) -const AttrNetConfig = "httpdns.netconfig" - // HTTP is a helper mimicking the llb.HTTP function, but with the ability to // set additional attributes. -func HTTP(url string, clientIDs []string, opts ...llb.HTTPOption) llb.State { +func HTTP(url string, namespace string, opts ...llb.HTTPOption) llb.State { hi := &llb.HTTPInfo{} for _, o := range opts { o.SetHTTPOption(hi) @@ -34,7 +31,7 @@ func HTTP(url string, clientIDs []string, opts ...llb.HTTPOption) llb.State { attrs[pb.AttrHTTPGID] = strconv.Itoa(hi.GID) } - attrs[AttrHTTPClientIDs] = strings.Join(clientIDs, ",") + attrs[AttrDNSNamespace] = namespace source := llb.NewSource(url, attrs, hi.Constraints) return llb.NewState(source.Output()) From a8785b825c0f1aa7c7b342524146d714774bfd30 Mon Sep 17 00:00:00 2001 From: Erik Sipsma Date: Mon, 29 Apr 2024 18:38:16 -0700 Subject: [PATCH 3/5] fix routing of host services to correct client Signed-off-by: Erik Sipsma --- core/c2h.go | 2 ++ core/query.go | 9 +++--- core/schema/host.go | 59 +++++++++++++++++++++++++++++++++++++-- core/service.go | 3 ++ core/socket.go | 7 ++++- engine/buildkit/socket.go | 29 +++++++++++++++++-- 6 files changed, 98 insertions(+), 11 deletions(-) diff --git a/core/c2h.go b/core/c2h.go index b11dc09e7d..8a70d931ba 100644 --- a/core/c2h.go +++ b/core/c2h.go @@ -20,6 +20,7 @@ type c2hTunnel struct { upstreamHost string tunnelServiceHost string tunnelServicePorts []PortForward + sessionID string } func (d *c2hTunnel) Tunnel(ctx context.Context) (rerr error) { @@ -56,6 +57,7 @@ func (d *c2hTunnel) Tunnel(ctx context.Context) (rerr error) { upstream := NewHostIPSocket( port.Protocol.Network(), fmt.Sprintf("%s:%d", d.upstreamHost, port.Backend), + d.sessionID, ) sockPath := fmt.Sprintf("/upstream.%d.sock", frontend) diff --git a/core/query.go b/core/query.go index 09b9d7bc18..be086d49d5 100644 --- a/core/query.go +++ b/core/query.go @@ -290,11 +290,12 @@ func (q *Query) NewTunnelService(upstream dagql.Instance[*Service], ports []Port } } -func (q *Query) NewHostService(upstream string, ports []PortForward) *Service { +func (q *Query) NewHostService(upstream string, ports []PortForward, sessionID string) *Service { return &Service{ - Query: q, - HostUpstream: upstream, - HostPorts: ports, + Query: q, + HostUpstream: upstream, + HostPorts: ports, + HostSessionID: sessionID, } } diff --git a/core/schema/host.go b/core/schema/host.go index 6d9246c03a..c9cc745c21 100644 --- a/core/schema/host.go +++ b/core/schema/host.go @@ -160,6 +160,7 @@ func (s *hostSchema) Install() { `If ports are given and native is true, the ports are additive.`), dagql.Func("service", s.service). + Impure("Value depends on the caller as it points to their host."). Doc(`Creates a service that forwards traffic to a specified address via the host.`). ArgDoc("ports", `Ports to expose via the service, forwarding through the host network.`, @@ -168,6 +169,10 @@ func (s *hostSchema) Install() { `An empty set of ports is not valid; an error will be returned.`). ArgDoc("host", `Upstream host to forward traffic to.`), + // hidden from external clients via the __ prefix + dagql.Func("__internalService", s.internalService). + Doc(`(Internal-only) "service" but scoped to the exact right buildkit session ID.`), + dagql.Func("setSecretFile", s.setSecretFile). Impure("`setSecretFile` reads its value from the local machine."). Doc( @@ -279,10 +284,58 @@ type hostServiceArgs struct { Ports []dagql.InputObject[core.PortForward] } -func (s *hostSchema) service(ctx context.Context, parent *core.Host, args hostServiceArgs) (*core.Service, error) { +func (s *hostSchema) service(ctx context.Context, parent *core.Host, args hostServiceArgs) (inst dagql.Instance[*core.Service], err error) { if len(args.Ports) == 0 { - return nil, errors.New("no ports specified") + return inst, errors.New("no ports specified") + } + + clientMetadata, err := engine.ClientMetadataFromContext(ctx) + if err != nil { + return inst, fmt.Errorf("failed to get client metadata: %w", err) + } + + portsArg := make(dagql.ArrayInput[dagql.InputObject[core.PortForward]], len(args.Ports)) + copy(portsArg, args.Ports) + + err = s.srv.Select(ctx, s.srv.Root(), &inst, + dagql.Selector{ + Field: "host", + }, + dagql.Selector{ + Field: "__internalService", + Args: []dagql.NamedInput{ + { + Name: "host", + Value: dagql.NewString(args.Host), + }, + { + Name: "ports", + Value: portsArg, + }, + { + Name: "sessionId", + Value: dagql.NewString(clientMetadata.BuildkitSessionID()), + }, + }, + }, + ) + return inst, err +} + +type hostInternalServiceArgs struct { + Host string `default:"localhost"` + Ports []dagql.InputObject[core.PortForward] + SessionID string +} + +func (s *hostSchema) internalService(ctx context.Context, parent *core.Host, args hostInternalServiceArgs) (*core.Service, error) { + if args.SessionID == "" { + return nil, errors.New("no session ID specified") } - return parent.Query.NewHostService(args.Host, collectInputsSlice(args.Ports)), nil + return parent.Query.NewHostService( + args.Host, + collectInputsSlice(args.Ports), + args.SessionID, + ), nil } diff --git a/core/service.go b/core/service.go index 8cbfab5505..c779141768 100644 --- a/core/service.go +++ b/core/service.go @@ -44,6 +44,8 @@ type Service struct { HostUpstream string `json:"reverse_tunnel_upstream_addr,omitempty"` // HostPorts configures the port forwarding rules for the host. HostPorts []PortForward `json:"host_ports,omitempty"` + // HostSessionID is the session ID of the host (could differ from main client in the case of nested execs). + HostSessionID string `json:"host_session_id,omitempty"` } func (*Service) Type() *ast.Type { @@ -604,6 +606,7 @@ func (svc *Service) startReverseTunnel(ctx context.Context, id *call.ID) (runnin upstreamHost: svc.HostUpstream, tunnelServiceHost: fullHost, tunnelServicePorts: svc.HostPorts, + sessionID: svc.HostSessionID, } checkPorts := []Port{} diff --git a/core/socket.go b/core/socket.go index a5b7ad1ff5..5b6958653c 100644 --- a/core/socket.go +++ b/core/socket.go @@ -17,6 +17,9 @@ type Socket struct { // IP HostProtocol string `json:"host_protocol,omitempty"` HostAddr string `json:"host_addr,omitempty"` + + // The session ID of the host's client + SessionID string `json:"session_id,omitempty"` } func (*Socket) Type() *ast.Type { @@ -36,10 +39,11 @@ func NewHostUnixSocket(absPath string) *Socket { } } -func NewHostIPSocket(proto string, addr string) *Socket { +func NewHostIPSocket(proto string, addr string, sessionID string) *Socket { return &Socket{ HostAddr: addr, HostProtocol: proto, + SessionID: sessionID, } } @@ -52,6 +56,7 @@ func (socket *Socket) SSHID() string { default: u.Scheme = socket.HostProtocol u.Host = socket.HostAddr + u.Fragment = socket.SessionID } return u.String() } diff --git a/engine/buildkit/socket.go b/engine/buildkit/socket.go index 3f8c6bf502..789c6d336d 100644 --- a/engine/buildkit/socket.go +++ b/engine/buildkit/socket.go @@ -2,11 +2,15 @@ package buildkit import ( "context" + "fmt" + "net/url" "github.com/moby/buildkit/session/sshforward" "go.opentelemetry.io/otel/trace" "google.golang.org/grpc" + "google.golang.org/grpc/codes" "google.golang.org/grpc/metadata" + "google.golang.org/grpc/status" ) type socketProxy struct { @@ -29,10 +33,29 @@ func (p *socketProxy) ForwardAgent(stream sshforward.SSH_ForwardAgentServer) err ctx = trace.ContextWithSpanContext(ctx, p.c.spanCtx) // ensure server's span context is propagated - incomingMD, _ := metadata.FromIncomingContext(ctx) - ctx = metadata.NewOutgoingContext(ctx, incomingMD) + opts, _ := metadata.FromIncomingContext(ctx) + ctx = metadata.NewOutgoingContext(ctx, opts) - forwardAgentClient, err := sshforward.NewSSHClient(p.c.MainClientCaller.Conn()).ForwardAgent(ctx) + var connURL *url.URL + if v, ok := opts[sshforward.KeySSHID]; ok && len(v) > 0 && v[0] != "" { + var err error + connURL, err = url.Parse(v[0]) + if err != nil { + return status.Errorf(codes.InvalidArgument, "invalid id: %s", err) + } + } + + caller := p.c.MainClientCaller + if connURL != nil && connURL.Fragment != "" { + sessionID := connURL.Fragment + var err error + caller, err = p.c.SessionManager.Get(ctx, sessionID, true) + if err != nil { + return fmt.Errorf("failed to get session: %w", err) + } + } + + forwardAgentClient, err := sshforward.NewSSHClient(caller.Conn()).ForwardAgent(ctx) if err != nil { return err } From df0753150527377c2fb9f8f0f996e08ce1fd2b61 Mon Sep 17 00:00:00 2001 From: Erik Sipsma Date: Mon, 29 Apr 2024 19:54:41 -0700 Subject: [PATCH 4/5] deregister secret tokens once client disconnects We previously never explicitly removed client ID -> secret token mappings because it theoretically opened more possibilities for malicious attempts to register a client ID with a different token. However, we need to deregister these now since Client IDs are a content hash of the function call/nested exec definition, which means the same client ID can connect and disconnect multiple times per server. The security implications of this also end up being extremely minimal. Registering a client ID with a different secret token was and still is possible *before* a client fully connects. It is possible to after a client disconnects now but this would only amount to a DOS since the "real" client would just be unable to connect. No information would be leaked. It also would have to be in the same server (i.e. a module or nested exec called by the main client directly or transitively). This issue can also be squashed by not leaking the buildkit sock to nested execs/modules, which is possible now by migrating functionality from our shim to our custom executor. There's no immediate plans to do this but the possibility is open whenever needed (or when we make that change for other reasons). Signed-off-by: Erik Sipsma --- core/container.go | 9 ++++++++- core/query.go | 7 +++++-- engine/server/buildkitcontroller.go | 6 ++++++ engine/server/server.go | 11 +++++++---- 4 files changed, 26 insertions(+), 7 deletions(-) diff --git a/core/container.go b/core/container.go index 9d2c976d8b..e8e82c0d66 100644 --- a/core/container.go +++ b/core/container.go @@ -1007,7 +1007,14 @@ func (container *Container) WithExec(ctx context.Context, opts ContainerExecOpts // this allows executed containers to communicate back to this API if opts.ExperimentalPrivilegedNesting { - clientID, err := container.Query.RegisterCaller(ctx, opts.NestedExecFunctionCall) + callerOpts := opts.NestedExecFunctionCall + if callerOpts == nil { + // default to caching the nested exec + callerOpts = &FunctionCall{ + Cache: true, + } + } + clientID, err := container.Query.RegisterCaller(ctx, callerOpts) if err != nil { return nil, fmt.Errorf("register caller: %w", err) } diff --git a/core/query.go b/core/query.go index be086d49d5..22114d5cac 100644 --- a/core/query.go +++ b/core/query.go @@ -18,6 +18,7 @@ import ( "github.com/dagger/dagger/dagql/call" "github.com/dagger/dagger/engine" "github.com/dagger/dagger/engine/buildkit" + "github.com/dagger/dagger/engine/slog" ) // Query forms the root of the DAG and houses all necessary state and @@ -154,8 +155,10 @@ func (q *Query) RegisterCaller(ctx context.Context, call *FunctionCall) (string, // also trim it to 25 chars as it ends up becoming part of service URLs clientID := clientIDDigest.Encoded()[:25] - // break glass for debugging which client is which operation - // bklog.G(ctx).Debugf("CLIENT ID %s = %s", clientID, currentID.Display()) + slog.ExtraDebug("registering nested caller", + "client_id", clientID, + "op", currentID.Display(), + ) if call.Module == nil { callCtx.Deps = q.DefaultDeps diff --git a/engine/server/buildkitcontroller.go b/engine/server/buildkitcontroller.go index a3ad5a58a2..450df5c6e9 100644 --- a/engine/server/buildkitcontroller.go +++ b/engine/server/buildkitcontroller.go @@ -267,6 +267,12 @@ func (e *BuildkitController) Session(stream controlapi.Control_SessionServer) (r if err != nil { return fmt.Errorf("failed to register client: %w", err) } + defer func() { + err := srv.UnregisterClient(opts.ClientID) + if err != nil { + slog.Error("failed to unregister client", "err", err) + } + }() eg.Go(func() error { bklog.G(ctx).Trace("waiting for server") diff --git a/engine/server/server.go b/engine/server/server.go index ffcb20465a..f7bca64d2d 100644 --- a/engine/server/server.go +++ b/engine/server/server.go @@ -382,10 +382,6 @@ func (s *DaggerServer) RegisterClient(clientID, clientHostname, secretToken stri return nil } s.clientIDToSecretToken[clientID] = secretToken - // NOTE: we purposely don't delete the secret token, it should never be reused and will be released - // from memory once the dagger server instance corresponding to this buildkit client shuts down. - // Deleting it would make it easier to create race conditions around using the client's session - // before it is fully closed. return nil } @@ -403,6 +399,13 @@ func (s *DaggerServer) VerifyClient(clientID, secretToken string) error { return nil } +func (s *DaggerServer) UnregisterClient(clientID string) error { + s.clientIDMu.Lock() + defer s.clientIDMu.Unlock() + delete(s.clientIDToSecretToken, clientID) + return nil +} + func (s *DaggerServer) ClientCallContext(clientID string) (*core.ClientCallContext, bool) { s.clientCallMu.RLock() defer s.clientCallMu.RUnlock() From 67c105096061db66a00ee3fdf87aec6588ff9151 Mon Sep 17 00:00:00 2001 From: Erik Sipsma Date: Mon, 29 Apr 2024 20:42:20 -0700 Subject: [PATCH 5/5] add integ test coverage Signed-off-by: Erik Sipsma --- core/integration/module_test.go | 112 ++++++++++++++++++++++++++++++++ 1 file changed, 112 insertions(+) diff --git a/core/integration/module_test.go b/core/integration/module_test.go index 7e1a2dcf04..b1ea971a8c 100644 --- a/core/integration/module_test.go +++ b/core/integration/module_test.go @@ -5628,6 +5628,118 @@ func TestModuleUnicodePath(t *testing.T) { require.JSONEq(t, `{"test":{"hello":"hello"}}`, out) } +func TestModuleStartServices(t *testing.T) { + t.Parallel() + + // regression test for https://github.com/dagger/dagger/pull/6914 + t.Run("use service in multiple functions", func(t *testing.T) { + t.Parallel() + c, ctx := connect(t) + + out, err := c.Container().From(golangImage). + WithMountedFile(testCLIBinPath, daggerCliFile(t, c)). + WithWorkdir("/work"). + With(daggerExec("init", "--source=.", "--name=test", "--sdk=go")). + WithNewFile("/work/main.go", dagger.ContainerWithNewFileOpts{ + Contents: `package main + import ( + "context" + "fmt" + ) + + type Test struct { + } + + func (m *Test) FnA(ctx context.Context) (*Sub, error) { + svc := dag.Container(). + From("python"). + WithMountedDirectory( + "/srv/www", + dag.Directory().WithNewFile("index.html", "hey there"), + ). + WithWorkdir("/srv/www"). + WithExposedPort(23457). + WithExec([]string{"python", "-m", "http.server", "23457"}). + AsService() + + ctr := dag.Container(). + From("alpine:3.18.6"). + WithServiceBinding("svc", svc). + WithExec([]string{"wget", "-O", "-", "http://svc:23457"}) + + out, err := ctr.Stdout(ctx) + if err != nil { + return nil, err + } + if out != "hey there" { + return nil, fmt.Errorf("unexpected output: %q", out) + } + return &Sub{Ctr: ctr}, nil + } + + type Sub struct { + Ctr *Container + } + + func (m *Sub) FnB(ctx context.Context) (string, error) { + return m.Ctr. + WithExec([]string{"wget", "-O", "-", "http://svc:23457"}). + Stdout(ctx) + } + `, + }). + With(daggerCall("fnA", "fnB")). + Stdout(ctx) + require.NoError(t, err) + require.Equal(t, "hey there", strings.TrimSpace(out)) + }) + + // regression test for https://github.com/dagger/dagger/issues/6951 + t.Run("service in multiple containers", func(t *testing.T) { + t.Parallel() + c, ctx := connect(t) + + _, err := c.Container().From(golangImage). + WithMountedFile(testCLIBinPath, daggerCliFile(t, c)). + WithWorkdir("/work"). + With(daggerExec("init", "--source=.", "--name=test", "--sdk=go")). + WithNewFile("/work/main.go", dagger.ContainerWithNewFileOpts{ + Contents: `package main +import ( + "context" +) + +type Test struct { +} + +func (m *Test) Fn(ctx context.Context) *Container { + redis := dag.Container(). + From("redis"). + WithExposedPort(6379). + AsService() + cli := dag.Container(). + From("redis"). + WithoutEntrypoint(). + WithServiceBinding("redis", redis) + + ctrA := cli.WithExec([]string{"sh", "-c", "redis-cli -h redis info >> /tmp/out.txt"}) + + file := ctrA.Directory("/tmp").File("/out.txt") + + ctrB := dag.Container(). + From("alpine"). + WithFile("/out.txt", file) + + return ctrB.WithExec([]string{"cat", "/out.txt"}) +} + `, + }). + With(daggerCall("fn", "stdout")). + Sync(ctx) + require.NoError(t, err) + }) +} + func daggerExec(args ...string) dagger.WithContainerFunc { return func(c *dagger.Container) *dagger.Container { return c.WithExec(append([]string{"dagger", "--debug"}, args...), dagger.ContainerWithExecOpts{