Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

inpainting added, and works via api #13

Merged
merged 5 commits into from
Aug 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cog.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ build:
- "Pygments==2.16.1"
run:
- curl -o /usr/local/bin/pget -L "https://github.com/replicate/pget/releases/download/v0.8.2/pget_linux_x86_64" && chmod +x /usr/local/bin/pget
- pip install "git+https://github.com/Gothos/diffusers.git@flux-inpaint"

# predict.py defines how predictions are run on your model
predict: "predict.py:Predictor"
Expand Down
105 changes: 98 additions & 7 deletions predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@
import subprocess
import time
from dataclasses import dataclass
from typing import List, cast
from typing import List, cast, Tuple

import numpy as np
import torch
from PIL import Image
from cog import BasePredictor, Input, Path
from diffusers.pipelines.flux.pipeline_flux import FluxPipeline
from diffusers.pipelines.flux.pipeline_flux_inpaint import FluxInpaintPipeline
from diffusers.pipelines.stable_diffusion.safety_checker import (
StableDiffusionSafetyChecker,
)
Expand Down Expand Up @@ -50,7 +52,7 @@ class Predictor(BasePredictor):
def setup(self) -> None: # pyright: ignore
"""Load the model into memory to make running multiple predictions efficient"""
start = time.time()
# Dont pull weights
# Don't pull weights
os.environ["TRANSFORMERS_OFFLINE"] = "1"

self.weights_cache = WeightsDownloadCache()
Expand Down Expand Up @@ -78,6 +80,7 @@ def setup(self) -> None: # pyright: ignore
download_base_weights(MODEL_URL_SCHNELL, FLUX_SCHNELL_PATH)
schnell_pipe = FluxPipeline.from_pretrained(
"FLUX.1-schnell",
vae=dev_pipe.vae,
text_encoder=dev_pipe.text_encoder,
text_encoder_2=dev_pipe.text_encoder_2,
tokenizer=dev_pipe.tokenizer,
Expand All @@ -89,11 +92,48 @@ def setup(self) -> None: # pyright: ignore
"dev": dev_pipe,
"schnell": schnell_pipe,
}

# Load inpainting pipelines
print("Loading Flux dev inpaint pipeline")
dev_inpaint_pipe = FluxInpaintPipeline(
transformer=dev_pipe.transformer,
scheduler=dev_pipe.scheduler,
vae=dev_pipe.vae,
text_encoder=dev_pipe.text_encoder,
text_encoder_2=dev_pipe.text_encoder_2,
tokenizer=dev_pipe.tokenizer,
tokenizer_2=dev_pipe.tokenizer_2,
).to("cuda")

print("Loading Flux schnell inpaint pipeline")
schnell_inpaint_pipe = FluxInpaintPipeline(
transformer=schnell_pipe.transformer,
scheduler=schnell_pipe.scheduler,
vae=schnell_pipe.vae,
text_encoder=schnell_pipe.text_encoder,
text_encoder_2=schnell_pipe.text_encoder_2,
tokenizer=schnell_pipe.tokenizer,
tokenizer_2=schnell_pipe.tokenizer_2,
).to("cuda")

self.inpaint_pipes = {
"dev": dev_inpaint_pipe,
"schnell": schnell_inpaint_pipe,
}

self.loaded_lora_urls = {
"dev": LoadedLoRAs(main=None, extra=None),
"schnell": LoadedLoRAs(main=None, extra=None),
}

self.loaded_models = [
"safety_checker",
"dev",
"schnell",
"dev_inpaint",
"schnell_inpaint",
]
print(f"[!] Loaded models: {self.loaded_models}")
print("setup took: ", time.time() - start)

@torch.inference_mode()
Expand All @@ -102,6 +142,13 @@ def predict( # pyright: ignore
prompt: str = Input(
description="Prompt for generated image. If you include the `trigger_word` used in the training process you are more likely to activate the trained object, style, or concept in the resulting image."
),
image: Path = Input(
description="Input image for img2img or inpaint mode", default=None
),
mask: Path = Input(
description="Input mask for inpaint mode. Black areas will be preserved, white areas will be inpainted.",
default=None,
),
aspect_ratio: str = Input(
description="Aspect ratio for the generated image. The size will always be 1 megapixel, i.e. 1024x1024 if aspect ratio is 1:1. To use arbitrary width and height, set aspect ratio to 'custom'.",
choices=list(ASPECT_RATIOS.keys()) + ["custom"], # pyright: ignore
Expand Down Expand Up @@ -148,6 +195,12 @@ def predict( # pyright: ignore
le=10,
default=3.5,
),
prompt_strength: float = Input(
description="Strength for img2img or inpaint. 1.0 corresponds to full destruction of information in image",
ge=0.0,
le=1.0,
default=0.8,
),
seed: int = Input(
description="Random seed. Set for reproducible generation.", default=None
),
Expand Down Expand Up @@ -199,21 +252,39 @@ def predict( # pyright: ignore

flux_kwargs = {}
print(f"Prompt: {prompt}")
print("txt2img mode")

inpaint_mode = image is not None and mask is not None

if inpaint_mode:
print("inpaint mode")
input_image = Image.open(image).convert("RGB")
mask_image = Image.open(mask).convert("RGB")
width, height = self.resize_image_dimensions(input_image.size)
zsxkib marked this conversation as resolved.
Show resolved Hide resolved
flux_kwargs["image"] = input_image.resize((width, height), Image.LANCZOS)
flux_kwargs["mask_image"] = mask_image.resize(
(width, height), Image.LANCZOS
)
flux_kwargs["strength"] = prompt_strength
print(f"Using {model} model for inpainting")
pipe = self.inpaint_pipes[model]
else:
print("txt2img mode")
pipe = self.pipes[model]

flux_kwargs["width"] = width
flux_kwargs["height"] = height
if replicate_weights:
flux_kwargs["joint_attention_kwargs"] = {"scale": lora_scale}

assert model in ["dev", "schnell"]
if model == "dev":
zsxkib marked this conversation as resolved.
Show resolved Hide resolved
print("Using dev model")
max_sequence_length = 512
else:
else: # model == "schnell":
print("Using schnell model")
max_sequence_length = 256
guidance_scale = 0
andreasjansson marked this conversation as resolved.
Show resolved Hide resolved

pipe = self.pipes[model]

if replicate_weights:
start_time = time.time()
if extra_lora:
Expand All @@ -232,7 +303,7 @@ def predict( # pyright: ignore
pipe.unload_lora_weights()
self.loaded_lora_urls[model] = LoadedLoRAs(main=None, extra=None)

generator = torch.Generator("cuda").manual_seed(seed)
generator = torch.Generator(device="cuda").manual_seed(seed)

common_args = {
"prompt": [prompt] * num_outputs,
Expand Down Expand Up @@ -320,6 +391,26 @@ def run_safety_checker(self, image):
def aspect_ratio_to_width_height(self, aspect_ratio: str) -> tuple[int, int]:
return ASPECT_RATIOS[aspect_ratio]

def resize_image_dimensions(
self,
original_resolution_wh: Tuple[int, int],
maximum_dimension: int = 1024,
) -> Tuple[int, int]:
width, height = original_resolution_wh

if width > height:
scaling_factor = maximum_dimension / width
else:
scaling_factor = maximum_dimension / height

new_width = int(width * scaling_factor)
new_height = int(height * scaling_factor)

new_width = new_width - (new_width % 32)
new_height = new_height - (new_height % 32)

return new_width, new_height


def download_base_weights(url: str, dest: Path):
start = time.time()
Expand Down
Loading