Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dataset class and configuration for open pose training #118

Open
wants to merge 27 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
f3c7bf2
controlnet on cuda:1
raulc0399 Sep 10, 2024
4ec28f5
controlnet on cuda1
raulc0399 Sep 10, 2024
bef9d83
move each entry
raulc0399 Sep 10, 2024
1a830da
use cuda:1
raulc0399 Sep 13, 2024
0a21852
multigpu
raulc0399 Sep 13, 2024
a736257
multigpu params
raulc0399 Sep 14, 2024
643542a
renamed 2 gpu pipeline
raulc0399 Sep 14, 2024
3153419
add support for two gpus pipeline
raulc0399 Sep 14, 2024
1e4cfaf
move results back to model's device
raulc0399 Sep 14, 2024
20bd531
move controlnet args to controlnet's device
raulc0399 Sep 14, 2024
a8b6049
fix param
raulc0399 Sep 14, 2024
e08ab95
info when the diffusion process has started
raulc0399 Sep 14, 2024
ffb5a60
rm print
raulc0399 Sep 14, 2024
e81aa05
devices as params
raulc0399 Sep 14, 2024
2375525
openpose dataset and config
raulc0399 Sep 23, 2024
73fa768
read openpose controlnet
raulc0399 Sep 23, 2024
2514658
ignore _pose files
raulc0399 Sep 23, 2024
f0bbc5f
correct file
raulc0399 Sep 23, 2024
a21366f
correct file
raulc0399 Sep 23, 2024
3ed30ce
training config
raulc0399 Sep 23, 2024
8524be4
checkpoints limit and rm print
raulc0399 Sep 23, 2024
b4b339d
convert to rgb to support grayscale images as well
raulc0399 Sep 23, 2024
f371340
device for annotator
raulc0399 Sep 25, 2024
dde5074
Merge branch 'main_multigpu' into open_pose_training
raulc0399 Sep 25, 2024
7ae33ca
Merge remote-tracking branch 'upstream/main'
raulc0399 Sep 25, 2024
6ec028b
Merge branch 'main' into main_multigpu
raulc0399 Sep 25, 2024
db45541
Merge branch 'main_multigpu' into open_pose_training
raulc0399 Sep 25, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -165,3 +165,4 @@ cython_debug/
#.idea/

.DS_Store
.aider*
62 changes: 62 additions & 0 deletions image_datasets/openpose_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import os
import pandas as pd
import numpy as np
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
import json
import random

def c_crop(image):
width, height = image.size
new_size = min(width, height)
left = (width - new_size) / 2
top = (height - new_size) / 2
right = (width + new_size) / 2
bottom = (height + new_size) / 2
return image.crop((left, top, right, bottom))

class OpenPoseImageDataset(Dataset):
def __init__(self, img_dir, img_size=512):
self.img_dir = img_dir
self.images = [os.path.join(img_dir, i) for i in os.listdir(img_dir) if ('.jpg' in i or '.png' in i) and not i.endswith('_pose.jpg') and not i.endswith('_pose.png')]
self.images.sort()
self.img_size = img_size

print('OpenPoseImageDataset: ', len(self.images))

def __len__(self):
return len(self.images)

def __getitem__(self, idx):
try:
json_path = self.images[idx].split('.')[0] + '.json'
json_data = json.load(open(json_path))

img = Image.open(self.images[idx])
img = c_crop(img)
img = img.resize((self.img_size, self.img_size))
# support gray scale images as well
if img.mode != 'RGB':
img = img.convert('RGB')
img = torch.from_numpy((np.array(img) / 127.5) - 1)
img = img.permute(2, 0, 1)

hint_path = os.path.join(self.img_dir, json_data['conditioning_image'])
hint = Image.open(hint_path)
hint = c_crop(hint)
hint = hint.resize((self.img_size, self.img_size))
hint = torch.from_numpy((np.array(hint) / 127.5) - 1)
hint = hint.permute(2, 0, 1)

prompt = json_data['caption']
return img, hint, prompt

except Exception as e:
print(e)
return self.__getitem__(random.randint(0, len(self.images) - 1))


def openpose_dataset_loader(train_batch_size, num_workers, **args):
dataset = OpenPoseImageDataset(**args)
return DataLoader(dataset, batch_size=train_batch_size, num_workers=num_workers, shuffle=True)
6 changes: 5 additions & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,10 @@ def create_argparser():
parser.add_argument(
"--save_path", type=str, default='results', help="Path to save"
)
parser.add_argument(
"--two_gpus_pipeline", action='store_true', default=False,
help="Enable two-GPU pipeline (cuda:0 and cuda:1), the transformer will be loaded on the device specified by --device"
)
return parser


