Skip to content

Commit

Permalink
img2img added + update input descriptions (#23)
Browse files Browse the repository at this point in the history
* img2img added + update input descriptions

* Fix import bug in predict.py

* Run ruff format and check

* somehow messing w the import fixed the bug?? ( idk... fixed now tho)

* Use cropping instead of resizing for image processing
- Crop images for img2img and inpaint modes (from center)
- Make dimensions multiples of 16 (round down)
- Avoid resizing artifacts

* forgot to ruff format+check again🙄
  • Loading branch information
zsxkib authored Sep 5, 2024
1 parent 3a5056b commit a889701
Showing 1 changed file with 67 additions and 28 deletions.
95 changes: 67 additions & 28 deletions predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,11 @@
import logging
from PIL import Image
from cog import BasePredictor, Input, Path
from diffusers.pipelines.flux.pipeline_flux import FluxPipeline
from diffusers.pipelines.flux.pipeline_flux_inpaint import FluxInpaintPipeline
from diffusers.pipelines.flux import (
FluxPipeline,
FluxInpaintPipeline,
FluxImg2ImgPipeline,
)
from diffusers.pipelines.stable_diffusion.safety_checker import (
StableDiffusionSafetyChecker,
)
Expand Down Expand Up @@ -117,6 +120,34 @@ def setup(self) -> None: # pyright: ignore
"schnell": schnell_pipe,
}

# Load img2img pipelines
print("Loading Flux dev img2img pipeline")
dev_img2img_pipe = FluxImg2ImgPipeline(
transformer=dev_pipe.transformer,
scheduler=dev_pipe.scheduler,
vae=dev_pipe.vae,
text_encoder=dev_pipe.text_encoder,
text_encoder_2=dev_pipe.text_encoder_2,
tokenizer=dev_pipe.tokenizer,
tokenizer_2=dev_pipe.tokenizer_2,
).to("cuda")

print("Loading Flux schnell img2img pipeline")
schnell_img2img_pipe = FluxImg2ImgPipeline(
transformer=schnell_pipe.transformer,
scheduler=schnell_pipe.scheduler,
vae=schnell_pipe.vae,
text_encoder=schnell_pipe.text_encoder,
text_encoder_2=schnell_pipe.text_encoder_2,
tokenizer=schnell_pipe.tokenizer,
tokenizer_2=schnell_pipe.tokenizer_2,
).to("cuda")

self.img2img_pipes = {
"dev": dev_img2img_pipe,
"schnell": schnell_img2img_pipe,
}

# Load inpainting pipelines
print("Loading Flux dev inpaint pipeline")
dev_inpaint_pipe = FluxInpaintPipeline(
Expand Down Expand Up @@ -149,15 +180,6 @@ def setup(self) -> None: # pyright: ignore
"dev": LoadedLoRAs(main=None, extra=None),
"schnell": LoadedLoRAs(main=None, extra=None),
}

self.loaded_models = [
"safety_checker",
"dev",
"schnell",
"dev_inpaint",
"schnell_inpaint",
]
print(f"[!] Loaded models: {self.loaded_models}")
print("setup took: ", time.time() - start)

@torch.inference_mode()
Expand All @@ -172,18 +194,18 @@ def predict( # pyright: ignore
default=None,
),
aspect_ratio: str = Input(
description="Aspect ratio for the generated image. The size will always be 1 megapixel, i.e. 1024x1024 if aspect ratio is 1:1. To use arbitrary width and height, set aspect ratio to 'custom'.",
description="Aspect ratio for the generated image. The size will always be 1 megapixel, i.e. 1024x1024 if aspect ratio is 1:1. To use arbitrary width and height, set aspect ratio to 'custom'. Note: Ignored in img2img and inpainting modes.",
choices=list(ASPECT_RATIOS.keys()) + ["custom"], # pyright: ignore
default="1:1",
),
width: int = Input(
description="Width of the generated image. Optional, only used when aspect_ratio=custom. Must be a multiple of 16 (if it's not, it will be rounded to nearest multiple of 16)",
description="Width of the generated image. Optional, only used when aspect_ratio=custom. Must be a multiple of 16 (if it's not, it will be rounded to nearest multiple of 16). Note: Overridden by input image in img2img and inpainting modes.",
ge=256,
le=1440,
default=None,
),
height: int = Input(
description="Height of the generated image. Optional, only used when aspect_ratio=custom. Must be a multiple of 16 (if it's not, it will be rounded to nearest multiple of 16)",
description="Height of the generated image. Optional, only used when aspect_ratio=custom. Must be a multiple of 16 (if it's not, it will be rounded to nearest multiple of 16). Note: Overridden by input image in img2img and inpainting modes.",
ge=256,
le=1440,
default=None,
Expand Down Expand Up @@ -272,25 +294,41 @@ def predict( # pyright: ignore
width, height = self.aspect_ratio_to_width_height(aspect_ratio)
max_sequence_length = 512

is_img2img_mode = image is not None and mask is None
is_inpaint_mode = image is not None and mask is not None

flux_kwargs = {}
print(f"Prompt: {prompt}")

inpaint_mode = image is not None and mask is not None

if inpaint_mode:
print("inpaint mode")
if is_img2img_mode or is_inpaint_mode:
input_image = Image.open(image).convert("RGB")
mask_image = Image.open(mask).convert("RGB")
width, height = self.resize_image_dimensions(input_image.size)
flux_kwargs["image"] = input_image.resize((width, height), Image.LANCZOS)
flux_kwargs["mask_image"] = mask_image.resize(
(width, height), Image.LANCZOS
resized_width, resized_height = self.resize_image_dimensions(
input_image.size
)
flux_kwargs["strength"] = prompt_strength
print(f"Using {model} model for inpainting")
pipe = self.inpaint_pipes[model]
else:
print("txt2img mode")
# Crop to the nearest smaller multiple of 16
width = resized_width - (resized_width % 16)
height = resized_height - (resized_height % 16)
# Center crop
left = (input_image.width - width) // 2
top = (input_image.height - height) // 2
right = left + width
bottom = top + height
flux_kwargs["image"] = input_image.crop((left, top, right, bottom))

if is_img2img_mode:
print("[!] img2img mode")
flux_kwargs["strength"] = prompt_strength
print(f"Using {model} model for img2img")
pipe = self.img2img_pipes[model]
else: # is_inpaint_mode
print("[!] inpaint mode")
mask_image = Image.open(mask).convert("RGB")
flux_kwargs["mask_image"] = mask_image.crop((left, top, right, bottom))
flux_kwargs["strength"] = prompt_strength
print(f"Using {model} model for inpainting")
pipe = self.inpaint_pipes[model]
else: # is_txt2img_mode
print("[!] txt2img mode")
pipe = self.pipes[model]

flux_kwargs["width"] = width
Expand Down Expand Up @@ -460,4 +498,5 @@ def download_base_weights(url: str, dest: Path):


def make_multiple_of_16(n):
# Rounds up to the next multiple of 16, or returns n if already a multiple of 16
return ((n + 15) // 16) * 16

0 comments on commit a889701

Please sign in to comment.