Skip to content

Commit

Permalink
Update app.py
Browse files Browse the repository at this point in the history
  • Loading branch information
nftblackmagic authored Nov 26, 2024
1 parent ddc1268 commit 7b183da
Showing 1 changed file with 13 additions and 46 deletions.
59 changes: 13 additions & 46 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,55 +8,22 @@
import tempfile
import torch
from diffusers import FluxTransformer2DModel, FluxFillPipeline
import subprocess

import shutil
subprocess.run("rm -rf /data-nvme/zerogpu-offload/*", env={}, shell=True)
dtype = torch.bfloat16
device = "cuda" if torch.cuda.is_available() else "cpu"

def find_cuda():
# Check if CUDA_HOME or CUDA_PATH environment variables are set
cuda_home = os.environ.get('CUDA_HOME') or os.environ.get('CUDA_PATH')

if cuda_home and os.path.exists(cuda_home):
return cuda_home

# Search for the nvcc executable in the system's PATH
nvcc_path = shutil.which('nvcc')

if nvcc_path:
# Remove the 'bin/nvcc' part to get the CUDA installation path
cuda_path = os.path.dirname(os.path.dirname(nvcc_path))
return cuda_path

return None

cuda_path = find_cuda()

if cuda_path:
print(f"CUDA installation found at: {cuda_path}")
else:
print("CUDA installation not found")

device = torch.device('cuda')

print("Start loading LoRA weights")
state_dict, network_alphas = FluxFillPipeline.lora_state_dict(
pretrained_model_name_or_path_or_dict="xiaozaa/catvton-flux-lora-alpha", ## The tryon Lora weights
weight_name="pytorch_lora_weights.safetensors",
return_alphas=True
)
is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.")
print('Loading diffusion model ...')
transformer = FluxTransformer2DModel.from_pretrained(
"xiaozaa/catvton-flux-alpha",
torch_dtype=device
)
pipe = FluxFillPipeline.from_pretrained(
"black-forest-labs/FLUX.1-Fill-dev",
torch_dtype=torch.bfloat16
"black-forest-labs/FLUX.1-dev",
transformer=transformer,
torch_dtype=device
).to(device)
FluxFillPipeline.load_lora_into_transformer(
state_dict=state_dict,
network_alphas=network_alphas,
transformer=pipe.transformer,
)

print('Loading Finished!')

@spaces.GPU
Expand Down Expand Up @@ -109,7 +76,7 @@ def gradio_inference(

with gr.Blocks() as demo:
gr.Markdown("""
# CATVTON FLUX Virtual Try-On Demo (by using LoRA weights)
# CATVTON FLUX Virtual Try-On Demo
Upload a model image, draw a mask, and a garment image to generate virtual try-on results.
[![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/xiaozaa/catvton-flux-alpha)
Expand Down Expand Up @@ -222,4 +189,4 @@ def gradio_inference(
)


demo.launch()
demo.launch()

0 comments on commit 7b183da

Please sign in to comment.