Skip to content

Commit

Permalink
runner: Add seed in response
Browse files Browse the repository at this point in the history
  • Loading branch information
yondonfu committed Feb 8, 2024
1 parent 87e4f48 commit 4087840
Show file tree
Hide file tree
Showing 9 changed files with 100 additions and 29 deletions.
9 changes: 8 additions & 1 deletion runner/app/pipelines/image_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,14 @@ def __init__(self, model_id: str):
def __call__(self, prompt: str, image: PIL.Image, **kwargs) -> List[PIL.Image]:
seed = kwargs.pop("seed", None)
if seed is not None:
kwargs["generator"] = torch.Generator(get_torch_device()).manual_seed(seed)
if isinstance(seed, int):
kwargs["generator"] = torch.Generator(get_torch_device()).manual_seed(
seed
)
elif isinstance(seed, list):
kwargs["generator"] = [
torch.Generator(get_torch_device()).manual_seed(s) for s in seed
]

if (
self.model_id == "stabilityai/sdxl-turbo"
Expand Down
9 changes: 8 additions & 1 deletion runner/app/pipelines/image_to_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,14 @@ def __call__(self, image: PIL.Image, **kwargs) -> List[List[PIL.Image]]:

seed = kwargs.pop("seed", None)
if seed is not None:
kwargs["generator"] = torch.Generator(get_torch_device()).manual_seed(seed)
if isinstance(seed, int):
kwargs["generator"] = torch.Generator(get_torch_device()).manual_seed(
seed
)
elif isinstance(seed, list):
kwargs["generator"] = [
torch.Generator(get_torch_device()).manual_seed(s) for s in seed
]

return self.ldm(image, **kwargs).frames

Expand Down
9 changes: 8 additions & 1 deletion runner/app/pipelines/text_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,14 @@ def __init__(self, model_id: str):
def __call__(self, prompt: str, **kwargs) -> List[PIL.Image]:
seed = kwargs.pop("seed", None)
if seed is not None:
kwargs["generator"] = torch.Generator(get_torch_device()).manual_seed(seed)
if isinstance(seed, int):
kwargs["generator"] = torch.Generator(get_torch_device()).manual_seed(
seed
)
elif isinstance(seed, list):
kwargs["generator"] = [
torch.Generator(get_torch_device()).manual_seed(s) for s in seed
]

if (
self.model_id == "stabilityai/sdxl-turbo"
Expand Down
27 changes: 24 additions & 3 deletions runner/app/routes/image_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import PIL
from typing import Annotated
import logging
import random
from typing import List

router = APIRouter()

Expand Down Expand Up @@ -42,10 +44,25 @@ async def image_to_image(
),
)

if seed is None:
init_seed = random.randint(0, 2**32 - 1)
if num_images_per_prompt > 1:
seed = [i for i in range(init_seed, init_seed + num_images_per_prompt)]
else:
seed = init_seed

img = PIL.Image.open(image.file).convert("RGB")
# If a list of seeds/generators is passed, diffusers wants a list of images
# https://github.com/huggingface/diffusers/blob/17808a091e2d5615c2ed8a63d7ae6f2baea11e1e/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py#L715
if isinstance(seed, list):
image = [img] * num_images_per_prompt
else:
image = img

try:
images = pipeline(
prompt,
PIL.Image.open(image.file).convert("RGB"),
image,
strength=strength,
guidance_scale=guidance_scale,
negative_prompt=negative_prompt,
Expand All @@ -59,8 +76,12 @@ async def image_to_image(
status_code=500, content=http_error("ImageToImagePipeline error")
)

seeds = seed
if not isinstance(seeds, list):
seeds = [seeds]

output_images = []
for img in images:
output_images.append({"url": image_to_data_url(img)})
for img, s in zip(images, seeds):
output_images.append({"url": image_to_data_url(img), "seed": s})

return {"images": output_images}
8 changes: 7 additions & 1 deletion runner/app/routes/image_to_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import PIL
from typing import Annotated
import logging
import random

router = APIRouter()

Expand Down Expand Up @@ -44,6 +45,9 @@ async def image_to_video(
},
)

if seed is None:
seed = random.randint(0, 2**32 - 1)

try:
batch_frames = pipeline(
PIL.Image.open(image.file).convert("RGB"),
Expand All @@ -63,6 +67,8 @@ async def image_to_video(

output_frames = []
for frames in batch_frames:
output_frames.append([{"url": image_to_data_url(frame)} for frame in frames])
output_frames.append(
[{"url": image_to_data_url(frame), "seed": seed} for frame in frames]
)

return {"frames": output_frames}
19 changes: 17 additions & 2 deletions runner/app/routes/text_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from app.dependencies import get_pipeline
from app.routes.util import image_to_data_url, ImageResponse, HTTPError, http_error
import logging
from typing import List
import random

router = APIRouter()

Expand Down Expand Up @@ -40,6 +42,15 @@ async def text_to_image(
),
)

if params.seed is None:
init_seed = random.randint(0, 2**32 - 1)
if params.num_images_per_prompt > 1:
params.seed = [
i for i in range(init_seed, init_seed + params.num_images_per_prompt)
]
else:
params.seed = init_seed

try:
images = pipeline(**params.model_dump())
except Exception as e:
Expand All @@ -49,8 +60,12 @@ async def text_to_image(
status_code=500, content=http_error("TextToImagePipeline error")
)

seeds = params.seed
if not isinstance(seeds, list):
seeds = [seeds]

output_images = []
for img in images:
output_images.append({"url": image_to_data_url(img)})
for img, sd in zip(images, seeds):
output_images.append({"url": image_to_data_url(img), "seed": sd})

return {"images": output_images}
1 change: 1 addition & 0 deletions runner/app/routes/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

class Media(BaseModel):
url: str
seed: int


class ImageResponse(BaseModel):
Expand Down
7 changes: 6 additions & 1 deletion runner/openapi.json
Original file line number Diff line number Diff line change
Expand Up @@ -367,11 +367,16 @@
"url": {
"type": "string",
"title": "Url"
},
"seed": {
"type": "integer",
"title": "Seed"
}
},
"type": "object",
"required": [
"url"
"url",
"seed"
],
"title": "Media"
},
Expand Down
40 changes: 21 additions & 19 deletions worker/runner.gen.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 4087840

Please sign in to comment.