Skip to content

Commit 5c48959

Browse files
committed
runner: Fix manual seeding for generators
1 parent 1a87744 commit 5c48959

File tree

3 files changed

+3
-3
lines changed

3 files changed

+3
-3
lines changed

runner/app/pipelines/image_to_image.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def __init__(self, model_id: str):
3737
def __call__(self, prompt: str, image: PIL.Image, **kwargs) -> List[PIL.Image]:
3838
seed = kwargs.pop("seed", None)
3939
if seed is not None:
40-
kwargs["generator"] = torch.Generator(seed)
40+
kwargs["generator"] = torch.Generator(get_torch_device()).manual_seed(seed)
4141

4242
if (
4343
self.model_id == "stabilityai/sdxl-turbo"

runner/app/pipelines/image_to_video.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def __call__(self, image: PIL.Image, **kwargs) -> List[List[PIL.Image]]:
4040

4141
seed = kwargs.pop("seed", None)
4242
if seed is not None:
43-
kwargs["generator"] = torch.Generator(seed)
43+
kwargs["generator"] = torch.Generator(get_torch_device()).manual_seed(seed)
4444

4545
return self.ldm(image, **kwargs).frames
4646

runner/app/pipelines/text_to_image.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def __init__(self, model_id: str):
3333
def __call__(self, prompt: str, **kwargs) -> List[PIL.Image]:
3434
seed = kwargs.pop("seed", None)
3535
if seed is not None:
36-
kwargs["generator"] = torch.Generator(seed)
36+
kwargs["generator"] = torch.Generator(get_torch_device()).manual_seed(seed)
3737

3838
if (
3939
self.model_id == "stabilityai/sdxl-turbo"

0 commit comments

Comments
 (0)