-
Notifications
You must be signed in to change notification settings - Fork 27
/
Copy pathtryon_inference_lora.py
134 lines (114 loc) · 4.83 KB
/
tryon_inference_lora.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import argparse
import torch
from diffusers.utils import load_image, check_min_version
from diffusers import FluxPriorReduxPipeline, FluxFillPipeline
from diffusers import FluxTransformer2DModel
import numpy as np
from torchvision import transforms
def run_inference(
image_path,
mask_path,
garment_path,
size=(576, 768),
num_steps=50,
guidance_scale=30,
seed=42,
pipe=None
):
# Build pipeline
if pipe is None:
transformer = FluxTransformer2DModel.from_pretrained(
"xiaozaa/flux1-fill-dev-diffusers", ## The official Flux-Fill weights
torch_dtype=torch.bfloat16
)
print("Start loading LoRA weights")
state_dict, network_alphas = FluxFillPipeline.lora_state_dict(
pretrained_model_name_or_path_or_dict="xiaozaa/catvton-flux-lora-alpha", ## The tryon Lora weights
weight_name="pytorch_lora_weights.safetensors",
return_alphas=True
)
is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.")
FluxFillPipeline.load_lora_into_transformer(
state_dict=state_dict,
network_alphas=network_alphas,
transformer=transformer,
)
pipe = FluxFillPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
transformer=transformer,
torch_dtype=torch.bfloat16
).to("cuda")
else:
pipe.to("cuda")
pipe.transformer.to(torch.bfloat16)
# Add transform
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]) # For RGB images
])
mask_transform = transforms.Compose([
transforms.ToTensor()
])
# Load and process images
# print("image_path", image_path)
image = load_image(image_path).convert("RGB").resize(size)
mask = load_image(mask_path).convert("RGB").resize(size)
garment = load_image(garment_path).convert("RGB").resize(size)
# Transform images using the new preprocessing
image_tensor = transform(image)
mask_tensor = mask_transform(mask)[:1] # Take only first channel
garment_tensor = transform(garment)
# Create concatenated images
inpaint_image = torch.cat([garment_tensor, image_tensor], dim=2) # Concatenate along width
garment_mask = torch.zeros_like(mask_tensor)
extended_mask = torch.cat([garment_mask, mask_tensor], dim=2)
prompt = f"The pair of images highlights a clothing and its styling on a model, high resolution, 4K, 8K; " \
f"[IMAGE1] Detailed product shot of a clothing" \
f"[IMAGE2] The same cloth is worn by a model in a lifestyle setting."
generator = torch.Generator(device="cuda").manual_seed(seed)
result = pipe(
height=size[1],
width=size[0] * 2,
image=inpaint_image,
mask_image=extended_mask,
num_inference_steps=num_steps,
generator=generator,
max_sequence_length=512,
guidance_scale=guidance_scale,
prompt=prompt,
).images[0]
# Split and save results
width = size[0]
garment_result = result.crop((0, 0, width, size[1]))
tryon_result = result.crop((width, 0, width * 2, size[1]))
return garment_result, tryon_result
def main():
parser = argparse.ArgumentParser(description='Run FLUX virtual try-on inference')
parser.add_argument('--image', required=True, help='Path to the model image')
parser.add_argument('--mask', required=True, help='Path to the agnostic mask')
parser.add_argument('--garment', required=True, help='Path to the garment image')
parser.add_argument('--output_garment', default='flux_inpaint_garment.png', help='Output path for garment result')
parser.add_argument('--output_tryon', default='flux_inpaint_tryon.png', help='Output path for try-on result')
parser.add_argument('--steps', type=int, default=50, help='Number of inference steps')
parser.add_argument('--guidance_scale', type=float, default=30, help='Guidance scale')
parser.add_argument('--seed', type=int, default=0, help='Random seed')
parser.add_argument('--width', type=int, default=576, help='Width')
parser.add_argument('--height', type=int, default=768, help='Height')
args = parser.parse_args()
check_min_version("0.30.2")
garment_result, tryon_result = run_inference(
image_path=args.image,
mask_path=args.mask,
garment_path=args.garment,
num_steps=args.steps,
guidance_scale=args.guidance_scale,
seed=args.seed,
size=(args.width, args.height)
)
output_tryon_path=args.output_tryon
tryon_result.save(output_tryon_path)
print("Successfully saved garment and try-on images")
if __name__ == "__main__":
main()