From 408784061f26580a8c2de25cad01cf43cb9c7705 Mon Sep 17 00:00:00 2001 From: Yondon Fu Date: Thu, 8 Feb 2024 15:30:40 +0000 Subject: [PATCH] runner: Add seed in response --- runner/app/pipelines/image_to_image.py | 9 +++++- runner/app/pipelines/image_to_video.py | 9 +++++- runner/app/pipelines/text_to_image.py | 9 +++++- runner/app/routes/image_to_image.py | 27 +++++++++++++++-- runner/app/routes/image_to_video.py | 8 +++++- runner/app/routes/text_to_image.py | 19 ++++++++++-- runner/app/routes/util.py | 1 + runner/openapi.json | 7 ++++- worker/runner.gen.go | 40 ++++++++++++++------------ 9 files changed, 100 insertions(+), 29 deletions(-) diff --git a/runner/app/pipelines/image_to_image.py b/runner/app/pipelines/image_to_image.py index 2e009a83..a4a021aa 100644 --- a/runner/app/pipelines/image_to_image.py +++ b/runner/app/pipelines/image_to_image.py @@ -47,7 +47,14 @@ def __init__(self, model_id: str): def __call__(self, prompt: str, image: PIL.Image, **kwargs) -> List[PIL.Image]: seed = kwargs.pop("seed", None) if seed is not None: - kwargs["generator"] = torch.Generator(get_torch_device()).manual_seed(seed) + if isinstance(seed, int): + kwargs["generator"] = torch.Generator(get_torch_device()).manual_seed( + seed + ) + elif isinstance(seed, list): + kwargs["generator"] = [ + torch.Generator(get_torch_device()).manual_seed(s) for s in seed + ] if ( self.model_id == "stabilityai/sdxl-turbo" diff --git a/runner/app/pipelines/image_to_video.py b/runner/app/pipelines/image_to_video.py index b6308823..cb2c07f1 100644 --- a/runner/app/pipelines/image_to_video.py +++ b/runner/app/pipelines/image_to_video.py @@ -50,7 +50,14 @@ def __call__(self, image: PIL.Image, **kwargs) -> List[List[PIL.Image]]: seed = kwargs.pop("seed", None) if seed is not None: - kwargs["generator"] = torch.Generator(get_torch_device()).manual_seed(seed) + if isinstance(seed, int): + kwargs["generator"] = torch.Generator(get_torch_device()).manual_seed( + seed + ) + elif isinstance(seed, list): + kwargs["generator"] = [ + torch.Generator(get_torch_device()).manual_seed(s) for s in seed + ] return self.ldm(image, **kwargs).frames diff --git a/runner/app/pipelines/text_to_image.py b/runner/app/pipelines/text_to_image.py index eb6b5292..bb8b9998 100644 --- a/runner/app/pipelines/text_to_image.py +++ b/runner/app/pipelines/text_to_image.py @@ -43,7 +43,14 @@ def __init__(self, model_id: str): def __call__(self, prompt: str, **kwargs) -> List[PIL.Image]: seed = kwargs.pop("seed", None) if seed is not None: - kwargs["generator"] = torch.Generator(get_torch_device()).manual_seed(seed) + if isinstance(seed, int): + kwargs["generator"] = torch.Generator(get_torch_device()).manual_seed( + seed + ) + elif isinstance(seed, list): + kwargs["generator"] = [ + torch.Generator(get_torch_device()).manual_seed(s) for s in seed + ] if ( self.model_id == "stabilityai/sdxl-turbo" diff --git a/runner/app/routes/image_to_image.py b/runner/app/routes/image_to_image.py index 6a756ea0..fb8a33ba 100644 --- a/runner/app/routes/image_to_image.py +++ b/runner/app/routes/image_to_image.py @@ -6,6 +6,8 @@ import PIL from typing import Annotated import logging +import random +from typing import List router = APIRouter() @@ -42,10 +44,25 @@ async def image_to_image( ), ) + if seed is None: + init_seed = random.randint(0, 2**32 - 1) + if num_images_per_prompt > 1: + seed = [i for i in range(init_seed, init_seed + num_images_per_prompt)] + else: + seed = init_seed + + img = PIL.Image.open(image.file).convert("RGB") + # If a list of seeds/generators is passed, diffusers wants a list of images + # https://github.com/huggingface/diffusers/blob/17808a091e2d5615c2ed8a63d7ae6f2baea11e1e/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py#L715 + if isinstance(seed, list): + image = [img] * num_images_per_prompt + else: + image = img + try: images = pipeline( prompt, - PIL.Image.open(image.file).convert("RGB"), + image, strength=strength, guidance_scale=guidance_scale, negative_prompt=negative_prompt, @@ -59,8 +76,12 @@ async def image_to_image( status_code=500, content=http_error("ImageToImagePipeline error") ) + seeds = seed + if not isinstance(seeds, list): + seeds = [seeds] + output_images = [] - for img in images: - output_images.append({"url": image_to_data_url(img)}) + for img, s in zip(images, seeds): + output_images.append({"url": image_to_data_url(img), "seed": s}) return {"images": output_images} diff --git a/runner/app/routes/image_to_video.py b/runner/app/routes/image_to_video.py index 0977a3c7..2bf5576c 100644 --- a/runner/app/routes/image_to_video.py +++ b/runner/app/routes/image_to_video.py @@ -6,6 +6,7 @@ import PIL from typing import Annotated import logging +import random router = APIRouter() @@ -44,6 +45,9 @@ async def image_to_video( }, ) + if seed is None: + seed = random.randint(0, 2**32 - 1) + try: batch_frames = pipeline( PIL.Image.open(image.file).convert("RGB"), @@ -63,6 +67,8 @@ async def image_to_video( output_frames = [] for frames in batch_frames: - output_frames.append([{"url": image_to_data_url(frame)} for frame in frames]) + output_frames.append( + [{"url": image_to_data_url(frame), "seed": seed} for frame in frames] + ) return {"frames": output_frames} diff --git a/runner/app/routes/text_to_image.py b/runner/app/routes/text_to_image.py index 78f0b089..3a4028ef 100644 --- a/runner/app/routes/text_to_image.py +++ b/runner/app/routes/text_to_image.py @@ -5,6 +5,8 @@ from app.dependencies import get_pipeline from app.routes.util import image_to_data_url, ImageResponse, HTTPError, http_error import logging +from typing import List +import random router = APIRouter() @@ -40,6 +42,15 @@ async def text_to_image( ), ) + if params.seed is None: + init_seed = random.randint(0, 2**32 - 1) + if params.num_images_per_prompt > 1: + params.seed = [ + i for i in range(init_seed, init_seed + params.num_images_per_prompt) + ] + else: + params.seed = init_seed + try: images = pipeline(**params.model_dump()) except Exception as e: @@ -49,8 +60,12 @@ async def text_to_image( status_code=500, content=http_error("TextToImagePipeline error") ) + seeds = params.seed + if not isinstance(seeds, list): + seeds = [seeds] + output_images = [] - for img in images: - output_images.append({"url": image_to_data_url(img)}) + for img, sd in zip(images, seeds): + output_images.append({"url": image_to_data_url(img), "seed": sd}) return {"images": output_images} diff --git a/runner/app/routes/util.py b/runner/app/routes/util.py index 9cb24f9e..d14ccb25 100644 --- a/runner/app/routes/util.py +++ b/runner/app/routes/util.py @@ -7,6 +7,7 @@ class Media(BaseModel): url: str + seed: int class ImageResponse(BaseModel): diff --git a/runner/openapi.json b/runner/openapi.json index 453733c7..a494c110 100644 --- a/runner/openapi.json +++ b/runner/openapi.json @@ -367,11 +367,16 @@ "url": { "type": "string", "title": "Url" + }, + "seed": { + "type": "integer", + "title": "Seed" } }, "type": "object", "required": [ - "url" + "url", + "seed" ], "title": "Media" }, diff --git a/worker/runner.gen.go b/worker/runner.gen.go index 1504859b..87d27693 100644 --- a/worker/runner.gen.go +++ b/worker/runner.gen.go @@ -73,7 +73,8 @@ type ImageResponse struct { // Media defines model for Media. type Media struct { - Url string `json:"url"` + Seed int `json:"seed"` + Url string `json:"url"` } // TextToImageParams defines model for TextToImageParams. @@ -1066,24 +1067,25 @@ func HandlerWithOptions(si ServerInterface, options ChiServerOptions) http.Handl // Base64 encoded, gzipped, json marshaled Swagger object var swaggerSpec = []string{ - "H4sIAAAAAAAC/+xX227jNhD9FYLtoxM77qYp/Jb0tkabbhC724fFwmCkscxdiWTJkbuG4X8vSMoSdavs", - "YjdFgTzFkuZyZubwcLKnkcyUFCDQ0NmemmgDGXM/bx/mP2ottf2ttFSgkYP7kpnE/kGOKdAZvTcJHVHc", - "KftgUHOR0MNhRDX8mXMNMZ29cy7vR6VLGbv0k08fIEJ6GNE7Ge9WPGMJrFAWPxqPShpsw0pyHjMRwcpE", - "zGbZ0xjWLE+Rzm4ur6vkPxd2ZOHsSggiz55AWwguiw2wljpjSGf0iQumd7QKMncmrbJHNJMxpCse1/LT", - "wPPeGpB53OUsIGHIt7BSWmYKe2P8VtiRB2/XFSrPfLfMSoHuCngVxMsz4ioy5AF0KyoXCIlvTRXn6NsP", - "wQDEoeXCPncFNahBJLipwZtcflcBXBwtWtNqEE0d0fgZBpw7lVeDlNzyGGTzsZuSa2XqPKzg/KRMZy82", - "wJNNfVDXN99Wfq/99y7X/4y2mUQuxeopjz4CNoNcTW/CKNaS3DnLWrSgDiG5gRXLk1UPMSbTgLrWmNzm", - "CennyBlU/IvHjXRXk+mrKt0f7nvbs0HDAfb1U6iDfa+Xy4ceJY4BGU/tr681rOmMfjWu9HxciPm4VNsm", - "ysI9gFnl6gHylqU8ZnaIg5A4QmaGsDXjHSosP/hIJRCmNdu5GkK0zQBduIGluPl+A9HHNl6DDPP6KaVv", - "fqGh9DiDrhuuOpNVgo787tA9glFSGGgj8Cp9csfuIeYs7JMX7q4+tRhpwlnXYXXg9plaeHOdhkfpd50O", - "Xv/WJcjsA3dkXMInXEoH7IFp5pvxpW75SmlP0NaXa/38a73U0jPFswATEKbNiw7yDEpTKqPaKWNi92ZN", - "Z+/2rRr3LYjvgwP3q4xcmtaRG7VWYzCm58L1LypTh5ks7duh02Tr8KkKy6BTJ8jhW3vb9MvRWrOsIUdn", - "6lKjJ+XG4wMP6FSRPiyphrdVkA3AxVr6Q2AizZUbzozeCsKUSrmfFkFJdC7I7ZworiDlwoM5DpVvQQFo", - "+/0xFwJs77agjY81uby6nNhqpALBFKcz+o17NaKK4cZ1Z7xx94BTKXCHyfbVJZ/H5TVBbb2+GOc1nUzs", - "n0gKBOG8AtDjD8amP/5jNjSD8CJyjak3ZJFHERizzlNS9tNamTzL7J5YQrQvx05mLlBelHvlccmtl+WO", - "ZXE6qR8mGLQLT6OuLE+RK6ZxbBfUi5ghO720U9f3Q51QqHM4fMGO1y/RU3s+oq8+59TLpa0j/x2LyaMf", - "ics7nX7WvK39rY2gMiHljnf9XOXPBYIWLCUL0FvQpFqEK9K7GZKl9Hdlg/xuNx8kv9Oo5yJ//38Pz0z+", - "ujK/kP9/TX5PYUd+hE94gvAHW9k/Uv/fV9fe+17k/YXhZzLckihU98Ph7wAAAP//+Oaai/YWAAA=", + "H4sIAAAAAAAC/+xX32/bNhD+Vwhuj07seM0y+C3Zrxpb1iD2uoeiMBjpLLOVSI4/vBqG//eBpC1RojTJ", + "Q5thQJ5iSce77+6++3jZ44QXgjNgWuHZHqtkAwVxP28f5j9KyaX9LSQXIDUF96VQmf2jqc4Bz/C9yvAI", + "652wD0pLyjJ8OIywhD8NlZDi2Tt35P2oPFL6Ls/xpw+QaHwY4Tue7la0IBmsND/+aDwKrnQMKzM0JSyB", + "lUqIjbLHKayJyTWe3VxeV8F/PtqhhbMrITBTPIG0EFwU62DNZUE0nuEnyojc4crJ3JlEaY9wwVPIVzSt", + "xcfByXtrgOZp22EGGdF0CysheSF0p4/fjnbowdu1uTKFr5ZaCZBtDq8Cf6ZALiOFHkBGXinTkPnSVH5O", + "Z7shKIA0tFzY5zanSktgmd7U4E0uv6sALk4WUbcaRBMnNL6HAeeG8qqXkluaAm8+tlNyLVSdhxWcn4Rq", + "rcUGaLapN+r65tvq3Gv/ve3of0bbgmvK2erJJB9BN51cTW9CL9YS3TnLmrcgD8apghUx2aqDGJNpQF1r", + "jG5Nhro5cgYV/6JpI9zVZPqqCveH+x6fbNCwh33dFGph3+vl8qFDiVPQhOb219cS1niGvxpXej4+ivm4", + "VNsmyuPxAGYVqwPIW5LTlNgm9kKiGgrVh63p71Bh+cF7KoEQKcnO5RCibTpoww0k15vvN5B8jPEqTbSp", + "Tyl+8wsOpccZtN1w1UxWAVriu6F7BCU4UxAj8Co9uGL3kFIS1skLd1udIkaqsNd1WC24faS4YkNnycg8", + "tPtd5r17gnE2LkKA1ANpQbiET3rJXSIPRBJfvC+1FVTKPECLX9aA89eAUnvPFNsjmIAwMS9ayNMrZTlP", + "alNJ2O7NGs/e7aMc9xHE98GA/soTFyYa0VG0SoNSHRe0f1GZOsxoad/2DZXNw4c6WgaVGiCfb+3t1C1f", + "a0mKhnydqWONmpQbknfco2vH8GFKNbxRQtYBZWvuh0AlkgrXnBm+ZYgIkVPfLaQ5koah2zkSVEBOmQdz", + "airdggCQ9vujYQxs7bYglfc1uby6nNhsuABGBMUz/I17NcKC6I2rznjj7g2nUuCGydbVBZ+n5bWCbb4+", + "GXdqOpnYPwlnGpg7FYAef1A2/Okfub4ehBeXK0y9IAuTJKDU2uSorKe1UqYo7F5ZQrQvx05mLjS/KPfQ", + "01JcT8uN5XE6sW8mKG0XpEZehck1FUTqsV1oL1KiyfDUhq77hzqhtDRw+IIVr1+6Q2s+wq8+Z9fLJa8l", + "/h1J0aNviYs7nX7WuNG+FyOoTFC5E14/V/pzpkEykqMFyC1IVC3OFeldD9GS+7uyQX63y/eS32nUc5G/", + "+7+NZyZ/XZlfyP+/Jr+nsCO/hk96gPAHW9k/Uv/fZxfvfS/y/sLwMxluSRSq++HwdwAAAP//Jp1yTCYX", + "AAA=", } // GetSwagger returns the content of the embedded swagger specification file