From 5c4895915580dce426ba935e5bb08f9045357bbc Mon Sep 17 00:00:00 2001 From: Yondon Fu Date: Mon, 5 Feb 2024 18:50:39 +0000 Subject: [PATCH] runner: Fix manual seeding for generators --- runner/app/pipelines/image_to_image.py | 2 +- runner/app/pipelines/image_to_video.py | 2 +- runner/app/pipelines/text_to_image.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/runner/app/pipelines/image_to_image.py b/runner/app/pipelines/image_to_image.py index 04b93481..2cd50720 100644 --- a/runner/app/pipelines/image_to_image.py +++ b/runner/app/pipelines/image_to_image.py @@ -37,7 +37,7 @@ def __init__(self, model_id: str): def __call__(self, prompt: str, image: PIL.Image, **kwargs) -> List[PIL.Image]: seed = kwargs.pop("seed", None) if seed is not None: - kwargs["generator"] = torch.Generator(seed) + kwargs["generator"] = torch.Generator(get_torch_device()).manual_seed(seed) if ( self.model_id == "stabilityai/sdxl-turbo" diff --git a/runner/app/pipelines/image_to_video.py b/runner/app/pipelines/image_to_video.py index 08a67311..d73cba33 100644 --- a/runner/app/pipelines/image_to_video.py +++ b/runner/app/pipelines/image_to_video.py @@ -40,7 +40,7 @@ def __call__(self, image: PIL.Image, **kwargs) -> List[List[PIL.Image]]: seed = kwargs.pop("seed", None) if seed is not None: - kwargs["generator"] = torch.Generator(seed) + kwargs["generator"] = torch.Generator(get_torch_device()).manual_seed(seed) return self.ldm(image, **kwargs).frames diff --git a/runner/app/pipelines/text_to_image.py b/runner/app/pipelines/text_to_image.py index 6fdb7aec..25480ff2 100644 --- a/runner/app/pipelines/text_to_image.py +++ b/runner/app/pipelines/text_to_image.py @@ -33,7 +33,7 @@ def __init__(self, model_id: str): def __call__(self, prompt: str, **kwargs) -> List[PIL.Image]: seed = kwargs.pop("seed", None) if seed is not None: - kwargs["generator"] = torch.Generator(seed) + kwargs["generator"] = torch.Generator(get_torch_device()).manual_seed(seed) if ( self.model_id == "stabilityai/sdxl-turbo"