Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

inpainting added, and works via api #13

Merged
merged 5 commits into from
Aug 30, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cog.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ build:
- "Pygments==2.16.1"
run:
- curl -o /usr/local/bin/pget -L "https://github.com/replicate/pget/releases/download/v0.8.2/pget_linux_x86_64" && chmod +x /usr/local/bin/pget
- pip install "git+https://github.com/Gothos/diffusers.git@flux-inpaint"

# predict.py defines how predictions are run on your model
predict: "predict.py:Predictor"
Expand Down
149 changes: 141 additions & 8 deletions predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@
import subprocess
import time
from dataclasses import dataclass
from typing import List, cast
from typing import List, cast, Tuple

import numpy as np
import torch
from PIL import Image
from cog import BasePredictor, Input, Path
from diffusers.pipelines.flux.pipeline_flux import FluxPipeline
from diffusers.pipelines.flux.pipeline_flux_inpaint import FluxInpaintPipeline
from diffusers.pipelines.stable_diffusion.safety_checker import (
StableDiffusionSafetyChecker,
)
Expand Down Expand Up @@ -83,7 +85,7 @@ def setup(self) -> None: # pyright: ignore
tokenizer=dev_pipe.tokenizer,
tokenizer_2=dev_pipe.tokenizer_2,
torch_dtype=torch.bfloat16,
).to("cuda")
).to("cpu") # Keep schnell on CPU initially

self.pipes = {
"dev": dev_pipe,
Expand All @@ -93,13 +95,29 @@ def setup(self) -> None: # pyright: ignore
"dev": LoadedLoRAs(main=None, extra=None),
"schnell": LoadedLoRAs(main=None, extra=None),
}
self.inpaint_pipes = {
"dev": None,
"schnell": None,
}
self.current_model = "dev"
self.current_inpaint = None

self.loaded_models = ["safety_checker", "dev"]
print(f"[!] Loaded models: {self.loaded_models}")

print("setup took: ", time.time() - start)

