From b872da9c169b3c811977048ee5680a1365f90fc7 Mon Sep 17 00:00:00 2001 From: Rick Staa Date: Thu, 28 Nov 2024 10:13:15 +0100 Subject: [PATCH] feat(worker): auto pull images if not found (#200) This commit provides users with the ability to pull the docker images if they are not found locally. --- go.mod | 3 +- go.sum | 7 +- worker/b64_test.go | 69 +++- worker/container.go | 5 +- worker/docker.go | 151 ++++++-- worker/docker_test.go | 786 ++++++++++++++++++++++++++++++++++++++++++ worker/worker.go | 13 +- 7 files changed, 988 insertions(+), 46 deletions(-) create mode 100644 worker/docker_test.go diff --git a/go.mod b/go.mod index e86fbdee..407c41d6 100644 --- a/go.mod +++ b/go.mod @@ -10,11 +10,13 @@ require ( github.com/getkin/kin-openapi v0.128.0 github.com/go-chi/chi/v5 v5.1.0 github.com/oapi-codegen/runtime v1.1.1 + github.com/opencontainers/image-spec v1.1.0 github.com/stretchr/testify v1.9.0 github.com/vincent-petithory/dataurl v1.0.0 ) require ( + github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 // indirect github.com/Microsoft/go-winio v0.6.2 // indirect github.com/apapsch/go-jsonmerge/v2 v2.0.0 // indirect github.com/containerd/log v0.1.0 // indirect @@ -36,7 +38,6 @@ require ( github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826 // indirect github.com/morikuni/aec v1.0.0 // indirect github.com/opencontainers/go-digest v1.0.0 // indirect - github.com/opencontainers/image-spec v1.1.0 // indirect github.com/perimeterx/marshmallow v1.1.5 // indirect github.com/pkg/errors v0.9.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect diff --git a/go.sum b/go.sum index 27454361..96c9f9f1 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,5 @@ -github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 h1:UQHMgLO+TxOElx5B5HZ4hJQsoJ/PvUvKRhJHDQXO8P8= -github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E= +github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 h1:L/gRVlceqvL25UVaW/CKtUDjefjrs0SPonmDGUVOYP0= +github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E= github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY= github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU= github.com/RaveNoX/go-jsoncommentstrip v1.0.0/go.mod h1:78ihd09MekBnJnxpICcwzCMzGrKSKYe4AqU6PDYYpjk= @@ -10,6 +10,8 @@ github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK3 github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE= github.com/containerd/log v0.1.0 h1:TCJt7ioM2cr/tfR8GPbGf9/VRAX8D2B4PjzCpfX540I= github.com/containerd/log v0.1.0/go.mod h1:VRRf09a7mHDIRezVKTRCrOq78v577GXq3bSa3EhrzVo= +github.com/creack/pty v1.1.18 h1:n56/Zwd5o6whRC5PMGretI4IdRLlmBXYNjScPaBgsbY= +github.com/creack/pty v1.1.18/go.mod h1:MOBLtS5ELjhRRrroQr9kyvTxUAFNvYEK993ew/Vr4O4= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -134,6 +136,7 @@ golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210616094352-59db8d763f22/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.27.0 h1:wBqf8DvsY9Y/2P8gAfPDEYNuS30J4lPHJxXSb/nJZ+s= golang.org/x/sys v0.27.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= diff --git a/worker/b64_test.go b/worker/b64_test.go index 2a18861e..44332963 100644 --- a/worker/b64_test.go +++ b/worker/b64_test.go @@ -3,6 +3,9 @@ package worker import ( "bytes" "encoding/base64" + "image" + "image/color" + "image/png" "os" "testing" @@ -10,31 +13,65 @@ import ( ) func TestReadImageB64DataUrl(t *testing.T) { - // Create a sample PNG image and encode it as a data URL - imgData := []byte{ - 0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A, // PNG header - // ... (rest of the PNG data) + tests := []struct { + name string + dataURL string + expectError bool + }{ + { + name: "Valid PNG Image", + dataURL: func() string { + img := image.NewRGBA(image.Rect(0, 0, 1, 1)) + img.Set(0, 0, color.RGBA{255, 0, 0, 255}) // Set a single red pixel + var imgBuf bytes.Buffer + err := png.Encode(&imgBuf, img) + require.NoError(t, err) + + return "data:image/png;base64," + base64.StdEncoding.EncodeToString(imgBuf.Bytes()) + }(), + expectError: false, + }, + { + name: "Unsupported Image Format", + dataURL: "data:image/bmp;base64," + base64.StdEncoding.EncodeToString([]byte{ + 0x42, 0x4D, // BMP header + // ... (rest of the BMP data) + }), + expectError: true, + }, + { + name: "Invalid Data URL", + dataURL: "invalid-data-url", + expectError: true, + }, } - dataURL := "data:image/png;base64," + base64.StdEncoding.EncodeToString(imgData) - var buf bytes.Buffer - err := ReadImageB64DataUrl(dataURL, &buf) - require.NoError(t, err) - require.NotEmpty(t, buf.Bytes()) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var buf bytes.Buffer + err := ReadImageB64DataUrl(tt.dataURL, &buf) + if tt.expectError { + require.Error(t, err) + } else { + require.NoError(t, err) + require.NotEmpty(t, buf.Bytes()) + } + }) + } } func TestSaveImageB64DataUrl(t *testing.T) { - // Create a sample PNG image and encode it as a data URL - imgData := []byte{ - 0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A, // PNG header - // ... (rest of the PNG data) - } - dataURL := "data:image/png;base64," + base64.StdEncoding.EncodeToString(imgData) + img := image.NewRGBA(image.Rect(0, 0, 1, 1)) + img.Set(0, 0, color.RGBA{255, 0, 0, 255}) // Set a single red pixel + var imgBuf bytes.Buffer + err := png.Encode(&imgBuf, img) + require.NoError(t, err) + dataURL := "data:image/png;base64," + base64.StdEncoding.EncodeToString(imgBuf.Bytes()) outputPath := "test_output.png" defer os.Remove(outputPath) - err := SaveImageB64DataUrl(dataURL, outputPath) + err = SaveImageB64DataUrl(dataURL, outputPath) require.NoError(t, err) // Verify that the file was created and is not empty diff --git a/worker/container.go b/worker/container.go index 5918d7fc..1396e7f8 100644 --- a/worker/container.go +++ b/worker/container.go @@ -42,6 +42,9 @@ type RunnerContainerConfig struct { containerTimeout time.Duration } +// Create global references to functions to allow for mocking in tests. +var runnerWaitUntilReadyFunc = runnerWaitUntilReady + func NewRunnerContainer(ctx context.Context, cfg RunnerContainerConfig, name string) (*RunnerContainer, error) { // Ensure that timeout is set to a non-zero value. timeout := cfg.containerTimeout @@ -66,7 +69,7 @@ func NewRunnerContainer(ctx context.Context, cfg RunnerContainerConfig, name str cctx, cancel := context.WithTimeout(ctx, cfg.containerTimeout) defer cancel() - if err := runnerWaitUntilReady(cctx, client, pollingInterval); err != nil { + if err := runnerWaitUntilReadyFunc(cctx, client, pollingInterval); err != nil { return nil, err } diff --git a/worker/docker.go b/worker/docker.go index 443862d1..f6dc49d6 100644 --- a/worker/docker.go +++ b/worker/docker.go @@ -2,20 +2,27 @@ package worker import ( "context" + "encoding/json" "errors" "fmt" + "io" "log/slog" "strings" "sync" "time" "github.com/docker/cli/opts" + "github.com/docker/docker/api/types" "github.com/docker/docker/api/types/container" "github.com/docker/docker/api/types/filters" + "github.com/docker/docker/api/types/image" "github.com/docker/docker/api/types/mount" + "github.com/docker/docker/api/types/network" docker "github.com/docker/docker/client" "github.com/docker/docker/errdefs" + "github.com/docker/docker/pkg/jsonmessage" "github.com/docker/go-connections/nat" + ocispec "github.com/opencontainers/image-spec/specs-go/v1" ) const containerModelDir = "/models" @@ -27,7 +34,8 @@ const optFlagsContainerTimeout = 5 * time.Minute const containerRemoveTimeout = 30 * time.Second const containerCreatorLabel = "creator" const containerCreator = "ai-worker" -const containerWatchInterval = 10 * time.Second + +var containerWatchInterval = 10 * time.Second // This only works right now on a single GPU because if there is another container // using the GPU we stop it so we don't have to worry about having enough ports @@ -57,12 +65,31 @@ var livePipelineToImage = map[string]string{ "noop": "livepeer/ai-runner:live-app-noop", } +// DockerClient is an interface for the Docker client, allowing for mocking in tests. +// NOTE: ensure any docker.Client methods used in this package are added. +type DockerClient interface { + ContainerCreate(ctx context.Context, config *container.Config, hostConfig *container.HostConfig, networkingConfig *network.NetworkingConfig, platform *ocispec.Platform, containerName string) (container.CreateResponse, error) + ContainerInspect(ctx context.Context, containerID string) (types.ContainerJSON, error) + ContainerList(ctx context.Context, options container.ListOptions) ([]types.Container, error) + ContainerRemove(ctx context.Context, containerID string, options container.RemoveOptions) error + ContainerStart(ctx context.Context, containerID string, options container.StartOptions) error + ContainerStop(ctx context.Context, containerID string, options container.StopOptions) error + ImageInspectWithRaw(ctx context.Context, imageID string) (types.ImageInspect, []byte, error) + ImagePull(ctx context.Context, ref string, options image.PullOptions) (io.ReadCloser, error) +} + +// Compile-time assertion to ensure docker.Client implements DockerClient. +var _ DockerClient = (*docker.Client)(nil) + +// Create global references to functions to allow for mocking in tests. +var dockerWaitUntilRunningFunc = dockerWaitUntilRunning + type DockerManager struct { defaultImage string gpus []string modelDir string - dockerClient *docker.Client + dockerClient DockerClient // gpu ID => container name gpuContainers map[string]string // container name => container @@ -70,28 +97,44 @@ type DockerManager struct { mu *sync.Mutex } -func NewDockerManager(defaultImage string, gpus []string, modelDir string) (*DockerManager, error) { - dockerClient, err := docker.NewClientWithOpts(docker.FromEnv, docker.WithAPIVersionNegotiation()) - if err != nil { - return nil, err - } - +func NewDockerManager(defaultImage string, gpus []string, modelDir string, client DockerClient) (*DockerManager, error) { ctx, cancel := context.WithTimeout(context.Background(), containerTimeout) - if err := removeExistingContainers(ctx, dockerClient); err != nil { + if err := removeExistingContainers(ctx, client); err != nil { cancel() return nil, err } cancel() - return &DockerManager{ + manager := &DockerManager{ defaultImage: defaultImage, gpus: gpus, modelDir: modelDir, - dockerClient: dockerClient, + dockerClient: client, gpuContainers: make(map[string]string), containers: make(map[string]*RunnerContainer), mu: &sync.Mutex{}, - }, nil + } + + return manager, nil +} + +// EnsureImageAvailable ensures the container image is available locally for the given pipeline and model ID. +func (m *DockerManager) EnsureImageAvailable(ctx context.Context, pipeline string, modelID string) error { + imageName, err := m.getContainerImageName(pipeline, modelID) + if err != nil { + return err + } + + // Pull the image if it is not available locally. + if !m.isImageAvailable(ctx, pipeline, modelID) { + slog.Info(fmt.Sprintf("Pulling image for pipeline %s and modelID %s: %s", pipeline, modelID, imageName)) + err = m.pullImage(ctx, imageName) + if err != nil { + return err + } + } + + return nil } func (m *DockerManager) Warm(ctx context.Context, pipeline string, modelID string, optimizationFlags OptimizationFlags) error { @@ -157,6 +200,24 @@ func (m *DockerManager) returnContainer(rc *RunnerContainer) { m.containers[rc.Name] = rc } +// getContainerImageName returns the image name for the given pipeline and model ID. +// Returns an error if the image is not found for "live-video-to-video". +func (m *DockerManager) getContainerImageName(pipeline, modelID string) (string, error) { + if pipeline == "live-video-to-video" { + // We currently use the model ID as the live pipeline name for legacy reasons. + if image, ok := livePipelineToImage[modelID]; ok { + return image, nil + } + return "", fmt.Errorf("no container image found for live pipeline %s", modelID) + } + + if image, ok := pipelineToImage[pipeline]; ok { + return image, nil + } + + return m.defaultImage, nil +} + // HasCapacity checks if an unused managed container exists or if a GPU is available for a new container. func (m *DockerManager) HasCapacity(ctx context.Context, pipeline, modelID string) bool { m.mu.Lock() @@ -169,11 +230,57 @@ func (m *DockerManager) HasCapacity(ctx context.Context, pipeline, modelID strin } } + // TODO: This can be removed if we optimize the selection algorithm. + // Currently, using CreateContainer errors only can cause orchestrator reselection. + if !m.isImageAvailable(ctx, pipeline, modelID) { + return false + } + // Check for available GPU to allocate for a new container for the requested model. _, err := m.allocGPU(ctx) return err == nil } +// isImageAvailable checks if the specified image is available locally. +func (m *DockerManager) isImageAvailable(ctx context.Context, pipeline string, modelID string) bool { + imageName, err := m.getContainerImageName(pipeline, modelID) + if err != nil { + slog.Error(err.Error()) + return false + } + + _, _, err = m.dockerClient.ImageInspectWithRaw(ctx, imageName) + if err != nil { + slog.Error(fmt.Sprintf("Image for pipeline %s and modelID %s is not available locally: %s", pipeline, modelID, imageName)) + } + return err == nil +} + +// pullImage pulls the specified image from the registry. +func (m *DockerManager) pullImage(ctx context.Context, imageName string) error { + reader, err := m.dockerClient.ImagePull(ctx, imageName, image.PullOptions{}) + if err != nil { + return fmt.Errorf("failed to pull image: %w", err) + } + defer reader.Close() + + // Display progress messages from ImagePull reader. + decoder := json.NewDecoder(reader) + for { + var progress jsonmessage.JSONMessage + if err := decoder.Decode(&progress); err == io.EOF { + break + } else if err != nil { + return fmt.Errorf("error decoding progress message: %w", err) + } + if progress.Status != "" && progress.Progress != nil { + slog.Info(fmt.Sprintf("%s: %s", progress.Status, progress.Progress.String())) + } + } + + return nil +} + func (m *DockerManager) createContainer(ctx context.Context, pipeline string, modelID string, keepWarm bool, optimizationFlags OptimizationFlags) (*RunnerContainer, error) { gpu, err := m.allocGPU(ctx) if err != nil { @@ -183,15 +290,9 @@ func (m *DockerManager) createContainer(ctx context.Context, pipeline string, mo // NOTE: We currently allow only one container per GPU for each pipeline. containerHostPort := containerHostPorts[pipeline][:3] + gpu containerName := dockerContainerName(pipeline, modelID, containerHostPort) - containerImage := m.defaultImage - if pipelineSpecificImage, ok := pipelineToImage[pipeline]; ok { - containerImage = pipelineSpecificImage - } else if pipeline == "live-video-to-video" { - // We currently use the model ID as the live pipeline name for legacy reasons - containerImage = livePipelineToImage[modelID] - if containerImage == "" { - return nil, fmt.Errorf("no container image found for live pipeline %s", modelID) - } + containerImage, err := m.getContainerImageName(pipeline, modelID) + if err != nil { + return nil, err } slog.Info("Starting managed container", slog.String("gpu", gpu), slog.String("name", containerName), slog.String("modelID", modelID), slog.String("containerImage", containerImage)) @@ -258,7 +359,7 @@ func (m *DockerManager) createContainer(ctx context.Context, pipeline string, mo cancel() cctx, cancel = context.WithTimeout(ctx, containerTimeout) - if err := dockerWaitUntilRunning(cctx, m.dockerClient, resp.ID, pollingInterval); err != nil { + if err := dockerWaitUntilRunningFunc(cctx, m.dockerClient, resp.ID, pollingInterval); err != nil { cancel() dockerRemoveContainer(m.dockerClient, resp.ID) return nil, err @@ -390,7 +491,7 @@ func (m *DockerManager) watchContainer(rc *RunnerContainer, borrowCtx context.Co } } -func removeExistingContainers(ctx context.Context, client *docker.Client) error { +func removeExistingContainers(ctx context.Context, client DockerClient) error { filters := filters.NewArgs(filters.Arg("label", containerCreatorLabel+"="+containerCreator)) containers, err := client.ContainerList(ctx, container.ListOptions{All: true, Filters: filters}) if err != nil { @@ -416,7 +517,7 @@ func dockerContainerName(pipeline string, modelID string, suffix ...string) stri return fmt.Sprintf("%s_%s", pipeline, sanitizedModelID) } -func dockerRemoveContainer(client *docker.Client, containerID string) error { +func dockerRemoveContainer(client DockerClient, containerID string) error { ctx, cancel := context.WithTimeout(context.Background(), containerRemoveTimeout) defer cancel() @@ -449,7 +550,7 @@ func dockerRemoveContainer(client *docker.Client, containerID string) error { } } -func dockerWaitUntilRunning(ctx context.Context, client *docker.Client, containerID string, pollingInterval time.Duration) error { +func dockerWaitUntilRunning(ctx context.Context, client DockerClient, containerID string, pollingInterval time.Duration) error { ticker := time.NewTicker(pollingInterval) defer ticker.Stop() diff --git a/worker/docker_test.go b/worker/docker_test.go new file mode 100644 index 00000000..cbb20086 --- /dev/null +++ b/worker/docker_test.go @@ -0,0 +1,786 @@ +package worker + +import ( + "context" + "fmt" + "io" + "strings" + "sync" + "testing" + "time" + + "github.com/docker/docker/api/types" + "github.com/docker/docker/api/types/container" + "github.com/docker/docker/api/types/image" + "github.com/docker/docker/api/types/network" + ocispec "github.com/opencontainers/image-spec/specs-go/v1" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" +) + +type MockDockerClient struct { + mock.Mock +} + +func (m *MockDockerClient) ImagePull(ctx context.Context, ref string, options image.PullOptions) (io.ReadCloser, error) { + args := m.Called(ctx, ref, options) + return args.Get(0).(io.ReadCloser), args.Error(1) +} + +func (m *MockDockerClient) ImageInspectWithRaw(ctx context.Context, imageID string) (types.ImageInspect, []byte, error) { + args := m.Called(ctx, imageID) + return args.Get(0).(types.ImageInspect), args.Get(1).([]byte), args.Error(2) +} + +func (m *MockDockerClient) ContainerCreate(ctx context.Context, config *container.Config, hostConfig *container.HostConfig, networkingConfig *network.NetworkingConfig, platform *ocispec.Platform, containerName string) (container.CreateResponse, error) { + args := m.Called(ctx, config, hostConfig, networkingConfig, platform, containerName) + return args.Get(0).(container.CreateResponse), args.Error(1) +} + +func (m *MockDockerClient) ContainerStart(ctx context.Context, containerID string, options container.StartOptions) error { + args := m.Called(ctx, containerID, options) + return args.Error(0) +} + +func (m *MockDockerClient) ContainerInspect(ctx context.Context, containerID string) (types.ContainerJSON, error) { + args := m.Called(ctx, containerID) + return args.Get(0).(types.ContainerJSON), args.Error(1) +} + +func (m *MockDockerClient) ContainerList(ctx context.Context, options container.ListOptions) ([]types.Container, error) { + args := m.Called(ctx, options) + return args.Get(0).([]types.Container), args.Error(1) +} + +func (m *MockDockerClient) ContainerStop(ctx context.Context, containerID string, options container.StopOptions) error { + args := m.Called(ctx, containerID, options) + return args.Error(0) +} + +func (m *MockDockerClient) ContainerRemove(ctx context.Context, containerID string, options container.RemoveOptions) error { + args := m.Called(ctx, containerID, options) + return args.Error(0) +} + +// createDockerManager creates a DockerManager with a mock DockerClient. +func createDockerManager(mockDockerClient *MockDockerClient) *DockerManager { + return &DockerManager{ + defaultImage: "default-image", + gpus: []string{"gpu0"}, + modelDir: "/models", + dockerClient: mockDockerClient, + gpuContainers: make(map[string]string), + containers: make(map[string]*RunnerContainer), + mu: &sync.Mutex{}, + } +} + +func TestNewDockerManager(t *testing.T) { + mockDockerClient := new(MockDockerClient) + + createAndVerifyManager := func() *DockerManager { + manager, err := NewDockerManager("default-image", []string{"gpu0"}, "/models", mockDockerClient) + require.NoError(t, err) + require.NotNil(t, manager) + require.Equal(t, "default-image", manager.defaultImage) + require.Equal(t, []string{"gpu0"}, manager.gpus) + require.Equal(t, "/models", manager.modelDir) + require.Equal(t, mockDockerClient, manager.dockerClient) + return manager + } + + t.Run("NoExistingContainers", func(t *testing.T) { + mockDockerClient.On("ContainerList", mock.Anything, mock.Anything).Return([]types.Container{}, nil).Once() + createAndVerifyManager() + mockDockerClient.AssertNotCalled(t, "ContainerStop", mock.Anything, mock.Anything, mock.Anything) + mockDockerClient.AssertNotCalled(t, "ContainerRemove", mock.Anything, mock.Anything, mock.Anything) + mockDockerClient.AssertExpectations(t) + }) + + t.Run("ExistingContainers", func(t *testing.T) { + // Mock client methods to simulate the removal of existing containers. + existingContainers := []types.Container{ + {ID: "container1", Names: []string{"/container1"}}, + {ID: "container2", Names: []string{"/container2"}}, + } + mockDockerClient.On("ContainerList", mock.Anything, mock.Anything).Return(existingContainers, nil) + mockDockerClient.On("ContainerStop", mock.Anything, "container1", mock.Anything).Return(nil) + mockDockerClient.On("ContainerStop", mock.Anything, "container2", mock.Anything).Return(nil) + mockDockerClient.On("ContainerRemove", mock.Anything, "container1", mock.Anything).Return(nil) + mockDockerClient.On("ContainerRemove", mock.Anything, "container2", mock.Anything).Return(nil) + + // Verify that existing containers were stopped and removed. + createAndVerifyManager() + mockDockerClient.AssertCalled(t, "ContainerStop", mock.Anything, "container1", mock.Anything) + mockDockerClient.AssertCalled(t, "ContainerStop", mock.Anything, "container2", mock.Anything) + mockDockerClient.AssertCalled(t, "ContainerRemove", mock.Anything, "container1", mock.Anything) + mockDockerClient.AssertCalled(t, "ContainerRemove", mock.Anything, "container2", mock.Anything) + mockDockerClient.AssertExpectations(t) + }) +} + +func TestDockerManager_EnsureImageAvailable(t *testing.T) { + mockDockerClient := new(MockDockerClient) + dockerManager := createDockerManager(mockDockerClient) + + ctx := context.Background() + pipeline := "text-to-image" + modelID := "test-model" + + tests := []struct { + name string + setup func(*DockerManager, *MockDockerClient) + expectedPull bool + }{ + { + name: "ImageAvailable", + setup: func(dockerManager *DockerManager, mockDockerClient *MockDockerClient) { + // Mock client methods to simulate the image being available locally. + mockDockerClient.On("ImageInspectWithRaw", mock.Anything, "default-image").Return(types.ImageInspect{}, []byte{}, nil).Once() + }, + expectedPull: false, + }, + { + name: "ImageNotAvailable", + setup: func(dockerManager *DockerManager, mockDockerClient *MockDockerClient) { + // Mock client methods to simulate the image not being available locally. + mockDockerClient.On("ImageInspectWithRaw", mock.Anything, "default-image").Return(types.ImageInspect{}, []byte{}, fmt.Errorf("image not found")).Once() + }, + expectedPull: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.setup(dockerManager, mockDockerClient) + + if tt.expectedPull { + mockDockerClient.On("ImagePull", mock.Anything, "default-image", mock.Anything).Return(io.NopCloser(strings.NewReader("")), nil).Once() + } + + err := dockerManager.EnsureImageAvailable(ctx, pipeline, modelID) + require.NoError(t, err) + + mockDockerClient.AssertExpectations(t) + }) + } +} + +func TestDockerManager_Warm(t *testing.T) { + mockDockerClient := new(MockDockerClient) + dockerManager := createDockerManager(mockDockerClient) + + ctx := context.Background() + pipeline := "text-to-image" + modelID := "test-model" + containerID := "container1" + optimizationFlags := OptimizationFlags{} + + // Mock nested functions. + originalFunc := dockerWaitUntilRunningFunc + dockerWaitUntilRunningFunc = func(ctx context.Context, client DockerClient, containerID string, pollingInterval time.Duration) error { + return nil + } + defer func() { dockerWaitUntilRunningFunc = originalFunc }() + originalFunc2 := runnerWaitUntilReadyFunc + runnerWaitUntilReadyFunc = func(ctx context.Context, client *ClientWithResponses, pollingInterval time.Duration) error { + return nil + } + defer func() { runnerWaitUntilReadyFunc = originalFunc2 }() + + mockDockerClient.On("ContainerCreate", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(container.CreateResponse{ID: containerID}, nil) + mockDockerClient.On("ContainerStart", mock.Anything, containerID, mock.Anything).Return(nil) + err := dockerManager.Warm(ctx, pipeline, modelID, optimizationFlags) + require.NoError(t, err) + mockDockerClient.AssertExpectations(t) +} + +func TestDockerManager_Stop(t *testing.T) { + MockDockerClient := new(MockDockerClient) + dockerManager := createDockerManager(MockDockerClient) + + ctx, cancel := context.WithTimeout(context.Background(), containerRemoveTimeout) + defer cancel() + containerID := "container1" + dockerManager.containers[containerID] = &RunnerContainer{ + RunnerContainerConfig: RunnerContainerConfig{ + ID: containerID, + }, + } + + MockDockerClient.On("ContainerStop", mock.Anything, containerID, container.StopOptions{Timeout: nil}).Return(nil) + MockDockerClient.On("ContainerRemove", mock.Anything, containerID, container.RemoveOptions{}).Return(nil) + err := dockerManager.Stop(ctx) + require.NoError(t, err) + MockDockerClient.AssertExpectations(t) +} + +func TestDockerManager_Borrow(t *testing.T) { + mockDockerClient := new(MockDockerClient) + dockerManager := createDockerManager(mockDockerClient) + + ctx := context.Background() + pipeline := "text-to-image" + modelID := "model" + containerID, _ := dockerManager.getContainerImageName(pipeline, modelID) + + // Mock nested functions. + originalFunc := dockerWaitUntilRunningFunc + dockerWaitUntilRunningFunc = func(ctx context.Context, client DockerClient, containerID string, pollingInterval time.Duration) error { + return nil + } + defer func() { dockerWaitUntilRunningFunc = originalFunc }() + originalFunc2 := runnerWaitUntilReadyFunc + runnerWaitUntilReadyFunc = func(ctx context.Context, client *ClientWithResponses, pollingInterval time.Duration) error { + return nil + } + defer func() { runnerWaitUntilReadyFunc = originalFunc2 }() + + mockDockerClient.On("ContainerCreate", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(container.CreateResponse{ID: containerID}, nil) + mockDockerClient.On("ContainerStart", mock.Anything, containerID, mock.Anything).Return(nil) + rc, err := dockerManager.Borrow(ctx, pipeline, modelID) + require.NoError(t, err) + require.NotNil(t, rc) + require.Empty(t, dockerManager.containers, "containers map should be empty") + mockDockerClient.AssertExpectations(t) +} + +func TestDockerManager_returnContainer(t *testing.T) { + mockDockerClient := new(MockDockerClient) + dockerManager := createDockerManager(mockDockerClient) + + // Create a RunnerContainer to return to the pool + rc := &RunnerContainer{ + Name: "container1", + RunnerContainerConfig: RunnerContainerConfig{}, + } + + // Ensure the container is not in the pool initially. + _, exists := dockerManager.containers[rc.Name] + require.False(t, exists) + + // Return the container to the pool. + dockerManager.returnContainer(rc) + + // Verify the container is now in the pool. + returnedContainer, exists := dockerManager.containers[rc.Name] + require.True(t, exists) + require.Equal(t, rc, returnedContainer) +} + +func TestDockerManager_getContainerImageName(t *testing.T) { + mockDockerClient := new(MockDockerClient) + manager := createDockerManager(mockDockerClient) + + tests := []struct { + name string + pipeline string + modelID string + expectedImage string + expectError bool + }{ + { + name: "live-video-to-video with valid modelID", + pipeline: "live-video-to-video", + modelID: "streamdiffusion", + expectedImage: "livepeer/ai-runner:live-app-streamdiffusion", + expectError: false, + }, + { + name: "live-video-to-video with invalid modelID", + pipeline: "live-video-to-video", + modelID: "invalid-model", + expectError: true, + }, + { + name: "valid pipeline", + pipeline: "text-to-speech", + modelID: "", + expectedImage: "livepeer/ai-runner:text-to-speech", + expectError: false, + }, + { + name: "invalid pipeline", + pipeline: "invalid-pipeline", + modelID: "", + expectedImage: "default-image", + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + image, err := manager.getContainerImageName(tt.pipeline, tt.modelID) + if tt.expectError { + require.Error(t, err) + require.Equal(t, fmt.Sprintf("no container image found for live pipeline %s", tt.modelID), err.Error()) + } else { + require.NoError(t, err) + require.Equal(t, tt.expectedImage, image) + } + }) + } +} + +func TestDockerManager_HasCapacity(t *testing.T) { + ctx := context.Background() + pipeline := "text-to-image" + modelID := "test-model" + + tests := []struct { + name string + setup func(*DockerManager, *MockDockerClient) + expectedHasCapacity bool + }{ + { + name: "UnusedManagedContainerExists", + setup: func(dockerManager *DockerManager, mockDockerClient *MockDockerClient) { + // Add an unused managed container. + dockerManager.containers["container1"] = &RunnerContainer{ + RunnerContainerConfig: RunnerContainerConfig{ + Pipeline: pipeline, + ModelID: modelID, + }} + }, + expectedHasCapacity: true, + }, + { + name: "ImageNotAvailable", + setup: func(dockerManager *DockerManager, mockDockerClient *MockDockerClient) { + // Mock client methods to simulate the image not being available locally. + mockDockerClient.On("ImageInspectWithRaw", mock.Anything, "default-image").Return(types.ImageInspect{}, []byte{}, fmt.Errorf("image not found")) + }, + expectedHasCapacity: false, + }, + { + name: "GPUAvailable", + setup: func(dockerManager *DockerManager, mockDockerClient *MockDockerClient) { + // Mock client methods to simulate the image being available locally. + mockDockerClient.On("ImageInspectWithRaw", mock.Anything, "default-image").Return(types.ImageInspect{}, []byte{}, nil) + // Ensure that the GPU is available by not setting any container for the GPU. + dockerManager.gpuContainers = make(map[string]string) + }, + expectedHasCapacity: true, + }, + { + name: "GPUUnavailable", + setup: func(dockerManager *DockerManager, mockDockerClient *MockDockerClient) { + // Mock client methods to simulate the image being available locally. + mockDockerClient.On("ImageInspectWithRaw", mock.Anything, "default-image").Return(types.ImageInspect{}, []byte{}, nil) + // Ensure that the GPU is not available by setting a container for the GPU. + dockerManager.gpuContainers["gpu0"] = "container1" + }, + expectedHasCapacity: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockDockerClient := new(MockDockerClient) + dockerManager := createDockerManager(mockDockerClient) + + tt.setup(dockerManager, mockDockerClient) + + hasCapacity := dockerManager.HasCapacity(ctx, pipeline, modelID) + require.Equal(t, tt.expectedHasCapacity, hasCapacity) + + mockDockerClient.AssertExpectations(t) + }) + } +} + +func TestDockerManager_isImageAvailable(t *testing.T) { + mockDockerClient := new(MockDockerClient) + dockerManager := createDockerManager(mockDockerClient) + + ctx := context.Background() + pipeline := "text-to-image" + modelID := "test-model" + + t.Run("ImageNotFound", func(t *testing.T) { + mockDockerClient.On("ImageInspectWithRaw", mock.Anything, "default-image").Return(types.ImageInspect{}, []byte{}, fmt.Errorf("image not found")).Once() + + isAvailable := dockerManager.isImageAvailable(ctx, pipeline, modelID) + require.False(t, isAvailable) + mockDockerClient.AssertExpectations(t) + }) + + t.Run("ImageFound", func(t *testing.T) { + mockDockerClient.On("ImageInspectWithRaw", mock.Anything, "default-image").Return(types.ImageInspect{}, []byte{}, nil).Once() + + isAvailable := dockerManager.isImageAvailable(ctx, pipeline, modelID) + require.True(t, isAvailable) + mockDockerClient.AssertExpectations(t) + }) +} + +func TestDockerManager_pullImage(t *testing.T) { + mockDockerClient := new(MockDockerClient) + dockerManager := createDockerManager(mockDockerClient) + + ctx := context.Background() + imageName := "default-image" + + t.Run("ImagePullError", func(t *testing.T) { + mockDockerClient.On("ImagePull", mock.Anything, imageName, mock.Anything).Return(io.NopCloser(strings.NewReader("")), fmt.Errorf("failed to pull image: pull error")).Once() + + err := dockerManager.pullImage(ctx, imageName) + require.Error(t, err) + require.Contains(t, err.Error(), "failed to pull image: pull error") + mockDockerClient.AssertExpectations(t) + }) + + t.Run("ImagePullSuccess", func(t *testing.T) { + mockDockerClient.On("ImagePull", mock.Anything, imageName, mock.Anything).Return(io.NopCloser(strings.NewReader("")), nil).Once() + + err := dockerManager.pullImage(ctx, imageName) + require.NoError(t, err) + mockDockerClient.AssertExpectations(t) + }) +} + +func TestDockerManager_createContainer(t *testing.T) { + mockDockerClient := new(MockDockerClient) + dockerManager := createDockerManager(mockDockerClient) + + ctx := context.Background() + pipeline := "text-to-image" + modelID := "test-model" + containerID := "container1" + gpu := "0" + containerHostPort := "8000" + containerName := dockerContainerName(pipeline, modelID, containerHostPort) + containerImage := "default-image" + optimizationFlags := OptimizationFlags{} + + // Mock nested functions. + originalFunc := dockerWaitUntilRunningFunc + dockerWaitUntilRunningFunc = func(ctx context.Context, client DockerClient, containerID string, pollingInterval time.Duration) error { + return nil + } + defer func() { dockerWaitUntilRunningFunc = originalFunc }() + originalFunc2 := runnerWaitUntilReadyFunc + runnerWaitUntilReadyFunc = func(ctx context.Context, client *ClientWithResponses, pollingInterval time.Duration) error { + return nil + } + defer func() { runnerWaitUntilReadyFunc = originalFunc2 }() + + // Mock allocGPU and getContainerImageName methods. + dockerManager.gpus = []string{gpu} + dockerManager.gpuContainers = make(map[string]string) + dockerManager.containers = make(map[string]*RunnerContainer) + dockerManager.defaultImage = containerImage + + mockDockerClient.On("ContainerCreate", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(container.CreateResponse{ID: containerID}, nil) + mockDockerClient.On("ContainerStart", mock.Anything, containerID, mock.Anything).Return(nil) + + rc, err := dockerManager.createContainer(ctx, pipeline, modelID, false, optimizationFlags) + require.NoError(t, err) + require.NotNil(t, rc) + require.Equal(t, containerID, rc.ID) + require.Equal(t, gpu, rc.GPU) + require.Equal(t, pipeline, rc.Pipeline) + require.Equal(t, modelID, rc.ModelID) + require.Equal(t, containerName, rc.Name) + + mockDockerClient.AssertExpectations(t) +} + +func TestDockerManager_allocGPU(t *testing.T) { + ctx := context.Background() + + tests := []struct { + name string + setup func(*DockerManager, *MockDockerClient) + expectedAllocatedGPU string + errorMessage string + }{ + { + name: "GPUAvailable", + setup: func(dockerManager *DockerManager, mockDockerClient *MockDockerClient) { + // Ensure that the GPU is available by not setting any container for the GPU. + dockerManager.gpuContainers = make(map[string]string) + }, + expectedAllocatedGPU: "gpu0", + errorMessage: "", + }, + { + name: "GPUUnavailable", + setup: func(dockerManager *DockerManager, mockDockerClient *MockDockerClient) { + // Ensure that the GPU is not available by setting a container for the GPU. + dockerManager.gpuContainers["gpu0"] = "container1" + }, + expectedAllocatedGPU: "", + errorMessage: "insufficient capacity", + }, + { + name: "GPUUnavailableAndWarm", + setup: func(dockerManager *DockerManager, mockDockerClient *MockDockerClient) { + // Ensure that the GPU is not available by setting a container for the GPU. + dockerManager.gpuContainers["gpu0"] = "container1" + dockerManager.containers["container1"] = &RunnerContainer{ + RunnerContainerConfig: RunnerContainerConfig{ + ID: "container1", + KeepWarm: true, + }, + } + }, + expectedAllocatedGPU: "", + errorMessage: "insufficient capacity", + }, + { + name: "GPUUnavailableButCold", + setup: func(dockerManager *DockerManager, mockDockerClient *MockDockerClient) { + // Ensure that the GPU is not available by setting a container for the GPU. + dockerManager.gpuContainers["gpu0"] = "container1" + dockerManager.containers["container1"] = &RunnerContainer{ + RunnerContainerConfig: RunnerContainerConfig{ + ID: "container1", + KeepWarm: false, + }, + } + // Mock client methods to simulate the removal of the warm container. + mockDockerClient.On("ContainerStop", mock.Anything, "container1", container.StopOptions{}).Return(nil) + mockDockerClient.On("ContainerRemove", mock.Anything, "container1", container.RemoveOptions{}).Return(nil) + }, + expectedAllocatedGPU: "gpu0", + errorMessage: "", + }, + } + + for _, tt := range tests { + mockDockerClient := new(MockDockerClient) + dockerManager := createDockerManager(mockDockerClient) + + tt.setup(dockerManager, mockDockerClient) + + gpu, err := dockerManager.allocGPU(ctx) + if tt.errorMessage != "" { + require.Error(t, err) + require.Contains(t, err.Error(), tt.errorMessage) + } else { + require.NoError(t, err) + require.Equal(t, tt.expectedAllocatedGPU, gpu) + } + } +} + +func TestDockerManager_destroyContainer(t *testing.T) { + mockDockerClient := new(MockDockerClient) + dockerManager := createDockerManager(mockDockerClient) + + containerID := "container1" + gpu := "gpu0" + + rc := &RunnerContainer{ + Name: containerID, + RunnerContainerConfig: RunnerContainerConfig{ + ID: containerID, + GPU: gpu, + }, + } + dockerManager.gpuContainers[gpu] = containerID + dockerManager.containers[containerID] = rc + + mockDockerClient.On("ContainerStop", mock.Anything, containerID, container.StopOptions{}).Return(nil) + mockDockerClient.On("ContainerRemove", mock.Anything, containerID, container.RemoveOptions{}).Return(nil) + + err := dockerManager.destroyContainer(rc, true) + require.NoError(t, err) + require.Empty(t, dockerManager.gpuContainers, "gpuContainers map should be empty") + require.Empty(t, dockerManager.containers, "containers map should be empty") + mockDockerClient.AssertExpectations(t) +} + +func TestDockerManager_watchContainer(t *testing.T) { + mockDockerClient := new(MockDockerClient) + dockerManager := createDockerManager(mockDockerClient) + + // Override the containerWatchInterval for testing purposes. + containerWatchInterval = 10 * time.Millisecond + + containerID := "container1" + rc := &RunnerContainer{ + Name: containerID, + RunnerContainerConfig: RunnerContainerConfig{ + ID: containerID, + }, + } + + t.Run("ReturnContainerOnContextDone", func(t *testing.T) { + borrowCtx, cancel := context.WithCancel(context.Background()) + + go dockerManager.watchContainer(rc, borrowCtx) + cancel() // Cancel the context. + time.Sleep(50 * time.Millisecond) // Ensure the ticker triggers. + + // Verify that the container was returned. + _, exists := dockerManager.containers[rc.Name] + require.True(t, exists) + }) + + t.Run("DestroyContainerOnNotRunning", func(t *testing.T) { + borrowCtx := context.Background() + + // Mock ContainerInspect to return a non-running state. + mockDockerClient.On("ContainerInspect", mock.Anything, containerID).Return(types.ContainerJSON{ + ContainerJSONBase: &types.ContainerJSONBase{ + State: &types.ContainerState{ + Running: false, + }, + }, + }, nil).Once() + + // Mock destroyContainer to verify it is called. + mockDockerClient.On("ContainerStop", mock.Anything, containerID, mock.Anything).Return(nil) + mockDockerClient.On("ContainerRemove", mock.Anything, containerID, mock.Anything).Return(nil) + + go dockerManager.watchContainer(rc, borrowCtx) + time.Sleep(50 * time.Millisecond) // Ensure the ticker triggers. + + // Verify that the container was destroyed. + _, exists := dockerManager.containers[rc.Name] + require.False(t, exists) + }) +} + +// Watch container + +func TestRemoveExistingContainers(t *testing.T) { + mockDockerClient := new(MockDockerClient) + + ctx := context.Background() + + // Mock client methods to simulate the removal of existing containers. + existingContainers := []types.Container{ + {ID: "container1", Names: []string{"/container1"}}, + {ID: "container2", Names: []string{"/container2"}}, + } + mockDockerClient.On("ContainerList", mock.Anything, mock.Anything).Return(existingContainers, nil) + mockDockerClient.On("ContainerStop", mock.Anything, "container1", mock.Anything).Return(nil) + mockDockerClient.On("ContainerStop", mock.Anything, "container2", mock.Anything).Return(nil) + mockDockerClient.On("ContainerRemove", mock.Anything, "container1", mock.Anything).Return(nil) + mockDockerClient.On("ContainerRemove", mock.Anything, "container2", mock.Anything).Return(nil) + + removeExistingContainers(ctx, mockDockerClient) + mockDockerClient.AssertExpectations(t) +} + +func TestDockerContainerName(t *testing.T) { + tests := []struct { + name string + pipeline string + modelID string + suffix []string + expectedName string + }{ + { + name: "with suffix", + pipeline: "text-to-speech", + modelID: "model1", + suffix: []string{"suffix1"}, + expectedName: "text-to-speech_model1_suffix1", + }, + { + name: "without suffix", + pipeline: "text-to-speech", + modelID: "model1", + expectedName: "text-to-speech_model1", + }, + { + name: "modelID with special characters", + pipeline: "text-to-speech", + modelID: "model/1_2", + suffix: []string{"suffix1"}, + expectedName: "text-to-speech_model-1-2_suffix1", + }, + { + name: "modelID with special characters without suffix", + pipeline: "text-to-speech", + modelID: "model/1_2", + expectedName: "text-to-speech_model-1-2", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + name := dockerContainerName(tt.pipeline, tt.modelID, tt.suffix...) + require.Equal(t, tt.expectedName, name) + }) + } +} + +func TestDockerRemoveContainer(t *testing.T) { + mockDockerClient := new(MockDockerClient) + + mockDockerClient.On("ContainerStop", mock.Anything, "container1", container.StopOptions{}).Return(nil) + mockDockerClient.On("ContainerRemove", mock.Anything, "container1", container.RemoveOptions{}).Return(nil) + + err := dockerRemoveContainer(mockDockerClient, "container1") + require.NoError(t, err) + mockDockerClient.AssertExpectations(t) +} + +func TestDockerWaitUntilRunning(t *testing.T) { + mockDockerClient := new(MockDockerClient) + containerID := "container1" + pollingInterval := 10 * time.Millisecond + ctx := context.Background() + + t.Run("ContainerRunning", func(t *testing.T) { + // Mock ContainerInspect to return a running container state. + mockDockerClient.On("ContainerInspect", mock.Anything, containerID).Return(types.ContainerJSON{ + ContainerJSONBase: &types.ContainerJSONBase{ + State: &types.ContainerState{ + Running: true, + }, + }, + }, nil).Once() + + err := dockerWaitUntilRunning(ctx, mockDockerClient, containerID, pollingInterval) + require.NoError(t, err) + mockDockerClient.AssertExpectations(t) + }) + + t.Run("ContainerNotRunningInitially", func(t *testing.T) { + // Mock ContainerInspect to return a non-running state initially, then a running state. + mockDockerClient.On("ContainerInspect", mock.Anything, containerID).Return(types.ContainerJSON{ + ContainerJSONBase: &types.ContainerJSONBase{ + State: &types.ContainerState{ + Running: false, + }, + }, + }, nil).Once() + mockDockerClient.On("ContainerInspect", mock.Anything, containerID).Return(types.ContainerJSON{ + ContainerJSONBase: &types.ContainerJSONBase{ + State: &types.ContainerState{ + Running: true, + }, + }, + }, nil).Once() + + err := dockerWaitUntilRunning(ctx, mockDockerClient, containerID, pollingInterval) + require.NoError(t, err) + mockDockerClient.AssertExpectations(t) + }) + + t.Run("ContextTimeout", func(t *testing.T) { + // Create a context that will timeout. + timeoutCtx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) + defer cancel() + + // Mock ContainerInspect to always return a non-running state. + mockDockerClient.On("ContainerInspect", mock.Anything, containerID).Return(types.ContainerJSON{ + ContainerJSONBase: &types.ContainerJSONBase{ + State: &types.ContainerState{ + Running: false, + }, + }, + }, nil) + + err := dockerWaitUntilRunning(timeoutCtx, mockDockerClient, containerID, pollingInterval) + require.Error(t, err) + require.Contains(t, err.Error(), "timed out waiting for managed container") + mockDockerClient.AssertExpectations(t) + }) +} diff --git a/worker/worker.go b/worker/worker.go index f9a6457c..e381e4be 100644 --- a/worker/worker.go +++ b/worker/worker.go @@ -12,6 +12,8 @@ import ( "strconv" "strings" "sync" + + docker "github.com/docker/docker/client" ) // EnvValue unmarshals JSON booleans as strings for compatibility with env variables. @@ -50,7 +52,12 @@ type Worker struct { } func NewWorker(defaultImage string, gpus []string, modelDir string) (*Worker, error) { - manager, err := NewDockerManager(defaultImage, gpus, modelDir) + dockerClient, err := docker.NewClientWithOpts(docker.FromEnv, docker.WithAPIVersionNegotiation()) + if err != nil { + return nil, err + } + + manager, err := NewDockerManager(defaultImage, gpus, modelDir, dockerClient) if err != nil { return nil, err } @@ -652,6 +659,10 @@ func (w *Worker) LiveVideoToVideo(ctx context.Context, req GenLiveVideoToVideoJS return resp.JSON200, nil } +func (w *Worker) EnsureImageAvailable(ctx context.Context, pipeline string, modelID string) error { + return w.manager.EnsureImageAvailable(ctx, pipeline, modelID) +} + func (w *Worker) Warm(ctx context.Context, pipeline string, modelID string, endpoint RunnerEndpoint, optimizationFlags OptimizationFlags) error { if endpoint.URL == "" { return w.manager.Warm(ctx, pipeline, modelID, optimizationFlags)