Skip to content

Commit

Permalink
inpainting added, and works via api
Browse files Browse the repository at this point in the history
  • Loading branch information
zsxkib committed Aug 29, 2024
1 parent 1b51b99 commit ad819de
Show file tree
Hide file tree
Showing 2 changed files with 142 additions and 8 deletions.
1 change: 1 addition & 0 deletions cog.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ build:
- "Pygments==2.16.1"
run:
- curl -o /usr/local/bin/pget -L "https://github.com/replicate/pget/releases/download/v0.8.2/pget_linux_x86_64" && chmod +x /usr/local/bin/pget
- pip install "git+https://github.com/Gothos/diffusers.git@flux-inpaint"

# predict.py defines how predictions are run on your model
predict: "predict.py:Predictor"
Expand Down
149 changes: 141 additions & 8 deletions predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@
import subprocess
import time
from dataclasses import dataclass
from typing import List, cast
from typing import List, cast, Tuple

import numpy as np
import torch
from PIL import Image
from cog import BasePredictor, Input, Path
from diffusers.pipelines.flux.pipeline_flux import FluxPipeline
from diffusers.pipelines.flux.pipeline_flux_inpaint import FluxInpaintPipeline
from diffusers.pipelines.stable_diffusion.safety_checker import (
StableDiffusionSafetyChecker,
)
Expand Down Expand Up @@ -83,7 +85,7 @@ def setup(self) -> None: # pyright: ignore
tokenizer=dev_pipe.tokenizer,
tokenizer_2=dev_pipe.tokenizer_2,
torch_dtype=torch.bfloat16,
).to("cuda")
).to("cpu") # Keep schnell on CPU initially

self.pipes = {
"dev": dev_pipe,
Expand All @@ -93,13 +95,29 @@ def setup(self) -> None: # pyright: ignore
"dev": LoadedLoRAs(main=None, extra=None),
"schnell": LoadedLoRAs(main=None, extra=None),
}
self.inpaint_pipes = {
"dev": None,
"schnell": None,
}
self.current_model = "dev"
self.current_inpaint = None

self.loaded_models = ["safety_checker", "dev"]
print(f"[!] Loaded models: {self.loaded_models}")

print("setup took: ", time.time() - start)