Expand All @@ -140,7 +144,7 @@ def main(args):
else:
image = None

xflux_pipeline = XFluxPipeline(args.model_type, args.device, args.offload)
xflux_pipeline = XFluxPipeline(args.model_type, args.device, args.offload, two_gpus_pipeline=args.two_gpus_pipeline)
if args.use_ip:
print('load ip-adapter:', args.ip_local_path, args.ip_repo_id, args.ip_name)
xflux_pipeline.set_ip(args.ip_local_path, args.ip_repo_id, args.ip_name)
Expand Down
60 changes: 43 additions & 17 deletions src/flux/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,23 +172,40 @@ def denoise_controlnet(
image_proj: Tensor=None,
neg_image_proj: Tensor=None,
ip_scale: Tensor | float = 1,
neg_ip_scale: Tensor | float = 1,
neg_ip_scale: Tensor | float = 1,
controlnet_device: torch.device = "cuda:0",
model_device: torch.device = "cuda:0"
):
# this is ignored for schnell
i = 0
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:]):
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)

# move controlnet params to controlnet's device
img_controlnet_device = img.to(controlnet_device)
img_ids_controlnet_device = img_ids.to(controlnet_device)
controlnet_cond_controlnet_device = controlnet_cond.to(controlnet_device)
txt_controlnet_device = txt.to(controlnet_device)
txt_ids_controlnet_device = txt_ids.to(controlnet_device)
vec_controlnet_device = vec.to(controlnet_device)
t_vec_controlnet_device = t_vec.to(controlnet_device)
guidance_vec_controlnet_device = guidance_vec.to(controlnet_device)

block_res_samples = controlnet(
img=img,
img_ids=img_ids,
controlnet_cond=controlnet_cond,
txt=txt,
txt_ids=txt_ids,
y=vec,
timesteps=t_vec,
guidance=guidance_vec,
img=img_controlnet_device,
img_ids=img_ids_controlnet_device,
controlnet_cond=controlnet_cond_controlnet_device,
txt=txt_controlnet_device,
txt_ids=txt_ids_controlnet_device,
y=vec_controlnet_device,
timesteps=t_vec_controlnet_device,
guidance=guidance_vec_controlnet_device,
)

# move results back to model's device
block_res_samples = [i.to(model_device) for i in block_res_samples]