@torch.inference_mode()
def predict( # pyright: ignore
self,
prompt: str = Input(description="Prompt for generated image"),
image: Path = Input(
description="Input image for img2img or inpaint mode", default=None
),
mask: Path = Input(
description="Input mask for inpaint mode. Black areas will be preserved, white areas will be inpainted.",
default=None,
),
aspect_ratio: str = Input(
description="Aspect ratio for the generated image. The size will always be 1 megapixel, i.e. 1024x1024 if aspect ratio is 1:1. To use arbitrary width and height, set aspect ratio to 'custom'.",
choices=list(ASPECT_RATIOS.keys()) + ["custom"], # pyright: ignore
Expand Down Expand Up @@ -146,6 +164,12 @@ def predict( # pyright: ignore
le=10,
default=3.5,
),
prompt_strength: float = Input(
description="Strength for img2img or inpaint. 1.0 corresponds to full destruction of information in image",
ge=0.0,
le=1.0,
default=0.8,
),
seed: int = Input(
description="Random seed. Set for reproducible generation", default=None
),
Expand Down Expand Up @@ -197,21 +221,43 @@ def predict( # pyright: ignore

flux_kwargs = {}
print(f"Prompt: {prompt}")
print("txt2img mode")

inpaint_mode = image is not None and mask is not None
self.configure_active_model(model, inpaint_mode)

if inpaint_mode:
print("inpaint mode")
input_image = Image.open(image).convert("RGB")
mask_image = Image.open(mask).convert("RGB")
width, height = self.resize_image_dimensions(input_image.size)
zsxkib marked this conversation as resolved.
Show resolved Hide resolved
flux_kwargs["image"] = input_image.resize((width, height), Image.LANCZOS)
flux_kwargs["mask_image"] = mask_image.resize(
(width, height), Image.LANCZOS
)
flux_kwargs["strength"] = prompt_strength
print(f"Using {model} model for inpainting")
pipe = self.inpaint_pipes[model]
else:
# TODO add img2img mode (when we have just image and not mask)
print("txt2img mode")
pipe = self.pipes[model]

flux_kwargs["width"] = width
flux_kwargs["height"] = height
if replicate_weights:
flux_kwargs["joint_attention_kwargs"] = {"scale": lora_scale}

# Avoid a footgun in case we update the model input but forget to
# update clauses in this if statement
assert model in ["dev", "schnell"]
if model == "dev":
zsxkib marked this conversation as resolved.
Show resolved Hide resolved
print("Using dev model")
max_sequence_length = 512
else:
else: # model == "schnell":
print("Using schnell model")
max_sequence_length = 256
guidance_scale = 0
andreasjansson marked this conversation as resolved.
Show resolved Hide resolved

pipe = self.pipes[model]

if replicate_weights:
start_time = time.time()
if extra_lora:
Expand All @@ -230,7 +276,15 @@ def predict( # pyright: ignore
pipe.unload_lora_weights()
self.loaded_lora_urls[model] = LoadedLoRAs(main=None, extra=None)

generator = torch.Generator("cuda").manual_seed(seed)
# Ensure all model components are on the correct device
device = pipe.device
for component_name in ['unet', 'text_encoder', 'text_encoder_2', 'vae']:
if hasattr(pipe, component_name):
component = getattr(pipe, component_name)
if isinstance(component, torch.nn.Module):
component.to(device)

generator = torch.Generator(device=device).manual_seed(seed)

common_args = {
"prompt": [prompt] * num_outputs,
Expand Down Expand Up @@ -318,6 +372,85 @@ def run_safety_checker(self, image):
def aspect_ratio_to_width_height(self, aspect_ratio: str) -> tuple[int, int]:
return ASPECT_RATIOS[aspect_ratio]

def configure_active_model(self, model: str, inpaint: bool = False):
initial_models = set(self.loaded_models)

# Unload current model if it's different
if self.current_model != model:
if self.current_model:
self.pipes[self.current_model].to("cpu")
self.loaded_models.remove(self.current_model)

self.pipes[model].to("cuda")
self.current_model = model
self.loaded_models.append(model)

# Ensure the model and all its components are on CUDA
pipe = self.pipes[model]
if pipe.device.type != "cuda":
print(f"Moving {model} model to CUDA.")
pipe.to("cuda")

# Explicitly move specific model components to CUDA
for component_name in ['unet', 'text_encoder', 'text_encoder_2', 'vae']:
if hasattr(pipe, component_name):
component = getattr(pipe, component_name)
if isinstance(component, torch.nn.Module):
component.to("cuda")

# Handle inpainting models
if inpaint:
if self.current_inpaint != model:
if self.current_inpaint:
self.inpaint_pipes[self.current_inpaint].to("cpu")
self.loaded_models.remove(f"{self.current_inpaint}_inpaint")

if self.inpaint_pipes[model] is None:
base_pipe = self.pipes[model]
self.inpaint_pipes[model] = FluxInpaintPipeline.from_pretrained(
f"FLUX.1-{model}",
text_encoder=base_pipe.text_encoder,
text_encoder_2=base_pipe.text_encoder_2,
tokenizer=base_pipe.tokenizer,
tokenizer_2=base_pipe.tokenizer_2,
torch_dtype=torch.bfloat16,
).to("cuda")
else:
self.inpaint_pipes[model].to("cuda")

self.current_inpaint = model
self.loaded_models.append(f"{model}_inpaint")
else:
if self.current_inpaint:
self.inpaint_pipes[self.current_inpaint].to("cpu")
self.loaded_models.remove(f"{self.current_inpaint}_inpaint")
self.current_inpaint = None

torch.cuda.empty_cache()

if set(self.loaded_models) != initial_models:
print(f"[!] Loaded models: {self.loaded_models}")

def resize_image_dimensions(
self,
original_resolution_wh: Tuple[int, int],
maximum_dimension: int = 1024,
) -> Tuple[int, int]:
width, height = original_resolution_wh

if width > height:
scaling_factor = maximum_dimension / width
else:
scaling_factor = maximum_dimension / height

new_width = int(width * scaling_factor)
new_height = int(height * scaling_factor)

new_width = new_width - (new_width % 32)
new_height = new_height - (new_height % 32)

return new_width, new_height


def download_base_weights(url: str, dest: Path):
start = time.time()
Expand All @@ -328,4 +461,4 @@ def download_base_weights(url: str, dest: Path):


def make_multiple_of_16(n):
return ((n + 15) // 16) * 16
return ((n + 15) // 16) * 16
Loading