diff --git a/invokeai/app/invocations/fields.py b/invokeai/app/invocations/fields.py index 29234f4d791..4866098b2c6 100644 --- a/invokeai/app/invocations/fields.py +++ b/invokeai/app/invocations/fields.py @@ -283,6 +283,12 @@ class FluxReduxConditioningField(BaseModel): ) +class FluxUnoReferenceField(BaseModel): + """A FLUX Uno image list primitive value""" + + images: list[ImageField] = Field(description="The images to use as reference for FLUX Uno.") + + class FluxFillConditioningField(BaseModel): """A FLUX Fill conditioning field.""" diff --git a/invokeai/app/invocations/flux_denoise.py b/invokeai/app/invocations/flux_denoise.py index 3a7c15f949f..e5ffdf8303f 100644 --- a/invokeai/app/invocations/flux_denoise.py +++ b/invokeai/app/invocations/flux_denoise.py @@ -6,6 +6,7 @@ import numpy.typing as npt import torch import torchvision.transforms as tv_transforms +import torchvision.transforms.functional as TVF from PIL import Image from torchvision.transforms.functional import resize as tv_resize from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection @@ -17,6 +18,7 @@ FluxConditioningField, FluxFillConditioningField, FluxReduxConditioningField, + FluxUnoReferenceField, ImageField, Input, InputField, @@ -25,6 +27,7 @@ WithMetadata, ) from invokeai.app.invocations.flux_controlnet import FluxControlNetField +from invokeai.app.invocations.flux_uno import preprocess_ref from invokeai.app.invocations.flux_vae_encode import FluxVaeEncodeInvocation from invokeai.app.invocations.ip_adapter import IPAdapterField from invokeai.app.invocations.model import ControlLoRAField, LoRAField, TransformerField, VAEField @@ -45,6 +48,7 @@ get_noise, get_schedule, pack, + prepare_multi_ip, unpack, ) from invokeai.backend.flux.text_conditioning import FluxReduxConditioning, FluxTextConditioning @@ -109,6 +113,11 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard): description="FLUX Redux conditioning tensor.", input=Input.Connection, ) + uno_ref: FluxUnoReferenceField | None = InputField( + default=None, + description="FLUX Uno reference.", + input=Input.Connection, + ) fill_conditioning: FluxFillConditioningField | None = InputField( default=None, description="FLUX Fill conditioning.", @@ -284,6 +293,14 @@ def _run_diffusion( img_ids = generate_img_ids(h=latent_h, w=latent_w, batch_size=b, device=x.device, dtype=x.dtype) + if self.uno_ref is not None: + # Encode reference images and prepare position ids + uno_ref_imgs = self._prep_uno_reference_imgs(context=context) + uno_ref_imgs, uno_ref_ids = prepare_multi_ip(x, uno_ref_imgs) + else: + uno_ref_imgs = None + uno_ref_ids = None + # Pack all latent tensors. init_latents = pack(init_latents) if init_latents is not None else None inpaint_mask = pack(inpaint_mask) if inpaint_mask is not None else None @@ -391,6 +408,8 @@ def _run_diffusion( pos_ip_adapter_extensions=pos_ip_adapter_extensions, neg_ip_adapter_extensions=neg_ip_adapter_extensions, img_cond=img_cond, + uno_ref_imgs=uno_ref_imgs, + uno_ref_ids=uno_ref_ids, ) x = unpack(x.float(), self.height, self.width) @@ -658,6 +677,30 @@ def _prep_controlnet_extensions( return controlnet_extensions + def _prep_uno_reference_imgs(self, context: InvocationContext) -> list[torch.Tensor]: + # Load the conditioning image and resize it to the target image size. + + assert self.uno_ref is not None, "uno_ref must be set when using UNO." + ref_img_names = [i.image_name for i in self.uno_ref.images] + + assert self.controlnet_vae is not None, "Controlnet Vae must be set for UNO encoding" + vae_info = context.models.load(self.controlnet_vae.vae) + + ref_latents: list[torch.Tensor] = [] + + # TODO: Maybe move reference side to UNO Node as parameter + ref_long_side = 512 if len(ref_img_names) <= 1 else 320 + + for img_name in ref_img_names: + image_pil = context.images.get_pil(img_name, mode="RGB") + image_pil = preprocess_ref(image_pil, ref_long_side) # resize and crop + + image_tensor = (TVF.to_tensor(image_pil) * 2.0 - 1.0).unsqueeze(0).float() + ref_latent = FluxVaeEncodeInvocation.vae_encode(vae_info=vae_info, image_tensor=image_tensor) + ref_latents.append(ref_latent) + + return ref_latents + def _prep_structural_control_img_cond(self, context: InvocationContext) -> torch.Tensor | None: if self.control_lora is None: return None @@ -714,6 +757,7 @@ def _prep_flux_fill_img_cond( cond_img = context.images.get_pil(self.fill_conditioning.image.image_name, mode="RGB") cond_img = cond_img.resize((self.width, self.height), Image.Resampling.BICUBIC) cond_img = np.array(cond_img) + cond_img = torch.from_numpy(cond_img).float() / 127.5 - 1.0 cond_img = einops.rearrange(cond_img, "h w c -> 1 c h w") cond_img = cond_img.to(device=device, dtype=dtype) diff --git a/invokeai/app/invocations/flux_uno.py b/invokeai/app/invocations/flux_uno.py new file mode 100644 index 00000000000..beba8b25619 --- /dev/null +++ b/invokeai/app/invocations/flux_uno.py @@ -0,0 +1,71 @@ +from PIL import Image + +from invokeai.app.invocations.baseinvocation import ( + BaseInvocation, + BaseInvocationOutput, + Classification, + invocation, + invocation_output, +) +from invokeai.app.invocations.fields import FluxUnoReferenceField, InputField, OutputField +from invokeai.app.invocations.primitives import ImageField +from invokeai.app.services.shared.invocation_context import InvocationContext + + +def preprocess_ref(raw_image: Image.Image, long_size: int = 512) -> Image.Image: + """Resize and center crop reference image + Code from https://github.com/bytedance/UNO/blob/main/uno/flux/pipeline.py + """ + # Get the width and height of the original image + image_w, image_h = raw_image.size + + # Calculate the long and short sides + if image_w >= image_h: + new_w = long_size + new_h = int((long_size / image_w) * image_h) + else: + new_h = long_size + new_w = int((long_size / image_h) * image_w) + + # Scale proportionally to the new width and height + raw_image = raw_image.resize((new_w, new_h), resample=Image.Resampling.LANCZOS) + target_w = new_w // 16 * 16 + target_h = new_h // 16 * 16 + + # Calculate the starting coordinates of the clipping to achieve center clipping + left = (new_w - target_w) // 2 + top = (new_h - target_h) // 2 + right = left + target_w + bottom = top + target_h + + # Center crop + raw_image = raw_image.crop((left, top, right, bottom)) + + # Convert to RGB mode + raw_image = raw_image.convert("RGB") + return raw_image + + +@invocation_output("flux_uno_output") +class FluxUnoOutput(BaseInvocationOutput): + """The conditioning output of a FLUX Redux invocation.""" + + uno_ref: FluxUnoReferenceField = OutputField(description="Reference images container", title="Reference images") + + +@invocation( + "flux_uno", + title="FLUX UNO", + tags=["uno", "control"], + category="ip_adapter", + version="2.1.0", + classification=Classification.Beta, +) +class FluxUnoInvocation(BaseInvocation): + """Loads a FLUX UNO reference images.""" + + images: list[ImageField] | None = InputField(default=None, description="The UNO reference images.") + + def invoke(self, context: InvocationContext) -> FluxUnoOutput: + uno_ref = FluxUnoReferenceField(images=self.images or []) + return FluxUnoOutput(uno_ref=uno_ref) diff --git a/invokeai/app/invocations/image_context_utils.py b/invokeai/app/invocations/image_context_utils.py new file mode 100644 index 00000000000..01bc48f9004 --- /dev/null +++ b/invokeai/app/invocations/image_context_utils.py @@ -0,0 +1,238 @@ +"""Utility functions for ACE++ framework pipelines and different crop/merge operations + +1. Create empty image with given size and color +2. Concat images either horizontally or vertically +3. Crop image to given size and position +4. Paste cropped image to given position with resizing +Create empty mask with given size and value + +Nodes in Inpaint-Stitch pipeline + +Test each node: +Create empty image +- Test create different sizes and colors +- Test create mask - zero mask/ ones mask +Image Resize Advanced +- Resize to fixed size +- Resize with keep proportion +- Condition: downscale if bigger, upscale if smaller +- Gaussian Blur Mask +- Image Concatenate +""" + +import math +from typing import List, Literal, Optional + +import numpy as np +import torch +import torchvision.transforms as T +from PIL import Image + +from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output +from invokeai.app.invocations.fields import ( + ImageField, + Input, + InputField, + OutputField, + TensorField, + WithBoard, + WithMetadata, +) +from invokeai.app.invocations.primitives import ImageOutput +from invokeai.app.services.shared.invocation_context import InvocationContext + +DIRECTION_OPTIONS = Literal["right", "left", "down", "up"] + + +def concat_images( + image1: Image.Image, image2: Image.Image, direction: str = "right", match_image_size=True +) -> Image.Image: + """Concatenate two images either horizontally or vertically.""" + # Ensure that image modes are same + if image1.mode != image2.mode: + image2 = image2.convert(image1.mode) + + if direction == "right" or direction == "left": + if direction == "left": + image1, image2 = image2, image1 + new_width = image1.width + image2.width + new_height = max(image1.height, image2.height) + new_image = Image.new(image1.mode, (new_width, new_height)) + new_image.paste(image1, (0, 0)) + new_image.paste(image2, (image1.width, 0)) + elif direction == "down" or direction == "up": + if direction == "up": + image1, image2 = image2, image1 + new_width = max(image1.width, image2.width) + new_height = image1.height + image2.height + new_image = Image.new(image1.mode, (new_width, new_height)) + new_image.paste(image1, (0, 0)) + new_image.paste(image2, (0, image1.height)) + else: + raise ValueError("Mode must be either 'horizontal' or 'vertical'.") + + return new_image + + +@invocation( + "concat_images", + title="Concatenate Images", + tags=["image_processing"], + category="image_processing", + version="1.0.0", +) +class ConcatImagesInvocation(BaseInvocation, WithMetadata, WithBoard): + """Concatenate two images either horizontally or vertically.""" + + image1: ImageField = InputField(description="The first image to process") + image2: ImageField = InputField(description="The second image to process") + mode: DIRECTION_OPTIONS = InputField( + default="horizontal", description="Mode of concatenation: 'horizontal' or 'vertical'" + ) + + def invoke(self, context: InvocationContext) -> ImageOutput: + image1 = context.images.get_pil(self.image1.image_name) + image2 = context.images.get_pil(self.image2.image_name) + concatenated_image = concat_images(image1, image2, self.mode) + image_dto = context.images.save(image=concatenated_image) + return ImageOutput.build(image_dto) + + +@invocation_output("inpaint_crop_output") +class InpaintCropOutput(BaseInvocationOutput): + """The output of Inpain Crop Invocation.""" + + image_crop: ImageField = OutputField(description="Cropped part of image", title="Conditioning") + stitcher: List[int] = OutputField(description="Parameter for stitching image after inpainting") + + +@invocation( + "inpaint_crop", + title="Inpaint Crop", + tags=["image_processing"], + version="1.0.0", +) +class InpaintCropInvocation(BaseInvocation, WithMetadata, WithBoard): + "Crop from image masked area with resize and expand options" + + image: ImageField = InputField(description="The source image") + mask: TensorField = InputField(description="Inpaint mask") + + def invoke(self, context: InvocationContext) -> ImageOutput: + image = context.images.get_pil(self.image.image_name, "RGB") + mask = context.tensors.load(self.mask.tensor_name) + + # TODO: Finish InpaintCrop implementation + image_crop = Image.new("RGB", (256, 256)) + + image_dto = context.images.save(image=image_crop) + return ImageOutput.build(image_dto) + + +@invocation_output("ace_plus_plus_output") +class ACEppProcessorOutput(BaseInvocationOutput): + """The conditioning output of a FLUX Fill invocation.""" + + image: ImageField = OutputField(description="Concatted image", title="Image") + mask: TensorField = OutputField(description="Inpaint mask") + crop_pad: int = OutputField(description="Padding to crop result") + crop_width: int = OutputField(description="Width of output area") + crop_height: int = OutputField(description="Heihgt of crop area") + + +@invocation( + "ace_plus_plus_processor", + title="ACE++ processor", + tags=["image_processing"], + version="1.0.0", +) +class ACEppProcessor(BaseInvocation): + reference_image: ImageField = InputField(description="Reference Image") + edit_image: Optional[ImageField] = InputField(description="Edit Image", default=None, input=Input.Connection) + edit_mask: Optional[TensorField] = InputField(description="Edit Mask", default=None, input=Input.Connection) + + width: int = InputField(default=512, gt=0, description="The width of the crop rectangle") + height: int = InputField(default=512, gt=0, description="The height of the crop rectangle") + + max_seq_len: int = InputField(default=4096, gt=2048, le=5120, description="The height of the crop rectangle") + + def image_check(self, image_pil: Image.Image) -> torch.Tensor: + max_aspect_ratio = 4 + + image = self.transform_pil_tensor(image_pil) + image = image.unsqueeze(0) + # preprocess + H, W = image.shape[2:] + if H / W > max_aspect_ratio: + image[0] = T.CenterCrop([int(max_aspect_ratio * W), W])(image[0]) + elif W / H > max_aspect_ratio: + image[0] = T.CenterCrop([H, int(max_aspect_ratio * H)])(image[0]) + return image[0] + + def transform_pil_tensor(self, pil_image: Image.Image) -> torch.Tensor: + transform = T.Compose([T.ToTensor()]) + tensor_image: torch.Tensor = transform(pil_image) + return tensor_image + + def invoke(self, context: InvocationContext) -> ACEppProcessorOutput: + d = 16 # Flux pixels per patch rate + + image_pil = context.images.get_pil(self.reference_image.image_name, "RGB") + image = self.image_check(image_pil) - 0.5 + + if self.edit_image is None: + edit_image = torch.zeros((3, self.height, self.width)) + edit_mask = torch.ones((1, self.height, self.width)) + else: + # TODO: make variant for editing + edit_image = context.images.get_pil(self.edit_image.image_name) + edit_image = self.image_check(edit_image) - 0.5 + if self.edit_mask is None: + _, eH, eW = edit_image.shape + edit_mask = torch.ones((eH, eW)) + else: + edit_mask = context.tensors.load(self.edit_mask.tensor_name) + + out_H, out_W = edit_image.shape[-2:] + + _, H, W = image.shape + _, eH, eW = edit_image.shape + + # align height with edit_image + scale = eH / H + tH, tW = eH, int(W * scale) + + reference_image = T.Resize((tH, tW), interpolation=T.InterpolationMode.BILINEAR, antialias=True)(image) + edit_image = torch.cat([reference_image, edit_image], dim=-1) + edit_mask = torch.cat([torch.zeros((1, reference_image.shape[1], reference_image.shape[2])), edit_mask], dim=-1) + slice_w = reference_image.shape[-1] + + H, W = edit_image.shape[-2:] + scale = min(1.0, math.sqrt(self.max_seq_len * 2 / ((H / d) * (W / d)))) + rH = int(H * scale) // d * d + rW = int(W * scale) // d * d + slice_w = int(slice_w * scale) // d * d + + edit_image = T.Resize((rH, rW), interpolation=T.InterpolationMode.NEAREST_EXACT, antialias=True)(edit_image) + edit_mask = T.Resize((rH, rW), interpolation=T.InterpolationMode.NEAREST_EXACT, antialias=True)(edit_mask) + + edit_image = edit_image.unsqueeze(0).permute(0, 2, 3, 1) + slice_w = slice_w if slice_w < 30 else slice_w + 30 + + # Manipulations with -0.5/+0.5 needed only for gray color in mask + # and took from original author's implementation + # TODO: remove this -0.5/+0.5 + edit_image += 0.5 + # Convert to torch.bool + edit_mask = edit_mask > 0.5 + image_out = Image.fromarray((edit_image[0].numpy() * 255).astype(np.uint8)) + + image_dto = context.images.save(image=image_out) + mask_name = context.tensors.save(edit_mask) + return ACEppProcessorOutput( + image=ImageField(image_name=image_dto.image_name), + mask=TensorField(tensor_name=mask_name), + crop_pad=slice_w, + crop_height=int(out_H * scale), + crop_width=int(out_W * scale), + ) diff --git a/invokeai/backend/flux/denoise.py b/invokeai/backend/flux/denoise.py index 706f6941da0..9ce6fe19167 100644 --- a/invokeai/backend/flux/denoise.py +++ b/invokeai/backend/flux/denoise.py @@ -32,6 +32,8 @@ def denoise( neg_ip_adapter_extensions: list[XLabsIPAdapterExtension], # extra img tokens img_cond: torch.Tensor | None, + uno_ref_imgs: list[torch.Tensor] | None = None, + uno_ref_ids: list[torch.Tensor] | None = None, ): # step 0 is the initial state total_steps = len(timesteps) - 1 @@ -86,6 +88,8 @@ def denoise( controlnet_single_block_residuals=merged_controlnet_residuals.single_block_residuals, ip_adapter_extensions=pos_ip_adapter_extensions, regional_prompting_extension=pos_regional_prompting_extension, + uno_ref_imgs=uno_ref_imgs, + uno_ref_ids=uno_ref_ids, ) step_cfg_scale = cfg_scale[step_index] diff --git a/invokeai/backend/flux/model.py b/invokeai/backend/flux/model.py index cfa85691e94..91933cccaea 100644 --- a/invokeai/backend/flux/model.py +++ b/invokeai/backend/flux/model.py @@ -102,6 +102,8 @@ def forward( controlnet_single_block_residuals: list[Tensor] | None, ip_adapter_extensions: list[XLabsIPAdapterExtension], regional_prompting_extension: RegionalPromptingExtension, + uno_ref_imgs: list[torch.Tensor] | None = None, + uno_ref_ids: list[torch.Tensor] | None = None, ) -> Tensor: if img.ndim != 3 or txt.ndim != 3: raise ValueError("Input img and txt tensors must have 3 dimensions.") @@ -117,6 +119,14 @@ def forward( txt = self.txt_in(txt) ids = torch.cat((txt_ids, img_ids), dim=1) + # Concatenate UNO reference images tokens and position ids + img_end = img.shape[1] # length of original image vector + if uno_ref_imgs is not None and uno_ref_ids is not None: + img_in = [img] + [self.img_in(ref) for ref in uno_ref_imgs] + img_ids = [ids] + uno_ref_ids + img = torch.cat(img_in, dim=1) + ids = torch.cat(img_ids, dim=1) + pe = self.pe_embedder(ids) # Validate double_block_residuals shape. @@ -164,5 +174,7 @@ def forward( img = img[:, txt.shape[1] :, ...] + # index img + img = img[:, :img_end, ...] img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) return img diff --git a/invokeai/backend/flux/sampling_utils.py b/invokeai/backend/flux/sampling_utils.py index a4b36df9fde..f75a1b63b81 100644 --- a/invokeai/backend/flux/sampling_utils.py +++ b/invokeai/backend/flux/sampling_utils.py @@ -182,3 +182,45 @@ def generate_img_ids(h: int, w: int, batch_size: int, device: torch.device, dtyp img_ids.to(orig_dtype) return img_ids + + +def prepare_multi_ip(img: torch.Tensor, ref_imgs: list[torch.Tensor]) -> tuple[list[torch.Tensor], list[torch.Tensor]]: + """Generate universal rotary position embedding(UnoPE) for reference images. + + Args: + img (torch.Tensor): latent image representation for denoising + ref_imgs (list[torch.Tensor]): list of reference images + + Returns: + tuple[list[torch.Tensor], list[torch.Tensor]]: packed reference images and position embeddings + """ + bs, c, h, w = img.shape + + ref_img_ids: list[torch.Tensor] = [] + ref_imgs_list: list[torch.Tensor] = [] + pe_shift_w, pe_shift_h = w // 2, h // 2 + for ref_img in ref_imgs: + _, _, ref_h1, ref_w1 = ref_img.shape + ref_img = rearrange(ref_img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) + if ref_img.shape[0] == 1 and bs > 1: + ref_img = repeat(ref_img, "1 ... -> bs ...", bs=bs) + ref_img_ids1 = torch.zeros(ref_h1 // 2, ref_w1 // 2, 3) + # img id offsets its maximum values ​​in width and height respectively + h_offset = pe_shift_h + w_offset = pe_shift_w + ref_img_ids1[..., 1] = ref_img_ids1[..., 1] + torch.arange(ref_h1 // 2)[:, None] + h_offset + ref_img_ids1[..., 2] = ref_img_ids1[..., 2] + torch.arange(ref_w1 // 2)[None, :] + w_offset + ref_img_ids1 = repeat(ref_img_ids1, "h w c -> b (h w) c", b=bs) + ref_img_ids.append(ref_img_ids1) + ref_imgs_list.append(ref_img) + + # Update pe shift + pe_shift_h += ref_h1 // 2 + pe_shift_w += ref_w1 // 2 + + return ( + # "img": img, + # "img_ids": img_ids.to(img.device), + ref_imgs_list, + [ref_img_id.to(img.device) for ref_img_id in ref_img_ids], + )