pred = model(
img=img,
img_ids=img_ids,
Expand All @@ -202,16 +219,25 @@ def denoise_controlnet(
ip_scale=ip_scale,
)
if i >= timestep_to_start_cfg:
# move negative prompt to controlnet's device
neg_txt_controlnet_device = neg_txt.to(controlnet_device)
neg_txt_ids_controlnet_device = neg_txt_ids.to(controlnet_device)
neg_vec_controlnet_device = neg_vec.to(controlnet_device)

neg_block_res_samples = controlnet(
img=img,
img_ids=img_ids,
controlnet_cond=controlnet_cond,
txt=neg_txt,
txt_ids=neg_txt_ids,
y=neg_vec,
timesteps=t_vec,
guidance=guidance_vec,
img=img_controlnet_device,
img_ids=img_ids_controlnet_device,
controlnet_cond=controlnet_cond_controlnet_device,
txt=neg_txt_controlnet_device,
txt_ids=neg_txt_ids_controlnet_device,
y=neg_vec_controlnet_device,
timesteps=t_vec_controlnet_device,
guidance=guidance_vec_controlnet_device,
)

# move results back to model's device
neg_block_res_samples = [i.to(model_device) for i in neg_block_res_samples]

neg_pred = model(
img=img,
img_ids=img_ids,
Expand Down
55 changes: 33 additions & 22 deletions src/flux/xflux_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,18 +31,23 @@
from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor

class XFluxPipeline:
def __init__(self, model_type, device, offload: bool = False):
self.device = torch.device(device)
def __init__(self, model_type, device, offload: bool = False, two_gpus_pipeline: bool = False):
if two_gpus_pipeline:
self.model_device = torch.device(device)
self.other_device = torch.device("cuda:0" if device == "cuda:1" else "cuda:1")
else:
self.model_device = self.other_device = torch.device(device)

self.offload = offload
self.model_type = model_type

self.clip = load_clip(self.device)
self.t5 = load_t5(self.device, max_length=512)
self.ae = load_ae(model_type, device="cpu" if offload else self.device)
self.clip = load_clip(self.other_device)
self.t5 = load_t5(self.other_device, max_length=512)
self.ae = load_ae(model_type, device="cpu" if offload else self.other_device)
if "fp8" in model_type:
self.model = load_flow_model_quintized(model_type, device="cpu" if offload else self.device)
self.model = load_flow_model_quintized(model_type, device="cpu" if offload else self.model_device)
else:
self.model = load_flow_model(model_type, device="cpu" if offload else self.device)
self.model = load_flow_model(model_type, device="cpu" if offload else self.model_device)

self.image_encoder_path = "openai/clip-vit-large-patch14"
self.hf_lora_collection = "XLabs-AI/flux-lora-collection"
Expand All @@ -53,7 +58,7 @@ def __init__(self, model_type, device, offload: bool = False):
self.ip_loaded = False

def set_ip(self, local_path: str = None, repo_id = None, name: str = None):
self.model.to(self.device)
self.model.to(self.model_device)

# unpack checkpoint
checkpoint = load_checkpoint(local_path, repo_id, name)
Expand All @@ -69,14 +74,14 @@ def set_ip(self, local_path: str = None, repo_id = None, name: str = None):

# load image encoder
self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to(
self.device, dtype=torch.float16
self.other_device, dtype=torch.float16
)
self.clip_image_processor = CLIPImageProcessor()

# setup image embedding projection model
self.improj = ImageProjModel(4096, 768, 4)
self.improj.load_state_dict(proj)
self.improj = self.improj.to(self.device, dtype=torch.bfloat16)
self.improj = self.improj.to(self.other_device, dtype=torch.bfloat16)

ip_attn_procs = {}

Expand All @@ -88,7 +93,7 @@ def set_ip(self, local_path: str = None, repo_id = None, name: str = None):
if ip_state_dict:
ip_attn_procs[name] = IPDoubleStreamBlockProcessor(4096, 3072)
ip_attn_procs[name].load_state_dict(ip_state_dict)
ip_attn_procs[name].to(self.device, dtype=torch.bfloat16)
ip_attn_procs[name].to(self.model_device, dtype=torch.bfloat16)
else:
ip_attn_procs[name] = self.model.attn_processors[name]

Expand Down Expand Up @@ -122,7 +127,7 @@ def update_model_with_lora(self, checkpoint, lora_weight):
else:
lora_attn_procs[name] = DoubleStreamBlockLoraProcessor(dim=3072, rank=rank)
lora_attn_procs[name].load_state_dict(lora_state_dict)
lora_attn_procs[name].to(self.device)
lora_attn_procs[name].to(self.model_device)
else:
if name.startswith("single_blocks"):
lora_attn_procs[name] = SingleStreamBlockProcessor()
Expand All @@ -132,12 +137,13 @@ def update_model_with_lora(self, checkpoint, lora_weight):
self.model.set_attn_processor(lora_attn_procs)

def set_controlnet(self, control_type: str, local_path: str = None, repo_id: str = None, name: str = None):
self.model.to(self.device)
self.controlnet = load_controlnet(self.model_type, self.device).to(torch.bfloat16)
self.model.to(self.model_device)

self.controlnet = load_controlnet(self.model_type, self.other_device).to(torch.bfloat16)

checkpoint = load_checkpoint(local_path, repo_id, name)
self.controlnet.load_state_dict(checkpoint, strict=False)
self.annotator = Annotator(control_type, self.device)
self.annotator = Annotator(control_type, self.other_device)
self.controlnet_loaded = True
self.control_type = control_type

Expand All @@ -154,7 +160,7 @@ def get_image_proj(
image_prompt_embeds = self.image_encoder(
image_prompt
).image_embeds.to(
device=self.device, dtype=torch.bfloat16,
device=self.model_device, dtype=torch.bfloat16,
)
# encode image
image_proj = self.improj(image_prompt_embeds)
Expand Down Expand Up @@ -196,7 +202,7 @@ def __call__(self,
controlnet_image = self.annotator(controlnet_image, width, height)
controlnet_image = torch.from_numpy((np.array(controlnet_image) / 127.5) - 1)
controlnet_image = controlnet_image.permute(
2, 0, 1).unsqueeze(0).to(torch.bfloat16).to(self.device)
2, 0, 1).unsqueeze(0).to(torch.bfloat16).to(self.model_device)

return self.forward(
prompt,
Expand Down Expand Up @@ -277,25 +283,28 @@ def forward(
ip_scale=1.0,
neg_ip_scale=1.0,
):
print("Starting the diffusion process...")

x = get_noise(
1, height, width, device=self.device,
1, height, width, device=self.model_device,
dtype=torch.bfloat16, seed=seed
)
timesteps = get_schedule(
num_steps,
(width // 8) * (height // 8) // (16 * 16),
shift=True,
)

torch.manual_seed(seed)
with torch.no_grad():
if self.offload:
self.t5, self.clip = self.t5.to(self.device), self.clip.to(self.device)
self.t5, self.clip = self.t5.to(self.other_device), self.clip.to(self.other_device)
inp_cond = prepare(t5=self.t5, clip=self.clip, img=x, prompt=prompt)
neg_inp_cond = prepare(t5=self.t5, clip=self.clip, img=x, prompt=neg_prompt)

if self.offload:
self.offload_model_to_cpu(self.t5, self.clip)
self.model = self.model.to(self.device)
self.model = self.model.to(self.model_device)
if self.controlnet_loaded:
x = denoise_controlnet(
self.model,
Expand All @@ -314,6 +323,8 @@ def forward(
neg_image_proj=neg_image_proj,
ip_scale=ip_scale,
neg_ip_scale=neg_ip_scale,
controlnet_device=self.other_device,
model_device=self.model_device,
)
else:
x = denoise(
Expand All @@ -334,9 +345,9 @@ def forward(

if self.offload:
self.offload_model_to_cpu(self.model)
self.ae.decoder.to(x.device)
self.ae.decoder.to(self.other_device)
x = unpack(x.float(), height, width)
x = self.ae.decode(x)
x = self.ae.decode(x.to(self.other_device))
self.offload_model_to_cpu(self.ae.decoder)

x1 = x.clamp(-1, 1)
Expand Down
26 changes: 26 additions & 0 deletions train_configs/test_openpose_controlnet.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
model_name: "flux-dev"
is_openpose: true
data_config:
train_batch_size: 2
num_workers: 2
img_size: 512
img_dir: images/
report_to: wandb
train_batch_size: 2
output_dir: saves_openpose/
max_train_steps: 100000
learning_rate: 2e-5
lr_scheduler: constant
lr_warmup_steps: 10
adam_beta1: 0.9
adam_beta2: 0.999
adam_weight_decay: 0.01
adam_epsilon: 1e-8
max_grad_norm: 1.0
logging_dir: logs
mixed_precision: "bf16"
checkpointing_steps: 2500
checkpoints_total_limit: 50
tracker_project_name: openpose_training
resume_from_checkpoint: latest
gradient_accumulation_steps: 2
9 changes: 7 additions & 2 deletions train_flux_deepspeed_controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from src.flux.util import (configs, load_ae, load_clip,
load_flow_model2, load_controlnet, load_t5)
from image_datasets.canny_dataset import loader
from image_datasets.openpose_dataset import openpose_dataset_loader
if is_wandb_available():
import wandb
logger = get_logger(__name__, log_level="INFO")
Expand Down Expand Up @@ -122,7 +123,11 @@ def main():
eps=args.adam_epsilon,
)

train_dataloader = loader(**args.data_config)
if args.is_openpose:
train_dataloader = openpose_dataset_loader(**args.data_config)
else:
train_dataloader = loader(**args.data_config)

# Scheduler and math around the number of training steps.
overrode_max_train_steps = False
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
Expand Down Expand Up @@ -219,7 +224,7 @@ def main():
t = torch.sigmoid(torch.randn((bs,), device=accelerator.device))

x_0 = torch.randn_like(x_1).to(accelerator.device)
print(t.shape, x_1.shape, x_0.shape)
# print(t.shape, x_1.shape, x_0.shape)
x_t = (1 - t.unsqueeze(1).unsqueeze(2).repeat(1, x_1.shape[1], x_1.shape[2])) * x_1 + t.unsqueeze(1).unsqueeze(2).repeat(1, x_1.shape[1], x_1.shape[2]) * x_0
bsz = x_1.shape[0]
guidance_vec = torch.full((x_t.shape[0],), 4, device=x_t.device, dtype=x_t.dtype)
Expand Down