Skip to content

Commit

Permalink
inpainting added, and works via api
Browse files Browse the repository at this point in the history
  • Loading branch information
zsxkib committed Aug 23, 2024
1 parent 4f79b10 commit a890dd8
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 22 deletions.
1 change: 1 addition & 0 deletions cog.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ build:
- "Pygments==2.16.1"
run:
- curl -o /usr/local/bin/pget -L "https://github.com/replicate/pget/releases/download/v0.8.2/pget_linux_x86_64" && chmod +x /usr/local/bin/pget
- pip install "git+https://github.com/Gothos/diffusers.git@flux-inpaint"

# predict.py defines how predictions are run on your model
predict: "predict.py:Predictor"
Expand Down
132 changes: 110 additions & 22 deletions predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,16 @@
import time
import torch
import subprocess
from typing import List
from typing import List, Tuple
from cog import BasePredictor, Input, Path
import numpy as np
from diffusers import FluxPipeline
from diffusers import FluxPipeline, FluxInpaintPipeline
from transformers import CLIPImageProcessor
from diffusers.pipelines.stable_diffusion.safety_checker import (
StableDiffusionSafetyChecker,
)
from weights import WeightsDownloadCache
from PIL import Image


MODEL_URL_DEV = (
Expand Down Expand Up @@ -55,10 +56,12 @@ def setup(self) -> None:
print("Loading safety checker...")
if not os.path.exists(SAFETY_CACHE):
download_base_weights(SAFETY_URL, SAFETY_CACHE)
self.safety_checker = StableDiffusionSafetyChecker.from_pretrained(
SAFETY_CACHE, torch_dtype=torch.float16
).to("cuda")
self.feature_extractor = CLIPImageProcessor.from_pretrained(FEATURE_EXTRACTOR)

# TODO: implement safety checker w/ lazy loading
# self.safety_checker = StableDiffusionSafetyChecker.from_pretrained(
# SAFETY_CACHE, torch_dtype=torch.float16
# ).to("cuda")
# self.feature_extractor = CLIPImageProcessor.from_pretrained(FEATURE_EXTRACTOR)

print("Loading Flux dev pipeline")
if not os.path.exists("FLUX.1-dev"):
Expand Down Expand Up @@ -88,13 +91,30 @@ def setup(self) -> None:
"dev": LoadedLoRAs(main=None, extra=None),
"schnell": LoadedLoRAs(main=None, extra=None),
}
self.inpaint_pipes = {}

print("setup took: ", time.time() - start)

def setup_inpaint_pipeline(self, model: str):
if model not in self.inpaint_pipes:
print(f"Creating inpaint pipeline for {model}")
base_pipe = self.pipes[model]
inpaint_pipe = FluxInpaintPipeline.from_pretrained(
f"FLUX.1-{model}",
text_encoder=base_pipe.text_encoder,
text_encoder_2=base_pipe.text_encoder_2,
tokenizer=base_pipe.tokenizer,
tokenizer_2=base_pipe.tokenizer_2,
torch_dtype=torch.bfloat16,
).to("cuda")
self.inpaint_pipes[model] = inpaint_pipe