@torch.inference_mode()
def predict( # pyright: ignore
self,
prompt: str = Input(description="Prompt for generated image"),
image: Path = Input(
description="Input image for img2img or inpaint mode", default=None
),
mask: Path = Input(
description="Input mask for inpaint mode. Black areas will be preserved, white areas will be inpainted.",
default=None,
),
aspect_ratio: str = Input(
description="Aspect ratio for the generated image. The size will always be 1 megapixel, i.e. 1024x1024 if aspect ratio is 1:1. To use arbitrary width and height, set aspect ratio to 'custom'.",
choices=list(ASPECT_RATIOS.keys()) + ["custom"], # pyright: ignore
Expand Down Expand Up @@ -146,6 +164,12 @@ def predict( # pyright: ignore
le=10,
default=3.5,
),
prompt_strength: float = Input(
description="Strength for img2img or inpaint. 1.0 corresponds to full destruction of information in image",
ge=0.0,
le=1.0,
default=0.8,
),
seed: int = Input(
description="Random seed. Set for reproducible generation", default=None
),
Expand Down Expand Up @@ -197,21 +221,43 @@ def predict( # pyright: ignore

flux_kwargs = {}
print(f"Prompt: {prompt}")
print("txt2img mode")

inpaint_mode = image is not None and mask is not None
self.configure_active_model(model, inpaint_mode)

if inpaint_mode:
print("inpaint mode")
input_image = Image.open(image).convert("RGB")
mask_image = Image.open(mask).convert("RGB")
width, height = self.resize_image_dimensions(input_image.size)
flux_kwargs["image"] = input_image.resize((width, height), Image.LANCZOS)
flux_kwargs["mask_image"] = mask_image.resize(
(width, height), Image.LANCZOS
)
flux_kwargs["strength"] = prompt_strength
print(f"Using {model} model for inpainting")
pipe = self.inpaint_pipes[model]
else:
# TODO add img2img mode (when we have just image and not mask)
print("txt2img mode")
pipe = self.pipes[model]

flux_kwargs["width"] = width
flux_kwargs["height"] = height
if replicate_weights:
flux_kwargs["joint_attention_kwargs"] = {"scale": lora_scale}

# Avoid a footgun in case we update the model input but forget to
# update clauses in this if statement
assert model in ["dev", "schnell"]
if model == "dev":
print("Using dev model")
max_sequence_length = 512
else:
else: # model == "schnell":
print("Using schnell model")
max_sequence_length = 256
guidance_scale = 0

pipe = self.pipes[model]

if replicate_weights:
start_time = time.time()
if extra_lora:
Expand All @@ -230,7 +276,15 @@ def predict( # pyright: ignore
pipe.unload_lora_weights()
self.loaded_lora_urls[model] = LoadedLoRAs(main=None, extra=None)

generator = torch.Generator("cuda").manual_seed(seed)
# Ensure all model components are on the correct device
device = pipe.device
for component_name in ['unet', 'text_encoder', 'text_encoder_2', 'vae']:
if hasattr(pipe, component_name):
component = getattr(pipe, component_name)
if isinstance(component, torch.nn.Module):
component.to(device)

generator = torch.Generator(device=device).manual_seed(seed)

common_args = {
"prompt": [prompt] * num_outputs,
Expand Down Expand Up @@ -318,6 +372,85 @@ def run_safety_checker(self, image):
def aspect_ratio_to_width_height(self, aspect_ratio: str) -> tuple[int, int]:
return ASPECT_RATIOS[aspect_ratio]

def configure_active_model(self, model: str, inpaint: bool = False):
initial_models = set(self.loaded_models)

# Unload current model if it's different
if self.current_model != model:
if self.current_model:
self.pipes[self.current_model].to("cpu")
self.loaded_models.remove(self.current_model)

self.pipes[model].to("cuda")
self.current_model = model
self.loaded_models.append(model)

# Ensure the model and all its components are on CUDA
pipe = self.pipes[model]
if pipe.device.type != "cuda":
print(f"Moving {model} model to CUDA.")
pipe.to("cuda")

# Explicitly move specific model components to CUDA
for component_name in ['unet', 'text_encoder', 'text_encoder_2', 'vae']:
if hasattr(pipe, component_name):
component = getattr(pipe, component_name)
if isinstance(component, torch.nn.Module):
component.to("cuda")

# Handle inpainting models
if inpaint:
if self.current_inpaint != model:
if self.current_inpaint:
self.inpaint_pipes[self.current_inpaint].to("cpu")
self.loaded_models.remove(f"{self.current_inpaint}_inpaint")

if self.inpaint_pipes[model] is None:
base_pipe = self.pipes[model]
self.inpaint_pipes[model] = FluxInpaintPipeline.from_pretrained(
f"FLUX.1-{model}",
text_encoder=base_pipe.text_encoder,
text_encoder_2=base_pipe.text_encoder_2,
tokenizer=base_pipe.tokenizer,
tokenizer_2=base_pipe.tokenizer_2,
torch_dtype=torch.bfloat16,
).to("cuda")
else:
self.inpaint_pipes[model].to("cuda")

self.current_inpaint = model
self.loaded_models.append(f"{model}_inpaint")
else:
if self.current_inpaint:
self.inpaint_pipes[self.current_inpaint].to("cpu")
self.loaded_models.remove(f"{self.current_inpaint}_inpaint")
self.current_inpaint = None

torch.cuda.empty_cache()

if set(self.loaded_models) != initial_models:
print(f"[!] Loaded models: {self.loaded_models}")

def resize_image_dimensions(
self,
original_resolution_wh: Tuple[int, int],
maximum_dimension: int = 1024,
) -> Tuple[int, int]:
width, height = original_resolution_wh

if width > height:
scaling_factor = maximum_dimension / width
else:
scaling_factor = maximum_dimension / height

new_width = int(width * scaling_factor)
new_height = int(height * scaling_factor)

new_width = new_width - (new_width % 32)
new_height = new_height - (new_height % 32)

return new_width, new_height


def download_base_weights(url: str, dest: Path):
start = time.time()
Expand All @@ -328,4 +461,4 @@ def download_base_weights(url: str, dest: Path):


def make_multiple_of_16(n):
return ((n + 15) // 16) * 16
return ((n + 15) // 16) * 16

0 comments on commit ad819de

Please sign in to comment.