From a890dd835d29546ec2b33d72e24e4107c3214035 Mon Sep 17 00:00:00 2001 From: zsxkib Date: Fri, 23 Aug 2024 22:36:31 +0000 Subject: [PATCH] inpainting added, and works via api --- cog.yaml | 1 + predict.py | 132 ++++++++++++++++++++++++++++++++++++++++++++--------- 2 files changed, 111 insertions(+), 22 deletions(-) diff --git a/cog.yaml b/cog.yaml index a9f9a33..3752b5a 100644 --- a/cog.yaml +++ b/cog.yaml @@ -55,6 +55,7 @@ build: - "Pygments==2.16.1" run: - curl -o /usr/local/bin/pget -L "https://github.com/replicate/pget/releases/download/v0.8.2/pget_linux_x86_64" && chmod +x /usr/local/bin/pget + - pip install "git+https://github.com/Gothos/diffusers.git@flux-inpaint" # predict.py defines how predictions are run on your model predict: "predict.py:Predictor" diff --git a/predict.py b/predict.py index 2d5912d..445fbce 100644 --- a/predict.py +++ b/predict.py @@ -3,15 +3,16 @@ import time import torch import subprocess -from typing import List +from typing import List, Tuple from cog import BasePredictor, Input, Path import numpy as np -from diffusers import FluxPipeline +from diffusers import FluxPipeline, FluxInpaintPipeline from transformers import CLIPImageProcessor from diffusers.pipelines.stable_diffusion.safety_checker import ( StableDiffusionSafetyChecker, ) from weights import WeightsDownloadCache +from PIL import Image MODEL_URL_DEV = ( @@ -55,10 +56,12 @@ def setup(self) -> None: print("Loading safety checker...") if not os.path.exists(SAFETY_CACHE): download_base_weights(SAFETY_URL, SAFETY_CACHE) - self.safety_checker = StableDiffusionSafetyChecker.from_pretrained( - SAFETY_CACHE, torch_dtype=torch.float16 - ).to("cuda") - self.feature_extractor = CLIPImageProcessor.from_pretrained(FEATURE_EXTRACTOR) + + # TODO: implement safety checker w/ lazy loading + # self.safety_checker = StableDiffusionSafetyChecker.from_pretrained( + # SAFETY_CACHE, torch_dtype=torch.float16 + # ).to("cuda") + # self.feature_extractor = CLIPImageProcessor.from_pretrained(FEATURE_EXTRACTOR) print("Loading Flux dev pipeline") if not os.path.exists("FLUX.1-dev"): @@ -88,13 +91,30 @@ def setup(self) -> None: "dev": LoadedLoRAs(main=None, extra=None), "schnell": LoadedLoRAs(main=None, extra=None), } + self.inpaint_pipes = {} print("setup took: ", time.time() - start) + def setup_inpaint_pipeline(self, model: str): + if model not in self.inpaint_pipes: + print(f"Creating inpaint pipeline for {model}") + base_pipe = self.pipes[model] + inpaint_pipe = FluxInpaintPipeline.from_pretrained( + f"FLUX.1-{model}", + text_encoder=base_pipe.text_encoder, + text_encoder_2=base_pipe.text_encoder_2, + tokenizer=base_pipe.tokenizer, + tokenizer_2=base_pipe.tokenizer_2, + torch_dtype=torch.bfloat16, + ).to("cuda") + self.inpaint_pipes[model] = inpaint_pipe + @torch.inference_mode() def predict( self, prompt: str = Input(description="Prompt for generated image"), + image: Path = Input(description="Input image for img2img or inpaint mode", default=None), + mask: Path = Input(description="Input mask for inpaint mode. Black areas will be preserved, white areas will be inpainted.", default=None), aspect_ratio: str = Input( description="Aspect ratio for the generated image. The size will always be 1 megapixel, i.e. 1024x1024 if aspect ratio is 1:1. To use arbitrary width and height, set aspect ratio to 'custom'.", choices=list(ASPECT_RATIOS.keys()) + ["custom"], @@ -141,6 +161,12 @@ def predict( le=10, default=3.5, ), + strength: float = Input( + description="Strength for img2img or inpaint. 1.0 corresponds to full destruction of information in image", + ge=0.0, + le=1.0, + default=0.8, + ), seed: int = Input( description="Random seed. Set for reproducible generation", default=None ), @@ -175,6 +201,9 @@ def predict( ), ) -> List[Path]: """Run a single prediction on the model""" + if not prompt: + raise ValueError("Please enter a text prompt.") + if seed is None or seed < 0: seed = int.from_bytes(os.urandom(2), "big") print(f"Using seed: {seed}") @@ -192,20 +221,48 @@ def predict( flux_kwargs = {} print(f"Prompt: {prompt}") - print("txt2img mode") + + if image and mask: + print("inpaint mode") + self.setup_inpaint_pipeline(model) + input_image = self.load_image(image) + mask_image = self.load_image(mask) + width, height = self.resize_image_dimensions(input_image.size) + flux_kwargs["image"] = input_image.resize((width, height), Image.LANCZOS) + flux_kwargs["mask_image"] = mask_image.resize((width, height), Image.LANCZOS) + flux_kwargs["strength"] = strength + print(f"Using {model} model for inpainting") + pipe = self.inpaint_pipes[model] + else: + # Unload inpainting pipelines if they exist + if self.inpaint_pipes: + print("Unloading inpaint pipelines to free CUDA memory") + self.inpaint_pipes.clear() + torch.cuda.empty_cache() + + if image: + print("img2img mode") + flux_kwargs["image"] = self.load_image(image) + flux_kwargs["strength"] = strength + else: + print("txt2img mode") + + pipe = self.pipes[model] + flux_kwargs["width"] = width flux_kwargs["height"] = height + if replicate_weights: flux_kwargs["joint_attention_kwargs"] = {"scale": lora_scale} if model == "dev": print("Using dev model") max_sequence_length = 512 - else: + elif model == "schnell": print("Using schnell model") max_sequence_length = 256 guidance_scale = 0 - pipe = self.pipes[model] + print("Available pipelines:", list(self.pipes.keys())) if replicate_weights: start_time = time.time() @@ -238,8 +295,14 @@ def predict( output = pipe(**common_args, **flux_kwargs) - if not disable_safety_checker: - _, has_nsfw_content = self.run_safety_checker(output.images) + # TODO: implement safety checker w/ lazy loading + # if not disable_safety_checker: + # _, has_nsfw_content = self.run_safety_checker(output.images) + # else: + # has_nsfw_content = [False] * len(output.images) + disable_safety_checker = True # TODO: remove this when we have a safety checker back + has_nsfw_content = [False] * len(output.images) # TODO: remove this when we have a safety checker back + output_paths = [] for i, image in enumerate(output.images): @@ -260,6 +323,9 @@ def predict( return output_paths + def load_image(self, path): + return Image.open(path).convert("RGB") + def load_single_lora(self, lora_url: str, model: str): # If no change, skip if lora_url == self.loaded_lora_urls[model].main: @@ -299,19 +365,41 @@ def load_multiple_loras(self, main_lora_url: str, extra_lora_url: str, model: st @torch.amp.autocast("cuda") def run_safety_checker(self, image): - safety_checker_input = self.feature_extractor(image, return_tensors="pt").to( - "cuda" - ) - np_image = [np.array(val) for val in image] - image, has_nsfw_concept = self.safety_checker( - images=np_image, - clip_input=safety_checker_input.pixel_values.to(torch.float16), - ) - return image, has_nsfw_concept - - def aspect_ratio_to_width_height(self, aspect_ratio: str) -> tuple[int, int]: + # TODO: implement safety checker w/ lazy loading + # safety_checker_input = self.feature_extractor(image, return_tensors="pt").to( + # "cuda" + # ) + # np_image = [np.array(val) for val in image] + # image, has_nsfw_concept = self.safety_checker( + # images=np_image, + # clip_input=safety_checker_input.pixel_values.to(torch.float16), + # ) + # return image, has_nsfw_concept + return image, False # TODO: remove this when we have a safety checker back + + def aspect_ratio_to_width_height(self, aspect_ratio: str) -> Tuple[int, int]: return ASPECT_RATIOS[aspect_ratio] + def resize_image_dimensions( + self, + original_resolution_wh: Tuple[int, int], + maximum_dimension: int = 1024, + ) -> Tuple[int, int]: + width, height = original_resolution_wh + + if width > height: + scaling_factor = maximum_dimension / width + else: + scaling_factor = maximum_dimension / height + + new_width = int(width * scaling_factor) + new_height = int(height * scaling_factor) + + new_width = new_width - (new_width % 32) + new_height = new_height - (new_height % 32) + + return new_width, new_height + def download_base_weights(url, dest): start = time.time()