@torch.inference_mode()
def predict(
self,
prompt: str = Input(description="Prompt for generated image"),
image: Path = Input(description="Input image for img2img or inpaint mode", default=None),
mask: Path = Input(description="Input mask for inpaint mode. Black areas will be preserved, white areas will be inpainted.", default=None),
aspect_ratio: str = Input(
description="Aspect ratio for the generated image. The size will always be 1 megapixel, i.e. 1024x1024 if aspect ratio is 1:1. To use arbitrary width and height, set aspect ratio to 'custom'.",
choices=list(ASPECT_RATIOS.keys()) + ["custom"],
Expand Down Expand Up @@ -141,6 +161,12 @@ def predict(
le=10,
default=3.5,
),
strength: float = Input(
description="Strength for img2img or inpaint. 1.0 corresponds to full destruction of information in image",
ge=0.0,
le=1.0,
default=0.8,
),
seed: int = Input(
description="Random seed. Set for reproducible generation", default=None
),
Expand Down Expand Up @@ -175,6 +201,9 @@ def predict(
),
) -> List[Path]:
"""Run a single prediction on the model"""
if not prompt:
raise ValueError("Please enter a text prompt.")

if seed is None or seed < 0:
seed = int.from_bytes(os.urandom(2), "big")
print(f"Using seed: {seed}")
Expand All @@ -192,20 +221,48 @@ def predict(

flux_kwargs = {}
print(f"Prompt: {prompt}")
print("txt2img mode")

if image and mask:
print("inpaint mode")
self.setup_inpaint_pipeline(model)
input_image = self.load_image(image)
mask_image = self.load_image(mask)
width, height = self.resize_image_dimensions(input_image.size)
flux_kwargs["image"] = input_image.resize((width, height), Image.LANCZOS)
flux_kwargs["mask_image"] = mask_image.resize((width, height), Image.LANCZOS)
flux_kwargs["strength"] = strength
print(f"Using {model} model for inpainting")
pipe = self.inpaint_pipes[model]
else:
# Unload inpainting pipelines if they exist
if self.inpaint_pipes:
print("Unloading inpaint pipelines to free CUDA memory")
self.inpaint_pipes.clear()
torch.cuda.empty_cache()

if image:
print("img2img mode")
flux_kwargs["image"] = self.load_image(image)
flux_kwargs["strength"] = strength
else:
print("txt2img mode")

pipe = self.pipes[model]

flux_kwargs["width"] = width
flux_kwargs["height"] = height

if replicate_weights:
flux_kwargs["joint_attention_kwargs"] = {"scale": lora_scale}
if model == "dev":
print("Using dev model")
max_sequence_length = 512
else:
elif model == "schnell":
print("Using schnell model")
max_sequence_length = 256
guidance_scale = 0

pipe = self.pipes[model]
print("Available pipelines:", list(self.pipes.keys()))

if replicate_weights:
start_time = time.time()
Expand Down Expand Up @@ -238,8 +295,14 @@ def predict(

output = pipe(**common_args, **flux_kwargs)

if not disable_safety_checker:
_, has_nsfw_content = self.run_safety_checker(output.images)
# TODO: implement safety checker w/ lazy loading
# if not disable_safety_checker:
# _, has_nsfw_content = self.run_safety_checker(output.images)
# else:
# has_nsfw_content = [False] * len(output.images)
disable_safety_checker = True # TODO: remove this when we have a safety checker back
has_nsfw_content = [False] * len(output.images) # TODO: remove this when we have a safety checker back


output_paths = []
for i, image in enumerate(output.images):
Expand All @@ -260,6 +323,9 @@ def predict(

return output_paths

def load_image(self, path):
return Image.open(path).convert("RGB")

def load_single_lora(self, lora_url: str, model: str):
# If no change, skip
if lora_url == self.loaded_lora_urls[model].main:
Expand Down Expand Up @@ -299,19 +365,41 @@ def load_multiple_loras(self, main_lora_url: str, extra_lora_url: str, model: st

@torch.amp.autocast("cuda")
def run_safety_checker(self, image):
safety_checker_input = self.feature_extractor(image, return_tensors="pt").to(
"cuda"
)
np_image = [np.array(val) for val in image]
image, has_nsfw_concept = self.safety_checker(
images=np_image,
clip_input=safety_checker_input.pixel_values.to(torch.float16),
)
return image, has_nsfw_concept

def aspect_ratio_to_width_height(self, aspect_ratio: str) -> tuple[int, int]:
# TODO: implement safety checker w/ lazy loading
# safety_checker_input = self.feature_extractor(image, return_tensors="pt").to(
# "cuda"
# )
# np_image = [np.array(val) for val in image]
# image, has_nsfw_concept = self.safety_checker(
# images=np_image,
# clip_input=safety_checker_input.pixel_values.to(torch.float16),
# )
# return image, has_nsfw_concept
return image, False # TODO: remove this when we have a safety checker back

def aspect_ratio_to_width_height(self, aspect_ratio: str) -> Tuple[int, int]:
return ASPECT_RATIOS[aspect_ratio]

def resize_image_dimensions(
self,
original_resolution_wh: Tuple[int, int],
maximum_dimension: int = 1024,
) -> Tuple[int, int]:
width, height = original_resolution_wh

if width > height:
scaling_factor = maximum_dimension / width
else:
scaling_factor = maximum_dimension / height

new_width = int(width * scaling_factor)
new_height = int(height * scaling_factor)

new_width = new_width - (new_width % 32)
new_height = new_height - (new_height % 32)

return new_width, new_height


def download_base_weights(url, dest):
start = time.time()
Expand Down

0 comments on commit a890dd8

Please sign in to comment.