Skip to content

Commit

Permalink
fix: getting rid of black borders in inpainting and img2img modes (#29)
Browse files Browse the repository at this point in the history
* img2img cropping bug fix
* again - img2img+inpainting cropping bug fixed
i.e. we go back to resizing
  • Loading branch information
zsxkib authored Sep 7, 2024
1 parent 011f7df commit 8c881c9
Showing 1 changed file with 37 additions and 45 deletions.
82 changes: 37 additions & 45 deletions predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import subprocess
import time
from dataclasses import dataclass
from typing import List, cast, Tuple
from typing import List, cast

import numpy as np
import torch
Expand Down Expand Up @@ -188,24 +188,27 @@ def predict( # pyright: ignore
prompt: str = Input(
description="Prompt for generated image. If you include the `trigger_word` used in the training process you are more likely to activate the trained object, style, or concept in the resulting image."
),
image: Path = Input(description="Input image for inpaint mode", default=None),
image: Path = Input(
description="Input image for img2img or inpainting mode. If provided, aspect_ratio, width, and height inputs are ignored.",
default=None,
),
mask: Path = Input(
description="Input mask for inpaint mode. Black areas will be preserved, white areas will be inpainted.",
description="Input mask for inpainting mode. Black areas will be preserved, white areas will be inpainted. Must be provided along with 'image' for inpainting mode.",
default=None,
),
aspect_ratio: str = Input(
description="Aspect ratio for the generated image. The size will always be 1 megapixel, i.e. 1024x1024 if aspect ratio is 1:1. To use arbitrary width and height, set aspect ratio to 'custom'. Note: Ignored in img2img and inpainting modes.",
description="Aspect ratio for the generated image in text-to-image mode. The size will always be 1 megapixel, i.e. 1024x1024 if aspect ratio is 1:1. To use arbitrary width and height, set aspect ratio to 'custom'. Note: Ignored in img2img and inpainting modes.",
choices=list(ASPECT_RATIOS.keys()) + ["custom"], # pyright: ignore
default="1:1",
),
width: int = Input(
description="Width of the generated image. Optional, only used when aspect_ratio=custom. Must be a multiple of 16 (if it's not, it will be rounded to nearest multiple of 16). Note: Overridden by input image in img2img and inpainting modes.",
description="Width of the generated image in text-to-image mode. Only used when aspect_ratio=custom. Must be a multiple of 16 (if it's not, it will be rounded to nearest multiple of 16). Note: Ignored in img2img and inpainting modes.",
ge=256,
le=1440,
default=None,
),
height: int = Input(
description="Height of the generated image. Optional, only used when aspect_ratio=custom. Must be a multiple of 16 (if it's not, it will be rounded to nearest multiple of 16). Note: Overridden by input image in img2img and inpainting modes.",
description="Height of the generated image in text-to-image mode. Only used when aspect_ratio=custom. Must be a multiple of 16 (if it's not, it will be rounded to nearest multiple of 16). Note: Ignored in img2img and inpainting modes.",
ge=256,
le=1440,
default=None,
Expand Down Expand Up @@ -240,7 +243,7 @@ def predict( # pyright: ignore
default=3.5,
),
prompt_strength: float = Input(
description="Strength for inpainting. 1.0 corresponds to full destruction of information in image",
description="Prompt strength when using img2img / inpaint. 1.0 corresponds to full destruction of information in image",
ge=0.0,
le=1.0,
default=0.8,
Expand Down Expand Up @@ -302,37 +305,46 @@ def predict( # pyright: ignore

if is_img2img_mode or is_inpaint_mode:
input_image = Image.open(image).convert("RGB")
resized_width, resized_height = self.resize_image_dimensions(
input_image.size
original_width, original_height = input_image.size

# Calculate dimensions that are multiples of 16
target_width = make_multiple_of_16(original_width)
target_height = make_multiple_of_16(original_height)
target_size = (target_width, target_height)

print(
f"[!] Resizing input image from {original_width}x{original_height} to {target_width}x{target_height}"
)
# Crop to the nearest smaller multiple of 16
width = resized_width - (resized_width % 16)
height = resized_height - (resized_height % 16)
# Center crop
left = (input_image.width - width) // 2
top = (input_image.height - height) // 2
right = left + width
bottom = top + height
flux_kwargs["image"] = input_image.crop((left, top, right, bottom))

# Determine if we should use highest quality settings
use_highest_quality = output_quality == 100 or output_format == "png"

# Resize the input image
resampling_method = Image.LANCZOS if use_highest_quality else Image.BICUBIC
input_image = input_image.resize(target_size, resampling_method)
flux_kwargs["image"] = input_image

# Set width and height to match the resized input image
flux_kwargs["width"], flux_kwargs["height"] = target_size

if is_img2img_mode:
print("[!] img2img mode")
flux_kwargs["strength"] = prompt_strength
print(f"Using {model} model for img2img")
pipe = self.img2img_pipes[model]
else: # is_inpaint_mode
print("[!] inpaint mode")
mask_image = Image.open(mask).convert("RGB")
flux_kwargs["mask_image"] = mask_image.crop((left, top, right, bottom))
flux_kwargs["strength"] = prompt_strength
print(f"Using {model} model for inpainting")
mask_image = mask_image.resize(target_size, Image.NEAREST)
flux_kwargs["mask_image"] = mask_image
pipe = self.inpaint_pipes[model]

flux_kwargs["strength"] = prompt_strength
print(
f"[!] Using {model} model for {'img2img' if is_img2img_mode else 'inpainting'}"
)
else: # is_txt2img_mode
print("[!] txt2img mode")
pipe = self.pipes[model]

flux_kwargs["width"] = width
flux_kwargs["height"] = height
if replicate_weights:
flux_kwargs["joint_attention_kwargs"] = {"scale": lora_scale}

Expand Down Expand Up @@ -468,26 +480,6 @@ def run_falcon_safety_checker(self, image):
def aspect_ratio_to_width_height(self, aspect_ratio: str) -> tuple[int, int]:
return ASPECT_RATIOS[aspect_ratio]

def resize_image_dimensions(
self,
original_resolution_wh: Tuple[int, int],
maximum_dimension: int = 1024,
) -> Tuple[int, int]:
width, height = original_resolution_wh

if width > height:
scaling_factor = maximum_dimension / width
else:
scaling_factor = maximum_dimension / height

new_width = int(width * scaling_factor)
new_height = int(height * scaling_factor)

new_width = new_width - (new_width % 32)
new_height = new_height - (new_height % 32)

return new_width, new_height


def download_base_weights(url: str, dest: Path):
start = time.time()
Expand Down

0 comments on commit 8c881c9

Please sign in to comment.