-
Notifications
You must be signed in to change notification settings - Fork 27
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #19 from nftblackmagic/anzh/tryoff
Anzh/tryoff
- Loading branch information
Showing
13 changed files
with
169 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -53,3 +53,6 @@ Thumbs.db | |
|
||
# Gradio cache | ||
.gradio/example/github.mp4 | ||
|
||
aws/ | ||
checkpoints/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,3 +11,4 @@ sentencepiece | |
peft==0.13.2 | ||
huggingface-hub | ||
spaces | ||
protobuf |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
python tryoff_inference.py \ | ||
--image ./example/person/00069_00.jpg \ | ||
--mask ./example/person/00069_00_mask.png \ | ||
--seed 41 \ | ||
--output_tryon test_original.png \ | ||
--output_garment restored_garment6.png \ | ||
--steps 30 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
import argparse | ||
import torch | ||
from diffusers.utils import load_image, check_min_version | ||
from diffusers import FluxPriorReduxPipeline, FluxFillPipeline | ||
from diffusers import FluxTransformer2DModel | ||
import numpy as np | ||
from torchvision import transforms | ||
|
||
def run_inference( | ||
image_path, | ||
mask_path, | ||
size=(576, 768), | ||
num_steps=50, | ||
guidance_scale=30, | ||
seed=42, | ||
pipe=None | ||
): | ||
# Build pipeline | ||
if pipe is None: | ||
transformer = FluxTransformer2DModel.from_pretrained( | ||
"xiaozaa/cat-tryoff-flux", | ||
torch_dtype=torch.bfloat16 | ||
) | ||
pipe = FluxFillPipeline.from_pretrained( | ||
"black-forest-labs/FLUX.1-dev", | ||
transformer=transformer, | ||
torch_dtype=torch.bfloat16 | ||
).to("cuda") | ||
else: | ||
pipe.to("cuda") | ||
|
||
pipe.transformer.to(torch.bfloat16) | ||
|
||
# Add transform | ||
transform = transforms.Compose([ | ||
transforms.ToTensor(), | ||
transforms.Normalize([0.5], [0.5]) # For RGB images | ||
]) | ||
mask_transform = transforms.Compose([ | ||
transforms.ToTensor() | ||
]) | ||
|
||
# Load and process images | ||
# print("image_path", image_path) | ||
image = load_image(image_path).convert("RGB").resize(size) | ||
mask = load_image(mask_path).convert("RGB").resize(size) | ||
|
||
# Transform images using the new preprocessing | ||
image_tensor = transform(image) | ||
mask_tensor = mask_transform(mask)[:1] # Take only first channel | ||
garment_tensor = torch.zeros_like(image_tensor) | ||
image_tensor = image_tensor * mask_tensor | ||
|
||
# Create concatenated images | ||
inpaint_image = torch.cat([garment_tensor, image_tensor], dim=2) # Concatenate along width | ||
garment_mask = torch.zeros_like(mask_tensor) | ||
extended_mask = torch.cat([1 - garment_mask, garment_mask], dim=2) | ||
|
||
prompt = f"The pair of images highlights a clothing and its styling on a model, high resolution, 4K, 8K; " \ | ||
f"[IMAGE1] Detailed product shot of a clothing" \ | ||
f"[IMAGE2] The same cloth is worn by a model in a lifestyle setting." | ||
|
||
generator = torch.Generator(device="cuda").manual_seed(seed) | ||
|
||
result = pipe( | ||
height=size[1], | ||
width=size[0] * 2, | ||
image=inpaint_image, | ||
mask_image=extended_mask, | ||
num_inference_steps=num_steps, | ||
generator=generator, | ||
max_sequence_length=512, | ||
guidance_scale=guidance_scale, | ||
prompt=prompt, | ||
).images[0] | ||
|
||
# Split and save results | ||
width = size[0] | ||
garment_result = result.crop((0, 0, width, size[1])) | ||
tryon_result = result.crop((width, 0, width * 2, size[1])) | ||
|
||
return garment_result, tryon_result | ||
|
||
def main(): | ||
parser = argparse.ArgumentParser(description='Run FLUX virtual try-on inference') | ||
parser.add_argument('--image', required=True, help='Path to the model image') | ||
parser.add_argument('--mask', required=True, help='Path to the agnostic mask') | ||
parser.add_argument('--output_garment', default='flux_inpaint_garment.png', help='Output path for garment result') | ||
parser.add_argument('--output_tryon', default='flux_inpaint_tryon.png', help='Output path for try-on result') | ||
parser.add_argument('--steps', type=int, default=50, help='Number of inference steps') | ||
parser.add_argument('--guidance_scale', type=float, default=30, help='Guidance scale') | ||
parser.add_argument('--seed', type=int, default=0, help='Random seed') | ||
parser.add_argument('--width', type=int, default=576, help='Width') | ||
parser.add_argument('--height', type=int, default=768, help='Height') | ||
|
||
args = parser.parse_args() | ||
|
||
check_min_version("0.30.2") | ||
|
||
garment_result, tryon_result = run_inference( | ||
image_path=args.image, | ||
mask_path=args.mask, | ||
num_steps=args.steps, | ||
guidance_scale=args.guidance_scale, | ||
seed=args.seed, | ||
size=(args.width, args.height) | ||
) | ||
output_tryon_path=args.output_tryon | ||
output_garment_path=args.output_garment | ||
|
||
tryon_result.save(output_tryon_path) | ||
garment_result.save(output_garment_path) | ||
|
||
print("Successfully saved garment and try-on images") | ||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
python tryon_inference.py \ | ||
--image ./example/person/00008_00.jpg \ | ||
--mask ./example/person/00008_00_mask.png \ | ||
--garment ./example/garment/00034_00.jpg \ | ||
--seed 42 \ | ||
--output_tryon test.png \ | ||
--steps 30 |