Skip to content

Commit

Permalink
runner: Fix manual seeding for generators
Browse files Browse the repository at this point in the history
  • Loading branch information
yondonfu committed Feb 5, 2024
1 parent 1a87744 commit 5c48959
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion runner/app/pipelines/image_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def __init__(self, model_id: str):
def __call__(self, prompt: str, image: PIL.Image, **kwargs) -> List[PIL.Image]:
seed = kwargs.pop("seed", None)
if seed is not None:
kwargs["generator"] = torch.Generator(seed)
kwargs["generator"] = torch.Generator(get_torch_device()).manual_seed(seed)

if (
self.model_id == "stabilityai/sdxl-turbo"
Expand Down
2 changes: 1 addition & 1 deletion runner/app/pipelines/image_to_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def __call__(self, image: PIL.Image, **kwargs) -> List[List[PIL.Image]]:

seed = kwargs.pop("seed", None)
if seed is not None:
kwargs["generator"] = torch.Generator(seed)
kwargs["generator"] = torch.Generator(get_torch_device()).manual_seed(seed)

return self.ldm(image, **kwargs).frames

Expand Down
2 changes: 1 addition & 1 deletion runner/app/pipelines/text_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def __init__(self, model_id: str):
def __call__(self, prompt: str, **kwargs) -> List[PIL.Image]:
seed = kwargs.pop("seed", None)
if seed is not None:
kwargs["generator"] = torch.Generator(seed)
kwargs["generator"] = torch.Generator(get_torch_device()).manual_seed(seed)

if (
self.model_id == "stabilityai/sdxl-turbo"
Expand Down

0 comments on commit 5c48959

Please sign in to comment.