diff --git a/cmd/examples/sketch-to-image/main.go b/cmd/examples/sketch-to-image/main.go new file mode 100644 index 00000000..18b75aff --- /dev/null +++ b/cmd/examples/sketch-to-image/main.go @@ -0,0 +1,105 @@ +// Package main provides a small example on how to run the 'sketch-to-image' pipeline using the AI worker package. +package main + +import ( + "context" + "flag" + "log/slog" + "os" + "path" + "path/filepath" + "strconv" + "time" + + "github.com/livepeer/ai-worker/worker" + "github.com/oapi-codegen/runtime/types" +) + +func main() { + aiModelsDir := flag.String("aiModelsDir", "runner/models", "path to the models directory") + flag.Parse() + + containerName := "sketch-to-image" + baseOutputPath := "output" + + containerImageID := "livepeer/ai-runner:latest" + gpus := []string{"0"} + + modelsDir, err := filepath.Abs(*aiModelsDir) + if err != nil { + slog.Error("Error getting absolute path for 'aiModelsDir'", slog.String("error", err.Error())) + return + } + + modelID := "xinsir/controlnet-scribble-sdxl-1.0" + + w, err := worker.NewWorker(containerImageID, gpus, modelsDir) + if err != nil { + slog.Error("Error creating worker", slog.String("error", err.Error())) + return + } + + slog.Info("Warming container") + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + if err := w.Warm(ctx, containerName, modelID, worker.RunnerEndpoint{}, worker.OptimizationFlags{}); err != nil { + slog.Error("Error warming container", slog.String("error", err.Error())) + return + } + + slog.Info("Warm container is up") + + args := os.Args[1:] + runs, err := strconv.Atoi(args[0]) + if err != nil { + slog.Error("Invalid runs arg", slog.String("error", err.Error())) + return + } + + prompt := args[1] + imagePath := args[2] + + imageBytes, err := os.ReadFile(imagePath) + if err != nil { + slog.Error("Error reading image", slog.String("imagePath", imagePath)) + return + } + imageFile := types.File{} + imageFile.InitFromBytes(imageBytes, imagePath) + + req := worker.GenSketchToImageMultipartRequestBody{ + Image: imageFile, + ModelId: &modelID, + Prompt: prompt, + } + + for i := 0; i < runs; i++ { + slog.Info("Running sketch-to-image", slog.Int("num", i)) + + resp, err := w.SketchToImage(ctx, req) + if err != nil { + slog.Error("Error running sketch-to-image", slog.String("error", err.Error())) + return + } + + for j, media := range resp.Images { + outputPath := path.Join(baseOutputPath, strconv.Itoa(i)+"_"+strconv.Itoa(j)+".png") + if err := worker.SaveImageB64DataUrl(media.Url, outputPath); err != nil { + slog.Error("Error saving b64 data url as image", slog.String("error", err.Error())) + return + } + + slog.Info("Output written", slog.String("outputPath", outputPath)) + } + } + + slog.Info("Sleeping 2 seconds and then stopping container") + + time.Sleep(2 * time.Second) + + w.Stop(ctx) + + time.Sleep(1 * time.Second) +} diff --git a/runner/app/main.py b/runner/app/main.py index 52668990..c67799e1 100644 --- a/runner/app/main.py +++ b/runner/app/main.py @@ -56,11 +56,16 @@ def load_pipeline(pipeline: str, model_id: str) -> any: return SegmentAnything2Pipeline(model_id) case "llm": from app.pipelines.llm import LLMPipeline + return LLMPipeline(model_id) case "image-to-text": from app.pipelines.image_to_text import ImageToTextPipeline return ImageToTextPipeline(model_id) + case "sketch-to-image": + from app.pipelines.sketch_to_image import SketchToImagePipeline + + return SketchToImagePipeline(model_id) case _: raise EnvironmentError( f"{pipeline} is not a valid pipeline for model {model_id}" @@ -97,10 +102,16 @@ def load_route(pipeline: str) -> any: return segment_anything_2.router case "llm": from app.routes import llm + return llm.router case "image-to-text": from app.routes import image_to_text + return image_to_text.router + case "sketch-to-image": + from app.routes import sketch_to_image + + return sketch_to_image.router case _: raise EnvironmentError(f"{pipeline} is not a valid pipeline") diff --git a/runner/app/pipelines/sketch_to_image.py b/runner/app/pipelines/sketch_to_image.py new file mode 100644 index 00000000..b5aa40b1 --- /dev/null +++ b/runner/app/pipelines/sketch_to_image.py @@ -0,0 +1,123 @@ +import logging +import os +from enum import Enum +from typing import List, Optional, Tuple + +import PIL +import torch +from app.pipelines.base import Pipeline +from app.pipelines.utils import ( + LoraLoader, + SafetyChecker, + get_model_dir, + get_torch_device, + is_lightning_model, + is_turbo_model, +) +from diffusers import ( + AutoencoderKL, + ControlNetModel, + EulerAncestralDiscreteScheduler, + StableDiffusionXLControlNetPipeline, +) +from huggingface_hub import file_download, hf_hub_download +from PIL import ImageFile +from safetensors.torch import load_file + +ImageFile.LOAD_TRUNCATED_IMAGES = True + +logger = logging.getLogger(__name__) + + +class ModelName(Enum): + """Enumeration mapping model names to their corresponding IDs.""" + + SCRIBBLE_SDXL = "xinsir/controlnet-scribble-sdxl-1.0" + + @classmethod + def list(cls): + """Return a list of all model IDs.""" + return list(map(lambda c: c.value, cls)) + + +class SketchToImagePipeline(Pipeline): + def __init__(self, model_id: str): + self.model_id = model_id + kwargs = {"cache_dir": get_model_dir()} + + torch_device = get_torch_device() + folder_name = file_download.repo_folder_name( + repo_id=model_id, repo_type="model" + ) + folder_path = os.path.join(get_model_dir(), folder_name) + has_fp16_variant = ( + any( + ".fp16.safetensors" in fname + for _, _, files in os.walk(folder_path) + for fname in files + ) + ) + + torch_dtype = torch.float + if torch_device.type != "cpu" and has_fp16_variant: + logger.info("SketchToImagePipeline loading fp16 variant for %s", model_id) + torch_dtype = torch.float16 + kwargs["torch_dtype"] = torch.float16 + kwargs["variant"] = "fp16" + + eulera_scheduler = EulerAncestralDiscreteScheduler.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", + subfolder="scheduler", + cache_dir=get_model_dir(), + ) + controlnet = ControlNetModel.from_pretrained( + self.model_id, + torch_dtype=torch_dtype, + cache_dir=get_model_dir(), + ) + vae = AutoencoderKL.from_pretrained( + "madebyollin/sdxl-vae-fp16-fix", + torch_dtype=torch_dtype + ) + self.ldm = StableDiffusionXLControlNetPipeline.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", + controlnet=controlnet, + vae=vae, + safety_checker=None, + scheduler=eulera_scheduler, + ).to(torch_device) + + safety_checker_device = os.getenv("SAFETY_CHECKER_DEVICE", "cuda").lower() + self._safety_checker = SafetyChecker(device=safety_checker_device) + + def __call__( + self, prompt: str, image: PIL.Image, **kwargs + ) -> Tuple[List[PIL.Image], List[Optional[bool]]]: + seed = kwargs.pop("seed", None) + safety_check = kwargs.pop("safety_check", True) + + if seed is not None: + if isinstance(seed, int): + kwargs["generator"] = torch.Generator(get_torch_device()).manual_seed( + seed + ) + elif isinstance(seed, list): + kwargs["generator"] = [ + torch.Generator(get_torch_device()).manual_seed(s) for s in seed + ] + if "num_inference_steps" in kwargs and ( + kwargs["num_inference_steps"] is None or kwargs["num_inference_steps"] < 1 + ): + del kwargs["num_inference_steps"] + + output = self.ldm(prompt, image=image, **kwargs) + + if safety_check: + _, has_nsfw_concept = self._safety_checker.check_nsfw_images(output.images) + else: + has_nsfw_concept = [None] * len(output.images) + + return output.images, has_nsfw_concept + + def __str__(self) -> str: + return f"SketchToImagePipeline model_id={self.model_id}" diff --git a/runner/app/routes/sketch_to_image.py b/runner/app/routes/sketch_to_image.py new file mode 100644 index 00000000..502f835e --- /dev/null +++ b/runner/app/routes/sketch_to_image.py @@ -0,0 +1,159 @@ +import logging +import os +from typing import Annotated + +import torch +from fastapi import APIRouter, Depends, status, UploadFile, File, Form +from fastapi.responses import JSONResponse +from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer +from PIL import Image, ImageFile + +from app.dependencies import get_pipeline +from app.pipelines.base import Pipeline +from app.routes.utils import HTTPError, ImageResponse, image_to_data_url, http_error + +ImageFile.LOAD_TRUNCATED_IMAGES = True + +router = APIRouter() + +logger = logging.getLogger(__name__) + +RESPONSES = { + status.HTTP_200_OK: { + "content": { + "application/json": { + "schema": { + "x-speakeasy-name-override": "data", + } + } + }, + }, + status.HTTP_400_BAD_REQUEST: {"model": HTTPError}, + status.HTTP_401_UNAUTHORIZED: {"model": HTTPError}, + status.HTTP_413_REQUEST_ENTITY_TOO_LARGE: {"model": HTTPError}, + status.HTTP_500_INTERNAL_SERVER_ERROR: {"model": HTTPError}, +} + + +def handle_pipeline_error(e: Exception) -> JSONResponse: + if isinstance(e, torch.cuda.OutOfMemoryError): + torch.cuda.empty_cache() + logger.error(f"SketchToImagePipeline error: {e}") + logger.exception(e) + return JSONResponse( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + content=http_error("SketchToImagePipeline error"), + ) + +@router.post( + "/sketch-to-image", + response_model=ImageResponse, + responses=RESPONSES, + description="Transform sketch to image.", + operation_id="genSketchToImage", + summary="Sketch To Image", + tags=["generate"], + openapi_extra={"x-speakeasy-name-override": "sketchToImage"}, +) +@router.post( + "/sketch-to-image/", + response_model=ImageResponse, + responses=RESPONSES, + include_in_schema=False, +) +async def sketch_to_image( + prompt: Annotated[ + str, + Form( + description=( + "Text prompt(s) to guide image generation. Separate multiple prompts " + "with '|' if supported by the model." + ) + ), + ], + image: Annotated[ + UploadFile, + File(description="Uploaded sketch image to generate a image from."), + ], + model_id: Annotated[ + str, + Form( + description="Hugging Face model ID used for image generation." + ), + ] = "", + height: Annotated[ + int, + Form(description="The height in pixels of the generated image."), + ] = 512, + width: Annotated[ + int, + Form(description="The width in pixels of the generated image."), + ] = 1024, + negative_prompt: Annotated[ + str, + Form( + description=( + "Text prompt(s) to guide what to exclude from image generation. " + "Ignored if guidance_scale < 1." + ), + ), + ] = "", + num_inference_steps: Annotated[ + int, + Form( + description=( + "Number of denoising steps. More steps usually lead to higher quality " + "images but slower inference. Modulated by strength." + ), + ), + ] = 8, + controlnet_conditioning_scale: Annotated[ + float, + Form(description="Encourages model to generate images follow the conditioning input more strictly"), + ] = 1.0, + pipeline: Pipeline = Depends(get_pipeline), + token: HTTPAuthorizationCredentials = Depends(HTTPBearer(auto_error=False)), +): + auth_token = os.environ.get("AUTH_TOKEN") + if auth_token: + if not token or token.credentials != auth_token: + return JSONResponse( + status_code=status.HTTP_401_UNAUTHORIZED, + headers={"WWW-Authenticate": "Bearer"}, + content=http_error("Invalid bearer token"), + ) + + if model_id != "" and model_id != pipeline.model_id: + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content=http_error( + f"pipeline configured with {pipeline.model_id} but called with " + f"{model_id}" + ), + ) + + image = Image.open(image.file).convert("RGB") + + images = [] + has_nsfw_concept = [] + try: + imgs, nsfw_checks = pipeline( + prompt=prompt, + image=image, + height=height, + width=width, + negative_prompt=negative_prompt, + num_inference_steps=num_inference_steps, + controlnet_conditioning_scale=controlnet_conditioning_scale, + ) + images.extend(imgs) + has_nsfw_concept.extend(nsfw_checks) + except Exception as e: + handle_pipeline_error(e) + + output_images = [ + {"url": image_to_data_url(img), "seed": 0, "nsfw": nsfw or False} + for img, sd, nsfw in zip(images, [1], has_nsfw_concept) + ] + + return {"images": output_images} \ No newline at end of file diff --git a/runner/dl_checkpoints.sh b/runner/dl_checkpoints.sh index 4a03c134..e0ccb909 100755 --- a/runner/dl_checkpoints.sh +++ b/runner/dl_checkpoints.sh @@ -28,6 +28,7 @@ function download_beta_models() { huggingface-cli download SG161222/RealVisXL_V4.0_Lightning --include "*.fp16.safetensors" "*.json" "*.txt" --exclude ".onnx" ".onnx_data" --cache-dir models huggingface-cli download ByteDance/SDXL-Lightning --include "*unet.safetensors" --cache-dir models huggingface-cli download timbrooks/instruct-pix2pix --include "*.fp16.safetensors" "*.json" "*.txt" --cache-dir models + huggingface-cli download xinsir/controlnet-scribble-sdxl-1.0 --include "*.safetensors" "*.json" --cache-dir models # Download upscale models huggingface-cli download stabilityai/stable-diffusion-x4-upscaler --include "*.fp16.safetensors" --cache-dir models diff --git a/runner/example_data/sketch.png b/runner/example_data/sketch.png new file mode 100644 index 00000000..089870ab Binary files /dev/null and b/runner/example_data/sketch.png differ diff --git a/runner/gateway.openapi.yaml b/runner/gateway.openapi.yaml index 2eaa0193..bde80654 100644 --- a/runner/gateway.openapi.yaml +++ b/runner/gateway.openapi.yaml @@ -3,7 +3,7 @@ openapi: 3.1.0 info: title: Livepeer AI Runner description: An application to run AI pipelines - version: v0.9.0 + version: '' servers: - url: https://dream-gateway.livepeer.cloud description: Livepeer Cloud Community Gateway @@ -411,6 +411,60 @@ paths: security: - HTTPBearer: [] x-speakeasy-name-override: imageToText + /sketch-to-image: + post: + tags: + - generate + summary: Sketch To Image + description: Transform sketch to image. + operationId: genSketchToImage + requestBody: + content: + multipart/form-data: + schema: + $ref: '#/components/schemas/Body_genSketchToImage' + required: true + responses: + '200': + description: Successful Response + content: + application/json: + schema: + $ref: '#/components/schemas/ImageResponse' + x-speakeasy-name-override: data + '400': + description: Bad Request + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPError' + '401': + description: Unauthorized + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPError' + '413': + description: Request Entity Too Large + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPError' + '500': + description: Internal Server Error + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPError' + '422': + description: Validation Error + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPValidationError' + security: + - HTTPBearer: [] + x-speakeasy-name-override: sketchToImage components: schemas: APIError: @@ -686,6 +740,57 @@ components: - image - model_id title: Body_genSegmentAnything2 + Body_genSketchToImage: + properties: + prompt: + type: string + title: Prompt + description: Text prompt(s) to guide image generation. Separate multiple + prompts with '|' if supported by the model. + image: + type: string + format: binary + title: Image + description: Uploaded sketch image to generate a image from. + model_id: + type: string + title: Model Id + description: Hugging Face model ID used for image generation. + default: '' + height: + type: integer + title: Height + description: The height in pixels of the generated image. + default: 512 + width: + type: integer + title: Width + description: The width in pixels of the generated image. + default: 1024 + negative_prompt: + type: string + title: Negative Prompt + description: Text prompt(s) to guide what to exclude from image generation. + Ignored if guidance_scale < 1. + default: '' + num_inference_steps: + type: integer + title: Num Inference Steps + description: Number of denoising steps. More steps usually lead to higher + quality images but slower inference. Modulated by strength. + default: 8 + controlnet_conditioning_scale: + type: number + title: Controlnet Conditioning Scale + description: Encourages model to generate images follow the conditioning + input more strictly + default: 1.0 + type: object + required: + - prompt + - image + - model_id + title: Body_genSketchToImage Body_genUpscale: properties: prompt: diff --git a/runner/gen_openapi.py b/runner/gen_openapi.py index ae7438c5..b11c1453 100644 --- a/runner/gen_openapi.py +++ b/runner/gen_openapi.py @@ -16,6 +16,7 @@ upscale, llm, image_to_text, + sketch_to_image, ) from fastapi.openapi.utils import get_openapi @@ -127,6 +128,7 @@ def write_openapi(fname: str, entrypoint: str = "runner", version: str = "0.0.0" app.include_router(segment_anything_2.router) app.include_router(llm.router) app.include_router(image_to_text.router) + app.include_router(sketch_to_image.router) logger.info(f"Generating OpenAPI schema for '{entrypoint}' entrypoint...") openapi = get_openapi( diff --git a/runner/openapi.yaml b/runner/openapi.yaml index b34f5dda..bb82b20c 100644 --- a/runner/openapi.yaml +++ b/runner/openapi.yaml @@ -3,7 +3,7 @@ openapi: 3.1.0 info: title: Livepeer AI Runner description: An application to run AI pipelines - version: v0.9.0 + version: '' servers: - url: https://dream-gateway.livepeer.cloud description: Livepeer Cloud Community Gateway @@ -422,6 +422,60 @@ paths: security: - HTTPBearer: [] x-speakeasy-name-override: imageToText + /sketch-to-image: + post: + tags: + - generate + summary: Sketch To Image + description: Transform sketch to image. + operationId: genSketchToImage + requestBody: + content: + multipart/form-data: + schema: + $ref: '#/components/schemas/Body_genSketchToImage' + required: true + responses: + '200': + description: Successful Response + content: + application/json: + schema: + $ref: '#/components/schemas/ImageResponse' + x-speakeasy-name-override: data + '400': + description: Bad Request + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPError' + '401': + description: Unauthorized + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPError' + '413': + description: Request Entity Too Large + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPError' + '500': + description: Internal Server Error + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPError' + '422': + description: Validation Error + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPValidationError' + security: + - HTTPBearer: [] + x-speakeasy-name-override: sketchToImage components: schemas: APIError: @@ -691,6 +745,56 @@ components: required: - image title: Body_genSegmentAnything2 + Body_genSketchToImage: + properties: + prompt: + type: string + title: Prompt + description: Text prompt(s) to guide image generation. Separate multiple + prompts with '|' if supported by the model. + image: + type: string + format: binary + title: Image + description: Uploaded sketch image to generate a image from. + model_id: + type: string + title: Model Id + description: Hugging Face model ID used for image generation. + default: '' + height: + type: integer + title: Height + description: The height in pixels of the generated image. + default: 512 + width: + type: integer + title: Width + description: The width in pixels of the generated image. + default: 1024 + negative_prompt: + type: string + title: Negative Prompt + description: Text prompt(s) to guide what to exclude from image generation. + Ignored if guidance_scale < 1. + default: '' + num_inference_steps: + type: integer + title: Num Inference Steps + description: Number of denoising steps. More steps usually lead to higher + quality images but slower inference. Modulated by strength. + default: 8 + controlnet_conditioning_scale: + type: number + title: Controlnet Conditioning Scale + description: Encourages model to generate images follow the conditioning + input more strictly + default: 1.0 + type: object + required: + - prompt + - image + title: Body_genSketchToImage Body_genUpscale: properties: prompt: diff --git a/worker/docker.go b/worker/docker.go index b94dce66..29c6744f 100644 --- a/worker/docker.go +++ b/worker/docker.go @@ -38,6 +38,7 @@ var containerHostPorts = map[string]string{ "llm": "8500", "segment-anything-2": "8600", "image-to-text": "8700", + "sketch-to-image": "8800", } // Mapping for per pipeline container images. diff --git a/worker/multipart.go b/worker/multipart.go index b0f57eb5..bb488c56 100644 --- a/worker/multipart.go +++ b/worker/multipart.go @@ -400,3 +400,63 @@ func NewImageToTextMultipartWriter(w io.Writer, req GenImageToTextMultipartReque return mw, nil } + +func NewSketchToImageMultipartWriter(w io.Writer, req GenSketchToImageMultipartRequestBody) (*multipart.Writer, error) { + mw := multipart.NewWriter(w) + writer, err := mw.CreateFormFile("image", req.Image.Filename()) + if err != nil { + return nil, err + } + imageSize := req.Image.FileSize() + imageRdr, err := req.Image.Reader() + if err != nil { + return nil, err + } + copied, err := io.Copy(writer, imageRdr) + if err != nil { + return nil, err + } + if copied != imageSize { + return nil, fmt.Errorf("failed to copy image to multipart request imageBytes=%v copiedBytes=%v", imageSize, copied) + } + + if err := mw.WriteField("prompt", req.Prompt); err != nil { + return nil, err + } + if req.ModelId != nil { + if err := mw.WriteField("model_id", *req.ModelId); err != nil { + return nil, err + } + } + if req.Width != nil { + if err := mw.WriteField("width", strconv.Itoa(*req.Width)); err != nil { + return nil, err + } + } + if req.Height != nil { + if err := mw.WriteField("height", strconv.Itoa(*req.Height)); err != nil { + return nil, err + } + } + if req.ControlnetConditioningScale != nil { + if err := mw.WriteField("controlnet_conditioning_scale", fmt.Sprintf("%f", *req.ControlnetConditioningScale)); err != nil { + return nil, err + } + } + if req.NegativePrompt != nil { + if err := mw.WriteField("negative_prompt", *req.NegativePrompt); err != nil { + return nil, err + } + } + if req.NumInferenceSteps != nil { + if err := mw.WriteField("num_inference_steps", strconv.Itoa(*req.NumInferenceSteps)); err != nil { + return nil, err + } + } + + if err := mw.Close(); err != nil { + return nil, err + } + + return mw, nil +} diff --git a/worker/runner.gen.go b/worker/runner.gen.go index 2ed17131..f551ed83 100644 --- a/worker/runner.gen.go +++ b/worker/runner.gen.go @@ -166,6 +166,33 @@ type BodyGenSegmentAnything2 struct { ReturnLogits *bool `json:"return_logits,omitempty"` } +// BodyGenSketchToImage defines model for Body_genSketchToImage. +type BodyGenSketchToImage struct { + // ControlnetConditioningScale Encourages model to generate images follow the conditioning input more strictly + ControlnetConditioningScale *float32 `json:"controlnet_conditioning_scale,omitempty"` + + // Height The height in pixels of the generated image. + Height *int `json:"height,omitempty"` + + // Image Uploaded sketch image to generate a image from. + Image openapi_types.File `json:"image"` + + // ModelId Hugging Face model ID used for image generation. + ModelId *string `json:"model_id,omitempty"` + + // NegativePrompt Text prompt(s) to guide what to exclude from image generation. Ignored if guidance_scale < 1. + NegativePrompt *string `json:"negative_prompt,omitempty"` + + // NumInferenceSteps Number of denoising steps. More steps usually lead to higher quality images but slower inference. Modulated by strength. + NumInferenceSteps *int `json:"num_inference_steps,omitempty"` + + // Prompt Text prompt(s) to guide image generation. Separate multiple prompts with '|' if supported by the model. + Prompt string `json:"prompt"` + + // Width The width in pixels of the generated image. + Width *int `json:"width,omitempty"` +} + // BodyGenUpscale defines model for Body_genUpscale. type BodyGenUpscale struct { // Image Uploaded image to modify with the pipeline. @@ -341,6 +368,9 @@ type GenLLMFormdataRequestBody = BodyGenLLM // GenSegmentAnything2MultipartRequestBody defines body for GenSegmentAnything2 for multipart/form-data ContentType. type GenSegmentAnything2MultipartRequestBody = BodyGenSegmentAnything2 +// GenSketchToImageMultipartRequestBody defines body for GenSketchToImage for multipart/form-data ContentType. +type GenSketchToImageMultipartRequestBody = BodyGenSketchToImage + // GenTextToImageJSONRequestBody defines body for GenTextToImage for application/json ContentType. type GenTextToImageJSONRequestBody = TextToImageParams @@ -505,6 +535,9 @@ type ClientInterface interface { // GenSegmentAnything2WithBody request with any body GenSegmentAnything2WithBody(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*http.Response, error) + // GenSketchToImageWithBody request with any body + GenSketchToImageWithBody(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*http.Response, error) + // GenTextToImageWithBody request with any body GenTextToImageWithBody(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*http.Response, error) @@ -610,6 +643,18 @@ func (c *Client) GenSegmentAnything2WithBody(ctx context.Context, contentType st return c.Client.Do(req) } +func (c *Client) GenSketchToImageWithBody(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*http.Response, error) { + req, err := NewGenSketchToImageRequestWithBody(c.Server, contentType, body) + if err != nil { + return nil, err + } + req = req.WithContext(ctx) + if err := c.applyEditors(ctx, req, reqEditors); err != nil { + return nil, err + } + return c.Client.Do(req) +} + func (c *Client) GenTextToImageWithBody(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*http.Response, error) { req, err := NewGenTextToImageRequestWithBody(c.Server, contentType, body) if err != nil { @@ -858,6 +903,35 @@ func NewGenSegmentAnything2RequestWithBody(server string, contentType string, bo return req, nil } +// NewGenSketchToImageRequestWithBody generates requests for GenSketchToImage with any type of body +func NewGenSketchToImageRequestWithBody(server string, contentType string, body io.Reader) (*http.Request, error) { + var err error + + serverURL, err := url.Parse(server) + if err != nil { + return nil, err + } + + operationPath := fmt.Sprintf("/sketch-to-image") + if operationPath[0] == '/' { + operationPath = "." + operationPath + } + + queryURL, err := serverURL.Parse(operationPath) + if err != nil { + return nil, err + } + + req, err := http.NewRequest("POST", queryURL.String(), body) + if err != nil { + return nil, err + } + + req.Header.Add("Content-Type", contentType) + + return req, nil +} + // NewGenTextToImageRequest calls the generic GenTextToImage builder with application/json body func NewGenTextToImageRequest(server string, body GenTextToImageJSONRequestBody) (*http.Request, error) { var bodyReader io.Reader @@ -993,6 +1067,9 @@ type ClientWithResponsesInterface interface { // GenSegmentAnything2WithBodyWithResponse request with any body GenSegmentAnything2WithBodyWithResponse(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*GenSegmentAnything2Response, error) + // GenSketchToImageWithBodyWithResponse request with any body + GenSketchToImageWithBodyWithResponse(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*GenSketchToImageResponse, error) + // GenTextToImageWithBodyWithResponse request with any body GenTextToImageWithBodyWithResponse(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*GenTextToImageResponse, error) @@ -1183,6 +1260,33 @@ func (r GenSegmentAnything2Response) StatusCode() int { return 0 } +type GenSketchToImageResponse struct { + Body []byte + HTTPResponse *http.Response + JSON200 *ImageResponse + JSON400 *HTTPError + JSON401 *HTTPError + JSON413 *HTTPError + JSON422 *HTTPValidationError + JSON500 *HTTPError +} + +// Status returns HTTPResponse.Status +func (r GenSketchToImageResponse) Status() string { + if r.HTTPResponse != nil { + return r.HTTPResponse.Status + } + return http.StatusText(0) +} + +// StatusCode returns HTTPResponse.StatusCode +func (r GenSketchToImageResponse) StatusCode() int { + if r.HTTPResponse != nil { + return r.HTTPResponse.StatusCode + } + return 0 +} + type GenTextToImageResponse struct { Body []byte HTTPResponse *http.Response @@ -1306,6 +1410,15 @@ func (c *ClientWithResponses) GenSegmentAnything2WithBodyWithResponse(ctx contex return ParseGenSegmentAnything2Response(rsp) } +// GenSketchToImageWithBodyWithResponse request with arbitrary body returning *GenSketchToImageResponse +func (c *ClientWithResponses) GenSketchToImageWithBodyWithResponse(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*GenSketchToImageResponse, error) { + rsp, err := c.GenSketchToImageWithBody(ctx, contentType, body, reqEditors...) + if err != nil { + return nil, err + } + return ParseGenSketchToImageResponse(rsp) +} + // GenTextToImageWithBodyWithResponse request with arbitrary body returning *GenTextToImageResponse func (c *ClientWithResponses) GenTextToImageWithBodyWithResponse(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*GenTextToImageResponse, error) { rsp, err := c.GenTextToImageWithBody(ctx, contentType, body, reqEditors...) @@ -1703,6 +1816,67 @@ func ParseGenSegmentAnything2Response(rsp *http.Response) (*GenSegmentAnything2R return response, nil } +// ParseGenSketchToImageResponse parses an HTTP response from a GenSketchToImageWithResponse call +func ParseGenSketchToImageResponse(rsp *http.Response) (*GenSketchToImageResponse, error) { + bodyBytes, err := io.ReadAll(rsp.Body) + defer func() { _ = rsp.Body.Close() }() + if err != nil { + return nil, err + } + + response := &GenSketchToImageResponse{ + Body: bodyBytes, + HTTPResponse: rsp, + } + + switch { + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 200: + var dest ImageResponse + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON200 = &dest + + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 400: + var dest HTTPError + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON400 = &dest + + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 401: + var dest HTTPError + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON401 = &dest + + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 413: + var dest HTTPError + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON413 = &dest + + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 422: + var dest HTTPValidationError + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON422 = &dest + + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 500: + var dest HTTPError + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON500 = &dest + + } + + return response, nil +} + // ParseGenTextToImageResponse parses an HTTP response from a GenTextToImageWithResponse call func ParseGenTextToImageResponse(rsp *http.Response) (*GenTextToImageResponse, error) { bodyBytes, err := io.ReadAll(rsp.Body) @@ -1834,6 +2008,9 @@ type ServerInterface interface { // Segment Anything 2 // (POST /segment-anything-2) GenSegmentAnything2(w http.ResponseWriter, r *http.Request) + // Sketch To Image + // (POST /sketch-to-image) + GenSketchToImage(w http.ResponseWriter, r *http.Request) // Text To Image // (POST /text-to-image) GenTextToImage(w http.ResponseWriter, r *http.Request) @@ -1888,6 +2065,12 @@ func (_ Unimplemented) GenSegmentAnything2(w http.ResponseWriter, r *http.Reques w.WriteHeader(http.StatusNotImplemented) } +// Sketch To Image +// (POST /sketch-to-image) +func (_ Unimplemented) GenSketchToImage(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotImplemented) +} + // Text To Image // (POST /text-to-image) func (_ Unimplemented) GenTextToImage(w http.ResponseWriter, r *http.Request) { @@ -2026,6 +2209,23 @@ func (siw *ServerInterfaceWrapper) GenSegmentAnything2(w http.ResponseWriter, r handler.ServeHTTP(w, r.WithContext(ctx)) } +// GenSketchToImage operation middleware +func (siw *ServerInterfaceWrapper) GenSketchToImage(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + ctx = context.WithValue(ctx, HTTPBearerScopes, []string{}) + + handler := http.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + siw.Handler.GenSketchToImage(w, r) + })) + + for _, middleware := range siw.HandlerMiddlewares { + handler = middleware(handler) + } + + handler.ServeHTTP(w, r.WithContext(ctx)) +} + // GenTextToImage operation middleware func (siw *ServerInterfaceWrapper) GenTextToImage(w http.ResponseWriter, r *http.Request) { ctx := r.Context() @@ -2194,6 +2394,9 @@ func HandlerWithOptions(si ServerInterface, options ChiServerOptions) http.Handl r.Group(func(r chi.Router) { r.Post(options.BaseURL+"/segment-anything-2", wrapper.GenSegmentAnything2) }) + r.Group(func(r chi.Router) { + r.Post(options.BaseURL+"/sketch-to-image", wrapper.GenSketchToImage) + }) r.Group(func(r chi.Router) { r.Post(options.BaseURL+"/text-to-image", wrapper.GenTextToImage) }) @@ -2207,66 +2410,68 @@ func HandlerWithOptions(si ServerInterface, options ChiServerOptions) http.Handl // Base64 encoded, gzipped, json marshaled Swagger object var swaggerSpec = []string{ - "H4sIAAAAAAAC/+xce28bNxL/KsTeAU0AyZLcurkz0D+cNE2Ms9PAlpsWiSFQu6MV611yy4clNefvfuBw", - "d8V96OXabi/VX3W0JOf9m+GQ7OcgFGkmOHCtguPPgQqnkFL88+T96WsphbR/R6BCyTLNBA+O7RcC9hOR", - "oDLBFZBURJAcBJ0gkyIDqRngGqmKm9OHU8inp6AUjcHO00wnEBwH5yq2/1pk9h9KS8bj4O6uE0j4zTAJ", - "UXD8EVe9Xk4pGS3nifGvEOrgrhO8FNFiFAM/MRETQzGEubYMVbmk9mOTz6ssETSCiOB3MmEJEC3IGIiW", - "lNuRY4gs7xMhU6qD42DMOJULTxok25SnE6C+RixyVCfUJHZ+0Kmx8NbEMeMx+YGGuY7J6ffEKIjIRMiS", - "Dxxe0aIbGm1UpRPdU2abwtbo9TSlMQwF/qep2NiwiPIQRiqkCVRkfXFwVBf2NQ+FkTQGlYuqBYmBg6Qa", - "CEvxQ5gIBcmCJIzfQGRH6CkQDXNNMinSTJNnUxZPQZJbmhi7El0QCZEJ8yXIb4YmTC+e++p6k/NJLpHP", - "Ul5u0jFIKy8rBFzhIm5tLSznbLIgM6anyFrGMkgYh/V+4vTX4ie47miNHgdNPX4PsQRkZjZloWOj0GPB", - "KVMkM2qKKpxRGSkcxTjTjCZuzEGdP7JZTYmQDjzW+PQJORMXJ+TZmZh1Lyi/IScRzTS1X5/nhqc8Ikwr", - "EgrpECayQTADFk81Or4TIhfK+j55PadplsAx+Uw+BQnVwHU3FFwxpYGHi14Spl3LXVdF8+RTcEwGB/0O", - "+RRwkOxX1cvYHJIulbpbfD288xVwhoI9WiA35NkqljsBh5hqdgsj5/wbmBguw+SZeo7hZVgEZDal2v4L", - "5mFiIiATKdIWFZ/GXEjrQRNSdUjyyfT7X4dk4LP9LmeNvHestXFv0pGL61EGsk2GQV2Ed+hqREwKQPAx", - "IgOZi1dhxKTk1A1+D7LBDuMaYue9yA+fgAQUTUNW9eVBv7+anwi4YMraGCcekHMhwf1NjDI0sagFFDEr", - "h6gcigpRxkYTlYgZSFJyYZeJTIKRO14QpSXwWE8b8hXjySVy3Sadr95tvGKdT662qaIT0ItROIXwpqI8", - "LQ3UtfcepMVEQombRnAauqLSLEXcn9Sxy8KCSSKbh8VkAlxZJxOSTKlMJybx2bx0q75CZkpmx0IkQDly", - "CxA1NXIJeVhKyiOREodvK1RhB7fqu7BVRQv9g3+tgGsxcencJQkmOKFZlrBlkpNQ2NhZ5lnffhlUEtll", - "QbOBzbW8nxUGdImtpQCoZPbNFUB7ZbV12ixFf7DM+YAVVmmSbWH5D6HxapKroq5m200m3bKm+4lFIJom", - "ndRA8dtOS3k/kTQFhYCsIBQ8Qveu1CG3dnlfuh9W4NYU036F5tGLVqpuJGGcYDpXWxB96xZvo7u175b5", - "h7r1MX/+qV7r2Ni9nEiFHT0am/AGdJ2LweGLOhtXBUFrYmZ/tExZldNUGK6tAdyaLpim1YICbeZSof2U", - "w6z9M7W5M585Y0liwZ5x/NQw4bkb9hKZrgjmp3bBFIyoiUcrYLl/2KhTSxFwMqFRtATjisCuXCZvKxuP", - "fNMhQUE6TrBsXjnXFbw8lEBVIXclxSMDJyYmqwF+c/lyePR/XL3s64pCEzMW1bx30D/8pg0PceROcPgB", - "125S3THDuNSxJsWcnZ03M8uUKS3kogp9H699tM5HtEEXnY+0uAFe9/lvPaSgczJ0Y9oUuxJ7d0v5W9TI", - "WgJNK2QmNFFQreNo2u5aC6UhHZWdtRY+L3EIaW2ldQINaWbtbyTUMPDFcomhN2jLWrLFHayZ13jBJcQp", - "cH3CF3rKeHzYdImxmLe0H0mCMEK+IVRKuiAxuwVOqCKUjMW8aATlaItW7dgo+PmXn38hLif7Pv9SzFd2", - "XprET4usrxzz983zVN2MGM+MbpVPzLoSlEgMpjY7mODgmlB6kbEQsRm37JRkEm6ZMMr+EbEQZzOdo0tn", - "WVsjOg7mb+cfyLO333347vDoWwSmy5Pzyn7i3FI+RTb/cr2P1CQWy9XNSBhdKnJNVji1OywDnaUGXW0h", - "QRtpiwu7DbMLKuSLpmMWG6tMp3rnVqpDxEQDt/+MTGjlGoPWIPOZekq5zTuMxwl4ZqhIVXBOfnSct8U5", - "t06VsN9hFAohI7WbeJlgXBOcyTjVoMoyqlx3ubGkPAbysd8ZXOcugrNzugTmGYTaDR+DGyBB2R/tT858", - "EUttxhRcVeuWnBZ55WRoE9Qn1gyGd/PDPMrFJJcqN0QtFmZTkECAhjn7hFnDkWc/d355vsyBle0UDqtz", - "5kE6MpbQMSQtjJ3h72VdW2Gt4GZAGI9YiPqndijEUhge5aNt1devDBnT8MYf0mTXkW1j17nxKBEx0zt4", - "i5umiOFdGwFqKhJb56J7urUI40rb2k9MLIuIcfjd5+7CBdGZo96087YVRCMnrMkfV1nZD79n2+GBu/UP", - "A4jGiRXdvyu8YSPw4uhv1MbcSpv7fuamfcfO/cMiOFvi9+1w+H7F0bL9tOXZcgSasgTPb5Pkx0lw/PFz", - "8E8Jk+A4+Edveardy4+0e+Ux8d11swVrl4Iop8y4136rS56T9SReirNC1p9owiJcrpR6lShMQ4o/rZOk", - "vt7dkhcnyZIRTJ0og89tfYE2voEmevqqcPsqv0pTbWrnfD/+p9KHxgFtncpl521JoIU+YuxF7gJNP7mo", - "OMfKOrIlLaj2Cwn1oLSztzLGOUSM+iZwZ01tJmgkQOW7UVXiVSpx7dudFIOn4+v0ovPe/Tqt2DE+TFSb", - "yCs60bhwXcSaBC2Cnp2d+wJWmZXel2XxUV/M2+7ihn9ks6s/xfUByJXaBu2kt763nCeZz3KLRHYfpXYy", - "mptb7DJXGM6v9eqmk3TWIYZ75f5yM6LIMzf1eVm/4u6leuJdreSqe9eNUdRYD1XQmmFDIVdFJerjK5sh", - "+YRFWBm44cg3FvtVkpVM5hbeeJsoZ0wVw3OtXtd4X2tfBIGWnXxqPxTGDAXXlLm2L/fO+cbC7uyr6rPz", - "mgbnajJrkvkwBV000R3BGVVkktA4hohQRd5d/vChUmvZZbavH6wl7BdXovonHiXFrTqXRibti19dnOU7", - "pqUIIeW2JKJhCEq5e1YFgSuZbLSqwTHKsYJq8+2J5mqx48Njazg1fHO04DJu6NaJB4f7ieeVI1VPPJ1H", - "xvdOIeN1dfa6eLHf87Pl91RSJ+yXenfsIQ8wGzez1hxg7i9j7S9jfbmXsY7+1nexyCVkFPWMbfAMe82u", - "LYrNrK/++5V1DWWyTMic4bJZuu98/Gknrg383vLEtXnG1kyhLXl2Y+chEWGl7UD5Im+l1P3hc4PF6zsf", - "kkMk01J95CeVy9oL7/y3btXwh+VQ5JkM7a+bKhErhyOVj/Q0tUW3Aw+sdyr82q7Y1C5K4R2oTXVXcWPI", - "jq2Ufjs2H+olX3GpyjGxoRmRs+rrrKKQFo256rNlx4Mf0PEtliEYUaJZCkrTNNul+YAL5BGEq26uT+33", - "nNKKNYvPjYULfXvKG5ZrbdCf9gfWeh5OUQ0NImSFRjK9uLTGdMp4Oxy+fwlUgiwf4yDOuZ/KRaZaZ8Gd", - "XcPuI1uskF8edTFpUVgaTk5Oy/MN5VdT7BYyAGm/XxjOkdAtSOXWuu0f/Pugb1UrMuA0Y8Fx8PXB4KBv", - "LUn1FPnu4TuSrhbdwpyZUG1mLR/OeI9q3Elevv8QWR5Qp5EtrusPUazWQemXIsJ7InZHDRwJuTxIpe7Z", - "RNSNqKbLB02b4qjt1ctd1co26+EPLiZQ7MN+v8aFp/ber8rKvC0LlS0T0q6lMoM74YlJyHJYJ/jmAVlY", - "NrJb6L+kEblw2nd0B09D94pTo6dCst8hQsKDr5+GcC4sec21LQyHQpAzKmOn9cHRU0m/LOEQ7F06tCwc", - "Hj4oC41DhSYzyyGkPHg4eir/O+UaJKcJuQR5C7LgwMNRLFt8BP14fXfdCZRJUyoXxWs8MhSkSB40Vha+", - "i2xsYXveVRnQG6Bq0eU0ha64BSlZhOBfQYdO0JviGQZ2KwBlr6KXO+IIHhE0/EOUbTHjzldJziJKg7Wp", - "xfDy+LwdxE+yLFkUZ+iVy+qI5NTuRGxZ41W7DVSvvS54ZFivUHtiXK8e6+yBfTWw7wFtV0BzlxGHgpQ3", - "UnZENFYNDB8EtijkcMfvcGBzHVd9fPI0Af9n1HFtZ5z7qP+Ll3N76Lk39NyzlmKVCPWB57Z8d9aKPG/a", - "XlvtVHQUrxOeBoMctScGoWr7Zg8/+6LjESK/fOVzv9AvAqMT9JIk3SLgsR9o8FiHkoTy2FhGyhONRri7", - "1yero9xX8bw7m826GO1GJsBDEbnzhN1i3pJ84lD3bx7tA30f6A8X6PnrrR2j28YyBnV+da1L86v83cPV", - "MZ7f+s8vSuHDDcrXZPKWVwKPnM0bFJ84zKtX0PaBvg/0hwv0IvoK5yaH94h71QyQTtCzOXuLluKb2lUt", - "rOm9m1mqFQW8I/CtE/3uhzLVQ/Z993Af9l9I2ONloz/QPNRe+GGwG+89XmuY52+CytxOxovif3uBl6K1", - "Istnz60hv3xV9Mj5viC0j/d9vH8h8e69yNsx0o0fDAoZUEiu9iS6uEnyKhEmIq9EmhrO9IK8oRpmdBHk", - "TwDw/oo67vUiCTTtxu7rQZJPPwjtdLxytmL9S40nuauWLRdSOK5HM9Ybg6a9Ut6767v/BQAA//8hSG/r", - "+lYAAA==", + "H4sIAAAAAAAC/+xca28bN7P+K8SeAzQBJEty66Yw0A+OmybGsdPAVpoWiSFQu6MV611yy4slNcf//QWH", + "uyvuRTfHdvum+hR5eZkb55nhkMznIBRpJjhwrYLjz4EKp5BS/Hny7uyVlELa3xGoULJMM8GDY9tCwDYR", + "CSoTXAFJRQTJQdAJMikykJoBzpGquDl8OIV8eApK0RjsOM10AsFxcKFi+9cis38oLRmPg7u7TiDhT8Mk", + "RMHxR5z1ejmkZLQcJ8Z/QKiDu07wUkSLUQz8xERMDMUQ5toyVOWS2sYmn++zRNAIIoLtZMISIFqQMRAt", + "Kbc9xxBZ3idCplQHx8GYcSoXnjRItilPJ0B9jVjkqE6oSez4oFNj4Y2JY8Zj8jMNcx2Ts5+IURCRiZAl", + "H9i9okXXNdqoSie6p8w2ha3R61lKYxgK/Kep2NiwiPIQRiqkCVRkfXFwVBf2FQ+FkTQGlYuqBYmBg6Qa", + "CEuxIUyEgmRBEsZvILI99BSIhrkmmRRppsmzKYunIMktTYydiS6IhMiE+RTkT0MTphfPfXW9zvkkV8hn", + "KS836RiklZcVAq5YIm5uLSznbLIgM6anyFrGMkgYh/XrxOmvZZ3gvKM1ehw09fgTxBKQmdmUhY6NQo8F", + "p0yRzKgpqnBGZaSwF+NMM5q4Pgd1/shmNSVCOvBYs6ZPyLm4PCHPzsWse0n5DTmJaKapbX2eG57yiDCt", + "SCikQ5jIOsEMWDzVuPCdELlQdu2TV3OaZgkck8/kU5BQDVx3Q8EVUxp4uOglYdq13HVVNE8+BcdkcNDv", + "kE8BB8n+UL2MzSHpUqm7Revhna+AcxTs0Ry5Ic9WvtwJOMRUs1sYucW/gYnh0k2eqefoXoZFQGZTqu1f", + "MA8TEwGZSJG2qPgs5kLaFTQh1QVJPpl+/9uQDHy23+askXeOtTbuTTpyfj3KQLbJMKiL8BaXGhGTAhB8", + "jMhA5uJVGDEpOXOd34FssMO4htitXuSHT0ACiqYhq67lQb+/mp8IuGDK2hgHHpALIcH9JkYZmljUAoqY", + "lUNUDkWFKGOjiUrEDCQpubDTRCZBzx0viNISeKynDfmK/uQKuW6TzlfvNqti3ZpcbVNFJ6AXo3AK4U1F", + "eVoaqGvvHUiLiYQSN4zgMFyKSrMUcX9Sxy4LCyaJbBwWkwlwZReZkGRKZToxic/mlZv1FJkpmR0LkQDl", + "yC1A1NTIFeRuKSmPREocvq1Qhe3cqu/CVhUt9A9+WAHXYuLCuQsSTHBCsyxhyyAnobCxs8yzvm0ZVALZ", + "VUGzgc21uJ8VBnSBrSUBqET2zRlAe2a1ddgsRX+wyPmAGVZpkm1h+YvQeDXJVV5Xs+0mk26Z0/3KIhBN", + "k05qoPh9pyW9n0iagkJAVhAKHuHyruQht3Z6X7qfV+DWFMN+hebRi1aqridhnGA4V1sQfeMmb6O79dot", + "4w9182P8/FtXrWNj93QiFbb3aGzCG9B1LgaHL+psvC8IWhMz+9EyZVVOU2G4tgZwczpnmlYTCrSZC4W2", + "KYdZ+zO1sTMfOWNJYsGecWxqmPDCdXuJTFcE80O7YApG1MSjFbDcP2zkqaUIOJjQKFqCcUVgly6TN5WN", + "R77pkKAgHSeYNq8c6xJeHkqgqpC7EuKRgRMTk9UAvzl9OTz6L85e9nlFoYkZi2qrd9A//K4ND7HnTnD4", + "AeduUt0xwrjQsSbEnJ9fNCPLlCkt5KIKfR+vfbTOe7RBF52PtLgBXl/z33tIQedk6Pq0KXYl9u4W8rfI", + "kbUEmlbITGiioJrH0bR9aS2UhnRUVtZa+LzCLqS1lNYJNKSZtb+RUMPAF8sphl6nLXPJluVgzbxmFVxB", + "nALXJ3yhp4zHh80lMRbzlvIjSRBGyHeESkkXJGa3wAlVhJKxmBeFoBxt0aod6wW//f7b78TFZH/NvxTz", + "lZWXJvGzIuorx/x94zxVNyPGM6Nb5ROzrgQlEoOhzXYm2LkmlF5kLERsxi07JZmEWyaMsj8iFuJopnN0", + "6Sxza0THwfzN/AN59ubHDz8eHn2PwHR1clHZT1xYymfI5j+u9pGaxGK5uhkJo0tFrokKZ3aHZaCz1KDL", + "LSRoI21yYbdhdkKFfNF0zGJjlelU75aV6hAx0cDtn5EJrVxj0BpkPlJPKbdxh/E4Ac8MFakKzskvjvM2", + "P+d2USXsLxiFQshI7SZeJhjXBEcyTjWoMo0q511uLCmPgXzsdwbX+RLB0TldAvMMQu26j8F1kKDsR/vJ", + "mS9iqY2Ygqtq3pLTIqdOhjZBfWJNZ3g7P8y9XExyqXJD1HxhNgUJBGiYs0+YNRx59lvn9+fLGFjZTmG3", + "OmcepCNjCR1D0sLYOX4v89oKawU3A8J4xELUP7VdIZbC8CjvbbO+fqXLmIY3fpcmu45sG7tuGY8SETO9", + "w2pxwxQxvGs9QE1FYvNcXJ5uLsK40jb3ExPLImIctvvcXTonOnfUm3beNoNoxIR18eMGdDhdefoQCq6l", + "SDjY1bVMuduK6Pc5ipiIJBGzlowe0SJ1+bNkoU78iHBaMkVO/VErS+pt29/B4f22v426/hdsfxUqv3UX", + "7L797bvgf3FRfc0u8Id/cwmbXEFGcZVi9pBhiHbRBOPaN///jbWEMlkmZM5wGWO2K4A/4Pas4a5bbs82", + "F3eryLkGYt9nJVres7L7wAeiDwMNxokV3R8jNnjZi6N/kZttpc39kdGm0s7OXlw4Z4v/vhkO3624vWOb", + "try+E4GmLMErMknyyyQ4/vg5+F8Jk+A4+J/e8uJQL7811Ctv4txdN0+57FQQ5ZQZ90446pLnZD2Jl+Ks", + "kPVXmrAIpyulXiUK05Dip3WS1Oe7W/LiJFkygrsTlMHntj5BG99AEz09LZZ9lV+lqTa1qxS//F/lqA87", + "tB0GLbO7JYEW+oixl/kSaK6Ty8riWJlRtYQF1X7nq+6UdvRWxriAiFHfBO44v80EjT2G8pdRVeJVKnEn", + "ZDspBi8grdOLzo9H12nF9vFhonpOt+KwDyeui1iToEXQ8/MLX8Aqs9JrWe7v6pN5FUWsqY5sdPWHuFIr", + "ea+2QTvpze9N50nms9wi0QVVN2ono7mxRSFvheH87XTddJLOOsRwr6KyrPco8swNfV6WCLBAVL1UVN0s", + "V8uDG72oMR+qoDXChkKu8krUxzc2QvIJizAzcN2Rb6ynVElWIpmbeOOFzZwxVXTPtXpd432tfREEWoql", + "qW0ojGn3/JTl+3DvKsVYGF3LrnFc0+BcTWZNMh+moItzSkdwRhWZJDSOISJUkbdXP3+o5Fp2mu3zB2sJ", + "2+JSVH87XVLc6nDIyKR98veX53lRailCSLlNiWgYglLuKmtB4L1MNlrVYB/lWEG1+fZEc7XY8eGxNZwa", + "vtlbcBrXdevAg939wHPqSNUDT+eR8b1TyHhdHb3OX2x7vsF7RyV1wn6t13Mf8o7ILkWy/X3XfWnu673v", + "etTf1woftVa4r3w8wqWWL6ya1kJsNYS2xNmNlYdEhJWyA+WLvJRSXw+fGyxe3/mQHCKZluwjvwyyzL3w", + "WVXrVg0/LLsiz2Rov27KRKwcjlTe09PUFtUOvBO0U+LXdouxdhcVr5luyruKS5m2byX127H4UE/5inur", + "jokNxYicVV9nFYW0aMxlny07HmzAhW+xDMGIEs1SUJqm2S7FB5wg9yCcdXN+attzSivmLJobExf69pQ3", + "LOfaoD/td6zVPJyiGhpEyAqNZHpxZY3plPFmOHz3EqgEWb53RJxzn8pJplpnwZ2dw+4jW6yQ3893PmlR", + "WBpOTs7K8w3lZ1PsFjIAadsvDedI6BakcnNZpYoMOM1YcBx8ezA46FsbUj1Fjnv4SK+rRbcwZCZUm0HL", + "V4nei0V3TSLfeYgsd6WzyKbV9Vd+Vt+g9EsRLYrzc+BIyEVAKnXPhqBuRDVdvhbd5EFtTwrvqva18Q4/", + "OG9AsQ/7/RoXnsJ7fygr87YsVDZLSLsWxAzugScmIctuneC7B2RhWcJuof+SRuTSad/RHTwN3fecGj0V", + "kv0FERIefPs0hHNhySuubUo4FIKcUxk7rQ+Onkr6ZfKGMO8CoWXh8PBBWWgcJzSZWXYh5ZHD0VOtvzOu", + "QXKakCuQtyALDjwExYTFx86P13fXnUCZNKVyUTx1JkNBirBBY2WBu4jDFrDnXZUBvQGqFl1OU+iKW5CS", + "RQj7FXToBL0pnl5gnQJQ9ip6ucON4BFBwz8+2RYz7nyV5CyiNJiVWgwvD87bQfwky5JFcXpeeQmESE7t", + "HsQmNF6e20D12tOtR4b1CrUnxvXqgc4e2FcD+x7QdgU0d9N7KEh5F2VHRGNVx/BBYItEDvf6+U26jXlc", + "9WXf0zj835HHtZ1u7r3+H57O7aHn3tBzz1yKVTzUB57b8lFvK/K8bnvKulPSUTz9ehoMctSeGISqhZs9", + "/OyTjkfw/PIJ5f1cv3CMTtBLknQLh8dKoMEDHUoSymNjGSnPMhru7p72rfZyX8Xz7mw266K3G5kAD0Xk", + "ThJ283lL8old3b9ztHf0vaM/nKPnT2N39G7ry+jU+aW1Ls3fSXUPV/t4/qQqvyKFr+IoXxPJW55gPXI0", + "b1B8YjevXj7bO/re0R/O0QvvKxY3ObyH36umg1gQwFc8WxQVlwWF/NVe8W633f9rj4Me2/kr5PYFxH0p", + "YV9KeED4cQ7/BWVMVfPPTtCzm4UtYOd1/cW0FKl/GVS1wo9362brHcbup8HVez171NnnG1+Jw+P9xi9w", + "d+25Hzq78Z4At7p5/gyx3FSQ8aL4z8zwHYZWZPmf2bS6/PIh4yPnGgWhvb/v/f0r8XfvEfCOnm58Z1DI", + "gEJytf/opri8dpoIE5FTkaaG29TsNdUwo4sgf3WEV+bUca8XSaBpN3atB0k+/CC0w/GW64r5rzReIVk1", + "bTmRwn49mrHeGDTtlfLeXd/9JwAA//9lcv3I0GAAAA==", } // GetSwagger returns the content of the embedded swagger specification file diff --git a/worker/worker.go b/worker/worker.go index f54f6b49..8247392b 100644 --- a/worker/worker.go +++ b/worker/worker.go @@ -516,6 +516,54 @@ func (w *Worker) ImageToText(ctx context.Context, req GenImageToTextMultipartReq return resp.JSON200, nil } +func (w *Worker) SketchToImage(ctx context.Context, req GenSketchToImageMultipartRequestBody) (*ImageResponse, error) { + c, err := w.borrowContainer(ctx, "sketch-to-image", *req.ModelId) + if err != nil { + return nil, err + } + defer w.returnContainer(c) + + var buf bytes.Buffer + mw, err := NewSketchToImageMultipartWriter(&buf, req) + if err != nil { + return nil, err + } + + resp, err := c.Client.GenSketchToImageWithBodyWithResponse(ctx, mw.FormDataContentType(), &buf) + if err != nil { + return nil, err + } + + if resp.JSON422 != nil { + val, err := json.Marshal(resp.JSON422) + if err != nil { + return nil, err + } + slog.Error("sketch-to-image container returned 422", slog.String("err", string(val))) + return nil, errors.New("sketch-to-image container returned 422") + } + + if resp.JSON400 != nil { + val, err := json.Marshal(resp.JSON400) + if err != nil { + return nil, err + } + slog.Error("sketch-to-image container returned 400", slog.String("err", string(val))) + return nil, errors.New("sketch-to-image container returned 400: " + resp.JSON400.Detail.Msg) + } + + if resp.JSON500 != nil { + val, err := json.Marshal(resp.JSON500) + if err != nil { + return nil, err + } + slog.Error("sketch-to-image container returned 500", slog.String("err", string(val))) + return nil, errors.New("sketch-to-image container returned 500") + } + + return resp.JSON200, nil +} + func (w *Worker) Warm(ctx context.Context, pipeline string, modelID string, endpoint RunnerEndpoint, optimizationFlags OptimizationFlags) error { if endpoint.URL == "" { return w.manager.Warm(ctx, pipeline, modelID, optimizationFlags)