diff --git a/cog.yaml b/cog.yaml index a9f9a33..3752b5a 100644 --- a/cog.yaml +++ b/cog.yaml @@ -55,6 +55,7 @@ build: - "Pygments==2.16.1" run: - curl -o /usr/local/bin/pget -L "https://github.com/replicate/pget/releases/download/v0.8.2/pget_linux_x86_64" && chmod +x /usr/local/bin/pget + - pip install "git+https://github.com/Gothos/diffusers.git@flux-inpaint" # predict.py defines how predictions are run on your model predict: "predict.py:Predictor" diff --git a/predict.py b/predict.py index 342630e..6ff96d3 100644 --- a/predict.py +++ b/predict.py @@ -2,12 +2,14 @@ import subprocess import time from dataclasses import dataclass -from typing import List, cast +from typing import List, cast, Tuple import numpy as np import torch +from PIL import Image from cog import BasePredictor, Input, Path from diffusers.pipelines.flux.pipeline_flux import FluxPipeline +from diffusers.pipelines.flux.pipeline_flux_inpaint import FluxInpaintPipeline from diffusers.pipelines.stable_diffusion.safety_checker import ( StableDiffusionSafetyChecker, ) @@ -50,7 +52,7 @@ class Predictor(BasePredictor): def setup(self) -> None: # pyright: ignore """Load the model into memory to make running multiple predictions efficient""" start = time.time() - # Dont pull weights + # Don't pull weights os.environ["TRANSFORMERS_OFFLINE"] = "1" self.weights_cache = WeightsDownloadCache() @@ -78,6 +80,7 @@ def setup(self) -> None: # pyright: ignore download_base_weights(MODEL_URL_SCHNELL, FLUX_SCHNELL_PATH) schnell_pipe = FluxPipeline.from_pretrained( "FLUX.1-schnell", + vae=dev_pipe.vae, text_encoder=dev_pipe.text_encoder, text_encoder_2=dev_pipe.text_encoder_2, tokenizer=dev_pipe.tokenizer, @@ -89,11 +92,48 @@ def setup(self) -> None: # pyright: ignore "dev": dev_pipe, "schnell": schnell_pipe, } + + # Load inpainting pipelines + print("Loading Flux dev inpaint pipeline") + dev_inpaint_pipe = FluxInpaintPipeline( + transformer=dev_pipe.transformer, + scheduler=dev_pipe.scheduler, + vae=dev_pipe.vae, + text_encoder=dev_pipe.text_encoder, + text_encoder_2=dev_pipe.text_encoder_2, + tokenizer=dev_pipe.tokenizer, + tokenizer_2=dev_pipe.tokenizer_2, + ).to("cuda") + + print("Loading Flux schnell inpaint pipeline") + schnell_inpaint_pipe = FluxInpaintPipeline( + transformer=schnell_pipe.transformer, + scheduler=schnell_pipe.scheduler, + vae=schnell_pipe.vae, + text_encoder=schnell_pipe.text_encoder, + text_encoder_2=schnell_pipe.text_encoder_2, + tokenizer=schnell_pipe.tokenizer, + tokenizer_2=schnell_pipe.tokenizer_2, + ).to("cuda") + + self.inpaint_pipes = { + "dev": dev_inpaint_pipe, + "schnell": schnell_inpaint_pipe, + } + self.loaded_lora_urls = { "dev": LoadedLoRAs(main=None, extra=None), "schnell": LoadedLoRAs(main=None, extra=None), } + self.loaded_models = [ + "safety_checker", + "dev", + "schnell", + "dev_inpaint", + "schnell_inpaint", + ] + print(f"[!] Loaded models: {self.loaded_models}") print("setup took: ", time.time() - start) @torch.inference_mode() @@ -102,6 +142,13 @@ def predict( # pyright: ignore prompt: str = Input( description="Prompt for generated image. If you include the `trigger_word` used in the training process you are more likely to activate the trained object, style, or concept in the resulting image." ), + image: Path = Input( + description="Input image for img2img or inpaint mode", default=None + ), + mask: Path = Input( + description="Input mask for inpaint mode. Black areas will be preserved, white areas will be inpainted.", + default=None, + ), aspect_ratio: str = Input( description="Aspect ratio for the generated image. The size will always be 1 megapixel, i.e. 1024x1024 if aspect ratio is 1:1. To use arbitrary width and height, set aspect ratio to 'custom'.", choices=list(ASPECT_RATIOS.keys()) + ["custom"], # pyright: ignore @@ -148,6 +195,12 @@ def predict( # pyright: ignore le=10, default=3.5, ), + prompt_strength: float = Input( + description="Strength for img2img or inpaint. 1.0 corresponds to full destruction of information in image", + ge=0.0, + le=1.0, + default=0.8, + ), seed: int = Input( description="Random seed. Set for reproducible generation.", default=None ), @@ -199,21 +252,39 @@ def predict( # pyright: ignore flux_kwargs = {} print(f"Prompt: {prompt}") - print("txt2img mode") + + inpaint_mode = image is not None and mask is not None + + if inpaint_mode: + print("inpaint mode") + input_image = Image.open(image).convert("RGB") + mask_image = Image.open(mask).convert("RGB") + width, height = self.resize_image_dimensions(input_image.size) + flux_kwargs["image"] = input_image.resize((width, height), Image.LANCZOS) + flux_kwargs["mask_image"] = mask_image.resize( + (width, height), Image.LANCZOS + ) + flux_kwargs["strength"] = prompt_strength + print(f"Using {model} model for inpainting") + pipe = self.inpaint_pipes[model] + else: + print("txt2img mode") + pipe = self.pipes[model] + flux_kwargs["width"] = width flux_kwargs["height"] = height if replicate_weights: flux_kwargs["joint_attention_kwargs"] = {"scale": lora_scale} + + assert model in ["dev", "schnell"] if model == "dev": print("Using dev model") max_sequence_length = 512 - else: + else: # model == "schnell": print("Using schnell model") max_sequence_length = 256 guidance_scale = 0 - pipe = self.pipes[model] - if replicate_weights: start_time = time.time() if extra_lora: @@ -232,7 +303,7 @@ def predict( # pyright: ignore pipe.unload_lora_weights() self.loaded_lora_urls[model] = LoadedLoRAs(main=None, extra=None) - generator = torch.Generator("cuda").manual_seed(seed) + generator = torch.Generator(device="cuda").manual_seed(seed) common_args = { "prompt": [prompt] * num_outputs, @@ -320,6 +391,26 @@ def run_safety_checker(self, image): def aspect_ratio_to_width_height(self, aspect_ratio: str) -> tuple[int, int]: return ASPECT_RATIOS[aspect_ratio] + def resize_image_dimensions( + self, + original_resolution_wh: Tuple[int, int], + maximum_dimension: int = 1024, + ) -> Tuple[int, int]: + width, height = original_resolution_wh + + if width > height: + scaling_factor = maximum_dimension / width + else: + scaling_factor = maximum_dimension / height + + new_width = int(width * scaling_factor) + new_height = int(height * scaling_factor) + + new_width = new_width - (new_width % 32) + new_height = new_height - (new_height % 32) + + return new_width, new_height + def download_base_weights(url: str, dest: Path): start = time.time()