Skip to content

Commit

Permalink
cmd+worker: Impl ImageToVideo
Browse files Browse the repository at this point in the history
  • Loading branch information
yondonfu committed Jan 17, 2024
1 parent ad6b5bd commit dc36f66
Show file tree
Hide file tree
Showing 2 changed files with 150 additions and 2 deletions.
106 changes: 106 additions & 0 deletions cmd/examples/image-to-video/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
package main

import (
"context"
"log/slog"
"os"
"path"
"path/filepath"
"strconv"
"time"

"github.com/livepeer/ai-worker/worker"
"github.com/oapi-codegen/runtime/types"
)

func main() {
containerName := "image-to-video"
baseOutputPath := "output"

containerImageID := "runner"
gpus := "all"

modelDir, err := filepath.Abs("runner/models")
if err != nil {
slog.Error("Error getting absolute path for modelDir", slog.String("error", err.Error()))
return
}

modelID := "stabilityai/stable-video-diffusion-img2vid-xt"

w, err := worker.NewWorker(containerImageID, gpus, modelDir)
if err != nil {
slog.Error("Error creating worker", slog.String("error", err.Error()))
return
}

slog.Info("Warming container")

ctx, cancel := context.WithCancel(context.Background())
defer cancel()

if err := w.Warm(ctx, containerName, modelID); err != nil {
slog.Error("Error warming container", slog.String("error", err.Error()))
return
}

slog.Info("Warm container is up")

args := os.Args[1:]
runs, err := strconv.Atoi(args[0])
if err != nil {
slog.Error("Invalid runs arg", slog.String("error", err.Error()))
return
}

imagePath := args[1]

imageBytes, err := os.ReadFile(imagePath)
if err != nil {
slog.Error("Error reading image", slog.String("imagePath", imagePath))
return
}
imageFile := types.File{}
imageFile.InitFromBytes(imageBytes, imagePath)

req := worker.ImageToVideoMultipartRequestBody{
Image: imageFile,
ModelId: &modelID,
}

for i := 0; i < runs; i++ {
slog.Info("Running image-to-video", slog.Int("num", i))

resp, err := w.ImageToVideo(ctx, req)
if err != nil {
slog.Error("Error running image-to-video", slog.String("error", err.Error()))
return
}

for j, batch := range resp.Frames {
dirPath := path.Join(baseOutputPath, strconv.Itoa(i)+"_"+strconv.Itoa(j))

for frameNum, media := range batch {
if err := os.MkdirAll(dirPath, os.ModePerm); err != nil {
slog.Error("Error creating dir", slog.String("dir", dirPath))
return
}

outputPath := path.Join(dirPath, strconv.Itoa(frameNum)+".png")
if err := worker.SaveImageB64DataUrl(media.Url, outputPath); err != nil {
slog.Error("Error saving b64 data url as image", slog.String("error", err.Error()))
return
}
}
slog.Info("Outputs written", slog.String("dirPath", dirPath))
}
}

slog.Info("Sleeping 2 seconds and then stopping container")

time.Sleep(2 * time.Second)

w.Stop(ctx, containerName)

time.Sleep(1 * time.Second)
}
46 changes: 44 additions & 2 deletions worker/worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,50 @@ func (w *Worker) ImageToImage(ctx context.Context, req ImageToImageMultipartRequ
return resp.JSON200, nil
}

func (w *Worker) ImageToVideo(ctx context.Context, req ImageToVideoMultipartRequestBody) ([]string, error) {
return nil, nil
func (w *Worker) ImageToVideo(ctx context.Context, req ImageToVideoMultipartRequestBody) (*VideoResponse, error) {
c, err := w.getWarmContainer(ctx, "image-to-video", *req.ModelId)
if err != nil {
return nil, err
}

var buf bytes.Buffer
mw := multipart.NewWriter(&buf)
writer, err := mw.CreateFormFile("image", req.Image.Filename())
if err != nil {
return nil, err
}
imageSize := req.Image.FileSize()
imageRdr, err := req.Image.Reader()
if err != nil {
return nil, err
}
copied, err := io.Copy(writer, imageRdr)
if err != nil {
return nil, err
}
if copied != imageSize {
return nil, fmt.Errorf("failed to copy image to multipart request imageBytes=%v copiedBytes=%v", imageSize, copied)
}

if err := mw.WriteField("model_id", *req.ModelId); err != nil {
return nil, err
}

if err := mw.Close(); err != nil {
return nil, err
}

resp, err := c.Client.ImageToVideoWithBodyWithResponse(ctx, mw.FormDataContentType(), &buf)
if err != nil {
return nil, err
}

if resp.JSON422 != nil {
// TODO: Handle JSON422 struct
return nil, errors.New("image-to-video container returned 422")
}

return resp.JSON200, nil
}

func (w *Worker) Warm(ctx context.Context, containerName, modelID string) error {
Expand Down

0 comments on commit dc36f66

Please sign in to comment.